mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-01 15:27:08 +08:00
Merge branch 'main' into fix-blockstored-kvevent
This commit is contained in:
commit
a3a2971c96
@ -141,7 +141,6 @@ if [[ $commands == *" entrypoints/openai "* ]]; then
|
||||
--ignore=entrypoints/openai/test_audio.py \
|
||||
--ignore=entrypoints/openai/test_shutdown.py \
|
||||
--ignore=entrypoints/openai/test_completion.py \
|
||||
--ignore=entrypoints/openai/test_sleep.py \
|
||||
--ignore=entrypoints/openai/test_models.py \
|
||||
--ignore=entrypoints/openai/test_lora_adapters.py \
|
||||
--ignore=entrypoints/openai/test_return_tokens_as_ids.py \
|
||||
|
||||
@ -39,7 +39,7 @@ docker run \
|
||||
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
|
||||
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
|
||||
python3 examples/offline_inference/basic/generate.py --model Intel/Qwen2.5-0.5B-W4A16-G128-AutoRound-LLMC-TEST-ONLY --enforce-eager
|
||||
VLLM_ATTENTION_BACKEND=TRITON_ATTN python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
|
||||
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager --attention-backend=TRITON_ATTN
|
||||
cd tests
|
||||
pytest -v -s v1/core
|
||||
pytest -v -s v1/engine
|
||||
|
||||
@ -128,7 +128,7 @@ steps:
|
||||
- tests/entrypoints/
|
||||
commands:
|
||||
- pytest -v -s entrypoints/openai/tool_parsers
|
||||
- pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/openai --ignore=entrypoints/offline_mode --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling
|
||||
- pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/openai --ignore=entrypoints/rpc --ignore=entrypoints/sleep --ignore=entrypoints/instrumentator --ignore=entrypoints/offline_mode --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling
|
||||
|
||||
- label: Entrypoints Integration Test (LLM) # 30min
|
||||
timeout_in_minutes: 40
|
||||
@ -148,7 +148,7 @@ steps:
|
||||
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
|
||||
|
||||
- label: Entrypoints Integration Test (API Server) # 100min
|
||||
- label: Entrypoints Integration Test (API Server 1) # 100min
|
||||
timeout_in_minutes: 130
|
||||
mirror_hardwares: [amdexperimental]
|
||||
agent_pool: mi325_1
|
||||
@ -162,10 +162,28 @@ steps:
|
||||
- tests/entrypoints/test_chat_utils
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/openai/test_collective_rpc.py # PYTHONPATH is needed to import custom Worker extension
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_collective_rpc.py --ignore=entrypoints/openai/tool_parsers/
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/tool_parsers/
|
||||
- pytest -v -s entrypoints/test_chat_utils.py
|
||||
|
||||
- label: Entrypoints Integration Test (API Server 2)
|
||||
timeout_in_minutes: 50
|
||||
mirror_hardwares: [amdexperimental]
|
||||
agent_pool: mi325_1
|
||||
# grade: Blocking
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
fast_check: true
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/entrypoints/sleep
|
||||
- tests/entrypoints/rpc
|
||||
- tests/tool_use
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s entrypoints/sleep
|
||||
- pytest -v -s tool_use
|
||||
- PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/rpc
|
||||
|
||||
- label: Entrypoints Integration Test (Pooling)
|
||||
timeout_in_minutes: 50
|
||||
mirror_hardwares: [amdexperimental]
|
||||
@ -722,7 +740,7 @@ steps:
|
||||
# https://github.com/pytorch/ao/issues/2919, we'll have to skip new torchao tests for now
|
||||
# we can only upgrade after this is resolved
|
||||
# TODO(jerryzh168): resolve the above comment
|
||||
- uv pip install --system torchao==0.13.0
|
||||
- uv pip install --system torchao==0.14.1
|
||||
- uv pip install --system conch-triton-kernels
|
||||
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/ --ignore quantization/test_blackwell_moe.py
|
||||
|
||||
@ -751,17 +769,6 @@ steps:
|
||||
# Transcription WER check is skipped because encoder-decoder models are not supported on ROCm, see https://github.com/vllm-project/vllm/issues/27442
|
||||
- pytest -s entrypoints/openai/correctness/
|
||||
|
||||
- label: OpenAI-Compatible Tool Use # 23 min
|
||||
timeout_in_minutes: 35
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
agent_pool: mi325_1
|
||||
# grade: Blocking
|
||||
fast_check: false
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/tool_use
|
||||
commands:
|
||||
- pytest -v -s tool_use
|
||||
|
||||
##### models test #####
|
||||
|
||||
|
||||
@ -114,7 +114,7 @@ steps:
|
||||
- tests/entrypoints/
|
||||
commands:
|
||||
- pytest -v -s entrypoints/openai/tool_parsers
|
||||
- pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/openai --ignore=entrypoints/offline_mode --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling
|
||||
- pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/rpc --ignore=entrypoints/sleep --ignore=entrypoints/instrumentator --ignore=entrypoints/openai --ignore=entrypoints/offline_mode --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling
|
||||
|
||||
- label: Entrypoints Integration Test (LLM) # 30min
|
||||
timeout_in_minutes: 40
|
||||
@ -132,7 +132,7 @@ steps:
|
||||
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
|
||||
|
||||
- label: Entrypoints Integration Test (API Server) # 100min
|
||||
- label: Entrypoints Integration Test (API Server 1) # 100min
|
||||
timeout_in_minutes: 130
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
@ -144,10 +144,26 @@ steps:
|
||||
- tests/entrypoints/test_chat_utils
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/openai/test_collective_rpc.py # PYTHONPATH is needed to import custom Worker extension
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_collective_rpc.py --ignore=entrypoints/openai/tool_parsers/
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/tool_parsers/
|
||||
- pytest -v -s entrypoints/test_chat_utils.py
|
||||
|
||||
- label: Entrypoints Integration Test (API Server 2)
|
||||
timeout_in_minutes: 50
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
fast_check: true
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/entrypoints/sleep
|
||||
- tests/entrypoints/rpc
|
||||
- tests/tool_use
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s entrypoints/sleep
|
||||
- PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/rpc
|
||||
- pytest -v -s tool_use
|
||||
|
||||
- label: Entrypoints Integration Test (Pooling)
|
||||
timeout_in_minutes: 50
|
||||
mirror_hardwares: [amdexperimental]
|
||||
@ -642,7 +658,7 @@ steps:
|
||||
# https://github.com/pytorch/ao/issues/2919, we'll have to skip new torchao tests for now
|
||||
# we can only upgrade after this is resolved
|
||||
# TODO(jerryzh168): resolve the above comment
|
||||
- uv pip install --system torchao==0.13.0 --index-url https://download.pytorch.org/whl/cu129
|
||||
- uv pip install --system torchao==0.14.1 --index-url https://download.pytorch.org/whl/cu129
|
||||
- uv pip install --system conch-triton-kernels
|
||||
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/ --ignore quantization/test_blackwell_moe.py
|
||||
|
||||
@ -654,7 +670,7 @@ steps:
|
||||
- vllm/model_executor/layers/quantization
|
||||
autorun_on_main: true
|
||||
commands:
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt
|
||||
|
||||
- label: OpenAI API correctness # 22min
|
||||
timeout_in_minutes: 30
|
||||
@ -666,16 +682,6 @@ steps:
|
||||
commands: # LMEval+Transcription WER check
|
||||
- pytest -s entrypoints/openai/correctness/
|
||||
|
||||
- label: OpenAI-Compatible Tool Use # 23 min
|
||||
timeout_in_minutes: 35
|
||||
mirror_hardwares: [amdexperimental]
|
||||
fast_check: false
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/tool_use
|
||||
commands:
|
||||
- pytest -v -s tool_use
|
||||
|
||||
##### models test #####
|
||||
|
||||
- label: Basic Models Tests (Initialization)
|
||||
@ -1064,7 +1070,7 @@ steps:
|
||||
- csrc/
|
||||
- vllm/model_executor/layers/quantization
|
||||
commands:
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-blackwell.txt --tp-size=1
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-blackwell.txt
|
||||
|
||||
##### 1 GPU test #####
|
||||
##### multi gpus test #####
|
||||
|
||||
@ -32,6 +32,7 @@ steps:
|
||||
- label: Prime-RL Integration (2 GPUs)
|
||||
timeout_in_minutes: 30
|
||||
optional: true
|
||||
soft_fail: true
|
||||
num_gpus: 2
|
||||
working_dir: "/vllm-workspace"
|
||||
source_file_dependencies:
|
||||
@ -39,21 +40,3 @@ steps:
|
||||
- .buildkite/scripts/run-prime-rl-test.sh
|
||||
commands:
|
||||
- bash .buildkite/scripts/run-prime-rl-test.sh
|
||||
|
||||
- label: DeepSeek V2-Lite Async EPLB Accuracy
|
||||
timeout_in_minutes: 60
|
||||
gpu: h100
|
||||
optional: true
|
||||
num_gpus: 4
|
||||
working_dir: "/vllm-workspace"
|
||||
commands:
|
||||
- bash .buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_async_eplb.sh 0.25 1319 8030
|
||||
|
||||
- label: Qwen3-Next-80B-A3B-Instruct MTP Async EPLB Accuracy
|
||||
timeout_in_minutes: 60
|
||||
gpu: h100
|
||||
optional: true
|
||||
num_gpus: 4
|
||||
working_dir: "/vllm-workspace"
|
||||
commands:
|
||||
- bash .buildkite/scripts/scheduled_integration_test/qwen3_next_mtp_async_eplb.sh 0.8 1319 8040
|
||||
|
||||
@ -10,7 +10,7 @@ steps:
|
||||
- tests/entrypoints/
|
||||
commands:
|
||||
- pytest -v -s entrypoints/openai/tool_parsers
|
||||
- pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/openai --ignore=entrypoints/offline_mode --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling
|
||||
- pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/rpc --ignore=entrypoints/sleep --ignore=entrypoints/instrumentator --ignore=entrypoints/openai --ignore=entrypoints/offline_mode --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling
|
||||
|
||||
- label: Entrypoints Integration (LLM)
|
||||
timeout_in_minutes: 40
|
||||
@ -25,7 +25,7 @@ steps:
|
||||
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
|
||||
|
||||
- label: Entrypoints Integration (API Server)
|
||||
- label: Entrypoints Integration (API Server 1)
|
||||
timeout_in_minutes: 130
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
source_file_dependencies:
|
||||
@ -34,11 +34,26 @@ steps:
|
||||
- tests/entrypoints/test_chat_utils
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/openai/test_collective_rpc.py # PYTHONPATH is needed to import custom Worker extension
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_collective_rpc.py --ignore=entrypoints/openai/tool_parsers/
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/tool_parsers/
|
||||
- pytest -v -s entrypoints/test_chat_utils.py
|
||||
|
||||
|
||||
- label: Entrypoints Integration (API Server 2)
|
||||
timeout_in_minutes: 130
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/tool_use
|
||||
- tests/entrypoints/sleep
|
||||
- tests/entrypoints/instrumentator
|
||||
- tests/entrypoints/rpc
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/rpc
|
||||
- pytest -v -s entrypoints/instrumentator
|
||||
- pytest -v -s entrypoints/sleep
|
||||
- pytest -v -s tool_use
|
||||
|
||||
- label: Entrypoints Integration (Pooling)
|
||||
timeout_in_minutes: 50
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
|
||||
@ -9,7 +9,7 @@ steps:
|
||||
- vllm/model_executor/layers/quantization
|
||||
autorun_on_main: true
|
||||
commands:
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt
|
||||
|
||||
- label: LM Eval Large Models (4 GPUs)(A100)
|
||||
gpu: a100
|
||||
@ -43,4 +43,4 @@ steps:
|
||||
- csrc/
|
||||
- vllm/model_executor/layers/quantization
|
||||
commands:
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-blackwell.txt --tp-size=1
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-blackwell.txt
|
||||
|
||||
@ -22,6 +22,8 @@ steps:
|
||||
# FIXIT: find out which code initialize cuda before running the test
|
||||
# before the fix, we need to use spawn to test it
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
# Alot of these tests are on the edge of OOMing
|
||||
- export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
||||
# There is some Tensor Parallelism related processing logic in LoRA that
|
||||
# requires multi-GPU testing for validation.
|
||||
- pytest -v -s -x lora/test_chatglm3_tp.py
|
||||
|
||||
@ -9,6 +9,7 @@ steps:
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/models/test_initialization.py
|
||||
- tests/models/registry.py
|
||||
commands:
|
||||
# Run a subset of model initialization tests
|
||||
- pytest -v -s models/test_initialization.py::test_can_initialize_small_subset
|
||||
@ -20,6 +21,7 @@ steps:
|
||||
source_file_dependencies:
|
||||
- vllm/model_executor/models/
|
||||
- tests/models/test_initialization.py
|
||||
- tests/models/registry.py
|
||||
commands:
|
||||
# Only when vLLM model source is modified - test initialization of a large
|
||||
# subset of supported models (the complement of the small subset in the above
|
||||
|
||||
@ -13,7 +13,9 @@ steps:
|
||||
# tests covered elsewhere.
|
||||
# Use `find` to launch multiple instances of pytest so that
|
||||
# they do not suffer from https://github.com/vllm-project/vllm/issues/28965
|
||||
- "find compile/ -maxdepth 1 -name 'test_*.py' -exec pytest -s -v {} \\;"
|
||||
# However, find does not normally propagate error codes, so we combine it with xargs
|
||||
# (using -0 for proper path handling)
|
||||
- "find compile/ -maxdepth 1 -name 'test_*.py' -print0 | xargs -0 -n1 -I{} pytest -s -v '{}'"
|
||||
|
||||
- label: PyTorch Fullgraph Smoke Test
|
||||
timeout_in_minutes: 30
|
||||
|
||||
@ -1,13 +0,0 @@
|
||||
group: Tool use
|
||||
depends_on:
|
||||
- image-build
|
||||
steps:
|
||||
- label: OpenAI-Compatible Tool Use
|
||||
timeout_in_minutes: 35
|
||||
mirror_hardwares: [amdexperimental]
|
||||
fast_check: false
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/tool_use
|
||||
commands:
|
||||
- pytest -v -s tool_use
|
||||
14
.github/mergify.yml
vendored
14
.github/mergify.yml
vendored
@ -235,6 +235,20 @@ pull_request_rules:
|
||||
add:
|
||||
- rocm
|
||||
|
||||
- name: label-cpu
|
||||
description: Automatically apply cpu label
|
||||
conditions:
|
||||
- label != stale
|
||||
- files~=^(?!.*kv_offload)(?!.*cpu_offload).*\bcpu.*
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
- cpu
|
||||
assign:
|
||||
users:
|
||||
- "fadara01"
|
||||
- "aditew01"
|
||||
|
||||
- name: label-structured-output
|
||||
description: Automatically apply structured-output label
|
||||
conditions:
|
||||
|
||||
113
CMakeLists.txt
113
CMakeLists.txt
@ -56,8 +56,8 @@ endif()
|
||||
# requirements.txt files and should be kept consistent. The ROCm torch
|
||||
# versions are derived from docker/Dockerfile.rocm
|
||||
#
|
||||
set(TORCH_SUPPORTED_VERSION_CUDA "2.9.0")
|
||||
set(TORCH_SUPPORTED_VERSION_ROCM "2.9.0")
|
||||
set(TORCH_SUPPORTED_VERSION_CUDA "2.9.1")
|
||||
set(TORCH_SUPPORTED_VERSION_ROCM "2.9.1")
|
||||
|
||||
#
|
||||
# Try to find python package with an executable that exactly matches
|
||||
@ -357,6 +357,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
# marlin arches for fp16 output
|
||||
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0+PTX" "${CUDA_ARCHS}")
|
||||
# marlin has limited support for turing
|
||||
cuda_archs_loose_intersection(MARLIN_SM75_ARCHS "7.5" "${CUDA_ARCHS}")
|
||||
# marlin arches for bf16 output (we need 9.0 for bf16 atomicAdd PTX)
|
||||
cuda_archs_loose_intersection(MARLIN_BF16_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")
|
||||
# marlin arches for fp8 input
|
||||
@ -364,8 +366,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction
|
||||
# so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0)
|
||||
cuda_archs_loose_intersection(MARLIN_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}")
|
||||
# marlin arches for other files
|
||||
cuda_archs_loose_intersection(MARLIN_OTHER_ARCHS "7.5;8.0+PTX" "${CUDA_ARCHS}")
|
||||
|
||||
if (MARLIN_ARCHS)
|
||||
if (MARLIN_OTHER_ARCHS)
|
||||
|
||||
#
|
||||
# For the Marlin kernels we automatically generate sources for various
|
||||
@ -406,25 +410,39 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
message(STATUS "Marlin generation script has not changed, skipping generation.")
|
||||
endif()
|
||||
|
||||
file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_float16.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
|
||||
if (MARLIN_ARCHS)
|
||||
file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_float16.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
|
||||
|
||||
file(GLOB MARLIN_TEMPLATE_BF16_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_bfloat16.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_TEMPLATE_BF16_KERNEL_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_BF16_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_TEMPLATE_BF16_KERNEL_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
file(GLOB MARLIN_TEMPLATE_BF16_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_bfloat16.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_TEMPLATE_BF16_KERNEL_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_BF16_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_TEMPLATE_BF16_KERNEL_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_BF16_KERNEL_SRC})
|
||||
endif()
|
||||
|
||||
if (MARLIN_SM75_ARCHS)
|
||||
file(GLOB MARLIN_TEMPLATE_SM75_KERNEL_SRC "csrc/quantization/gptq_marlin/sm75_kernel_*.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_TEMPLATE_SM75_KERNEL_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_SM75_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_TEMPLATE_SM75_KERNEL_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_SM75_KERNEL_SRC})
|
||||
endif()
|
||||
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_BF16_KERNEL_SRC})
|
||||
|
||||
if (MARLIN_FP8_ARCHS)
|
||||
file(GLOB MARLIN_TEMPLATE_FP8_KERNEL_SRC "csrc/quantization/gptq_marlin/sm89_kernel_*.cu")
|
||||
@ -446,14 +464,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_SRCS}"
|
||||
CUDA_ARCHS "${MARLIN_ARCHS}")
|
||||
CUDA_ARCHS "${MARLIN_OTHER_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties("csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
||||
set_source_files_properties(${MARLIN_SRCS}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_EXT_SRC "${MARLIN_SRCS}")
|
||||
|
||||
message(STATUS "Building Marlin kernels for archs: ${MARLIN_ARCHS}")
|
||||
message(STATUS "Building Marlin kernels for archs: ${MARLIN_OTHER_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building Marlin kernels as no compatible archs found"
|
||||
" in CUDA target architectures")
|
||||
@ -980,12 +998,16 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# note that we always set `use_atomic_add=False` for moe marlin now,
|
||||
# so we don't need 9.0 for bf16 atomicAdd PTX
|
||||
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0+PTX" "${CUDA_ARCHS}")
|
||||
# moe marlin has limited support for turing
|
||||
cuda_archs_loose_intersection(MARLIN_MOE_SM75_ARCHS "7.5" "${CUDA_ARCHS}")
|
||||
# moe marlin arches for fp8 input
|
||||
# - sm80 doesn't support fp8 computation
|
||||
# - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction
|
||||
# so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0)
|
||||
cuda_archs_loose_intersection(MARLIN_MOE_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}")
|
||||
if (MARLIN_MOE_ARCHS)
|
||||
# moe marlin arches for other files
|
||||
cuda_archs_loose_intersection(MARLIN_MOE_OTHER_ARCHS "7.5;8.0+PTX" "${CUDA_ARCHS}")
|
||||
if (MARLIN_MOE_OTHER_ARCHS)
|
||||
|
||||
#
|
||||
# For the Marlin MOE kernels we automatically generate sources for various
|
||||
@ -1026,16 +1048,29 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
message(STATUS "Marlin MOE generation script has not changed, skipping generation.")
|
||||
endif()
|
||||
|
||||
file(GLOB MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/sm80_kernel_*.cu")
|
||||
list(APPEND MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/ops.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_MOE_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_MOE_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_MOE_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
if (MARLIN_MOE_ARCHS)
|
||||
file(GLOB MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/sm80_kernel_*.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_MOE_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_MOE_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_MOE_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_SRC})
|
||||
endif()
|
||||
|
||||
if (MARLIN_MOE_SM75_ARCHS)
|
||||
file(GLOB MARLIN_MOE_SM75_SRC "csrc/moe/marlin_moe_wna16/sm75_kernel_*.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_MOE_SM75_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_MOE_SM75_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_MOE_SM75_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_SM75_SRC})
|
||||
endif()
|
||||
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_SRC})
|
||||
|
||||
if (MARLIN_MOE_FP8_ARCHS)
|
||||
file(GLOB MARLIN_MOE_FP8_SRC "csrc/moe/marlin_moe_wna16/sm89_kernel_*.cu")
|
||||
@ -1049,7 +1084,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_FP8_SRC})
|
||||
endif()
|
||||
|
||||
message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}")
|
||||
set(MARLIN_MOE_OTHER_SRC "csrc/moe/marlin_moe_wna16/ops.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_MOE_OTHER_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_MOE_OTHER_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_MOE_OTHER_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_MOE_EXT_SRC "${MARLIN_MOE_OTHER_SRC}")
|
||||
|
||||
message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_OTHER_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building Marlin MOE kernels as no compatible archs found"
|
||||
" in CUDA target architectures")
|
||||
|
||||
@ -13,8 +13,8 @@ from vllm.triton_utils import triton
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
batch_size_range = [1, 16, 32, 64, 128]
|
||||
seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]
|
||||
batch_size_range = [1, 16, 128]
|
||||
seq_len_range = [1, 16, 64, 1024, 4096]
|
||||
intermediate_size = [3072, 9728, 12288]
|
||||
configs = list(itertools.product(batch_size_range, seq_len_range, intermediate_size))
|
||||
|
||||
|
||||
@ -15,19 +15,61 @@ __device__ __forceinline__ scalar_t compute(const scalar_t& x,
|
||||
const scalar_t& y) {
|
||||
return act_first ? ACT_FN(x) * y : x * ACT_FN(y);
|
||||
}
|
||||
// Activation and gating kernel template.
|
||||
|
||||
// Check if all pointers are 16-byte aligned for int4 vectorized access
|
||||
__device__ __forceinline__ bool is_16byte_aligned(const void* ptr) {
|
||||
return (reinterpret_cast<uintptr_t>(ptr) & 15) == 0;
|
||||
}
|
||||
|
||||
// Activation and gating kernel template.
|
||||
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
|
||||
bool act_first>
|
||||
__global__ void act_and_mul_kernel(
|
||||
scalar_t* __restrict__ out, // [..., d]
|
||||
const scalar_t* __restrict__ input, // [..., 2, d]
|
||||
const int d) {
|
||||
constexpr int VEC_SIZE = 16 / sizeof(scalar_t);
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
|
||||
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
|
||||
out[token_idx * d + idx] = compute<scalar_t, ACT_FN, act_first>(x, y);
|
||||
const scalar_t* x_ptr = input + token_idx * 2 * d;
|
||||
const scalar_t* y_ptr = x_ptr + d;
|
||||
scalar_t* out_ptr = out + token_idx * d;
|
||||
|
||||
// Check alignment for 128-bit vectorized access.
|
||||
// All three pointers must be 16-byte aligned for safe int4 operations.
|
||||
const bool aligned = is_16byte_aligned(x_ptr) && is_16byte_aligned(y_ptr) &&
|
||||
is_16byte_aligned(out_ptr);
|
||||
|
||||
if (aligned && d >= VEC_SIZE) {
|
||||
// Fast path: 128-bit vectorized loop
|
||||
const int4* x_vec = reinterpret_cast<const int4*>(x_ptr);
|
||||
const int4* y_vec = reinterpret_cast<const int4*>(y_ptr);
|
||||
int4* out_vec = reinterpret_cast<int4*>(out_ptr);
|
||||
const int num_vecs = d / VEC_SIZE;
|
||||
const int vec_end = num_vecs * VEC_SIZE;
|
||||
|
||||
for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
|
||||
int4 x = VLLM_LDG(&x_vec[i]), y = VLLM_LDG(&y_vec[i]), r;
|
||||
auto* xp = reinterpret_cast<scalar_t*>(&x);
|
||||
auto* yp = reinterpret_cast<scalar_t*>(&y);
|
||||
auto* rp = reinterpret_cast<scalar_t*>(&r);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; j++) {
|
||||
rp[j] = compute<scalar_t, ACT_FN, act_first>(xp[j], yp[j]);
|
||||
}
|
||||
out_vec[i] = r;
|
||||
}
|
||||
// Scalar cleanup for remaining elements
|
||||
for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) {
|
||||
out_ptr[i] = compute<scalar_t, ACT_FN, act_first>(VLLM_LDG(&x_ptr[i]),
|
||||
VLLM_LDG(&y_ptr[i]));
|
||||
}
|
||||
} else {
|
||||
// Scalar fallback for unaligned data or small d
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = VLLM_LDG(&x_ptr[idx]);
|
||||
const scalar_t y = VLLM_LDG(&y_ptr[idx]);
|
||||
out_ptr[idx] = compute<scalar_t, ACT_FN, act_first>(x, y);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -120,50 +162,115 @@ template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&, const float)>
|
||||
__global__ void act_and_mul_kernel_with_param(
|
||||
scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const int d,
|
||||
const float param) {
|
||||
constexpr int VEC_SIZE = 16 / sizeof(scalar_t);
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
|
||||
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
|
||||
out[token_idx * d + idx] = ACT_FN(x, param) * y;
|
||||
const scalar_t* x_ptr = input + token_idx * 2 * d;
|
||||
const scalar_t* y_ptr = x_ptr + d;
|
||||
scalar_t* out_ptr = out + token_idx * d;
|
||||
|
||||
// Check alignment for 128-bit vectorized access
|
||||
const bool aligned = is_16byte_aligned(x_ptr) && is_16byte_aligned(y_ptr) &&
|
||||
is_16byte_aligned(out_ptr);
|
||||
|
||||
if (aligned && d >= VEC_SIZE) {
|
||||
// Fast path: 128-bit vectorized loop
|
||||
const int4* x_vec = reinterpret_cast<const int4*>(x_ptr);
|
||||
const int4* y_vec = reinterpret_cast<const int4*>(y_ptr);
|
||||
int4* out_vec = reinterpret_cast<int4*>(out_ptr);
|
||||
const int num_vecs = d / VEC_SIZE;
|
||||
const int vec_end = num_vecs * VEC_SIZE;
|
||||
|
||||
for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
|
||||
int4 x = VLLM_LDG(&x_vec[i]), y = VLLM_LDG(&y_vec[i]), r;
|
||||
auto* xp = reinterpret_cast<scalar_t*>(&x);
|
||||
auto* yp = reinterpret_cast<scalar_t*>(&y);
|
||||
auto* rp = reinterpret_cast<scalar_t*>(&r);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; j++) {
|
||||
rp[j] = ACT_FN(xp[j], param) * yp[j];
|
||||
}
|
||||
out_vec[i] = r;
|
||||
}
|
||||
// Scalar cleanup for remaining elements
|
||||
for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) {
|
||||
out_ptr[i] = ACT_FN(VLLM_LDG(&x_ptr[i]), param) * VLLM_LDG(&y_ptr[i]);
|
||||
}
|
||||
} else {
|
||||
// Scalar fallback for unaligned data or small d
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = VLLM_LDG(&x_ptr[idx]);
|
||||
const scalar_t y = VLLM_LDG(&y_ptr[idx]);
|
||||
out_ptr[idx] = ACT_FN(x, param) * y;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T swigluoai_and_mul(const T& gate, const T& up,
|
||||
float alpha, float limit) {
|
||||
// clamp gate: min=None, max=limit
|
||||
const float gate_f = (float)gate;
|
||||
const float clamped_gate = gate_f > limit ? limit : gate_f;
|
||||
|
||||
// clamp up: min=-limit, max=limit
|
||||
const float up_f = (float)up;
|
||||
const float clamped_up =
|
||||
up_f > limit ? limit : (up_f < -limit ? -limit : up_f);
|
||||
|
||||
// glu = gate * sigmoid(gate * alpha)
|
||||
const float sigmoid_val = 1.0f / (1.0f + expf(-clamped_gate * alpha));
|
||||
const float glu = clamped_gate * sigmoid_val;
|
||||
|
||||
// (up + 1) * glu
|
||||
return (T)((clamped_up + 1.0f) * glu);
|
||||
// Clamp gate to (-inf, limit] and up to [-limit, limit]
|
||||
const float g = fminf((float)gate, limit);
|
||||
const float u = fmaxf(fminf((float)up, limit), -limit);
|
||||
// glu = gate * sigmoid(gate * alpha), then return (up + 1) * glu
|
||||
return (T)((u + 1.0f) * g / (1.0f + expf(-g * alpha)));
|
||||
}
|
||||
|
||||
// Interleaved gate/up: input has [gate0, up0, gate1, up1, ...].
|
||||
template <typename scalar_t,
|
||||
scalar_t (*ACT_FN)(const scalar_t&, const scalar_t&, const float,
|
||||
const float)>
|
||||
__global__ void swigluoai_and_mul_kernel(
|
||||
scalar_t* __restrict__ out, // [..., d]
|
||||
const scalar_t* __restrict__ input, // [..., 2, d]
|
||||
const scalar_t* __restrict__ input, // [..., 2 * d] (interleaved)
|
||||
const int d, const float alpha, const float limit) {
|
||||
// For interleaved data: input has 2*d elements per token (gate/up pairs)
|
||||
// output has d elements per token
|
||||
constexpr int VEC_SIZE = 16 / sizeof(scalar_t);
|
||||
constexpr int PAIRS = VEC_SIZE / 2; // Number of gate/up pairs per int4 load
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
// TODO: Vectorize loads and stores.
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
// gate = x[..., ::2] (even indices)
|
||||
const scalar_t gate = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx]);
|
||||
// up = x[..., 1::2] (odd indices)
|
||||
const scalar_t up = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx + 1]);
|
||||
const scalar_t* in_ptr = input + token_idx * 2 * d;
|
||||
scalar_t* out_ptr = out + token_idx * d;
|
||||
|
||||
out[token_idx * d + idx] = ACT_FN(gate, up, alpha, limit);
|
||||
// Check alignment for 128-bit vectorized access on input.
|
||||
// For output we use int2 (64-bit) which has 8-byte alignment requirement.
|
||||
const bool in_aligned = is_16byte_aligned(in_ptr);
|
||||
const bool out_aligned =
|
||||
(reinterpret_cast<uintptr_t>(out_ptr) & 7) == 0; // 8-byte for int2
|
||||
|
||||
if (in_aligned && out_aligned && d >= PAIRS) {
|
||||
// Fast path: vectorized loop
|
||||
// Each int4 load gives VEC_SIZE elements = PAIRS gate/up pairs
|
||||
// Each int2 store writes PAIRS output elements
|
||||
const int4* in_vec = reinterpret_cast<const int4*>(in_ptr);
|
||||
int2* out_vec = reinterpret_cast<int2*>(out_ptr);
|
||||
const int num_vecs = d / PAIRS;
|
||||
const int vec_end = num_vecs * PAIRS;
|
||||
|
||||
for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
|
||||
int4 v = VLLM_LDG(&in_vec[i]);
|
||||
int2 r;
|
||||
auto* vp = reinterpret_cast<scalar_t*>(&v);
|
||||
auto* rp = reinterpret_cast<scalar_t*>(&r);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < PAIRS; j++) {
|
||||
rp[j] = ACT_FN(vp[2 * j], vp[2 * j + 1], alpha, limit);
|
||||
}
|
||||
out_vec[i] = r;
|
||||
}
|
||||
// Scalar cleanup for remaining elements
|
||||
for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) {
|
||||
out_ptr[i] = ACT_FN(VLLM_LDG(&in_ptr[2 * i]),
|
||||
VLLM_LDG(&in_ptr[2 * i + 1]), alpha, limit);
|
||||
}
|
||||
} else {
|
||||
// Scalar fallback for unaligned data or small d
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
// gate = x[..., ::2] (even indices)
|
||||
const scalar_t gate = VLLM_LDG(&in_ptr[2 * idx]);
|
||||
// up = x[..., 1::2] (odd indices)
|
||||
const scalar_t up = VLLM_LDG(&in_ptr[2 * idx + 1]);
|
||||
out_ptr[idx] = ACT_FN(gate, up, alpha, limit);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -217,10 +324,41 @@ __global__ void activation_kernel(
|
||||
scalar_t* __restrict__ out, // [..., d]
|
||||
const scalar_t* __restrict__ input, // [..., d]
|
||||
const int d) {
|
||||
constexpr int VEC_SIZE = 16 / sizeof(scalar_t);
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
|
||||
out[token_idx * d + idx] = ACT_FN(x);
|
||||
const scalar_t* in_ptr = input + token_idx * d;
|
||||
scalar_t* out_ptr = out + token_idx * d;
|
||||
|
||||
// Check alignment for 128-bit vectorized access
|
||||
const bool aligned = is_16byte_aligned(in_ptr) && is_16byte_aligned(out_ptr);
|
||||
|
||||
if (aligned && d >= VEC_SIZE) {
|
||||
// Fast path: 128-bit vectorized loop
|
||||
const int4* in_vec = reinterpret_cast<const int4*>(in_ptr);
|
||||
int4* out_vec = reinterpret_cast<int4*>(out_ptr);
|
||||
const int num_vecs = d / VEC_SIZE;
|
||||
const int vec_end = num_vecs * VEC_SIZE;
|
||||
|
||||
for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
|
||||
int4 v = VLLM_LDG(&in_vec[i]), r;
|
||||
auto* vp = reinterpret_cast<scalar_t*>(&v);
|
||||
auto* rp = reinterpret_cast<scalar_t*>(&r);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; j++) {
|
||||
rp[j] = ACT_FN(vp[j]);
|
||||
}
|
||||
out_vec[i] = r;
|
||||
}
|
||||
// Scalar cleanup for remaining elements
|
||||
for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) {
|
||||
out_ptr[i] = ACT_FN(VLLM_LDG(&in_ptr[i]));
|
||||
}
|
||||
} else {
|
||||
// Scalar fallback for unaligned data or small d
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = VLLM_LDG(&in_ptr[idx]);
|
||||
out_ptr[idx] = ACT_FN(x);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -107,6 +107,16 @@ void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem,
|
||||
prop.location.id = device;
|
||||
prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_NONE;
|
||||
|
||||
#ifndef USE_ROCM
|
||||
int flag = 0;
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
&flag, CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WITH_CUDA_VMM_SUPPORTED,
|
||||
device));
|
||||
if (flag) { // support GPUDirect RDMA if possible
|
||||
prop.allocFlags.gpuDirectRDMACapable = 1;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// Allocate memory using cuMemCreate
|
||||
CUDA_CHECK(cuMemCreate(p_memHandle, size, &prop, 0));
|
||||
|
||||
@ -446,9 +446,13 @@ __device__ inline T apply_sigmoid(T val) {
|
||||
|
||||
template <ScoringFunc SF, typename T>
|
||||
__device__ inline T apply_scoring(T val) {
|
||||
if constexpr (SF == SCORING_SIGMOID) {
|
||||
if constexpr (SF == SCORING_NONE) {
|
||||
return val;
|
||||
} else if constexpr (SF == SCORING_SIGMOID) {
|
||||
return apply_sigmoid(val);
|
||||
} else {
|
||||
static_assert(SF == SCORING_NONE || SF == SCORING_SIGMOID,
|
||||
"Unsupported ScoringFunc in apply_scoring");
|
||||
return val;
|
||||
}
|
||||
}
|
||||
@ -670,10 +674,13 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
|
||||
if (case_id < num_tokens) {
|
||||
if (if_proceed_next_topk) {
|
||||
float scale = routed_scaling_factor;
|
||||
if (renormalize) {
|
||||
scale /= topk_sum;
|
||||
}
|
||||
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
||||
float base = cuda_cast<float, T>(s_topk_value[i]);
|
||||
float value = renormalize ? (base / topk_sum * routed_scaling_factor)
|
||||
: (base * routed_scaling_factor);
|
||||
float value = base * scale;
|
||||
topk_indices[i] = s_topk_idx[i];
|
||||
topk_values[i] = value;
|
||||
}
|
||||
|
||||
1
csrc/moe/marlin_moe_wna16/.gitignore
vendored
1
csrc/moe/marlin_moe_wna16/.gitignore
vendored
@ -1,2 +1,3 @@
|
||||
sm*_kernel_*.cu
|
||||
kernel_selector.h
|
||||
kernel_*.cu
|
||||
|
||||
@ -10,6 +10,8 @@ import jinja2
|
||||
|
||||
ARCHS = []
|
||||
SUPPORT_FP8 = False
|
||||
SUPPORT_SM75 = False
|
||||
SUPPORT_SM80 = False
|
||||
for arch in sys.argv[1].split(","):
|
||||
arch = arch[: arch.index(".") + 2].replace(".", "")
|
||||
arch = int(arch)
|
||||
@ -19,6 +21,10 @@ for arch in sys.argv[1].split(","):
|
||||
# with FP16 MMA, so it cannot achieve any acceleration.
|
||||
if arch in [89, 120]:
|
||||
SUPPORT_FP8 = True
|
||||
if arch >= 80:
|
||||
SUPPORT_SM80 = True
|
||||
if arch == 75:
|
||||
SUPPORT_SM75 = True
|
||||
|
||||
FILE_HEAD_COMMENT = """
|
||||
// auto generated by generate_kernels.py
|
||||
@ -157,6 +163,7 @@ def remove_old_kernels():
|
||||
|
||||
def generate_new_kernels():
|
||||
result_dict = {}
|
||||
sm_75_result_dict = {}
|
||||
|
||||
for quant_config in QUANT_CONFIGS:
|
||||
c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"])
|
||||
@ -174,6 +181,8 @@ def generate_new_kernels():
|
||||
s_type = quant_config.get("s_type", c_type)
|
||||
if (a_type, b_type, c_type) not in result_dict:
|
||||
result_dict[(a_type, b_type, c_type)] = []
|
||||
if a_type in ["kFloat16", "kS8"] and c_type == "kFloat16":
|
||||
sm_75_result_dict[(a_type, b_type, c_type)] = []
|
||||
|
||||
for group_blocks, m_blocks, thread_configs in itertools.product(
|
||||
all_group_blocks, all_m_blocks, all_thread_configs
|
||||
@ -197,78 +206,89 @@ def generate_new_kernels():
|
||||
"thread_k_blocks": thread_k // 16,
|
||||
"thread_n_blocks": thread_n // 16,
|
||||
"m_block_size_8": "true" if m_blocks == 0.5 else "false",
|
||||
"stages": "pipe_stages",
|
||||
"stages": 4,
|
||||
"group_blocks": group_blocks,
|
||||
"is_zp_float": "false",
|
||||
}
|
||||
|
||||
result_dict[(a_type, b_type, c_type)].append(config)
|
||||
if SUPPORT_SM80:
|
||||
result_dict[(a_type, b_type, c_type)].append(config)
|
||||
if (a_type, b_type, c_type) in sm_75_result_dict and SUPPORT_SM75:
|
||||
config_sm75 = config.copy()
|
||||
config_sm75["stages"] = 2
|
||||
sm_75_result_dict[(a_type, b_type, c_type)].append(config_sm75)
|
||||
|
||||
kernel_selector_str = FILE_HEAD_COMMENT
|
||||
|
||||
for (a_type, b_type, c_type), config_list in result_dict.items():
|
||||
all_template_str_list = []
|
||||
for config in config_list:
|
||||
s_type = config["s_type"]
|
||||
template_str = jinja2.Template(TEMPLATE).render(
|
||||
a_type_id=f"vllm::{a_type}.id()",
|
||||
b_type_id=f"vllm::{b_type}.id()",
|
||||
c_type_id=f"vllm::{c_type}.id()",
|
||||
s_type_id=f"vllm::{s_type}.id()",
|
||||
**config,
|
||||
)
|
||||
all_template_str_list.append(template_str)
|
||||
|
||||
conditions = [
|
||||
f"a_type == vllm::{a_type}",
|
||||
f"b_type == vllm::{b_type}",
|
||||
f"c_type == vllm::{c_type}",
|
||||
f"s_type == vllm::{s_type}",
|
||||
f"threads == {config['threads']}",
|
||||
f"thread_m_blocks == {config['thread_m_blocks']}",
|
||||
f"thread_n_blocks == {config['thread_n_blocks']}",
|
||||
f"thread_k_blocks == {config['thread_k_blocks']}",
|
||||
f"m_block_size_8 == {config['m_block_size_8']}",
|
||||
f"group_blocks == {config['group_blocks']}",
|
||||
f"is_zp_float == {config['is_zp_float']}",
|
||||
]
|
||||
conditions = " && ".join(conditions)
|
||||
|
||||
if kernel_selector_str == FILE_HEAD_COMMENT:
|
||||
kernel_selector_str += f"if ({conditions})\n kernel = "
|
||||
else:
|
||||
kernel_selector_str += f"else if ({conditions})\n kernel = "
|
||||
|
||||
kernel_template2 = (
|
||||
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
|
||||
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
|
||||
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
|
||||
"{{is_zp_float}}>;"
|
||||
)
|
||||
|
||||
kernel_selector_str += (
|
||||
jinja2.Template(kernel_template2).render(
|
||||
for result_dict_tmp in [result_dict, sm_75_result_dict]:
|
||||
for (a_type, b_type, c_type), config_list in result_dict_tmp.items():
|
||||
all_template_str_list = []
|
||||
if not config_list:
|
||||
continue
|
||||
for config in config_list:
|
||||
s_type = config["s_type"]
|
||||
template_str = jinja2.Template(TEMPLATE).render(
|
||||
a_type_id=f"vllm::{a_type}.id()",
|
||||
b_type_id=f"vllm::{b_type}.id()",
|
||||
c_type_id=f"vllm::{c_type}.id()",
|
||||
s_type_id=f"vllm::{s_type}.id()",
|
||||
**config,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
all_template_str_list.append(template_str)
|
||||
|
||||
file_content = FILE_HEAD + "\n\n"
|
||||
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
|
||||
if a_type == "kFE4M3fn":
|
||||
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
else:
|
||||
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
conditions = [
|
||||
f"a_type == vllm::{a_type}",
|
||||
f"b_type == vllm::{b_type}",
|
||||
f"c_type == vllm::{c_type}",
|
||||
f"s_type == vllm::{s_type}",
|
||||
f"threads == {config['threads']}",
|
||||
f"thread_m_blocks == {config['thread_m_blocks']}",
|
||||
f"thread_n_blocks == {config['thread_n_blocks']}",
|
||||
f"thread_k_blocks == {config['thread_k_blocks']}",
|
||||
f"m_block_size_8 == {config['m_block_size_8']}",
|
||||
f"stages == {config['stages']}",
|
||||
f"group_blocks == {config['group_blocks']}",
|
||||
f"is_zp_float == {config['is_zp_float']}",
|
||||
]
|
||||
conditions = " && ".join(conditions)
|
||||
|
||||
filename = filename.lower()
|
||||
if kernel_selector_str == FILE_HEAD_COMMENT:
|
||||
kernel_selector_str += f"if ({conditions})\n kernel = "
|
||||
else:
|
||||
kernel_selector_str += f"else if ({conditions})\n kernel = "
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
|
||||
f.write(file_content)
|
||||
kernel_template2 = (
|
||||
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
|
||||
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
|
||||
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
|
||||
"{{is_zp_float}}>;"
|
||||
)
|
||||
|
||||
kernel_selector_str += (
|
||||
jinja2.Template(kernel_template2).render(
|
||||
a_type_id=f"vllm::{a_type}.id()",
|
||||
b_type_id=f"vllm::{b_type}.id()",
|
||||
c_type_id=f"vllm::{c_type}.id()",
|
||||
s_type_id=f"vllm::{s_type}.id()",
|
||||
**config,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
file_content = FILE_HEAD + "\n\n"
|
||||
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
|
||||
if a_type == "kFE4M3fn":
|
||||
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
elif result_dict_tmp is sm_75_result_dict:
|
||||
filename = f"sm75_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
else:
|
||||
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
|
||||
filename = filename.lower()
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
|
||||
f.write(file_content)
|
||||
|
||||
if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT:
|
||||
kernel_selector_str += (
|
||||
|
||||
@ -26,6 +26,7 @@
|
||||
#include "quantization/gptq_marlin/marlin.cuh"
|
||||
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
|
||||
#include "quantization/gptq_marlin/dequant.h"
|
||||
#include "quantization/gptq_marlin/marlin_mma.h"
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||
@ -35,7 +36,7 @@
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
||||
|
||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const vllm::ScalarTypeId b_type_id, // weight MarlinScalarType id
|
||||
@ -84,146 +85,6 @@ __global__ void Marlin(
|
||||
|
||||
#else
|
||||
|
||||
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
|
||||
// output/accumulation.
|
||||
template <vllm::ScalarTypeId type_id, int k_size = 16>
|
||||
__device__ inline void mma(
|
||||
const typename MarlinScalarType<type_id>::FragA& a_frag,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b,
|
||||
typename MarlinScalarType<type_id>::FragC& frag_c, int idx = 0) {
|
||||
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
|
||||
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
||||
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
|
||||
if constexpr (k_size == 16) {
|
||||
if constexpr (std::is_same<scalar_t, half>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "f"(c[0]),
|
||||
"f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "r"(c[0]),
|
||||
"r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
}
|
||||
} else if (k_size == 32) {
|
||||
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <vllm::ScalarTypeId type_id, int k_size = 16>
|
||||
__device__ inline void mma_trans(
|
||||
const typename MarlinScalarType<type_id>::FragA& a_frag,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b2,
|
||||
typename MarlinScalarType<type_id>::FragC& frag_c) {
|
||||
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
|
||||
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
||||
const uint32_t* b2 = reinterpret_cast<const uint32_t*>(&frag_b2);
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
|
||||
if constexpr (k_size == 16) {
|
||||
if constexpr (std::is_same<scalar_t, half>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
|
||||
"f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]),
|
||||
"r"(c[3]));
|
||||
}
|
||||
} else {
|
||||
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 1200
|
||||
asm volatile(
|
||||
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
#else
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
#endif
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
||||
// memory, directly in tensor core layout.
|
||||
template <int count, vllm::ScalarTypeId type_id>
|
||||
@ -439,9 +300,20 @@ __global__ void Marlin(
|
||||
if constexpr (a_type_id == vllm::kFE4M3fn.id()) return;
|
||||
#endif
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
// Turing TensorCore only supports fp16 and int8
|
||||
if constexpr (a_type_id != vllm::kFloat16.id() && a_type_id != vllm::kS8.id())
|
||||
return;
|
||||
#endif
|
||||
|
||||
int num_tokens_past_padded = num_tokens_past_padded_ptr[0];
|
||||
constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks);
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
constexpr bool use_fp16_accum = a_type_id == vllm::kFloat16.id();
|
||||
#else
|
||||
constexpr bool use_fp16_accum = false;
|
||||
#endif
|
||||
using Adtype = MarlinScalarType<a_type_id>;
|
||||
using Cdtype = MarlinScalarType<c_type_id>;
|
||||
|
||||
@ -618,7 +490,22 @@ __global__ void Marlin(
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
|
||||
if constexpr (moe_block_size >= 16)
|
||||
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 16);
|
||||
if constexpr (moe_block_size >= 8)
|
||||
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 8);
|
||||
if constexpr (moe_block_size >= 4)
|
||||
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 4);
|
||||
if constexpr (moe_block_size >= 2)
|
||||
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 2);
|
||||
|
||||
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 1);
|
||||
block_num_valid_tokens = local_count;
|
||||
#else
|
||||
block_num_valid_tokens = __reduce_add_sync(0xffffffff, local_count);
|
||||
#endif
|
||||
|
||||
if (lane_id == 0)
|
||||
reinterpret_cast<int*>(sh_new)[0] = block_num_valid_tokens;
|
||||
@ -1018,10 +905,6 @@ __global__ void Marlin(
|
||||
constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride)
|
||||
: (stages * s_sh_stage);
|
||||
int4* sh_s = sh_zp + (stages * zp_sh_stage);
|
||||
// shared memory reused by reduction should be smaller than
|
||||
// shared memory used by weight.
|
||||
static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <=
|
||||
stages * b_sh_stage);
|
||||
int4* sh_a = sh_s + sh_s_size;
|
||||
|
||||
// Register storage for double buffer of shared memory reads.
|
||||
@ -1545,11 +1428,13 @@ __global__ void Marlin(
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks; i++) {
|
||||
if constexpr (m_block_size_8) {
|
||||
mma_trans<a_type_id>(frag_a[k2][i], frag_b0, frag_b1,
|
||||
frag_c[i][j][0]);
|
||||
mma_trans<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b0, frag_b1,
|
||||
frag_c[i][j][0]);
|
||||
} else {
|
||||
mma<a_type_id>(frag_a[k2][i], frag_b0, frag_c[i][j][0]);
|
||||
mma<a_type_id>(frag_a[k2][i], frag_b1, frag_c[i][j][1]);
|
||||
mma<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b0,
|
||||
frag_c[i][j][0]);
|
||||
mma<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b1,
|
||||
frag_c[i][j][1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1583,10 +1468,12 @@ __global__ void Marlin(
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks; i++) {
|
||||
mma<a_type_id, 32>(frag_a[k2][i], frag_b[0],
|
||||
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]);
|
||||
mma<a_type_id, 32>(frag_a[k2][i], frag_b[1],
|
||||
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]);
|
||||
mma<a_type_id, false, 32>(
|
||||
frag_a[k2][i], frag_b[0],
|
||||
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]);
|
||||
mma<a_type_id, false, 32>(
|
||||
frag_a[k2][i], frag_b[1],
|
||||
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]);
|
||||
}
|
||||
|
||||
if constexpr (group_blocks != -1) {
|
||||
@ -2132,6 +2019,21 @@ __global__ void Marlin(
|
||||
// While this pattern may not be the most readable, other ways of writing
|
||||
// the loop seemed to noticeably worse performance after compilation.
|
||||
if (slice_iters == 0) {
|
||||
// convert fp16 accum to fp32 for reduction
|
||||
if constexpr (use_fp16_accum) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < (thread_m_blocks * (is_a_8bit ? 2 : 4) * 2); i++) {
|
||||
float* frag_c_part_float = reinterpret_cast<float*>(frag_c) + i * 4;
|
||||
scalar_t* frag_c_part_half =
|
||||
reinterpret_cast<scalar_t*>(frag_c_part_float);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 3; i >= 0; i--) {
|
||||
frag_c_part_float[i] = Cdtype::num2float(frag_c_part_half[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (is_a_8bit) {
|
||||
float frag_a_s[2 * thread_m_blocks];
|
||||
|
||||
|
||||
@ -142,7 +142,7 @@ typedef struct {
|
||||
|
||||
int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
|
||||
int prob_n, int prob_k, int num_bits, int group_size,
|
||||
bool has_act_order, bool is_k_full) {
|
||||
bool has_act_order, bool is_k_full, int stages) {
|
||||
bool cache_scales_chunk = has_act_order && !is_k_full;
|
||||
|
||||
int tb_n = th_config.thread_n;
|
||||
@ -160,13 +160,13 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
|
||||
|
||||
if (cache_scales_chunk) {
|
||||
int load_groups =
|
||||
tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
|
||||
tb_groups * stages * 2; // Chunk size is 2x pipeline over dim K
|
||||
load_groups = max(load_groups, 32); // We load at least 32 scale groups
|
||||
return load_groups * tb_n * 2;
|
||||
} else {
|
||||
int tb_scales = tb_groups * tb_n * 2;
|
||||
|
||||
return tb_scales * pipe_stages;
|
||||
return tb_scales * stages;
|
||||
}
|
||||
}
|
||||
|
||||
@ -174,7 +174,7 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
|
||||
int thread_m_blocks, int prob_m, int prob_n,
|
||||
int prob_k, int num_bits, int group_size,
|
||||
bool has_act_order, bool is_k_full, int has_zp,
|
||||
int is_zp_float, bool is_a_8bit) {
|
||||
int is_zp_float, bool is_a_8bit, int stages) {
|
||||
int pack_factor = 32 / num_bits;
|
||||
|
||||
// Get B size
|
||||
@ -185,8 +185,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
|
||||
// shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights
|
||||
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
|
||||
int sh_block_meta_size = tb_m * 16;
|
||||
int sh_a_size = pipe_stages * (tb_m * tb_k) * (is_a_8bit ? 1 : 2);
|
||||
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
|
||||
int sh_a_size = stages * (tb_m * tb_k) * (is_a_8bit ? 1 : 2);
|
||||
int sh_b_size = stages * (tb_k * tb_n / pack_factor) * 4;
|
||||
int sh_red_size = tb_m * (tb_n + 8) * 2;
|
||||
int sh_bias_size = tb_n * 2;
|
||||
int tmp_size =
|
||||
@ -195,8 +195,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
|
||||
|
||||
int sh_s_size =
|
||||
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
|
||||
group_size, has_act_order, is_k_full);
|
||||
int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0;
|
||||
group_size, has_act_order, is_k_full, stages);
|
||||
int sh_g_idx_size = has_act_order && !is_k_full ? stages * tb_k / 4 : 0;
|
||||
int sh_zp_size = 0;
|
||||
if (has_zp) {
|
||||
if (is_zp_float)
|
||||
@ -217,7 +217,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
|
||||
int thread_m_blocks, int prob_m, int prob_n, int prob_k,
|
||||
int num_bits, int group_size, bool has_act_order,
|
||||
bool is_k_full, int has_zp, int is_zp_float,
|
||||
int max_shared_mem, bool is_a_8bit) {
|
||||
bool is_a_8bit, int stages, int max_shared_mem) {
|
||||
// Sanity
|
||||
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
|
||||
th_config.num_threads == -1) {
|
||||
@ -243,7 +243,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
|
||||
int cache_size =
|
||||
get_kernel_cache_size(th_config, m_block_size_8, thread_m_blocks, prob_m,
|
||||
prob_n, prob_k, num_bits, group_size, has_act_order,
|
||||
is_k_full, has_zp, is_zp_float, is_a_8bit);
|
||||
is_k_full, has_zp, is_zp_float, is_a_8bit, stages);
|
||||
return cache_size <= max_shared_mem;
|
||||
}
|
||||
|
||||
@ -252,7 +252,7 @@ MarlinFuncPtr get_marlin_kernel(
|
||||
const vllm::ScalarType c_type, const vllm::ScalarType s_type,
|
||||
int thread_m_blocks, int thread_n_blocks, int thread_k_blocks,
|
||||
bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks,
|
||||
int threads, bool is_zp_float) {
|
||||
int threads, bool is_zp_float, int stages) {
|
||||
int num_bits = b_type.size_bits();
|
||||
auto kernel = MarlinDefault;
|
||||
|
||||
@ -266,8 +266,8 @@ exec_config_t determine_exec_config(
|
||||
const vllm::ScalarType& c_type, const vllm::ScalarType& s_type, int prob_m,
|
||||
int prob_n, int prob_k, int num_experts, int top_k, int thread_m_blocks,
|
||||
bool m_block_size_8, int num_bits, int group_size, bool has_act_order,
|
||||
bool is_k_full, bool has_zp, bool is_zp_float, int max_shared_mem, int sms,
|
||||
bool is_a_8bit) {
|
||||
bool is_k_full, bool has_zp, bool is_zp_float, bool is_a_8bit, int stages,
|
||||
int max_shared_mem, int sms) {
|
||||
exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}};
|
||||
thread_config_t* thread_configs = thread_m_blocks > 1
|
||||
? large_batch_thread_configs
|
||||
@ -284,15 +284,15 @@ exec_config_t determine_exec_config(
|
||||
|
||||
if (!is_valid_config(th_config, m_block_size_8, thread_m_blocks, prob_m,
|
||||
prob_n, prob_k, num_bits, group_size, has_act_order,
|
||||
is_k_full, has_zp, is_zp_float, max_shared_mem - 512,
|
||||
is_a_8bit)) {
|
||||
is_k_full, has_zp, is_zp_float, is_a_8bit, stages,
|
||||
max_shared_mem - 512)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int cache_size = get_kernel_cache_size(
|
||||
th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k,
|
||||
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float,
|
||||
is_a_8bit);
|
||||
is_a_8bit, stages);
|
||||
|
||||
int group_blocks = 0;
|
||||
if (!has_act_order) {
|
||||
@ -303,7 +303,7 @@ exec_config_t determine_exec_config(
|
||||
get_marlin_kernel(a_type, b_type, c_type, s_type, thread_m_blocks,
|
||||
th_config.thread_n / 16, th_config.thread_k / 16,
|
||||
m_block_size_8, has_act_order, has_zp, group_blocks,
|
||||
th_config.num_threads, is_zp_float);
|
||||
th_config.num_threads, is_zp_float, stages);
|
||||
|
||||
if (kernel == MarlinDefault) continue;
|
||||
|
||||
@ -433,8 +433,14 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
dev);
|
||||
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
|
||||
dev);
|
||||
TORCH_CHECK(major_capability * 10 + minor_capability >= 80,
|
||||
"marlin kernel only support Ampere or newer GPUs.");
|
||||
TORCH_CHECK(major_capability * 10 + minor_capability >= 75,
|
||||
"marlin kernel only support Turing or newer GPUs.");
|
||||
int stages = 4;
|
||||
if (major_capability == 7 && minor_capability == 5) {
|
||||
stages = 2;
|
||||
TORCH_CHECK(a_type == vllm::kFloat16 || a_type == vllm::kS8,
|
||||
"Turing only support FP16 or INT8 activation.");
|
||||
}
|
||||
if (a_type == vllm::kFE4M3fn) {
|
||||
TORCH_CHECK(major_capability * 10 + minor_capability >= 89,
|
||||
"FP8 only support Ada Lovelace or newer GPUs.");
|
||||
@ -461,8 +467,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
exec_cfg = determine_exec_config(
|
||||
a_type, b_type, c_type, s_type, prob_m, prob_n, prob_k, num_experts,
|
||||
top_k, thread_m_blocks, m_block_size_8, num_bits, group_size,
|
||||
has_act_order, is_k_full, has_zp, is_zp_float, max_shared_mem, sms,
|
||||
is_a_8bit);
|
||||
has_act_order, is_k_full, has_zp, is_zp_float, is_a_8bit, stages,
|
||||
max_shared_mem, sms);
|
||||
thread_tfg = exec_cfg.tb_cfg;
|
||||
}
|
||||
|
||||
@ -479,7 +485,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
TORCH_CHECK(is_valid_config(thread_tfg, m_block_size_8, thread_m_blocks,
|
||||
prob_m, prob_n, prob_k, num_bits, group_size,
|
||||
has_act_order, is_k_full, has_zp, is_zp_float,
|
||||
max_shared_mem, is_a_8bit),
|
||||
is_a_8bit, stages, max_shared_mem),
|
||||
"Invalid thread config: thread_m_blocks = ", thread_m_blocks,
|
||||
", thread_k = ", thread_tfg.thread_k,
|
||||
", thread_n = ", thread_tfg.thread_n,
|
||||
@ -493,12 +499,12 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
int sh_cache_size =
|
||||
get_kernel_cache_size(thread_tfg, m_block_size_8, thread_m_blocks, prob_m,
|
||||
prob_n, prob_k, num_bits, group_size, has_act_order,
|
||||
is_k_full, has_zp, is_zp_float, is_a_8bit);
|
||||
is_k_full, has_zp, is_zp_float, is_a_8bit, stages);
|
||||
|
||||
auto kernel = get_marlin_kernel(
|
||||
a_type, b_type, c_type, s_type, thread_m_blocks, thread_n_blocks,
|
||||
thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks,
|
||||
num_threads, is_zp_float);
|
||||
num_threads, is_zp_float, stages);
|
||||
|
||||
if (kernel == MarlinDefault) {
|
||||
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
|
||||
|
||||
1
csrc/quantization/gptq_marlin/.gitignore
vendored
1
csrc/quantization/gptq_marlin/.gitignore
vendored
@ -1,2 +1,3 @@
|
||||
sm*_kernel_*.cu
|
||||
kernel_selector.h
|
||||
kernel_*.cu
|
||||
|
||||
@ -67,7 +67,7 @@ where `scale_factor * multiplier` can be computed at weight loading.
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 750
|
||||
// Lookup-table based 3-input logical operation; explicitly used for
|
||||
// dequantization as the compiler does not seem to automatically recognize it in
|
||||
// all cases.
|
||||
|
||||
@ -10,6 +10,8 @@ import jinja2
|
||||
|
||||
ARCHS = []
|
||||
SUPPORT_FP8 = False
|
||||
SUPPORT_SM75 = False
|
||||
SUPPORT_SM80 = False
|
||||
for arch in sys.argv[1].split(","):
|
||||
arch = arch[: arch.index(".") + 2].replace(".", "")
|
||||
arch = int(arch)
|
||||
@ -19,6 +21,10 @@ for arch in sys.argv[1].split(","):
|
||||
# with FP16 MMA, so it cannot achieve any acceleration.
|
||||
if arch in [89, 120]:
|
||||
SUPPORT_FP8 = True
|
||||
if arch >= 80:
|
||||
SUPPORT_SM80 = True
|
||||
if arch == 75:
|
||||
SUPPORT_SM75 = True
|
||||
|
||||
FILE_HEAD_COMMENT = """
|
||||
// auto generated by generate_kernels.py
|
||||
@ -166,6 +172,7 @@ def remove_old_kernels():
|
||||
|
||||
def generate_new_kernels():
|
||||
result_dict = {}
|
||||
sm_75_result_dict = {}
|
||||
|
||||
for quant_config in QUANT_CONFIGS:
|
||||
c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"])
|
||||
@ -184,6 +191,8 @@ def generate_new_kernels():
|
||||
s_type = quant_config.get("s_type", c_type)
|
||||
if (a_type, b_type, c_type) not in result_dict:
|
||||
result_dict[(a_type, b_type, c_type)] = []
|
||||
if a_type in ["kFloat16", "kS8"] and c_type == "kFloat16":
|
||||
sm_75_result_dict[(a_type, b_type, c_type)] = []
|
||||
|
||||
for group_blocks, m_blocks, thread_configs in itertools.product(
|
||||
all_group_blocks, all_m_blocks, all_thread_configs
|
||||
@ -207,78 +216,89 @@ def generate_new_kernels():
|
||||
"thread_k_blocks": thread_k // 16,
|
||||
"thread_n_blocks": thread_n // 16,
|
||||
"m_block_size_8": "true" if m_blocks == 0.5 else "false",
|
||||
"stages": "pipe_stages",
|
||||
"stages": 4,
|
||||
"group_blocks": group_blocks,
|
||||
"is_zp_float": "true" if is_zp_float else "false",
|
||||
}
|
||||
|
||||
result_dict[(a_type, b_type, c_type)].append(config)
|
||||
if SUPPORT_SM80:
|
||||
result_dict[(a_type, b_type, c_type)].append(config)
|
||||
if (a_type, b_type, c_type) in sm_75_result_dict and SUPPORT_SM75:
|
||||
config_sm75 = config.copy()
|
||||
config_sm75["stages"] = 2
|
||||
sm_75_result_dict[(a_type, b_type, c_type)].append(config_sm75)
|
||||
|
||||
kernel_selector_str = FILE_HEAD_COMMENT
|
||||
|
||||
for (a_type, b_type, c_type), config_list in result_dict.items():
|
||||
all_template_str_list = []
|
||||
for config in config_list:
|
||||
s_type = config["s_type"]
|
||||
template_str = jinja2.Template(TEMPLATE).render(
|
||||
a_type_id=f"vllm::{a_type}.id()",
|
||||
b_type_id=f"vllm::{b_type}.id()",
|
||||
c_type_id=f"vllm::{c_type}.id()",
|
||||
s_type_id=f"vllm::{s_type}.id()",
|
||||
**config,
|
||||
)
|
||||
all_template_str_list.append(template_str)
|
||||
|
||||
conditions = [
|
||||
f"a_type == vllm::{a_type}",
|
||||
f"b_type == vllm::{b_type}",
|
||||
f"c_type == vllm::{c_type}",
|
||||
f"s_type == vllm::{s_type}",
|
||||
f"threads == {config['threads']}",
|
||||
f"thread_m_blocks == {config['thread_m_blocks']}",
|
||||
f"thread_n_blocks == {config['thread_n_blocks']}",
|
||||
f"thread_k_blocks == {config['thread_k_blocks']}",
|
||||
f"m_block_size_8 == {config['m_block_size_8']}",
|
||||
f"group_blocks == {config['group_blocks']}",
|
||||
f"is_zp_float == {config['is_zp_float']}",
|
||||
]
|
||||
conditions = " && ".join(conditions)
|
||||
|
||||
if kernel_selector_str == FILE_HEAD_COMMENT:
|
||||
kernel_selector_str += f"if ({conditions})\n kernel = "
|
||||
else:
|
||||
kernel_selector_str += f"else if ({conditions})\n kernel = "
|
||||
|
||||
kernel_template2 = (
|
||||
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
|
||||
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
|
||||
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
|
||||
"{{is_zp_float}}>;"
|
||||
)
|
||||
|
||||
kernel_selector_str += (
|
||||
jinja2.Template(kernel_template2).render(
|
||||
for result_dict_tmp in [result_dict, sm_75_result_dict]:
|
||||
for (a_type, b_type, c_type), config_list in result_dict_tmp.items():
|
||||
all_template_str_list = []
|
||||
if not config_list:
|
||||
continue
|
||||
for config in config_list:
|
||||
s_type = config["s_type"]
|
||||
template_str = jinja2.Template(TEMPLATE).render(
|
||||
a_type_id=f"vllm::{a_type}.id()",
|
||||
b_type_id=f"vllm::{b_type}.id()",
|
||||
c_type_id=f"vllm::{c_type}.id()",
|
||||
s_type_id=f"vllm::{s_type}.id()",
|
||||
**config,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
all_template_str_list.append(template_str)
|
||||
|
||||
file_content = FILE_HEAD + "\n\n"
|
||||
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
|
||||
if a_type == "kFE4M3fn":
|
||||
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
else:
|
||||
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
conditions = [
|
||||
f"a_type == vllm::{a_type}",
|
||||
f"b_type == vllm::{b_type}",
|
||||
f"c_type == vllm::{c_type}",
|
||||
f"s_type == vllm::{s_type}",
|
||||
f"threads == {config['threads']}",
|
||||
f"thread_m_blocks == {config['thread_m_blocks']}",
|
||||
f"thread_n_blocks == {config['thread_n_blocks']}",
|
||||
f"thread_k_blocks == {config['thread_k_blocks']}",
|
||||
f"m_block_size_8 == {config['m_block_size_8']}",
|
||||
f"stages == {config['stages']}",
|
||||
f"group_blocks == {config['group_blocks']}",
|
||||
f"is_zp_float == {config['is_zp_float']}",
|
||||
]
|
||||
conditions = " && ".join(conditions)
|
||||
|
||||
filename = filename.lower()
|
||||
if kernel_selector_str == FILE_HEAD_COMMENT:
|
||||
kernel_selector_str += f"if ({conditions})\n kernel = "
|
||||
else:
|
||||
kernel_selector_str += f"else if ({conditions})\n kernel = "
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
|
||||
f.write(file_content)
|
||||
kernel_template2 = (
|
||||
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
|
||||
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
|
||||
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
|
||||
"{{is_zp_float}}>;"
|
||||
)
|
||||
|
||||
kernel_selector_str += (
|
||||
jinja2.Template(kernel_template2).render(
|
||||
a_type_id=f"vllm::{a_type}.id()",
|
||||
b_type_id=f"vllm::{b_type}.id()",
|
||||
c_type_id=f"vllm::{c_type}.id()",
|
||||
s_type_id=f"vllm::{s_type}.id()",
|
||||
**config,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
file_content = FILE_HEAD + "\n\n"
|
||||
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
|
||||
if a_type == "kFE4M3fn":
|
||||
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
elif result_dict_tmp is sm_75_result_dict:
|
||||
filename = f"sm75_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
else:
|
||||
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
|
||||
filename = filename.lower()
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
|
||||
f.write(file_content)
|
||||
|
||||
if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT:
|
||||
kernel_selector_str += (
|
||||
|
||||
@ -37,7 +37,7 @@ __global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){};
|
||||
|
||||
using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS);
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
||||
|
||||
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
|
||||
int const* __restrict__ perm_int_ptr,
|
||||
@ -148,7 +148,7 @@ typedef struct {
|
||||
|
||||
int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
|
||||
int prob_n, int prob_k, int num_bits, int group_size,
|
||||
bool has_act_order, bool is_k_full) {
|
||||
bool has_act_order, bool is_k_full, int stages) {
|
||||
bool cache_scales_chunk = has_act_order && !is_k_full;
|
||||
|
||||
int tb_n = th_config.thread_n;
|
||||
@ -166,28 +166,29 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
|
||||
|
||||
if (cache_scales_chunk) {
|
||||
int load_groups =
|
||||
tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
|
||||
tb_groups * stages * 2; // Chunk size is 2x pipeline over dim K
|
||||
load_groups = max(load_groups, 32); // We load at least 32 scale groups
|
||||
return load_groups * tb_n * 2;
|
||||
} else {
|
||||
int tb_scales = tb_groups * tb_n * 2;
|
||||
|
||||
return tb_scales * pipe_stages;
|
||||
return tb_scales * stages;
|
||||
}
|
||||
}
|
||||
|
||||
int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
|
||||
int prob_m, int prob_n, int prob_k, int num_bits,
|
||||
int group_size, bool has_act_order, bool is_k_full,
|
||||
int has_zp, int is_zp_float) {
|
||||
int has_zp, bool is_zp_float, bool is_a_8bit,
|
||||
int stages) {
|
||||
int pack_factor = 32 / num_bits;
|
||||
|
||||
// Get B size
|
||||
int tb_k = th_config.thread_k;
|
||||
int tb_n = th_config.thread_n;
|
||||
int tb_m = thread_m_blocks * 16;
|
||||
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2;
|
||||
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
|
||||
int sh_a_size = stages * (tb_m * tb_k) * (is_a_8bit ? 1 : 2);
|
||||
int sh_b_size = stages * (tb_k * tb_n / pack_factor) * 4;
|
||||
int sh_red_size = tb_m * (tb_n + 8) * 2;
|
||||
int sh_bias_size = tb_n * 2;
|
||||
int tmp_size =
|
||||
@ -196,8 +197,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
|
||||
|
||||
int sh_s_size =
|
||||
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
|
||||
group_size, has_act_order, is_k_full);
|
||||
int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0;
|
||||
group_size, has_act_order, is_k_full, stages);
|
||||
int sh_g_idx_size = has_act_order && !is_k_full ? stages * tb_k / 4 : 0;
|
||||
int sh_zp_size = 0;
|
||||
if (has_zp) {
|
||||
if (is_zp_float)
|
||||
@ -217,7 +218,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
|
||||
bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
|
||||
int prob_m, int prob_n, int prob_k, int num_bits,
|
||||
int group_size, bool has_act_order, bool is_k_full,
|
||||
int has_zp, int is_zp_float, int max_shared_mem) {
|
||||
int has_zp, bool is_zp_float, bool is_a_8bit, int stages,
|
||||
int max_shared_mem) {
|
||||
// Sanity
|
||||
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
|
||||
th_config.num_threads == -1) {
|
||||
@ -242,7 +244,7 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
|
||||
// Check that pipeline fits into cache
|
||||
int cache_size = get_kernel_cache_size(
|
||||
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size,
|
||||
has_act_order, is_k_full, has_zp, is_zp_float);
|
||||
has_act_order, is_k_full, has_zp, is_zp_float, is_a_8bit, stages);
|
||||
return cache_size <= max_shared_mem;
|
||||
}
|
||||
|
||||
@ -251,7 +253,7 @@ MarlinFuncPtr get_marlin_kernel(
|
||||
const vllm::ScalarType c_type, const vllm::ScalarType s_type,
|
||||
int thread_m_blocks, int thread_n_blocks, int thread_k_blocks,
|
||||
bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks,
|
||||
int threads, bool is_zp_float) {
|
||||
int threads, bool is_zp_float, int stages) {
|
||||
int num_bits = b_type.size_bits();
|
||||
auto kernel = MarlinDefault;
|
||||
|
||||
@ -265,7 +267,8 @@ exec_config_t determine_exec_config(
|
||||
const vllm::ScalarType& c_type, const vllm::ScalarType& s_type, int prob_m,
|
||||
int prob_n, int prob_k, int thread_m_blocks, bool m_block_size_8,
|
||||
int num_bits, int group_size, bool has_act_order, bool is_k_full,
|
||||
bool has_zp, bool is_zp_float, int max_shared_mem, int sms) {
|
||||
bool has_zp, bool is_zp_float, int is_a_8bit, int stages,
|
||||
int max_shared_mem, int sms) {
|
||||
exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}};
|
||||
thread_config_t* thread_configs = thread_m_blocks > 1
|
||||
? large_batch_thread_configs
|
||||
@ -280,13 +283,15 @@ exec_config_t determine_exec_config(
|
||||
|
||||
if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k,
|
||||
num_bits, group_size, has_act_order, is_k_full, has_zp,
|
||||
is_zp_float, max_shared_mem - 512)) {
|
||||
is_zp_float, is_a_8bit, stages,
|
||||
max_shared_mem - 512)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int cache_size = get_kernel_cache_size(
|
||||
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits,
|
||||
group_size, has_act_order, is_k_full, has_zp, is_zp_float);
|
||||
int cache_size = get_kernel_cache_size(th_config, thread_m_blocks, prob_m,
|
||||
prob_n, prob_k, num_bits, group_size,
|
||||
has_act_order, is_k_full, has_zp,
|
||||
is_zp_float, is_a_8bit, stages);
|
||||
|
||||
int group_blocks = 0;
|
||||
if (!has_act_order) {
|
||||
@ -297,14 +302,10 @@ exec_config_t determine_exec_config(
|
||||
get_marlin_kernel(a_type, b_type, c_type, s_type, thread_m_blocks,
|
||||
th_config.thread_n / 16, th_config.thread_k / 16,
|
||||
m_block_size_8, has_act_order, has_zp, group_blocks,
|
||||
th_config.num_threads, is_zp_float);
|
||||
th_config.num_threads, is_zp_float, stages);
|
||||
|
||||
if (kernel == MarlinDefault) continue;
|
||||
|
||||
// int m_tiles = div_ceil(prob_m, thread_m_blocks * 16);
|
||||
// int n_tiles = prob_n / th_config.thread_n;
|
||||
// int k_tiles = prob_k / th_config.thread_k;
|
||||
|
||||
return {1, th_config};
|
||||
}
|
||||
|
||||
@ -321,6 +322,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
int group_size, int dev, cudaStream_t stream, int thread_k_init,
|
||||
int thread_n_init, int sms, bool use_atomic_add,
|
||||
bool use_fp32_reduce, bool is_zp_float) {
|
||||
bool is_a_8bit = a_type.size_bits() == 8;
|
||||
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
||||
", ", prob_n, ", ", prob_k, "]");
|
||||
|
||||
@ -389,8 +391,14 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
dev);
|
||||
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
|
||||
dev);
|
||||
TORCH_CHECK(major_capability * 10 + minor_capability >= 80,
|
||||
"marlin kernel only support Ampere or newer GPUs.");
|
||||
TORCH_CHECK(major_capability * 10 + minor_capability >= 75,
|
||||
"marlin kernel only support Turing or newer GPUs.");
|
||||
int stages = 4;
|
||||
if (major_capability == 7 && minor_capability == 5) {
|
||||
stages = 2;
|
||||
TORCH_CHECK(a_type == vllm::kFloat16 || a_type == vllm::kS8,
|
||||
"Turing only support FP16 or INT8 activation.");
|
||||
}
|
||||
if (a_type == vllm::kFE4M3fn) {
|
||||
TORCH_CHECK(
|
||||
major_capability * 10 + minor_capability == 89 ||
|
||||
@ -431,7 +439,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
exec_cfg = determine_exec_config(
|
||||
a_type, b_type, c_type, s_type, prob_m_split, prob_n, prob_k,
|
||||
thread_m_blocks, m_block_size_8, num_bits, group_size, has_act_order,
|
||||
is_k_full, has_zp, is_zp_float, max_shared_mem, sms);
|
||||
is_k_full, has_zp, is_zp_float, is_a_8bit, stages, max_shared_mem,
|
||||
sms);
|
||||
thread_tfg = exec_cfg.tb_cfg;
|
||||
if (thread_tfg.thread_n != -1) {
|
||||
if (prob_n / thread_tfg.thread_n *
|
||||
@ -440,7 +449,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
if (is_valid_config({128, 64, 128}, thread_m_blocks, prob_m_split,
|
||||
prob_n, prob_k, num_bits, group_size,
|
||||
has_act_order, is_k_full, has_zp, is_zp_float,
|
||||
max_shared_mem_new)) {
|
||||
is_a_8bit, stages, max_shared_mem_new)) {
|
||||
thread_tfg = {128, 64, 128};
|
||||
exec_cfg = {1, thread_tfg};
|
||||
}
|
||||
@ -466,7 +475,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
TORCH_CHECK(
|
||||
is_valid_config(thread_tfg, thread_m_blocks, prob_m_split, prob_n,
|
||||
prob_k, num_bits, group_size, has_act_order, is_k_full,
|
||||
has_zp, is_zp_float, max_shared_mem_new),
|
||||
has_zp, is_zp_float, is_a_8bit, stages,
|
||||
max_shared_mem_new),
|
||||
"Invalid thread config: thread_m_blocks = ", thread_m_blocks,
|
||||
", thread_k = ", thread_tfg.thread_k,
|
||||
", thread_n = ", thread_tfg.thread_n,
|
||||
@ -475,12 +485,12 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
", prob_m_split = ", prob_m_split, ", group_size = ", group_size,
|
||||
", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
|
||||
", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float,
|
||||
", max_shared_mem_new = ", max_shared_mem_new);
|
||||
", stages = ", stages, ", max_shared_mem_new = ", max_shared_mem_new);
|
||||
|
||||
auto kernel = get_marlin_kernel(
|
||||
a_type, b_type, c_type, s_type, thread_m_blocks, thread_n_blocks,
|
||||
thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks,
|
||||
num_threads, is_zp_float);
|
||||
num_threads, is_zp_float, stages);
|
||||
|
||||
if (kernel == MarlinDefault) {
|
||||
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
|
||||
|
||||
@ -1,17 +1,19 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/all.h>
|
||||
#ifndef _marlin_cuh
|
||||
#define _marlin_cuh
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <iostream>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <iostream>
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin
|
||||
#endif
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin
|
||||
#endif
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
@ -51,9 +53,51 @@ using I4 = Vec<int, 4>;
|
||||
|
||||
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
// No support for async
|
||||
#else
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
|
||||
__device__ inline void cp_async1_ca_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
if (pred) {
|
||||
reinterpret_cast<int32_t*>(smem_ptr)[0] =
|
||||
reinterpret_cast<const int32_t*>(glob_ptr)[0];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void cp_async2_ca_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
if (pred) {
|
||||
reinterpret_cast<int64_t*>(smem_ptr)[0] =
|
||||
reinterpret_cast<const int64_t*>(glob_ptr)[0];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void cp_async4_ca_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
if (pred) {
|
||||
reinterpret_cast<int4*>(smem_ptr)[0] =
|
||||
reinterpret_cast<const int4*>(glob_ptr)[0];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
if (pred) {
|
||||
reinterpret_cast<int4*>(smem_ptr)[0] =
|
||||
reinterpret_cast<const int4*>(glob_ptr)[0];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
||||
reinterpret_cast<int4*>(smem_ptr)[0] =
|
||||
reinterpret_cast<const int4*>(glob_ptr)[0];
|
||||
}
|
||||
|
||||
__device__ inline void cp_async_fence() {}
|
||||
|
||||
template <int n>
|
||||
__device__ inline void cp_async_wait() {}
|
||||
|
||||
#else
|
||||
|
||||
__device__ inline void cp_async1_ca_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
@ -126,6 +170,8 @@ __device__ inline void cp_async_wait() {
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
||||
}
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
|
||||
#endif
|
||||
269
csrc/quantization/gptq_marlin/marlin_mma.h
Normal file
269
csrc/quantization/gptq_marlin/marlin_mma.h
Normal file
@ -0,0 +1,269 @@
|
||||
|
||||
#include "marlin_dtypes.cuh"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
|
||||
// output/accumulation.
|
||||
template <vllm::ScalarTypeId type_id, bool use_fp16_accum, int k_size = 16>
|
||||
__device__ inline void mma(
|
||||
const typename MarlinScalarType<type_id>::FragA& a_frag,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b,
|
||||
typename MarlinScalarType<type_id>::FragC& frag_c, int idx = 0) {
|
||||
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
|
||||
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
||||
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
|
||||
if constexpr (!std::is_same<scalar_t, half>::value || k_size != 16) {
|
||||
static_assert(!use_fp16_accum);
|
||||
}
|
||||
|
||||
if constexpr (k_size == 16) {
|
||||
if constexpr (std::is_same<scalar_t, half>::value && !use_fp16_accum) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(b[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
|
||||
"f"(c[3]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[2]), "r"(a[3]), "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
|
||||
"f"(c[3]));
|
||||
#else
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
#endif
|
||||
} else if constexpr (std::is_same<scalar_t, half>::value &&
|
||||
use_fp16_accum) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
uint32_t* c = reinterpret_cast<uint32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
|
||||
"{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(b[0]), "r"(c[0]), "r"(c[1]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
|
||||
"{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(a[2]), "r"(a[3]), "r"(b[1]), "r"(c[0]), "r"(c[1]));
|
||||
#else
|
||||
uint32_t* c = reinterpret_cast<uint32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
|
||||
"{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"r"(c[0]), "r"(c[1]));
|
||||
#endif
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "f"(c[0]),
|
||||
"f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "r"(c[0]),
|
||||
"r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
}
|
||||
} else if (k_size == 32) {
|
||||
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(a[0]), "r"(b[0]), "r"(c[0]), "r"(c[1]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
||||
: "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(a[1]), "r"(b[0]), "r"(c[2]), "r"(c[3]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(a[2]), "r"(b[1]), "r"(c[0]), "r"(c[1]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
||||
: "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(a[3]), "r"(b[1]), "r"(c[2]), "r"(c[3]));
|
||||
#else
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <vllm::ScalarTypeId type_id, bool use_fp16_accum, int k_size = 16>
|
||||
__device__ inline void mma_trans(
|
||||
const typename MarlinScalarType<type_id>::FragA& a_frag,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b2,
|
||||
typename MarlinScalarType<type_id>::FragC& frag_c) {
|
||||
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
|
||||
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
||||
const uint32_t* b2 = reinterpret_cast<const uint32_t*>(&frag_b2);
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
|
||||
if constexpr (!std::is_same<scalar_t, half>::value || k_size != 16) {
|
||||
static_assert(!use_fp16_accum);
|
||||
}
|
||||
|
||||
if constexpr (k_size == 16) {
|
||||
if constexpr (std::is_same<scalar_t, half>::value && !use_fp16_accum) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
|
||||
"f"(c[3]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[1]), "r"(b2[1]), "r"(a[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
|
||||
"f"(c[3]));
|
||||
#else
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
#endif
|
||||
} else if constexpr (std::is_same<scalar_t, half>::value &&
|
||||
use_fp16_accum) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
uint32_t* c = reinterpret_cast<uint32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
|
||||
"{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
|
||||
"{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(b[1]), "r"(b2[1]), "r"(a[1]), "r"(c[0]), "r"(c[1]));
|
||||
#else
|
||||
uint32_t* c = reinterpret_cast<uint32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
|
||||
"{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"r"(c[0]), "r"(c[1]));
|
||||
#endif
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
|
||||
"f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]),
|
||||
"r"(c[3]));
|
||||
}
|
||||
} else {
|
||||
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(b[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
||||
: "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(b2[1]), "r"(a[0]), "r"(c[2]), "r"(c[3]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(b[0]), "r"(a[1]), "r"(c[0]), "r"(c[1]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
||||
: "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(b2[1]), "r"(a[1]), "r"(c[2]), "r"(c[3]));
|
||||
#else
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
@ -26,6 +26,7 @@
|
||||
#include "marlin.cuh"
|
||||
#include "marlin_dtypes.cuh"
|
||||
#include "dequant.h"
|
||||
#include "marlin_mma.h"
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||
@ -35,7 +36,7 @@
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
||||
|
||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const vllm::ScalarTypeId b_type_id, // weight MarlinScalarType id
|
||||
@ -75,137 +76,6 @@ __global__ void Marlin(
|
||||
|
||||
#else
|
||||
|
||||
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
|
||||
// output/accumulation.
|
||||
template <vllm::ScalarTypeId type_id, int k_size = 16>
|
||||
__device__ inline void mma(
|
||||
const typename MarlinScalarType<type_id>::FragA& a_frag,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b,
|
||||
typename MarlinScalarType<type_id>::FragC& frag_c, int idx = 0) {
|
||||
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
|
||||
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
||||
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
|
||||
if constexpr (k_size == 16) {
|
||||
if constexpr (std::is_same<scalar_t, half>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "f"(c[0]),
|
||||
"f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "r"(c[0]),
|
||||
"r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
}
|
||||
} else if (k_size == 32) {
|
||||
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <vllm::ScalarTypeId type_id, int k_size = 16>
|
||||
__device__ inline void mma_trans(
|
||||
const typename MarlinScalarType<type_id>::FragA& a_frag,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b2,
|
||||
typename MarlinScalarType<type_id>::FragC& frag_c) {
|
||||
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
|
||||
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
||||
const uint32_t* b2 = reinterpret_cast<const uint32_t*>(&frag_b2);
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
|
||||
if constexpr (k_size == 16) {
|
||||
if constexpr (std::is_same<scalar_t, half>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
|
||||
"f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]),
|
||||
"r"(c[3]));
|
||||
}
|
||||
} else {
|
||||
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
||||
// memory, directly in tensor core layout.
|
||||
template <int count, vllm::ScalarTypeId type_id>
|
||||
@ -415,6 +285,17 @@ __global__ void Marlin(
|
||||
if constexpr (a_type_id == vllm::kFE4M3fn.id()) return;
|
||||
#endif
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
// Turing TensorCore only supports fp16 and int8
|
||||
if constexpr (a_type_id != vllm::kFloat16.id() && a_type_id != vllm::kS8.id())
|
||||
return;
|
||||
#endif
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
constexpr bool use_fp16_accum = a_type_id == vllm::kFloat16.id();
|
||||
#else
|
||||
constexpr bool use_fp16_accum = false;
|
||||
#endif
|
||||
using Adtype = MarlinScalarType<a_type_id>;
|
||||
using Cdtype = MarlinScalarType<c_type_id>;
|
||||
const int4* A = A0;
|
||||
@ -873,10 +754,6 @@ __global__ void Marlin(
|
||||
constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride)
|
||||
: (stages * s_sh_stage);
|
||||
int4* sh_s = sh_zp + (stages * zp_sh_stage);
|
||||
// shared memory reused by reduction should be smaller than
|
||||
// shared memory used by weight.
|
||||
static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <=
|
||||
stages * b_sh_stage);
|
||||
int4* sh_a = sh_s + sh_s_size;
|
||||
|
||||
// Register storage for double buffer of shared memory reads.
|
||||
@ -1395,11 +1272,13 @@ __global__ void Marlin(
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks; i++) {
|
||||
if constexpr (m_block_size_8) {
|
||||
mma_trans<a_type_id>(frag_a[k2][i], frag_b0, frag_b1,
|
||||
frag_c[i][j][0]);
|
||||
mma_trans<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b0, frag_b1,
|
||||
frag_c[i][j][0]);
|
||||
} else {
|
||||
mma<a_type_id>(frag_a[k2][i], frag_b0, frag_c[i][j][0]);
|
||||
mma<a_type_id>(frag_a[k2][i], frag_b1, frag_c[i][j][1]);
|
||||
mma<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b0,
|
||||
frag_c[i][j][0]);
|
||||
mma<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b1,
|
||||
frag_c[i][j][1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1433,10 +1312,12 @@ __global__ void Marlin(
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks; i++) {
|
||||
mma<a_type_id, 32>(frag_a[k2][i], frag_b[0],
|
||||
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]);
|
||||
mma<a_type_id, 32>(frag_a[k2][i], frag_b[1],
|
||||
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]);
|
||||
mma<a_type_id, false, 32>(
|
||||
frag_a[k2][i], frag_b[0],
|
||||
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]);
|
||||
mma<a_type_id, false, 32>(
|
||||
frag_a[k2][i], frag_b[1],
|
||||
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]);
|
||||
}
|
||||
|
||||
if constexpr (group_blocks != -1) {
|
||||
@ -1956,6 +1837,21 @@ __global__ void Marlin(
|
||||
// While this pattern may not be the most readable, other ways of writing
|
||||
// the loop seemed to noticeably worse performance after compilation.
|
||||
if (slice_iters == 0) {
|
||||
// convert fp16 accum to fp32 for reduction
|
||||
if constexpr (use_fp16_accum) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < (thread_m_blocks * (is_a_8bit ? 2 : 4) * 2); i++) {
|
||||
float* frag_c_part_float = reinterpret_cast<float*>(frag_c) + i * 4;
|
||||
scalar_t* frag_c_part_half =
|
||||
reinterpret_cast<scalar_t*>(frag_c_part_float);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 3; i >= 0; i--) {
|
||||
frag_c_part_float[i] = Cdtype::num2float(frag_c_part_half[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (is_a_8bit) {
|
||||
float frag_a_s[2 * thread_m_blocks];
|
||||
|
||||
|
||||
@ -550,8 +550,8 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowPrefill(
|
||||
int rowEnd = rowEnds[rowIdx];
|
||||
|
||||
// Local pointers to this block
|
||||
outIndices += rowIdx * topK;
|
||||
logits += rowIdx * stride0;
|
||||
outIndices += static_cast<int64_t>(rowIdx) * topK;
|
||||
logits += static_cast<int64_t>(rowIdx) * stride0;
|
||||
|
||||
topKPerRowJob<kNumThreadsPerBlock, kNumBins, useRadixSort>(
|
||||
nullptr, logits, rowStart, rowEnd, outIndices, nullptr, stride1, topK);
|
||||
@ -576,19 +576,21 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode(
|
||||
|
||||
// Local pointers to this block
|
||||
if constexpr (!multipleBlocksPerRow && !mergeBlocks) {
|
||||
outIndices += rowIdx * topK;
|
||||
outIndices += static_cast<int64_t>(rowIdx) * topK;
|
||||
} else if constexpr (multipleBlocksPerRow) {
|
||||
const auto blockSize = rowEnd / gridDim.y; // 16384 / 2 = 8192
|
||||
rowStart = blockSize * blockIdx.y; // 8192 * 1 = 8192
|
||||
rowEnd = gridDim.y == blockIdx.y + 1 ? rowEnd : rowStart + blockSize;
|
||||
outIndices += rowIdx * gridDim.y * topK + blockIdx.y * topK;
|
||||
outLogits += rowIdx * gridDim.y * topK + blockIdx.y * topK;
|
||||
outIndices +=
|
||||
static_cast<int64_t>(rowIdx) * gridDim.y * topK + blockIdx.y * topK;
|
||||
outLogits +=
|
||||
static_cast<int64_t>(rowIdx) * gridDim.y * topK + blockIdx.y * topK;
|
||||
} else if constexpr (mergeBlocks) {
|
||||
rowEnd = numBlocksToMerge * topK;
|
||||
indices += rowIdx * numBlocksToMerge * topK;
|
||||
outIndices += rowIdx * topK;
|
||||
indices += static_cast<int64_t>(rowIdx) * numBlocksToMerge * topK;
|
||||
outIndices += static_cast<int64_t>(rowIdx) * topK;
|
||||
}
|
||||
logits += rowIdx * stride0;
|
||||
logits += static_cast<int64_t>(rowIdx) * stride0;
|
||||
|
||||
topKPerRowJob<kNumThreadsPerBlock, kNumBins, useRadixSort,
|
||||
multipleBlocksPerRow, mergeBlocks>(
|
||||
|
||||
@ -621,7 +621,7 @@ ENV UV_HTTP_TIMEOUT=500
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=requirements/kv_connectors.txt,target=/tmp/kv_connectors.txt,ro \
|
||||
if [ "$INSTALL_KV_CONNECTORS" = "true" ]; then \
|
||||
uv pip install --system -r /tmp/kv_connectors.txt; \
|
||||
uv pip install --system -r /tmp/kv_connectors.txt || true; \
|
||||
fi
|
||||
|
||||
ENV VLLM_USAGE_SOURCE production-docker-image
|
||||
|
||||
@ -109,7 +109,7 @@ Every plugin has three parts:
|
||||
- `init_device`: This function is called to set up the device for the worker.
|
||||
- `initialize_cache`: This function is called to set cache config for the worker.
|
||||
- `load_model`: This function is called to load the model weights to device.
|
||||
- `get_kv_cache_spaces`: This function is called to generate the kv cache spaces for the model.
|
||||
- `get_kv_cache_spec`: This function is called to generate the kv cache spec for the model.
|
||||
- `determine_available_memory`: This function is called to profiles the peak memory usage of the model to determine how much memory can be used for KV cache without OOMs.
|
||||
- `initialize_from_config`: This function is called to allocate device KV cache with the specified kv_cache_config
|
||||
- `execute_model`: This function is called every step to inference the model.
|
||||
|
||||
@ -181,3 +181,4 @@ If you have PRs touching the area, please feel free to ping the area owner for r
|
||||
|
||||
- Ascend NPU: [@wangxiyuan](https://github.com/wangxiyuan) and [see more details](https://vllm-ascend.readthedocs.io/en/latest/community/contributors.html#maintainers)
|
||||
- Intel Gaudi HPU [@xuechendi](https://github.com/xuechendi) and [@kzawora-intel](https://github.com/kzawora-intel)
|
||||
- Semantic Router: [@xunzhuo](https://github.com/xunzhuo), [@rootfs](https://github.com/rootfs) and [see more details](https://vllm-semantic-router.com/community/team)
|
||||
|
||||
@ -47,6 +47,8 @@ We currently support the following OpenAI APIs:
|
||||
- [Completions API](#completions-api) (`/v1/completions`)
|
||||
- Only applicable to [text generation models](../models/generative_models.md).
|
||||
- *Note: `suffix` parameter is not supported.*
|
||||
- [Responses API](#responses-api) (`/v1/responses`)
|
||||
- Only applicable to [text generation models](../models/generative_models.md).
|
||||
- [Chat Completions API](#chat-api) (`/v1/chat/completions`)
|
||||
- Only applicable to [text generation models](../models/generative_models.md) with a [chat template](../serving/openai_compatible_server.md#chat-template).
|
||||
- *Note: `user` parameter is ignored.*
|
||||
@ -229,6 +231,31 @@ The following extra parameters are supported:
|
||||
--8<-- "vllm/entrypoints/openai/protocol.py:chat-completion-extra-params"
|
||||
```
|
||||
|
||||
### Responses API
|
||||
|
||||
Our Responses API is compatible with [OpenAI's Responses API](https://platform.openai.com/docs/api-reference/responses);
|
||||
you can use the [official OpenAI Python client](https://github.com/openai/openai-python) to interact with it.
|
||||
|
||||
Code example: [examples/online_serving/openai_responses_client_with_tools.py](../../examples/online_serving/openai_responses_client_with_tools.py)
|
||||
|
||||
#### Extra parameters
|
||||
|
||||
The following extra parameters in the request object are supported:
|
||||
|
||||
??? code
|
||||
|
||||
```python
|
||||
--8<-- "vllm/entrypoints/openai/protocol.py:responses-extra-params"
|
||||
```
|
||||
|
||||
The following extra parameters in the response object are supported:
|
||||
|
||||
??? code
|
||||
|
||||
```python
|
||||
--8<-- "vllm/entrypoints/openai/protocol.py:responses-response-extra-params"
|
||||
```
|
||||
|
||||
### Embeddings API
|
||||
|
||||
Our Embeddings API is compatible with [OpenAI's Embeddings API](https://platform.openai.com/docs/api-reference/embeddings);
|
||||
|
||||
@ -6,7 +6,7 @@ requires = [
|
||||
"packaging>=24.2",
|
||||
"setuptools>=77.0.3,<81.0.0",
|
||||
"setuptools-scm>=8.0",
|
||||
"torch == 2.9.0",
|
||||
"torch == 2.9.1",
|
||||
"wheel",
|
||||
"jinja2",
|
||||
]
|
||||
|
||||
@ -4,7 +4,7 @@ ninja
|
||||
packaging>=24.2
|
||||
setuptools>=77.0.3,<81.0.0
|
||||
setuptools-scm>=8
|
||||
torch==2.9.0
|
||||
torch==2.9.1
|
||||
wheel
|
||||
jinja2>=3.1.6
|
||||
regex
|
||||
|
||||
@ -37,7 +37,7 @@ pyyaml
|
||||
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
|
||||
setuptools>=77.0.3,<81.0.0; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12
|
||||
einops # Required for Qwen2-VL.
|
||||
compressed-tensors == 0.12.2 # required for compressed-tensors
|
||||
compressed-tensors == 0.13.0 # required for compressed-tensors
|
||||
depyf==0.20.0 # required for profiling and debugging with compilation config
|
||||
cloudpickle # allows pickling lambda functions in model_executor/models/registry.py
|
||||
watchfiles # required for http server to monitor the updates of TLS files
|
||||
@ -50,5 +50,5 @@ ijson # Required for mistral streaming tool parser
|
||||
setproctitle # Used to set process names for better debugging and monitoring
|
||||
openai-harmony >= 0.0.3 # Required for gpt-oss
|
||||
anthropic == 0.71.0
|
||||
model-hosting-container-standards >= 0.1.9, < 1.0.0
|
||||
mcp
|
||||
model-hosting-container-standards >= 0.1.10, < 1.0.0
|
||||
mcp
|
||||
|
||||
@ -5,9 +5,9 @@ numba == 0.61.2 # Required for N-gram speculative decoding
|
||||
|
||||
# Dependencies for NVIDIA GPUs
|
||||
ray[cgraph]>=2.48.0 # Ray Compiled Graph, required for pipeline parallelism in V1.
|
||||
torch==2.9.0
|
||||
torchaudio==2.9.0
|
||||
torch==2.9.1
|
||||
torchaudio==2.9.1
|
||||
# These must be updated alongside torch
|
||||
torchvision==0.24.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
|
||||
torchvision==0.24.1 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
|
||||
# FlashInfer should be updated together with the Dockerfile
|
||||
flashinfer-python==0.5.3
|
||||
|
||||
@ -2,11 +2,11 @@
|
||||
-r common.txt
|
||||
|
||||
--extra-index-url https://download.pytorch.org/whl/rocm6.4
|
||||
torch==2.9.0
|
||||
torchvision==0.24.0
|
||||
torchaudio==2.9.0
|
||||
torch==2.9.1
|
||||
torchvision==0.24.1
|
||||
torchaudio==2.9.1
|
||||
|
||||
triton==3.5.0
|
||||
triton==3.5.1
|
||||
cmake>=3.26.1,<4
|
||||
packaging>=24.2
|
||||
setuptools>=77.0.3,<80.0.0
|
||||
|
||||
@ -24,9 +24,9 @@ soundfile # required for audio tests
|
||||
jiwer # required for audio tests
|
||||
tblib # for pickling test exceptions
|
||||
timm >=1.0.17 # required for internvl and gemma3n-mm test
|
||||
torch==2.9.0
|
||||
torchaudio==2.9.0
|
||||
torchvision==0.24.0
|
||||
torch==2.9.1
|
||||
torchaudio==2.9.1
|
||||
torchvision==0.24.1
|
||||
transformers_stream_generator # required for qwen-vl test
|
||||
matplotlib # required for qwen-vl test
|
||||
mistral_common[image,audio] >= 1.8.5 # required for voxtral test
|
||||
|
||||
@ -1123,7 +1123,7 @@ tomli==2.2.1
|
||||
# via schemathesis
|
||||
tomli-w==1.2.0
|
||||
# via schemathesis
|
||||
torch==2.9.0+cu129
|
||||
torch==2.9.1+cu129
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# accelerate
|
||||
@ -1152,7 +1152,7 @@ torch==2.9.0+cu129
|
||||
# torchvision
|
||||
# vector-quantize-pytorch
|
||||
# vocos
|
||||
torchaudio==2.9.0+cu129
|
||||
torchaudio==2.9.1+cu129
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# encodec
|
||||
@ -1165,7 +1165,7 @@ torchmetrics==1.7.4
|
||||
# pytorch-lightning
|
||||
# terratorch
|
||||
# torchgeo
|
||||
torchvision==0.24.0+cu129
|
||||
torchvision==0.24.1+cu129
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# lightly
|
||||
@ -1206,7 +1206,7 @@ transformers==4.57.3
|
||||
# transformers-stream-generator
|
||||
transformers-stream-generator==0.0.5
|
||||
# via -r requirements/test.in
|
||||
triton==3.5.0
|
||||
triton==3.5.1
|
||||
# via torch
|
||||
tritonclient==2.51.0
|
||||
# via
|
||||
|
||||
@ -67,7 +67,6 @@ def _fix_prompt_embed_outputs(
|
||||
@pytest.mark.parametrize("model_executor", ["uni", "mp"])
|
||||
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
|
||||
def test_models(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
hf_runner,
|
||||
model: str,
|
||||
backend: str,
|
||||
@ -77,48 +76,46 @@ def test_models(
|
||||
model_executor: str,
|
||||
enable_prompt_embeds: bool,
|
||||
) -> None:
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", backend)
|
||||
# 5042 tokens for gemma2
|
||||
# gemma2 has alternating sliding window size of 4096
|
||||
# we need a prompt with more than 4096 tokens to test the sliding window
|
||||
prompt = (
|
||||
"The following numbers of the sequence "
|
||||
+ ", ".join(str(i) for i in range(1024))
|
||||
+ " are:"
|
||||
)
|
||||
example_prompts = [prompt]
|
||||
|
||||
# 5042 tokens for gemma2
|
||||
# gemma2 has alternating sliding window size of 4096
|
||||
# we need a prompt with more than 4096 tokens to test the sliding window
|
||||
prompt = (
|
||||
"The following numbers of the sequence "
|
||||
+ ", ".join(str(i) for i in range(1024))
|
||||
+ " are:"
|
||||
)
|
||||
example_prompts = [prompt]
|
||||
with hf_runner(model) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
if enable_prompt_embeds:
|
||||
with torch.no_grad():
|
||||
prompt_embeds = hf_model.get_prompt_embeddings(example_prompts)
|
||||
|
||||
with hf_runner(model) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
if enable_prompt_embeds:
|
||||
with torch.no_grad():
|
||||
prompt_embeds = hf_model.get_prompt_embeddings(example_prompts)
|
||||
with VllmRunner(
|
||||
model,
|
||||
max_model_len=8192,
|
||||
enforce_eager=enforce_eager,
|
||||
enable_prompt_embeds=enable_prompt_embeds,
|
||||
gpu_memory_utilization=0.7,
|
||||
async_scheduling=async_scheduling,
|
||||
distributed_executor_backend=model_executor,
|
||||
attention_config={"backend": backend},
|
||||
) as vllm_model:
|
||||
if enable_prompt_embeds:
|
||||
vllm_outputs = vllm_model.generate_greedy(prompt_embeds, max_tokens)
|
||||
vllm_outputs = _fix_prompt_embed_outputs(
|
||||
vllm_outputs, hf_model, example_prompts
|
||||
)
|
||||
else:
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
with VllmRunner(
|
||||
model,
|
||||
max_model_len=8192,
|
||||
enforce_eager=enforce_eager,
|
||||
enable_prompt_embeds=enable_prompt_embeds,
|
||||
gpu_memory_utilization=0.7,
|
||||
async_scheduling=async_scheduling,
|
||||
distributed_executor_backend=model_executor,
|
||||
) as vllm_model:
|
||||
if enable_prompt_embeds:
|
||||
vllm_outputs = vllm_model.generate_greedy(prompt_embeds, max_tokens)
|
||||
vllm_outputs = _fix_prompt_embed_outputs(
|
||||
vllm_outputs, hf_model, example_prompts
|
||||
)
|
||||
else:
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@ -161,12 +158,6 @@ def test_models_distributed(
|
||||
): # noqa
|
||||
pytest.skip("enable_prompt_embeds does not work with ray compiled dag.")
|
||||
|
||||
if attention_backend:
|
||||
monkeypatch_context.setenv(
|
||||
"VLLM_ATTENTION_BACKEND",
|
||||
attention_backend,
|
||||
)
|
||||
|
||||
for k, v in extra_env.items():
|
||||
monkeypatch_context.setenv(k, v)
|
||||
|
||||
@ -178,6 +169,7 @@ def test_models_distributed(
|
||||
# if we run HF first, the cuda initialization will be done and it
|
||||
# will hurt multiprocessing backend with fork method
|
||||
# (the default method).
|
||||
attention_config = {"backend": attention_backend} if attention_backend else None
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
@ -185,6 +177,7 @@ def test_models_distributed(
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enable_prompt_embeds=enable_prompt_embeds,
|
||||
gpu_memory_utilization=0.7,
|
||||
attention_config=attention_config,
|
||||
) as vllm_model:
|
||||
if enable_prompt_embeds:
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
|
||||
@ -19,21 +19,18 @@ def server():
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_bench_serve(server):
|
||||
# Test default model detection and input/output len
|
||||
command = [
|
||||
"vllm",
|
||||
"bench",
|
||||
"serve",
|
||||
"--model",
|
||||
MODEL_NAME,
|
||||
"--host",
|
||||
server.host,
|
||||
"--port",
|
||||
str(server.port),
|
||||
"--dataset-name",
|
||||
"random",
|
||||
"--random-input-len",
|
||||
"--input-len",
|
||||
"32",
|
||||
"--random-output-len",
|
||||
"--output-len",
|
||||
"4",
|
||||
"--num-prompts",
|
||||
"5",
|
||||
|
||||
@ -208,7 +208,8 @@ def test_attn_quant(
|
||||
# To capture subprocess logs, we need to know whether spawn or fork is used.
|
||||
# Force spawn as it is more general.
|
||||
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
|
||||
|
||||
model_kwargs["attention_config"] = {"backend": backend.name}
|
||||
|
||||
compilation_config = CompilationConfig(
|
||||
# Testing properties
|
||||
@ -297,7 +298,8 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
|
||||
# To capture subprocess logs, we need to know whether spawn or fork is used.
|
||||
# Force spawn as it is more general.
|
||||
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
|
||||
|
||||
model_kwargs["attention_config"] = {"backend": backend.name}
|
||||
|
||||
compilation_config = CompilationConfig(
|
||||
# Testing properties
|
||||
@ -409,7 +411,8 @@ def test_tp2_attn_quant_async_tp(
|
||||
# To capture subprocess logs, we need to know whether spawn or fork is used.
|
||||
# Force spawn as it is more general.
|
||||
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
|
||||
|
||||
model_kwargs["attention_config"] = {"backend": backend.name}
|
||||
|
||||
compilation_config = CompilationConfig(
|
||||
# Testing properties
|
||||
@ -523,6 +526,8 @@ CUSTOM_OPS_QUANT_RMS_NORM = ["+quant_fp8,+rms_norm"]
|
||||
list[tuple[Any, ...]](flat_product(MODELS_GROUP_FP8, CUSTOM_OPS_QUANT_RMS_NORM)),
|
||||
)
|
||||
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
|
||||
# TODO: remove skip after we fix the fusion thoroughly
|
||||
@pytest.mark.skipif(is_blackwell(), reason="Temporarily disabled on Blackwell")
|
||||
def test_rms_group_quant(
|
||||
model_name: str,
|
||||
model_kwargs: dict[str, Any],
|
||||
@ -562,7 +567,9 @@ def test_rms_group_quant(
|
||||
splitting_ops=splitting_ops,
|
||||
# Common
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
pass_config=PassConfig(eliminate_noops=True, enable_fusion=True),
|
||||
pass_config=PassConfig(
|
||||
fuse_norm_quant=True, fuse_act_quant=True, eliminate_noops=True
|
||||
),
|
||||
# Inductor caches custom passes by default as well via uuid
|
||||
inductor_compile_config={"force_disable_caches": True},
|
||||
)
|
||||
|
||||
@ -89,7 +89,6 @@ class TestSetting:
|
||||
],
|
||||
)
|
||||
def test_compile_correctness(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
test_setting: TestSetting,
|
||||
):
|
||||
# this test is run under multiple suits, with different GPUs.
|
||||
@ -107,49 +106,48 @@ def test_compile_correctness(
|
||||
f"{cuda_device_count_stateless()}"
|
||||
)
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||
final_args = [
|
||||
*model_args,
|
||||
"-pp",
|
||||
str(pp_size),
|
||||
"-tp",
|
||||
str(tp_size),
|
||||
"-cc.cudagraph_mode=none",
|
||||
]
|
||||
final_args = [
|
||||
*model_args,
|
||||
"-pp",
|
||||
str(pp_size),
|
||||
"-tp",
|
||||
str(tp_size),
|
||||
"-cc.cudagraph_mode=none",
|
||||
f"--attention-backend={attn_backend}",
|
||||
]
|
||||
|
||||
all_args: list[list[str]] = []
|
||||
all_envs: list[dict[str, str] | None] = []
|
||||
all_args: list[list[str]] = []
|
||||
all_envs: list[dict[str, str] | None] = []
|
||||
|
||||
for comp_mode in [
|
||||
CompilationMode.STOCK_TORCH_COMPILE,
|
||||
CompilationMode.DYNAMO_TRACE_ONCE,
|
||||
CompilationMode.VLLM_COMPILE,
|
||||
]:
|
||||
for mode in [CompilationMode.NONE, comp_mode]:
|
||||
all_args.append(
|
||||
final_args + [f"-cc.mode={mode.name}", "-cc.backend=inductor"]
|
||||
)
|
||||
|
||||
# inductor will change the output, so we only compare if the output
|
||||
# is close, not exactly the same.
|
||||
compare_all_settings(
|
||||
model,
|
||||
all_args,
|
||||
all_envs,
|
||||
method=method if method != "generate" else "generate_close",
|
||||
for comp_mode in [
|
||||
CompilationMode.STOCK_TORCH_COMPILE,
|
||||
CompilationMode.DYNAMO_TRACE_ONCE,
|
||||
CompilationMode.VLLM_COMPILE,
|
||||
]:
|
||||
for mode in [CompilationMode.NONE, comp_mode]:
|
||||
all_args.append(
|
||||
final_args + [f"-cc.mode={mode.name}", "-cc.backend=inductor"]
|
||||
)
|
||||
all_envs.clear()
|
||||
all_args.clear()
|
||||
|
||||
for mode in [
|
||||
CompilationMode.NONE,
|
||||
CompilationMode.STOCK_TORCH_COMPILE,
|
||||
CompilationMode.DYNAMO_TRACE_ONCE,
|
||||
CompilationMode.VLLM_COMPILE,
|
||||
]:
|
||||
all_args.append(final_args + [f"-cc.mode={mode.name}", "-cc.backend=eager"])
|
||||
all_envs.append({})
|
||||
all_envs.append({})
|
||||
# inductor will change the output, so we only compare if the output
|
||||
# is close, not exactly the same.
|
||||
compare_all_settings(
|
||||
model,
|
||||
all_args,
|
||||
all_envs,
|
||||
method=method if method != "generate" else "generate_close",
|
||||
)
|
||||
all_envs.clear()
|
||||
all_args.clear()
|
||||
|
||||
compare_all_settings(model, all_args * 3, all_envs, method=method)
|
||||
for mode in [
|
||||
CompilationMode.NONE,
|
||||
CompilationMode.STOCK_TORCH_COMPILE,
|
||||
CompilationMode.DYNAMO_TRACE_ONCE,
|
||||
CompilationMode.VLLM_COMPILE,
|
||||
]:
|
||||
all_args.append(final_args + [f"-cc.mode={mode.name}", "-cc.backend=eager"])
|
||||
all_envs.append({})
|
||||
all_envs.append({})
|
||||
|
||||
compare_all_settings(model, all_args * 3, all_envs, method=method)
|
||||
|
||||
@ -74,7 +74,6 @@ def llm_pair(request):
|
||||
# Force native sampler to avoid potential nondeterminism in FlashInfer
|
||||
# when per-request generators are not used in V1.
|
||||
"VLLM_USE_FLASHINFER_SAMPLER": "0",
|
||||
**backend_config.env_vars,
|
||||
}
|
||||
with temporary_environ(env_vars):
|
||||
full = LLM(
|
||||
@ -170,16 +169,10 @@ class TestFullCUDAGraph:
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
|
||||
def test_full_cudagraph_with_invalid_backend():
|
||||
with (
|
||||
temporary_environ(
|
||||
{
|
||||
"VLLM_ATTENTION_BACKEND": "FLEX_ATTENTION",
|
||||
# Flex_Attention is not supported with full cuda graph
|
||||
}
|
||||
),
|
||||
pytest.raises(RuntimeError),
|
||||
):
|
||||
# Flex_Attention is not supported with full cuda graph
|
||||
with pytest.raises(RuntimeError):
|
||||
LLM(
|
||||
model="Qwen/Qwen2-1.5B-Instruct",
|
||||
compilation_config=CompilationConfig(cudagraph_mode="FULL"),
|
||||
attention_config={"backend": "FLEX_ATTENTION"},
|
||||
)
|
||||
|
||||
@ -197,20 +197,19 @@ def test_custom_compile_config(
|
||||
],
|
||||
)
|
||||
def test_fp8_kv_scale_compile(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
compilation_mode: int,
|
||||
model: str,
|
||||
backend: AttentionBackendEnum | None,
|
||||
):
|
||||
if backend:
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
|
||||
|
||||
model_kwargs = {
|
||||
"quantization": "fp8",
|
||||
"kv_cache_dtype": "fp8_e4m3",
|
||||
"calculate_kv_scales": True,
|
||||
"max_model_len": 512,
|
||||
}
|
||||
if backend:
|
||||
model_kwargs["attention_config"] = {"backend": backend.name}
|
||||
|
||||
run_model(compilation_mode, model, **model_kwargs)
|
||||
|
||||
|
||||
|
||||
@ -219,14 +219,12 @@ def _test_cp_gsm8k(
|
||||
]
|
||||
)
|
||||
|
||||
server_env = {}
|
||||
if attn_backend:
|
||||
server_env["VLLM_ATTENTION_BACKEND"] = attn_backend
|
||||
server_args.append(f"--attention-backend={attn_backend}")
|
||||
|
||||
with RemoteOpenAIServer(
|
||||
model_id,
|
||||
server_args,
|
||||
env_dict=server_env,
|
||||
max_wait_seconds=720,
|
||||
) as remote_server:
|
||||
host = f"http://{remote_server.host}"
|
||||
|
||||
@ -20,23 +20,21 @@ from ..utils import compare_two_settings, create_new_process_for_each_test
|
||||
)
|
||||
@create_new_process_for_each_test()
|
||||
def test_pp_cudagraph(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
PP_SIZE: int,
|
||||
MODEL_NAME: str,
|
||||
ATTN_BACKEND: LiteralString,
|
||||
):
|
||||
with monkeypatch.context() as m:
|
||||
cudagraph_args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"float16",
|
||||
"--pipeline-parallel-size",
|
||||
str(PP_SIZE),
|
||||
"--distributed-executor-backend",
|
||||
"mp",
|
||||
]
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", ATTN_BACKEND)
|
||||
cudagraph_args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"float16",
|
||||
"--pipeline-parallel-size",
|
||||
str(PP_SIZE),
|
||||
"--distributed-executor-backend",
|
||||
"mp",
|
||||
f"--attention-backend={ATTN_BACKEND}",
|
||||
]
|
||||
|
||||
eager_args = cudagraph_args + ["--enforce-eager"]
|
||||
eager_args = cudagraph_args + ["--enforce-eager"]
|
||||
|
||||
compare_two_settings(MODEL_NAME, eager_args, cudagraph_args)
|
||||
compare_two_settings(MODEL_NAME, eager_args, cudagraph_args)
|
||||
|
||||
@ -9,7 +9,7 @@ from typing import Annotated, Literal
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import CompilationConfig, config
|
||||
from vllm.config import AttentionConfig, CompilationConfig, config
|
||||
from vllm.engine.arg_utils import (
|
||||
EngineArgs,
|
||||
contains_type,
|
||||
@ -298,6 +298,139 @@ def test_compilation_config():
|
||||
)
|
||||
|
||||
|
||||
def test_attention_config():
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||
|
||||
# default value
|
||||
args = parser.parse_args([])
|
||||
assert args is not None
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
assert engine_args.attention_config == AttentionConfig()
|
||||
|
||||
# set backend via dot notation
|
||||
args = parser.parse_args(["--attention-config.backend", "FLASH_ATTN"])
|
||||
assert args is not None
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
assert engine_args.attention_config.backend is not None
|
||||
assert engine_args.attention_config.backend.name == "FLASH_ATTN"
|
||||
|
||||
# set backend via --attention-backend shorthand
|
||||
args = parser.parse_args(["--attention-backend", "FLASHINFER"])
|
||||
assert args is not None
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
assert engine_args.attention_backend is not None
|
||||
assert engine_args.attention_backend == "FLASHINFER"
|
||||
|
||||
# set all fields via dot notation
|
||||
args = parser.parse_args(
|
||||
[
|
||||
"--attention-config.backend",
|
||||
"FLASH_ATTN",
|
||||
"--attention-config.flash_attn_version",
|
||||
"3",
|
||||
"--attention-config.use_prefill_decode_attention",
|
||||
"true",
|
||||
"--attention-config.flash_attn_max_num_splits_for_cuda_graph",
|
||||
"16",
|
||||
"--attention-config.use_cudnn_prefill",
|
||||
"true",
|
||||
"--attention-config.use_trtllm_ragged_deepseek_prefill",
|
||||
"true",
|
||||
"--attention-config.use_trtllm_attention",
|
||||
"true",
|
||||
"--attention-config.disable_flashinfer_prefill",
|
||||
"true",
|
||||
"--attention-config.disable_flashinfer_q_quantization",
|
||||
"true",
|
||||
]
|
||||
)
|
||||
assert args is not None
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
assert engine_args.attention_config.backend is not None
|
||||
assert engine_args.attention_config.backend.name == "FLASH_ATTN"
|
||||
assert engine_args.attention_config.flash_attn_version == 3
|
||||
assert engine_args.attention_config.use_prefill_decode_attention is True
|
||||
assert engine_args.attention_config.flash_attn_max_num_splits_for_cuda_graph == 16
|
||||
assert engine_args.attention_config.use_cudnn_prefill is True
|
||||
assert engine_args.attention_config.use_trtllm_ragged_deepseek_prefill is True
|
||||
assert engine_args.attention_config.use_trtllm_attention is True
|
||||
assert engine_args.attention_config.disable_flashinfer_prefill is True
|
||||
assert engine_args.attention_config.disable_flashinfer_q_quantization is True
|
||||
|
||||
# set to string form of a dict with all fields
|
||||
args = parser.parse_args(
|
||||
[
|
||||
"--attention-config="
|
||||
'{"backend": "FLASHINFER", "flash_attn_version": 2, '
|
||||
'"use_prefill_decode_attention": false, '
|
||||
'"flash_attn_max_num_splits_for_cuda_graph": 8, '
|
||||
'"use_cudnn_prefill": false, '
|
||||
'"use_trtllm_ragged_deepseek_prefill": false, '
|
||||
'"use_trtllm_attention": false, '
|
||||
'"disable_flashinfer_prefill": false, '
|
||||
'"disable_flashinfer_q_quantization": false}',
|
||||
]
|
||||
)
|
||||
assert args is not None
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
assert engine_args.attention_config.backend is not None
|
||||
assert engine_args.attention_config.backend.name == "FLASHINFER"
|
||||
assert engine_args.attention_config.flash_attn_version == 2
|
||||
assert engine_args.attention_config.use_prefill_decode_attention is False
|
||||
assert engine_args.attention_config.flash_attn_max_num_splits_for_cuda_graph == 8
|
||||
assert engine_args.attention_config.use_cudnn_prefill is False
|
||||
assert engine_args.attention_config.use_trtllm_ragged_deepseek_prefill is False
|
||||
assert engine_args.attention_config.use_trtllm_attention is False
|
||||
assert engine_args.attention_config.disable_flashinfer_prefill is False
|
||||
assert engine_args.attention_config.disable_flashinfer_q_quantization is False
|
||||
|
||||
# test --attention-backend flows into VllmConfig.attention_config
|
||||
args = parser.parse_args(
|
||||
[
|
||||
"--model",
|
||||
"facebook/opt-125m",
|
||||
"--attention-backend",
|
||||
"FLASH_ATTN",
|
||||
]
|
||||
)
|
||||
assert args is not None
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
assert vllm_config.attention_config.backend == AttentionBackendEnum.FLASH_ATTN
|
||||
|
||||
# test --attention-config.backend flows into VllmConfig.attention_config
|
||||
args = parser.parse_args(
|
||||
[
|
||||
"--model",
|
||||
"facebook/opt-125m",
|
||||
"--attention-config.backend",
|
||||
"FLASHINFER",
|
||||
]
|
||||
)
|
||||
assert args is not None
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
assert vllm_config.attention_config.backend == AttentionBackendEnum.FLASHINFER
|
||||
|
||||
# test --attention-backend and --attention-config.backend are mutually exclusive
|
||||
args = parser.parse_args(
|
||||
[
|
||||
"--model",
|
||||
"facebook/opt-125m",
|
||||
"--attention-backend",
|
||||
"FLASH_ATTN",
|
||||
"--attention-config.backend",
|
||||
"FLASHINFER",
|
||||
]
|
||||
)
|
||||
assert args is not None
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
with pytest.raises(ValueError, match="mutually exclusive"):
|
||||
engine_args.create_engine_config()
|
||||
|
||||
|
||||
def test_prefix_cache_default():
|
||||
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||
args = parser.parse_args([])
|
||||
|
||||
0
tests/entrypoints/instrumentator/__init__.py
Normal file
0
tests/entrypoints/instrumentator/__init__.py
Normal file
@ -14,11 +14,10 @@ import requests
|
||||
from prometheus_client.parser import text_string_to_metric_families
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from tests.conftest import LocalAssetServer
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm import version
|
||||
|
||||
from ...conftest import LocalAssetServer
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODELS = {
|
||||
"text": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"multimodal": "HuggingFaceTB/SmolVLM-256M-Instruct",
|
||||
@ -254,7 +254,9 @@ async def test_single_chat_session_input_audio(
|
||||
async def test_chat_streaming_audio(
|
||||
client: openai.AsyncOpenAI, model_name: str, audio_url: str
|
||||
):
|
||||
messages = dummy_messages_from_audio_url(audio_url)
|
||||
messages = dummy_messages_from_audio_url(
|
||||
audio_url, "What's a short title for this audio?"
|
||||
)
|
||||
|
||||
# test single completion
|
||||
chat_completion = await client.chat.completions.create(
|
||||
|
||||
@ -76,15 +76,10 @@ def default_server_args(with_tool_parser: bool):
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def gptoss_server(
|
||||
monkeypatch_module: pytest.MonkeyPatch, default_server_args: list[str]
|
||||
):
|
||||
with monkeypatch_module.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
|
||||
with RemoteOpenAIServer(
|
||||
GPT_OSS_MODEL_NAME, default_server_args
|
||||
) as remote_server:
|
||||
yield remote_server
|
||||
def gptoss_server(default_server_args: list[str]):
|
||||
server_args = default_server_args + ["--attention-backend=TRITON_ATTN"]
|
||||
with RemoteOpenAIServer(GPT_OSS_MODEL_NAME, server_args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
|
||||
@ -244,3 +244,35 @@ async def test_audio_with_timestamp(mary_had_lamb, whisper_client):
|
||||
)
|
||||
assert transcription.segments is not None
|
||||
assert len(transcription.segments) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_with_max_tokens(whisper_client, mary_had_lamb):
|
||||
transcription = await whisper_client.audio.transcriptions.create(
|
||||
model=MODEL_NAME,
|
||||
file=mary_had_lamb,
|
||||
language="en",
|
||||
response_format="text",
|
||||
temperature=0.0,
|
||||
extra_body={"max_completion_tokens": 1},
|
||||
)
|
||||
out = json.loads(transcription)
|
||||
out_text = out["text"]
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tok = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
out_tokens = tok(out_text, add_special_tokens=False)["input_ids"]
|
||||
assert len(out_tokens) == 1
|
||||
# max_completion_tokens > max_model_len
|
||||
transcription = await whisper_client.audio.transcriptions.create(
|
||||
model=MODEL_NAME,
|
||||
file=mary_had_lamb,
|
||||
language="en",
|
||||
response_format="text",
|
||||
temperature=0.0,
|
||||
extra_body={"max_completion_tokens": int(1e6)},
|
||||
)
|
||||
out = json.loads(transcription)
|
||||
out_text = out["text"]
|
||||
out_tokens = tok(out_text, add_special_tokens=False)["input_ids"]
|
||||
assert len(out_tokens) < 450 # ~Whisper max output len
|
||||
|
||||
@ -227,3 +227,36 @@ async def test_long_audio_request(foscolo, client_and_model):
|
||||
)
|
||||
out = json.loads(translation)["text"].strip().lower()
|
||||
assert out.count("greek sea") == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_with_max_tokens(mary_had_lamb, client_and_model):
|
||||
client, model_name = client_and_model
|
||||
transcription = await client.audio.translations.create(
|
||||
model=model_name,
|
||||
file=mary_had_lamb,
|
||||
response_format="text",
|
||||
temperature=0.0,
|
||||
extra_body={"max_completion_tokens": 1},
|
||||
)
|
||||
out = json.loads(transcription)
|
||||
out_text = out["text"]
|
||||
print(out_text)
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tok = AutoTokenizer.from_pretrained(model_name)
|
||||
out_tokens = tok(out_text, add_special_tokens=False)["input_ids"]
|
||||
assert len(out_tokens) == 1
|
||||
# max_completion_tokens > max_model_len
|
||||
transcription = await client.audio.transcriptions.create(
|
||||
model=model_name,
|
||||
file=mary_had_lamb,
|
||||
response_format="text",
|
||||
temperature=0.0,
|
||||
extra_body={"max_completion_tokens": int(1e6)},
|
||||
)
|
||||
out = json.loads(transcription)
|
||||
out_text = out["text"]
|
||||
print(out_text)
|
||||
out_tokens = tok(out_text, add_special_tokens=False)["input_ids"]
|
||||
assert len(out_tokens) < 450 # ~Whisper max output len
|
||||
|
||||
0
tests/entrypoints/rpc/__init__.py
Normal file
0
tests/entrypoints/rpc/__init__.py
Normal file
@ -37,7 +37,7 @@ def server():
|
||||
"--max-num-seqs",
|
||||
"128",
|
||||
"--worker-extension-cls",
|
||||
"tests.entrypoints.openai.test_collective_rpc.TestWorkerExtension",
|
||||
"tests.entrypoints.rpc.test_collective_rpc.TestWorkerExtension",
|
||||
]
|
||||
with RemoteOpenAIServer(
|
||||
MODEL_NAME,
|
||||
0
tests/entrypoints/sleep/__init__.py
Normal file
0
tests/entrypoints/sleep/__init__.py
Normal file
@ -4,7 +4,7 @@
|
||||
import requests
|
||||
from prometheus_client.parser import text_string_to_metric_families
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "meta-llama/Llama-3.2-1B"
|
||||
|
||||
@ -7,9 +7,8 @@ This directory contains a replacement for the lm-eval-harness GSM8K evaluation,
|
||||
### Run tests with pytest (like buildkite)
|
||||
|
||||
```bash
|
||||
pytest -s -v tests/gsm8k/test_gsm8k_correctness.py \
|
||||
--config-list-file=configs/models-small.txt \
|
||||
--tp-size=1
|
||||
pytest -s -v tests/evals/gsm8k/test_gsm8k_correctness.py \
|
||||
--config-list-file=configs/models-small.txt
|
||||
```
|
||||
|
||||
### Run standalone evaluation script
|
||||
@ -31,5 +30,11 @@ model_name: "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
accuracy_threshold: 0.54 # Minimum expected accuracy
|
||||
num_questions: 1319 # Number of questions (default: full test set)
|
||||
num_fewshot: 5 # Few-shot examples from train set
|
||||
max_model_len: 4096 # Model context length
|
||||
server_args: "--max-model-len 4096 --tensor-parallel-size 2" # Server arguments
|
||||
env: # Environment variables (optional)
|
||||
VLLM_USE_FLASHINFER_MOE_FP4: "1"
|
||||
```
|
||||
|
||||
The `server_args` field accepts any arguments that can be passed to `vllm serve`.
|
||||
|
||||
The `env` field accepts a dictionary of environment variables to set for the server process.
|
||||
|
||||
@ -2,5 +2,4 @@ model_name: "RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8"
|
||||
accuracy_threshold: 0.72
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
max_model_len: 4096
|
||||
|
||||
server_args: "--enforce-eager --max-model-len 4096"
|
||||
|
||||
@ -2,4 +2,4 @@ model_name: "nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test"
|
||||
accuracy_threshold: 0.74
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
max_model_len: 4096
|
||||
server_args: "--enforce-eager --max-model-len 4096"
|
||||
|
||||
@ -2,4 +2,4 @@ model_name: "RedHatAI/Llama-3.2-1B-Instruct-quantized.w8a8"
|
||||
accuracy_threshold: 0.31
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
max_model_len: 4096
|
||||
server_args: "--enforce-eager --max-model-len 4096"
|
||||
|
||||
@ -2,4 +2,4 @@ model_name: "nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16"
|
||||
accuracy_threshold: 0.45
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
max_model_len: 4096
|
||||
server_args: "--enforce-eager --max-model-len 4096"
|
||||
|
||||
@ -2,4 +2,4 @@ model_name: "RedHatAI/Qwen2.5-VL-3B-Instruct-FP8-Dynamic"
|
||||
accuracy_threshold: 0.60
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
max_model_len: 4096
|
||||
server_args: "--enforce-eager --max-model-len 4096"
|
||||
|
||||
@ -2,4 +2,4 @@ model_name: "Qwen/Qwen3-0.6B-FP8"
|
||||
accuracy_threshold: 0.375
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
max_model_len: 4096
|
||||
server_args: "--enforce-eager --max-model-len 4096"
|
||||
|
||||
@ -2,5 +2,4 @@ model_name: "nvidia/Qwen3-30B-A3B-FP4"
|
||||
accuracy_threshold: 0.89
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
max_model_len: 4096
|
||||
|
||||
server_args: "--enforce-eager --max-model-len 4096"
|
||||
|
||||
12
tests/evals/gsm8k/configs/Qwen3-Next-80B-A3B-NVFP4-EP2.yaml
Normal file
12
tests/evals/gsm8k/configs/Qwen3-Next-80B-A3B-NVFP4-EP2.yaml
Normal file
@ -0,0 +1,12 @@
|
||||
model_name: "nm-testing/Qwen3-Next-80B-A3B-Instruct-NVFP4"
|
||||
accuracy_threshold: 0.75
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
server_args: >-
|
||||
--enforce-eager
|
||||
--max-model-len 4096
|
||||
--tensor-parallel-size 2
|
||||
--enable-expert-parallel
|
||||
--speculative-config '{"method":"qwen3_next_mtp","num_speculative_tokens":1}'
|
||||
env:
|
||||
VLLM_USE_FLASHINFER_MOE_FP4: "1"
|
||||
@ -3,3 +3,4 @@ Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml
|
||||
Qwen1.5-MoE-W4A16-CT.yaml
|
||||
DeepSeek-V2-Lite-Instruct-FP8.yaml
|
||||
Qwen3-30B-A3B-NVFP4.yaml
|
||||
Qwen3-Next-80B-A3B-NVFP4-EP2.yaml
|
||||
|
||||
@ -11,14 +11,12 @@ def pytest_addoption(parser):
|
||||
default="configs/models-small.txt",
|
||||
help="File containing list of config files to test",
|
||||
)
|
||||
parser.addoption("--tp-size", default=1, type=int, help="Tensor parallel size")
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc):
|
||||
"""Generate test parameters from config files."""
|
||||
if "config_filename" in metafunc.fixturenames:
|
||||
config_list_file = metafunc.config.getoption("--config-list-file")
|
||||
tp_size = metafunc.config.getoption("--tp-size")
|
||||
|
||||
# Handle both relative and absolute paths
|
||||
config_list_path = Path(config_list_file)
|
||||
@ -55,9 +53,9 @@ def pytest_generate_tests(metafunc):
|
||||
# Generate test parameters
|
||||
if config_files:
|
||||
metafunc.parametrize(
|
||||
["config_filename", "tp_size"],
|
||||
[(config_file, int(tp_size)) for config_file in config_files],
|
||||
ids=[f"{config_file.stem}-tp{tp_size}" for config_file in config_files],
|
||||
"config_filename",
|
||||
config_files,
|
||||
ids=[config_file.stem for config_file in config_files],
|
||||
)
|
||||
else:
|
||||
print("No config files found, test will be skipped")
|
||||
|
||||
@ -5,30 +5,31 @@ GSM8K evaluation using vLLM server and isolated GSM8K script.
|
||||
Replacement for lm-eval-harness with better performance and control.
|
||||
|
||||
Usage:
|
||||
pytest -s -v test_gsm8k_correctness.py \
|
||||
--config-list-file=configs/models-small.txt \
|
||||
--tp-size=1
|
||||
pytest -s -v tests/evals/gsm8k/test_gsm8k_correctness.py \
|
||||
--config-list-file=configs/models-small.txt
|
||||
"""
|
||||
|
||||
import shlex
|
||||
|
||||
import yaml
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
from .gsm8k_eval import evaluate_gsm8k
|
||||
|
||||
RTOL = 0.08 # Relative tolerance for accuracy comparison
|
||||
TOL = 0.08 # Absolute tolerance for accuracy comparison
|
||||
|
||||
|
||||
def launch_gsm8k_eval(eval_config, server_url, tp_size):
|
||||
"""Launch GSM8K evaluation using our isolated script."""
|
||||
def run_gsm8k_eval(eval_config: dict, server_url: str) -> dict:
|
||||
"""Run GSM8K evaluation using our isolated script."""
|
||||
# Extract host and port from server URL
|
||||
if "://" in server_url:
|
||||
server_url = server_url.split("://")[1]
|
||||
|
||||
host_port = server_url.split("/")[0] # Remove path if present
|
||||
if ":" in host_port:
|
||||
host, port = host_port.split(":")
|
||||
port = int(port)
|
||||
host, p = host_port.split(":")
|
||||
port = int(p)
|
||||
else:
|
||||
host = host_port
|
||||
port = 8000
|
||||
@ -48,46 +49,57 @@ def launch_gsm8k_eval(eval_config, server_url, tp_size):
|
||||
return results
|
||||
|
||||
|
||||
def test_gsm8k_correctness_param(config_filename, tp_size):
|
||||
def test_gsm8k_correctness(config_filename):
|
||||
"""Test GSM8K correctness for a given model configuration."""
|
||||
eval_config = yaml.safe_load(config_filename.read_text(encoding="utf-8"))
|
||||
|
||||
# Server arguments
|
||||
server_args = [
|
||||
"--max-model-len",
|
||||
str(eval_config.get("max_model_len", 4096)),
|
||||
"--enforce-eager",
|
||||
"--trust-remote-code",
|
||||
"--tensor-parallel-size",
|
||||
str(tp_size),
|
||||
]
|
||||
# Parse server arguments from config (use shlex to handle quoted strings)
|
||||
server_args_str = eval_config.get("server_args", "")
|
||||
server_args = shlex.split(server_args_str) if server_args_str else []
|
||||
|
||||
# Add standard server arguments
|
||||
server_args.extend(
|
||||
[
|
||||
"--trust-remote-code",
|
||||
]
|
||||
)
|
||||
|
||||
env_dict = eval_config.get("env", None)
|
||||
|
||||
print(f"Starting GSM8K evaluation for model: {eval_config['model_name']}")
|
||||
print(f"Expected metric threshold: {eval_config['accuracy_threshold']}")
|
||||
print(f"Number of questions: {eval_config['num_questions']}")
|
||||
print(f"Number of few-shot examples: {eval_config['num_fewshot']}")
|
||||
print(f"Server args: {' '.join(server_args)}")
|
||||
|
||||
# Launch server and run evaluation
|
||||
with RemoteOpenAIServer(
|
||||
eval_config["model_name"], server_args, env_dict=env_dict, max_wait_seconds=480
|
||||
eval_config["model_name"],
|
||||
server_args,
|
||||
env_dict=env_dict,
|
||||
max_wait_seconds=600,
|
||||
) as remote_server:
|
||||
server_url = remote_server.url_for("v1")
|
||||
print(f"Server started at: {server_url}")
|
||||
|
||||
results = launch_gsm8k_eval(eval_config, server_url, tp_size)
|
||||
results = run_gsm8k_eval(eval_config, server_url)
|
||||
|
||||
# Check accuracy against threshold
|
||||
measured_accuracy = results["accuracy"]
|
||||
expected_accuracy = eval_config["accuracy_threshold"]
|
||||
measured_metric = results["accuracy"]
|
||||
expected_metric = eval_config["accuracy_threshold"]
|
||||
|
||||
print(f"GSM8K Results for {eval_config['model_name']}:")
|
||||
print(f" Accuracy: {measured_accuracy:.3f}")
|
||||
print(f" Expected: {expected_accuracy:.3f}")
|
||||
print(f" Measured metric: {measured_metric:.4f}")
|
||||
print(f" Expected metric: {expected_metric:.4f}")
|
||||
print(f" Tolerance: {TOL:.4f}")
|
||||
print(f" Questions: {results['num_questions']}")
|
||||
print(f" Invalid rate: {results['invalid_rate']:.3f}")
|
||||
print(f" Latency: {results['latency']:.1f}s")
|
||||
print(f" QPS: {results['questions_per_second']:.1f}")
|
||||
|
||||
# Verify accuracy is within tolerance
|
||||
assert measured_accuracy >= expected_accuracy - RTOL, (
|
||||
f"Accuracy too low: {measured_accuracy:.3f} < "
|
||||
f"{expected_accuracy:.3f} - {RTOL:.3f}"
|
||||
# Verify metric is within tolerance
|
||||
assert measured_metric >= expected_metric - TOL, (
|
||||
f"GSM8K metric too low: {measured_metric:.4f} < "
|
||||
f"{expected_metric:.4f} - {TOL:.4f} = {expected_metric - TOL:.4f}"
|
||||
)
|
||||
|
||||
print(f"✅ GSM8K test passed for {eval_config['model_name']}")
|
||||
|
||||
@ -6,7 +6,9 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
|
||||
from vllm.config import AttentionConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.cpu import CpuPlatform
|
||||
from vllm.platforms.cuda import CudaPlatform
|
||||
@ -73,18 +75,18 @@ def generate_params():
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device, name, use_mla, block_size", generate_params())
|
||||
def test_env(
|
||||
def test_backend_selection(
|
||||
device: str,
|
||||
name: str,
|
||||
use_mla: bool,
|
||||
block_size: int,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
"""Test attention backend selection with valid device-backend pairs."""
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", name)
|
||||
m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0")
|
||||
# Create AttentionConfig with the specified backend
|
||||
attention_config = AttentionConfig(backend=AttentionBackendEnum[name])
|
||||
vllm_config = VllmConfig(attention_config=attention_config)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
if device == "cpu":
|
||||
with patch("vllm.platforms.current_platform", CpuPlatform()):
|
||||
backend = get_attn_backend(16, torch.float16, None, block_size)
|
||||
@ -217,27 +219,32 @@ def test_env(
|
||||
@pytest.mark.parametrize("device", ["cpu", "cuda"])
|
||||
def test_fp32_fallback(device: str):
|
||||
"""Test attention backend selection with fp32."""
|
||||
if device == "cpu":
|
||||
with patch("vllm.platforms.current_platform", CpuPlatform()):
|
||||
backend = get_attn_backend(16, torch.float32, None, 16)
|
||||
assert backend.get_name() == "CPU_ATTN"
|
||||
# Use default config (no backend specified)
|
||||
vllm_config = VllmConfig()
|
||||
|
||||
elif device == "cuda":
|
||||
with patch("vllm.platforms.current_platform", CudaPlatform()):
|
||||
backend = get_attn_backend(16, torch.float32, None, 16)
|
||||
assert backend.get_name() == "FLEX_ATTENTION"
|
||||
with set_current_vllm_config(vllm_config):
|
||||
if device == "cpu":
|
||||
with patch("vllm.platforms.current_platform", CpuPlatform()):
|
||||
backend = get_attn_backend(16, torch.float32, None, 16)
|
||||
assert backend.get_name() == "CPU_ATTN"
|
||||
|
||||
elif device == "cuda":
|
||||
with patch("vllm.platforms.current_platform", CudaPlatform()):
|
||||
backend = get_attn_backend(16, torch.float32, None, 16)
|
||||
assert backend.get_name() == "FLEX_ATTENTION"
|
||||
|
||||
|
||||
def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test FlashAttn validation."""
|
||||
pytest.skip(
|
||||
"Skipping as current backend selector does not "
|
||||
"handle fallbacks when a backend is set via env var."
|
||||
"handle fallbacks when a backend is explicitly set."
|
||||
)
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLASH_ATTN")
|
||||
attention_config = AttentionConfig(backend=AttentionBackendEnum.FLASH_ATTN)
|
||||
vllm_config = VllmConfig(attention_config=attention_config)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
# Unsupported CUDA arch
|
||||
monkeypatch.setattr(torch.cuda, "get_device_capability", lambda _=None: (7, 5))
|
||||
backend = get_attn_backend(16, torch.float16, None, 16)
|
||||
@ -277,15 +284,10 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
|
||||
assert backend.get_name() != "FLASH_ATTN"
|
||||
|
||||
|
||||
def test_invalid_env(monkeypatch: pytest.MonkeyPatch):
|
||||
def test_invalid_backend():
|
||||
"""Test that invalid attention backend names raise ValueError."""
|
||||
with (
|
||||
monkeypatch.context() as m,
|
||||
patch("vllm.platforms.current_platform", CudaPlatform()),
|
||||
pytest.raises(ValueError),
|
||||
):
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "INVALID")
|
||||
|
||||
# Should raise ValueError for invalid backend
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
get_attn_backend(32, torch.float16, None, 16)
|
||||
assert "Invalid value 'INVALID'" in str(exc_info.value)
|
||||
# Invalid backend name should raise ValueError when creating enum
|
||||
AttentionConfig(backend=AttentionBackendEnum["INVALID"])
|
||||
|
||||
@ -455,3 +455,38 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol),
|
||||
f"{torch.max(torch.abs(output - output_trtllm))}",
|
||||
)
|
||||
|
||||
|
||||
def test_trtllm_attention_rejects_num_kv_heads_1() -> None:
|
||||
"""Test that TRTLLM attention correctly rejects num_kv_heads=1.
|
||||
|
||||
When num_kv_heads=1 (MQA), the KV cache strides become degenerate
|
||||
(stride_heads == stride_batch), which causes CUDA's cuTensorMapEncodeTiled
|
||||
to fail because TMA descriptors cannot handle degenerate 4D tensors with
|
||||
singleton dimensions.
|
||||
|
||||
This test verifies that can_use_trtllm_attention returns False for
|
||||
num_kv_heads=1 configurations.
|
||||
"""
|
||||
from vllm.utils.flashinfer import can_use_trtllm_attention
|
||||
|
||||
# num_kv_heads=1 should be rejected
|
||||
assert not can_use_trtllm_attention(num_qo_heads=64, num_kv_heads=1), (
|
||||
"can_use_trtllm_attention should return False for num_kv_heads=1"
|
||||
)
|
||||
assert not can_use_trtllm_attention(num_qo_heads=32, num_kv_heads=1), (
|
||||
"can_use_trtllm_attention should return False for num_kv_heads=1"
|
||||
)
|
||||
|
||||
# num_kv_heads > 1 should be accepted (if platform supports it)
|
||||
# Note: This may return False on non-Blackwell platforms, which is fine
|
||||
result_kv8 = can_use_trtllm_attention(num_qo_heads=64, num_kv_heads=8)
|
||||
result_kv1 = can_use_trtllm_attention(num_qo_heads=64, num_kv_heads=1)
|
||||
|
||||
# Even if platform doesn't support TRTLLM, num_kv_heads=1 should never
|
||||
# return True when num_kv_heads > 1 returns True
|
||||
if result_kv8:
|
||||
assert not result_kv1, (
|
||||
"If TRTLLM is supported for num_kv_heads=8, "
|
||||
"it must be rejected for num_kv_heads=1"
|
||||
)
|
||||
|
||||
@ -4,7 +4,9 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
|
||||
from vllm.config import AttentionConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.platforms.rocm import RocmPlatform
|
||||
|
||||
|
||||
@ -16,40 +18,56 @@ def clear_cache():
|
||||
|
||||
@pytest.mark.skip(reason="Skipped for now. Should be revisited.")
|
||||
def test_selector(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_ATTN")
|
||||
# Set the current platform to ROCm using monkeypatch
|
||||
monkeypatch.setattr("vllm.attention.selector.current_platform", RocmPlatform())
|
||||
|
||||
# Set the current platform to ROCm using monkeypatch
|
||||
monkeypatch.setattr("vllm.attention.selector.current_platform", RocmPlatform())
|
||||
# Test standard ROCm attention
|
||||
attention_config = AttentionConfig(backend=AttentionBackendEnum.ROCM_ATTN)
|
||||
vllm_config = VllmConfig(attention_config=attention_config)
|
||||
|
||||
# Test standard ROCm attention
|
||||
with set_current_vllm_config(vllm_config):
|
||||
backend = get_attn_backend(16, torch.float16, torch.float16, 16, False)
|
||||
assert backend.get_name() == "ROCM_FLASH" or backend.get_name() == "TRITON_ATTN"
|
||||
|
||||
# MLA test for deepseek related
|
||||
# MLA test for deepseek related
|
||||
# Change the attention backend to triton MLA
|
||||
attention_config = AttentionConfig(backend=AttentionBackendEnum.TRITON_MLA)
|
||||
vllm_config = VllmConfig(attention_config=attention_config)
|
||||
|
||||
# change the attention backend to triton MLA
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_MLA")
|
||||
with set_current_vllm_config(vllm_config):
|
||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True)
|
||||
assert backend.get_name() == "TRITON_MLA"
|
||||
|
||||
# If attention backend is None
|
||||
# If use_mla is true
|
||||
# The selected backend is triton MLA
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "")
|
||||
# If attention backend is None
|
||||
# If use_mla is true
|
||||
# The selected backend is triton MLA
|
||||
attention_config = AttentionConfig(backend=None)
|
||||
vllm_config = VllmConfig(attention_config=attention_config)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True)
|
||||
assert backend.get_name() == "TRITON_MLA"
|
||||
|
||||
# change the attention backend to AITER MLA
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_AITER_MLA")
|
||||
# Change the attention backend to AITER MLA
|
||||
attention_config = AttentionConfig(backend=AttentionBackendEnum.ROCM_AITER_MLA)
|
||||
vllm_config = VllmConfig(attention_config=attention_config)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True)
|
||||
assert backend.get_name() == "ROCM_AITER_MLA"
|
||||
|
||||
# If attention backend is None
|
||||
# If use_mla is true
|
||||
# If VLLM_ROCM_USE_AITER is enabled
|
||||
# The selected backend is ROCM_AITER_MLA
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "")
|
||||
# If attention backend is None
|
||||
# If use_mla is true
|
||||
# If VLLM_ROCM_USE_AITER is enabled
|
||||
# The selected backend is ROCM_AITER_MLA
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True)
|
||||
assert backend.get_name() == "ROCM_AITER_MLA"
|
||||
|
||||
attention_config = AttentionConfig(backend=None)
|
||||
vllm_config = VllmConfig(attention_config=attention_config)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
backend = get_attn_backend(
|
||||
576, torch.bfloat16, "auto", 1, False, use_mla=True
|
||||
)
|
||||
assert backend.get_name() == "ROCM_AITER_MLA"
|
||||
|
||||
@ -9,8 +9,8 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
GroupedTopk,
|
||||
fused_grouped_topk,
|
||||
grouped_topk,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@ -50,15 +50,17 @@ def test_grouped_topk(
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0")
|
||||
baseline_topk_weights, baseline_topk_ids = grouped_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=gating_output,
|
||||
grouped_topk = GroupedTopk(
|
||||
topk=topk,
|
||||
renormalize=renormalize,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
baseline_topk_weights, baseline_topk_ids = grouped_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=gating_output,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
|
||||
@ -37,7 +37,7 @@ def set_seed(seed):
|
||||
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
|
||||
reason="CUDA not available or PyTorch version < 2.7",
|
||||
)
|
||||
def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
|
||||
def test_flex_attention_vs_default_backend(vllm_runner):
|
||||
"""Test that FlexAttention produces the same outputs as the default backend.
|
||||
|
||||
This test compares the outputs from the FlexAttention backend with
|
||||
@ -54,35 +54,32 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
|
||||
]
|
||||
|
||||
# Run with flex attention
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
|
||||
|
||||
set_seed(seed)
|
||||
with vllm_runner(
|
||||
model_name,
|
||||
runner="generate",
|
||||
tensor_parallel_size=1,
|
||||
num_gpu_blocks_override=128,
|
||||
enforce_eager=True,
|
||||
) as llm_flex:
|
||||
output_flex = llm_flex.generate_greedy_logprobs(
|
||||
prompts, max_tokens, num_logprobs
|
||||
)
|
||||
set_seed(seed)
|
||||
with vllm_runner(
|
||||
model_name,
|
||||
runner="generate",
|
||||
tensor_parallel_size=1,
|
||||
num_gpu_blocks_override=128,
|
||||
enforce_eager=True,
|
||||
attention_config={"backend": "FLEX_ATTENTION"},
|
||||
) as llm_flex:
|
||||
output_flex = llm_flex.generate_greedy_logprobs(
|
||||
prompts, max_tokens, num_logprobs
|
||||
)
|
||||
|
||||
# Run with default backend
|
||||
with monkeypatch.context() as m:
|
||||
set_seed(seed)
|
||||
with vllm_runner(
|
||||
model_name,
|
||||
runner="generate",
|
||||
tensor_parallel_size=1,
|
||||
num_gpu_blocks_override=128,
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.85,
|
||||
) as llm_default:
|
||||
output_default = llm_default.generate_greedy_logprobs(
|
||||
prompts, max_tokens, num_logprobs
|
||||
)
|
||||
set_seed(seed)
|
||||
with vllm_runner(
|
||||
model_name,
|
||||
runner="generate",
|
||||
tensor_parallel_size=1,
|
||||
num_gpu_blocks_override=128,
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.85,
|
||||
) as llm_default:
|
||||
output_default = llm_default.generate_greedy_logprobs(
|
||||
prompts, max_tokens, num_logprobs
|
||||
)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=output_flex,
|
||||
@ -96,7 +93,7 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
|
||||
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
|
||||
reason="CUDA not available or PyTorch version < 2.7",
|
||||
)
|
||||
def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
|
||||
def test_encoder_flex_attention_vs_default_backend(vllm_runner):
|
||||
"""Test that FlexAttention produces the same outputs as the default backend.
|
||||
|
||||
This test compares the outputs from the FlexAttention backend with
|
||||
@ -110,30 +107,26 @@ def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
|
||||
]
|
||||
|
||||
# Run with flex attention
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
|
||||
with vllm_runner(
|
||||
model_name,
|
||||
runner="pooling",
|
||||
dtype=torch.bfloat16,
|
||||
tensor_parallel_size=1,
|
||||
max_model_len=100,
|
||||
enforce_eager=True,
|
||||
) as llm_flex:
|
||||
flex_outputs = llm_flex.embed(prompts)
|
||||
with vllm_runner(
|
||||
model_name,
|
||||
runner="pooling",
|
||||
dtype=torch.bfloat16,
|
||||
tensor_parallel_size=1,
|
||||
max_model_len=100,
|
||||
enforce_eager=True,
|
||||
attention_config={"backend": "FLEX_ATTENTION"},
|
||||
) as llm_flex:
|
||||
flex_outputs = llm_flex.embed(prompts)
|
||||
|
||||
# Run with default backend
|
||||
with (
|
||||
monkeypatch.context() as m,
|
||||
vllm_runner(
|
||||
model_name,
|
||||
runner="pooling",
|
||||
dtype=torch.bfloat16,
|
||||
tensor_parallel_size=1,
|
||||
max_model_len=100,
|
||||
enforce_eager=True,
|
||||
) as llm_default,
|
||||
):
|
||||
with vllm_runner(
|
||||
model_name,
|
||||
runner="pooling",
|
||||
dtype=torch.bfloat16,
|
||||
tensor_parallel_size=1,
|
||||
max_model_len=100,
|
||||
enforce_eager=True,
|
||||
) as llm_default:
|
||||
default_outputs = llm_default.embed(prompts)
|
||||
|
||||
check_embeddings_close(
|
||||
|
||||
@ -35,10 +35,12 @@ audio_lora_path = MODEL_NAME
|
||||
models = [MODEL_NAME]
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def set_attention_backend_for_rocm(monkeypatch):
|
||||
@pytest.fixture
|
||||
def granite_speech_attention_config():
|
||||
"""Return attention config for Granite Speech tests on ROCm."""
|
||||
if current_platform.is_rocm():
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
|
||||
return {"backend": "TRITON_ATTN"}
|
||||
return None
|
||||
|
||||
|
||||
def run_test(
|
||||
@ -53,6 +55,7 @@ def run_test(
|
||||
num_logprobs: int,
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: str | None = None,
|
||||
attention_config: dict | None = None,
|
||||
):
|
||||
"""Inference result should be the same between hf and vllm.
|
||||
|
||||
@ -80,6 +83,7 @@ def run_test(
|
||||
enable_lora=True,
|
||||
max_lora_rank=64,
|
||||
enforce_eager=True,
|
||||
attention_config=attention_config,
|
||||
) as vllm_model:
|
||||
lora_request = LoRARequest("audio", 1, audio_lora_path)
|
||||
vllm_outputs_per_case = [
|
||||
@ -131,6 +135,7 @@ def test_models(
|
||||
vllm_runner,
|
||||
model: str,
|
||||
audio_assets: AudioTestAssets,
|
||||
granite_speech_attention_config,
|
||||
dtype: str,
|
||||
max_model_len: int,
|
||||
max_tokens: int,
|
||||
@ -157,4 +162,5 @@ def test_models(
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=1,
|
||||
attention_config=granite_speech_attention_config,
|
||||
)
|
||||
|
||||
@ -2,23 +2,17 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Pytest configuration for vLLM pooling tests."""
|
||||
|
||||
import os
|
||||
import warnings
|
||||
import pytest
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
"""Set FLEX_ATTENTION backend for SigLIP tests on ROCm."""
|
||||
if not current_platform.is_rocm():
|
||||
return
|
||||
@pytest.fixture
|
||||
def siglip_attention_config():
|
||||
"""Return attention config for SigLIP tests on ROCm.
|
||||
|
||||
siglip_tests = [item for item in items if "test_siglip" in item.nodeid]
|
||||
|
||||
if siglip_tests:
|
||||
os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION"
|
||||
warnings.warn(
|
||||
"ROCm: Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION for SigLIP tests",
|
||||
UserWarning,
|
||||
stacklevel=1,
|
||||
)
|
||||
On ROCm, SigLIP tests require FLEX_ATTENTION backend.
|
||||
"""
|
||||
if current_platform.is_rocm():
|
||||
return {"backend": "FLEX_ATTENTION"}
|
||||
return None
|
||||
|
||||
@ -38,6 +38,7 @@ def _run_test(
|
||||
*,
|
||||
dtype: str,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
attention_config: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
if tokenization_kwargs is None:
|
||||
tokenization_kwargs = {}
|
||||
@ -49,6 +50,7 @@ def _run_test(
|
||||
enforce_eager=True,
|
||||
max_model_len=64,
|
||||
gpu_memory_utilization=0.7,
|
||||
attention_config=attention_config,
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.embed(
|
||||
input_texts, images=input_images, tokenization_kwargs=tokenization_kwargs
|
||||
@ -90,6 +92,7 @@ def test_models_text(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
siglip_attention_config,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
@ -108,6 +111,7 @@ def test_models_text(
|
||||
"padding": "max_length",
|
||||
"max_length": 64,
|
||||
}, # siglip2 was trained with this padding setting.
|
||||
attention_config=siglip_attention_config,
|
||||
)
|
||||
|
||||
|
||||
@ -117,6 +121,7 @@ def test_models_image(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
siglip_attention_config,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
@ -133,6 +138,7 @@ def test_models_image(
|
||||
input_images,
|
||||
model,
|
||||
dtype=dtype,
|
||||
attention_config=siglip_attention_config,
|
||||
)
|
||||
|
||||
|
||||
@ -141,6 +147,7 @@ def test_models_image(
|
||||
def test_models_text_image_no_crash(
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
siglip_attention_config,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
@ -154,6 +161,7 @@ def test_models_text_image_no_crash(
|
||||
enforce_eager=True,
|
||||
max_model_len=64,
|
||||
gpu_memory_utilization=0.7,
|
||||
attention_config=siglip_attention_config,
|
||||
) as vllm_model:
|
||||
with pytest.raises(ValueError, match="not both"):
|
||||
vllm_model.embed(texts, images=images)
|
||||
|
||||
@ -60,12 +60,12 @@ def test_profiling(model_id: str, max_model_len: int):
|
||||
total_num_patches.item() + num_tiles.item() + 3
|
||||
) # image start, image, image end
|
||||
|
||||
profiled_tokens = profiler.get_mm_max_contiguous_tokens(
|
||||
profiled_tokens = profiler.get_mm_max_tokens(
|
||||
max_model_len,
|
||||
mm_counts=mm_counts,
|
||||
)
|
||||
|
||||
assert total_tokens == profiled_tokens["image"]
|
||||
assert total_num_patches == profiled_tokens["image"]
|
||||
assert total_tokens == sum(
|
||||
placeholder.length
|
||||
for placeholder in decoder_dummy_data.multi_modal_placeholders["image"]
|
||||
|
||||
@ -75,7 +75,6 @@ def test_models(
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("TOKENIZERS_PARALLELISM", "true")
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", backend)
|
||||
|
||||
MAX_MODEL_LEN = 1024
|
||||
NUM_LOG_PROBS = 8
|
||||
@ -86,6 +85,7 @@ def test_models(
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
enforce_eager=enforce_eager,
|
||||
kv_cache_dtype="auto",
|
||||
attention_config={"backend": backend},
|
||||
) as vllm_model:
|
||||
baseline_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, NUM_LOG_PROBS
|
||||
@ -97,6 +97,7 @@ def test_models(
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
enforce_eager=enforce_eager,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
attention_config={"backend": backend},
|
||||
) as vllm_model:
|
||||
test_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, NUM_LOG_PROBS
|
||||
|
||||
@ -108,11 +108,12 @@ def can_initialize(
|
||||
patch.object(V1EngineCore, "_initialize_kv_caches", _initialize_kv_caches_v1),
|
||||
monkeypatch.context() as m,
|
||||
):
|
||||
if model_arch == "GptOssForCausalLM":
|
||||
# FIXME: A hack to bypass FA3 assertion because our CI's L4 GPU
|
||||
# has cc==8.9 which hasn't supported FA3 yet. Remove this hack when
|
||||
# L4 supports FA3.
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
|
||||
# FIXME: A hack to bypass FA3 assertion because our CI's L4 GPU
|
||||
# has cc==8.9 which hasn't supported FA3 yet. Remove this hack when
|
||||
# L4 supports FA3.
|
||||
attention_config = (
|
||||
{"backend": "TRITON_ATTN"} if model_arch == "GptOssForCausalLM" else None
|
||||
)
|
||||
if model_arch == "WhisperForConditionalGeneration":
|
||||
m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
||||
|
||||
@ -143,6 +144,7 @@ def can_initialize(
|
||||
else "vllm",
|
||||
hf_overrides=hf_overrides_fn,
|
||||
max_num_seqs=model_info.max_num_seqs,
|
||||
attention_config=attention_config,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -9,6 +9,7 @@ from tempfile import NamedTemporaryFile, TemporaryDirectory
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image, ImageChops
|
||||
|
||||
from vllm.multimodal.image import convert_image_mode
|
||||
@ -410,6 +411,97 @@ def test_argsort_mm_positions(case):
|
||||
assert modality_idxs == expected_modality_idxs
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"is_embed,expected",
|
||||
[
|
||||
(None, 5),
|
||||
(torch.tensor([True, True, True, True, True]), 5),
|
||||
(torch.tensor([False, False, False, False, False]), 0),
|
||||
(torch.tensor([True, False, True, False, True]), 3),
|
||||
(torch.tensor([True]), 1),
|
||||
],
|
||||
)
|
||||
def test_placeholder_range_get_num_embeds(is_embed, expected):
|
||||
length = len(is_embed) if is_embed is not None else 5
|
||||
pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed)
|
||||
assert pr.get_num_embeds == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"is_embed,expected",
|
||||
[
|
||||
(None, None),
|
||||
(
|
||||
torch.tensor([False, True, False, True, True]),
|
||||
torch.tensor([0, 1, 1, 2, 3]),
|
||||
),
|
||||
(torch.tensor([True, True, True]), torch.tensor([1, 2, 3])),
|
||||
],
|
||||
)
|
||||
def test_placeholder_range_embeds_cumsum(is_embed, expected):
|
||||
length = len(is_embed) if is_embed is not None else 5
|
||||
pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed)
|
||||
|
||||
if expected is None:
|
||||
assert pr.embeds_cumsum is None
|
||||
return
|
||||
|
||||
assert torch.equal(pr.embeds_cumsum, expected)
|
||||
# cached_property should return the same object on repeated access
|
||||
assert pr.embeds_cumsum is pr.embeds_cumsum
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"is_embed,start_idx,end_idx,expected",
|
||||
[
|
||||
(None, 2, 4, (2, 4)),
|
||||
(
|
||||
torch.tensor([False, True, False, True, True]),
|
||||
3,
|
||||
5,
|
||||
(1, 3),
|
||||
),
|
||||
(
|
||||
torch.tensor([False, True, False, True, True]),
|
||||
0,
|
||||
2,
|
||||
(0, 1),
|
||||
),
|
||||
(
|
||||
torch.tensor([True, False, True, False]),
|
||||
2,
|
||||
2,
|
||||
(1, 1),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_placeholder_range_get_embeds_indices_in_range(
|
||||
is_embed, start_idx, end_idx, expected
|
||||
):
|
||||
length = len(is_embed) if is_embed is not None else 5
|
||||
pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed)
|
||||
assert pr.get_embeds_indices_in_range(start_idx, end_idx) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"offset,is_embed,expected",
|
||||
[
|
||||
(0, None, [(0, 4)]),
|
||||
(
|
||||
2,
|
||||
torch.tensor([False, True, False, True, True]),
|
||||
[(3, 3), (5, 6)],
|
||||
),
|
||||
(0, torch.tensor([True, True, True, True]), [(0, 3)]),
|
||||
(0, torch.tensor([False, False, False, False]), []),
|
||||
],
|
||||
)
|
||||
def test_placeholder_range_extract_embeds_range(offset, is_embed, expected):
|
||||
length = len(is_embed) if is_embed is not None else 5
|
||||
pr = PlaceholderRange(offset=offset, length=length, is_embed=is_embed)
|
||||
assert pr.extract_embeds_range() == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
|
||||
@pytest.mark.parametrize("num_frames", [-1, 32, 1800])
|
||||
|
||||
119
tests/tool_use/test_minimax_m2_tool_parser.py
Normal file
119
tests/tool_use/test_minimax_m2_tool_parser.py
Normal file
@ -0,0 +1,119 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.tool_parsers.minimax_m2_tool_parser import (
|
||||
MinimaxM2ToolParser,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
class FakeTokenizer:
|
||||
"""Minimal fake tokenizer that exposes the attributes used by the
|
||||
parser: a truthy model_tokenizer marker and a vocab mapping for the
|
||||
special tokens.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.model_tokenizer = True
|
||||
# The parser will look up start/end tokens by their literal strings
|
||||
self.vocab = {
|
||||
"<minimax:tool_call>": 1,
|
||||
"</minimax:tool_call>": 2,
|
||||
}
|
||||
|
||||
def get_vocab(self):
|
||||
return self.vocab
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def minimax_m2_tool_parser():
|
||||
return MinimaxM2ToolParser(FakeTokenizer())
|
||||
|
||||
|
||||
def test_extract_tool_calls_streaming_incremental(minimax_m2_tool_parser):
|
||||
parser = minimax_m2_tool_parser
|
||||
parser._reset_streaming_state()
|
||||
chunks = [
|
||||
"<minimax:tool_call>",
|
||||
'<invoke name="get_weather">',
|
||||
'<parameter name="city">',
|
||||
"Seattle</parameter>",
|
||||
"</invoke></minimax:tool_call>",
|
||||
]
|
||||
previous = ""
|
||||
for chunk in chunks:
|
||||
current = previous + chunk
|
||||
delta = chunk
|
||||
parser.extract_tool_calls_streaming(
|
||||
previous_text=previous,
|
||||
current_text=current,
|
||||
delta_text=delta,
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=None,
|
||||
)
|
||||
previous = current
|
||||
|
||||
assert len(parser.prev_tool_call_arr) == 1
|
||||
entry = parser.prev_tool_call_arr[0]
|
||||
|
||||
assert entry["name"] == "get_weather"
|
||||
args = entry["arguments"]
|
||||
assert args["city"] == "Seattle"
|
||||
|
||||
|
||||
def test_streaming_minimax_m2_multiple_invokes(minimax_m2_tool_parser):
|
||||
parser = minimax_m2_tool_parser
|
||||
parser._reset_streaming_state()
|
||||
|
||||
chunks = [
|
||||
"<minimax:tool_call>",
|
||||
'<invoke name="search_web">',
|
||||
'<parameter name="query_tag">',
|
||||
'["technology", "events"]</parameter>',
|
||||
'<parameter name="query_list">',
|
||||
'["OpenAI", "latest", "release"]</parameter>',
|
||||
"</invoke>",
|
||||
'<invoke name="search_web">',
|
||||
'<parameter name="query_tag">',
|
||||
'["technology", "events"]</parameter>',
|
||||
'<parameter name="query_list">',
|
||||
'["Gemini", "latest", "release"]</parameter>',
|
||||
"</invoke>",
|
||||
"</minimax:tool_call>",
|
||||
]
|
||||
previous = ""
|
||||
for chunk in chunks:
|
||||
current = previous + chunk
|
||||
delta = chunk
|
||||
parser.extract_tool_calls_streaming(
|
||||
previous_text=previous,
|
||||
current_text=current,
|
||||
delta_text=delta,
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=None,
|
||||
)
|
||||
previous = current
|
||||
|
||||
assert len(parser.prev_tool_call_arr) == 2
|
||||
|
||||
for entry, expect_model in zip(parser.prev_tool_call_arr, ["OpenAI", "Gemini"]):
|
||||
assert entry["name"] == "search_web"
|
||||
args = json.dumps(entry["arguments"])
|
||||
assert "technology" in args and "events" in args
|
||||
assert expect_model in args
|
||||
|
||||
# check streamed_args_for_tool for serving_chat.py
|
||||
for index in range(2):
|
||||
expected_call = parser.prev_tool_call_arr[index].get("arguments", {})
|
||||
expected_call = json.dumps(expected_call)
|
||||
actual_call = parser.streamed_args_for_tool[index]
|
||||
assert expected_call == actual_call
|
||||
@ -172,7 +172,7 @@ def test_local_attention_virtual_batches(test_data: LocalAttentionTestData):
|
||||
)
|
||||
|
||||
# Call the function
|
||||
result = make_local_attention_virtual_batches(
|
||||
result, _ = make_local_attention_virtual_batches(
|
||||
attn_chunk_size, common_attn_metadata, block_size
|
||||
)
|
||||
|
||||
|
||||
@ -94,26 +94,20 @@ def mock_on_gfx9():
|
||||
None,
|
||||
AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path(),
|
||||
),
|
||||
# Test Case 9: VLLM_V1_USE_PREFILL_DECODE_ATTENTION=1
|
||||
(
|
||||
{"VLLM_V1_USE_PREFILL_DECODE_ATTENTION": "1"},
|
||||
None,
|
||||
AttentionBackendEnum.ROCM_ATTN.get_path(),
|
||||
),
|
||||
# Test Case 10: VLLM_ROCM_USE_AITER=1 + explicit TRITON_ATTN
|
||||
# Test Case 9: VLLM_ROCM_USE_AITER=1 + explicit TRITON_ATTN
|
||||
(
|
||||
{"VLLM_ROCM_USE_AITER": "1"},
|
||||
"TRITON_ATTN",
|
||||
AttentionBackendEnum.TRITON_ATTN.get_path(),
|
||||
),
|
||||
# Test Case 11: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_MHA=0
|
||||
# Test Case 10: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_MHA=0
|
||||
# (explicitly disabled)
|
||||
(
|
||||
{"VLLM_ROCM_USE_AITER": "1", "VLLM_ROCM_USE_AITER_MHA": "0"},
|
||||
None,
|
||||
AttentionBackendEnum.TRITON_ATTN.get_path(),
|
||||
),
|
||||
# Test Case 12: VLLM_ROCM_USE_AITER=1 + explicit ROCM_ATTN
|
||||
# Test Case 11: VLLM_ROCM_USE_AITER=1 + explicit ROCM_ATTN
|
||||
(
|
||||
{"VLLM_ROCM_USE_AITER": "1"},
|
||||
"ROCM_ATTN",
|
||||
|
||||
@ -249,8 +249,8 @@ def create_dummy_kv_cache(
|
||||
@dataclass
|
||||
class BackendConfig:
|
||||
name: str
|
||||
env_vars: dict
|
||||
comp_config: dict # compilation config
|
||||
attention_config: dict
|
||||
comp_config: dict
|
||||
specific_gpu_arch: tuple | None = None
|
||||
|
||||
|
||||
@ -259,10 +259,10 @@ full_cg_backend_configs = {
|
||||
# FA3 on Hopper
|
||||
"FA3": BackendConfig(
|
||||
name="FA3",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN",
|
||||
"VLLM_FLASH_ATTN_VERSION": "3",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||
attention_config={
|
||||
"backend": "FLASH_ATTN",
|
||||
"flash_attn_version": 3,
|
||||
"flash_attn_max_num_splits_for_cuda_graph": 16,
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL",
|
||||
@ -272,9 +272,7 @@ full_cg_backend_configs = {
|
||||
# FlashMLA on Hopper
|
||||
"FlashMLA": BackendConfig(
|
||||
name="FlashMLA",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
|
||||
},
|
||||
attention_config={"backend": "FLASHMLA"},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
},
|
||||
@ -283,9 +281,7 @@ full_cg_backend_configs = {
|
||||
# Cutlass MLA on Blackwell
|
||||
"CutlassMLA": BackendConfig(
|
||||
name="CutlassMLA",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
|
||||
},
|
||||
attention_config={"backend": "CUTLASS_MLA"},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
},
|
||||
@ -294,9 +290,7 @@ full_cg_backend_configs = {
|
||||
# FlashInfer MLA on Blackwell
|
||||
"FlashInferMLA": BackendConfig(
|
||||
name="FlashInferMLA",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASHINFER_MLA",
|
||||
},
|
||||
attention_config={"backend": "FLASHINFER_MLA"},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
},
|
||||
@ -305,9 +299,9 @@ full_cg_backend_configs = {
|
||||
# FlashAttention MLA on Hopper
|
||||
"FlashAttentionMLA": BackendConfig(
|
||||
name="FlashAttentionMLA",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||
attention_config={
|
||||
"backend": "FLASH_ATTN_MLA",
|
||||
"flash_attn_max_num_splits_for_cuda_graph": 16,
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||
@ -317,10 +311,10 @@ full_cg_backend_configs = {
|
||||
# FA2
|
||||
"FA2": BackendConfig(
|
||||
name="FA2",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN",
|
||||
"VLLM_FLASH_ATTN_VERSION": "2",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||
attention_config={
|
||||
"backend": "FLASH_ATTN",
|
||||
"flash_attn_version": 2,
|
||||
"flash_attn_max_num_splits_for_cuda_graph": 16,
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
@ -329,7 +323,7 @@ full_cg_backend_configs = {
|
||||
# Triton Attention
|
||||
"TritonAttn": BackendConfig(
|
||||
name="TritonAttn",
|
||||
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"},
|
||||
attention_config={"backend": "TRITON_ATTN"},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
},
|
||||
@ -337,14 +331,17 @@ full_cg_backend_configs = {
|
||||
# FlashInfer
|
||||
"FlashInfer": BackendConfig(
|
||||
name="FlashInfer",
|
||||
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
|
||||
attention_config={"backend": "FLASHINFER"},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
},
|
||||
),
|
||||
"RocmAttn": BackendConfig(
|
||||
name="RocmAttn",
|
||||
env_vars={"VLLM_V1_USE_PREFILL_DECODE_ATTENTION": "1"},
|
||||
attention_config={
|
||||
"backend": "ROCM_ATTN",
|
||||
"use_prefill_decode_attention": True,
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL",
|
||||
},
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec, PlaceholderRange
|
||||
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
|
||||
@ -23,7 +24,7 @@ class MockRequest:
|
||||
)
|
||||
self.mm_features.append(feature)
|
||||
|
||||
def get_num_encoder_tokens(self, input_id: int) -> int:
|
||||
def get_num_encoder_embeds(self, input_id: int) -> int:
|
||||
return self._token_counts[input_id]
|
||||
|
||||
|
||||
@ -162,8 +163,8 @@ def test_schedule_request_multi_images_respect_space_limit():
|
||||
|
||||
num_tokens_to_schedule = 0
|
||||
assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule)
|
||||
num_tokens_to_schedule += req.get_num_encoder_tokens(0)
|
||||
compute_budget -= req.get_num_encoder_tokens(0)
|
||||
num_tokens_to_schedule += req.get_num_encoder_embeds(0)
|
||||
compute_budget -= req.get_num_encoder_embeds(0)
|
||||
|
||||
assert not manager.can_allocate(req, 1, compute_budget, num_tokens_to_schedule)
|
||||
|
||||
@ -174,7 +175,75 @@ def test_schedule_request_multi_images_respect_compute_limit():
|
||||
compute_budget = 10
|
||||
num_tokens_to_schedule = 0
|
||||
assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule)
|
||||
num_tokens_to_schedule += req.get_num_encoder_tokens(0)
|
||||
compute_budget -= req.get_num_encoder_tokens(0)
|
||||
num_tokens_to_schedule += req.get_num_encoder_embeds(0)
|
||||
compute_budget -= req.get_num_encoder_embeds(0)
|
||||
|
||||
assert not manager.can_allocate(req, 1, compute_budget, num_tokens_to_schedule)
|
||||
|
||||
|
||||
def test_encoder_cache_with_is_embed_mask():
|
||||
class MockRequestWithMask(MockRequest):
|
||||
def get_num_encoder_embeds(self, input_id: int) -> int:
|
||||
return self.mm_features[input_id].mm_position.get_num_embeds
|
||||
|
||||
is_embed = torch.zeros(100, dtype=torch.bool)
|
||||
is_embed[torch.tensor([5, 15, 25, 35, 45, 55, 65, 75])] = True
|
||||
|
||||
request = MockRequestWithMask("r1", ["img1"], [100])
|
||||
request.mm_features[0] = MultiModalFeatureSpec(
|
||||
data=None,
|
||||
modality="image",
|
||||
identifier="img1",
|
||||
mm_position=PlaceholderRange(offset=0, length=100, is_embed=is_embed),
|
||||
)
|
||||
|
||||
manager = EncoderCacheManager(cache_size=100)
|
||||
manager.allocate(request, 0)
|
||||
|
||||
assert manager.num_free_slots == 92
|
||||
assert "img1" in manager.cached
|
||||
|
||||
old_size = 100
|
||||
new_size = request.mm_features[0].mm_position.get_num_embeds
|
||||
assert new_size == 8
|
||||
savings_ratio = old_size / new_size
|
||||
assert savings_ratio == 12.5
|
||||
|
||||
|
||||
def test_encoder_cache_mask_based_retrieval():
|
||||
class MockRequestWithMask(MockRequest):
|
||||
def get_num_encoder_embeds(self, input_id: int) -> int:
|
||||
return self.mm_features[input_id].mm_position.get_num_embeds
|
||||
|
||||
is_embed = torch.tensor(
|
||||
[False, False, True, True, False, True, True, True, False, False]
|
||||
)
|
||||
|
||||
request = MockRequestWithMask("r1", ["img1"], [10])
|
||||
request.mm_features[0] = MultiModalFeatureSpec(
|
||||
data=None,
|
||||
modality="image",
|
||||
identifier="img1",
|
||||
mm_position=PlaceholderRange(offset=0, length=10, is_embed=is_embed),
|
||||
)
|
||||
|
||||
manager = EncoderCacheManager(cache_size=50)
|
||||
manager.allocate(request, 0)
|
||||
|
||||
assert request.mm_features[0].mm_position.get_num_embeds == 5
|
||||
|
||||
start_idx = 2
|
||||
end_idx = 8
|
||||
num_embeds_before = is_embed[:start_idx].sum().item()
|
||||
num_embeds_in_range = is_embed[start_idx:end_idx].sum().item()
|
||||
|
||||
assert num_embeds_before == 0
|
||||
assert num_embeds_in_range == 5
|
||||
|
||||
start_idx = 0
|
||||
end_idx = 5
|
||||
num_embeds_before = is_embed[:start_idx].sum().item() if start_idx > 0 else 0
|
||||
num_embeds_in_range = is_embed[start_idx:end_idx].sum().item()
|
||||
|
||||
assert num_embeds_before == 0
|
||||
assert num_embeds_in_range == 2
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
import os
|
||||
import weakref
|
||||
from contextlib import ExitStack
|
||||
|
||||
@ -13,26 +11,6 @@ from vllm import LLM
|
||||
from vllm.config import CompilationConfig, CompilationMode
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def temporary_environ(env_vars):
|
||||
"""
|
||||
Temporarily set environment variables and restore them afterward.
|
||||
We have to do this vs monkeypatch because monkeypatch doesn't work
|
||||
with "module" scoped fixtures.
|
||||
"""
|
||||
original_env = {k: os.environ.get(k) for k in env_vars}
|
||||
try:
|
||||
os.environ.update(env_vars)
|
||||
yield
|
||||
finally:
|
||||
for k, v in original_env.items():
|
||||
if v is None:
|
||||
os.environ.pop(k, None)
|
||||
else:
|
||||
os.environ[k] = v
|
||||
|
||||
|
||||
# test attention backend and cudagraph_mode combo
|
||||
# (backend_name, cudagraph_mode, supported)
|
||||
if current_platform.is_rocm():
|
||||
@ -68,9 +46,9 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte
|
||||
):
|
||||
pytest.skip("Only Hopper GPUs support FA3 and FlashMLA")
|
||||
|
||||
env_vars = backend_configs[backend_name].env_vars
|
||||
attention_config = backend_config.attention_config
|
||||
|
||||
with temporary_environ(env_vars), ExitStack() as stack:
|
||||
with ExitStack() as stack:
|
||||
if not supported:
|
||||
stack.enter_context(pytest.raises(Exception))
|
||||
|
||||
@ -80,6 +58,7 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte
|
||||
trust_remote_code=True,
|
||||
gpu_memory_utilization=0.45,
|
||||
max_model_len=1024,
|
||||
attention_config=attention_config,
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE, cudagraph_mode=cudagraph_mode
|
||||
),
|
||||
@ -122,9 +101,10 @@ combo_cases_2 = [
|
||||
def test_cudagraph_compilation_combo(
|
||||
backend_name, cudagraph_mode, compilation_mode, supported
|
||||
):
|
||||
env_vars = backend_configs[backend_name].env_vars
|
||||
backend_config = backend_configs[backend_name]
|
||||
attention_config = backend_config.attention_config
|
||||
|
||||
with temporary_environ(env_vars), ExitStack() as stack:
|
||||
with ExitStack() as stack:
|
||||
if not supported:
|
||||
stack.enter_context(pytest.raises(Exception))
|
||||
|
||||
@ -134,6 +114,7 @@ def test_cudagraph_compilation_combo(
|
||||
trust_remote_code=True,
|
||||
gpu_memory_utilization=0.45,
|
||||
max_model_len=1024,
|
||||
attention_config=attention_config,
|
||||
compilation_config=CompilationConfig(
|
||||
mode=compilation_mode, cudagraph_mode=cudagraph_mode
|
||||
),
|
||||
|
||||
@ -28,7 +28,7 @@ IS_DEVICE_CAPABILITY_BELOW_90 = is_device_capability_below_90()
|
||||
BACKENDS,
|
||||
)
|
||||
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
|
||||
backend, monkeypatch: pytest.MonkeyPatch
|
||||
backend,
|
||||
):
|
||||
"""
|
||||
Ensures that the same request (the 'needle' prompt) yields identical output
|
||||
@ -54,7 +54,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
|
||||
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||
random.seed(seed)
|
||||
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
|
||||
attention_config = {"backend": backend}
|
||||
# Allow overrides from environment (useful for CI tuning)
|
||||
# "facebook/opt-125m" is too small, doesn't reliably test determinism
|
||||
model = resolve_model_name(backend)
|
||||
@ -92,6 +92,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
|
||||
max_num_seqs=max_batch_size,
|
||||
gpu_memory_utilization=gpu_mem_util,
|
||||
max_model_len=max_model_len,
|
||||
attention_config=attention_config,
|
||||
)
|
||||
|
||||
# Baseline generation for the needle prompt alone.
|
||||
@ -106,6 +107,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
|
||||
max_num_seqs=max_batch_size,
|
||||
gpu_memory_utilization=gpu_mem_util,
|
||||
max_model_len=max_model_len,
|
||||
attention_config=attention_config,
|
||||
)
|
||||
|
||||
mismatches = 0
|
||||
@ -163,10 +165,8 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
|
||||
BACKENDS,
|
||||
)
|
||||
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
||||
backend, monkeypatch: pytest.MonkeyPatch
|
||||
backend,
|
||||
):
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
|
||||
|
||||
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||
random.seed(seed)
|
||||
model_name = resolve_model_name(backend)
|
||||
@ -188,12 +188,12 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
tensor_parallel_size=tp_size,
|
||||
# enable_prefix_caching=False,
|
||||
max_num_seqs=32,
|
||||
max_model_len=8192,
|
||||
dtype="bfloat16", # not everything is supported
|
||||
gpu_memory_utilization=0.9,
|
||||
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
|
||||
attention_config={"backend": backend},
|
||||
)
|
||||
|
||||
# Use more realistic prompts for better token generation
|
||||
@ -382,12 +382,11 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
||||
"backend",
|
||||
BACKENDS,
|
||||
)
|
||||
def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
|
||||
def test_simple_generation(backend):
|
||||
"""
|
||||
Simple test that runs the model with a basic prompt and prints the output.
|
||||
Useful for quick smoke testing and debugging.
|
||||
"""
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
|
||||
model = resolve_model_name(backend)
|
||||
|
||||
llm = LLM(
|
||||
@ -399,6 +398,7 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
|
||||
dtype="bfloat16",
|
||||
enable_prefix_caching=False,
|
||||
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
|
||||
attention_config={"backend": backend},
|
||||
)
|
||||
|
||||
prompt = "the capital of france is"
|
||||
@ -445,8 +445,6 @@ def test_logprobs_without_batch_invariance_should_fail(
|
||||
The test will PASS if we detect differences (proving batch invariance matters).
|
||||
The test will FAIL if everything matches (suggesting batch invariance isn't needed).
|
||||
"""
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
|
||||
|
||||
# CRITICAL: Disable batch invariance for this test
|
||||
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "0")
|
||||
monkeypatch.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", False)
|
||||
@ -466,6 +464,7 @@ def test_logprobs_without_batch_invariance_should_fail(
|
||||
max_model_len=8192,
|
||||
dtype="bfloat16",
|
||||
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
|
||||
attention_config={"backend": backend},
|
||||
)
|
||||
|
||||
# build ragged prompts to change shapes significantly across BS=1 vs BS=N
|
||||
@ -650,7 +649,7 @@ def test_logprobs_without_batch_invariance_should_fail(
|
||||
@skip_unsupported
|
||||
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
|
||||
def test_decode_logprobs_match_prefill_logprobs(
|
||||
backend, monkeypatch: pytest.MonkeyPatch
|
||||
backend,
|
||||
):
|
||||
"""
|
||||
Test that verifies decode logprobs match prefill logprobs.
|
||||
@ -665,8 +664,6 @@ def test_decode_logprobs_match_prefill_logprobs(
|
||||
This ensures that the logprobs from decode are consistent with what
|
||||
we would get if we ran prefill on each prefix.
|
||||
"""
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
|
||||
|
||||
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||
random.seed(seed)
|
||||
model_name = resolve_model_name(backend)
|
||||
@ -690,6 +687,7 @@ def test_decode_logprobs_match_prefill_logprobs(
|
||||
max_model_len=8192,
|
||||
dtype="bfloat16",
|
||||
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
|
||||
attention_config={"backend": backend},
|
||||
)
|
||||
|
||||
# Use a few test prompts
|
||||
@ -921,6 +919,7 @@ def LLM_with_max_seqs(
|
||||
max_num_seqs: int,
|
||||
gpu_memory_utilization: float,
|
||||
max_model_len: int,
|
||||
attention_config: dict | None = None,
|
||||
) -> LLM:
|
||||
"""
|
||||
Helper to construct an LLM with a specific max_num_seqs (batch-size limit)
|
||||
@ -935,6 +934,7 @@ def LLM_with_max_seqs(
|
||||
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
|
||||
enable_prefix_caching=False,
|
||||
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
|
||||
attention_config=attention_config,
|
||||
# Enable for MOE models
|
||||
# enable_expert_parallel=True,
|
||||
)
|
||||
|
||||
@ -136,11 +136,9 @@ def _compare_bs1_vs_bsn_single_process(
|
||||
@skip_unsupported
|
||||
@pytest.mark.parametrize("backend", BACKENDS)
|
||||
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
||||
backend: str, monkeypatch: pytest.MonkeyPatch
|
||||
backend: str,
|
||||
) -> None:
|
||||
random.seed(int(os.getenv("VLLM_TEST_SEED", "12345")))
|
||||
# Override backend for this test (and the RemoteOpenAIServer child process).
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
|
||||
model_name = resolve_model_name(backend)
|
||||
prompts_all = [_random_prompt(10, 50) for _ in range(32)]
|
||||
|
||||
@ -156,6 +154,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
||||
server_args: list[str] = [
|
||||
"--max-model-len=8192",
|
||||
"--max-num-seqs=32",
|
||||
f"--attention-backend={backend}",
|
||||
]
|
||||
if tp_size:
|
||||
server_args += ["-tp", tp_size]
|
||||
|
||||
@ -142,16 +142,17 @@ def run_tests(
|
||||
"""Test consistency of combos of async scheduling, preemption,
|
||||
uni/multiproc executor with spec decoding."""
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
# avoid precision errors
|
||||
if current_platform.is_rocm():
|
||||
if is_testing_with_spec_decoding:
|
||||
# Use TRITON_ATTN for spec decoding test for consistency
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
|
||||
else:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_AITER_FA")
|
||||
# Determine attention config based on platform
|
||||
if current_platform.is_rocm():
|
||||
if is_testing_with_spec_decoding:
|
||||
# Use TRITON_ATTN for spec decoding test for consistency
|
||||
attention_config = {"backend": "TRITON_ATTN"}
|
||||
else:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
|
||||
attention_config = {"backend": "ROCM_AITER_FA"}
|
||||
else:
|
||||
attention_config = {"backend": "FLEX_ATTENTION"}
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
# lock matmul precision to full FP32 (IEEE)
|
||||
m.setenv("VLLM_FLOAT32_MATMUL_PRECISION", "ieee")
|
||||
# m.setenv("VLLM_BATCH_INVARIANT", "1")
|
||||
@ -174,6 +175,7 @@ def run_tests(
|
||||
spec_config,
|
||||
test_prefill_chunking=test_prefill_chunking,
|
||||
is_testing_with_spec_decoding=is_testing_with_spec_decoding,
|
||||
attention_config=attention_config,
|
||||
)
|
||||
outputs.append(test_results)
|
||||
|
||||
@ -262,6 +264,7 @@ def run_test(
|
||||
spec_config: dict[str, Any] | None,
|
||||
test_prefill_chunking: bool,
|
||||
is_testing_with_spec_decoding: bool = False,
|
||||
attention_config: dict[str, Any] | None = None,
|
||||
):
|
||||
spec_decoding = spec_config is not None
|
||||
cache_arg: dict[str, Any] = (
|
||||
@ -301,6 +304,7 @@ def run_test(
|
||||
dtype=dtype,
|
||||
speculative_config=spec_config,
|
||||
disable_log_stats=False,
|
||||
attention_config=attention_config,
|
||||
**cache_arg,
|
||||
) as vllm_model:
|
||||
results = []
|
||||
|
||||
@ -10,7 +10,7 @@ from ...utils import create_new_process_for_each_test
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
@pytest.mark.parametrize("attn_backend", ["FLASH_ATTN", "FLASHINFER"])
|
||||
def test_cascade_attention(example_system_message, monkeypatch, attn_backend):
|
||||
def test_cascade_attention(example_system_message, attn_backend):
|
||||
prompt = "\n<User>: Implement fibonacci sequence in Python.\n<Claude>:"
|
||||
|
||||
if attn_backend == "FLASHINFER":
|
||||
@ -19,19 +19,18 @@ def test_cascade_attention(example_system_message, monkeypatch, attn_backend):
|
||||
"needs investigation. See issue #25679."
|
||||
)
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||
llm = LLM(
|
||||
model="Qwen/Qwen2-1.5B-Instruct", attention_config={"backend": attn_backend}
|
||||
)
|
||||
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
|
||||
|
||||
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct")
|
||||
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
|
||||
# No cascade attention.
|
||||
single_prompt = [example_system_message + prompt]
|
||||
responses = llm.generate(single_prompt, sampling_params)
|
||||
ref_output = responses[0].outputs[0].text
|
||||
|
||||
# No cascade attention.
|
||||
single_prompt = [example_system_message + prompt]
|
||||
responses = llm.generate(single_prompt, sampling_params)
|
||||
ref_output = responses[0].outputs[0].text
|
||||
|
||||
# (Probably) Use cascade attention.
|
||||
prompts = [example_system_message + prompt] * 64
|
||||
responses = llm.generate(prompts, sampling_params)
|
||||
for response in responses:
|
||||
assert response.outputs[0].text == ref_output
|
||||
# (Probably) Use cascade attention.
|
||||
prompts = [example_system_message + prompt] * 64
|
||||
responses = llm.generate(prompts, sampling_params)
|
||||
for response in responses:
|
||||
assert response.outputs[0].text == ref_output
|
||||
|
||||
@ -438,25 +438,26 @@ def test_eagle_correctness(
|
||||
should be the same when using eagle speculative decoding.
|
||||
model_setup: (method, model_name, eagle_model_name, tp_size)
|
||||
"""
|
||||
# Determine attention config
|
||||
# Scout requires default backend selection because vision encoder has
|
||||
# head_dim 88 being incompatible with FLASH_ATTN and needs to fall back
|
||||
# to Flex Attn
|
||||
if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN":
|
||||
if current_platform.is_rocm():
|
||||
# TODO: Enable Flex Attn for spec_decode on ROCm
|
||||
pytest.skip("Flex Attn for spec_decode not supported on ROCm currently")
|
||||
attention_config = None # Let it fall back to default
|
||||
else:
|
||||
attention_config = {"backend": attn_backend}
|
||||
|
||||
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
|
||||
pytest.skip(
|
||||
"TRITON_ATTN does not support "
|
||||
"multi-token eagle spec decode on current platform"
|
||||
)
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN":
|
||||
# Scout requires default backend selection
|
||||
# because vision encoder has head_dim 88 being incompatible
|
||||
# with FLASH_ATTN and needs to fall back to Flex Attn
|
||||
|
||||
# pass if not ROCm
|
||||
if current_platform.is_rocm():
|
||||
# TODO: Enable Flex Attn for spec_decode on ROCm
|
||||
pytest.skip("Flex Attn for spec_decode not supported on ROCm currently")
|
||||
else:
|
||||
m.setenv("VLLM_MLA_DISABLE", "1")
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||
|
||||
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
|
||||
pytest.skip(
|
||||
"TRITON_ATTN does not support "
|
||||
"multi-token eagle spec decode on current platform"
|
||||
)
|
||||
m.setenv("VLLM_MLA_DISABLE", "1")
|
||||
|
||||
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
|
||||
if "deepseek" in model_setup[1].lower():
|
||||
@ -471,7 +472,10 @@ def test_eagle_correctness(
|
||||
max_num_batched_tokens = 128 if enable_chunked_prefill else max_model_len
|
||||
|
||||
ref_llm = LLM(
|
||||
model=model_name, max_model_len=max_model_len, tensor_parallel_size=tp_size
|
||||
model=model_name,
|
||||
max_model_len=max_model_len,
|
||||
tensor_parallel_size=tp_size,
|
||||
attention_config=attention_config,
|
||||
)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
del ref_llm
|
||||
@ -492,6 +496,7 @@ def test_eagle_correctness(
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
model_impl=model_impl,
|
||||
attention_config=attention_config,
|
||||
)
|
||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||
matches = 0
|
||||
|
||||
@ -38,7 +38,7 @@ class MockRequest:
|
||||
)
|
||||
self.mm_features.append(feature)
|
||||
|
||||
def get_num_encoder_tokens(self, input_id: int) -> int:
|
||||
def get_num_encoder_embeds(self, input_id: int) -> int:
|
||||
assert input_id < len(self._token_counts)
|
||||
return self._token_counts[input_id]
|
||||
|
||||
|
||||
@ -3,21 +3,29 @@ set -xe
|
||||
|
||||
# Parse command line arguments
|
||||
KV_BUFFER_DEVICE="cuda" # Default to cuda
|
||||
ATTENTION_BACKEND="" # Default to empty (use vllm default)
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
--kv_buffer_device)
|
||||
KV_BUFFER_DEVICE="$2"
|
||||
shift 2
|
||||
;;
|
||||
--attention-backend)
|
||||
ATTENTION_BACKEND="$2"
|
||||
shift 2
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option $1"
|
||||
echo "Usage: $0 [--kv_buffer_device <cuda|cpu>]"
|
||||
echo "Usage: $0 [--kv_buffer_device <cuda|cpu>] [--attention-backend <backend>]"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE"
|
||||
if [[ -n "$ATTENTION_BACKEND" ]]; then
|
||||
echo "Using attention backend: $ATTENTION_BACKEND"
|
||||
fi
|
||||
|
||||
DECODER_KV_LAYOUT=${DECODER_KV_LAYOUT:-"HND"} # Default to HND, optional NHD
|
||||
if [[ "$DECODER_KV_LAYOUT" == "NHD" ]]; then
|
||||
@ -148,6 +156,11 @@ run_tests_for_model() {
|
||||
--tensor-parallel-size $PREFILLER_TP_SIZE \
|
||||
--kv-transfer-config '$KV_CONFIG'"
|
||||
|
||||
# Add attention backend config if specified
|
||||
if [[ -n "$ATTENTION_BACKEND" ]]; then
|
||||
BASE_CMD="${BASE_CMD} --attention-backend=$ATTENTION_BACKEND"
|
||||
fi
|
||||
|
||||
if [ -n "$model_args" ]; then
|
||||
FULL_CMD="$BASE_CMD $model_args"
|
||||
else
|
||||
@ -188,7 +201,12 @@ run_tests_for_model() {
|
||||
--block-size ${DECODE_BLOCK_SIZE} \
|
||||
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
|
||||
--kv-transfer-config '$KV_CONFIG'"
|
||||
|
||||
|
||||
# Add attention backend config if specified
|
||||
if [[ -n "$ATTENTION_BACKEND" ]]; then
|
||||
BASE_CMD="${BASE_CMD} --attention-backend=$ATTENTION_BACKEND"
|
||||
fi
|
||||
|
||||
# DP-EP attention mode
|
||||
if [[ -z "$DP_EP" ]]; then
|
||||
BASE_CMD="${BASE_CMD} --tensor-parallel-size $DECODER_TP_SIZE"
|
||||
|
||||
@ -15,14 +15,14 @@ configs=(
|
||||
|
||||
run_tests() {
|
||||
local label=$1
|
||||
local extra_env=$2
|
||||
local extra_args=$2
|
||||
|
||||
echo "=== Running tests (${label}) ==="
|
||||
for cfg in "${configs[@]}"; do
|
||||
echo "-> Running with ${cfg} ${extra_env:+and ${extra_env}}"
|
||||
echo "-> Running with ${cfg} ${extra_args:+and ${extra_args}}"
|
||||
# Use 'env' to safely set variables without eval
|
||||
if ! env ${extra_env} ${cfg} bash "${SCRIPT}"; then
|
||||
echo "❌ Test failed for config: ${cfg} ${extra_env:+(${extra_env})}"
|
||||
if ! env ${cfg} bash "${SCRIPT}" ${extra_args}; then
|
||||
echo "❌ Test failed for config: ${cfg} ${extra_args:+(${extra_args})}"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
@ -34,8 +34,8 @@ run_tests "default backend" ""
|
||||
|
||||
# Check if FLASHINFER is set (non-empty)
|
||||
if [[ -n "${FLASHINFER:-}" ]]; then
|
||||
echo "FLASHINFER is set, rerunning with VLLM_ATTENTION_BACKEND=FLASHINFER"
|
||||
run_tests "FLASHINFER backend" "VLLM_ATTENTION_BACKEND=FLASHINFER"
|
||||
echo "FLASHINFER is set, rerunning with --attention-backend FLASHINFER"
|
||||
run_tests "FLASHINFER backend" "--attention-backend FLASHINFER"
|
||||
else
|
||||
echo "FLASHINFER not set, skipping FLASHINFER runs."
|
||||
fi
|
||||
|
||||
@ -1132,7 +1132,7 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
|
||||
"TRITON_ATTN",
|
||||
],
|
||||
)
|
||||
def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
|
||||
def test_register_kv_caches(dist_init, attn_backend):
|
||||
"""
|
||||
Test that register_kv_caches() properly calls nixl_wrapper methods with
|
||||
correct data.
|
||||
@ -1144,9 +1144,7 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
|
||||
block layout info
|
||||
"""
|
||||
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
vllm_config = create_vllm_config(attention_backend=attn_backend)
|
||||
|
||||
# Import the appropriate backend based on the parameter
|
||||
if attn_backend == "FLASH_ATTN":
|
||||
|
||||
@ -11,6 +11,7 @@ import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import (
|
||||
AttentionConfig,
|
||||
CacheConfig,
|
||||
DeviceConfig,
|
||||
KVTransferConfig,
|
||||
@ -94,6 +95,7 @@ def create_vllm_config(
|
||||
dtype: str = "float16",
|
||||
cache_dtype: str = "auto",
|
||||
hf_overrides: dict[str, Any] | None = None,
|
||||
attention_backend: str | None = None,
|
||||
) -> VllmConfig:
|
||||
"""Initialize VllmConfig For Testing."""
|
||||
model_config = ModelConfig(
|
||||
@ -124,12 +126,14 @@ def create_vllm_config(
|
||||
enable_permute_local_kv=enable_permute_local_kv,
|
||||
kv_connector_extra_config=kv_connector_extra_config or {},
|
||||
)
|
||||
attention_config = AttentionConfig(backend=attention_backend)
|
||||
return VllmConfig(
|
||||
scheduler_config=scheduler_config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
device_config=DeviceConfig("cpu"),
|
||||
attention_config=attention_config,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -13,7 +13,6 @@ from vllm import LLM, SamplingParams, TokensPrompt
|
||||
from vllm.config import KVEventsConfig, KVTransferConfig
|
||||
from vllm.distributed.kv_events import BlockStored, KVEventBatch
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.system_utils import set_env_var
|
||||
|
||||
CPU_BLOCK_SIZES = [48]
|
||||
ATTN_BACKENDS = ["FLASH_ATTN"]
|
||||
@ -180,13 +179,13 @@ def test_cpu_offloading(cpu_block_size: int, attn_backend: str) -> None:
|
||||
topic="test",
|
||||
)
|
||||
|
||||
with set_env_var("VLLM_ATTENTION_BACKEND", attn_backend):
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-3.2-1B-Instruct",
|
||||
gpu_memory_utilization=0.5,
|
||||
kv_events_config=kv_events_config,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
)
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-3.2-1B-Instruct",
|
||||
gpu_memory_utilization=0.5,
|
||||
kv_events_config=kv_events_config,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
attention_config={"backend": attn_backend},
|
||||
)
|
||||
|
||||
events_endpoint = events_endpoint.replace("*", "127.0.0.1")
|
||||
subscriber = MockSubscriber(events_endpoint, topic=kv_events_config.topic)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user