Merge branch 'main' into fix-blockstored-kvevent

This commit is contained in:
Nick Hill 2025-12-17 15:29:26 -08:00 committed by GitHub
commit a3a2971c96
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
168 changed files with 3340 additions and 1538 deletions

View File

@ -141,7 +141,6 @@ if [[ $commands == *" entrypoints/openai "* ]]; then
--ignore=entrypoints/openai/test_audio.py \ --ignore=entrypoints/openai/test_audio.py \
--ignore=entrypoints/openai/test_shutdown.py \ --ignore=entrypoints/openai/test_shutdown.py \
--ignore=entrypoints/openai/test_completion.py \ --ignore=entrypoints/openai/test_completion.py \
--ignore=entrypoints/openai/test_sleep.py \
--ignore=entrypoints/openai/test_models.py \ --ignore=entrypoints/openai/test_models.py \
--ignore=entrypoints/openai/test_lora_adapters.py \ --ignore=entrypoints/openai/test_lora_adapters.py \
--ignore=entrypoints/openai/test_return_tokens_as_ids.py \ --ignore=entrypoints/openai/test_return_tokens_as_ids.py \

View File

@ -39,7 +39,7 @@ docker run \
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
python3 examples/offline_inference/basic/generate.py --model Intel/Qwen2.5-0.5B-W4A16-G128-AutoRound-LLMC-TEST-ONLY --enforce-eager python3 examples/offline_inference/basic/generate.py --model Intel/Qwen2.5-0.5B-W4A16-G128-AutoRound-LLMC-TEST-ONLY --enforce-eager
VLLM_ATTENTION_BACKEND=TRITON_ATTN python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager --attention-backend=TRITON_ATTN
cd tests cd tests
pytest -v -s v1/core pytest -v -s v1/core
pytest -v -s v1/engine pytest -v -s v1/engine

View File

@ -128,7 +128,7 @@ steps:
- tests/entrypoints/ - tests/entrypoints/
commands: commands:
- pytest -v -s entrypoints/openai/tool_parsers - pytest -v -s entrypoints/openai/tool_parsers
- pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/openai --ignore=entrypoints/offline_mode --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling - pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/openai --ignore=entrypoints/rpc --ignore=entrypoints/sleep --ignore=entrypoints/instrumentator --ignore=entrypoints/offline_mode --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling
- label: Entrypoints Integration Test (LLM) # 30min - label: Entrypoints Integration Test (LLM) # 30min
timeout_in_minutes: 40 timeout_in_minutes: 40
@ -148,7 +148,7 @@ steps:
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process - pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
- pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests - pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
- label: Entrypoints Integration Test (API Server) # 100min - label: Entrypoints Integration Test (API Server 1) # 100min
timeout_in_minutes: 130 timeout_in_minutes: 130
mirror_hardwares: [amdexperimental] mirror_hardwares: [amdexperimental]
agent_pool: mi325_1 agent_pool: mi325_1
@ -162,10 +162,28 @@ steps:
- tests/entrypoints/test_chat_utils - tests/entrypoints/test_chat_utils
commands: commands:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn - export VLLM_WORKER_MULTIPROC_METHOD=spawn
- PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/openai/test_collective_rpc.py # PYTHONPATH is needed to import custom Worker extension - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/tool_parsers/
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_collective_rpc.py --ignore=entrypoints/openai/tool_parsers/
- pytest -v -s entrypoints/test_chat_utils.py - pytest -v -s entrypoints/test_chat_utils.py
- label: Entrypoints Integration Test (API Server 2)
timeout_in_minutes: 50
mirror_hardwares: [amdexperimental]
agent_pool: mi325_1
# grade: Blocking
working_dir: "/vllm-workspace/tests"
fast_check: true
torch_nightly: true
source_file_dependencies:
- vllm/
- tests/entrypoints/sleep
- tests/entrypoints/rpc
- tests/tool_use
commands:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -v -s entrypoints/sleep
- pytest -v -s tool_use
- PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/rpc
- label: Entrypoints Integration Test (Pooling) - label: Entrypoints Integration Test (Pooling)
timeout_in_minutes: 50 timeout_in_minutes: 50
mirror_hardwares: [amdexperimental] mirror_hardwares: [amdexperimental]
@ -722,7 +740,7 @@ steps:
# https://github.com/pytorch/ao/issues/2919, we'll have to skip new torchao tests for now # https://github.com/pytorch/ao/issues/2919, we'll have to skip new torchao tests for now
# we can only upgrade after this is resolved # we can only upgrade after this is resolved
# TODO(jerryzh168): resolve the above comment # TODO(jerryzh168): resolve the above comment
- uv pip install --system torchao==0.13.0 - uv pip install --system torchao==0.14.1
- uv pip install --system conch-triton-kernels - uv pip install --system conch-triton-kernels
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/ --ignore quantization/test_blackwell_moe.py - VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/ --ignore quantization/test_blackwell_moe.py
@ -751,17 +769,6 @@ steps:
# Transcription WER check is skipped because encoder-decoder models are not supported on ROCm, see https://github.com/vllm-project/vllm/issues/27442 # Transcription WER check is skipped because encoder-decoder models are not supported on ROCm, see https://github.com/vllm-project/vllm/issues/27442
- pytest -s entrypoints/openai/correctness/ - pytest -s entrypoints/openai/correctness/
- label: OpenAI-Compatible Tool Use # 23 min
timeout_in_minutes: 35
mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_1
# grade: Blocking
fast_check: false
source_file_dependencies:
- vllm/
- tests/tool_use
commands:
- pytest -v -s tool_use
##### models test ##### ##### models test #####

View File

@ -114,7 +114,7 @@ steps:
- tests/entrypoints/ - tests/entrypoints/
commands: commands:
- pytest -v -s entrypoints/openai/tool_parsers - pytest -v -s entrypoints/openai/tool_parsers
- pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/openai --ignore=entrypoints/offline_mode --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling - pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/rpc --ignore=entrypoints/sleep --ignore=entrypoints/instrumentator --ignore=entrypoints/openai --ignore=entrypoints/offline_mode --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling
- label: Entrypoints Integration Test (LLM) # 30min - label: Entrypoints Integration Test (LLM) # 30min
timeout_in_minutes: 40 timeout_in_minutes: 40
@ -132,7 +132,7 @@ steps:
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process - pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
- pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests - pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
- label: Entrypoints Integration Test (API Server) # 100min - label: Entrypoints Integration Test (API Server 1) # 100min
timeout_in_minutes: 130 timeout_in_minutes: 130
mirror_hardwares: [amdexperimental] mirror_hardwares: [amdexperimental]
working_dir: "/vllm-workspace/tests" working_dir: "/vllm-workspace/tests"
@ -144,10 +144,26 @@ steps:
- tests/entrypoints/test_chat_utils - tests/entrypoints/test_chat_utils
commands: commands:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn - export VLLM_WORKER_MULTIPROC_METHOD=spawn
- PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/openai/test_collective_rpc.py # PYTHONPATH is needed to import custom Worker extension - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/tool_parsers/
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_collective_rpc.py --ignore=entrypoints/openai/tool_parsers/
- pytest -v -s entrypoints/test_chat_utils.py - pytest -v -s entrypoints/test_chat_utils.py
- label: Entrypoints Integration Test (API Server 2)
timeout_in_minutes: 50
mirror_hardwares: [amdexperimental]
working_dir: "/vllm-workspace/tests"
fast_check: true
torch_nightly: true
source_file_dependencies:
- vllm/
- tests/entrypoints/sleep
- tests/entrypoints/rpc
- tests/tool_use
commands:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -v -s entrypoints/sleep
- PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/rpc
- pytest -v -s tool_use
- label: Entrypoints Integration Test (Pooling) - label: Entrypoints Integration Test (Pooling)
timeout_in_minutes: 50 timeout_in_minutes: 50
mirror_hardwares: [amdexperimental] mirror_hardwares: [amdexperimental]
@ -642,7 +658,7 @@ steps:
# https://github.com/pytorch/ao/issues/2919, we'll have to skip new torchao tests for now # https://github.com/pytorch/ao/issues/2919, we'll have to skip new torchao tests for now
# we can only upgrade after this is resolved # we can only upgrade after this is resolved
# TODO(jerryzh168): resolve the above comment # TODO(jerryzh168): resolve the above comment
- uv pip install --system torchao==0.13.0 --index-url https://download.pytorch.org/whl/cu129 - uv pip install --system torchao==0.14.1 --index-url https://download.pytorch.org/whl/cu129
- uv pip install --system conch-triton-kernels - uv pip install --system conch-triton-kernels
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/ --ignore quantization/test_blackwell_moe.py - VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/ --ignore quantization/test_blackwell_moe.py
@ -654,7 +670,7 @@ steps:
- vllm/model_executor/layers/quantization - vllm/model_executor/layers/quantization
autorun_on_main: true autorun_on_main: true
commands: commands:
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1 - pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt
- label: OpenAI API correctness # 22min - label: OpenAI API correctness # 22min
timeout_in_minutes: 30 timeout_in_minutes: 30
@ -666,16 +682,6 @@ steps:
commands: # LMEval+Transcription WER check commands: # LMEval+Transcription WER check
- pytest -s entrypoints/openai/correctness/ - pytest -s entrypoints/openai/correctness/
- label: OpenAI-Compatible Tool Use # 23 min
timeout_in_minutes: 35
mirror_hardwares: [amdexperimental]
fast_check: false
source_file_dependencies:
- vllm/
- tests/tool_use
commands:
- pytest -v -s tool_use
##### models test ##### ##### models test #####
- label: Basic Models Tests (Initialization) - label: Basic Models Tests (Initialization)
@ -1064,7 +1070,7 @@ steps:
- csrc/ - csrc/
- vllm/model_executor/layers/quantization - vllm/model_executor/layers/quantization
commands: commands:
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-blackwell.txt --tp-size=1 - pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-blackwell.txt
##### 1 GPU test ##### ##### 1 GPU test #####
##### multi gpus test ##### ##### multi gpus test #####

View File

@ -32,6 +32,7 @@ steps:
- label: Prime-RL Integration (2 GPUs) - label: Prime-RL Integration (2 GPUs)
timeout_in_minutes: 30 timeout_in_minutes: 30
optional: true optional: true
soft_fail: true
num_gpus: 2 num_gpus: 2
working_dir: "/vllm-workspace" working_dir: "/vllm-workspace"
source_file_dependencies: source_file_dependencies:
@ -39,21 +40,3 @@ steps:
- .buildkite/scripts/run-prime-rl-test.sh - .buildkite/scripts/run-prime-rl-test.sh
commands: commands:
- bash .buildkite/scripts/run-prime-rl-test.sh - bash .buildkite/scripts/run-prime-rl-test.sh
- label: DeepSeek V2-Lite Async EPLB Accuracy
timeout_in_minutes: 60
gpu: h100
optional: true
num_gpus: 4
working_dir: "/vllm-workspace"
commands:
- bash .buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_async_eplb.sh 0.25 1319 8030
- label: Qwen3-Next-80B-A3B-Instruct MTP Async EPLB Accuracy
timeout_in_minutes: 60
gpu: h100
optional: true
num_gpus: 4
working_dir: "/vllm-workspace"
commands:
- bash .buildkite/scripts/scheduled_integration_test/qwen3_next_mtp_async_eplb.sh 0.8 1319 8040

View File

@ -10,7 +10,7 @@ steps:
- tests/entrypoints/ - tests/entrypoints/
commands: commands:
- pytest -v -s entrypoints/openai/tool_parsers - pytest -v -s entrypoints/openai/tool_parsers
- pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/openai --ignore=entrypoints/offline_mode --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling - pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/rpc --ignore=entrypoints/sleep --ignore=entrypoints/instrumentator --ignore=entrypoints/openai --ignore=entrypoints/offline_mode --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling
- label: Entrypoints Integration (LLM) - label: Entrypoints Integration (LLM)
timeout_in_minutes: 40 timeout_in_minutes: 40
@ -25,7 +25,7 @@ steps:
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process - pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
- pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests - pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
- label: Entrypoints Integration (API Server) - label: Entrypoints Integration (API Server 1)
timeout_in_minutes: 130 timeout_in_minutes: 130
working_dir: "/vllm-workspace/tests" working_dir: "/vllm-workspace/tests"
source_file_dependencies: source_file_dependencies:
@ -34,11 +34,26 @@ steps:
- tests/entrypoints/test_chat_utils - tests/entrypoints/test_chat_utils
commands: commands:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn - export VLLM_WORKER_MULTIPROC_METHOD=spawn
- PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/openai/test_collective_rpc.py # PYTHONPATH is needed to import custom Worker extension - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/tool_parsers/
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_collective_rpc.py --ignore=entrypoints/openai/tool_parsers/
- pytest -v -s entrypoints/test_chat_utils.py - pytest -v -s entrypoints/test_chat_utils.py
- label: Entrypoints Integration (API Server 2)
timeout_in_minutes: 130
working_dir: "/vllm-workspace/tests"
source_file_dependencies:
- vllm/
- tests/tool_use
- tests/entrypoints/sleep
- tests/entrypoints/instrumentator
- tests/entrypoints/rpc
commands:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/rpc
- pytest -v -s entrypoints/instrumentator
- pytest -v -s entrypoints/sleep
- pytest -v -s tool_use
- label: Entrypoints Integration (Pooling) - label: Entrypoints Integration (Pooling)
timeout_in_minutes: 50 timeout_in_minutes: 50
working_dir: "/vllm-workspace/tests" working_dir: "/vllm-workspace/tests"

View File

@ -9,7 +9,7 @@ steps:
- vllm/model_executor/layers/quantization - vllm/model_executor/layers/quantization
autorun_on_main: true autorun_on_main: true
commands: commands:
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1 - pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt
- label: LM Eval Large Models (4 GPUs)(A100) - label: LM Eval Large Models (4 GPUs)(A100)
gpu: a100 gpu: a100
@ -43,4 +43,4 @@ steps:
- csrc/ - csrc/
- vllm/model_executor/layers/quantization - vllm/model_executor/layers/quantization
commands: commands:
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-blackwell.txt --tp-size=1 - pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-blackwell.txt

View File

@ -22,6 +22,8 @@ steps:
# FIXIT: find out which code initialize cuda before running the test # FIXIT: find out which code initialize cuda before running the test
# before the fix, we need to use spawn to test it # before the fix, we need to use spawn to test it
- export VLLM_WORKER_MULTIPROC_METHOD=spawn - export VLLM_WORKER_MULTIPROC_METHOD=spawn
# Alot of these tests are on the edge of OOMing
- export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
# There is some Tensor Parallelism related processing logic in LoRA that # There is some Tensor Parallelism related processing logic in LoRA that
# requires multi-GPU testing for validation. # requires multi-GPU testing for validation.
- pytest -v -s -x lora/test_chatglm3_tp.py - pytest -v -s -x lora/test_chatglm3_tp.py

View File

@ -9,6 +9,7 @@ steps:
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/models/test_initialization.py - tests/models/test_initialization.py
- tests/models/registry.py
commands: commands:
# Run a subset of model initialization tests # Run a subset of model initialization tests
- pytest -v -s models/test_initialization.py::test_can_initialize_small_subset - pytest -v -s models/test_initialization.py::test_can_initialize_small_subset
@ -20,6 +21,7 @@ steps:
source_file_dependencies: source_file_dependencies:
- vllm/model_executor/models/ - vllm/model_executor/models/
- tests/models/test_initialization.py - tests/models/test_initialization.py
- tests/models/registry.py
commands: commands:
# Only when vLLM model source is modified - test initialization of a large # Only when vLLM model source is modified - test initialization of a large
# subset of supported models (the complement of the small subset in the above # subset of supported models (the complement of the small subset in the above

View File

@ -13,7 +13,9 @@ steps:
# tests covered elsewhere. # tests covered elsewhere.
# Use `find` to launch multiple instances of pytest so that # Use `find` to launch multiple instances of pytest so that
# they do not suffer from https://github.com/vllm-project/vllm/issues/28965 # they do not suffer from https://github.com/vllm-project/vllm/issues/28965
- "find compile/ -maxdepth 1 -name 'test_*.py' -exec pytest -s -v {} \\;" # However, find does not normally propagate error codes, so we combine it with xargs
# (using -0 for proper path handling)
- "find compile/ -maxdepth 1 -name 'test_*.py' -print0 | xargs -0 -n1 -I{} pytest -s -v '{}'"
- label: PyTorch Fullgraph Smoke Test - label: PyTorch Fullgraph Smoke Test
timeout_in_minutes: 30 timeout_in_minutes: 30

View File

@ -1,13 +0,0 @@
group: Tool use
depends_on:
- image-build
steps:
- label: OpenAI-Compatible Tool Use
timeout_in_minutes: 35
mirror_hardwares: [amdexperimental]
fast_check: false
source_file_dependencies:
- vllm/
- tests/tool_use
commands:
- pytest -v -s tool_use

14
.github/mergify.yml vendored
View File

@ -235,6 +235,20 @@ pull_request_rules:
add: add:
- rocm - rocm
- name: label-cpu
description: Automatically apply cpu label
conditions:
- label != stale
- files~=^(?!.*kv_offload)(?!.*cpu_offload).*\bcpu.*
actions:
label:
add:
- cpu
assign:
users:
- "fadara01"
- "aditew01"
- name: label-structured-output - name: label-structured-output
description: Automatically apply structured-output label description: Automatically apply structured-output label
conditions: conditions:

View File

@ -56,8 +56,8 @@ endif()
# requirements.txt files and should be kept consistent. The ROCm torch # requirements.txt files and should be kept consistent. The ROCm torch
# versions are derived from docker/Dockerfile.rocm # versions are derived from docker/Dockerfile.rocm
# #
set(TORCH_SUPPORTED_VERSION_CUDA "2.9.0") set(TORCH_SUPPORTED_VERSION_CUDA "2.9.1")
set(TORCH_SUPPORTED_VERSION_ROCM "2.9.0") set(TORCH_SUPPORTED_VERSION_ROCM "2.9.1")
# #
# Try to find python package with an executable that exactly matches # Try to find python package with an executable that exactly matches
@ -357,6 +357,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# marlin arches for fp16 output # marlin arches for fp16 output
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0+PTX" "${CUDA_ARCHS}") cuda_archs_loose_intersection(MARLIN_ARCHS "8.0+PTX" "${CUDA_ARCHS}")
# marlin has limited support for turing
cuda_archs_loose_intersection(MARLIN_SM75_ARCHS "7.5" "${CUDA_ARCHS}")
# marlin arches for bf16 output (we need 9.0 for bf16 atomicAdd PTX) # marlin arches for bf16 output (we need 9.0 for bf16 atomicAdd PTX)
cuda_archs_loose_intersection(MARLIN_BF16_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}") cuda_archs_loose_intersection(MARLIN_BF16_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")
# marlin arches for fp8 input # marlin arches for fp8 input
@ -364,8 +366,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction # - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction
# so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0) # so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0)
cuda_archs_loose_intersection(MARLIN_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}") cuda_archs_loose_intersection(MARLIN_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}")
# marlin arches for other files
cuda_archs_loose_intersection(MARLIN_OTHER_ARCHS "7.5;8.0+PTX" "${CUDA_ARCHS}")
if (MARLIN_ARCHS) if (MARLIN_OTHER_ARCHS)
# #
# For the Marlin kernels we automatically generate sources for various # For the Marlin kernels we automatically generate sources for various
@ -406,25 +410,39 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
message(STATUS "Marlin generation script has not changed, skipping generation.") message(STATUS "Marlin generation script has not changed, skipping generation.")
endif() endif()
file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_float16.cu") if (MARLIN_ARCHS)
set_gencode_flags_for_srcs( file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_float16.cu")
SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}" set_gencode_flags_for_srcs(
CUDA_ARCHS "${MARLIN_ARCHS}") SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}"
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) CUDA_ARCHS "${MARLIN_ARCHS}")
set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC} if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC}
endif() PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC}) endif()
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
file(GLOB MARLIN_TEMPLATE_BF16_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_bfloat16.cu") file(GLOB MARLIN_TEMPLATE_BF16_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_bfloat16.cu")
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
SRCS "${MARLIN_TEMPLATE_BF16_KERNEL_SRC}" SRCS "${MARLIN_TEMPLATE_BF16_KERNEL_SRC}"
CUDA_ARCHS "${MARLIN_BF16_ARCHS}") CUDA_ARCHS "${MARLIN_BF16_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
set_source_files_properties(${MARLIN_TEMPLATE_BF16_KERNEL_SRC} set_source_files_properties(${MARLIN_TEMPLATE_BF16_KERNEL_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif()
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_BF16_KERNEL_SRC})
endif()
if (MARLIN_SM75_ARCHS)
file(GLOB MARLIN_TEMPLATE_SM75_KERNEL_SRC "csrc/quantization/gptq_marlin/sm75_kernel_*.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_TEMPLATE_SM75_KERNEL_SRC}"
CUDA_ARCHS "${MARLIN_SM75_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
set_source_files_properties(${MARLIN_TEMPLATE_SM75_KERNEL_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif()
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_SM75_KERNEL_SRC})
endif() endif()
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_BF16_KERNEL_SRC})
if (MARLIN_FP8_ARCHS) if (MARLIN_FP8_ARCHS)
file(GLOB MARLIN_TEMPLATE_FP8_KERNEL_SRC "csrc/quantization/gptq_marlin/sm89_kernel_*.cu") file(GLOB MARLIN_TEMPLATE_FP8_KERNEL_SRC "csrc/quantization/gptq_marlin/sm89_kernel_*.cu")
@ -446,14 +464,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu") "csrc/quantization/gptq_marlin/awq_marlin_repack.cu")
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
SRCS "${MARLIN_SRCS}" SRCS "${MARLIN_SRCS}"
CUDA_ARCHS "${MARLIN_ARCHS}") CUDA_ARCHS "${MARLIN_OTHER_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
set_source_files_properties("csrc/quantization/gptq_marlin/gptq_marlin.cu" set_source_files_properties(${MARLIN_SRCS}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif() endif()
list(APPEND VLLM_EXT_SRC "${MARLIN_SRCS}") list(APPEND VLLM_EXT_SRC "${MARLIN_SRCS}")
message(STATUS "Building Marlin kernels for archs: ${MARLIN_ARCHS}") message(STATUS "Building Marlin kernels for archs: ${MARLIN_OTHER_ARCHS}")
else() else()
message(STATUS "Not building Marlin kernels as no compatible archs found" message(STATUS "Not building Marlin kernels as no compatible archs found"
" in CUDA target architectures") " in CUDA target architectures")
@ -980,12 +998,16 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# note that we always set `use_atomic_add=False` for moe marlin now, # note that we always set `use_atomic_add=False` for moe marlin now,
# so we don't need 9.0 for bf16 atomicAdd PTX # so we don't need 9.0 for bf16 atomicAdd PTX
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0+PTX" "${CUDA_ARCHS}") cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0+PTX" "${CUDA_ARCHS}")
# moe marlin has limited support for turing
cuda_archs_loose_intersection(MARLIN_MOE_SM75_ARCHS "7.5" "${CUDA_ARCHS}")
# moe marlin arches for fp8 input # moe marlin arches for fp8 input
# - sm80 doesn't support fp8 computation # - sm80 doesn't support fp8 computation
# - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction # - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction
# so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0) # so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0)
cuda_archs_loose_intersection(MARLIN_MOE_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}") cuda_archs_loose_intersection(MARLIN_MOE_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}")
if (MARLIN_MOE_ARCHS) # moe marlin arches for other files
cuda_archs_loose_intersection(MARLIN_MOE_OTHER_ARCHS "7.5;8.0+PTX" "${CUDA_ARCHS}")
if (MARLIN_MOE_OTHER_ARCHS)
# #
# For the Marlin MOE kernels we automatically generate sources for various # For the Marlin MOE kernels we automatically generate sources for various
@ -1026,16 +1048,29 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
message(STATUS "Marlin MOE generation script has not changed, skipping generation.") message(STATUS "Marlin MOE generation script has not changed, skipping generation.")
endif() endif()
file(GLOB MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/sm80_kernel_*.cu") if (MARLIN_MOE_ARCHS)
list(APPEND MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/ops.cu") file(GLOB MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/sm80_kernel_*.cu")
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
SRCS "${MARLIN_MOE_SRC}" SRCS "${MARLIN_MOE_SRC}"
CUDA_ARCHS "${MARLIN_MOE_ARCHS}") CUDA_ARCHS "${MARLIN_MOE_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
set_source_files_properties(${MARLIN_MOE_SRC} set_source_files_properties(${MARLIN_MOE_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif()
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_SRC})
endif()
if (MARLIN_MOE_SM75_ARCHS)
file(GLOB MARLIN_MOE_SM75_SRC "csrc/moe/marlin_moe_wna16/sm75_kernel_*.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_MOE_SM75_SRC}"
CUDA_ARCHS "${MARLIN_MOE_SM75_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
set_source_files_properties(${MARLIN_MOE_SM75_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif()
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_SM75_SRC})
endif() endif()
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_SRC})
if (MARLIN_MOE_FP8_ARCHS) if (MARLIN_MOE_FP8_ARCHS)
file(GLOB MARLIN_MOE_FP8_SRC "csrc/moe/marlin_moe_wna16/sm89_kernel_*.cu") file(GLOB MARLIN_MOE_FP8_SRC "csrc/moe/marlin_moe_wna16/sm89_kernel_*.cu")
@ -1049,7 +1084,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_FP8_SRC}) list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_FP8_SRC})
endif() endif()
message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}") set(MARLIN_MOE_OTHER_SRC "csrc/moe/marlin_moe_wna16/ops.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_MOE_OTHER_SRC}"
CUDA_ARCHS "${MARLIN_MOE_OTHER_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
set_source_files_properties(${MARLIN_MOE_OTHER_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif()
list(APPEND VLLM_MOE_EXT_SRC "${MARLIN_MOE_OTHER_SRC}")
message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_OTHER_ARCHS}")
else() else()
message(STATUS "Not building Marlin MOE kernels as no compatible archs found" message(STATUS "Not building Marlin MOE kernels as no compatible archs found"
" in CUDA target architectures") " in CUDA target architectures")

View File

@ -13,8 +13,8 @@ from vllm.triton_utils import triton
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
batch_size_range = [1, 16, 32, 64, 128] batch_size_range = [1, 16, 128]
seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096] seq_len_range = [1, 16, 64, 1024, 4096]
intermediate_size = [3072, 9728, 12288] intermediate_size = [3072, 9728, 12288]
configs = list(itertools.product(batch_size_range, seq_len_range, intermediate_size)) configs = list(itertools.product(batch_size_range, seq_len_range, intermediate_size))

View File

@ -15,19 +15,61 @@ __device__ __forceinline__ scalar_t compute(const scalar_t& x,
const scalar_t& y) { const scalar_t& y) {
return act_first ? ACT_FN(x) * y : x * ACT_FN(y); return act_first ? ACT_FN(x) * y : x * ACT_FN(y);
} }
// Activation and gating kernel template.
// Check if all pointers are 16-byte aligned for int4 vectorized access
__device__ __forceinline__ bool is_16byte_aligned(const void* ptr) {
return (reinterpret_cast<uintptr_t>(ptr) & 15) == 0;
}
// Activation and gating kernel template.
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&), template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
bool act_first> bool act_first>
__global__ void act_and_mul_kernel( __global__ void act_and_mul_kernel(
scalar_t* __restrict__ out, // [..., d] scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d] const scalar_t* __restrict__ input, // [..., 2, d]
const int d) { const int d) {
constexpr int VEC_SIZE = 16 / sizeof(scalar_t);
const int64_t token_idx = blockIdx.x; const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t* x_ptr = input + token_idx * 2 * d;
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); const scalar_t* y_ptr = x_ptr + d;
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]); scalar_t* out_ptr = out + token_idx * d;
out[token_idx * d + idx] = compute<scalar_t, ACT_FN, act_first>(x, y);
// Check alignment for 128-bit vectorized access.
// All three pointers must be 16-byte aligned for safe int4 operations.
const bool aligned = is_16byte_aligned(x_ptr) && is_16byte_aligned(y_ptr) &&
is_16byte_aligned(out_ptr);
if (aligned && d >= VEC_SIZE) {
// Fast path: 128-bit vectorized loop
const int4* x_vec = reinterpret_cast<const int4*>(x_ptr);
const int4* y_vec = reinterpret_cast<const int4*>(y_ptr);
int4* out_vec = reinterpret_cast<int4*>(out_ptr);
const int num_vecs = d / VEC_SIZE;
const int vec_end = num_vecs * VEC_SIZE;
for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
int4 x = VLLM_LDG(&x_vec[i]), y = VLLM_LDG(&y_vec[i]), r;
auto* xp = reinterpret_cast<scalar_t*>(&x);
auto* yp = reinterpret_cast<scalar_t*>(&y);
auto* rp = reinterpret_cast<scalar_t*>(&r);
#pragma unroll
for (int j = 0; j < VEC_SIZE; j++) {
rp[j] = compute<scalar_t, ACT_FN, act_first>(xp[j], yp[j]);
}
out_vec[i] = r;
}
// Scalar cleanup for remaining elements
for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) {
out_ptr[i] = compute<scalar_t, ACT_FN, act_first>(VLLM_LDG(&x_ptr[i]),
VLLM_LDG(&y_ptr[i]));
}
} else {
// Scalar fallback for unaligned data or small d
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&x_ptr[idx]);
const scalar_t y = VLLM_LDG(&y_ptr[idx]);
out_ptr[idx] = compute<scalar_t, ACT_FN, act_first>(x, y);
}
} }
} }
@ -120,50 +162,115 @@ template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&, const float)>
__global__ void act_and_mul_kernel_with_param( __global__ void act_and_mul_kernel_with_param(
scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const int d, scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const int d,
const float param) { const float param) {
constexpr int VEC_SIZE = 16 / sizeof(scalar_t);
const int64_t token_idx = blockIdx.x; const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t* x_ptr = input + token_idx * 2 * d;
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); const scalar_t* y_ptr = x_ptr + d;
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]); scalar_t* out_ptr = out + token_idx * d;
out[token_idx * d + idx] = ACT_FN(x, param) * y;
// Check alignment for 128-bit vectorized access
const bool aligned = is_16byte_aligned(x_ptr) && is_16byte_aligned(y_ptr) &&
is_16byte_aligned(out_ptr);
if (aligned && d >= VEC_SIZE) {
// Fast path: 128-bit vectorized loop
const int4* x_vec = reinterpret_cast<const int4*>(x_ptr);
const int4* y_vec = reinterpret_cast<const int4*>(y_ptr);
int4* out_vec = reinterpret_cast<int4*>(out_ptr);
const int num_vecs = d / VEC_SIZE;
const int vec_end = num_vecs * VEC_SIZE;
for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
int4 x = VLLM_LDG(&x_vec[i]), y = VLLM_LDG(&y_vec[i]), r;
auto* xp = reinterpret_cast<scalar_t*>(&x);
auto* yp = reinterpret_cast<scalar_t*>(&y);
auto* rp = reinterpret_cast<scalar_t*>(&r);
#pragma unroll
for (int j = 0; j < VEC_SIZE; j++) {
rp[j] = ACT_FN(xp[j], param) * yp[j];
}
out_vec[i] = r;
}
// Scalar cleanup for remaining elements
for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) {
out_ptr[i] = ACT_FN(VLLM_LDG(&x_ptr[i]), param) * VLLM_LDG(&y_ptr[i]);
}
} else {
// Scalar fallback for unaligned data or small d
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&x_ptr[idx]);
const scalar_t y = VLLM_LDG(&y_ptr[idx]);
out_ptr[idx] = ACT_FN(x, param) * y;
}
} }
} }
template <typename T> template <typename T>
__device__ __forceinline__ T swigluoai_and_mul(const T& gate, const T& up, __device__ __forceinline__ T swigluoai_and_mul(const T& gate, const T& up,
float alpha, float limit) { float alpha, float limit) {
// clamp gate: min=None, max=limit // Clamp gate to (-inf, limit] and up to [-limit, limit]
const float gate_f = (float)gate; const float g = fminf((float)gate, limit);
const float clamped_gate = gate_f > limit ? limit : gate_f; const float u = fmaxf(fminf((float)up, limit), -limit);
// glu = gate * sigmoid(gate * alpha), then return (up + 1) * glu
// clamp up: min=-limit, max=limit return (T)((u + 1.0f) * g / (1.0f + expf(-g * alpha)));
const float up_f = (float)up;
const float clamped_up =
up_f > limit ? limit : (up_f < -limit ? -limit : up_f);
// glu = gate * sigmoid(gate * alpha)
const float sigmoid_val = 1.0f / (1.0f + expf(-clamped_gate * alpha));
const float glu = clamped_gate * sigmoid_val;
// (up + 1) * glu
return (T)((clamped_up + 1.0f) * glu);
} }
// Interleaved gate/up: input has [gate0, up0, gate1, up1, ...].
template <typename scalar_t, template <typename scalar_t,
scalar_t (*ACT_FN)(const scalar_t&, const scalar_t&, const float, scalar_t (*ACT_FN)(const scalar_t&, const scalar_t&, const float,
const float)> const float)>
__global__ void swigluoai_and_mul_kernel( __global__ void swigluoai_and_mul_kernel(
scalar_t* __restrict__ out, // [..., d] scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d] const scalar_t* __restrict__ input, // [..., 2 * d] (interleaved)
const int d, const float alpha, const float limit) { const int d, const float alpha, const float limit) {
// For interleaved data: input has 2*d elements per token (gate/up pairs)
// output has d elements per token
constexpr int VEC_SIZE = 16 / sizeof(scalar_t);
constexpr int PAIRS = VEC_SIZE / 2; // Number of gate/up pairs per int4 load
const int64_t token_idx = blockIdx.x; const int64_t token_idx = blockIdx.x;
// TODO: Vectorize loads and stores. const scalar_t* in_ptr = input + token_idx * 2 * d;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { scalar_t* out_ptr = out + token_idx * d;
// gate = x[..., ::2] (even indices)
const scalar_t gate = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx]);
// up = x[..., 1::2] (odd indices)
const scalar_t up = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx + 1]);
out[token_idx * d + idx] = ACT_FN(gate, up, alpha, limit); // Check alignment for 128-bit vectorized access on input.
// For output we use int2 (64-bit) which has 8-byte alignment requirement.
const bool in_aligned = is_16byte_aligned(in_ptr);
const bool out_aligned =
(reinterpret_cast<uintptr_t>(out_ptr) & 7) == 0; // 8-byte for int2
if (in_aligned && out_aligned && d >= PAIRS) {
// Fast path: vectorized loop
// Each int4 load gives VEC_SIZE elements = PAIRS gate/up pairs
// Each int2 store writes PAIRS output elements
const int4* in_vec = reinterpret_cast<const int4*>(in_ptr);
int2* out_vec = reinterpret_cast<int2*>(out_ptr);
const int num_vecs = d / PAIRS;
const int vec_end = num_vecs * PAIRS;
for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
int4 v = VLLM_LDG(&in_vec[i]);
int2 r;
auto* vp = reinterpret_cast<scalar_t*>(&v);
auto* rp = reinterpret_cast<scalar_t*>(&r);
#pragma unroll
for (int j = 0; j < PAIRS; j++) {
rp[j] = ACT_FN(vp[2 * j], vp[2 * j + 1], alpha, limit);
}
out_vec[i] = r;
}
// Scalar cleanup for remaining elements
for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) {
out_ptr[i] = ACT_FN(VLLM_LDG(&in_ptr[2 * i]),
VLLM_LDG(&in_ptr[2 * i + 1]), alpha, limit);
}
} else {
// Scalar fallback for unaligned data or small d
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
// gate = x[..., ::2] (even indices)
const scalar_t gate = VLLM_LDG(&in_ptr[2 * idx]);
// up = x[..., 1::2] (odd indices)
const scalar_t up = VLLM_LDG(&in_ptr[2 * idx + 1]);
out_ptr[idx] = ACT_FN(gate, up, alpha, limit);
}
} }
} }
@ -217,10 +324,41 @@ __global__ void activation_kernel(
scalar_t* __restrict__ out, // [..., d] scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., d] const scalar_t* __restrict__ input, // [..., d]
const int d) { const int d) {
constexpr int VEC_SIZE = 16 / sizeof(scalar_t);
const int64_t token_idx = blockIdx.x; const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t* in_ptr = input + token_idx * d;
const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]); scalar_t* out_ptr = out + token_idx * d;
out[token_idx * d + idx] = ACT_FN(x);
// Check alignment for 128-bit vectorized access
const bool aligned = is_16byte_aligned(in_ptr) && is_16byte_aligned(out_ptr);
if (aligned && d >= VEC_SIZE) {
// Fast path: 128-bit vectorized loop
const int4* in_vec = reinterpret_cast<const int4*>(in_ptr);
int4* out_vec = reinterpret_cast<int4*>(out_ptr);
const int num_vecs = d / VEC_SIZE;
const int vec_end = num_vecs * VEC_SIZE;
for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
int4 v = VLLM_LDG(&in_vec[i]), r;
auto* vp = reinterpret_cast<scalar_t*>(&v);
auto* rp = reinterpret_cast<scalar_t*>(&r);
#pragma unroll
for (int j = 0; j < VEC_SIZE; j++) {
rp[j] = ACT_FN(vp[j]);
}
out_vec[i] = r;
}
// Scalar cleanup for remaining elements
for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) {
out_ptr[i] = ACT_FN(VLLM_LDG(&in_ptr[i]));
}
} else {
// Scalar fallback for unaligned data or small d
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&in_ptr[idx]);
out_ptr[idx] = ACT_FN(x);
}
} }
} }

View File

@ -107,6 +107,16 @@ void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem,
prop.location.id = device; prop.location.id = device;
prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_NONE; prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_NONE;
#ifndef USE_ROCM
int flag = 0;
CUDA_CHECK(cuDeviceGetAttribute(
&flag, CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WITH_CUDA_VMM_SUPPORTED,
device));
if (flag) { // support GPUDirect RDMA if possible
prop.allocFlags.gpuDirectRDMACapable = 1;
}
#endif
#ifndef USE_ROCM #ifndef USE_ROCM
// Allocate memory using cuMemCreate // Allocate memory using cuMemCreate
CUDA_CHECK(cuMemCreate(p_memHandle, size, &prop, 0)); CUDA_CHECK(cuMemCreate(p_memHandle, size, &prop, 0));

View File

@ -446,9 +446,13 @@ __device__ inline T apply_sigmoid(T val) {
template <ScoringFunc SF, typename T> template <ScoringFunc SF, typename T>
__device__ inline T apply_scoring(T val) { __device__ inline T apply_scoring(T val) {
if constexpr (SF == SCORING_SIGMOID) { if constexpr (SF == SCORING_NONE) {
return val;
} else if constexpr (SF == SCORING_SIGMOID) {
return apply_sigmoid(val); return apply_sigmoid(val);
} else { } else {
static_assert(SF == SCORING_NONE || SF == SCORING_SIGMOID,
"Unsupported ScoringFunc in apply_scoring");
return val; return val;
} }
} }
@ -670,10 +674,13 @@ __global__ void group_idx_and_topk_idx_kernel(
if (case_id < num_tokens) { if (case_id < num_tokens) {
if (if_proceed_next_topk) { if (if_proceed_next_topk) {
float scale = routed_scaling_factor;
if (renormalize) {
scale /= topk_sum;
}
for (int i = lane_id; i < topk; i += WARP_SIZE) { for (int i = lane_id; i < topk; i += WARP_SIZE) {
float base = cuda_cast<float, T>(s_topk_value[i]); float base = cuda_cast<float, T>(s_topk_value[i]);
float value = renormalize ? (base / topk_sum * routed_scaling_factor) float value = base * scale;
: (base * routed_scaling_factor);
topk_indices[i] = s_topk_idx[i]; topk_indices[i] = s_topk_idx[i];
topk_values[i] = value; topk_values[i] = value;
} }

View File

@ -1,2 +1,3 @@
sm*_kernel_*.cu sm*_kernel_*.cu
kernel_selector.h kernel_selector.h
kernel_*.cu

View File

@ -10,6 +10,8 @@ import jinja2
ARCHS = [] ARCHS = []
SUPPORT_FP8 = False SUPPORT_FP8 = False
SUPPORT_SM75 = False
SUPPORT_SM80 = False
for arch in sys.argv[1].split(","): for arch in sys.argv[1].split(","):
arch = arch[: arch.index(".") + 2].replace(".", "") arch = arch[: arch.index(".") + 2].replace(".", "")
arch = int(arch) arch = int(arch)
@ -19,6 +21,10 @@ for arch in sys.argv[1].split(","):
# with FP16 MMA, so it cannot achieve any acceleration. # with FP16 MMA, so it cannot achieve any acceleration.
if arch in [89, 120]: if arch in [89, 120]:
SUPPORT_FP8 = True SUPPORT_FP8 = True
if arch >= 80:
SUPPORT_SM80 = True
if arch == 75:
SUPPORT_SM75 = True
FILE_HEAD_COMMENT = """ FILE_HEAD_COMMENT = """
// auto generated by generate_kernels.py // auto generated by generate_kernels.py
@ -157,6 +163,7 @@ def remove_old_kernels():
def generate_new_kernels(): def generate_new_kernels():
result_dict = {} result_dict = {}
sm_75_result_dict = {}
for quant_config in QUANT_CONFIGS: for quant_config in QUANT_CONFIGS:
c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"]) c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"])
@ -174,6 +181,8 @@ def generate_new_kernels():
s_type = quant_config.get("s_type", c_type) s_type = quant_config.get("s_type", c_type)
if (a_type, b_type, c_type) not in result_dict: if (a_type, b_type, c_type) not in result_dict:
result_dict[(a_type, b_type, c_type)] = [] result_dict[(a_type, b_type, c_type)] = []
if a_type in ["kFloat16", "kS8"] and c_type == "kFloat16":
sm_75_result_dict[(a_type, b_type, c_type)] = []
for group_blocks, m_blocks, thread_configs in itertools.product( for group_blocks, m_blocks, thread_configs in itertools.product(
all_group_blocks, all_m_blocks, all_thread_configs all_group_blocks, all_m_blocks, all_thread_configs
@ -197,78 +206,89 @@ def generate_new_kernels():
"thread_k_blocks": thread_k // 16, "thread_k_blocks": thread_k // 16,
"thread_n_blocks": thread_n // 16, "thread_n_blocks": thread_n // 16,
"m_block_size_8": "true" if m_blocks == 0.5 else "false", "m_block_size_8": "true" if m_blocks == 0.5 else "false",
"stages": "pipe_stages", "stages": 4,
"group_blocks": group_blocks, "group_blocks": group_blocks,
"is_zp_float": "false", "is_zp_float": "false",
} }
result_dict[(a_type, b_type, c_type)].append(config) if SUPPORT_SM80:
result_dict[(a_type, b_type, c_type)].append(config)
if (a_type, b_type, c_type) in sm_75_result_dict and SUPPORT_SM75:
config_sm75 = config.copy()
config_sm75["stages"] = 2
sm_75_result_dict[(a_type, b_type, c_type)].append(config_sm75)
kernel_selector_str = FILE_HEAD_COMMENT kernel_selector_str = FILE_HEAD_COMMENT
for (a_type, b_type, c_type), config_list in result_dict.items(): for result_dict_tmp in [result_dict, sm_75_result_dict]:
all_template_str_list = [] for (a_type, b_type, c_type), config_list in result_dict_tmp.items():
for config in config_list: all_template_str_list = []
s_type = config["s_type"] if not config_list:
template_str = jinja2.Template(TEMPLATE).render( continue
a_type_id=f"vllm::{a_type}.id()", for config in config_list:
b_type_id=f"vllm::{b_type}.id()", s_type = config["s_type"]
c_type_id=f"vllm::{c_type}.id()", template_str = jinja2.Template(TEMPLATE).render(
s_type_id=f"vllm::{s_type}.id()",
**config,
)
all_template_str_list.append(template_str)
conditions = [
f"a_type == vllm::{a_type}",
f"b_type == vllm::{b_type}",
f"c_type == vllm::{c_type}",
f"s_type == vllm::{s_type}",
f"threads == {config['threads']}",
f"thread_m_blocks == {config['thread_m_blocks']}",
f"thread_n_blocks == {config['thread_n_blocks']}",
f"thread_k_blocks == {config['thread_k_blocks']}",
f"m_block_size_8 == {config['m_block_size_8']}",
f"group_blocks == {config['group_blocks']}",
f"is_zp_float == {config['is_zp_float']}",
]
conditions = " && ".join(conditions)
if kernel_selector_str == FILE_HEAD_COMMENT:
kernel_selector_str += f"if ({conditions})\n kernel = "
else:
kernel_selector_str += f"else if ({conditions})\n kernel = "
kernel_template2 = (
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
"{{is_zp_float}}>;"
)
kernel_selector_str += (
jinja2.Template(kernel_template2).render(
a_type_id=f"vllm::{a_type}.id()", a_type_id=f"vllm::{a_type}.id()",
b_type_id=f"vllm::{b_type}.id()", b_type_id=f"vllm::{b_type}.id()",
c_type_id=f"vllm::{c_type}.id()", c_type_id=f"vllm::{c_type}.id()",
s_type_id=f"vllm::{s_type}.id()", s_type_id=f"vllm::{s_type}.id()",
**config, **config,
) )
+ "\n" all_template_str_list.append(template_str)
)
file_content = FILE_HEAD + "\n\n" conditions = [
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n" f"a_type == vllm::{a_type}",
if a_type == "kFE4M3fn": f"b_type == vllm::{b_type}",
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu" f"c_type == vllm::{c_type}",
else: f"s_type == vllm::{s_type}",
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu" f"threads == {config['threads']}",
f"thread_m_blocks == {config['thread_m_blocks']}",
f"thread_n_blocks == {config['thread_n_blocks']}",
f"thread_k_blocks == {config['thread_k_blocks']}",
f"m_block_size_8 == {config['m_block_size_8']}",
f"stages == {config['stages']}",
f"group_blocks == {config['group_blocks']}",
f"is_zp_float == {config['is_zp_float']}",
]
conditions = " && ".join(conditions)
filename = filename.lower() if kernel_selector_str == FILE_HEAD_COMMENT:
kernel_selector_str += f"if ({conditions})\n kernel = "
else:
kernel_selector_str += f"else if ({conditions})\n kernel = "
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f: kernel_template2 = (
f.write(file_content) "Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
"{{is_zp_float}}>;"
)
kernel_selector_str += (
jinja2.Template(kernel_template2).render(
a_type_id=f"vllm::{a_type}.id()",
b_type_id=f"vllm::{b_type}.id()",
c_type_id=f"vllm::{c_type}.id()",
s_type_id=f"vllm::{s_type}.id()",
**config,
)
+ "\n"
)
file_content = FILE_HEAD + "\n\n"
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
if a_type == "kFE4M3fn":
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
elif result_dict_tmp is sm_75_result_dict:
filename = f"sm75_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
else:
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
filename = filename.lower()
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
f.write(file_content)
if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT: if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT:
kernel_selector_str += ( kernel_selector_str += (

View File

@ -26,6 +26,7 @@
#include "quantization/gptq_marlin/marlin.cuh" #include "quantization/gptq_marlin/marlin.cuh"
#include "quantization/gptq_marlin/marlin_dtypes.cuh" #include "quantization/gptq_marlin/marlin_dtypes.cuh"
#include "quantization/gptq_marlin/dequant.h" #include "quantization/gptq_marlin/dequant.h"
#include "quantization/gptq_marlin/marlin_mma.h"
#include "core/scalar_type.hpp" #include "core/scalar_type.hpp"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
@ -35,7 +36,7 @@
namespace MARLIN_NAMESPACE_NAME { namespace MARLIN_NAMESPACE_NAME {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
template <typename scalar_t, // compute dtype, half or nv_float16 template <typename scalar_t, // compute dtype, half or nv_float16
const vllm::ScalarTypeId b_type_id, // weight MarlinScalarType id const vllm::ScalarTypeId b_type_id, // weight MarlinScalarType id
@ -84,146 +85,6 @@ __global__ void Marlin(
#else #else
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
// output/accumulation.
template <vllm::ScalarTypeId type_id, int k_size = 16>
__device__ inline void mma(
const typename MarlinScalarType<type_id>::FragA& a_frag,
const typename MarlinScalarType<type_id>::FragB& frag_b,
typename MarlinScalarType<type_id>::FragC& frag_c, int idx = 0) {
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
if constexpr (k_size == 16) {
if constexpr (std::is_same<scalar_t, half>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "f"(c[0]),
"f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "r"(c[0]),
"r"(c[1]), "r"(c[2]), "r"(c[3]));
}
} else if (k_size == 32) {
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
}
}
}
template <vllm::ScalarTypeId type_id, int k_size = 16>
__device__ inline void mma_trans(
const typename MarlinScalarType<type_id>::FragA& a_frag,
const typename MarlinScalarType<type_id>::FragB& frag_b,
const typename MarlinScalarType<type_id>::FragB& frag_b2,
typename MarlinScalarType<type_id>::FragC& frag_c) {
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
const uint32_t* b2 = reinterpret_cast<const uint32_t*>(&frag_b2);
float* c = reinterpret_cast<float*>(&frag_c);
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
if constexpr (k_size == 16) {
if constexpr (std::is_same<scalar_t, half>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
"f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]),
"r"(c[3]));
}
} else {
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 1200
asm volatile(
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
#else
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
#endif
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
}
}
}
// Instruction for loading a full 16x16 matrix fragment of operand A from shared // Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout. // memory, directly in tensor core layout.
template <int count, vllm::ScalarTypeId type_id> template <int count, vllm::ScalarTypeId type_id>
@ -439,9 +300,20 @@ __global__ void Marlin(
if constexpr (a_type_id == vllm::kFE4M3fn.id()) return; if constexpr (a_type_id == vllm::kFE4M3fn.id()) return;
#endif #endif
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
// Turing TensorCore only supports fp16 and int8
if constexpr (a_type_id != vllm::kFloat16.id() && a_type_id != vllm::kS8.id())
return;
#endif
int num_tokens_past_padded = num_tokens_past_padded_ptr[0]; int num_tokens_past_padded = num_tokens_past_padded_ptr[0];
constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
constexpr bool use_fp16_accum = a_type_id == vllm::kFloat16.id();
#else
constexpr bool use_fp16_accum = false;
#endif
using Adtype = MarlinScalarType<a_type_id>; using Adtype = MarlinScalarType<a_type_id>;
using Cdtype = MarlinScalarType<c_type_id>; using Cdtype = MarlinScalarType<c_type_id>;
@ -618,7 +490,22 @@ __global__ void Marlin(
} }
} }
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
if constexpr (moe_block_size >= 16)
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 16);
if constexpr (moe_block_size >= 8)
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 8);
if constexpr (moe_block_size >= 4)
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 4);
if constexpr (moe_block_size >= 2)
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 2);
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 1);
block_num_valid_tokens = local_count;
#else
block_num_valid_tokens = __reduce_add_sync(0xffffffff, local_count); block_num_valid_tokens = __reduce_add_sync(0xffffffff, local_count);
#endif
if (lane_id == 0) if (lane_id == 0)
reinterpret_cast<int*>(sh_new)[0] = block_num_valid_tokens; reinterpret_cast<int*>(sh_new)[0] = block_num_valid_tokens;
@ -1018,10 +905,6 @@ __global__ void Marlin(
constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride) constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride)
: (stages * s_sh_stage); : (stages * s_sh_stage);
int4* sh_s = sh_zp + (stages * zp_sh_stage); int4* sh_s = sh_zp + (stages * zp_sh_stage);
// shared memory reused by reduction should be smaller than
// shared memory used by weight.
static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <=
stages * b_sh_stage);
int4* sh_a = sh_s + sh_s_size; int4* sh_a = sh_s + sh_s_size;
// Register storage for double buffer of shared memory reads. // Register storage for double buffer of shared memory reads.
@ -1545,11 +1428,13 @@ __global__ void Marlin(
#pragma unroll #pragma unroll
for (int i = 0; i < thread_m_blocks; i++) { for (int i = 0; i < thread_m_blocks; i++) {
if constexpr (m_block_size_8) { if constexpr (m_block_size_8) {
mma_trans<a_type_id>(frag_a[k2][i], frag_b0, frag_b1, mma_trans<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b0, frag_b1,
frag_c[i][j][0]); frag_c[i][j][0]);
} else { } else {
mma<a_type_id>(frag_a[k2][i], frag_b0, frag_c[i][j][0]); mma<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b0,
mma<a_type_id>(frag_a[k2][i], frag_b1, frag_c[i][j][1]); frag_c[i][j][0]);
mma<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b1,
frag_c[i][j][1]);
} }
} }
} }
@ -1583,10 +1468,12 @@ __global__ void Marlin(
#pragma unroll #pragma unroll
for (int i = 0; i < thread_m_blocks; i++) { for (int i = 0; i < thread_m_blocks; i++) {
mma<a_type_id, 32>(frag_a[k2][i], frag_b[0], mma<a_type_id, false, 32>(
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]); frag_a[k2][i], frag_b[0],
mma<a_type_id, 32>(frag_a[k2][i], frag_b[1], (group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]);
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]); mma<a_type_id, false, 32>(
frag_a[k2][i], frag_b[1],
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]);
} }
if constexpr (group_blocks != -1) { if constexpr (group_blocks != -1) {
@ -2132,6 +2019,21 @@ __global__ void Marlin(
// While this pattern may not be the most readable, other ways of writing // While this pattern may not be the most readable, other ways of writing
// the loop seemed to noticeably worse performance after compilation. // the loop seemed to noticeably worse performance after compilation.
if (slice_iters == 0) { if (slice_iters == 0) {
// convert fp16 accum to fp32 for reduction
if constexpr (use_fp16_accum) {
#pragma unroll
for (int i = 0; i < (thread_m_blocks * (is_a_8bit ? 2 : 4) * 2); i++) {
float* frag_c_part_float = reinterpret_cast<float*>(frag_c) + i * 4;
scalar_t* frag_c_part_half =
reinterpret_cast<scalar_t*>(frag_c_part_float);
#pragma unroll
for (int i = 3; i >= 0; i--) {
frag_c_part_float[i] = Cdtype::num2float(frag_c_part_half[i]);
}
}
}
if constexpr (is_a_8bit) { if constexpr (is_a_8bit) {
float frag_a_s[2 * thread_m_blocks]; float frag_a_s[2 * thread_m_blocks];

View File

@ -142,7 +142,7 @@ typedef struct {
int get_scales_cache_size(thread_config_t const& th_config, int prob_m, int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
int prob_n, int prob_k, int num_bits, int group_size, int prob_n, int prob_k, int num_bits, int group_size,
bool has_act_order, bool is_k_full) { bool has_act_order, bool is_k_full, int stages) {
bool cache_scales_chunk = has_act_order && !is_k_full; bool cache_scales_chunk = has_act_order && !is_k_full;
int tb_n = th_config.thread_n; int tb_n = th_config.thread_n;
@ -160,13 +160,13 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
if (cache_scales_chunk) { if (cache_scales_chunk) {
int load_groups = int load_groups =
tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K tb_groups * stages * 2; // Chunk size is 2x pipeline over dim K
load_groups = max(load_groups, 32); // We load at least 32 scale groups load_groups = max(load_groups, 32); // We load at least 32 scale groups
return load_groups * tb_n * 2; return load_groups * tb_n * 2;
} else { } else {
int tb_scales = tb_groups * tb_n * 2; int tb_scales = tb_groups * tb_n * 2;
return tb_scales * pipe_stages; return tb_scales * stages;
} }
} }
@ -174,7 +174,7 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
int thread_m_blocks, int prob_m, int prob_n, int thread_m_blocks, int prob_m, int prob_n,
int prob_k, int num_bits, int group_size, int prob_k, int num_bits, int group_size,
bool has_act_order, bool is_k_full, int has_zp, bool has_act_order, bool is_k_full, int has_zp,
int is_zp_float, bool is_a_8bit) { int is_zp_float, bool is_a_8bit, int stages) {
int pack_factor = 32 / num_bits; int pack_factor = 32 / num_bits;
// Get B size // Get B size
@ -185,8 +185,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
// shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights // shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32) // both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
int sh_block_meta_size = tb_m * 16; int sh_block_meta_size = tb_m * 16;
int sh_a_size = pipe_stages * (tb_m * tb_k) * (is_a_8bit ? 1 : 2); int sh_a_size = stages * (tb_m * tb_k) * (is_a_8bit ? 1 : 2);
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4; int sh_b_size = stages * (tb_k * tb_n / pack_factor) * 4;
int sh_red_size = tb_m * (tb_n + 8) * 2; int sh_red_size = tb_m * (tb_n + 8) * 2;
int sh_bias_size = tb_n * 2; int sh_bias_size = tb_n * 2;
int tmp_size = int tmp_size =
@ -195,8 +195,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
int sh_s_size = int sh_s_size =
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
group_size, has_act_order, is_k_full); group_size, has_act_order, is_k_full, stages);
int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0; int sh_g_idx_size = has_act_order && !is_k_full ? stages * tb_k / 4 : 0;
int sh_zp_size = 0; int sh_zp_size = 0;
if (has_zp) { if (has_zp) {
if (is_zp_float) if (is_zp_float)
@ -217,7 +217,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
int thread_m_blocks, int prob_m, int prob_n, int prob_k, int thread_m_blocks, int prob_m, int prob_n, int prob_k,
int num_bits, int group_size, bool has_act_order, int num_bits, int group_size, bool has_act_order,
bool is_k_full, int has_zp, int is_zp_float, bool is_k_full, int has_zp, int is_zp_float,
int max_shared_mem, bool is_a_8bit) { bool is_a_8bit, int stages, int max_shared_mem) {
// Sanity // Sanity
if (th_config.thread_k == -1 || th_config.thread_n == -1 || if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
th_config.num_threads == -1) { th_config.num_threads == -1) {
@ -243,7 +243,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
int cache_size = int cache_size =
get_kernel_cache_size(th_config, m_block_size_8, thread_m_blocks, prob_m, get_kernel_cache_size(th_config, m_block_size_8, thread_m_blocks, prob_m,
prob_n, prob_k, num_bits, group_size, has_act_order, prob_n, prob_k, num_bits, group_size, has_act_order,
is_k_full, has_zp, is_zp_float, is_a_8bit); is_k_full, has_zp, is_zp_float, is_a_8bit, stages);
return cache_size <= max_shared_mem; return cache_size <= max_shared_mem;
} }
@ -252,7 +252,7 @@ MarlinFuncPtr get_marlin_kernel(
const vllm::ScalarType c_type, const vllm::ScalarType s_type, const vllm::ScalarType c_type, const vllm::ScalarType s_type,
int thread_m_blocks, int thread_n_blocks, int thread_k_blocks, int thread_m_blocks, int thread_n_blocks, int thread_k_blocks,
bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks, bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks,
int threads, bool is_zp_float) { int threads, bool is_zp_float, int stages) {
int num_bits = b_type.size_bits(); int num_bits = b_type.size_bits();
auto kernel = MarlinDefault; auto kernel = MarlinDefault;
@ -266,8 +266,8 @@ exec_config_t determine_exec_config(
const vllm::ScalarType& c_type, const vllm::ScalarType& s_type, int prob_m, const vllm::ScalarType& c_type, const vllm::ScalarType& s_type, int prob_m,
int prob_n, int prob_k, int num_experts, int top_k, int thread_m_blocks, int prob_n, int prob_k, int num_experts, int top_k, int thread_m_blocks,
bool m_block_size_8, int num_bits, int group_size, bool has_act_order, bool m_block_size_8, int num_bits, int group_size, bool has_act_order,
bool is_k_full, bool has_zp, bool is_zp_float, int max_shared_mem, int sms, bool is_k_full, bool has_zp, bool is_zp_float, bool is_a_8bit, int stages,
bool is_a_8bit) { int max_shared_mem, int sms) {
exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}}; exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}};
thread_config_t* thread_configs = thread_m_blocks > 1 thread_config_t* thread_configs = thread_m_blocks > 1
? large_batch_thread_configs ? large_batch_thread_configs
@ -284,15 +284,15 @@ exec_config_t determine_exec_config(
if (!is_valid_config(th_config, m_block_size_8, thread_m_blocks, prob_m, if (!is_valid_config(th_config, m_block_size_8, thread_m_blocks, prob_m,
prob_n, prob_k, num_bits, group_size, has_act_order, prob_n, prob_k, num_bits, group_size, has_act_order,
is_k_full, has_zp, is_zp_float, max_shared_mem - 512, is_k_full, has_zp, is_zp_float, is_a_8bit, stages,
is_a_8bit)) { max_shared_mem - 512)) {
continue; continue;
} }
int cache_size = get_kernel_cache_size( int cache_size = get_kernel_cache_size(
th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k, th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float, num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float,
is_a_8bit); is_a_8bit, stages);
int group_blocks = 0; int group_blocks = 0;
if (!has_act_order) { if (!has_act_order) {
@ -303,7 +303,7 @@ exec_config_t determine_exec_config(
get_marlin_kernel(a_type, b_type, c_type, s_type, thread_m_blocks, get_marlin_kernel(a_type, b_type, c_type, s_type, thread_m_blocks,
th_config.thread_n / 16, th_config.thread_k / 16, th_config.thread_n / 16, th_config.thread_k / 16,
m_block_size_8, has_act_order, has_zp, group_blocks, m_block_size_8, has_act_order, has_zp, group_blocks,
th_config.num_threads, is_zp_float); th_config.num_threads, is_zp_float, stages);
if (kernel == MarlinDefault) continue; if (kernel == MarlinDefault) continue;
@ -433,8 +433,14 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
dev); dev);
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
dev); dev);
TORCH_CHECK(major_capability * 10 + minor_capability >= 80, TORCH_CHECK(major_capability * 10 + minor_capability >= 75,
"marlin kernel only support Ampere or newer GPUs."); "marlin kernel only support Turing or newer GPUs.");
int stages = 4;
if (major_capability == 7 && minor_capability == 5) {
stages = 2;
TORCH_CHECK(a_type == vllm::kFloat16 || a_type == vllm::kS8,
"Turing only support FP16 or INT8 activation.");
}
if (a_type == vllm::kFE4M3fn) { if (a_type == vllm::kFE4M3fn) {
TORCH_CHECK(major_capability * 10 + minor_capability >= 89, TORCH_CHECK(major_capability * 10 + minor_capability >= 89,
"FP8 only support Ada Lovelace or newer GPUs."); "FP8 only support Ada Lovelace or newer GPUs.");
@ -461,8 +467,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
exec_cfg = determine_exec_config( exec_cfg = determine_exec_config(
a_type, b_type, c_type, s_type, prob_m, prob_n, prob_k, num_experts, a_type, b_type, c_type, s_type, prob_m, prob_n, prob_k, num_experts,
top_k, thread_m_blocks, m_block_size_8, num_bits, group_size, top_k, thread_m_blocks, m_block_size_8, num_bits, group_size,
has_act_order, is_k_full, has_zp, is_zp_float, max_shared_mem, sms, has_act_order, is_k_full, has_zp, is_zp_float, is_a_8bit, stages,
is_a_8bit); max_shared_mem, sms);
thread_tfg = exec_cfg.tb_cfg; thread_tfg = exec_cfg.tb_cfg;
} }
@ -479,7 +485,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
TORCH_CHECK(is_valid_config(thread_tfg, m_block_size_8, thread_m_blocks, TORCH_CHECK(is_valid_config(thread_tfg, m_block_size_8, thread_m_blocks,
prob_m, prob_n, prob_k, num_bits, group_size, prob_m, prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, has_zp, is_zp_float, has_act_order, is_k_full, has_zp, is_zp_float,
max_shared_mem, is_a_8bit), is_a_8bit, stages, max_shared_mem),
"Invalid thread config: thread_m_blocks = ", thread_m_blocks, "Invalid thread config: thread_m_blocks = ", thread_m_blocks,
", thread_k = ", thread_tfg.thread_k, ", thread_k = ", thread_tfg.thread_k,
", thread_n = ", thread_tfg.thread_n, ", thread_n = ", thread_tfg.thread_n,
@ -493,12 +499,12 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
int sh_cache_size = int sh_cache_size =
get_kernel_cache_size(thread_tfg, m_block_size_8, thread_m_blocks, prob_m, get_kernel_cache_size(thread_tfg, m_block_size_8, thread_m_blocks, prob_m,
prob_n, prob_k, num_bits, group_size, has_act_order, prob_n, prob_k, num_bits, group_size, has_act_order,
is_k_full, has_zp, is_zp_float, is_a_8bit); is_k_full, has_zp, is_zp_float, is_a_8bit, stages);
auto kernel = get_marlin_kernel( auto kernel = get_marlin_kernel(
a_type, b_type, c_type, s_type, thread_m_blocks, thread_n_blocks, a_type, b_type, c_type, s_type, thread_m_blocks, thread_n_blocks,
thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks, thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks,
num_threads, is_zp_float); num_threads, is_zp_float, stages);
if (kernel == MarlinDefault) { if (kernel == MarlinDefault) {
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n, TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,

View File

@ -1,2 +1,3 @@
sm*_kernel_*.cu sm*_kernel_*.cu
kernel_selector.h kernel_selector.h
kernel_*.cu

View File

@ -67,7 +67,7 @@ where `scale_factor * multiplier` can be computed at weight loading.
namespace MARLIN_NAMESPACE_NAME { namespace MARLIN_NAMESPACE_NAME {
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 750
// Lookup-table based 3-input logical operation; explicitly used for // Lookup-table based 3-input logical operation; explicitly used for
// dequantization as the compiler does not seem to automatically recognize it in // dequantization as the compiler does not seem to automatically recognize it in
// all cases. // all cases.

View File

@ -10,6 +10,8 @@ import jinja2
ARCHS = [] ARCHS = []
SUPPORT_FP8 = False SUPPORT_FP8 = False
SUPPORT_SM75 = False
SUPPORT_SM80 = False
for arch in sys.argv[1].split(","): for arch in sys.argv[1].split(","):
arch = arch[: arch.index(".") + 2].replace(".", "") arch = arch[: arch.index(".") + 2].replace(".", "")
arch = int(arch) arch = int(arch)
@ -19,6 +21,10 @@ for arch in sys.argv[1].split(","):
# with FP16 MMA, so it cannot achieve any acceleration. # with FP16 MMA, so it cannot achieve any acceleration.
if arch in [89, 120]: if arch in [89, 120]:
SUPPORT_FP8 = True SUPPORT_FP8 = True
if arch >= 80:
SUPPORT_SM80 = True
if arch == 75:
SUPPORT_SM75 = True
FILE_HEAD_COMMENT = """ FILE_HEAD_COMMENT = """
// auto generated by generate_kernels.py // auto generated by generate_kernels.py
@ -166,6 +172,7 @@ def remove_old_kernels():
def generate_new_kernels(): def generate_new_kernels():
result_dict = {} result_dict = {}
sm_75_result_dict = {}
for quant_config in QUANT_CONFIGS: for quant_config in QUANT_CONFIGS:
c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"]) c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"])
@ -184,6 +191,8 @@ def generate_new_kernels():
s_type = quant_config.get("s_type", c_type) s_type = quant_config.get("s_type", c_type)
if (a_type, b_type, c_type) not in result_dict: if (a_type, b_type, c_type) not in result_dict:
result_dict[(a_type, b_type, c_type)] = [] result_dict[(a_type, b_type, c_type)] = []
if a_type in ["kFloat16", "kS8"] and c_type == "kFloat16":
sm_75_result_dict[(a_type, b_type, c_type)] = []
for group_blocks, m_blocks, thread_configs in itertools.product( for group_blocks, m_blocks, thread_configs in itertools.product(
all_group_blocks, all_m_blocks, all_thread_configs all_group_blocks, all_m_blocks, all_thread_configs
@ -207,78 +216,89 @@ def generate_new_kernels():
"thread_k_blocks": thread_k // 16, "thread_k_blocks": thread_k // 16,
"thread_n_blocks": thread_n // 16, "thread_n_blocks": thread_n // 16,
"m_block_size_8": "true" if m_blocks == 0.5 else "false", "m_block_size_8": "true" if m_blocks == 0.5 else "false",
"stages": "pipe_stages", "stages": 4,
"group_blocks": group_blocks, "group_blocks": group_blocks,
"is_zp_float": "true" if is_zp_float else "false", "is_zp_float": "true" if is_zp_float else "false",
} }
result_dict[(a_type, b_type, c_type)].append(config) if SUPPORT_SM80:
result_dict[(a_type, b_type, c_type)].append(config)
if (a_type, b_type, c_type) in sm_75_result_dict and SUPPORT_SM75:
config_sm75 = config.copy()
config_sm75["stages"] = 2
sm_75_result_dict[(a_type, b_type, c_type)].append(config_sm75)
kernel_selector_str = FILE_HEAD_COMMENT kernel_selector_str = FILE_HEAD_COMMENT
for (a_type, b_type, c_type), config_list in result_dict.items(): for result_dict_tmp in [result_dict, sm_75_result_dict]:
all_template_str_list = [] for (a_type, b_type, c_type), config_list in result_dict_tmp.items():
for config in config_list: all_template_str_list = []
s_type = config["s_type"] if not config_list:
template_str = jinja2.Template(TEMPLATE).render( continue
a_type_id=f"vllm::{a_type}.id()", for config in config_list:
b_type_id=f"vllm::{b_type}.id()", s_type = config["s_type"]
c_type_id=f"vllm::{c_type}.id()", template_str = jinja2.Template(TEMPLATE).render(
s_type_id=f"vllm::{s_type}.id()",
**config,
)
all_template_str_list.append(template_str)
conditions = [
f"a_type == vllm::{a_type}",
f"b_type == vllm::{b_type}",
f"c_type == vllm::{c_type}",
f"s_type == vllm::{s_type}",
f"threads == {config['threads']}",
f"thread_m_blocks == {config['thread_m_blocks']}",
f"thread_n_blocks == {config['thread_n_blocks']}",
f"thread_k_blocks == {config['thread_k_blocks']}",
f"m_block_size_8 == {config['m_block_size_8']}",
f"group_blocks == {config['group_blocks']}",
f"is_zp_float == {config['is_zp_float']}",
]
conditions = " && ".join(conditions)
if kernel_selector_str == FILE_HEAD_COMMENT:
kernel_selector_str += f"if ({conditions})\n kernel = "
else:
kernel_selector_str += f"else if ({conditions})\n kernel = "
kernel_template2 = (
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
"{{is_zp_float}}>;"
)
kernel_selector_str += (
jinja2.Template(kernel_template2).render(
a_type_id=f"vllm::{a_type}.id()", a_type_id=f"vllm::{a_type}.id()",
b_type_id=f"vllm::{b_type}.id()", b_type_id=f"vllm::{b_type}.id()",
c_type_id=f"vllm::{c_type}.id()", c_type_id=f"vllm::{c_type}.id()",
s_type_id=f"vllm::{s_type}.id()", s_type_id=f"vllm::{s_type}.id()",
**config, **config,
) )
+ "\n" all_template_str_list.append(template_str)
)
file_content = FILE_HEAD + "\n\n" conditions = [
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n" f"a_type == vllm::{a_type}",
if a_type == "kFE4M3fn": f"b_type == vllm::{b_type}",
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu" f"c_type == vllm::{c_type}",
else: f"s_type == vllm::{s_type}",
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu" f"threads == {config['threads']}",
f"thread_m_blocks == {config['thread_m_blocks']}",
f"thread_n_blocks == {config['thread_n_blocks']}",
f"thread_k_blocks == {config['thread_k_blocks']}",
f"m_block_size_8 == {config['m_block_size_8']}",
f"stages == {config['stages']}",
f"group_blocks == {config['group_blocks']}",
f"is_zp_float == {config['is_zp_float']}",
]
conditions = " && ".join(conditions)
filename = filename.lower() if kernel_selector_str == FILE_HEAD_COMMENT:
kernel_selector_str += f"if ({conditions})\n kernel = "
else:
kernel_selector_str += f"else if ({conditions})\n kernel = "
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f: kernel_template2 = (
f.write(file_content) "Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
"{{is_zp_float}}>;"
)
kernel_selector_str += (
jinja2.Template(kernel_template2).render(
a_type_id=f"vllm::{a_type}.id()",
b_type_id=f"vllm::{b_type}.id()",
c_type_id=f"vllm::{c_type}.id()",
s_type_id=f"vllm::{s_type}.id()",
**config,
)
+ "\n"
)
file_content = FILE_HEAD + "\n\n"
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
if a_type == "kFE4M3fn":
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
elif result_dict_tmp is sm_75_result_dict:
filename = f"sm75_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
else:
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
filename = filename.lower()
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
f.write(file_content)
if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT: if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT:
kernel_selector_str += ( kernel_selector_str += (

View File

@ -37,7 +37,7 @@ __global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){};
using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS); using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
int const* __restrict__ perm_int_ptr, int const* __restrict__ perm_int_ptr,
@ -148,7 +148,7 @@ typedef struct {
int get_scales_cache_size(thread_config_t const& th_config, int prob_m, int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
int prob_n, int prob_k, int num_bits, int group_size, int prob_n, int prob_k, int num_bits, int group_size,
bool has_act_order, bool is_k_full) { bool has_act_order, bool is_k_full, int stages) {
bool cache_scales_chunk = has_act_order && !is_k_full; bool cache_scales_chunk = has_act_order && !is_k_full;
int tb_n = th_config.thread_n; int tb_n = th_config.thread_n;
@ -166,28 +166,29 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
if (cache_scales_chunk) { if (cache_scales_chunk) {
int load_groups = int load_groups =
tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K tb_groups * stages * 2; // Chunk size is 2x pipeline over dim K
load_groups = max(load_groups, 32); // We load at least 32 scale groups load_groups = max(load_groups, 32); // We load at least 32 scale groups
return load_groups * tb_n * 2; return load_groups * tb_n * 2;
} else { } else {
int tb_scales = tb_groups * tb_n * 2; int tb_scales = tb_groups * tb_n * 2;
return tb_scales * pipe_stages; return tb_scales * stages;
} }
} }
int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks, int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
int prob_m, int prob_n, int prob_k, int num_bits, int prob_m, int prob_n, int prob_k, int num_bits,
int group_size, bool has_act_order, bool is_k_full, int group_size, bool has_act_order, bool is_k_full,
int has_zp, int is_zp_float) { int has_zp, bool is_zp_float, bool is_a_8bit,
int stages) {
int pack_factor = 32 / num_bits; int pack_factor = 32 / num_bits;
// Get B size // Get B size
int tb_k = th_config.thread_k; int tb_k = th_config.thread_k;
int tb_n = th_config.thread_n; int tb_n = th_config.thread_n;
int tb_m = thread_m_blocks * 16; int tb_m = thread_m_blocks * 16;
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2; int sh_a_size = stages * (tb_m * tb_k) * (is_a_8bit ? 1 : 2);
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4; int sh_b_size = stages * (tb_k * tb_n / pack_factor) * 4;
int sh_red_size = tb_m * (tb_n + 8) * 2; int sh_red_size = tb_m * (tb_n + 8) * 2;
int sh_bias_size = tb_n * 2; int sh_bias_size = tb_n * 2;
int tmp_size = int tmp_size =
@ -196,8 +197,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
int sh_s_size = int sh_s_size =
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
group_size, has_act_order, is_k_full); group_size, has_act_order, is_k_full, stages);
int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0; int sh_g_idx_size = has_act_order && !is_k_full ? stages * tb_k / 4 : 0;
int sh_zp_size = 0; int sh_zp_size = 0;
if (has_zp) { if (has_zp) {
if (is_zp_float) if (is_zp_float)
@ -217,7 +218,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks, bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
int prob_m, int prob_n, int prob_k, int num_bits, int prob_m, int prob_n, int prob_k, int num_bits,
int group_size, bool has_act_order, bool is_k_full, int group_size, bool has_act_order, bool is_k_full,
int has_zp, int is_zp_float, int max_shared_mem) { int has_zp, bool is_zp_float, bool is_a_8bit, int stages,
int max_shared_mem) {
// Sanity // Sanity
if (th_config.thread_k == -1 || th_config.thread_n == -1 || if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
th_config.num_threads == -1) { th_config.num_threads == -1) {
@ -242,7 +244,7 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
// Check that pipeline fits into cache // Check that pipeline fits into cache
int cache_size = get_kernel_cache_size( int cache_size = get_kernel_cache_size(
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, has_zp, is_zp_float); has_act_order, is_k_full, has_zp, is_zp_float, is_a_8bit, stages);
return cache_size <= max_shared_mem; return cache_size <= max_shared_mem;
} }
@ -251,7 +253,7 @@ MarlinFuncPtr get_marlin_kernel(
const vllm::ScalarType c_type, const vllm::ScalarType s_type, const vllm::ScalarType c_type, const vllm::ScalarType s_type,
int thread_m_blocks, int thread_n_blocks, int thread_k_blocks, int thread_m_blocks, int thread_n_blocks, int thread_k_blocks,
bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks, bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks,
int threads, bool is_zp_float) { int threads, bool is_zp_float, int stages) {
int num_bits = b_type.size_bits(); int num_bits = b_type.size_bits();
auto kernel = MarlinDefault; auto kernel = MarlinDefault;
@ -265,7 +267,8 @@ exec_config_t determine_exec_config(
const vllm::ScalarType& c_type, const vllm::ScalarType& s_type, int prob_m, const vllm::ScalarType& c_type, const vllm::ScalarType& s_type, int prob_m,
int prob_n, int prob_k, int thread_m_blocks, bool m_block_size_8, int prob_n, int prob_k, int thread_m_blocks, bool m_block_size_8,
int num_bits, int group_size, bool has_act_order, bool is_k_full, int num_bits, int group_size, bool has_act_order, bool is_k_full,
bool has_zp, bool is_zp_float, int max_shared_mem, int sms) { bool has_zp, bool is_zp_float, int is_a_8bit, int stages,
int max_shared_mem, int sms) {
exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}}; exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}};
thread_config_t* thread_configs = thread_m_blocks > 1 thread_config_t* thread_configs = thread_m_blocks > 1
? large_batch_thread_configs ? large_batch_thread_configs
@ -280,13 +283,15 @@ exec_config_t determine_exec_config(
if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k, if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full, has_zp, num_bits, group_size, has_act_order, is_k_full, has_zp,
is_zp_float, max_shared_mem - 512)) { is_zp_float, is_a_8bit, stages,
max_shared_mem - 512)) {
continue; continue;
} }
int cache_size = get_kernel_cache_size( int cache_size = get_kernel_cache_size(th_config, thread_m_blocks, prob_m,
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, prob_n, prob_k, num_bits, group_size,
group_size, has_act_order, is_k_full, has_zp, is_zp_float); has_act_order, is_k_full, has_zp,
is_zp_float, is_a_8bit, stages);
int group_blocks = 0; int group_blocks = 0;
if (!has_act_order) { if (!has_act_order) {
@ -297,14 +302,10 @@ exec_config_t determine_exec_config(
get_marlin_kernel(a_type, b_type, c_type, s_type, thread_m_blocks, get_marlin_kernel(a_type, b_type, c_type, s_type, thread_m_blocks,
th_config.thread_n / 16, th_config.thread_k / 16, th_config.thread_n / 16, th_config.thread_k / 16,
m_block_size_8, has_act_order, has_zp, group_blocks, m_block_size_8, has_act_order, has_zp, group_blocks,
th_config.num_threads, is_zp_float); th_config.num_threads, is_zp_float, stages);
if (kernel == MarlinDefault) continue; if (kernel == MarlinDefault) continue;
// int m_tiles = div_ceil(prob_m, thread_m_blocks * 16);
// int n_tiles = prob_n / th_config.thread_n;
// int k_tiles = prob_k / th_config.thread_k;
return {1, th_config}; return {1, th_config};
} }
@ -321,6 +322,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
int group_size, int dev, cudaStream_t stream, int thread_k_init, int group_size, int dev, cudaStream_t stream, int thread_k_init,
int thread_n_init, int sms, bool use_atomic_add, int thread_n_init, int sms, bool use_atomic_add,
bool use_fp32_reduce, bool is_zp_float) { bool use_fp32_reduce, bool is_zp_float) {
bool is_a_8bit = a_type.size_bits() == 8;
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
", ", prob_n, ", ", prob_k, "]"); ", ", prob_n, ", ", prob_k, "]");
@ -389,8 +391,14 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
dev); dev);
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
dev); dev);
TORCH_CHECK(major_capability * 10 + minor_capability >= 80, TORCH_CHECK(major_capability * 10 + minor_capability >= 75,
"marlin kernel only support Ampere or newer GPUs."); "marlin kernel only support Turing or newer GPUs.");
int stages = 4;
if (major_capability == 7 && minor_capability == 5) {
stages = 2;
TORCH_CHECK(a_type == vllm::kFloat16 || a_type == vllm::kS8,
"Turing only support FP16 or INT8 activation.");
}
if (a_type == vllm::kFE4M3fn) { if (a_type == vllm::kFE4M3fn) {
TORCH_CHECK( TORCH_CHECK(
major_capability * 10 + minor_capability == 89 || major_capability * 10 + minor_capability == 89 ||
@ -431,7 +439,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
exec_cfg = determine_exec_config( exec_cfg = determine_exec_config(
a_type, b_type, c_type, s_type, prob_m_split, prob_n, prob_k, a_type, b_type, c_type, s_type, prob_m_split, prob_n, prob_k,
thread_m_blocks, m_block_size_8, num_bits, group_size, has_act_order, thread_m_blocks, m_block_size_8, num_bits, group_size, has_act_order,
is_k_full, has_zp, is_zp_float, max_shared_mem, sms); is_k_full, has_zp, is_zp_float, is_a_8bit, stages, max_shared_mem,
sms);
thread_tfg = exec_cfg.tb_cfg; thread_tfg = exec_cfg.tb_cfg;
if (thread_tfg.thread_n != -1) { if (thread_tfg.thread_n != -1) {
if (prob_n / thread_tfg.thread_n * if (prob_n / thread_tfg.thread_n *
@ -440,7 +449,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
if (is_valid_config({128, 64, 128}, thread_m_blocks, prob_m_split, if (is_valid_config({128, 64, 128}, thread_m_blocks, prob_m_split,
prob_n, prob_k, num_bits, group_size, prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, has_zp, is_zp_float, has_act_order, is_k_full, has_zp, is_zp_float,
max_shared_mem_new)) { is_a_8bit, stages, max_shared_mem_new)) {
thread_tfg = {128, 64, 128}; thread_tfg = {128, 64, 128};
exec_cfg = {1, thread_tfg}; exec_cfg = {1, thread_tfg};
} }
@ -466,7 +475,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
TORCH_CHECK( TORCH_CHECK(
is_valid_config(thread_tfg, thread_m_blocks, prob_m_split, prob_n, is_valid_config(thread_tfg, thread_m_blocks, prob_m_split, prob_n,
prob_k, num_bits, group_size, has_act_order, is_k_full, prob_k, num_bits, group_size, has_act_order, is_k_full,
has_zp, is_zp_float, max_shared_mem_new), has_zp, is_zp_float, is_a_8bit, stages,
max_shared_mem_new),
"Invalid thread config: thread_m_blocks = ", thread_m_blocks, "Invalid thread config: thread_m_blocks = ", thread_m_blocks,
", thread_k = ", thread_tfg.thread_k, ", thread_k = ", thread_tfg.thread_k,
", thread_n = ", thread_tfg.thread_n, ", thread_n = ", thread_tfg.thread_n,
@ -475,12 +485,12 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
", prob_m_split = ", prob_m_split, ", group_size = ", group_size, ", prob_m_split = ", prob_m_split, ", group_size = ", group_size,
", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float, ", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float,
", max_shared_mem_new = ", max_shared_mem_new); ", stages = ", stages, ", max_shared_mem_new = ", max_shared_mem_new);
auto kernel = get_marlin_kernel( auto kernel = get_marlin_kernel(
a_type, b_type, c_type, s_type, thread_m_blocks, thread_n_blocks, a_type, b_type, c_type, s_type, thread_m_blocks, thread_n_blocks,
thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks, thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks,
num_threads, is_zp_float); num_threads, is_zp_float, stages);
if (kernel == MarlinDefault) { if (kernel == MarlinDefault) {
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n, TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,

View File

@ -1,17 +1,19 @@
#pragma once #pragma once
#include <torch/all.h> #ifndef _marlin_cuh
#define _marlin_cuh
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <iostream> #include <iostream>
#ifndef MARLIN_NAMESPACE_NAME #ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin #define MARLIN_NAMESPACE_NAME marlin
#endif #endif
namespace MARLIN_NAMESPACE_NAME { namespace MARLIN_NAMESPACE_NAME {
@ -51,9 +53,51 @@ using I4 = Vec<int, 4>;
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// No support for async
#else __device__ inline void cp_async1_ca_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
if (pred) {
reinterpret_cast<int32_t*>(smem_ptr)[0] =
reinterpret_cast<const int32_t*>(glob_ptr)[0];
}
}
__device__ inline void cp_async2_ca_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
if (pred) {
reinterpret_cast<int64_t*>(smem_ptr)[0] =
reinterpret_cast<const int64_t*>(glob_ptr)[0];
}
}
__device__ inline void cp_async4_ca_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
if (pred) {
reinterpret_cast<int4*>(smem_ptr)[0] =
reinterpret_cast<const int4*>(glob_ptr)[0];
}
}
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
if (pred) {
reinterpret_cast<int4*>(smem_ptr)[0] =
reinterpret_cast<const int4*>(glob_ptr)[0];
}
}
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
reinterpret_cast<int4*>(smem_ptr)[0] =
reinterpret_cast<const int4*>(glob_ptr)[0];
}
__device__ inline void cp_async_fence() {}
template <int n>
__device__ inline void cp_async_wait() {}
#else
__device__ inline void cp_async1_ca_pred(void* smem_ptr, const void* glob_ptr, __device__ inline void cp_async1_ca_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) { bool pred = true) {
@ -126,6 +170,8 @@ __device__ inline void cp_async_wait() {
asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
} }
#endif #endif
} // namespace MARLIN_NAMESPACE_NAME } // namespace MARLIN_NAMESPACE_NAME
#endif

View File

@ -0,0 +1,269 @@
#include "marlin_dtypes.cuh"
namespace MARLIN_NAMESPACE_NAME {
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
// output/accumulation.
template <vllm::ScalarTypeId type_id, bool use_fp16_accum, int k_size = 16>
__device__ inline void mma(
const typename MarlinScalarType<type_id>::FragA& a_frag,
const typename MarlinScalarType<type_id>::FragB& frag_b,
typename MarlinScalarType<type_id>::FragC& frag_c, int idx = 0) {
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
if constexpr (!std::is_same<scalar_t, half>::value || k_size != 16) {
static_assert(!use_fp16_accum);
}
if constexpr (k_size == 16) {
if constexpr (std::is_same<scalar_t, half>::value && !use_fp16_accum) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(b[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
"f"(c[3]));
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[2]), "r"(a[3]), "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
"f"(c[3]));
#else
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
#endif
} else if constexpr (std::is_same<scalar_t, half>::value &&
use_fp16_accum) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
uint32_t* c = reinterpret_cast<uint32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n"
: "=r"(c[0]), "=r"(c[1])
: "r"(a[0]), "r"(a[1]), "r"(b[0]), "r"(c[0]), "r"(c[1]));
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n"
: "=r"(c[0]), "=r"(c[1])
: "r"(a[2]), "r"(a[3]), "r"(b[1]), "r"(c[0]), "r"(c[1]));
#else
uint32_t* c = reinterpret_cast<uint32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n"
: "=r"(c[0]), "=r"(c[1])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"r"(c[0]), "r"(c[1]));
#endif
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "f"(c[0]),
"f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "r"(c[0]),
"r"(c[1]), "r"(c[2]), "r"(c[3]));
}
} else if (k_size == 32) {
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
asm volatile(
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
: "=r"(c[0]), "=r"(c[1])
: "r"(a[0]), "r"(b[0]), "r"(c[0]), "r"(c[1]));
asm volatile(
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
: "=r"(c[2]), "=r"(c[3])
: "r"(a[1]), "r"(b[0]), "r"(c[2]), "r"(c[3]));
asm volatile(
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
: "=r"(c[0]), "=r"(c[1])
: "r"(a[2]), "r"(b[1]), "r"(c[0]), "r"(c[1]));
asm volatile(
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
: "=r"(c[2]), "=r"(c[3])
: "r"(a[3]), "r"(b[1]), "r"(c[2]), "r"(c[3]));
#else
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
#endif
}
}
}
template <vllm::ScalarTypeId type_id, bool use_fp16_accum, int k_size = 16>
__device__ inline void mma_trans(
const typename MarlinScalarType<type_id>::FragA& a_frag,
const typename MarlinScalarType<type_id>::FragB& frag_b,
const typename MarlinScalarType<type_id>::FragB& frag_b2,
typename MarlinScalarType<type_id>::FragC& frag_c) {
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
const uint32_t* b2 = reinterpret_cast<const uint32_t*>(&frag_b2);
float* c = reinterpret_cast<float*>(&frag_c);
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
if constexpr (!std::is_same<scalar_t, half>::value || k_size != 16) {
static_assert(!use_fp16_accum);
}
if constexpr (k_size == 16) {
if constexpr (std::is_same<scalar_t, half>::value && !use_fp16_accum) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
"f"(c[3]));
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[1]), "r"(b2[1]), "r"(a[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
"f"(c[3]));
#else
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
#endif
} else if constexpr (std::is_same<scalar_t, half>::value &&
use_fp16_accum) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
uint32_t* c = reinterpret_cast<uint32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n"
: "=r"(c[0]), "=r"(c[1])
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]));
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n"
: "=r"(c[0]), "=r"(c[1])
: "r"(b[1]), "r"(b2[1]), "r"(a[1]), "r"(c[0]), "r"(c[1]));
#else
uint32_t* c = reinterpret_cast<uint32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n"
: "=r"(c[0]), "=r"(c[1])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"r"(c[0]), "r"(c[1]));
#endif
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
"f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]),
"r"(c[3]));
}
} else {
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
asm volatile(
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
: "=r"(c[0]), "=r"(c[1])
: "r"(b[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]));
asm volatile(
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
: "=r"(c[2]), "=r"(c[3])
: "r"(b2[1]), "r"(a[0]), "r"(c[2]), "r"(c[3]));
asm volatile(
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
: "=r"(c[0]), "=r"(c[1])
: "r"(b[0]), "r"(a[1]), "r"(c[0]), "r"(c[1]));
asm volatile(
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
: "=r"(c[2]), "=r"(c[3])
: "r"(b2[1]), "r"(a[1]), "r"(c[2]), "r"(c[3]));
#else
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
#endif
}
}
}
} // namespace MARLIN_NAMESPACE_NAME

View File

@ -26,6 +26,7 @@
#include "marlin.cuh" #include "marlin.cuh"
#include "marlin_dtypes.cuh" #include "marlin_dtypes.cuh"
#include "dequant.h" #include "dequant.h"
#include "marlin_mma.h"
#include "core/scalar_type.hpp" #include "core/scalar_type.hpp"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
@ -35,7 +36,7 @@
namespace MARLIN_NAMESPACE_NAME { namespace MARLIN_NAMESPACE_NAME {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
template <typename scalar_t, // compute dtype, half or nv_float16 template <typename scalar_t, // compute dtype, half or nv_float16
const vllm::ScalarTypeId b_type_id, // weight MarlinScalarType id const vllm::ScalarTypeId b_type_id, // weight MarlinScalarType id
@ -75,137 +76,6 @@ __global__ void Marlin(
#else #else
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
// output/accumulation.
template <vllm::ScalarTypeId type_id, int k_size = 16>
__device__ inline void mma(
const typename MarlinScalarType<type_id>::FragA& a_frag,
const typename MarlinScalarType<type_id>::FragB& frag_b,
typename MarlinScalarType<type_id>::FragC& frag_c, int idx = 0) {
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
if constexpr (k_size == 16) {
if constexpr (std::is_same<scalar_t, half>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "f"(c[0]),
"f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "r"(c[0]),
"r"(c[1]), "r"(c[2]), "r"(c[3]));
}
} else if (k_size == 32) {
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
}
}
}
template <vllm::ScalarTypeId type_id, int k_size = 16>
__device__ inline void mma_trans(
const typename MarlinScalarType<type_id>::FragA& a_frag,
const typename MarlinScalarType<type_id>::FragB& frag_b,
const typename MarlinScalarType<type_id>::FragB& frag_b2,
typename MarlinScalarType<type_id>::FragC& frag_c) {
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
const uint32_t* b2 = reinterpret_cast<const uint32_t*>(&frag_b2);
float* c = reinterpret_cast<float*>(&frag_c);
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
if constexpr (k_size == 16) {
if constexpr (std::is_same<scalar_t, half>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
"f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]),
"r"(c[3]));
}
} else {
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
}
}
}
// Instruction for loading a full 16x16 matrix fragment of operand A from shared // Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout. // memory, directly in tensor core layout.
template <int count, vllm::ScalarTypeId type_id> template <int count, vllm::ScalarTypeId type_id>
@ -415,6 +285,17 @@ __global__ void Marlin(
if constexpr (a_type_id == vllm::kFE4M3fn.id()) return; if constexpr (a_type_id == vllm::kFE4M3fn.id()) return;
#endif #endif
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
// Turing TensorCore only supports fp16 and int8
if constexpr (a_type_id != vllm::kFloat16.id() && a_type_id != vllm::kS8.id())
return;
#endif
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
constexpr bool use_fp16_accum = a_type_id == vllm::kFloat16.id();
#else
constexpr bool use_fp16_accum = false;
#endif
using Adtype = MarlinScalarType<a_type_id>; using Adtype = MarlinScalarType<a_type_id>;
using Cdtype = MarlinScalarType<c_type_id>; using Cdtype = MarlinScalarType<c_type_id>;
const int4* A = A0; const int4* A = A0;
@ -873,10 +754,6 @@ __global__ void Marlin(
constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride) constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride)
: (stages * s_sh_stage); : (stages * s_sh_stage);
int4* sh_s = sh_zp + (stages * zp_sh_stage); int4* sh_s = sh_zp + (stages * zp_sh_stage);
// shared memory reused by reduction should be smaller than
// shared memory used by weight.
static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <=
stages * b_sh_stage);
int4* sh_a = sh_s + sh_s_size; int4* sh_a = sh_s + sh_s_size;
// Register storage for double buffer of shared memory reads. // Register storage for double buffer of shared memory reads.
@ -1395,11 +1272,13 @@ __global__ void Marlin(
#pragma unroll #pragma unroll
for (int i = 0; i < thread_m_blocks; i++) { for (int i = 0; i < thread_m_blocks; i++) {
if constexpr (m_block_size_8) { if constexpr (m_block_size_8) {
mma_trans<a_type_id>(frag_a[k2][i], frag_b0, frag_b1, mma_trans<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b0, frag_b1,
frag_c[i][j][0]); frag_c[i][j][0]);
} else { } else {
mma<a_type_id>(frag_a[k2][i], frag_b0, frag_c[i][j][0]); mma<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b0,
mma<a_type_id>(frag_a[k2][i], frag_b1, frag_c[i][j][1]); frag_c[i][j][0]);
mma<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b1,
frag_c[i][j][1]);
} }
} }
} }
@ -1433,10 +1312,12 @@ __global__ void Marlin(
#pragma unroll #pragma unroll
for (int i = 0; i < thread_m_blocks; i++) { for (int i = 0; i < thread_m_blocks; i++) {
mma<a_type_id, 32>(frag_a[k2][i], frag_b[0], mma<a_type_id, false, 32>(
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]); frag_a[k2][i], frag_b[0],
mma<a_type_id, 32>(frag_a[k2][i], frag_b[1], (group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]);
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]); mma<a_type_id, false, 32>(
frag_a[k2][i], frag_b[1],
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]);
} }
if constexpr (group_blocks != -1) { if constexpr (group_blocks != -1) {
@ -1956,6 +1837,21 @@ __global__ void Marlin(
// While this pattern may not be the most readable, other ways of writing // While this pattern may not be the most readable, other ways of writing
// the loop seemed to noticeably worse performance after compilation. // the loop seemed to noticeably worse performance after compilation.
if (slice_iters == 0) { if (slice_iters == 0) {
// convert fp16 accum to fp32 for reduction
if constexpr (use_fp16_accum) {
#pragma unroll
for (int i = 0; i < (thread_m_blocks * (is_a_8bit ? 2 : 4) * 2); i++) {
float* frag_c_part_float = reinterpret_cast<float*>(frag_c) + i * 4;
scalar_t* frag_c_part_half =
reinterpret_cast<scalar_t*>(frag_c_part_float);
#pragma unroll
for (int i = 3; i >= 0; i--) {
frag_c_part_float[i] = Cdtype::num2float(frag_c_part_half[i]);
}
}
}
if constexpr (is_a_8bit) { if constexpr (is_a_8bit) {
float frag_a_s[2 * thread_m_blocks]; float frag_a_s[2 * thread_m_blocks];

View File

@ -550,8 +550,8 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowPrefill(
int rowEnd = rowEnds[rowIdx]; int rowEnd = rowEnds[rowIdx];
// Local pointers to this block // Local pointers to this block
outIndices += rowIdx * topK; outIndices += static_cast<int64_t>(rowIdx) * topK;
logits += rowIdx * stride0; logits += static_cast<int64_t>(rowIdx) * stride0;
topKPerRowJob<kNumThreadsPerBlock, kNumBins, useRadixSort>( topKPerRowJob<kNumThreadsPerBlock, kNumBins, useRadixSort>(
nullptr, logits, rowStart, rowEnd, outIndices, nullptr, stride1, topK); nullptr, logits, rowStart, rowEnd, outIndices, nullptr, stride1, topK);
@ -576,19 +576,21 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode(
// Local pointers to this block // Local pointers to this block
if constexpr (!multipleBlocksPerRow && !mergeBlocks) { if constexpr (!multipleBlocksPerRow && !mergeBlocks) {
outIndices += rowIdx * topK; outIndices += static_cast<int64_t>(rowIdx) * topK;
} else if constexpr (multipleBlocksPerRow) { } else if constexpr (multipleBlocksPerRow) {
const auto blockSize = rowEnd / gridDim.y; // 16384 / 2 = 8192 const auto blockSize = rowEnd / gridDim.y; // 16384 / 2 = 8192
rowStart = blockSize * blockIdx.y; // 8192 * 1 = 8192 rowStart = blockSize * blockIdx.y; // 8192 * 1 = 8192
rowEnd = gridDim.y == blockIdx.y + 1 ? rowEnd : rowStart + blockSize; rowEnd = gridDim.y == blockIdx.y + 1 ? rowEnd : rowStart + blockSize;
outIndices += rowIdx * gridDim.y * topK + blockIdx.y * topK; outIndices +=
outLogits += rowIdx * gridDim.y * topK + blockIdx.y * topK; static_cast<int64_t>(rowIdx) * gridDim.y * topK + blockIdx.y * topK;
outLogits +=
static_cast<int64_t>(rowIdx) * gridDim.y * topK + blockIdx.y * topK;
} else if constexpr (mergeBlocks) { } else if constexpr (mergeBlocks) {
rowEnd = numBlocksToMerge * topK; rowEnd = numBlocksToMerge * topK;
indices += rowIdx * numBlocksToMerge * topK; indices += static_cast<int64_t>(rowIdx) * numBlocksToMerge * topK;
outIndices += rowIdx * topK; outIndices += static_cast<int64_t>(rowIdx) * topK;
} }
logits += rowIdx * stride0; logits += static_cast<int64_t>(rowIdx) * stride0;
topKPerRowJob<kNumThreadsPerBlock, kNumBins, useRadixSort, topKPerRowJob<kNumThreadsPerBlock, kNumBins, useRadixSort,
multipleBlocksPerRow, mergeBlocks>( multipleBlocksPerRow, mergeBlocks>(

View File

@ -621,7 +621,7 @@ ENV UV_HTTP_TIMEOUT=500
RUN --mount=type=cache,target=/root/.cache/uv \ RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=requirements/kv_connectors.txt,target=/tmp/kv_connectors.txt,ro \ --mount=type=bind,source=requirements/kv_connectors.txt,target=/tmp/kv_connectors.txt,ro \
if [ "$INSTALL_KV_CONNECTORS" = "true" ]; then \ if [ "$INSTALL_KV_CONNECTORS" = "true" ]; then \
uv pip install --system -r /tmp/kv_connectors.txt; \ uv pip install --system -r /tmp/kv_connectors.txt || true; \
fi fi
ENV VLLM_USAGE_SOURCE production-docker-image ENV VLLM_USAGE_SOURCE production-docker-image

View File

@ -109,7 +109,7 @@ Every plugin has three parts:
- `init_device`: This function is called to set up the device for the worker. - `init_device`: This function is called to set up the device for the worker.
- `initialize_cache`: This function is called to set cache config for the worker. - `initialize_cache`: This function is called to set cache config for the worker.
- `load_model`: This function is called to load the model weights to device. - `load_model`: This function is called to load the model weights to device.
- `get_kv_cache_spaces`: This function is called to generate the kv cache spaces for the model. - `get_kv_cache_spec`: This function is called to generate the kv cache spec for the model.
- `determine_available_memory`: This function is called to profiles the peak memory usage of the model to determine how much memory can be used for KV cache without OOMs. - `determine_available_memory`: This function is called to profiles the peak memory usage of the model to determine how much memory can be used for KV cache without OOMs.
- `initialize_from_config`: This function is called to allocate device KV cache with the specified kv_cache_config - `initialize_from_config`: This function is called to allocate device KV cache with the specified kv_cache_config
- `execute_model`: This function is called every step to inference the model. - `execute_model`: This function is called every step to inference the model.

View File

@ -181,3 +181,4 @@ If you have PRs touching the area, please feel free to ping the area owner for r
- Ascend NPU: [@wangxiyuan](https://github.com/wangxiyuan) and [see more details](https://vllm-ascend.readthedocs.io/en/latest/community/contributors.html#maintainers) - Ascend NPU: [@wangxiyuan](https://github.com/wangxiyuan) and [see more details](https://vllm-ascend.readthedocs.io/en/latest/community/contributors.html#maintainers)
- Intel Gaudi HPU [@xuechendi](https://github.com/xuechendi) and [@kzawora-intel](https://github.com/kzawora-intel) - Intel Gaudi HPU [@xuechendi](https://github.com/xuechendi) and [@kzawora-intel](https://github.com/kzawora-intel)
- Semantic Router: [@xunzhuo](https://github.com/xunzhuo), [@rootfs](https://github.com/rootfs) and [see more details](https://vllm-semantic-router.com/community/team)

View File

@ -47,6 +47,8 @@ We currently support the following OpenAI APIs:
- [Completions API](#completions-api) (`/v1/completions`) - [Completions API](#completions-api) (`/v1/completions`)
- Only applicable to [text generation models](../models/generative_models.md). - Only applicable to [text generation models](../models/generative_models.md).
- *Note: `suffix` parameter is not supported.* - *Note: `suffix` parameter is not supported.*
- [Responses API](#responses-api) (`/v1/responses`)
- Only applicable to [text generation models](../models/generative_models.md).
- [Chat Completions API](#chat-api) (`/v1/chat/completions`) - [Chat Completions API](#chat-api) (`/v1/chat/completions`)
- Only applicable to [text generation models](../models/generative_models.md) with a [chat template](../serving/openai_compatible_server.md#chat-template). - Only applicable to [text generation models](../models/generative_models.md) with a [chat template](../serving/openai_compatible_server.md#chat-template).
- *Note: `user` parameter is ignored.* - *Note: `user` parameter is ignored.*
@ -229,6 +231,31 @@ The following extra parameters are supported:
--8<-- "vllm/entrypoints/openai/protocol.py:chat-completion-extra-params" --8<-- "vllm/entrypoints/openai/protocol.py:chat-completion-extra-params"
``` ```
### Responses API
Our Responses API is compatible with [OpenAI's Responses API](https://platform.openai.com/docs/api-reference/responses);
you can use the [official OpenAI Python client](https://github.com/openai/openai-python) to interact with it.
Code example: [examples/online_serving/openai_responses_client_with_tools.py](../../examples/online_serving/openai_responses_client_with_tools.py)
#### Extra parameters
The following extra parameters in the request object are supported:
??? code
```python
--8<-- "vllm/entrypoints/openai/protocol.py:responses-extra-params"
```
The following extra parameters in the response object are supported:
??? code
```python
--8<-- "vllm/entrypoints/openai/protocol.py:responses-response-extra-params"
```
### Embeddings API ### Embeddings API
Our Embeddings API is compatible with [OpenAI's Embeddings API](https://platform.openai.com/docs/api-reference/embeddings); Our Embeddings API is compatible with [OpenAI's Embeddings API](https://platform.openai.com/docs/api-reference/embeddings);

View File

@ -6,7 +6,7 @@ requires = [
"packaging>=24.2", "packaging>=24.2",
"setuptools>=77.0.3,<81.0.0", "setuptools>=77.0.3,<81.0.0",
"setuptools-scm>=8.0", "setuptools-scm>=8.0",
"torch == 2.9.0", "torch == 2.9.1",
"wheel", "wheel",
"jinja2", "jinja2",
] ]

View File

@ -4,7 +4,7 @@ ninja
packaging>=24.2 packaging>=24.2
setuptools>=77.0.3,<81.0.0 setuptools>=77.0.3,<81.0.0
setuptools-scm>=8 setuptools-scm>=8
torch==2.9.0 torch==2.9.1
wheel wheel
jinja2>=3.1.6 jinja2>=3.1.6
regex regex

View File

@ -37,7 +37,7 @@ pyyaml
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
setuptools>=77.0.3,<81.0.0; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12 setuptools>=77.0.3,<81.0.0; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12
einops # Required for Qwen2-VL. einops # Required for Qwen2-VL.
compressed-tensors == 0.12.2 # required for compressed-tensors compressed-tensors == 0.13.0 # required for compressed-tensors
depyf==0.20.0 # required for profiling and debugging with compilation config depyf==0.20.0 # required for profiling and debugging with compilation config
cloudpickle # allows pickling lambda functions in model_executor/models/registry.py cloudpickle # allows pickling lambda functions in model_executor/models/registry.py
watchfiles # required for http server to monitor the updates of TLS files watchfiles # required for http server to monitor the updates of TLS files
@ -50,5 +50,5 @@ ijson # Required for mistral streaming tool parser
setproctitle # Used to set process names for better debugging and monitoring setproctitle # Used to set process names for better debugging and monitoring
openai-harmony >= 0.0.3 # Required for gpt-oss openai-harmony >= 0.0.3 # Required for gpt-oss
anthropic == 0.71.0 anthropic == 0.71.0
model-hosting-container-standards >= 0.1.9, < 1.0.0 model-hosting-container-standards >= 0.1.10, < 1.0.0
mcp mcp

View File

@ -5,9 +5,9 @@ numba == 0.61.2 # Required for N-gram speculative decoding
# Dependencies for NVIDIA GPUs # Dependencies for NVIDIA GPUs
ray[cgraph]>=2.48.0 # Ray Compiled Graph, required for pipeline parallelism in V1. ray[cgraph]>=2.48.0 # Ray Compiled Graph, required for pipeline parallelism in V1.
torch==2.9.0 torch==2.9.1
torchaudio==2.9.0 torchaudio==2.9.1
# These must be updated alongside torch # These must be updated alongside torch
torchvision==0.24.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version torchvision==0.24.1 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
# FlashInfer should be updated together with the Dockerfile # FlashInfer should be updated together with the Dockerfile
flashinfer-python==0.5.3 flashinfer-python==0.5.3

View File

@ -2,11 +2,11 @@
-r common.txt -r common.txt
--extra-index-url https://download.pytorch.org/whl/rocm6.4 --extra-index-url https://download.pytorch.org/whl/rocm6.4
torch==2.9.0 torch==2.9.1
torchvision==0.24.0 torchvision==0.24.1
torchaudio==2.9.0 torchaudio==2.9.1
triton==3.5.0 triton==3.5.1
cmake>=3.26.1,<4 cmake>=3.26.1,<4
packaging>=24.2 packaging>=24.2
setuptools>=77.0.3,<80.0.0 setuptools>=77.0.3,<80.0.0

View File

@ -24,9 +24,9 @@ soundfile # required for audio tests
jiwer # required for audio tests jiwer # required for audio tests
tblib # for pickling test exceptions tblib # for pickling test exceptions
timm >=1.0.17 # required for internvl and gemma3n-mm test timm >=1.0.17 # required for internvl and gemma3n-mm test
torch==2.9.0 torch==2.9.1
torchaudio==2.9.0 torchaudio==2.9.1
torchvision==0.24.0 torchvision==0.24.1
transformers_stream_generator # required for qwen-vl test transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test matplotlib # required for qwen-vl test
mistral_common[image,audio] >= 1.8.5 # required for voxtral test mistral_common[image,audio] >= 1.8.5 # required for voxtral test

View File

@ -1123,7 +1123,7 @@ tomli==2.2.1
# via schemathesis # via schemathesis
tomli-w==1.2.0 tomli-w==1.2.0
# via schemathesis # via schemathesis
torch==2.9.0+cu129 torch==2.9.1+cu129
# via # via
# -r requirements/test.in # -r requirements/test.in
# accelerate # accelerate
@ -1152,7 +1152,7 @@ torch==2.9.0+cu129
# torchvision # torchvision
# vector-quantize-pytorch # vector-quantize-pytorch
# vocos # vocos
torchaudio==2.9.0+cu129 torchaudio==2.9.1+cu129
# via # via
# -r requirements/test.in # -r requirements/test.in
# encodec # encodec
@ -1165,7 +1165,7 @@ torchmetrics==1.7.4
# pytorch-lightning # pytorch-lightning
# terratorch # terratorch
# torchgeo # torchgeo
torchvision==0.24.0+cu129 torchvision==0.24.1+cu129
# via # via
# -r requirements/test.in # -r requirements/test.in
# lightly # lightly
@ -1206,7 +1206,7 @@ transformers==4.57.3
# transformers-stream-generator # transformers-stream-generator
transformers-stream-generator==0.0.5 transformers-stream-generator==0.0.5
# via -r requirements/test.in # via -r requirements/test.in
triton==3.5.0 triton==3.5.1
# via torch # via torch
tritonclient==2.51.0 tritonclient==2.51.0
# via # via

View File

@ -67,7 +67,6 @@ def _fix_prompt_embed_outputs(
@pytest.mark.parametrize("model_executor", ["uni", "mp"]) @pytest.mark.parametrize("model_executor", ["uni", "mp"])
@pytest.mark.parametrize("enable_prompt_embeds", [True, False]) @pytest.mark.parametrize("enable_prompt_embeds", [True, False])
def test_models( def test_models(
monkeypatch: pytest.MonkeyPatch,
hf_runner, hf_runner,
model: str, model: str,
backend: str, backend: str,
@ -77,48 +76,46 @@ def test_models(
model_executor: str, model_executor: str,
enable_prompt_embeds: bool, enable_prompt_embeds: bool,
) -> None: ) -> None:
with monkeypatch.context() as m: # 5042 tokens for gemma2
m.setenv("VLLM_ATTENTION_BACKEND", backend) # gemma2 has alternating sliding window size of 4096
# we need a prompt with more than 4096 tokens to test the sliding window
prompt = (
"The following numbers of the sequence "
+ ", ".join(str(i) for i in range(1024))
+ " are:"
)
example_prompts = [prompt]
# 5042 tokens for gemma2 with hf_runner(model) as hf_model:
# gemma2 has alternating sliding window size of 4096 hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
# we need a prompt with more than 4096 tokens to test the sliding window if enable_prompt_embeds:
prompt = ( with torch.no_grad():
"The following numbers of the sequence " prompt_embeds = hf_model.get_prompt_embeddings(example_prompts)
+ ", ".join(str(i) for i in range(1024))
+ " are:"
)
example_prompts = [prompt]
with hf_runner(model) as hf_model: with VllmRunner(
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) model,
if enable_prompt_embeds: max_model_len=8192,
with torch.no_grad(): enforce_eager=enforce_eager,
prompt_embeds = hf_model.get_prompt_embeddings(example_prompts) enable_prompt_embeds=enable_prompt_embeds,
gpu_memory_utilization=0.7,
async_scheduling=async_scheduling,
distributed_executor_backend=model_executor,
attention_config={"backend": backend},
) as vllm_model:
if enable_prompt_embeds:
vllm_outputs = vllm_model.generate_greedy(prompt_embeds, max_tokens)
vllm_outputs = _fix_prompt_embed_outputs(
vllm_outputs, hf_model, example_prompts
)
else:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
with VllmRunner( check_outputs_equal(
model, outputs_0_lst=hf_outputs,
max_model_len=8192, outputs_1_lst=vllm_outputs,
enforce_eager=enforce_eager, name_0="hf",
enable_prompt_embeds=enable_prompt_embeds, name_1="vllm",
gpu_memory_utilization=0.7, )
async_scheduling=async_scheduling,
distributed_executor_backend=model_executor,
) as vllm_model:
if enable_prompt_embeds:
vllm_outputs = vllm_model.generate_greedy(prompt_embeds, max_tokens)
vllm_outputs = _fix_prompt_embed_outputs(
vllm_outputs, hf_model, example_prompts
)
else:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
check_outputs_equal(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
@multi_gpu_test(num_gpus=2) @multi_gpu_test(num_gpus=2)
@ -161,12 +158,6 @@ def test_models_distributed(
): # noqa ): # noqa
pytest.skip("enable_prompt_embeds does not work with ray compiled dag.") pytest.skip("enable_prompt_embeds does not work with ray compiled dag.")
if attention_backend:
monkeypatch_context.setenv(
"VLLM_ATTENTION_BACKEND",
attention_backend,
)
for k, v in extra_env.items(): for k, v in extra_env.items():
monkeypatch_context.setenv(k, v) monkeypatch_context.setenv(k, v)
@ -178,6 +169,7 @@ def test_models_distributed(
# if we run HF first, the cuda initialization will be done and it # if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method # will hurt multiprocessing backend with fork method
# (the default method). # (the default method).
attention_config = {"backend": attention_backend} if attention_backend else None
with vllm_runner( with vllm_runner(
model, model,
dtype=dtype, dtype=dtype,
@ -185,6 +177,7 @@ def test_models_distributed(
distributed_executor_backend=distributed_executor_backend, distributed_executor_backend=distributed_executor_backend,
enable_prompt_embeds=enable_prompt_embeds, enable_prompt_embeds=enable_prompt_embeds,
gpu_memory_utilization=0.7, gpu_memory_utilization=0.7,
attention_config=attention_config,
) as vllm_model: ) as vllm_model:
if enable_prompt_embeds: if enable_prompt_embeds:
with hf_runner(model, dtype=dtype) as hf_model: with hf_runner(model, dtype=dtype) as hf_model:

View File

@ -19,21 +19,18 @@ def server():
@pytest.mark.benchmark @pytest.mark.benchmark
def test_bench_serve(server): def test_bench_serve(server):
# Test default model detection and input/output len
command = [ command = [
"vllm", "vllm",
"bench", "bench",
"serve", "serve",
"--model",
MODEL_NAME,
"--host", "--host",
server.host, server.host,
"--port", "--port",
str(server.port), str(server.port),
"--dataset-name", "--input-len",
"random",
"--random-input-len",
"32", "32",
"--random-output-len", "--output-len",
"4", "4",
"--num-prompts", "--num-prompts",
"5", "5",

View File

@ -208,7 +208,8 @@ def test_attn_quant(
# To capture subprocess logs, we need to know whether spawn or fork is used. # To capture subprocess logs, we need to know whether spawn or fork is used.
# Force spawn as it is more general. # Force spawn as it is more general.
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
model_kwargs["attention_config"] = {"backend": backend.name}
compilation_config = CompilationConfig( compilation_config = CompilationConfig(
# Testing properties # Testing properties
@ -297,7 +298,8 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
# To capture subprocess logs, we need to know whether spawn or fork is used. # To capture subprocess logs, we need to know whether spawn or fork is used.
# Force spawn as it is more general. # Force spawn as it is more general.
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
model_kwargs["attention_config"] = {"backend": backend.name}
compilation_config = CompilationConfig( compilation_config = CompilationConfig(
# Testing properties # Testing properties
@ -409,7 +411,8 @@ def test_tp2_attn_quant_async_tp(
# To capture subprocess logs, we need to know whether spawn or fork is used. # To capture subprocess logs, we need to know whether spawn or fork is used.
# Force spawn as it is more general. # Force spawn as it is more general.
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
model_kwargs["attention_config"] = {"backend": backend.name}
compilation_config = CompilationConfig( compilation_config = CompilationConfig(
# Testing properties # Testing properties
@ -523,6 +526,8 @@ CUSTOM_OPS_QUANT_RMS_NORM = ["+quant_fp8,+rms_norm"]
list[tuple[Any, ...]](flat_product(MODELS_GROUP_FP8, CUSTOM_OPS_QUANT_RMS_NORM)), list[tuple[Any, ...]](flat_product(MODELS_GROUP_FP8, CUSTOM_OPS_QUANT_RMS_NORM)),
) )
@pytest.mark.parametrize("inductor_graph_partition", [True, False]) @pytest.mark.parametrize("inductor_graph_partition", [True, False])
# TODO: remove skip after we fix the fusion thoroughly
@pytest.mark.skipif(is_blackwell(), reason="Temporarily disabled on Blackwell")
def test_rms_group_quant( def test_rms_group_quant(
model_name: str, model_name: str,
model_kwargs: dict[str, Any], model_kwargs: dict[str, Any],
@ -562,7 +567,9 @@ def test_rms_group_quant(
splitting_ops=splitting_ops, splitting_ops=splitting_ops,
# Common # Common
mode=CompilationMode.VLLM_COMPILE, mode=CompilationMode.VLLM_COMPILE,
pass_config=PassConfig(eliminate_noops=True, enable_fusion=True), pass_config=PassConfig(
fuse_norm_quant=True, fuse_act_quant=True, eliminate_noops=True
),
# Inductor caches custom passes by default as well via uuid # Inductor caches custom passes by default as well via uuid
inductor_compile_config={"force_disable_caches": True}, inductor_compile_config={"force_disable_caches": True},
) )

View File

@ -89,7 +89,6 @@ class TestSetting:
], ],
) )
def test_compile_correctness( def test_compile_correctness(
monkeypatch: pytest.MonkeyPatch,
test_setting: TestSetting, test_setting: TestSetting,
): ):
# this test is run under multiple suits, with different GPUs. # this test is run under multiple suits, with different GPUs.
@ -107,49 +106,48 @@ def test_compile_correctness(
f"{cuda_device_count_stateless()}" f"{cuda_device_count_stateless()}"
) )
with monkeypatch.context() as m: final_args = [
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) *model_args,
final_args = [ "-pp",
*model_args, str(pp_size),
"-pp", "-tp",
str(pp_size), str(tp_size),
"-tp", "-cc.cudagraph_mode=none",
str(tp_size), f"--attention-backend={attn_backend}",
"-cc.cudagraph_mode=none", ]
]
all_args: list[list[str]] = [] all_args: list[list[str]] = []
all_envs: list[dict[str, str] | None] = [] all_envs: list[dict[str, str] | None] = []
for comp_mode in [ for comp_mode in [
CompilationMode.STOCK_TORCH_COMPILE, CompilationMode.STOCK_TORCH_COMPILE,
CompilationMode.DYNAMO_TRACE_ONCE, CompilationMode.DYNAMO_TRACE_ONCE,
CompilationMode.VLLM_COMPILE, CompilationMode.VLLM_COMPILE,
]: ]:
for mode in [CompilationMode.NONE, comp_mode]: for mode in [CompilationMode.NONE, comp_mode]:
all_args.append( all_args.append(
final_args + [f"-cc.mode={mode.name}", "-cc.backend=inductor"] final_args + [f"-cc.mode={mode.name}", "-cc.backend=inductor"]
)
# inductor will change the output, so we only compare if the output
# is close, not exactly the same.
compare_all_settings(
model,
all_args,
all_envs,
method=method if method != "generate" else "generate_close",
) )
all_envs.clear()
all_args.clear()
for mode in [ # inductor will change the output, so we only compare if the output
CompilationMode.NONE, # is close, not exactly the same.
CompilationMode.STOCK_TORCH_COMPILE, compare_all_settings(
CompilationMode.DYNAMO_TRACE_ONCE, model,
CompilationMode.VLLM_COMPILE, all_args,
]: all_envs,
all_args.append(final_args + [f"-cc.mode={mode.name}", "-cc.backend=eager"]) method=method if method != "generate" else "generate_close",
all_envs.append({}) )
all_envs.append({}) all_envs.clear()
all_args.clear()
compare_all_settings(model, all_args * 3, all_envs, method=method) for mode in [
CompilationMode.NONE,
CompilationMode.STOCK_TORCH_COMPILE,
CompilationMode.DYNAMO_TRACE_ONCE,
CompilationMode.VLLM_COMPILE,
]:
all_args.append(final_args + [f"-cc.mode={mode.name}", "-cc.backend=eager"])
all_envs.append({})
all_envs.append({})
compare_all_settings(model, all_args * 3, all_envs, method=method)

View File

@ -74,7 +74,6 @@ def llm_pair(request):
# Force native sampler to avoid potential nondeterminism in FlashInfer # Force native sampler to avoid potential nondeterminism in FlashInfer
# when per-request generators are not used in V1. # when per-request generators are not used in V1.
"VLLM_USE_FLASHINFER_SAMPLER": "0", "VLLM_USE_FLASHINFER_SAMPLER": "0",
**backend_config.env_vars,
} }
with temporary_environ(env_vars): with temporary_environ(env_vars):
full = LLM( full = LLM(
@ -170,16 +169,10 @@ class TestFullCUDAGraph:
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda") @pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
def test_full_cudagraph_with_invalid_backend(): def test_full_cudagraph_with_invalid_backend():
with ( # Flex_Attention is not supported with full cuda graph
temporary_environ( with pytest.raises(RuntimeError):
{
"VLLM_ATTENTION_BACKEND": "FLEX_ATTENTION",
# Flex_Attention is not supported with full cuda graph
}
),
pytest.raises(RuntimeError),
):
LLM( LLM(
model="Qwen/Qwen2-1.5B-Instruct", model="Qwen/Qwen2-1.5B-Instruct",
compilation_config=CompilationConfig(cudagraph_mode="FULL"), compilation_config=CompilationConfig(cudagraph_mode="FULL"),
attention_config={"backend": "FLEX_ATTENTION"},
) )

View File

@ -197,20 +197,19 @@ def test_custom_compile_config(
], ],
) )
def test_fp8_kv_scale_compile( def test_fp8_kv_scale_compile(
monkeypatch: pytest.MonkeyPatch,
compilation_mode: int, compilation_mode: int,
model: str, model: str,
backend: AttentionBackendEnum | None, backend: AttentionBackendEnum | None,
): ):
if backend:
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
model_kwargs = { model_kwargs = {
"quantization": "fp8", "quantization": "fp8",
"kv_cache_dtype": "fp8_e4m3", "kv_cache_dtype": "fp8_e4m3",
"calculate_kv_scales": True, "calculate_kv_scales": True,
"max_model_len": 512, "max_model_len": 512,
} }
if backend:
model_kwargs["attention_config"] = {"backend": backend.name}
run_model(compilation_mode, model, **model_kwargs) run_model(compilation_mode, model, **model_kwargs)

View File

@ -219,14 +219,12 @@ def _test_cp_gsm8k(
] ]
) )
server_env = {}
if attn_backend: if attn_backend:
server_env["VLLM_ATTENTION_BACKEND"] = attn_backend server_args.append(f"--attention-backend={attn_backend}")
with RemoteOpenAIServer( with RemoteOpenAIServer(
model_id, model_id,
server_args, server_args,
env_dict=server_env,
max_wait_seconds=720, max_wait_seconds=720,
) as remote_server: ) as remote_server:
host = f"http://{remote_server.host}" host = f"http://{remote_server.host}"

View File

@ -20,23 +20,21 @@ from ..utils import compare_two_settings, create_new_process_for_each_test
) )
@create_new_process_for_each_test() @create_new_process_for_each_test()
def test_pp_cudagraph( def test_pp_cudagraph(
monkeypatch: pytest.MonkeyPatch,
PP_SIZE: int, PP_SIZE: int,
MODEL_NAME: str, MODEL_NAME: str,
ATTN_BACKEND: LiteralString, ATTN_BACKEND: LiteralString,
): ):
with monkeypatch.context() as m: cudagraph_args = [
cudagraph_args = [ # use half precision for speed and memory savings in CI environment
# use half precision for speed and memory savings in CI environment "--dtype",
"--dtype", "float16",
"float16", "--pipeline-parallel-size",
"--pipeline-parallel-size", str(PP_SIZE),
str(PP_SIZE), "--distributed-executor-backend",
"--distributed-executor-backend", "mp",
"mp", f"--attention-backend={ATTN_BACKEND}",
] ]
m.setenv("VLLM_ATTENTION_BACKEND", ATTN_BACKEND)
eager_args = cudagraph_args + ["--enforce-eager"] eager_args = cudagraph_args + ["--enforce-eager"]
compare_two_settings(MODEL_NAME, eager_args, cudagraph_args) compare_two_settings(MODEL_NAME, eager_args, cudagraph_args)

View File

@ -9,7 +9,7 @@ from typing import Annotated, Literal
import pytest import pytest
from vllm.config import CompilationConfig, config from vllm.config import AttentionConfig, CompilationConfig, config
from vllm.engine.arg_utils import ( from vllm.engine.arg_utils import (
EngineArgs, EngineArgs,
contains_type, contains_type,
@ -298,6 +298,139 @@ def test_compilation_config():
) )
def test_attention_config():
from vllm.attention.backends.registry import AttentionBackendEnum
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
# default value
args = parser.parse_args([])
assert args is not None
engine_args = EngineArgs.from_cli_args(args)
assert engine_args.attention_config == AttentionConfig()
# set backend via dot notation
args = parser.parse_args(["--attention-config.backend", "FLASH_ATTN"])
assert args is not None
engine_args = EngineArgs.from_cli_args(args)
assert engine_args.attention_config.backend is not None
assert engine_args.attention_config.backend.name == "FLASH_ATTN"
# set backend via --attention-backend shorthand
args = parser.parse_args(["--attention-backend", "FLASHINFER"])
assert args is not None
engine_args = EngineArgs.from_cli_args(args)
assert engine_args.attention_backend is not None
assert engine_args.attention_backend == "FLASHINFER"
# set all fields via dot notation
args = parser.parse_args(
[
"--attention-config.backend",
"FLASH_ATTN",
"--attention-config.flash_attn_version",
"3",
"--attention-config.use_prefill_decode_attention",
"true",
"--attention-config.flash_attn_max_num_splits_for_cuda_graph",
"16",
"--attention-config.use_cudnn_prefill",
"true",
"--attention-config.use_trtllm_ragged_deepseek_prefill",
"true",
"--attention-config.use_trtllm_attention",
"true",
"--attention-config.disable_flashinfer_prefill",
"true",
"--attention-config.disable_flashinfer_q_quantization",
"true",
]
)
assert args is not None
engine_args = EngineArgs.from_cli_args(args)
assert engine_args.attention_config.backend is not None
assert engine_args.attention_config.backend.name == "FLASH_ATTN"
assert engine_args.attention_config.flash_attn_version == 3
assert engine_args.attention_config.use_prefill_decode_attention is True
assert engine_args.attention_config.flash_attn_max_num_splits_for_cuda_graph == 16
assert engine_args.attention_config.use_cudnn_prefill is True
assert engine_args.attention_config.use_trtllm_ragged_deepseek_prefill is True
assert engine_args.attention_config.use_trtllm_attention is True
assert engine_args.attention_config.disable_flashinfer_prefill is True
assert engine_args.attention_config.disable_flashinfer_q_quantization is True
# set to string form of a dict with all fields
args = parser.parse_args(
[
"--attention-config="
'{"backend": "FLASHINFER", "flash_attn_version": 2, '
'"use_prefill_decode_attention": false, '
'"flash_attn_max_num_splits_for_cuda_graph": 8, '
'"use_cudnn_prefill": false, '
'"use_trtllm_ragged_deepseek_prefill": false, '
'"use_trtllm_attention": false, '
'"disable_flashinfer_prefill": false, '
'"disable_flashinfer_q_quantization": false}',
]
)
assert args is not None
engine_args = EngineArgs.from_cli_args(args)
assert engine_args.attention_config.backend is not None
assert engine_args.attention_config.backend.name == "FLASHINFER"
assert engine_args.attention_config.flash_attn_version == 2
assert engine_args.attention_config.use_prefill_decode_attention is False
assert engine_args.attention_config.flash_attn_max_num_splits_for_cuda_graph == 8
assert engine_args.attention_config.use_cudnn_prefill is False
assert engine_args.attention_config.use_trtllm_ragged_deepseek_prefill is False
assert engine_args.attention_config.use_trtllm_attention is False
assert engine_args.attention_config.disable_flashinfer_prefill is False
assert engine_args.attention_config.disable_flashinfer_q_quantization is False
# test --attention-backend flows into VllmConfig.attention_config
args = parser.parse_args(
[
"--model",
"facebook/opt-125m",
"--attention-backend",
"FLASH_ATTN",
]
)
assert args is not None
engine_args = EngineArgs.from_cli_args(args)
vllm_config = engine_args.create_engine_config()
assert vllm_config.attention_config.backend == AttentionBackendEnum.FLASH_ATTN
# test --attention-config.backend flows into VllmConfig.attention_config
args = parser.parse_args(
[
"--model",
"facebook/opt-125m",
"--attention-config.backend",
"FLASHINFER",
]
)
assert args is not None
engine_args = EngineArgs.from_cli_args(args)
vllm_config = engine_args.create_engine_config()
assert vllm_config.attention_config.backend == AttentionBackendEnum.FLASHINFER
# test --attention-backend and --attention-config.backend are mutually exclusive
args = parser.parse_args(
[
"--model",
"facebook/opt-125m",
"--attention-backend",
"FLASH_ATTN",
"--attention-config.backend",
"FLASHINFER",
]
)
assert args is not None
engine_args = EngineArgs.from_cli_args(args)
with pytest.raises(ValueError, match="mutually exclusive"):
engine_args.create_engine_config()
def test_prefix_cache_default(): def test_prefix_cache_default():
parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
args = parser.parse_args([]) args = parser.parse_args([])

View File

@ -14,11 +14,10 @@ import requests
from prometheus_client.parser import text_string_to_metric_families from prometheus_client.parser import text_string_to_metric_families
from transformers import AutoTokenizer from transformers import AutoTokenizer
from tests.conftest import LocalAssetServer
from tests.utils import RemoteOpenAIServer
from vllm import version from vllm import version
from ...conftest import LocalAssetServer
from ...utils import RemoteOpenAIServer
MODELS = { MODELS = {
"text": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "text": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"multimodal": "HuggingFaceTB/SmolVLM-256M-Instruct", "multimodal": "HuggingFaceTB/SmolVLM-256M-Instruct",

View File

@ -254,7 +254,9 @@ async def test_single_chat_session_input_audio(
async def test_chat_streaming_audio( async def test_chat_streaming_audio(
client: openai.AsyncOpenAI, model_name: str, audio_url: str client: openai.AsyncOpenAI, model_name: str, audio_url: str
): ):
messages = dummy_messages_from_audio_url(audio_url) messages = dummy_messages_from_audio_url(
audio_url, "What's a short title for this audio?"
)
# test single completion # test single completion
chat_completion = await client.chat.completions.create( chat_completion = await client.chat.completions.create(

View File

@ -76,15 +76,10 @@ def default_server_args(with_tool_parser: bool):
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def gptoss_server( def gptoss_server(default_server_args: list[str]):
monkeypatch_module: pytest.MonkeyPatch, default_server_args: list[str] server_args = default_server_args + ["--attention-backend=TRITON_ATTN"]
): with RemoteOpenAIServer(GPT_OSS_MODEL_NAME, server_args) as remote_server:
with monkeypatch_module.context() as m: yield remote_server
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
with RemoteOpenAIServer(
GPT_OSS_MODEL_NAME, default_server_args
) as remote_server:
yield remote_server
@pytest_asyncio.fixture @pytest_asyncio.fixture

View File

@ -244,3 +244,35 @@ async def test_audio_with_timestamp(mary_had_lamb, whisper_client):
) )
assert transcription.segments is not None assert transcription.segments is not None
assert len(transcription.segments) > 0 assert len(transcription.segments) > 0
@pytest.mark.asyncio
async def test_audio_with_max_tokens(whisper_client, mary_had_lamb):
transcription = await whisper_client.audio.transcriptions.create(
model=MODEL_NAME,
file=mary_had_lamb,
language="en",
response_format="text",
temperature=0.0,
extra_body={"max_completion_tokens": 1},
)
out = json.loads(transcription)
out_text = out["text"]
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(MODEL_NAME)
out_tokens = tok(out_text, add_special_tokens=False)["input_ids"]
assert len(out_tokens) == 1
# max_completion_tokens > max_model_len
transcription = await whisper_client.audio.transcriptions.create(
model=MODEL_NAME,
file=mary_had_lamb,
language="en",
response_format="text",
temperature=0.0,
extra_body={"max_completion_tokens": int(1e6)},
)
out = json.loads(transcription)
out_text = out["text"]
out_tokens = tok(out_text, add_special_tokens=False)["input_ids"]
assert len(out_tokens) < 450 # ~Whisper max output len

View File

@ -227,3 +227,36 @@ async def test_long_audio_request(foscolo, client_and_model):
) )
out = json.loads(translation)["text"].strip().lower() out = json.loads(translation)["text"].strip().lower()
assert out.count("greek sea") == 2 assert out.count("greek sea") == 2
@pytest.mark.asyncio
async def test_audio_with_max_tokens(mary_had_lamb, client_and_model):
client, model_name = client_and_model
transcription = await client.audio.translations.create(
model=model_name,
file=mary_had_lamb,
response_format="text",
temperature=0.0,
extra_body={"max_completion_tokens": 1},
)
out = json.loads(transcription)
out_text = out["text"]
print(out_text)
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(model_name)
out_tokens = tok(out_text, add_special_tokens=False)["input_ids"]
assert len(out_tokens) == 1
# max_completion_tokens > max_model_len
transcription = await client.audio.transcriptions.create(
model=model_name,
file=mary_had_lamb,
response_format="text",
temperature=0.0,
extra_body={"max_completion_tokens": int(1e6)},
)
out = json.loads(transcription)
out_text = out["text"]
print(out_text)
out_tokens = tok(out_text, add_special_tokens=False)["input_ids"]
assert len(out_tokens) < 450 # ~Whisper max output len

View File

View File

@ -37,7 +37,7 @@ def server():
"--max-num-seqs", "--max-num-seqs",
"128", "128",
"--worker-extension-cls", "--worker-extension-cls",
"tests.entrypoints.openai.test_collective_rpc.TestWorkerExtension", "tests.entrypoints.rpc.test_collective_rpc.TestWorkerExtension",
] ]
with RemoteOpenAIServer( with RemoteOpenAIServer(
MODEL_NAME, MODEL_NAME,

View File

View File

@ -4,7 +4,7 @@
import requests import requests
from prometheus_client.parser import text_string_to_metric_families from prometheus_client.parser import text_string_to_metric_families
from ...utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
MODEL_NAME = "meta-llama/Llama-3.2-1B" MODEL_NAME = "meta-llama/Llama-3.2-1B"

View File

@ -7,9 +7,8 @@ This directory contains a replacement for the lm-eval-harness GSM8K evaluation,
### Run tests with pytest (like buildkite) ### Run tests with pytest (like buildkite)
```bash ```bash
pytest -s -v tests/gsm8k/test_gsm8k_correctness.py \ pytest -s -v tests/evals/gsm8k/test_gsm8k_correctness.py \
--config-list-file=configs/models-small.txt \ --config-list-file=configs/models-small.txt
--tp-size=1
``` ```
### Run standalone evaluation script ### Run standalone evaluation script
@ -31,5 +30,11 @@ model_name: "Qwen/Qwen2.5-1.5B-Instruct"
accuracy_threshold: 0.54 # Minimum expected accuracy accuracy_threshold: 0.54 # Minimum expected accuracy
num_questions: 1319 # Number of questions (default: full test set) num_questions: 1319 # Number of questions (default: full test set)
num_fewshot: 5 # Few-shot examples from train set num_fewshot: 5 # Few-shot examples from train set
max_model_len: 4096 # Model context length server_args: "--max-model-len 4096 --tensor-parallel-size 2" # Server arguments
env: # Environment variables (optional)
VLLM_USE_FLASHINFER_MOE_FP4: "1"
``` ```
The `server_args` field accepts any arguments that can be passed to `vllm serve`.
The `env` field accepts a dictionary of environment variables to set for the server process.

View File

@ -2,5 +2,4 @@ model_name: "RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8"
accuracy_threshold: 0.72 accuracy_threshold: 0.72
num_questions: 1319 num_questions: 1319
num_fewshot: 5 num_fewshot: 5
max_model_len: 4096 server_args: "--enforce-eager --max-model-len 4096"

View File

@ -2,4 +2,4 @@ model_name: "nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test"
accuracy_threshold: 0.74 accuracy_threshold: 0.74
num_questions: 1319 num_questions: 1319
num_fewshot: 5 num_fewshot: 5
max_model_len: 4096 server_args: "--enforce-eager --max-model-len 4096"

View File

@ -2,4 +2,4 @@ model_name: "RedHatAI/Llama-3.2-1B-Instruct-quantized.w8a8"
accuracy_threshold: 0.31 accuracy_threshold: 0.31
num_questions: 1319 num_questions: 1319
num_fewshot: 5 num_fewshot: 5
max_model_len: 4096 server_args: "--enforce-eager --max-model-len 4096"

View File

@ -2,4 +2,4 @@ model_name: "nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16"
accuracy_threshold: 0.45 accuracy_threshold: 0.45
num_questions: 1319 num_questions: 1319
num_fewshot: 5 num_fewshot: 5
max_model_len: 4096 server_args: "--enforce-eager --max-model-len 4096"

View File

@ -2,4 +2,4 @@ model_name: "RedHatAI/Qwen2.5-VL-3B-Instruct-FP8-Dynamic"
accuracy_threshold: 0.60 accuracy_threshold: 0.60
num_questions: 1319 num_questions: 1319
num_fewshot: 5 num_fewshot: 5
max_model_len: 4096 server_args: "--enforce-eager --max-model-len 4096"

View File

@ -2,4 +2,4 @@ model_name: "Qwen/Qwen3-0.6B-FP8"
accuracy_threshold: 0.375 accuracy_threshold: 0.375
num_questions: 1319 num_questions: 1319
num_fewshot: 5 num_fewshot: 5
max_model_len: 4096 server_args: "--enforce-eager --max-model-len 4096"

View File

@ -2,5 +2,4 @@ model_name: "nvidia/Qwen3-30B-A3B-FP4"
accuracy_threshold: 0.89 accuracy_threshold: 0.89
num_questions: 1319 num_questions: 1319
num_fewshot: 5 num_fewshot: 5
max_model_len: 4096 server_args: "--enforce-eager --max-model-len 4096"

View File

@ -0,0 +1,12 @@
model_name: "nm-testing/Qwen3-Next-80B-A3B-Instruct-NVFP4"
accuracy_threshold: 0.75
num_questions: 1319
num_fewshot: 5
server_args: >-
--enforce-eager
--max-model-len 4096
--tensor-parallel-size 2
--enable-expert-parallel
--speculative-config '{"method":"qwen3_next_mtp","num_speculative_tokens":1}'
env:
VLLM_USE_FLASHINFER_MOE_FP4: "1"

View File

@ -3,3 +3,4 @@ Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml
Qwen1.5-MoE-W4A16-CT.yaml Qwen1.5-MoE-W4A16-CT.yaml
DeepSeek-V2-Lite-Instruct-FP8.yaml DeepSeek-V2-Lite-Instruct-FP8.yaml
Qwen3-30B-A3B-NVFP4.yaml Qwen3-30B-A3B-NVFP4.yaml
Qwen3-Next-80B-A3B-NVFP4-EP2.yaml

View File

@ -11,14 +11,12 @@ def pytest_addoption(parser):
default="configs/models-small.txt", default="configs/models-small.txt",
help="File containing list of config files to test", help="File containing list of config files to test",
) )
parser.addoption("--tp-size", default=1, type=int, help="Tensor parallel size")
def pytest_generate_tests(metafunc): def pytest_generate_tests(metafunc):
"""Generate test parameters from config files.""" """Generate test parameters from config files."""
if "config_filename" in metafunc.fixturenames: if "config_filename" in metafunc.fixturenames:
config_list_file = metafunc.config.getoption("--config-list-file") config_list_file = metafunc.config.getoption("--config-list-file")
tp_size = metafunc.config.getoption("--tp-size")
# Handle both relative and absolute paths # Handle both relative and absolute paths
config_list_path = Path(config_list_file) config_list_path = Path(config_list_file)
@ -55,9 +53,9 @@ def pytest_generate_tests(metafunc):
# Generate test parameters # Generate test parameters
if config_files: if config_files:
metafunc.parametrize( metafunc.parametrize(
["config_filename", "tp_size"], "config_filename",
[(config_file, int(tp_size)) for config_file in config_files], config_files,
ids=[f"{config_file.stem}-tp{tp_size}" for config_file in config_files], ids=[config_file.stem for config_file in config_files],
) )
else: else:
print("No config files found, test will be skipped") print("No config files found, test will be skipped")

View File

@ -5,30 +5,31 @@ GSM8K evaluation using vLLM server and isolated GSM8K script.
Replacement for lm-eval-harness with better performance and control. Replacement for lm-eval-harness with better performance and control.
Usage: Usage:
pytest -s -v test_gsm8k_correctness.py \ pytest -s -v tests/evals/gsm8k/test_gsm8k_correctness.py \
--config-list-file=configs/models-small.txt \ --config-list-file=configs/models-small.txt
--tp-size=1
""" """
import shlex
import yaml import yaml
from tests.utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
from .gsm8k_eval import evaluate_gsm8k from .gsm8k_eval import evaluate_gsm8k
RTOL = 0.08 # Relative tolerance for accuracy comparison TOL = 0.08 # Absolute tolerance for accuracy comparison
def launch_gsm8k_eval(eval_config, server_url, tp_size): def run_gsm8k_eval(eval_config: dict, server_url: str) -> dict:
"""Launch GSM8K evaluation using our isolated script.""" """Run GSM8K evaluation using our isolated script."""
# Extract host and port from server URL # Extract host and port from server URL
if "://" in server_url: if "://" in server_url:
server_url = server_url.split("://")[1] server_url = server_url.split("://")[1]
host_port = server_url.split("/")[0] # Remove path if present host_port = server_url.split("/")[0] # Remove path if present
if ":" in host_port: if ":" in host_port:
host, port = host_port.split(":") host, p = host_port.split(":")
port = int(port) port = int(p)
else: else:
host = host_port host = host_port
port = 8000 port = 8000
@ -48,46 +49,57 @@ def launch_gsm8k_eval(eval_config, server_url, tp_size):
return results return results
def test_gsm8k_correctness_param(config_filename, tp_size): def test_gsm8k_correctness(config_filename):
"""Test GSM8K correctness for a given model configuration.""" """Test GSM8K correctness for a given model configuration."""
eval_config = yaml.safe_load(config_filename.read_text(encoding="utf-8")) eval_config = yaml.safe_load(config_filename.read_text(encoding="utf-8"))
# Server arguments # Parse server arguments from config (use shlex to handle quoted strings)
server_args = [ server_args_str = eval_config.get("server_args", "")
"--max-model-len", server_args = shlex.split(server_args_str) if server_args_str else []
str(eval_config.get("max_model_len", 4096)),
"--enforce-eager", # Add standard server arguments
"--trust-remote-code", server_args.extend(
"--tensor-parallel-size", [
str(tp_size), "--trust-remote-code",
] ]
)
env_dict = eval_config.get("env", None) env_dict = eval_config.get("env", None)
print(f"Starting GSM8K evaluation for model: {eval_config['model_name']}")
print(f"Expected metric threshold: {eval_config['accuracy_threshold']}")
print(f"Number of questions: {eval_config['num_questions']}")
print(f"Number of few-shot examples: {eval_config['num_fewshot']}")
print(f"Server args: {' '.join(server_args)}")
# Launch server and run evaluation # Launch server and run evaluation
with RemoteOpenAIServer( with RemoteOpenAIServer(
eval_config["model_name"], server_args, env_dict=env_dict, max_wait_seconds=480 eval_config["model_name"],
server_args,
env_dict=env_dict,
max_wait_seconds=600,
) as remote_server: ) as remote_server:
server_url = remote_server.url_for("v1") server_url = remote_server.url_for("v1")
print(f"Server started at: {server_url}")
results = launch_gsm8k_eval(eval_config, server_url, tp_size) results = run_gsm8k_eval(eval_config, server_url)
# Check accuracy against threshold measured_metric = results["accuracy"]
measured_accuracy = results["accuracy"] expected_metric = eval_config["accuracy_threshold"]
expected_accuracy = eval_config["accuracy_threshold"]
print(f"GSM8K Results for {eval_config['model_name']}:") print(f"GSM8K Results for {eval_config['model_name']}:")
print(f" Accuracy: {measured_accuracy:.3f}") print(f" Measured metric: {measured_metric:.4f}")
print(f" Expected: {expected_accuracy:.3f}") print(f" Expected metric: {expected_metric:.4f}")
print(f" Tolerance: {TOL:.4f}")
print(f" Questions: {results['num_questions']}") print(f" Questions: {results['num_questions']}")
print(f" Invalid rate: {results['invalid_rate']:.3f}") print(f" Invalid rate: {results['invalid_rate']:.3f}")
print(f" Latency: {results['latency']:.1f}s") print(f" Latency: {results['latency']:.1f}s")
print(f" QPS: {results['questions_per_second']:.1f}") print(f" QPS: {results['questions_per_second']:.1f}")
# Verify accuracy is within tolerance # Verify metric is within tolerance
assert measured_accuracy >= expected_accuracy - RTOL, ( assert measured_metric >= expected_metric - TOL, (
f"Accuracy too low: {measured_accuracy:.3f} < " f"GSM8K metric too low: {measured_metric:.4f} < "
f"{expected_accuracy:.3f} - {RTOL:.3f}" f"{expected_metric:.4f} - {TOL:.4f} = {expected_metric - TOL:.4f}"
) )
print(f"✅ GSM8K test passed for {eval_config['model_name']}") print(f"✅ GSM8K test passed for {eval_config['model_name']}")

View File

@ -6,7 +6,9 @@ from unittest.mock import patch
import pytest import pytest
import torch import torch
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
from vllm.config import AttentionConfig, VllmConfig, set_current_vllm_config
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.cpu import CpuPlatform from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform from vllm.platforms.cuda import CudaPlatform
@ -73,18 +75,18 @@ def generate_params():
@pytest.mark.parametrize("device, name, use_mla, block_size", generate_params()) @pytest.mark.parametrize("device, name, use_mla, block_size", generate_params())
def test_env( def test_backend_selection(
device: str, device: str,
name: str, name: str,
use_mla: bool, use_mla: bool,
block_size: int, block_size: int,
monkeypatch: pytest.MonkeyPatch,
): ):
"""Test attention backend selection with valid device-backend pairs.""" """Test attention backend selection with valid device-backend pairs."""
with monkeypatch.context() as m: # Create AttentionConfig with the specified backend
m.setenv("VLLM_ATTENTION_BACKEND", name) attention_config = AttentionConfig(backend=AttentionBackendEnum[name])
m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0") vllm_config = VllmConfig(attention_config=attention_config)
with set_current_vllm_config(vllm_config):
if device == "cpu": if device == "cpu":
with patch("vllm.platforms.current_platform", CpuPlatform()): with patch("vllm.platforms.current_platform", CpuPlatform()):
backend = get_attn_backend(16, torch.float16, None, block_size) backend = get_attn_backend(16, torch.float16, None, block_size)
@ -217,27 +219,32 @@ def test_env(
@pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_fp32_fallback(device: str): def test_fp32_fallback(device: str):
"""Test attention backend selection with fp32.""" """Test attention backend selection with fp32."""
if device == "cpu": # Use default config (no backend specified)
with patch("vllm.platforms.current_platform", CpuPlatform()): vllm_config = VllmConfig()
backend = get_attn_backend(16, torch.float32, None, 16)
assert backend.get_name() == "CPU_ATTN"
elif device == "cuda": with set_current_vllm_config(vllm_config):
with patch("vllm.platforms.current_platform", CudaPlatform()): if device == "cpu":
backend = get_attn_backend(16, torch.float32, None, 16) with patch("vllm.platforms.current_platform", CpuPlatform()):
assert backend.get_name() == "FLEX_ATTENTION" backend = get_attn_backend(16, torch.float32, None, 16)
assert backend.get_name() == "CPU_ATTN"
elif device == "cuda":
with patch("vllm.platforms.current_platform", CudaPlatform()):
backend = get_attn_backend(16, torch.float32, None, 16)
assert backend.get_name() == "FLEX_ATTENTION"
def test_flash_attn(monkeypatch: pytest.MonkeyPatch): def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
"""Test FlashAttn validation.""" """Test FlashAttn validation."""
pytest.skip( pytest.skip(
"Skipping as current backend selector does not " "Skipping as current backend selector does not "
"handle fallbacks when a backend is set via env var." "handle fallbacks when a backend is explicitly set."
) )
with monkeypatch.context() as m: attention_config = AttentionConfig(backend=AttentionBackendEnum.FLASH_ATTN)
m.setenv("VLLM_ATTENTION_BACKEND", "FLASH_ATTN") vllm_config = VllmConfig(attention_config=attention_config)
with set_current_vllm_config(vllm_config):
# Unsupported CUDA arch # Unsupported CUDA arch
monkeypatch.setattr(torch.cuda, "get_device_capability", lambda _=None: (7, 5)) monkeypatch.setattr(torch.cuda, "get_device_capability", lambda _=None: (7, 5))
backend = get_attn_backend(16, torch.float16, None, 16) backend = get_attn_backend(16, torch.float16, None, 16)
@ -277,15 +284,10 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
assert backend.get_name() != "FLASH_ATTN" assert backend.get_name() != "FLASH_ATTN"
def test_invalid_env(monkeypatch: pytest.MonkeyPatch): def test_invalid_backend():
"""Test that invalid attention backend names raise ValueError.""" """Test that invalid attention backend names raise ValueError."""
with ( with (
monkeypatch.context() as m, pytest.raises(ValueError),
patch("vllm.platforms.current_platform", CudaPlatform()),
): ):
m.setenv("VLLM_ATTENTION_BACKEND", "INVALID") # Invalid backend name should raise ValueError when creating enum
AttentionConfig(backend=AttentionBackendEnum["INVALID"])
# Should raise ValueError for invalid backend
with pytest.raises(ValueError) as exc_info:
get_attn_backend(32, torch.float16, None, 16)
assert "Invalid value 'INVALID'" in str(exc_info.value)

View File

@ -455,3 +455,38 @@ def test_flashinfer_trtllm_prefill_with_baseline(
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol),
f"{torch.max(torch.abs(output - output_trtllm))}", f"{torch.max(torch.abs(output - output_trtllm))}",
) )
def test_trtllm_attention_rejects_num_kv_heads_1() -> None:
"""Test that TRTLLM attention correctly rejects num_kv_heads=1.
When num_kv_heads=1 (MQA), the KV cache strides become degenerate
(stride_heads == stride_batch), which causes CUDA's cuTensorMapEncodeTiled
to fail because TMA descriptors cannot handle degenerate 4D tensors with
singleton dimensions.
This test verifies that can_use_trtllm_attention returns False for
num_kv_heads=1 configurations.
"""
from vllm.utils.flashinfer import can_use_trtllm_attention
# num_kv_heads=1 should be rejected
assert not can_use_trtllm_attention(num_qo_heads=64, num_kv_heads=1), (
"can_use_trtllm_attention should return False for num_kv_heads=1"
)
assert not can_use_trtllm_attention(num_qo_heads=32, num_kv_heads=1), (
"can_use_trtllm_attention should return False for num_kv_heads=1"
)
# num_kv_heads > 1 should be accepted (if platform supports it)
# Note: This may return False on non-Blackwell platforms, which is fine
result_kv8 = can_use_trtllm_attention(num_qo_heads=64, num_kv_heads=8)
result_kv1 = can_use_trtllm_attention(num_qo_heads=64, num_kv_heads=1)
# Even if platform doesn't support TRTLLM, num_kv_heads=1 should never
# return True when num_kv_heads > 1 returns True
if result_kv8:
assert not result_kv1, (
"If TRTLLM is supported for num_kv_heads=8, "
"it must be rejected for num_kv_heads=1"
)

View File

@ -4,7 +4,9 @@
import pytest import pytest
import torch import torch
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
from vllm.config import AttentionConfig, VllmConfig, set_current_vllm_config
from vllm.platforms.rocm import RocmPlatform from vllm.platforms.rocm import RocmPlatform
@ -16,40 +18,56 @@ def clear_cache():
@pytest.mark.skip(reason="Skipped for now. Should be revisited.") @pytest.mark.skip(reason="Skipped for now. Should be revisited.")
def test_selector(monkeypatch: pytest.MonkeyPatch): def test_selector(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m: # Set the current platform to ROCm using monkeypatch
m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_ATTN") monkeypatch.setattr("vllm.attention.selector.current_platform", RocmPlatform())
# Set the current platform to ROCm using monkeypatch # Test standard ROCm attention
monkeypatch.setattr("vllm.attention.selector.current_platform", RocmPlatform()) attention_config = AttentionConfig(backend=AttentionBackendEnum.ROCM_ATTN)
vllm_config = VllmConfig(attention_config=attention_config)
# Test standard ROCm attention with set_current_vllm_config(vllm_config):
backend = get_attn_backend(16, torch.float16, torch.float16, 16, False) backend = get_attn_backend(16, torch.float16, torch.float16, 16, False)
assert backend.get_name() == "ROCM_FLASH" or backend.get_name() == "TRITON_ATTN" assert backend.get_name() == "ROCM_FLASH" or backend.get_name() == "TRITON_ATTN"
# MLA test for deepseek related # MLA test for deepseek related
# Change the attention backend to triton MLA
attention_config = AttentionConfig(backend=AttentionBackendEnum.TRITON_MLA)
vllm_config = VllmConfig(attention_config=attention_config)
# change the attention backend to triton MLA with set_current_vllm_config(vllm_config):
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_MLA")
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True) backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True)
assert backend.get_name() == "TRITON_MLA" assert backend.get_name() == "TRITON_MLA"
# If attention backend is None # If attention backend is None
# If use_mla is true # If use_mla is true
# The selected backend is triton MLA # The selected backend is triton MLA
m.setenv("VLLM_ATTENTION_BACKEND", "") attention_config = AttentionConfig(backend=None)
vllm_config = VllmConfig(attention_config=attention_config)
with set_current_vllm_config(vllm_config):
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True) backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True)
assert backend.get_name() == "TRITON_MLA" assert backend.get_name() == "TRITON_MLA"
# change the attention backend to AITER MLA # Change the attention backend to AITER MLA
m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_AITER_MLA") attention_config = AttentionConfig(backend=AttentionBackendEnum.ROCM_AITER_MLA)
vllm_config = VllmConfig(attention_config=attention_config)
with set_current_vllm_config(vllm_config):
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True) backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True)
assert backend.get_name() == "ROCM_AITER_MLA" assert backend.get_name() == "ROCM_AITER_MLA"
# If attention backend is None # If attention backend is None
# If use_mla is true # If use_mla is true
# If VLLM_ROCM_USE_AITER is enabled # If VLLM_ROCM_USE_AITER is enabled
# The selected backend is ROCM_AITER_MLA # The selected backend is ROCM_AITER_MLA
m.setenv("VLLM_ATTENTION_BACKEND", "") with monkeypatch.context() as m:
m.setenv("VLLM_ROCM_USE_AITER", "1") m.setenv("VLLM_ROCM_USE_AITER", "1")
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True)
assert backend.get_name() == "ROCM_AITER_MLA" attention_config = AttentionConfig(backend=None)
vllm_config = VllmConfig(attention_config=attention_config)
with set_current_vllm_config(vllm_config):
backend = get_attn_backend(
576, torch.bfloat16, "auto", 1, False, use_mla=True
)
assert backend.get_name() == "ROCM_AITER_MLA"

View File

@ -9,8 +9,8 @@ import pytest
import torch import torch
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
GroupedTopk,
fused_grouped_topk, fused_grouped_topk,
grouped_topk,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -50,15 +50,17 @@ def test_grouped_topk(
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0") m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0")
baseline_topk_weights, baseline_topk_ids = grouped_topk( grouped_topk = GroupedTopk(
hidden_states=hidden_states,
gating_output=gating_output,
topk=topk, topk=topk,
renormalize=renormalize, renormalize=renormalize,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
topk_group=topk_group, topk_group=topk_group,
scoring_func=scoring_func, scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
)
baseline_topk_weights, baseline_topk_ids = grouped_topk(
hidden_states=hidden_states,
gating_output=gating_output,
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
) )

View File

@ -37,7 +37,7 @@ def set_seed(seed):
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION, not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
reason="CUDA not available or PyTorch version < 2.7", reason="CUDA not available or PyTorch version < 2.7",
) )
def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch): def test_flex_attention_vs_default_backend(vllm_runner):
"""Test that FlexAttention produces the same outputs as the default backend. """Test that FlexAttention produces the same outputs as the default backend.
This test compares the outputs from the FlexAttention backend with This test compares the outputs from the FlexAttention backend with
@ -54,35 +54,32 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
] ]
# Run with flex attention # Run with flex attention
with monkeypatch.context() as m: set_seed(seed)
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION") with vllm_runner(
model_name,
set_seed(seed) runner="generate",
with vllm_runner( tensor_parallel_size=1,
model_name, num_gpu_blocks_override=128,
runner="generate", enforce_eager=True,
tensor_parallel_size=1, attention_config={"backend": "FLEX_ATTENTION"},
num_gpu_blocks_override=128, ) as llm_flex:
enforce_eager=True, output_flex = llm_flex.generate_greedy_logprobs(
) as llm_flex: prompts, max_tokens, num_logprobs
output_flex = llm_flex.generate_greedy_logprobs( )
prompts, max_tokens, num_logprobs
)
# Run with default backend # Run with default backend
with monkeypatch.context() as m: set_seed(seed)
set_seed(seed) with vllm_runner(
with vllm_runner( model_name,
model_name, runner="generate",
runner="generate", tensor_parallel_size=1,
tensor_parallel_size=1, num_gpu_blocks_override=128,
num_gpu_blocks_override=128, enforce_eager=True,
enforce_eager=True, gpu_memory_utilization=0.85,
gpu_memory_utilization=0.85, ) as llm_default:
) as llm_default: output_default = llm_default.generate_greedy_logprobs(
output_default = llm_default.generate_greedy_logprobs( prompts, max_tokens, num_logprobs
prompts, max_tokens, num_logprobs )
)
check_logprobs_close( check_logprobs_close(
outputs_0_lst=output_flex, outputs_0_lst=output_flex,
@ -96,7 +93,7 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION, not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
reason="CUDA not available or PyTorch version < 2.7", reason="CUDA not available or PyTorch version < 2.7",
) )
def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch): def test_encoder_flex_attention_vs_default_backend(vllm_runner):
"""Test that FlexAttention produces the same outputs as the default backend. """Test that FlexAttention produces the same outputs as the default backend.
This test compares the outputs from the FlexAttention backend with This test compares the outputs from the FlexAttention backend with
@ -110,30 +107,26 @@ def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
] ]
# Run with flex attention # Run with flex attention
with monkeypatch.context() as m: with vllm_runner(
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION") model_name,
with vllm_runner( runner="pooling",
model_name, dtype=torch.bfloat16,
runner="pooling", tensor_parallel_size=1,
dtype=torch.bfloat16, max_model_len=100,
tensor_parallel_size=1, enforce_eager=True,
max_model_len=100, attention_config={"backend": "FLEX_ATTENTION"},
enforce_eager=True, ) as llm_flex:
) as llm_flex: flex_outputs = llm_flex.embed(prompts)
flex_outputs = llm_flex.embed(prompts)
# Run with default backend # Run with default backend
with ( with vllm_runner(
monkeypatch.context() as m, model_name,
vllm_runner( runner="pooling",
model_name, dtype=torch.bfloat16,
runner="pooling", tensor_parallel_size=1,
dtype=torch.bfloat16, max_model_len=100,
tensor_parallel_size=1, enforce_eager=True,
max_model_len=100, ) as llm_default:
enforce_eager=True,
) as llm_default,
):
default_outputs = llm_default.embed(prompts) default_outputs = llm_default.embed(prompts)
check_embeddings_close( check_embeddings_close(

View File

@ -35,10 +35,12 @@ audio_lora_path = MODEL_NAME
models = [MODEL_NAME] models = [MODEL_NAME]
@pytest.fixture(autouse=True) @pytest.fixture
def set_attention_backend_for_rocm(monkeypatch): def granite_speech_attention_config():
"""Return attention config for Granite Speech tests on ROCm."""
if current_platform.is_rocm(): if current_platform.is_rocm():
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN") return {"backend": "TRITON_ATTN"}
return None
def run_test( def run_test(
@ -53,6 +55,7 @@ def run_test(
num_logprobs: int, num_logprobs: int,
tensor_parallel_size: int, tensor_parallel_size: int,
distributed_executor_backend: str | None = None, distributed_executor_backend: str | None = None,
attention_config: dict | None = None,
): ):
"""Inference result should be the same between hf and vllm. """Inference result should be the same between hf and vllm.
@ -80,6 +83,7 @@ def run_test(
enable_lora=True, enable_lora=True,
max_lora_rank=64, max_lora_rank=64,
enforce_eager=True, enforce_eager=True,
attention_config=attention_config,
) as vllm_model: ) as vllm_model:
lora_request = LoRARequest("audio", 1, audio_lora_path) lora_request = LoRARequest("audio", 1, audio_lora_path)
vllm_outputs_per_case = [ vllm_outputs_per_case = [
@ -131,6 +135,7 @@ def test_models(
vllm_runner, vllm_runner,
model: str, model: str,
audio_assets: AudioTestAssets, audio_assets: AudioTestAssets,
granite_speech_attention_config,
dtype: str, dtype: str,
max_model_len: int, max_model_len: int,
max_tokens: int, max_tokens: int,
@ -157,4 +162,5 @@ def test_models(
max_tokens=max_tokens, max_tokens=max_tokens,
num_logprobs=num_logprobs, num_logprobs=num_logprobs,
tensor_parallel_size=1, tensor_parallel_size=1,
attention_config=granite_speech_attention_config,
) )

View File

@ -2,23 +2,17 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Pytest configuration for vLLM pooling tests.""" """Pytest configuration for vLLM pooling tests."""
import os import pytest
import warnings
from vllm.platforms import current_platform from vllm.platforms import current_platform
def pytest_collection_modifyitems(config, items): @pytest.fixture
"""Set FLEX_ATTENTION backend for SigLIP tests on ROCm.""" def siglip_attention_config():
if not current_platform.is_rocm(): """Return attention config for SigLIP tests on ROCm.
return
siglip_tests = [item for item in items if "test_siglip" in item.nodeid] On ROCm, SigLIP tests require FLEX_ATTENTION backend.
"""
if siglip_tests: if current_platform.is_rocm():
os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION" return {"backend": "FLEX_ATTENTION"}
warnings.warn( return None
"ROCm: Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION for SigLIP tests",
UserWarning,
stacklevel=1,
)

View File

@ -38,6 +38,7 @@ def _run_test(
*, *,
dtype: str, dtype: str,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
attention_config: dict[str, Any] | None = None,
) -> None: ) -> None:
if tokenization_kwargs is None: if tokenization_kwargs is None:
tokenization_kwargs = {} tokenization_kwargs = {}
@ -49,6 +50,7 @@ def _run_test(
enforce_eager=True, enforce_eager=True,
max_model_len=64, max_model_len=64,
gpu_memory_utilization=0.7, gpu_memory_utilization=0.7,
attention_config=attention_config,
) as vllm_model: ) as vllm_model:
vllm_outputs = vllm_model.embed( vllm_outputs = vllm_model.embed(
input_texts, images=input_images, tokenization_kwargs=tokenization_kwargs input_texts, images=input_images, tokenization_kwargs=tokenization_kwargs
@ -90,6 +92,7 @@ def test_models_text(
hf_runner, hf_runner,
vllm_runner, vllm_runner,
image_assets, image_assets,
siglip_attention_config,
model: str, model: str,
dtype: str, dtype: str,
) -> None: ) -> None:
@ -108,6 +111,7 @@ def test_models_text(
"padding": "max_length", "padding": "max_length",
"max_length": 64, "max_length": 64,
}, # siglip2 was trained with this padding setting. }, # siglip2 was trained with this padding setting.
attention_config=siglip_attention_config,
) )
@ -117,6 +121,7 @@ def test_models_image(
hf_runner, hf_runner,
vllm_runner, vllm_runner,
image_assets, image_assets,
siglip_attention_config,
model: str, model: str,
dtype: str, dtype: str,
) -> None: ) -> None:
@ -133,6 +138,7 @@ def test_models_image(
input_images, input_images,
model, model,
dtype=dtype, dtype=dtype,
attention_config=siglip_attention_config,
) )
@ -141,6 +147,7 @@ def test_models_image(
def test_models_text_image_no_crash( def test_models_text_image_no_crash(
vllm_runner, vllm_runner,
image_assets, image_assets,
siglip_attention_config,
model: str, model: str,
dtype: str, dtype: str,
) -> None: ) -> None:
@ -154,6 +161,7 @@ def test_models_text_image_no_crash(
enforce_eager=True, enforce_eager=True,
max_model_len=64, max_model_len=64,
gpu_memory_utilization=0.7, gpu_memory_utilization=0.7,
attention_config=siglip_attention_config,
) as vllm_model: ) as vllm_model:
with pytest.raises(ValueError, match="not both"): with pytest.raises(ValueError, match="not both"):
vllm_model.embed(texts, images=images) vllm_model.embed(texts, images=images)

View File

@ -60,12 +60,12 @@ def test_profiling(model_id: str, max_model_len: int):
total_num_patches.item() + num_tiles.item() + 3 total_num_patches.item() + num_tiles.item() + 3
) # image start, image, image end ) # image start, image, image end
profiled_tokens = profiler.get_mm_max_contiguous_tokens( profiled_tokens = profiler.get_mm_max_tokens(
max_model_len, max_model_len,
mm_counts=mm_counts, mm_counts=mm_counts,
) )
assert total_tokens == profiled_tokens["image"] assert total_num_patches == profiled_tokens["image"]
assert total_tokens == sum( assert total_tokens == sum(
placeholder.length placeholder.length
for placeholder in decoder_dummy_data.multi_modal_placeholders["image"] for placeholder in decoder_dummy_data.multi_modal_placeholders["image"]

View File

@ -75,7 +75,6 @@ def test_models(
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("TOKENIZERS_PARALLELISM", "true") m.setenv("TOKENIZERS_PARALLELISM", "true")
m.setenv("VLLM_ATTENTION_BACKEND", backend)
MAX_MODEL_LEN = 1024 MAX_MODEL_LEN = 1024
NUM_LOG_PROBS = 8 NUM_LOG_PROBS = 8
@ -86,6 +85,7 @@ def test_models(
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
kv_cache_dtype="auto", kv_cache_dtype="auto",
attention_config={"backend": backend},
) as vllm_model: ) as vllm_model:
baseline_outputs = vllm_model.generate_greedy_logprobs( baseline_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, NUM_LOG_PROBS example_prompts, max_tokens, NUM_LOG_PROBS
@ -97,6 +97,7 @@ def test_models(
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
attention_config={"backend": backend},
) as vllm_model: ) as vllm_model:
test_outputs = vllm_model.generate_greedy_logprobs( test_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, NUM_LOG_PROBS example_prompts, max_tokens, NUM_LOG_PROBS

View File

@ -108,11 +108,12 @@ def can_initialize(
patch.object(V1EngineCore, "_initialize_kv_caches", _initialize_kv_caches_v1), patch.object(V1EngineCore, "_initialize_kv_caches", _initialize_kv_caches_v1),
monkeypatch.context() as m, monkeypatch.context() as m,
): ):
if model_arch == "GptOssForCausalLM": # FIXME: A hack to bypass FA3 assertion because our CI's L4 GPU
# FIXME: A hack to bypass FA3 assertion because our CI's L4 GPU # has cc==8.9 which hasn't supported FA3 yet. Remove this hack when
# has cc==8.9 which hasn't supported FA3 yet. Remove this hack when # L4 supports FA3.
# L4 supports FA3. attention_config = (
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN") {"backend": "TRITON_ATTN"} if model_arch == "GptOssForCausalLM" else None
)
if model_arch == "WhisperForConditionalGeneration": if model_arch == "WhisperForConditionalGeneration":
m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
@ -143,6 +144,7 @@ def can_initialize(
else "vllm", else "vllm",
hf_overrides=hf_overrides_fn, hf_overrides=hf_overrides_fn,
max_num_seqs=model_info.max_num_seqs, max_num_seqs=model_info.max_num_seqs,
attention_config=attention_config,
) )

View File

@ -9,6 +9,7 @@ from tempfile import NamedTemporaryFile, TemporaryDirectory
import numpy as np import numpy as np
import pytest import pytest
import torch
from PIL import Image, ImageChops from PIL import Image, ImageChops
from vllm.multimodal.image import convert_image_mode from vllm.multimodal.image import convert_image_mode
@ -410,6 +411,97 @@ def test_argsort_mm_positions(case):
assert modality_idxs == expected_modality_idxs assert modality_idxs == expected_modality_idxs
@pytest.mark.parametrize(
"is_embed,expected",
[
(None, 5),
(torch.tensor([True, True, True, True, True]), 5),
(torch.tensor([False, False, False, False, False]), 0),
(torch.tensor([True, False, True, False, True]), 3),
(torch.tensor([True]), 1),
],
)
def test_placeholder_range_get_num_embeds(is_embed, expected):
length = len(is_embed) if is_embed is not None else 5
pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed)
assert pr.get_num_embeds == expected
@pytest.mark.parametrize(
"is_embed,expected",
[
(None, None),
(
torch.tensor([False, True, False, True, True]),
torch.tensor([0, 1, 1, 2, 3]),
),
(torch.tensor([True, True, True]), torch.tensor([1, 2, 3])),
],
)
def test_placeholder_range_embeds_cumsum(is_embed, expected):
length = len(is_embed) if is_embed is not None else 5
pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed)
if expected is None:
assert pr.embeds_cumsum is None
return
assert torch.equal(pr.embeds_cumsum, expected)
# cached_property should return the same object on repeated access
assert pr.embeds_cumsum is pr.embeds_cumsum
@pytest.mark.parametrize(
"is_embed,start_idx,end_idx,expected",
[
(None, 2, 4, (2, 4)),
(
torch.tensor([False, True, False, True, True]),
3,
5,
(1, 3),
),
(
torch.tensor([False, True, False, True, True]),
0,
2,
(0, 1),
),
(
torch.tensor([True, False, True, False]),
2,
2,
(1, 1),
),
],
)
def test_placeholder_range_get_embeds_indices_in_range(
is_embed, start_idx, end_idx, expected
):
length = len(is_embed) if is_embed is not None else 5
pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed)
assert pr.get_embeds_indices_in_range(start_idx, end_idx) == expected
@pytest.mark.parametrize(
"offset,is_embed,expected",
[
(0, None, [(0, 4)]),
(
2,
torch.tensor([False, True, False, True, True]),
[(3, 3), (5, 6)],
),
(0, torch.tensor([True, True, True, True]), [(0, 3)]),
(0, torch.tensor([False, False, False, False]), []),
],
)
def test_placeholder_range_extract_embeds_range(offset, is_embed, expected):
length = len(is_embed) if is_embed is not None else 5
pr = PlaceholderRange(offset=offset, length=length, is_embed=is_embed)
assert pr.extract_embeds_range() == expected
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS) @pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
@pytest.mark.parametrize("num_frames", [-1, 32, 1800]) @pytest.mark.parametrize("num_frames", [-1, 32, 1800])

View File

@ -0,0 +1,119 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import pytest
from vllm.tool_parsers.minimax_m2_tool_parser import (
MinimaxM2ToolParser,
)
pytestmark = pytest.mark.cpu_test
class FakeTokenizer:
"""Minimal fake tokenizer that exposes the attributes used by the
parser: a truthy model_tokenizer marker and a vocab mapping for the
special tokens.
"""
def __init__(self):
self.model_tokenizer = True
# The parser will look up start/end tokens by their literal strings
self.vocab = {
"<minimax:tool_call>": 1,
"</minimax:tool_call>": 2,
}
def get_vocab(self):
return self.vocab
@pytest.fixture
def minimax_m2_tool_parser():
return MinimaxM2ToolParser(FakeTokenizer())
def test_extract_tool_calls_streaming_incremental(minimax_m2_tool_parser):
parser = minimax_m2_tool_parser
parser._reset_streaming_state()
chunks = [
"<minimax:tool_call>",
'<invoke name="get_weather">',
'<parameter name="city">',
"Seattle</parameter>",
"</invoke></minimax:tool_call>",
]
previous = ""
for chunk in chunks:
current = previous + chunk
delta = chunk
parser.extract_tool_calls_streaming(
previous_text=previous,
current_text=current,
delta_text=delta,
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
previous = current
assert len(parser.prev_tool_call_arr) == 1
entry = parser.prev_tool_call_arr[0]
assert entry["name"] == "get_weather"
args = entry["arguments"]
assert args["city"] == "Seattle"
def test_streaming_minimax_m2_multiple_invokes(minimax_m2_tool_parser):
parser = minimax_m2_tool_parser
parser._reset_streaming_state()
chunks = [
"<minimax:tool_call>",
'<invoke name="search_web">',
'<parameter name="query_tag">',
'["technology", "events"]</parameter>',
'<parameter name="query_list">',
'["OpenAI", "latest", "release"]</parameter>',
"</invoke>",
'<invoke name="search_web">',
'<parameter name="query_tag">',
'["technology", "events"]</parameter>',
'<parameter name="query_list">',
'["Gemini", "latest", "release"]</parameter>',
"</invoke>",
"</minimax:tool_call>",
]
previous = ""
for chunk in chunks:
current = previous + chunk
delta = chunk
parser.extract_tool_calls_streaming(
previous_text=previous,
current_text=current,
delta_text=delta,
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
previous = current
assert len(parser.prev_tool_call_arr) == 2
for entry, expect_model in zip(parser.prev_tool_call_arr, ["OpenAI", "Gemini"]):
assert entry["name"] == "search_web"
args = json.dumps(entry["arguments"])
assert "technology" in args and "events" in args
assert expect_model in args
# check streamed_args_for_tool for serving_chat.py
for index in range(2):
expected_call = parser.prev_tool_call_arr[index].get("arguments", {})
expected_call = json.dumps(expected_call)
actual_call = parser.streamed_args_for_tool[index]
assert expected_call == actual_call

View File

@ -172,7 +172,7 @@ def test_local_attention_virtual_batches(test_data: LocalAttentionTestData):
) )
# Call the function # Call the function
result = make_local_attention_virtual_batches( result, _ = make_local_attention_virtual_batches(
attn_chunk_size, common_attn_metadata, block_size attn_chunk_size, common_attn_metadata, block_size
) )

View File

@ -94,26 +94,20 @@ def mock_on_gfx9():
None, None,
AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path(), AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path(),
), ),
# Test Case 9: VLLM_V1_USE_PREFILL_DECODE_ATTENTION=1 # Test Case 9: VLLM_ROCM_USE_AITER=1 + explicit TRITON_ATTN
(
{"VLLM_V1_USE_PREFILL_DECODE_ATTENTION": "1"},
None,
AttentionBackendEnum.ROCM_ATTN.get_path(),
),
# Test Case 10: VLLM_ROCM_USE_AITER=1 + explicit TRITON_ATTN
( (
{"VLLM_ROCM_USE_AITER": "1"}, {"VLLM_ROCM_USE_AITER": "1"},
"TRITON_ATTN", "TRITON_ATTN",
AttentionBackendEnum.TRITON_ATTN.get_path(), AttentionBackendEnum.TRITON_ATTN.get_path(),
), ),
# Test Case 11: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_MHA=0 # Test Case 10: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_MHA=0
# (explicitly disabled) # (explicitly disabled)
( (
{"VLLM_ROCM_USE_AITER": "1", "VLLM_ROCM_USE_AITER_MHA": "0"}, {"VLLM_ROCM_USE_AITER": "1", "VLLM_ROCM_USE_AITER_MHA": "0"},
None, None,
AttentionBackendEnum.TRITON_ATTN.get_path(), AttentionBackendEnum.TRITON_ATTN.get_path(),
), ),
# Test Case 12: VLLM_ROCM_USE_AITER=1 + explicit ROCM_ATTN # Test Case 11: VLLM_ROCM_USE_AITER=1 + explicit ROCM_ATTN
( (
{"VLLM_ROCM_USE_AITER": "1"}, {"VLLM_ROCM_USE_AITER": "1"},
"ROCM_ATTN", "ROCM_ATTN",

View File

@ -249,8 +249,8 @@ def create_dummy_kv_cache(
@dataclass @dataclass
class BackendConfig: class BackendConfig:
name: str name: str
env_vars: dict attention_config: dict
comp_config: dict # compilation config comp_config: dict
specific_gpu_arch: tuple | None = None specific_gpu_arch: tuple | None = None
@ -259,10 +259,10 @@ full_cg_backend_configs = {
# FA3 on Hopper # FA3 on Hopper
"FA3": BackendConfig( "FA3": BackendConfig(
name="FA3", name="FA3",
env_vars={ attention_config={
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN", "backend": "FLASH_ATTN",
"VLLM_FLASH_ATTN_VERSION": "3", "flash_attn_version": 3,
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", "flash_attn_max_num_splits_for_cuda_graph": 16,
}, },
comp_config={ comp_config={
"cudagraph_mode": "FULL", "cudagraph_mode": "FULL",
@ -272,9 +272,7 @@ full_cg_backend_configs = {
# FlashMLA on Hopper # FlashMLA on Hopper
"FlashMLA": BackendConfig( "FlashMLA": BackendConfig(
name="FlashMLA", name="FlashMLA",
env_vars={ attention_config={"backend": "FLASHMLA"},
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
},
comp_config={ comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE", "cudagraph_mode": "FULL_AND_PIECEWISE",
}, },
@ -283,9 +281,7 @@ full_cg_backend_configs = {
# Cutlass MLA on Blackwell # Cutlass MLA on Blackwell
"CutlassMLA": BackendConfig( "CutlassMLA": BackendConfig(
name="CutlassMLA", name="CutlassMLA",
env_vars={ attention_config={"backend": "CUTLASS_MLA"},
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
},
comp_config={ comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE", "cudagraph_mode": "FULL_AND_PIECEWISE",
}, },
@ -294,9 +290,7 @@ full_cg_backend_configs = {
# FlashInfer MLA on Blackwell # FlashInfer MLA on Blackwell
"FlashInferMLA": BackendConfig( "FlashInferMLA": BackendConfig(
name="FlashInferMLA", name="FlashInferMLA",
env_vars={ attention_config={"backend": "FLASHINFER_MLA"},
"VLLM_ATTENTION_BACKEND": "FLASHINFER_MLA",
},
comp_config={ comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE", "cudagraph_mode": "FULL_AND_PIECEWISE",
}, },
@ -305,9 +299,9 @@ full_cg_backend_configs = {
# FlashAttention MLA on Hopper # FlashAttention MLA on Hopper
"FlashAttentionMLA": BackendConfig( "FlashAttentionMLA": BackendConfig(
name="FlashAttentionMLA", name="FlashAttentionMLA",
env_vars={ attention_config={
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA", "backend": "FLASH_ATTN_MLA",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", "flash_attn_max_num_splits_for_cuda_graph": 16,
}, },
comp_config={ comp_config={
"cudagraph_mode": "FULL_DECODE_ONLY", "cudagraph_mode": "FULL_DECODE_ONLY",
@ -317,10 +311,10 @@ full_cg_backend_configs = {
# FA2 # FA2
"FA2": BackendConfig( "FA2": BackendConfig(
name="FA2", name="FA2",
env_vars={ attention_config={
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN", "backend": "FLASH_ATTN",
"VLLM_FLASH_ATTN_VERSION": "2", "flash_attn_version": 2,
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", "flash_attn_max_num_splits_for_cuda_graph": 16,
}, },
comp_config={ comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE", "cudagraph_mode": "FULL_AND_PIECEWISE",
@ -329,7 +323,7 @@ full_cg_backend_configs = {
# Triton Attention # Triton Attention
"TritonAttn": BackendConfig( "TritonAttn": BackendConfig(
name="TritonAttn", name="TritonAttn",
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"}, attention_config={"backend": "TRITON_ATTN"},
comp_config={ comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE", "cudagraph_mode": "FULL_AND_PIECEWISE",
}, },
@ -337,14 +331,17 @@ full_cg_backend_configs = {
# FlashInfer # FlashInfer
"FlashInfer": BackendConfig( "FlashInfer": BackendConfig(
name="FlashInfer", name="FlashInfer",
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"}, attention_config={"backend": "FLASHINFER"},
comp_config={ comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE", "cudagraph_mode": "FULL_AND_PIECEWISE",
}, },
), ),
"RocmAttn": BackendConfig( "RocmAttn": BackendConfig(
name="RocmAttn", name="RocmAttn",
env_vars={"VLLM_V1_USE_PREFILL_DECODE_ATTENTION": "1"}, attention_config={
"backend": "ROCM_ATTN",
"use_prefill_decode_attention": True,
},
comp_config={ comp_config={
"cudagraph_mode": "FULL", "cudagraph_mode": "FULL",
}, },

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest import pytest
import torch
from vllm.multimodal.inputs import MultiModalFeatureSpec, PlaceholderRange from vllm.multimodal.inputs import MultiModalFeatureSpec, PlaceholderRange
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
@ -23,7 +24,7 @@ class MockRequest:
) )
self.mm_features.append(feature) self.mm_features.append(feature)
def get_num_encoder_tokens(self, input_id: int) -> int: def get_num_encoder_embeds(self, input_id: int) -> int:
return self._token_counts[input_id] return self._token_counts[input_id]
@ -162,8 +163,8 @@ def test_schedule_request_multi_images_respect_space_limit():
num_tokens_to_schedule = 0 num_tokens_to_schedule = 0
assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule) assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule)
num_tokens_to_schedule += req.get_num_encoder_tokens(0) num_tokens_to_schedule += req.get_num_encoder_embeds(0)
compute_budget -= req.get_num_encoder_tokens(0) compute_budget -= req.get_num_encoder_embeds(0)
assert not manager.can_allocate(req, 1, compute_budget, num_tokens_to_schedule) assert not manager.can_allocate(req, 1, compute_budget, num_tokens_to_schedule)
@ -174,7 +175,75 @@ def test_schedule_request_multi_images_respect_compute_limit():
compute_budget = 10 compute_budget = 10
num_tokens_to_schedule = 0 num_tokens_to_schedule = 0
assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule) assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule)
num_tokens_to_schedule += req.get_num_encoder_tokens(0) num_tokens_to_schedule += req.get_num_encoder_embeds(0)
compute_budget -= req.get_num_encoder_tokens(0) compute_budget -= req.get_num_encoder_embeds(0)
assert not manager.can_allocate(req, 1, compute_budget, num_tokens_to_schedule) assert not manager.can_allocate(req, 1, compute_budget, num_tokens_to_schedule)
def test_encoder_cache_with_is_embed_mask():
class MockRequestWithMask(MockRequest):
def get_num_encoder_embeds(self, input_id: int) -> int:
return self.mm_features[input_id].mm_position.get_num_embeds
is_embed = torch.zeros(100, dtype=torch.bool)
is_embed[torch.tensor([5, 15, 25, 35, 45, 55, 65, 75])] = True
request = MockRequestWithMask("r1", ["img1"], [100])
request.mm_features[0] = MultiModalFeatureSpec(
data=None,
modality="image",
identifier="img1",
mm_position=PlaceholderRange(offset=0, length=100, is_embed=is_embed),
)
manager = EncoderCacheManager(cache_size=100)
manager.allocate(request, 0)
assert manager.num_free_slots == 92
assert "img1" in manager.cached
old_size = 100
new_size = request.mm_features[0].mm_position.get_num_embeds
assert new_size == 8
savings_ratio = old_size / new_size
assert savings_ratio == 12.5
def test_encoder_cache_mask_based_retrieval():
class MockRequestWithMask(MockRequest):
def get_num_encoder_embeds(self, input_id: int) -> int:
return self.mm_features[input_id].mm_position.get_num_embeds
is_embed = torch.tensor(
[False, False, True, True, False, True, True, True, False, False]
)
request = MockRequestWithMask("r1", ["img1"], [10])
request.mm_features[0] = MultiModalFeatureSpec(
data=None,
modality="image",
identifier="img1",
mm_position=PlaceholderRange(offset=0, length=10, is_embed=is_embed),
)
manager = EncoderCacheManager(cache_size=50)
manager.allocate(request, 0)
assert request.mm_features[0].mm_position.get_num_embeds == 5
start_idx = 2
end_idx = 8
num_embeds_before = is_embed[:start_idx].sum().item()
num_embeds_in_range = is_embed[start_idx:end_idx].sum().item()
assert num_embeds_before == 0
assert num_embeds_in_range == 5
start_idx = 0
end_idx = 5
num_embeds_before = is_embed[:start_idx].sum().item() if start_idx > 0 else 0
num_embeds_in_range = is_embed[start_idx:end_idx].sum().item()
assert num_embeds_before == 0
assert num_embeds_in_range == 2

View File

@ -1,7 +1,5 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import os
import weakref import weakref
from contextlib import ExitStack from contextlib import ExitStack
@ -13,26 +11,6 @@ from vllm import LLM
from vllm.config import CompilationConfig, CompilationMode from vllm.config import CompilationConfig, CompilationMode
from vllm.platforms import current_platform from vllm.platforms import current_platform
@contextlib.contextmanager
def temporary_environ(env_vars):
"""
Temporarily set environment variables and restore them afterward.
We have to do this vs monkeypatch because monkeypatch doesn't work
with "module" scoped fixtures.
"""
original_env = {k: os.environ.get(k) for k in env_vars}
try:
os.environ.update(env_vars)
yield
finally:
for k, v in original_env.items():
if v is None:
os.environ.pop(k, None)
else:
os.environ[k] = v
# test attention backend and cudagraph_mode combo # test attention backend and cudagraph_mode combo
# (backend_name, cudagraph_mode, supported) # (backend_name, cudagraph_mode, supported)
if current_platform.is_rocm(): if current_platform.is_rocm():
@ -68,9 +46,9 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte
): ):
pytest.skip("Only Hopper GPUs support FA3 and FlashMLA") pytest.skip("Only Hopper GPUs support FA3 and FlashMLA")
env_vars = backend_configs[backend_name].env_vars attention_config = backend_config.attention_config
with temporary_environ(env_vars), ExitStack() as stack: with ExitStack() as stack:
if not supported: if not supported:
stack.enter_context(pytest.raises(Exception)) stack.enter_context(pytest.raises(Exception))
@ -80,6 +58,7 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte
trust_remote_code=True, trust_remote_code=True,
gpu_memory_utilization=0.45, gpu_memory_utilization=0.45,
max_model_len=1024, max_model_len=1024,
attention_config=attention_config,
compilation_config=CompilationConfig( compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE, cudagraph_mode=cudagraph_mode mode=CompilationMode.VLLM_COMPILE, cudagraph_mode=cudagraph_mode
), ),
@ -122,9 +101,10 @@ combo_cases_2 = [
def test_cudagraph_compilation_combo( def test_cudagraph_compilation_combo(
backend_name, cudagraph_mode, compilation_mode, supported backend_name, cudagraph_mode, compilation_mode, supported
): ):
env_vars = backend_configs[backend_name].env_vars backend_config = backend_configs[backend_name]
attention_config = backend_config.attention_config
with temporary_environ(env_vars), ExitStack() as stack: with ExitStack() as stack:
if not supported: if not supported:
stack.enter_context(pytest.raises(Exception)) stack.enter_context(pytest.raises(Exception))
@ -134,6 +114,7 @@ def test_cudagraph_compilation_combo(
trust_remote_code=True, trust_remote_code=True,
gpu_memory_utilization=0.45, gpu_memory_utilization=0.45,
max_model_len=1024, max_model_len=1024,
attention_config=attention_config,
compilation_config=CompilationConfig( compilation_config=CompilationConfig(
mode=compilation_mode, cudagraph_mode=cudagraph_mode mode=compilation_mode, cudagraph_mode=cudagraph_mode
), ),

View File

@ -28,7 +28,7 @@ IS_DEVICE_CAPABILITY_BELOW_90 = is_device_capability_below_90()
BACKENDS, BACKENDS,
) )
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle( def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
backend, monkeypatch: pytest.MonkeyPatch backend,
): ):
""" """
Ensures that the same request (the 'needle' prompt) yields identical output Ensures that the same request (the 'needle' prompt) yields identical output
@ -54,7 +54,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
seed = int(os.getenv("VLLM_TEST_SEED", "12345")) seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed) random.seed(seed)
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend) attention_config = {"backend": backend}
# Allow overrides from environment (useful for CI tuning) # Allow overrides from environment (useful for CI tuning)
# "facebook/opt-125m" is too small, doesn't reliably test determinism # "facebook/opt-125m" is too small, doesn't reliably test determinism
model = resolve_model_name(backend) model = resolve_model_name(backend)
@ -92,6 +92,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
max_num_seqs=max_batch_size, max_num_seqs=max_batch_size,
gpu_memory_utilization=gpu_mem_util, gpu_memory_utilization=gpu_mem_util,
max_model_len=max_model_len, max_model_len=max_model_len,
attention_config=attention_config,
) )
# Baseline generation for the needle prompt alone. # Baseline generation for the needle prompt alone.
@ -106,6 +107,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
max_num_seqs=max_batch_size, max_num_seqs=max_batch_size,
gpu_memory_utilization=gpu_mem_util, gpu_memory_utilization=gpu_mem_util,
max_model_len=max_model_len, max_model_len=max_model_len,
attention_config=attention_config,
) )
mismatches = 0 mismatches = 0
@ -163,10 +165,8 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
BACKENDS, BACKENDS,
) )
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
backend, monkeypatch: pytest.MonkeyPatch backend,
): ):
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
seed = int(os.getenv("VLLM_TEST_SEED", "12345")) seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed) random.seed(seed)
model_name = resolve_model_name(backend) model_name = resolve_model_name(backend)
@ -188,12 +188,12 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
llm = LLM( llm = LLM(
model=model_name, model=model_name,
tensor_parallel_size=tp_size, tensor_parallel_size=tp_size,
# enable_prefix_caching=False,
max_num_seqs=32, max_num_seqs=32,
max_model_len=8192, max_model_len=8192,
dtype="bfloat16", # not everything is supported dtype="bfloat16", # not everything is supported
gpu_memory_utilization=0.9, gpu_memory_utilization=0.9,
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90, enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
attention_config={"backend": backend},
) )
# Use more realistic prompts for better token generation # Use more realistic prompts for better token generation
@ -382,12 +382,11 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
"backend", "backend",
BACKENDS, BACKENDS,
) )
def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch): def test_simple_generation(backend):
""" """
Simple test that runs the model with a basic prompt and prints the output. Simple test that runs the model with a basic prompt and prints the output.
Useful for quick smoke testing and debugging. Useful for quick smoke testing and debugging.
""" """
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
model = resolve_model_name(backend) model = resolve_model_name(backend)
llm = LLM( llm = LLM(
@ -399,6 +398,7 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
dtype="bfloat16", dtype="bfloat16",
enable_prefix_caching=False, enable_prefix_caching=False,
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90, enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
attention_config={"backend": backend},
) )
prompt = "the capital of france is" prompt = "the capital of france is"
@ -445,8 +445,6 @@ def test_logprobs_without_batch_invariance_should_fail(
The test will PASS if we detect differences (proving batch invariance matters). The test will PASS if we detect differences (proving batch invariance matters).
The test will FAIL if everything matches (suggesting batch invariance isn't needed). The test will FAIL if everything matches (suggesting batch invariance isn't needed).
""" """
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
# CRITICAL: Disable batch invariance for this test # CRITICAL: Disable batch invariance for this test
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "0") monkeypatch.setenv("VLLM_BATCH_INVARIANT", "0")
monkeypatch.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", False) monkeypatch.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", False)
@ -466,6 +464,7 @@ def test_logprobs_without_batch_invariance_should_fail(
max_model_len=8192, max_model_len=8192,
dtype="bfloat16", dtype="bfloat16",
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90, enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
attention_config={"backend": backend},
) )
# build ragged prompts to change shapes significantly across BS=1 vs BS=N # build ragged prompts to change shapes significantly across BS=1 vs BS=N
@ -650,7 +649,7 @@ def test_logprobs_without_batch_invariance_should_fail(
@skip_unsupported @skip_unsupported
@pytest.mark.parametrize("backend", ["FLASH_ATTN"]) @pytest.mark.parametrize("backend", ["FLASH_ATTN"])
def test_decode_logprobs_match_prefill_logprobs( def test_decode_logprobs_match_prefill_logprobs(
backend, monkeypatch: pytest.MonkeyPatch backend,
): ):
""" """
Test that verifies decode logprobs match prefill logprobs. Test that verifies decode logprobs match prefill logprobs.
@ -665,8 +664,6 @@ def test_decode_logprobs_match_prefill_logprobs(
This ensures that the logprobs from decode are consistent with what This ensures that the logprobs from decode are consistent with what
we would get if we ran prefill on each prefix. we would get if we ran prefill on each prefix.
""" """
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
seed = int(os.getenv("VLLM_TEST_SEED", "12345")) seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed) random.seed(seed)
model_name = resolve_model_name(backend) model_name = resolve_model_name(backend)
@ -690,6 +687,7 @@ def test_decode_logprobs_match_prefill_logprobs(
max_model_len=8192, max_model_len=8192,
dtype="bfloat16", dtype="bfloat16",
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90, enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
attention_config={"backend": backend},
) )
# Use a few test prompts # Use a few test prompts
@ -921,6 +919,7 @@ def LLM_with_max_seqs(
max_num_seqs: int, max_num_seqs: int,
gpu_memory_utilization: float, gpu_memory_utilization: float,
max_model_len: int, max_model_len: int,
attention_config: dict | None = None,
) -> LLM: ) -> LLM:
""" """
Helper to construct an LLM with a specific max_num_seqs (batch-size limit) Helper to construct an LLM with a specific max_num_seqs (batch-size limit)
@ -935,6 +934,7 @@ def LLM_with_max_seqs(
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")), tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
enable_prefix_caching=False, enable_prefix_caching=False,
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90, enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
attention_config=attention_config,
# Enable for MOE models # Enable for MOE models
# enable_expert_parallel=True, # enable_expert_parallel=True,
) )

View File

@ -136,11 +136,9 @@ def _compare_bs1_vs_bsn_single_process(
@skip_unsupported @skip_unsupported
@pytest.mark.parametrize("backend", BACKENDS) @pytest.mark.parametrize("backend", BACKENDS)
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
backend: str, monkeypatch: pytest.MonkeyPatch backend: str,
) -> None: ) -> None:
random.seed(int(os.getenv("VLLM_TEST_SEED", "12345"))) random.seed(int(os.getenv("VLLM_TEST_SEED", "12345")))
# Override backend for this test (and the RemoteOpenAIServer child process).
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
model_name = resolve_model_name(backend) model_name = resolve_model_name(backend)
prompts_all = [_random_prompt(10, 50) for _ in range(32)] prompts_all = [_random_prompt(10, 50) for _ in range(32)]
@ -156,6 +154,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
server_args: list[str] = [ server_args: list[str] = [
"--max-model-len=8192", "--max-model-len=8192",
"--max-num-seqs=32", "--max-num-seqs=32",
f"--attention-backend={backend}",
] ]
if tp_size: if tp_size:
server_args += ["-tp", tp_size] server_args += ["-tp", tp_size]

View File

@ -142,16 +142,17 @@ def run_tests(
"""Test consistency of combos of async scheduling, preemption, """Test consistency of combos of async scheduling, preemption,
uni/multiproc executor with spec decoding.""" uni/multiproc executor with spec decoding."""
with monkeypatch.context() as m: # Determine attention config based on platform
# avoid precision errors if current_platform.is_rocm():
if current_platform.is_rocm(): if is_testing_with_spec_decoding:
if is_testing_with_spec_decoding: # Use TRITON_ATTN for spec decoding test for consistency
# Use TRITON_ATTN for spec decoding test for consistency attention_config = {"backend": "TRITON_ATTN"}
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
else:
m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_AITER_FA")
else: else:
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION") attention_config = {"backend": "ROCM_AITER_FA"}
else:
attention_config = {"backend": "FLEX_ATTENTION"}
with monkeypatch.context() as m:
# lock matmul precision to full FP32 (IEEE) # lock matmul precision to full FP32 (IEEE)
m.setenv("VLLM_FLOAT32_MATMUL_PRECISION", "ieee") m.setenv("VLLM_FLOAT32_MATMUL_PRECISION", "ieee")
# m.setenv("VLLM_BATCH_INVARIANT", "1") # m.setenv("VLLM_BATCH_INVARIANT", "1")
@ -174,6 +175,7 @@ def run_tests(
spec_config, spec_config,
test_prefill_chunking=test_prefill_chunking, test_prefill_chunking=test_prefill_chunking,
is_testing_with_spec_decoding=is_testing_with_spec_decoding, is_testing_with_spec_decoding=is_testing_with_spec_decoding,
attention_config=attention_config,
) )
outputs.append(test_results) outputs.append(test_results)
@ -262,6 +264,7 @@ def run_test(
spec_config: dict[str, Any] | None, spec_config: dict[str, Any] | None,
test_prefill_chunking: bool, test_prefill_chunking: bool,
is_testing_with_spec_decoding: bool = False, is_testing_with_spec_decoding: bool = False,
attention_config: dict[str, Any] | None = None,
): ):
spec_decoding = spec_config is not None spec_decoding = spec_config is not None
cache_arg: dict[str, Any] = ( cache_arg: dict[str, Any] = (
@ -301,6 +304,7 @@ def run_test(
dtype=dtype, dtype=dtype,
speculative_config=spec_config, speculative_config=spec_config,
disable_log_stats=False, disable_log_stats=False,
attention_config=attention_config,
**cache_arg, **cache_arg,
) as vllm_model: ) as vllm_model:
results = [] results = []

View File

@ -10,7 +10,7 @@ from ...utils import create_new_process_for_each_test
@create_new_process_for_each_test() @create_new_process_for_each_test()
@pytest.mark.parametrize("attn_backend", ["FLASH_ATTN", "FLASHINFER"]) @pytest.mark.parametrize("attn_backend", ["FLASH_ATTN", "FLASHINFER"])
def test_cascade_attention(example_system_message, monkeypatch, attn_backend): def test_cascade_attention(example_system_message, attn_backend):
prompt = "\n<User>: Implement fibonacci sequence in Python.\n<Claude>:" prompt = "\n<User>: Implement fibonacci sequence in Python.\n<Claude>:"
if attn_backend == "FLASHINFER": if attn_backend == "FLASHINFER":
@ -19,19 +19,18 @@ def test_cascade_attention(example_system_message, monkeypatch, attn_backend):
"needs investigation. See issue #25679." "needs investigation. See issue #25679."
) )
with monkeypatch.context() as m: llm = LLM(
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) model="Qwen/Qwen2-1.5B-Instruct", attention_config={"backend": attn_backend}
)
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct") # No cascade attention.
sampling_params = SamplingParams(temperature=0.0, max_tokens=100) single_prompt = [example_system_message + prompt]
responses = llm.generate(single_prompt, sampling_params)
ref_output = responses[0].outputs[0].text
# No cascade attention. # (Probably) Use cascade attention.
single_prompt = [example_system_message + prompt] prompts = [example_system_message + prompt] * 64
responses = llm.generate(single_prompt, sampling_params) responses = llm.generate(prompts, sampling_params)
ref_output = responses[0].outputs[0].text for response in responses:
assert response.outputs[0].text == ref_output
# (Probably) Use cascade attention.
prompts = [example_system_message + prompt] * 64
responses = llm.generate(prompts, sampling_params)
for response in responses:
assert response.outputs[0].text == ref_output

View File

@ -438,25 +438,26 @@ def test_eagle_correctness(
should be the same when using eagle speculative decoding. should be the same when using eagle speculative decoding.
model_setup: (method, model_name, eagle_model_name, tp_size) model_setup: (method, model_name, eagle_model_name, tp_size)
""" """
# Determine attention config
# Scout requires default backend selection because vision encoder has
# head_dim 88 being incompatible with FLASH_ATTN and needs to fall back
# to Flex Attn
if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN":
if current_platform.is_rocm():
# TODO: Enable Flex Attn for spec_decode on ROCm
pytest.skip("Flex Attn for spec_decode not supported on ROCm currently")
attention_config = None # Let it fall back to default
else:
attention_config = {"backend": attn_backend}
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
pytest.skip(
"TRITON_ATTN does not support "
"multi-token eagle spec decode on current platform"
)
with monkeypatch.context() as m: with monkeypatch.context() as m:
if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN": m.setenv("VLLM_MLA_DISABLE", "1")
# Scout requires default backend selection
# because vision encoder has head_dim 88 being incompatible
# with FLASH_ATTN and needs to fall back to Flex Attn
# pass if not ROCm
if current_platform.is_rocm():
# TODO: Enable Flex Attn for spec_decode on ROCm
pytest.skip("Flex Attn for spec_decode not supported on ROCm currently")
else:
m.setenv("VLLM_MLA_DISABLE", "1")
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
pytest.skip(
"TRITON_ATTN does not support "
"multi-token eagle spec decode on current platform"
)
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm(): if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
if "deepseek" in model_setup[1].lower(): if "deepseek" in model_setup[1].lower():
@ -471,7 +472,10 @@ def test_eagle_correctness(
max_num_batched_tokens = 128 if enable_chunked_prefill else max_model_len max_num_batched_tokens = 128 if enable_chunked_prefill else max_model_len
ref_llm = LLM( ref_llm = LLM(
model=model_name, max_model_len=max_model_len, tensor_parallel_size=tp_size model=model_name,
max_model_len=max_model_len,
tensor_parallel_size=tp_size,
attention_config=attention_config,
) )
ref_outputs = ref_llm.chat(test_prompts, sampling_config) ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm del ref_llm
@ -492,6 +496,7 @@ def test_eagle_correctness(
max_num_batched_tokens=max_num_batched_tokens, max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=enable_chunked_prefill, enable_chunked_prefill=enable_chunked_prefill,
model_impl=model_impl, model_impl=model_impl,
attention_config=attention_config,
) )
spec_outputs = spec_llm.chat(test_prompts, sampling_config) spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0 matches = 0

View File

@ -38,7 +38,7 @@ class MockRequest:
) )
self.mm_features.append(feature) self.mm_features.append(feature)
def get_num_encoder_tokens(self, input_id: int) -> int: def get_num_encoder_embeds(self, input_id: int) -> int:
assert input_id < len(self._token_counts) assert input_id < len(self._token_counts)
return self._token_counts[input_id] return self._token_counts[input_id]

View File

@ -3,21 +3,29 @@ set -xe
# Parse command line arguments # Parse command line arguments
KV_BUFFER_DEVICE="cuda" # Default to cuda KV_BUFFER_DEVICE="cuda" # Default to cuda
ATTENTION_BACKEND="" # Default to empty (use vllm default)
while [[ $# -gt 0 ]]; do while [[ $# -gt 0 ]]; do
case $1 in case $1 in
--kv_buffer_device) --kv_buffer_device)
KV_BUFFER_DEVICE="$2" KV_BUFFER_DEVICE="$2"
shift 2 shift 2
;; ;;
--attention-backend)
ATTENTION_BACKEND="$2"
shift 2
;;
*) *)
echo "Unknown option $1" echo "Unknown option $1"
echo "Usage: $0 [--kv_buffer_device <cuda|cpu>]" echo "Usage: $0 [--kv_buffer_device <cuda|cpu>] [--attention-backend <backend>]"
exit 1 exit 1
;; ;;
esac esac
done done
echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE" echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE"
if [[ -n "$ATTENTION_BACKEND" ]]; then
echo "Using attention backend: $ATTENTION_BACKEND"
fi
DECODER_KV_LAYOUT=${DECODER_KV_LAYOUT:-"HND"} # Default to HND, optional NHD DECODER_KV_LAYOUT=${DECODER_KV_LAYOUT:-"HND"} # Default to HND, optional NHD
if [[ "$DECODER_KV_LAYOUT" == "NHD" ]]; then if [[ "$DECODER_KV_LAYOUT" == "NHD" ]]; then
@ -148,6 +156,11 @@ run_tests_for_model() {
--tensor-parallel-size $PREFILLER_TP_SIZE \ --tensor-parallel-size $PREFILLER_TP_SIZE \
--kv-transfer-config '$KV_CONFIG'" --kv-transfer-config '$KV_CONFIG'"
# Add attention backend config if specified
if [[ -n "$ATTENTION_BACKEND" ]]; then
BASE_CMD="${BASE_CMD} --attention-backend=$ATTENTION_BACKEND"
fi
if [ -n "$model_args" ]; then if [ -n "$model_args" ]; then
FULL_CMD="$BASE_CMD $model_args" FULL_CMD="$BASE_CMD $model_args"
else else
@ -188,7 +201,12 @@ run_tests_for_model() {
--block-size ${DECODE_BLOCK_SIZE} \ --block-size ${DECODE_BLOCK_SIZE} \
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \ --gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
--kv-transfer-config '$KV_CONFIG'" --kv-transfer-config '$KV_CONFIG'"
# Add attention backend config if specified
if [[ -n "$ATTENTION_BACKEND" ]]; then
BASE_CMD="${BASE_CMD} --attention-backend=$ATTENTION_BACKEND"
fi
# DP-EP attention mode # DP-EP attention mode
if [[ -z "$DP_EP" ]]; then if [[ -z "$DP_EP" ]]; then
BASE_CMD="${BASE_CMD} --tensor-parallel-size $DECODER_TP_SIZE" BASE_CMD="${BASE_CMD} --tensor-parallel-size $DECODER_TP_SIZE"

View File

@ -15,14 +15,14 @@ configs=(
run_tests() { run_tests() {
local label=$1 local label=$1
local extra_env=$2 local extra_args=$2
echo "=== Running tests (${label}) ===" echo "=== Running tests (${label}) ==="
for cfg in "${configs[@]}"; do for cfg in "${configs[@]}"; do
echo "-> Running with ${cfg} ${extra_env:+and ${extra_env}}" echo "-> Running with ${cfg} ${extra_args:+and ${extra_args}}"
# Use 'env' to safely set variables without eval # Use 'env' to safely set variables without eval
if ! env ${extra_env} ${cfg} bash "${SCRIPT}"; then if ! env ${cfg} bash "${SCRIPT}" ${extra_args}; then
echo "❌ Test failed for config: ${cfg} ${extra_env:+(${extra_env})}" echo "❌ Test failed for config: ${cfg} ${extra_args:+(${extra_args})}"
exit 1 exit 1
fi fi
done done
@ -34,8 +34,8 @@ run_tests "default backend" ""
# Check if FLASHINFER is set (non-empty) # Check if FLASHINFER is set (non-empty)
if [[ -n "${FLASHINFER:-}" ]]; then if [[ -n "${FLASHINFER:-}" ]]; then
echo "FLASHINFER is set, rerunning with VLLM_ATTENTION_BACKEND=FLASHINFER" echo "FLASHINFER is set, rerunning with --attention-backend FLASHINFER"
run_tests "FLASHINFER backend" "VLLM_ATTENTION_BACKEND=FLASHINFER" run_tests "FLASHINFER backend" "--attention-backend FLASHINFER"
else else
echo "FLASHINFER not set, skipping FLASHINFER runs." echo "FLASHINFER not set, skipping FLASHINFER runs."
fi fi

View File

@ -1132,7 +1132,7 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
"TRITON_ATTN", "TRITON_ATTN",
], ],
) )
def test_register_kv_caches(dist_init, attn_backend, monkeypatch): def test_register_kv_caches(dist_init, attn_backend):
""" """
Test that register_kv_caches() properly calls nixl_wrapper methods with Test that register_kv_caches() properly calls nixl_wrapper methods with
correct data. correct data.
@ -1144,9 +1144,7 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
block layout info block layout info
""" """
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend) vllm_config = create_vllm_config(attention_backend=attn_backend)
vllm_config = create_vllm_config()
# Import the appropriate backend based on the parameter # Import the appropriate backend based on the parameter
if attn_backend == "FLASH_ATTN": if attn_backend == "FLASH_ATTN":

View File

@ -11,6 +11,7 @@ import torch
from vllm import SamplingParams from vllm import SamplingParams
from vllm.config import ( from vllm.config import (
AttentionConfig,
CacheConfig, CacheConfig,
DeviceConfig, DeviceConfig,
KVTransferConfig, KVTransferConfig,
@ -94,6 +95,7 @@ def create_vllm_config(
dtype: str = "float16", dtype: str = "float16",
cache_dtype: str = "auto", cache_dtype: str = "auto",
hf_overrides: dict[str, Any] | None = None, hf_overrides: dict[str, Any] | None = None,
attention_backend: str | None = None,
) -> VllmConfig: ) -> VllmConfig:
"""Initialize VllmConfig For Testing.""" """Initialize VllmConfig For Testing."""
model_config = ModelConfig( model_config = ModelConfig(
@ -124,12 +126,14 @@ def create_vllm_config(
enable_permute_local_kv=enable_permute_local_kv, enable_permute_local_kv=enable_permute_local_kv,
kv_connector_extra_config=kv_connector_extra_config or {}, kv_connector_extra_config=kv_connector_extra_config or {},
) )
attention_config = AttentionConfig(backend=attention_backend)
return VllmConfig( return VllmConfig(
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
model_config=model_config, model_config=model_config,
cache_config=cache_config, cache_config=cache_config,
kv_transfer_config=kv_transfer_config, kv_transfer_config=kv_transfer_config,
device_config=DeviceConfig("cpu"), device_config=DeviceConfig("cpu"),
attention_config=attention_config,
) )

View File

@ -13,7 +13,6 @@ from vllm import LLM, SamplingParams, TokensPrompt
from vllm.config import KVEventsConfig, KVTransferConfig from vllm.config import KVEventsConfig, KVTransferConfig
from vllm.distributed.kv_events import BlockStored, KVEventBatch from vllm.distributed.kv_events import BlockStored, KVEventBatch
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.system_utils import set_env_var
CPU_BLOCK_SIZES = [48] CPU_BLOCK_SIZES = [48]
ATTN_BACKENDS = ["FLASH_ATTN"] ATTN_BACKENDS = ["FLASH_ATTN"]
@ -180,13 +179,13 @@ def test_cpu_offloading(cpu_block_size: int, attn_backend: str) -> None:
topic="test", topic="test",
) )
with set_env_var("VLLM_ATTENTION_BACKEND", attn_backend): llm = LLM(
llm = LLM( model="meta-llama/Llama-3.2-1B-Instruct",
model="meta-llama/Llama-3.2-1B-Instruct", gpu_memory_utilization=0.5,
gpu_memory_utilization=0.5, kv_events_config=kv_events_config,
kv_events_config=kv_events_config, kv_transfer_config=kv_transfer_config,
kv_transfer_config=kv_transfer_config, attention_config={"backend": attn_backend},
) )
events_endpoint = events_endpoint.replace("*", "127.0.0.1") events_endpoint = events_endpoint.replace("*", "127.0.0.1")
subscriber = MockSubscriber(events_endpoint, topic=kv_events_config.topic) subscriber = MockSubscriber(events_endpoint, topic=kv_events_config.topic)

Some files were not shown because too many files have changed in this diff Show More