Merge branch 'main' into rename_file_info_to_pkg/file

This commit is contained in:
Ning Xie 2025-12-02 09:56:42 +08:00 committed by GitHub
commit 910f89c0c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
438 changed files with 14103 additions and 8009 deletions

View File

@ -21,8 +21,8 @@ trap remove_docker_container EXIT
remove_docker_container remove_docker_container
# Try building the docker image # Try building the docker image
numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --tag cpu-test-"$NUMA_NODE" --target vllm-test -f docker/Dockerfile.cpu . numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --progress plain --tag cpu-test-"$NUMA_NODE" --target vllm-test -f docker/Dockerfile.cpu .
numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" --tag cpu-test-"$NUMA_NODE"-avx2 --target vllm-test -f docker/Dockerfile.cpu . numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --progress plain --build-arg VLLM_CPU_DISABLE_AVX512="true" --tag cpu-test-"$NUMA_NODE"-avx2 --target vllm-test -f docker/Dockerfile.cpu .
# Run the image, setting --shm-size=4g for tensor parallel. # Run the image, setting --shm-size=4g for tensor parallel.
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=16 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE" docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=16 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE"

View File

@ -35,7 +35,7 @@ docker run \
echo $ZE_AFFINITY_MASK echo $ZE_AFFINITY_MASK
pip install tblib==3.1.0 pip install tblib==3.1.0
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 -O3 -O.cudagraph_mode=NONE python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 -O3 -cc.cudagraph_mode=NONE
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
VLLM_ATTENTION_BACKEND=TRITON_ATTN python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager VLLM_ATTENTION_BACKEND=TRITON_ATTN python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager

View File

@ -39,9 +39,9 @@ steps:
# if this test fails, it means the nightly torch version is not compatible with some # if this test fails, it means the nightly torch version is not compatible with some
# of the dependencies. Please check the error message and add the package to whitelist # of the dependencies. Please check the error message and add the package to whitelist
# in /vllm/tools/pre_commit/generate_nightly_torch_test.py # in /vllm/tools/pre_commit/generate_nightly_torch_test.py
mirror_hardwares: [amdexperimental] mirror_hardwares: [amdexperimental, amdproduction, amdtentative]
agent_pool: mi325_1 agent_pool: mi325_1
# grade: Blocking grade: Blocking
soft_fail: true soft_fail: true
source_file_dependencies: source_file_dependencies:
- requirements/nightly_torch_test.txt - requirements/nightly_torch_test.txt
@ -50,9 +50,9 @@ steps:
- label: Async Engine, Inputs, Utils, Worker Test # 10min - label: Async Engine, Inputs, Utils, Worker Test # 10min
timeout_in_minutes: 15 timeout_in_minutes: 15
mirror_hardwares: [amdexperimental, amdproduction] mirror_hardwares: [amdexperimental, amdproduction, amdtentative]
agent_pool: mi325_1 agent_pool: mi325_1
# grade: Blocking grade: Blocking
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/multimodal - tests/multimodal
@ -61,17 +61,18 @@ steps:
- pytest -v -s -m 'not cpu_test' multimodal - pytest -v -s -m 'not cpu_test' multimodal
- pytest -v -s utils_ - pytest -v -s utils_
- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 4 mins - label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 15min
timeout_in_minutes: 10 timeout_in_minutes: 20
mirror_hardwares: [amdexperimental, amdproduction] mirror_hardwares: [amdexperimental, amdproduction, amdtentative]
agent_pool: mi325_1 agent_pool: mi325_1
# grade: Blocking grade: Blocking
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/test_inputs.py - tests/test_inputs.py
- tests/test_outputs.py - tests/test_outputs.py
- tests/multimodal - tests/multimodal
- tests/standalone_tests/lazy_imports.py - tests/standalone_tests/lazy_imports.py
- tests/tokenizers_
- tests/transformers_utils - tests/transformers_utils
- tests/config - tests/config
no_gpu: true no_gpu: true
@ -80,6 +81,7 @@ steps:
- pytest -v -s test_inputs.py - pytest -v -s test_inputs.py
- pytest -v -s test_outputs.py - pytest -v -s test_outputs.py
- pytest -v -s -m 'cpu_test' multimodal - pytest -v -s -m 'cpu_test' multimodal
- pytest -v -s tokenizers_
- pytest -v -s transformers_utils - pytest -v -s transformers_utils
- pytest -v -s config - pytest -v -s config
@ -113,9 +115,9 @@ steps:
- pytest -v -s basic_correctness/test_cpu_offload.py - pytest -v -s basic_correctness/test_cpu_offload.py
- label: Entrypoints Unit Tests # 5min - label: Entrypoints Unit Tests # 5min
mirror_hardwares: [amdexperimental, amdproduction] mirror_hardwares: [amdexperimental, amdproduction, amdtentative]
agent_pool: mi325_1 agent_pool: mi325_1
# grade: Blocking grade: Blocking
timeout_in_minutes: 10 timeout_in_minutes: 10
working_dir: "/vllm-workspace/tests" working_dir: "/vllm-workspace/tests"
fast_check: true fast_check: true
@ -212,6 +214,7 @@ steps:
# test with internal dp # test with internal dp
- python3 ../examples/offline_inference/data_parallel.py --enforce-eager - python3 ../examples/offline_inference/data_parallel.py --enforce-eager
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_eagle_dp.py
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_internal_lb_dp.py - TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_internal_lb_dp.py
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_hybrid_lb_dp.py - TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_hybrid_lb_dp.py
@ -250,9 +253,9 @@ steps:
- torchrun --nproc-per-node=8 ../examples/offline_inference/torchrun_dp_example.py --tp-size=2 --pp-size=1 --dp-size=4 --enable-ep - torchrun --nproc-per-node=8 ../examples/offline_inference/torchrun_dp_example.py --tp-size=2 --pp-size=1 --dp-size=4 --enable-ep
- label: EPLB Algorithm Test # 5min - label: EPLB Algorithm Test # 5min
mirror_hardwares: [amdexperimental, amdproduction] mirror_hardwares: [amdexperimental, amdproduction, amdtentative]
agent_pool: mi325_1 agent_pool: mi325_1
# grade: Blocking grade: Blocking
timeout_in_minutes: 15 timeout_in_minutes: 15
working_dir: "/vllm-workspace/tests" working_dir: "/vllm-workspace/tests"
source_file_dependencies: source_file_dependencies:
@ -308,23 +311,20 @@ steps:
- pytest -v -s test_regression.py - pytest -v -s test_regression.py
working_dir: "/vllm-workspace/tests" # optional working_dir: "/vllm-workspace/tests" # optional
- label: Engine Test # 25min - label: Engine Test # 9min
timeout_in_minutes: 40 timeout_in_minutes: 15
mirror_hardwares: [amdexperimental, amdproduction] mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_1 agent_pool: mi325_1
# grade: Blocking # grade: Blocking
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/engine - tests/engine
- tests/tokenization
- tests/test_sequence - tests/test_sequence
- tests/test_config - tests/test_config
- tests/test_logger - tests/test_logger
- tests/test_vllm_port - tests/test_vllm_port
commands: commands:
- pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py - pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py
# OOM in the CI unless we run this separately
- pytest -v -s tokenization
- label: V1 Test e2e + engine # 30min - label: V1 Test e2e + engine # 30min
timeout_in_minutes: 45 timeout_in_minutes: 45
@ -342,9 +342,9 @@ steps:
- label: V1 Test entrypoints # 35min - label: V1 Test entrypoints # 35min
timeout_in_minutes: 50 timeout_in_minutes: 50
mirror_hardwares: [amdexperimental, amdproduction] mirror_hardwares: [amdexperimental, amdproduction, amdtentative]
agent_pool: mi325_1 agent_pool: mi325_1
# grade: Blocking grade: Blocking
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/v1 - tests/v1
@ -392,6 +392,20 @@ steps:
commands: commands:
- pytest -v -s v1/attention - pytest -v -s v1/attention
- label: Batch Invariance Tests (H100) # 10min
mirror_hardwares: [amdexperimental]
agent_pool: mi325_1
timeout_in_minutes: 25
gpu: h100
source_file_dependencies:
- vllm/
- tests/v1/determinism/
commands:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pip install pytest-timeout pytest-forked
- pytest -v -s v1/determinism/test_batch_invariance.py
- pytest -v -s v1/determinism/test_rms_norm_batch_invariant.py
- label: V1 Test attention (B200) # 10min - label: V1 Test attention (B200) # 10min
timeout_in_minutes: 30 timeout_in_minutes: 30
gpu: b200 gpu: b200
@ -402,9 +416,9 @@ steps:
- VLLM_DISABLE_FLASHINFER_PREFILL=1 pytest -v -s v1/attention # TODO: FI prefill is bugged and causes incorrectness, fix this - VLLM_DISABLE_FLASHINFER_PREFILL=1 pytest -v -s v1/attention # TODO: FI prefill is bugged and causes incorrectness, fix this
- label: V1 Test others (CPU) # 5 mins - label: V1 Test others (CPU) # 5 mins
mirror_hardwares: [amdexperimental, amdproduction] mirror_hardwares: [amdexperimental, amdproduction, amdtentative]
agent_pool: mi325_1 agent_pool: mi325_1
# grade: Blocking grade: Blocking
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/v1 - tests/v1
@ -496,7 +510,7 @@ steps:
- label: PyTorch Compilation Unit Tests # 15min - label: PyTorch Compilation Unit Tests # 15min
timeout_in_minutes: 30 timeout_in_minutes: 30
mirror_hardwares: [amdexperimental] mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_1 agent_pool: mi325_1
# grade: Blocking # grade: Blocking
torch_nightly: true torch_nightly: true
@ -513,7 +527,7 @@ steps:
- label: PyTorch Fullgraph Smoke Test # 15min - label: PyTorch Fullgraph Smoke Test # 15min
timeout_in_minutes: 30 timeout_in_minutes: 30
mirror_hardwares: [amdexperimental] mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_1 agent_pool: mi325_1
# grade: Blocking # grade: Blocking
torch_nightly: true torch_nightly: true
@ -569,7 +583,7 @@ steps:
- label: Kernels Attention Test %N # 23min - label: Kernels Attention Test %N # 23min
timeout_in_minutes: 35 timeout_in_minutes: 35
mirror_hardwares: [amdexperimental] mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_8 agent_pool: mi325_8
# grade: Blocking # grade: Blocking
source_file_dependencies: source_file_dependencies:
@ -596,7 +610,7 @@ steps:
- label: Kernels MoE Test %N # 40min - label: Kernels MoE Test %N # 40min
timeout_in_minutes: 60 timeout_in_minutes: 60
mirror_hardwares: [amdexperimental] mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_8 agent_pool: mi325_8
# grade: Blocking # grade: Blocking
source_file_dependencies: source_file_dependencies:
@ -623,6 +637,26 @@ steps:
commands: commands:
- pytest -v -s kernels/mamba - pytest -v -s kernels/mamba
- label: Kernels DeepGEMM Test (H100) # Nvidia-centric
# Not replicating for CUTLAS & CuTe
timeout_in_minutes: 45
gpu: h100
num_gpus: 1
source_file_dependencies:
- tools/install_deepgemm.sh
- vllm/utils/deep_gemm.py
- vllm/model_executor/layers/fused_moe
- vllm/model_executor/layers/quantization
- tests/kernels/quantization/test_block_fp8.py
- tests/kernels/moe/test_deepgemm.py
- tests/kernels/moe/test_batched_deepgemm.py
- tests/kernels/attention/test_deepgemm_attention.py
commands:
- pytest -v -s kernels/quantization/test_block_fp8.py -k deep_gemm
- pytest -v -s kernels/moe/test_deepgemm.py
- pytest -v -s kernels/moe/test_batched_deepgemm.py
- pytest -v -s kernels/attention/test_deepgemm_attention.py
- label: Model Executor Test # 23min - label: Model Executor Test # 23min
timeout_in_minutes: 35 timeout_in_minutes: 35
torch_nightly: true torch_nightly: true
@ -1056,6 +1090,7 @@ steps:
- pytest -v -s tests/kernels/moe/test_nvfp4_moe.py - pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
- pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py - pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py
- pytest -v -s tests/kernels/moe/test_flashinfer.py - pytest -v -s tests/kernels/moe/test_flashinfer.py
- pytest -v -s tests/kernels/moe/test_cutedsl_moe.py
- label: Blackwell Fusion and Compile Tests # 30 min - label: Blackwell Fusion and Compile Tests # 30 min
timeout_in_minutes: 40 timeout_in_minutes: 40
@ -1065,11 +1100,19 @@ steps:
- csrc/quantization/fp4/ - csrc/quantization/fp4/
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
- vllm/v1/attention/backends/flashinfer.py - vllm/v1/attention/backends/flashinfer.py
- vllm/v1/worker/
- vllm/v1/cudagraph_dispatcher.py
- vllm/compilation/ - vllm/compilation/
# can affect pattern matching # can affect pattern matching
- vllm/model_executor/layers/layernorm.py - vllm/model_executor/layers/layernorm.py
- vllm/model_executor/layers/activation.py - vllm/model_executor/layers/activation.py
- vllm/model_executor/layers/quantization/input_quant_fp8.py - vllm/model_executor/layers/quantization/input_quant_fp8.py
- vllm/model_executor/layers/fused_moe/layer.py
- tests/compile/test_fusion_attn.py
- tests/compile/test_silu_mul_quant_fusion.py
- tests/compile/distributed/test_fusion_all_reduce.py
- tests/compile/distributed/test_fusions_e2e.py
- tests/compile/fullgraph/test_full_graph.py
commands: commands:
- nvidia-smi - nvidia-smi
- pytest -v -s tests/compile/test_fusion_attn.py - pytest -v -s tests/compile/test_fusion_attn.py
@ -1080,7 +1123,7 @@ steps:
# Wrap with quotes to escape yaml # Wrap with quotes to escape yaml
- "pytest -v -s tests/compile/distributed/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and not +quant_fp8 and not +rms_norm'" - "pytest -v -s tests/compile/distributed/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and not +quant_fp8 and not +rms_norm'"
# test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40) # test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40)
- pytest -v -s tests/compile/distributed/test_full_graph.py::test_fp8_kv_scale_compile - pytest -v -s tests/compile/fullgraph/test_full_graph.py::test_fp8_kv_scale_compile
- label: Blackwell Fusion E2E Tests # 30 min - label: Blackwell Fusion E2E Tests # 30 min
timeout_in_minutes: 40 timeout_in_minutes: 40
@ -1102,7 +1145,7 @@ steps:
commands: commands:
- nvidia-smi - nvidia-smi
# Run all e2e fusion tests # Run all e2e fusion tests
- pytest -v -s tests/compile/test_fusions_e2e.py - pytest -v -s tests/compile/distributed/test_fusions_e2e.py
- label: ROCm GPT-OSS Eval - label: ROCm GPT-OSS Eval
timeout_in_minutes: 60 timeout_in_minutes: 60
@ -1217,6 +1260,7 @@ steps:
- tests/v1/worker/test_worker_memory_snapshot.py - tests/v1/worker/test_worker_memory_snapshot.py
commands: commands:
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_eagle_dp.py
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py
- DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py - DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py
- pytest -v -s entrypoints/llm/test_collective_rpc.py - pytest -v -s entrypoints/llm/test_collective_rpc.py
@ -1252,7 +1296,7 @@ steps:
- label: Plugin Tests (2 GPUs) # 40min - label: Plugin Tests (2 GPUs) # 40min
timeout_in_minutes: 60 timeout_in_minutes: 60
mirror_hardwares: [amdexperimental] mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_2 agent_pool: mi325_2
# grade: Blocking # grade: Blocking
working_dir: "/vllm-workspace/tests" working_dir: "/vllm-workspace/tests"
@ -1328,7 +1372,7 @@ steps:
- label: Weight Loading Multiple GPU Test # 33min - label: Weight Loading Multiple GPU Test # 33min
timeout_in_minutes: 45 timeout_in_minutes: 45
mirror_hardwares: [amdexperimental] mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_2 agent_pool: mi325_2
# grade: Blocking # grade: Blocking
working_dir: "/vllm-workspace/tests" working_dir: "/vllm-workspace/tests"
@ -1433,7 +1477,7 @@ steps:
- pytest -v -s tests/compile/distributed/test_fusion_all_reduce.py - pytest -v -s tests/compile/distributed/test_fusion_all_reduce.py
#- pytest -v -s tests/compile/distributed/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm #- pytest -v -s tests/compile/distributed/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm
- "pytest -v -s tests/compile/distributed/test_fusions_e2e.py -k 'not Llama-4'" - "pytest -v -s tests/compile/distributed/test_fusions_e2e.py -k 'not Llama-4'"
- pytest -v -s tests/compile/distributed/test_sequence_parallel.py - pytest -v -s tests/distributed/test_sequence_parallel.py
- pytest -v -s tests/distributed/test_context_parallel.py - pytest -v -s tests/distributed/test_context_parallel.py
- CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 - CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048
- pytest -v -s tests/v1/distributed/test_dbo.py - pytest -v -s tests/v1/distributed/test_dbo.py
@ -1465,7 +1509,7 @@ steps:
- bash .buildkite/scripts/run-prime-rl-test.sh - bash .buildkite/scripts/run-prime-rl-test.sh
- label: DeepSeek V2-Lite Accuracy - label: DeepSeek V2-Lite Accuracy
mirror_hardwares: [amdexperimental] mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_4 agent_pool: mi325_4
# grade: Blocking # grade: Blocking
timeout_in_minutes: 60 timeout_in_minutes: 60
@ -1476,8 +1520,8 @@ steps:
commands: commands:
- bash .buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_eplb.sh 0.25 200 8010 - bash .buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_eplb.sh 0.25 200 8010
- label: Qwen3-30B-A3B-FP8-block Accuracy - label: Qwen3-30B-A3B-FP8-block Accuracy (H100)
mirror_hardwares: [amdexperimental] mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_4 agent_pool: mi325_4
# grade: Blocking # grade: Blocking
timeout_in_minutes: 60 timeout_in_minutes: 60
@ -1487,3 +1531,12 @@ steps:
working_dir: "/vllm-workspace" working_dir: "/vllm-workspace"
commands: commands:
- bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020
- label: Qwen3-30B-A3B-FP8-block Accuracy (B200)
timeout_in_minutes: 60
gpu: b200
optional: true
num_gpus: 2
working_dir: "/vllm-workspace"
commands:
- bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1

View File

@ -57,14 +57,15 @@ steps:
- pytest -v -s -m 'not cpu_test' multimodal - pytest -v -s -m 'not cpu_test' multimodal
- pytest -v -s utils_ - pytest -v -s utils_
- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 4 mins - label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 15min
timeout_in_minutes: 10 timeout_in_minutes: 20
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/test_inputs.py - tests/test_inputs.py
- tests/test_outputs.py - tests/test_outputs.py
- tests/multimodal - tests/multimodal
- tests/standalone_tests/lazy_imports.py - tests/standalone_tests/lazy_imports.py
- tests/tokenizers_
- tests/transformers_utils - tests/transformers_utils
- tests/config - tests/config
no_gpu: true no_gpu: true
@ -73,6 +74,7 @@ steps:
- pytest -v -s test_inputs.py - pytest -v -s test_inputs.py
- pytest -v -s test_outputs.py - pytest -v -s test_outputs.py
- pytest -v -s -m 'cpu_test' multimodal - pytest -v -s -m 'cpu_test' multimodal
- pytest -v -s tokenizers_
- pytest -v -s transformers_utils - pytest -v -s transformers_utils
- pytest -v -s config - pytest -v -s config
@ -213,6 +215,7 @@ steps:
timeout_in_minutes: 10 timeout_in_minutes: 10
gpu: h100 gpu: h100
num_gpus: 8 num_gpus: 8
optional: true
working_dir: "/vllm-workspace/tests" working_dir: "/vllm-workspace/tests"
source_file_dependencies: source_file_dependencies:
- examples/offline_inference/torchrun_dp_example.py - examples/offline_inference/torchrun_dp_example.py
@ -276,21 +279,18 @@ steps:
- pytest -v -s test_regression.py - pytest -v -s test_regression.py
working_dir: "/vllm-workspace/tests" # optional working_dir: "/vllm-workspace/tests" # optional
- label: Engine Test # 25min - label: Engine Test # 9min
timeout_in_minutes: 40 timeout_in_minutes: 15
mirror_hardwares: [amdexperimental] mirror_hardwares: [amdexperimental]
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/engine - tests/engine
- tests/tokenization
- tests/test_sequence - tests/test_sequence
- tests/test_config - tests/test_config
- tests/test_logger - tests/test_logger
- tests/test_vllm_port - tests/test_vllm_port
commands: commands:
- pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py - pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py
# OOM in the CI unless we run this separately
- pytest -v -s tokenization
- label: V1 Test e2e + engine # 30min - label: V1 Test e2e + engine # 30min
timeout_in_minutes: 45 timeout_in_minutes: 45
@ -820,14 +820,24 @@ steps:
commands: commands:
- pytest -v -s models/language/pooling_mteb_test - pytest -v -s models/language/pooling_mteb_test
- label: Multi-Modal Processor Test # 44min - label: Multi-Modal Processor Test (CPU)
timeout_in_minutes: 60
source_file_dependencies:
- vllm/
- tests/models/multimodal
no_gpu: true
commands:
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- pytest -v -s models/multimodal/processing --ignore models/multimodal/processing/test_tensor_schema.py
- label: Multi-Modal Processor Test
timeout_in_minutes: 60 timeout_in_minutes: 60
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/models/multimodal - tests/models/multimodal
commands: commands:
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- pytest -v -s models/multimodal/processing - pytest -v -s models/multimodal/processing/test_tensor_schema.py
- label: Multi-Modal Models Test (Standard) # 60min - label: Multi-Modal Models Test (Standard) # 60min
timeout_in_minutes: 80 timeout_in_minutes: 80
@ -1303,11 +1313,11 @@ steps:
working_dir: "/vllm-workspace/" working_dir: "/vllm-workspace/"
num_gpus: 2 num_gpus: 2
commands: commands:
- pytest -v -s tests/compile/distributed/test_async_tp.py - VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_async_tp.py
- pytest -v -s tests/compile/distributed/test_sequence_parallelism.py - pytest -v -s tests/compile/distributed/test_sequence_parallelism.py
- pytest -v -s tests/compile/distributed/test_fusion_all_reduce.py - pytest -v -s tests/compile/distributed/test_fusion_all_reduce.py
- "pytest -v -s tests/compile/distributed/test_fusions_e2e.py -k 'not Llama-4'" - "VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_fusions_e2e.py -k 'not Llama-4'"
- pytest -v -s tests/distributed/test_sequence_parallel.py - VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/distributed/test_sequence_parallel.py
- pytest -v -s tests/distributed/test_context_parallel.py - pytest -v -s tests/distributed/test_context_parallel.py
- CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 - CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048
- pytest -v -s tests/v1/distributed/test_dbo.py - pytest -v -s tests/v1/distributed/test_dbo.py
@ -1360,4 +1370,4 @@ steps:
num_gpus: 2 num_gpus: 2
working_dir: "/vllm-workspace" working_dir: "/vllm-workspace"
commands: commands:
- bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1 - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1

1
.github/CODEOWNERS vendored
View File

@ -149,6 +149,7 @@ mkdocs.yaml @hmellor
/examples/*/pooling/ @noooop /examples/*/pooling/ @noooop
/tests/models/*/pooling* @noooop /tests/models/*/pooling* @noooop
/tests/entrypoints/pooling @noooop /tests/entrypoints/pooling @noooop
/vllm/entrypoints/pooling @aarnphm @chaunceyjiang @noooop
/vllm/config/pooler.py @noooop /vllm/config/pooler.py @noooop
/vllm/pooling_params.py @noooop /vllm/pooling_params.py @noooop
/vllm/model_executor/layers/pooler.py @noooop /vllm/model_executor/layers/pooler.py @noooop

View File

@ -354,8 +354,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# Only build Marlin kernels if we are building for at least some compatible archs. # Only build Marlin kernels if we are building for at least some compatible archs.
# Keep building Marlin for 9.0 as there are some group sizes and shapes that # Keep building Marlin for 9.0 as there are some group sizes and shapes that
# are not supported by Machete yet. # are not supported by Machete yet.
# 9.0 for latest bf16 atomicAdd PTX
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}") # marlin arches for fp16 output
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0+PTX" "${CUDA_ARCHS}")
# marlin arches for bf16 output (we need 9.0 for bf16 atomicAdd PTX)
cuda_archs_loose_intersection(MARLIN_BF16_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")
# marlin arches for fp8 input
# - sm80 doesn't support fp8 computation
# - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction
# so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0)
cuda_archs_loose_intersection(MARLIN_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}")
if (MARLIN_ARCHS) if (MARLIN_ARCHS)
# #
@ -365,16 +374,18 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set(MARLIN_GEN_SCRIPT set(MARLIN_GEN_SCRIPT
${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/gptq_marlin/generate_kernels.py) ${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/gptq_marlin/generate_kernels.py)
file(MD5 ${MARLIN_GEN_SCRIPT} MARLIN_GEN_SCRIPT_HASH) file(MD5 ${MARLIN_GEN_SCRIPT} MARLIN_GEN_SCRIPT_HASH)
list(JOIN CUDA_ARCHS "," CUDA_ARCHS_STR)
set(MARLIN_GEN_SCRIPT_HASH_AND_ARCH "${MARLIN_GEN_SCRIPT_HASH}(ARCH:${CUDA_ARCHS_STR})")
message(STATUS "Marlin generation script hash: ${MARLIN_GEN_SCRIPT_HASH}") message(STATUS "Marlin generation script hash: ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH}")
message(STATUS "Last run Marlin generate script hash: $CACHE{MARLIN_GEN_SCRIPT_HASH}") message(STATUS "Last run Marlin generate script hash: $CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH}")
if (NOT DEFINED CACHE{MARLIN_GEN_SCRIPT_HASH} if (NOT DEFINED CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH}
OR NOT $CACHE{MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MARLIN_GEN_SCRIPT_HASH}) OR NOT $CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH} STREQUAL ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH})
execute_process( execute_process(
COMMAND ${CMAKE_COMMAND} -E env COMMAND ${CMAKE_COMMAND} -E env
PYTHONPATH=$PYTHONPATH PYTHONPATH=$PYTHONPATH
${Python_EXECUTABLE} ${MARLIN_GEN_SCRIPT} ${Python_EXECUTABLE} ${MARLIN_GEN_SCRIPT} ${CUDA_ARCHS_STR}
RESULT_VARIABLE marlin_generation_result RESULT_VARIABLE marlin_generation_result
OUTPUT_VARIABLE marlin_generation_result OUTPUT_VARIABLE marlin_generation_result
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log
@ -387,15 +398,15 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"\nCheck the log for details: " "\nCheck the log for details: "
"${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log") "${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log")
else() else()
set(MARLIN_GEN_SCRIPT_HASH ${MARLIN_GEN_SCRIPT_HASH} set(MARLIN_GEN_SCRIPT_HASH_AND_ARCH ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH}
CACHE STRING "Last run Marlin generate script hash" FORCE) CACHE STRING "Last run Marlin generate script hash and arch" FORCE)
message(STATUS "Marlin generation completed successfully.") message(STATUS "Marlin generation completed successfully.")
endif() endif()
else() else()
message(STATUS "Marlin generation script has not changed, skipping generation.") message(STATUS "Marlin generation script has not changed, skipping generation.")
endif() endif()
file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/kernel_*.cu") file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_float16.cu")
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}" SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}"
CUDA_ARCHS "${MARLIN_ARCHS}") CUDA_ARCHS "${MARLIN_ARCHS}")
@ -403,12 +414,34 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC} set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif() endif()
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC}) list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
file(GLOB MARLIN_TEMPLATE_BF16_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_bfloat16.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_TEMPLATE_BF16_KERNEL_SRC}"
CUDA_ARCHS "${MARLIN_BF16_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
set_source_files_properties(${MARLIN_TEMPLATE_BF16_KERNEL_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif()
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_BF16_KERNEL_SRC})
if (MARLIN_FP8_ARCHS)
file(GLOB MARLIN_TEMPLATE_FP8_KERNEL_SRC "csrc/quantization/gptq_marlin/sm89_kernel_*.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_TEMPLATE_FP8_KERNEL_SRC}"
CUDA_ARCHS "${MARLIN_FP8_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
set_source_files_properties(${MARLIN_TEMPLATE_FP8_KERNEL_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif()
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_FP8_KERNEL_SRC})
endif()
set(MARLIN_SRCS set(MARLIN_SRCS
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
"csrc/quantization/gptq_marlin/gptq_marlin.cu" "csrc/quantization/gptq_marlin/gptq_marlin.cu"
"csrc/quantization/gptq_marlin/marlin_int4_fp8_preprocess.cu"
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu") "csrc/quantization/gptq_marlin/awq_marlin_repack.cu")
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
@ -941,8 +974,15 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
CUDA_ARCHS "${CUDA_ARCHS}") CUDA_ARCHS "${CUDA_ARCHS}")
list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}") list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}")
# 9.0 for latest bf16 atomicAdd PTX # moe marlin arches
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}") # note that we always set `use_atomic_add=False` for moe marlin now,
# so we don't need 9.0 for bf16 atomicAdd PTX
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0+PTX" "${CUDA_ARCHS}")
# moe marlin arches for fp8 input
# - sm80 doesn't support fp8 computation
# - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction
# so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0)
cuda_archs_loose_intersection(MARLIN_MOE_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}")
if (MARLIN_MOE_ARCHS) if (MARLIN_MOE_ARCHS)
# #
@ -952,16 +992,18 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set(MOE_MARLIN_GEN_SCRIPT set(MOE_MARLIN_GEN_SCRIPT
${CMAKE_CURRENT_SOURCE_DIR}/csrc/moe/marlin_moe_wna16/generate_kernels.py) ${CMAKE_CURRENT_SOURCE_DIR}/csrc/moe/marlin_moe_wna16/generate_kernels.py)
file(MD5 ${MOE_MARLIN_GEN_SCRIPT} MOE_MARLIN_GEN_SCRIPT_HASH) file(MD5 ${MOE_MARLIN_GEN_SCRIPT} MOE_MARLIN_GEN_SCRIPT_HASH)
list(JOIN CUDA_ARCHS "," CUDA_ARCHS_STR)
set(MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH "${MOE_MARLIN_GEN_SCRIPT_HASH}(ARCH:${CUDA_ARCHS_STR})")
message(STATUS "Marlin MOE generation script hash: ${MOE_MARLIN_GEN_SCRIPT_HASH}") message(STATUS "Marlin MOE generation script hash with arch: ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}")
message(STATUS "Last run Marlin MOE generate script hash: $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH}") message(STATUS "Last run Marlin MOE generate script hash with arch: $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}")
if (NOT DEFINED CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} if (NOT DEFINED CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}
OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH}) OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH})
execute_process( execute_process(
COMMAND ${CMAKE_COMMAND} -E env COMMAND ${CMAKE_COMMAND} -E env
PYTHONPATH=$PYTHONPATH PYTHONPATH=$PYTHONPATH
${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT} ${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT} ${CUDA_ARCHS_STR}
RESULT_VARIABLE moe_marlin_generation_result RESULT_VARIABLE moe_marlin_generation_result
OUTPUT_VARIABLE moe_marlin_generation_output OUTPUT_VARIABLE moe_marlin_generation_output
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log
@ -974,7 +1016,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"\nCheck the log for details: " "\nCheck the log for details: "
"${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log") "${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log")
else() else()
set(MOE_MARLIN_GEN_SCRIPT_HASH ${MOE_MARLIN_GEN_SCRIPT_HASH} set(MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}
CACHE STRING "Last run Marlin MOE generate script hash" FORCE) CACHE STRING "Last run Marlin MOE generate script hash" FORCE)
message(STATUS "Marlin MOE generation completed successfully.") message(STATUS "Marlin MOE generation completed successfully.")
endif() endif()
@ -982,16 +1024,28 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
message(STATUS "Marlin MOE generation script has not changed, skipping generation.") message(STATUS "Marlin MOE generation script has not changed, skipping generation.")
endif() endif()
file(GLOB MOE_WNAA16_MARLIN_SRC "csrc/moe/marlin_moe_wna16/*.cu") file(GLOB MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/sm80_kernel_*.cu")
list(APPEND MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/ops.cu")
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
SRCS "${MOE_WNAA16_MARLIN_SRC}" SRCS "${MARLIN_MOE_SRC}"
CUDA_ARCHS "${MARLIN_MOE_ARCHS}") CUDA_ARCHS "${MARLIN_MOE_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
set_source_files_properties(${MOE_WNAA16_MARLIN_SRC} set_source_files_properties(${MARLIN_MOE_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif() endif()
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_SRC})
list(APPEND VLLM_MOE_EXT_SRC ${MOE_WNAA16_MARLIN_SRC}) if (MARLIN_MOE_FP8_ARCHS)
file(GLOB MARLIN_MOE_FP8_SRC "csrc/moe/marlin_moe_wna16/sm89_kernel_*.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_MOE_FP8_SRC}"
CUDA_ARCHS "${MARLIN_MOE_FP8_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
set_source_files_properties(${MARLIN_MOE_FP8_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif()
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_FP8_SRC})
endif()
message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}") message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}")
else() else()

View File

@ -83,7 +83,7 @@ MIN_CACHE_HIT_PCT=0
MAX_LATENCY_ALLOWED_MS=100000000000 # A very large number MAX_LATENCY_ALLOWED_MS=100000000000 # A very large number
``` ```
#### 2. Maximize Throughput with a Latency Requirement ### 2. Maximize Throughput with a Latency Requirement
- **Goal**: Find the best server parameters when P99 end-to-end latency must be below 500ms. - **Goal**: Find the best server parameters when P99 end-to-end latency must be below 500ms.
- **Configuration**: - **Configuration**:
@ -96,7 +96,7 @@ MIN_CACHE_HIT_PCT=0
MAX_LATENCY_ALLOWED_MS=500 MAX_LATENCY_ALLOWED_MS=500
``` ```
#### 3. Maximize Throughput with Prefix Caching and Latency Requirements ### 3. Maximize Throughput with Prefix Caching and Latency Requirements
- **Goal**: Find the best server parameters assuming a 60% prefix cache hit rate and a latency requirement of 500ms. - **Goal**: Find the best server parameters assuming a 60% prefix cache hit rate and a latency requirement of 500ms.
- **Configuration**: - **Configuration**:

View File

@ -620,7 +620,7 @@ def get_tokenizer(
kwargs["use_fast"] = False kwargs["use_fast"] = False
if tokenizer_mode == "mistral": if tokenizer_mode == "mistral":
try: try:
from vllm.transformers_utils.tokenizer import MistralTokenizer from vllm.tokenizers import MistralTokenizer
except ImportError as e: except ImportError as e:
raise ImportError( raise ImportError(
"MistralTokenizer requires vllm package.\n" "MistralTokenizer requires vllm package.\n"

View File

@ -237,6 +237,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
b_q_weight=w_q, b_q_weight=w_q,
b_bias=None, b_bias=None,
b_scales=w_s, b_scales=w_s,
a_scales=None,
global_scale=None, global_scale=None,
b_zeros=w_zp, b_zeros=w_zp,
g_idx=g_idx, g_idx=g_idx,

View File

@ -263,7 +263,7 @@ def bench_run(
results.append( results.append(
benchmark.Timer( benchmark.Timer(
stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501 stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501
globals=globals, globals=globals,
label=label, label=label,
sub_label=sub_label, sub_label=sub_label,
@ -273,7 +273,7 @@ def bench_run(
results.append( results.append(
benchmark.Timer( benchmark.Timer(
stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501 stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501
globals=globals, globals=globals,
label=label, label=label,
sub_label=sub_label, sub_label=sub_label,

View File

@ -495,7 +495,13 @@ function (define_extension_target MOD_NAME)
set(SOABI_KEYWORD "") set(SOABI_KEYWORD "")
endif() endif()
if (ARG_USE_SABI) run_python(IS_FREETHREADED_PYTHON
"import sysconfig; print(1 if sysconfig.get_config_var(\"Py_GIL_DISABLED\") else 0)"
"Failed to determine whether interpreter is free-threaded")
# Free-threaded Python doesn't yet support the stable ABI (see PEP 803/809),
# so avoid using the stable ABI under free-threading only.
if (ARG_USE_SABI AND NOT IS_FREETHREADED_PYTHON)
Python_add_library(${MOD_NAME} MODULE USE_SABI ${ARG_USE_SABI} ${SOABI_KEYWORD} "${ARG_SOURCES}") Python_add_library(${MOD_NAME} MODULE USE_SABI ${ARG_USE_SABI} ${SOABI_KEYWORD} "${ARG_SOURCES}")
else() else()
Python_add_library(${MOD_NAME} MODULE ${SOABI_KEYWORD} "${ARG_SOURCES}") Python_add_library(${MOD_NAME} MODULE ${SOABI_KEYWORD} "${ARG_SOURCES}")

View File

@ -51,12 +51,13 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) {
if (node_id != -1) { if (node_id != -1) {
node_ids.insert(node_id); node_ids.insert(node_id);
} }
TORCH_WARN(node_id == mem_node_id, "CPU ", cpu_id, " is on NUMA node ", if (node_id != mem_node_id) {
node_id, ", but CPU ", omp_cpu_ids.front(), TORCH_WARN("CPU ", cpu_id, " is on NUMA node ", node_id, ", but CPU ",
" is on NUMA node ", mem_node_id, omp_cpu_ids.front(), " is on NUMA node ", mem_node_id,
". All CPUs should be on the same NUMA node for optimal " ". All CPUs should be on the same NUMA node for optimal "
"performance. Memory will be bound to NUMA node ", "performance. Memory will be bound to NUMA node ",
mem_node_id, "."); mem_node_id, ".");
}
} }
// Concatenate all node_ids into a single comma-separated string // Concatenate all node_ids into a single comma-separated string
if (!node_ids.empty()) { if (!node_ids.empty()) {

View File

@ -1 +1,2 @@
kernel_*.cu sm*_kernel_*.cu
kernel_selector.h

View File

@ -4,134 +4,282 @@ import glob
import itertools import itertools
import os import os
import subprocess import subprocess
import sys
import jinja2 import jinja2
FILE_HEAD = """ ARCHS = []
// auto generated by generate.py SUPPORT_FP8 = False
// clang-format off for arch in sys.argv[1].split(","):
arch = arch[: arch.index(".") + 2].replace(".", "")
arch = int(arch)
# only SM89 and SM120 fully support
# mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32.
# SM90 and SM100 can use this PTX, but its simulated
# with FP16 MMA, so it cannot achieve any acceleration.
if arch in [89, 120]:
SUPPORT_FP8 = True
FILE_HEAD_COMMENT = """
// auto generated by generate_kernels.py
// clang-format off
""".lstrip()
FILE_HEAD = (
FILE_HEAD_COMMENT
+ """
#include "kernel.h" #include "kernel.h"
#include "marlin_template.h" #include "marlin_template.h"
namespace MARLIN_NAMESPACE_NAME { namespace MARLIN_NAMESPACE_NAME {
""".strip() """
)
TEMPLATE = ( TEMPLATE = (
"template __global__ void Marlin<" "template __global__ void Marlin<"
"{{scalar_t}}, " "{{a_type_id}}, "
"{{w_type_id}}, " "{{b_type_id}}, "
"{{c_type_id}}, "
"{{s_type_id}}, " "{{s_type_id}}, "
"{{threads}}, " "{{threads}}, "
"{{thread_m_blocks}}, " "{{thread_m_blocks}}, "
"{{thread_n_blocks}}, " "{{thread_n_blocks}}, "
"{{thread_k_blocks}}, " "{{thread_k_blocks}}, "
"{{'true' if m_block_size_8 else 'false'}}, " "{{m_block_size_8}}, "
"{{stages}}, " "{{stages}}, "
"{{group_blocks}}, " "{{group_blocks}}, "
"{{'true' if is_zp_float else 'false'}}>" "{{is_zp_float}}>"
"( MARLIN_KERNEL_PARAMS );" "( MARLIN_KERNEL_PARAMS );"
) )
# int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES = [
"vllm::kU4",
"vllm::kU4B8",
"vllm::kU8B128",
"vllm::kFE4M3fn",
"vllm::kFE2M1f",
]
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)] THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
# group_blocks:
# = 0 : act order case QUANT_CONFIGS = [
# = -1 : channelwise quantization # AWQ-INT4
# > 0 : group_size=16*group_blocks {
GROUP_BLOCKS = [0, -1, 1, 2, 4, 8] "b_type": "kU4",
DTYPES = ["fp16", "bf16"] "thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4
{
"b_type": "kU4B8",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 0, 2, 4, 8],
},
# AWQ-INT8
{
"b_type": "kU8B128",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 0, 2, 4, 8],
},
# FP8
{
"b_type": "kFE4M3fn",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 8],
},
# NVFP4
{
"b_type": "kFE2M1f",
"s_type": "kFE4M3fn",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [1],
},
# MXFP4
{
"a_type": ["kBFloat16"],
"b_type": "kFE2M1f",
"s_type": "kFE8M0fnu",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [2],
},
# AWQ-INT4 with INT8 activation
{
"a_type": ["kS8"],
"b_type": "kU4",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4 with INT8 activation
{
"a_type": ["kS8"],
"b_type": "kU4B8",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4 with FP8 activation
{
"a_type": ["kFE4M3fn"],
"b_type": "kU4B8",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# AWQ-INT4 with FP8 activation
{
"a_type": ["kFE4M3fn"],
"b_type": "kU4",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# MXFP4 with FP8 activation
{
"a_type": ["kFE4M3fn"],
"b_type": "kFE2M1f",
"c_type": ["kBFloat16"],
"s_type": "kFE8M0fnu",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [2],
},
]
def remove_old_kernels(): def remove_old_kernels():
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"): for filename in glob.glob(os.path.dirname(__file__) + "/*kernel_*.cu"):
subprocess.call(["rm", "-f", filename]) subprocess.call(["rm", "-f", filename])
filename = os.path.dirname(__file__) + "/kernel_selector.h"
subprocess.call(["rm", "-f", filename])
def generate_new_kernels(): def generate_new_kernels():
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES): result_dict = {}
for quant_config in QUANT_CONFIGS:
c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"])
a_types = quant_config.get("a_type", ["kFloat16", "kBFloat16"])
b_type = quant_config["b_type"]
all_group_blocks = quant_config["group_blocks"]
all_m_blocks = quant_config["thread_m_blocks"]
all_thread_configs = quant_config["thread_configs"]
for a_type, c_type in itertools.product(a_types, c_types):
if not SUPPORT_FP8 and a_type == "kFE4M3fn":
continue
if "16" in a_type and "16" in c_type and a_type != c_type:
continue
s_type = quant_config.get("s_type", c_type)
if (a_type, b_type, c_type) not in result_dict:
result_dict[(a_type, b_type, c_type)] = []
for group_blocks, m_blocks, thread_configs in itertools.product(
all_group_blocks, all_m_blocks, all_thread_configs
):
thread_k, thread_n, threads = thread_configs
if threads == 256:
# for small batch (m_blocks == 1),
# we only need (128, 128, 256)
# for large batch (m_blocks > 1),
# we only need (64, 256, 256)
if m_blocks <= 1 and (thread_k, thread_n) != (128, 128):
continue
if m_blocks > 1 and (thread_k, thread_n) != (64, 256):
continue
config = {
"threads": threads,
"s_type": s_type,
"thread_m_blocks": max(m_blocks, 1),
"thread_k_blocks": thread_k // 16,
"thread_n_blocks": thread_n // 16,
"m_block_size_8": "true" if m_blocks == 0.5 else "false",
"stages": "pipe_stages",
"group_blocks": group_blocks,
"is_zp_float": "false",
}
result_dict[(a_type, b_type, c_type)].append(config)
kernel_selector_str = FILE_HEAD_COMMENT
for (a_type, b_type, c_type), config_list in result_dict.items():
all_template_str_list = [] all_template_str_list = []
for config in config_list:
for group_blocks, m_blocks, thread_configs in itertools.product( s_type = config["s_type"]
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS
):
# act order case only support gptq-int4 and gptq-int8
if group_blocks == 0 and scalar_type not in [
"vllm::kU4B8",
"vllm::kU8B128",
]:
continue
if thread_configs[2] == 256:
# for small batch (m_blocks == 1), we only need (128, 128, 256)
# for large batch (m_blocks > 1), we only need (64, 256, 256)
if m_blocks <= 1 and thread_configs[0] != 128:
continue
if m_blocks > 1 and thread_configs[0] != 64:
continue
# we only support channelwise quantization and group_size == 128
# for fp8
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
continue
# nvfp4 only supports group_size == 16
# mxfp4 only supports group_size == 32
if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]:
continue
# other quantization methods don't support group_size = 16
if scalar_type != "vllm::kFE2M1f" and group_blocks == 1:
continue
k_blocks = thread_configs[0] // 16
n_blocks = thread_configs[1] // 16
threads = thread_configs[2]
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
if scalar_type == "vllm::kFE2M1f" and group_blocks == 1:
s_type = "vllm::kFE4M3fn"
elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2:
s_type = "vllm::kFE8M0fnu"
if dtype == "fp16":
# we cannot safely dequantize e8m0 to fp16, so skip this
continue
elif dtype == "fp16":
s_type = "vllm::kFloat16"
elif dtype == "bf16":
s_type = "vllm::kBFloat16"
template_str = jinja2.Template(TEMPLATE).render( template_str = jinja2.Template(TEMPLATE).render(
scalar_t=c_dtype, a_type_id=f"vllm::{a_type}.id()",
w_type_id=scalar_type + ".id()", b_type_id=f"vllm::{b_type}.id()",
s_type_id=s_type + ".id()", c_type_id=f"vllm::{c_type}.id()",
threads=threads, s_type_id=f"vllm::{s_type}.id()",
thread_m_blocks=max(m_blocks, 1), **config,
thread_n_blocks=n_blocks, )
thread_k_blocks=k_blocks, all_template_str_list.append(template_str)
m_block_size_8=m_blocks == 0.5,
stages="pipe_stages", conditions = [
group_blocks=group_blocks, f"a_type == vllm::{a_type}",
is_zp_float=False, f"b_type == vllm::{b_type}",
f"c_type == vllm::{c_type}",
f"s_type == vllm::{s_type}",
f"threads == {config['threads']}",
f"thread_m_blocks == {config['thread_m_blocks']}",
f"thread_n_blocks == {config['thread_n_blocks']}",
f"thread_k_blocks == {config['thread_k_blocks']}",
f"m_block_size_8 == {config['m_block_size_8']}",
f"group_blocks == {config['group_blocks']}",
f"is_zp_float == {config['is_zp_float']}",
]
conditions = " && ".join(conditions)
if kernel_selector_str == FILE_HEAD_COMMENT:
kernel_selector_str += f"if ({conditions})\n kernel = "
else:
kernel_selector_str += f"else if ({conditions})\n kernel = "
kernel_template2 = (
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
"{{is_zp_float}}>;"
) )
all_template_str_list.append(template_str) kernel_selector_str += (
jinja2.Template(kernel_template2).render(
a_type_id=f"vllm::{a_type}.id()",
b_type_id=f"vllm::{b_type}.id()",
c_type_id=f"vllm::{c_type}.id()",
s_type_id=f"vllm::{s_type}.id()",
**config,
)
+ "\n"
)
file_content = FILE_HEAD + "\n\n" file_content = FILE_HEAD + "\n\n"
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n" file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu" if a_type == "kFE4M3fn":
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
else:
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
filename = filename.lower()
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f: with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
f.write(file_content) f.write(file_content)
if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT:
kernel_selector_str += (
"else if (a_type == vllm::kFE4M3fn)\n"
" TORCH_CHECK(false, "
'"marlin kernel with fp8 activation is not built.");'
)
with open(os.path.join(os.path.dirname(__file__), "kernel_selector.h"), "w") as f:
f.write(kernel_selector_str)
if __name__ == "__main__": if __name__ == "__main__":
remove_old_kernels() remove_old_kernels()

View File

@ -11,8 +11,9 @@
const int4 *__restrict__ A, const int4 *__restrict__ B, \ const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ b_bias_ptr, \ const int4 *__restrict__ b_bias_ptr, \
const float *__restrict__ a_scales_ptr, \
const int4 *__restrict__ scales_ptr, \ const int4 *__restrict__ scales_ptr, \
const uint16_t *__restrict__ scale2_ptr, \ const uint16_t *__restrict__ global_scale_ptr, \
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
const int32_t *__restrict__ sorted_token_ids_ptr, \ const int32_t *__restrict__ sorted_token_ids_ptr, \
const int32_t *__restrict__ expert_ids_ptr, \ const int32_t *__restrict__ expert_ids_ptr, \
@ -20,12 +21,13 @@
const float *__restrict__ topk_weights_ptr, int top_k, \ const float *__restrict__ topk_weights_ptr, int top_k, \
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \ bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
int prob_n, int prob_k, int *locks, bool has_bias, bool use_atomic_add, \ int prob_n, int prob_k, int *locks, bool has_bias, bool use_atomic_add, \
bool use_fp32_reduce, int max_shared_mem bool use_fp32_reduce
namespace MARLIN_NAMESPACE_NAME { namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t, // compute dtype, half or nv_float16 template <const vllm::ScalarTypeId a_type_id, // A ScalarType id
const vllm::ScalarTypeId w_type_id, // weight ScalarType id const vllm::ScalarTypeId b_type_id, // B ScalarType id
const vllm::ScalarTypeId s_type_id, // weight scale ScalarType id const vllm::ScalarTypeId c_type_id, // C ScalarType id
const vllm::ScalarTypeId s_type_id, // B_SCALE ScalarType id
const int threads, // number of threads in a threadblock const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the // dimension (batchsize) of the

File diff suppressed because it is too large Load Diff

View File

@ -37,39 +37,6 @@ __global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){};
using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS); using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
template <int moe_block_size>
__global__ void permute_cols_kernel(
int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr,
int4* __restrict__ out_int4_ptr,
const int32_t* __restrict__ sorted_token_ids_ptr,
const int32_t* __restrict__ expert_ids_ptr,
const int32_t* __restrict__ num_tokens_past_padded_ptr, int size_m,
int size_k, int top_k) {};
} // namespace marlin
torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
torch::Tensor& b_q_weight,
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none,
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids,
torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights,
int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep,
vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
bool is_zp_float) {
TORCH_CHECK_NOT_IMPLEMENTED(false,
"marlin_gemm(..) requires CUDA_ARCH >= 8.0");
return torch::empty({1, 1});
}
#else
// For a given "a" of size [M,K] performs a permutation of the K columns based // For a given "a" of size [M,K] performs a permutation of the K columns based
// on the given "perm" indices. // on the given "perm" indices.
template <int moe_block_size> template <int moe_block_size>
@ -207,7 +174,7 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
int thread_m_blocks, int prob_m, int prob_n, int thread_m_blocks, int prob_m, int prob_n,
int prob_k, int num_bits, int group_size, int prob_k, int num_bits, int group_size,
bool has_act_order, bool is_k_full, int has_zp, bool has_act_order, bool is_k_full, int has_zp,
int is_zp_float) { int is_zp_float, bool is_a_8bit) {
int pack_factor = 32 / num_bits; int pack_factor = 32 / num_bits;
// Get B size // Get B size
@ -217,8 +184,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
// shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights // shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32) // both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
int sh_block_meta_size = tb_m * 4; int sh_block_meta_size = tb_m * 16;
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2; int sh_a_size = pipe_stages * (tb_m * tb_k) * (is_a_8bit ? 1 : 2);
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4; int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
int sh_red_size = tb_m * (tb_n + 8) * 2; int sh_red_size = tb_m * (tb_n + 8) * 2;
int sh_bias_size = tb_n * 2; int sh_bias_size = tb_n * 2;
@ -250,7 +217,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
int thread_m_blocks, int prob_m, int prob_n, int prob_k, int thread_m_blocks, int prob_m, int prob_n, int prob_k,
int num_bits, int group_size, bool has_act_order, int num_bits, int group_size, bool has_act_order,
bool is_k_full, int has_zp, int is_zp_float, bool is_k_full, int has_zp, int is_zp_float,
int max_shared_mem) { int max_shared_mem, bool is_a_8bit) {
// Sanity // Sanity
if (th_config.thread_k == -1 || th_config.thread_n == -1 || if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
th_config.num_threads == -1) { th_config.num_threads == -1) {
@ -273,188 +240,34 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
} }
// Check that pipeline fits into cache // Check that pipeline fits into cache
int cache_size = get_kernel_cache_size( int cache_size =
th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k, get_kernel_cache_size(th_config, m_block_size_8, thread_m_blocks, prob_m,
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float); prob_n, prob_k, num_bits, group_size, has_act_order,
return cache_size + 512 <= max_shared_mem; is_k_full, has_zp, is_zp_float, is_a_8bit);
return cache_size <= max_shared_mem;
} }
#define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ MarlinFuncPtr get_marlin_kernel(
M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ const vllm::ScalarType a_type, const vllm::ScalarType b_type,
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ const vllm::ScalarType c_type, const vllm::ScalarType s_type,
thread_n_blocks == THREAD_N_BLOCKS && \ int thread_m_blocks, int thread_n_blocks, int thread_k_blocks,
thread_k_blocks == THREAD_K_BLOCKS && \ bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks,
m_block_size_8 == M_BLOCK_SIZE_8 && \ int threads, bool is_zp_float) {
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ int num_bits = b_type.size_bits();
is_zp_float == IS_ZP_FLOAT) { \
constexpr auto S_TYPE = \
W_TYPE == vllm::kFE2M1f \
? (GROUP_BLOCKS == 1 ? vllm::kFE4M3fn : vllm::kFE8M0fnu) \
: (std::is_same<scalar_t, half>::value ? vllm::kFloat16 \
: vllm::kBFloat16); \
kernel = Marlin<scalar_t, W_TYPE.id(), S_TYPE.id(), NUM_THREADS, \
THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
M_BLOCK_SIZE_8, pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
}
// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false)
// this is the most common cases
// BIGGROUP: cases for big group size (group_blocks in [-1, 8])
// FZP: cases for float-zero-point (is_zp_float = true)
// ACT: cases for act order case (group_blocks == 0)
// FP4: cases for nvfp4(e2m1) (group_blocks == 1)
#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
\
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
\
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define COMMON_GET_IF(W_TYPE) \
COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \
COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \
COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \
COMMON_GET_IF_M234(W_TYPE, 8, 4, 128)
#define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define BIGGROUP_GET_IF(W_TYPE) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \
BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \
BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128)
#define NVFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define NVFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define NVFP4_GET_IF(W_TYPE) \
NVFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
NVFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
NVFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
NVFP4_GET_IF_M234(W_TYPE, 8, 4, 128)
#define MXFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
#define MXFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
#define MXFP4_GET_IF(W_TYPE) \
MXFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
MXFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
MXFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
MXFP4_GET_IF_M234(W_TYPE, 8, 4, 128)
// We currently have 4-bit models only with group_blocks == 4
#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
#define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
#define FZP_GET_IF(W_TYPE) \
FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \
FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \
FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \
FZP_GET_IF_M234(W_TYPE, 8, 4, 128)
// We currently have 4-bit models only with group_blocks == 4
#define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
#define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
#define ACT_GET_IF(W_TYPE) \
ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \
ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \
ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \
ACT_GET_IF_M234(W_TYPE, 8, 4, 128)
template <typename scalar_t>
MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
int thread_m_blocks, int thread_n_blocks,
int thread_k_blocks, bool m_block_size_8,
bool has_act_order, bool has_zp,
int group_blocks, int num_threads,
bool is_zp_float) {
int num_bits = q_type.size_bits();
auto kernel = MarlinDefault; auto kernel = MarlinDefault;
if (false) {
}
COMMON_GET_IF(vllm::kU4) #include "kernel_selector.h"
COMMON_GET_IF(vllm::kU4B8)
COMMON_GET_IF(vllm::kU8B128)
NVFP4_GET_IF(vllm::kFE2M1f)
BIGGROUP_GET_IF(vllm::kFE4M3fn)
ACT_GET_IF(vllm::kU4B8)
ACT_GET_IF(vllm::kU8B128)
if (std::is_same<scalar_t, nv_bfloat16>::value) {
if (false) {
}
MXFP4_GET_IF(vllm::kFE2M1f)
}
return kernel; return kernel;
} }
template <typename scalar_t> exec_config_t determine_exec_config(
exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, const vllm::ScalarType& a_type, const vllm::ScalarType& b_type,
int prob_n, int prob_k, int thread_m_blocks, const vllm::ScalarType& c_type, const vllm::ScalarType& s_type, int prob_m,
bool m_block_size_8, int num_bits, int prob_n, int prob_k, int num_experts, int top_k, int thread_m_blocks,
int group_size, bool has_act_order, bool m_block_size_8, int num_bits, int group_size, bool has_act_order,
bool is_k_full, bool has_zp, bool is_k_full, bool has_zp, bool is_zp_float, int max_shared_mem, int sms,
bool is_zp_float, int max_shared_mem) { bool is_a_8bit) {
exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}}; exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}};
thread_config_t* thread_configs = thread_m_blocks > 1 thread_config_t* thread_configs = thread_m_blocks > 1
? large_batch_thread_configs ? large_batch_thread_configs
@ -471,73 +284,69 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
if (!is_valid_config(th_config, m_block_size_8, thread_m_blocks, prob_m, if (!is_valid_config(th_config, m_block_size_8, thread_m_blocks, prob_m,
prob_n, prob_k, num_bits, group_size, has_act_order, prob_n, prob_k, num_bits, group_size, has_act_order,
is_k_full, has_zp, is_zp_float, max_shared_mem)) { is_k_full, has_zp, is_zp_float, max_shared_mem - 512,
is_a_8bit)) {
continue; continue;
} }
int cache_size = get_kernel_cache_size( int cache_size = get_kernel_cache_size(
th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k, th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float); num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float,
is_a_8bit);
int group_blocks = 0; int group_blocks = 0;
if (!has_act_order) { if (!has_act_order) {
group_blocks = group_size == -1 ? -1 : (group_size / 16); group_blocks = group_size == -1 ? -1 : (group_size / 16);
} }
auto kernel = get_marlin_kernel<scalar_t>( auto kernel =
q_type, thread_m_blocks, th_config.thread_n / 16, get_marlin_kernel(a_type, b_type, c_type, s_type, thread_m_blocks,
th_config.thread_k / 16, m_block_size_8, has_act_order, has_zp, th_config.thread_n / 16, th_config.thread_k / 16,
group_blocks, th_config.num_threads, is_zp_float); m_block_size_8, has_act_order, has_zp, group_blocks,
th_config.num_threads, is_zp_float);
if (kernel == MarlinDefault) continue; if (kernel == MarlinDefault) continue;
if (thread_m_blocks > 1) { cudaFuncAttributes attr;
exec_cfg = {1, th_config}; cudaFuncGetAttributes(&attr, kernel);
break; int reg_size = max(attr.numRegs, 1) * th_config.num_threads * 4;
} else { int allow_count = min(device_max_reg_size / reg_size,
cudaFuncAttributes attr; max_shared_mem / (cache_size + 1536));
cudaFuncGetAttributes(&attr, kernel); if (thread_m_blocks == 1)
int reg_size = max(attr.numRegs, 1) * th_config.num_threads * 4;
int allow_count = min(device_max_reg_size / reg_size,
max_shared_mem / (cache_size + 1024));
allow_count = max(min(allow_count, 4), 1); allow_count = max(min(allow_count, 4), 1);
if (allow_count > count) { else
count = allow_count; allow_count = max(min(allow_count, 2), 1);
exec_cfg = {count, th_config};
}; if (prob_n / th_config.thread_n * prob_m * top_k * 4 < sms * allow_count) {
allow_count =
max(prob_n / th_config.thread_n * prob_m * top_k * 4 / sms, 1);
} }
if (allow_count > count) {
count = allow_count;
exec_cfg = {count, th_config};
};
} }
return exec_cfg; return exec_cfg;
} }
template <typename scalar_t>
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
void* s, void* s2, void* zp, void* g_idx, void* perm, void* a_s, void* b_s, void* g_s, void* zp, void* g_idx,
void* a_tmp, void* sorted_token_ids, void* expert_ids, void* perm, void* a_tmp, void* sorted_token_ids,
void* num_tokens_past_padded, void* topk_weights, void* expert_ids, void* num_tokens_past_padded,
int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep, void* topk_weights, int moe_block_size, int num_experts,
int prob_m, int prob_n, int prob_k, void* workspace, int top_k, bool mul_topk_weights, bool is_ep, int prob_m,
vllm::ScalarType const& q_type, bool has_bias, int prob_n, int prob_k, void* workspace,
bool has_act_order, bool is_k_full, bool has_zp, int num_groups, vllm::ScalarType const& a_type, vllm::ScalarType const& b_type,
int group_size, int dev, cudaStream_t stream, int thread_k, vllm::ScalarType const& c_type, vllm::ScalarType const& s_type,
int thread_n, int sms, bool use_atomic_add, bool use_fp32_reduce, bool has_bias, bool has_act_order, bool is_k_full, bool has_zp,
bool is_zp_float) { int num_groups, int group_size, int dev, cudaStream_t stream,
int thread_k, int thread_n, int sms, int blocks_per_sm,
bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) {
int thread_m_blocks = div_ceil(moe_block_size, 16); int thread_m_blocks = div_ceil(moe_block_size, 16);
bool m_block_size_8 = moe_block_size == 8; bool m_block_size_8 = moe_block_size == 8;
bool is_a_8bit = a_type.size_bits() == 8;
if (has_zp) {
TORCH_CHECK(
q_type == vllm::kU4 || q_type == vllm::kU8,
"q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
} else {
TORCH_CHECK(
q_type == vllm::kU4B8 || q_type == vllm::kU8B128 ||
q_type == vllm::kFE4M3fn || q_type == vllm::kFE2M1f,
"q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when "
"has_zp = False. Got = ",
q_type.str());
}
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
", ", prob_n, ", ", prob_k, "]"); ", ", prob_n, ", ", prob_k, "]");
@ -563,14 +372,15 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
} }
} }
int num_bits = q_type.size_bits(); int num_bits = b_type.size_bits();
const int4* A_ptr = (const int4*)A; const int4* A_ptr = (const int4*)A;
const int4* B_ptr = (const int4*)B; const int4* B_ptr = (const int4*)B;
int4* C_ptr = (int4*)C; int4* C_ptr = (int4*)C;
int4* C_tmp_ptr = (int4*)C_tmp; int4* C_tmp_ptr = (int4*)C_tmp;
const int4* bias_ptr = (const int4*)b_bias; const int4* bias_ptr = (const int4*)b_bias;
const int4* s_ptr = (const int4*)s; const float* a_s_ptr = (const float*)a_s;
const uint16_t* s2_ptr = (const uint16_t*)s2; const int4* b_s_ptr = (const int4*)b_s;
const uint16_t* g_s_ptr = (const uint16_t*)g_s;
const int4* zp_ptr = (const int4*)zp; const int4* zp_ptr = (const int4*)zp;
const int* g_idx_ptr = (const int*)g_idx; const int* g_idx_ptr = (const int*)g_idx;
const int* perm_ptr = (const int*)perm; const int* perm_ptr = (const int*)perm;
@ -618,22 +428,41 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
TORCH_CHECK(max_shared_mem > 0); TORCH_CHECK(max_shared_mem > 0);
int major_capability, minor_capability;
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
dev);
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
dev);
TORCH_CHECK(major_capability * 10 + minor_capability >= 80,
"marlin kernel only support Ampere or newer GPUs.");
if (a_type == vllm::kFE4M3fn) {
TORCH_CHECK(major_capability * 10 + minor_capability >= 89,
"FP8 only support Ada Lovelace or newer GPUs.");
TORCH_CHECK(
major_capability * 10 + minor_capability == 89 ||
major_capability * 10 + minor_capability == 120,
"Marlin W4A8-FP8 only support SM89 or SM120 device (It is slower than "
"Marlin W4A16 on other devices).");
}
// Set thread config // Set thread config
exec_config_t exec_cfg; exec_config_t exec_cfg;
thread_config_t thread_tfg; thread_config_t thread_tfg;
if (thread_k != -1 && thread_n != -1) { if (thread_k != -1 && thread_n != -1) {
thread_tfg = thread_config_t{thread_k, thread_n, default_threads}; thread_tfg = thread_config_t{thread_k, thread_n, thread_k * thread_n / 64};
exec_cfg = exec_config_t{1, thread_tfg}; if (blocks_per_sm == -1) blocks_per_sm = 1;
exec_cfg = exec_config_t{blocks_per_sm, thread_tfg};
TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
" is not divisible by thread_n = ", thread_n); " is not divisible by thread_n = ", thread_n);
TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
" is not divisible by thread_k = ", thread_k); " is not divisible by thread_k = ", thread_k);
} else { } else {
// Auto config // Auto config
exec_cfg = determine_exec_config<scalar_t>( exec_cfg = determine_exec_config(
q_type, prob_m, prob_n, prob_k, thread_m_blocks, m_block_size_8, a_type, b_type, c_type, s_type, prob_m, prob_n, prob_k, num_experts,
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float, top_k, thread_m_blocks, m_block_size_8, num_bits, group_size,
max_shared_mem); has_act_order, is_k_full, has_zp, is_zp_float, max_shared_mem, sms,
is_a_8bit);
thread_tfg = exec_cfg.tb_cfg; thread_tfg = exec_cfg.tb_cfg;
} }
@ -647,22 +476,29 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
int thread_k_blocks = thread_k / 16; int thread_k_blocks = thread_k / 16;
int thread_n_blocks = thread_n / 16; int thread_n_blocks = thread_n / 16;
TORCH_CHECK( TORCH_CHECK(is_valid_config(thread_tfg, m_block_size_8, thread_m_blocks,
is_valid_config(thread_tfg, m_block_size_8, thread_m_blocks, prob_m, prob_m, prob_n, prob_k, num_bits, group_size,
prob_n, prob_k, num_bits, group_size, has_act_order, has_act_order, is_k_full, has_zp, is_zp_float,
is_k_full, has_zp, is_zp_float, max_shared_mem), max_shared_mem, is_a_8bit),
"Invalid thread config: thread_m_blocks = ", thread_m_blocks, "Invalid thread config: thread_m_blocks = ", thread_m_blocks,
", thread_k = ", thread_tfg.thread_k, ", thread_k = ", thread_tfg.thread_k,
", thread_n = ", thread_tfg.thread_n, ", thread_n = ", thread_tfg.thread_n,
", num_threads = ", thread_tfg.num_threads, " for MKN = [", prob_m, ", ", ", num_threads = ", thread_tfg.num_threads, " for MKN = [",
prob_k, ", ", prob_n, "] and num_bits = ", num_bits, prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
", group_size = ", group_size, ", has_act_order = ", has_act_order, ", group_size = ", group_size,
", is_k_full = ", is_k_full, ", has_zp = ", has_zp, ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
", is_zp_float = ", is_zp_float, ", max_shared_mem = ", max_shared_mem); ", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float,
", max_shared_mem = ", max_shared_mem);
auto kernel = get_marlin_kernel<scalar_t>( int sh_cache_size =
q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, m_block_size_8, get_kernel_cache_size(thread_tfg, m_block_size_8, thread_m_blocks, prob_m,
has_act_order, has_zp, group_blocks, num_threads, is_zp_float); prob_n, prob_k, num_bits, group_size, has_act_order,
is_k_full, has_zp, is_zp_float, is_a_8bit);
auto kernel = get_marlin_kernel(
a_type, b_type, c_type, s_type, thread_m_blocks, thread_n_blocks,
thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks,
num_threads, is_zp_float);
if (kernel == MarlinDefault) { if (kernel == MarlinDefault) {
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n, TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
@ -679,19 +515,20 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
// avoid ">>>" being formatted to "> > >" // avoid ">>>" being formatted to "> > >"
// clang-format off // clang-format off
kernel<<<blocks, num_threads, max_shared_mem, stream>>>( kernel<<<blocks, num_threads, max_shared_mem, stream>>>(
A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr, A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, a_s_ptr, b_s_ptr, g_s_ptr, zp_ptr, g_idx_ptr,
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m, topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m,
prob_n, prob_k, locks, has_bias, use_atomic_add, use_fp32_reduce, max_shared_mem); prob_n, prob_k, locks, has_bias, use_atomic_add, use_fp32_reduce);
// clang-format on // clang-format on
} }
} // namespace MARLIN_NAMESPACE_NAME } // namespace MARLIN_NAMESPACE_NAME
torch::Tensor moe_wna16_marlin_gemm( torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none, torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
torch::Tensor& b_q_weight, torch::Tensor& b_q_weight,
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales, std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
std::optional<torch::Tensor> const& a_scales_or_none,
std::optional<torch::Tensor> const& global_scale_or_none, std::optional<torch::Tensor> const& global_scale_or_none,
std::optional<torch::Tensor> const& b_zeros_or_none, std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none, std::optional<torch::Tensor> const& g_idx_or_none,
@ -699,11 +536,70 @@ torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids, torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids,
torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights, torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights,
int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep, int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep,
vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n, vllm::ScalarTypeId const& b_type_id, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
bool is_zp_float) { bool is_zp_float, int64_t thread_k, int64_t thread_n,
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id); int64_t blocks_per_sm) {
int pack_factor = 32 / b_q_type.size_bits(); vllm::ScalarTypeId a_type_id, c_type_id, s_type_id;
auto c_dtype = a.dtype();
if (a.scalar_type() == at::ScalarType::Half) {
a_type_id = vllm::kFloat16.id();
c_type_id = vllm::kFloat16.id();
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
a_type_id = vllm::kBFloat16.id();
c_type_id = vllm::kBFloat16.id();
} else {
c_dtype = b_scales.dtype();
if (b_scales.scalar_type() == at::ScalarType::Half) {
c_type_id = vllm::kFloat16.id();
} else if (b_scales.scalar_type() == at::ScalarType::BFloat16) {
c_type_id = vllm::kBFloat16.id();
} else {
c_type_id = vllm::kBFloat16.id();
TORCH_CHECK(c_or_none.has_value(), "c must be passed for W4A8-FP4");
torch::Tensor c = c_or_none.value();
c_dtype = c.dtype();
if (c.scalar_type() == at::ScalarType::Half) {
c_type_id = vllm::kFloat16.id();
} else if (c.scalar_type() == at::ScalarType::BFloat16) {
c_type_id = vllm::kBFloat16.id();
} else {
TORCH_CHECK(false, "unsupported c dtype");
}
}
if (a.scalar_type() == at::ScalarType::Float8_e4m3fn) {
a_type_id = vllm::kFE4M3fn.id();
} else if (a.scalar_type() == at::ScalarType::Char) {
a_type_id = vllm::kS8.id();
} else {
TORCH_CHECK(false, "unsupported `a` scalar_type");
}
}
s_type_id = c_type_id;
if (b_type_id == vllm::kFE2M1f.id()) {
if (b_scales.scalar_type() == at::ScalarType::Float8_e4m3fn) {
s_type_id = vllm::kFE4M3fn.id();
} else if (b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu) {
s_type_id = vllm::kFE8M0fnu.id();
} else {
TORCH_CHECK(false,
"When b_type = float4_e2m1f, b_scale scalar type must be",
"float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4).");
}
}
vllm::ScalarType a_type = vllm::ScalarType::from_id(a_type_id);
vllm::ScalarType b_type = vllm::ScalarType::from_id(b_type_id);
vllm::ScalarType c_type = vllm::ScalarType::from_id(c_type_id);
vllm::ScalarType s_type = vllm::ScalarType::from_id(s_type_id);
int pack_factor = 32 / b_type.size_bits();
int num_experts = b_q_weight.size(0);
if (moe_block_size != 8) { if (moe_block_size != 8) {
TORCH_CHECK(moe_block_size % 16 == 0, TORCH_CHECK(moe_block_size % 16 == 0,
@ -745,19 +641,27 @@ torch::Tensor moe_wna16_marlin_gemm(
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as torch::Tensor a_scales;
// auto -1) auto options = torch::TensorOptions().dtype(c_dtype).device(a.device());
int thread_k = -1; auto options_fp32 =
// thread_n: `n` size of a thread_tile in `weights` (can usually be left as torch::TensorOptions().dtype(at::kFloat).device(a.device());
// auto -1)
int thread_n = -1; if (a_scales_or_none.has_value()) {
a_scales = a_scales_or_none.value();
TORCH_CHECK(a_type.size_bits() == 8,
"a_scales can only be used for 8bit activation.");
} else {
a_scales = torch::empty({0}, options_fp32);
TORCH_CHECK(a_type.size_bits() != 8,
"the a_scales parameter must be passed for 8bit activation.");
}
// sms: number of SMs to use for the kernel // sms: number of SMs to use for the kernel
int sms = -1; int sms = -1;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device()); cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device());
// Alloc buffers // Alloc buffers
const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
torch::Tensor c; torch::Tensor c;
if (c_or_none.has_value()) { if (c_or_none.has_value()) {
c = c_or_none.value(); c = c_or_none.value();
@ -774,8 +678,6 @@ torch::Tensor moe_wna16_marlin_gemm(
// Alloc C tmp buffer that is going to be used for the global reduce // Alloc C tmp buffer that is going to be used for the global reduce
torch::Tensor c_tmp; torch::Tensor c_tmp;
auto options_fp32 =
torch::TensorOptions().dtype(at::kFloat).device(a.device());
if (use_fp32_reduce && !use_atomic_add) { if (use_fp32_reduce && !use_atomic_add) {
// max num of threadblocks is sms * 4 // max num of threadblocks is sms * 4
long max_c_tmp_size = min( long max_c_tmp_size = min(
@ -846,11 +748,11 @@ torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor global_scale; torch::Tensor global_scale;
if (global_scale_or_none.has_value()) { if (global_scale_or_none.has_value()) {
global_scale = global_scale_or_none.value(); global_scale = global_scale_or_none.value();
TORCH_CHECK(b_q_type == vllm::kFE2M1f && group_size == 16, TORCH_CHECK(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn,
"global_scale can only be used for nvfp4 format."); "global_scale can only be used for nvfp4 format.");
} else { } else {
global_scale = torch::empty({0}, options); global_scale = torch::empty({0}, options);
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f && group_size == 16), TORCH_CHECK(!(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn),
"the global_scale parameter must be passed for nvfp4 format."); "the global_scale parameter must be passed for nvfp4 format.");
} }
@ -877,15 +779,15 @@ torch::Tensor moe_wna16_marlin_gemm(
bool has_zp = b_zeros.size(-1) > 0; bool has_zp = b_zeros.size(-1) > 0;
if (has_zp) { if (has_zp) {
TORCH_CHECK( TORCH_CHECK(
b_q_type == vllm::kU4 || b_q_type == vllm::kU8, b_type == vllm::kU4 || b_type == vllm::kU8,
"b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str()); "b_type must be u4 or u8 when has_zp = True. Got = ", b_type.str());
} else { } else {
TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 || TORCH_CHECK(b_type == vllm::kU4B8 || b_type == vllm::kU8B128 ||
b_q_type == vllm::kFE4M3fn || b_q_type == vllm::kFE2M1f, b_type == vllm::kS4 || b_type == vllm::kS8 ||
"b_q_type must be uint4b8, uint8b128, float8_e4m3fn or " b_type == vllm::kFE4M3fn || b_type == vllm::kFE2M1f,
"float4_e2m1f when " "b_type must be uint4b8, uint8b128, int4, int8, "
"has_zp = False. Got = ", "float8_e4m3fn or float4_e2m1f when has_zp = False. Got = ",
b_q_type.str()); b_type.str());
} }
if (has_zp && is_zp_float) { if (has_zp && is_zp_float) {
@ -929,71 +831,33 @@ torch::Tensor moe_wna16_marlin_gemm(
" is below min_workspace_size = ", min_workspace_size); " is below min_workspace_size = ", min_workspace_size);
int dev = a.get_device(); int dev = a.get_device();
if (a.scalar_type() == at::ScalarType::Half) {
void* scales_ptr;
if (b_q_type == vllm::kFE2M1f) {
if (group_size == 16)
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
else if (group_size == 32)
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
else
TORCH_CHECK(false,
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
"and group_size == 32 (MXFP4)");
} else {
scales_ptr = b_scales.data_ptr<at::Half>();
}
MARLIN_NAMESPACE_NAME::marlin_mm<half>( TORCH_CHECK(a_scales.scalar_type() == at::ScalarType::Float,
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(), "scalar type of a_scales must be float");
c_tmp.data_ptr<float>(), b_bias.data_ptr<at::Half>(), scales_ptr, TORCH_CHECK(global_scale.scalar_type() == c.scalar_type(),
global_scale.data_ptr<at::Half>(), b_zeros.data_ptr(), g_idx.data_ptr(), "scalar type of global_scale must be the same with c");
perm.data_ptr(), a_tmp.data_ptr<at::Half>(), if (a_type.size_bits() == 16) {
sorted_token_ids.data_ptr(), expert_ids.data_ptr(), TORCH_CHECK(
num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(), a.scalar_type() == c.scalar_type(),
moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k, "scalar type of a must be the same with c for 16 bit activation");
workspace.data_ptr(), b_q_type, has_bias, has_act_order, is_k_full,
has_zp, num_groups, group_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
use_atomic_add, use_fp32_reduce, is_zp_float);
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
void* scales_ptr;
if (b_q_type == vllm::kFE2M1f) {
if (group_size == 16)
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
else if (group_size == 32)
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
else
TORCH_CHECK(false,
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
"and group_size == 32 (MXFP4)");
} else {
scales_ptr = b_scales.data_ptr<at::BFloat16>();
}
MARLIN_NAMESPACE_NAME::marlin_mm<nv_bfloat16>(
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
b_bias.data_ptr<at::BFloat16>(), scales_ptr,
global_scale.data_ptr<at::BFloat16>(), b_zeros.data_ptr(),
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
sorted_token_ids.data_ptr(), expert_ids.data_ptr(),
num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(),
moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k,
workspace.data_ptr(), b_q_type, has_bias, has_act_order, is_k_full,
has_zp, num_groups, group_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
use_atomic_add, use_fp32_reduce, is_zp_float);
} else {
TORCH_CHECK(false,
"moe_wna16_marlin_gemm only supports bfloat16 and float16");
} }
MARLIN_NAMESPACE_NAME::marlin_mm(
a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), c_tmp.data_ptr(),
b_bias.data_ptr(), a_scales.data_ptr(), b_scales.data_ptr(),
global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(),
perm.data_ptr(), a_tmp.data_ptr(), sorted_token_ids.data_ptr(),
expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(),
topk_weights.data_ptr(), moe_block_size, num_experts, top_k,
mul_topk_weights, is_ep, size_m, size_n, size_k, workspace.data_ptr(),
a_type, b_type, c_type, s_type, has_bias, has_act_order, is_k_full,
has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, blocks_per_sm, use_atomic_add, use_fp32_reduce,
is_zp_float);
return c; return c;
} }
#endif
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("moe_wna16_marlin_gemm", &moe_wna16_marlin_gemm); m.impl("moe_wna16_marlin_gemm", &moe_wna16_marlin_gemm);
} }

View File

@ -63,16 +63,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m.def( m.def(
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none," "moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
"Tensor! b_q_weight, Tensor? b_bias_or_none," "Tensor! b_q_weight, Tensor? b_bias_or_none,"
"Tensor! b_scales, Tensor? global_scale, Tensor? " "Tensor! b_scales, Tensor? a_scales, Tensor? global_scale, Tensor? "
"b_zeros_or_none," "b_zeros_or_none,"
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace," "Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
"Tensor sorted_token_ids," "Tensor sorted_token_ids,"
"Tensor! expert_ids, Tensor! num_tokens_past_padded," "Tensor! expert_ids, Tensor! num_tokens_past_padded,"
"Tensor! topk_weights, int moe_block_size, int top_k, " "Tensor! topk_weights, int moe_block_size, int top_k, "
"bool mul_topk_weights, bool is_ep, int b_q_type_id," "bool mul_topk_weights, bool is_ep, int b_type_id,"
"int size_m, int size_n, int size_k," "int size_m, int size_n, int size_k,"
"bool is_full_k, bool use_atomic_add," "bool is_full_k, bool use_atomic_add,"
"bool use_fp32_reduce, bool is_zp_float) -> Tensor"); "bool use_fp32_reduce, bool is_zp_float,"
"int thread_k, int thread_n, int blocks_per_sm) -> Tensor");
m.def( m.def(
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "

View File

@ -15,6 +15,8 @@
*/ */
#include <torch/all.h> #include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include "cutlass_extensions/common.hpp"
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100 #if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A, void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
@ -32,23 +34,34 @@ void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& alpha); torch::Tensor const& alpha);
#endif #endif
void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A, void cutlass_scaled_fp4_mm(torch::Tensor& D, const torch::Tensor& A,
torch::Tensor const& B, torch::Tensor const& A_sf, const torch::Tensor& B, const torch::Tensor& A_sf,
torch::Tensor const& B_sf, const torch::Tensor& B_sf,
torch::Tensor const& alpha) { const torch::Tensor& alpha) {
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100 // Make sure were on As device.
return cutlass_scaled_fp4_mm_sm100a(D, A, B, A_sf, B_sf, alpha); const c10::cuda::OptionalCUDAGuard device_guard(device_of(A));
#elif defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120 const int32_t sm = get_sm_version_num();
return cutlass_scaled_fp4_mm_sm120a(D, A, B, A_sf, B_sf, alpha);
#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
if (sm >= 100 && sm < 120) {
cutlass_scaled_fp4_mm_sm100a(D, A, B, A_sf, B_sf, alpha);
return;
}
#endif #endif
TORCH_CHECK_NOT_IMPLEMENTED(false,
"No compiled nvfp4 mm kernel, vLLM should " #if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120
"be compiled using CUDA 12.8 and target " if (sm >= 120 && sm < 130) {
"compute capability 100 or above."); cutlass_scaled_fp4_mm_sm120a(D, A, B, A_sf, B_sf, alpha);
return;
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 mm kernel for SM ", sm,
". Recompile with CUDA >= 12.8 and CC >= 100.");
} }
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability) { bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability) {
int runtimeVersion; int runtimeVersion;
cudaRuntimeGetVersion(&runtimeVersion); cudaRuntimeGetVersion(&runtimeVersion);
return cuda_device_capability >= 100 && runtimeVersion >= 12080; return cuda_device_capability >= 100 && runtimeVersion >= 12080;
} }

View File

@ -437,10 +437,10 @@ struct ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
for (int n_idx = 0; n_idx < WARP_NITER; ++n_idx) { for (int n_idx = 0; n_idx < WARP_NITER; ++n_idx) {
#pragma unroll #pragma unroll
for (int k_idx = 0; k_idx < 2; ++k_idx) { for (int k_idx = 0; k_idx < 2; ++k_idx) {
FType low16 = FType low16 = MarlinScalarType2<FType>::float2num(
ScalarType<FType>::float2num(C_frag[m_idx][n_idx][k_idx * 2]); C_frag[m_idx][n_idx][k_idx * 2]);
FType high16 = FType high16 = MarlinScalarType2<FType>::float2num(
ScalarType<FType>::float2num(C_frag[m_idx][n_idx][k_idx * 2 + 1]); C_frag[m_idx][n_idx][k_idx * 2 + 1]);
uint32_t tmp = (reinterpret_cast<uint32_t&>(low16) & 0xffff) | uint32_t tmp = (reinterpret_cast<uint32_t&>(low16) & 0xffff) |
(reinterpret_cast<uint32_t&>(high16) << 16); (reinterpret_cast<uint32_t&>(high16) << 16);
int sts_offset = int sts_offset =

View File

@ -8,7 +8,7 @@
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <iostream> #include <iostream>
#include "../gptq_marlin/marlin_dtypes.cuh" #include "../gptq_marlin/marlin_dtypes.cuh"
using marlin::ScalarType; using marlin::MarlinScalarType2;
namespace allspark { namespace allspark {
@ -72,10 +72,10 @@ __global__ void f16_gemm_splitk_reduce_kernel(const FType* C_split, FType* C,
int n_mat = N_MATRIX > 0 ? N_MATRIX : (int)n_matrix; int n_mat = N_MATRIX > 0 ? N_MATRIX : (int)n_matrix;
for (int i = 0; i < n_mat; ++i) { for (int i = 0; i < n_mat; ++i) {
sum += ScalarType<FType>::num2float(C_split[idx + i * matrix_size]); sum += MarlinScalarType2<FType>::num2float(C_split[idx + i * matrix_size]);
} }
C[idx] = ScalarType<FType>::float2num(sum); C[idx] = MarlinScalarType2<FType>::float2num(sum);
} }
template <typename FType> template <typename FType>

View File

@ -1 +1,2 @@
kernel_*.cu sm*_kernel_*.cu
kernel_selector.h

View File

@ -4,14 +4,16 @@
namespace marlin { namespace marlin {
template <int const num_threads, int const num_bits> template <int const num_threads, int const num_bits, bool is_a_8bit>
__global__ void awq_marlin_repack_kernel( __global__ void awq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr, uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) { int size_k, int size_n) {
constexpr int pack_factor = 32 / num_bits; constexpr int pack_factor = 32 / num_bits;
int k_tiles = size_k / tile_k_size; constexpr int target_tile_n_size = tile_n_size / (is_a_8bit ? 2 : 1);
int n_tiles = size_n / tile_n_size; constexpr int target_tile_k_size = tile_k_size * (is_a_8bit ? 2 : 1);
int k_tiles = size_k / target_tile_k_size;
int n_tiles = size_n / target_tile_n_size;
int block_k_tiles = div_ceil(k_tiles, gridDim.x); int block_k_tiles = div_ceil(k_tiles, gridDim.x);
auto start_k_tile = blockIdx.x * block_k_tiles; auto start_k_tile = blockIdx.x * block_k_tiles;
@ -33,10 +35,10 @@ __global__ void awq_marlin_repack_kernel(
extern __shared__ int4 sh[]; extern __shared__ int4 sh[];
constexpr int tile_n_ints = tile_n_size / pack_factor; constexpr int tile_n_ints = target_tile_n_size / pack_factor;
constexpr int stage_n_threads = tile_n_ints / 4; constexpr int stage_n_threads = tile_n_ints / 4;
constexpr int stage_k_threads = tile_k_size; constexpr int stage_k_threads = target_tile_k_size;
constexpr int stage_size = stage_k_threads * stage_n_threads; constexpr int stage_size = stage_k_threads * stage_n_threads;
auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) { auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
@ -45,7 +47,7 @@ __global__ void awq_marlin_repack_kernel(
return; return;
} }
int first_n = n_tile_id * tile_n_size; int first_n = n_tile_id * target_tile_n_size;
int first_n_packed = first_n / pack_factor; int first_n_packed = first_n / pack_factor;
int4* sh_ptr = sh + stage_size * pipe; int4* sh_ptr = sh + stage_size * pipe;
@ -54,7 +56,7 @@ __global__ void awq_marlin_repack_kernel(
auto k_id = threadIdx.x / stage_n_threads; auto k_id = threadIdx.x / stage_n_threads;
auto n_id = threadIdx.x % stage_n_threads; auto n_id = threadIdx.x % stage_n_threads;
int first_k = k_tile_id * tile_k_size; int first_k = k_tile_id * target_tile_k_size;
cp_async4(&sh_ptr[k_id * stage_n_threads + n_id], cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
reinterpret_cast<int4 const*>( reinterpret_cast<int4 const*>(
@ -78,11 +80,11 @@ __global__ void awq_marlin_repack_kernel(
} }
int tc_col = th_id / 4; int tc_col = th_id / 4;
int tc_row = (th_id % 4) * 2; int tc_row = (th_id % 4) * (is_a_8bit ? 4 : 2);
constexpr int tc_offsets[4] = {0, 1, 8, 9}; constexpr int tc_offsets[4] = {0, 1, 8, 9};
int cur_n = warp_id * 16 + tc_col; int cur_n = (warp_id / (is_a_8bit ? 2 : 1)) * 16 + tc_col;
int cur_n_packed = cur_n / pack_factor; int cur_n_packed = cur_n / pack_factor;
int cur_n_pos = cur_n % pack_factor; int cur_n_pos = cur_n % pack_factor;
@ -105,23 +107,50 @@ __global__ void awq_marlin_repack_kernel(
uint32_t vals[8]; uint32_t vals[8];
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
int cur_elem = tc_row + tc_offsets[i]; if constexpr (is_a_8bit) {
int cur_elem = tc_row + i;
int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem]; int packed_src_0 =
int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) + sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) * (warp_id % 2) +
sh_stride * cur_elem]; sh_stride * cur_elem];
int packed_src_1 =
sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) * (warp_id % 2) +
sh_stride * (cur_elem + 16)];
vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask; vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask;
vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask; vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask;
} else {
int cur_elem = tc_row + tc_offsets[i];
int packed_src_0 =
sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem];
int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) +
sh_stride * cur_elem];
vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask;
vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask;
}
} }
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor; constexpr int tile_size =
target_tile_k_size * target_tile_n_size / pack_factor;
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
// Result of: // Result of:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
if constexpr (num_bits == 4) { if constexpr (!is_a_8bit && num_bits == 4) {
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
uint32_t res = 0;
#pragma unroll
for (int i = 0; i < 8; i++) {
res |= vals[pack_idx[i]] << (i * 4);
}
out_ptr[out_offset + th_id * 4 + warp_id] = res;
} else if constexpr (is_a_8bit && num_bits == 4) {
int pack_idx[8] = {0, 4, 1, 5, 2, 6, 3, 7};
uint32_t res = 0; uint32_t res = 0;
#pragma unroll #pragma unroll
@ -138,8 +167,9 @@ __global__ void awq_marlin_repack_kernel(
uint32_t res2 = 0; uint32_t res2 = 0;
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
res1 |= vals[pack_idx[i]] << (i * 8); const int ii = is_a_8bit ? i : pack_idx[i];
res2 |= vals[4 + pack_idx[i]] << (i * 8); res1 |= vals[ii] << (i * 8);
res2 |= vals[4 + ii] << (i * 8);
} }
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
@ -176,18 +206,21 @@ __global__ void awq_marlin_repack_kernel(
} // namespace marlin } // namespace marlin
#define CALL_IF(NUM_BITS) \ #define CALL_IF(NUM_BITS, IS_A_8BIT) \
else if (num_bits == NUM_BITS) { \ else if (num_bits == NUM_BITS && is_a_8bit == IS_A_8BIT) { \
cudaFuncSetAttribute( \ cudaFuncSetAttribute( \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \ marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ IS_A_8BIT>, \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \ marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
b_q_weight_ptr, out_ptr, size_k, size_n); \ IS_A_8BIT> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, out_ptr, size_k, size_n); \
} }
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
int64_t size_n, int64_t num_bits) { int64_t size_n, int64_t num_bits,
bool is_a_8bit) {
// Verify compatibility with marlin tile of 16x64 // Verify compatibility with marlin tile of 16x64
TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k, TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
" is not divisible by tile_k_size = ", marlin::tile_k_size); " is not divisible by tile_k_size = ", marlin::tile_k_size);
@ -238,10 +271,13 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
if (false) { if (false) {
} }
CALL_IF(4) CALL_IF(4, false)
CALL_IF(8) CALL_IF(8, false)
CALL_IF(4, true)
CALL_IF(8, true)
else { else {
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits); TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits,
", is_a_8bit = ", is_a_8bit);
} }
return out; return out;

View File

@ -470,6 +470,50 @@ __device__ inline void dequant<nv_bfloat162, vllm::kFE2M1f.id(), false>(
frag_b[0] = __hmul2(frag_b[0], bias_reg); frag_b[0] = __hmul2(frag_b[0], bias_reg);
} }
template <>
__device__ inline void dequant<__nv_fp8x4_e4m3, vllm::kFE2M1f.id(), true>(
int q, __nv_fp8x4_e4m3* frag_b) {
// Constants for FP4 (E2M1) and FP16 formats
constexpr int FP4_EXPONENT = 2, FP8_EXPONENT = 4;
constexpr int RIGHT_SHIFT = FP8_EXPONENT - FP4_EXPONENT;
constexpr int MASK = 0x70707070;
// Extract and shift FP4 values to FP16 format
int Out1 = (q & 0x80808080) | ((q & MASK) >> RIGHT_SHIFT);
q <<= 4;
int Out2 = (q & 0x80808080) | ((q & MASK) >> RIGHT_SHIFT);
// Note1: reverse indexing is intentional because weights are permuted
// Note2: when dequant to 8bit type, we write to `frag_b[2]` instead of
// `frag_b[1]` to fit the layout of tensorcore
frag_b[1] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out1);
frag_b[0] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out2);
}
template <>
__device__ inline void dequant<int32_t, vllm::kU4B8.id(), true>(
int q, int32_t* frag_b) {
constexpr int repeated_zp = 0x08080808;
constexpr int MASK = 0x80808080;
frag_b[0] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK;
q >>= 4;
frag_b[1] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK;
}
template <>
__device__ inline void dequant<__nv_fp8x4_e4m3, vllm::kU4B8.id(), true>(
int q, __nv_fp8x4_e4m3* frag_b) {
int s = q & 0x08080808;
int Out1 = ((q & 0x07070707) | (s << 4)) + (s >> 3);
q >>= 4;
s = q & 0x08080808;
int Out2 = ((q & 0x07070707) | (s << 4)) + (s >> 3);
frag_b[0] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out1);
frag_b[1] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out2);
}
template <typename scalar_t2, vllm::ScalarTypeId s_type_id> template <typename scalar_t2, vllm::ScalarTypeId s_type_id>
__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b); __device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b);
@ -515,6 +559,49 @@ __device__ inline void dequant_fp8_scales<nv_bfloat162, vllm::kFE8M0fnu.id()>(
// Note: reverse indexing is intentional because weights are permuted // Note: reverse indexing is intentional because weights are permuted
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1); frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2); frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
};
// subtract zero point in quanted format and then dequant
template <typename scalar_t2, vllm::ScalarTypeId w_type_id,
bool skip_flop = false>
__device__ inline void sub_zp_and_dequant(int q, scalar_t2* frag_b, int zp);
template <>
__device__ inline void sub_zp_and_dequant<int32_t, vllm::kU4.id(), true>(
int q, int32_t* frag_b, int zp) {
// INT4 with zp -> INT8
// see https://github.com/vllm-project/vllm/pull/24722
int repeated_zp = 0x01010101 * zp;
int MASK = 0x80808080;
frag_b[0] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK;
q >>= 4;
frag_b[1] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK;
}
template <>
__device__ inline void sub_zp_and_dequant<__nv_fp8x4_e4m3, vllm::kU4.id(),
true>(int q, __nv_fp8x4_e4m3* frag_b,
int zp) {
// INT4 with zp -> FP8
// see https://github.com/vllm-project/vllm/pull/24722
uint32_t u_q = *reinterpret_cast<uint32_t*>(&q);
uint32_t u_zp = *reinterpret_cast<uint32_t*>(&zp);
uint32_t u_zp1 = u_zp + 1;
uint32_t repeated_zp = 0x01010101 * u_zp;
uint32_t q0, s;
q0 = (u_q & 0x0F0F0F0F) | 0x70707070;
s = (q0 + repeated_zp) & 0x80808080;
uint32_t Out1 = (q0 + (s >> 7) * u_zp1) & 0x0F0F0F0F | s;
u_q >>= 4;
q0 = (u_q & 0x0F0F0F0F) | 0x70707070;
s = (q0 + repeated_zp) & 0x80808080;
uint32_t Out2 = (q0 + (s >> 7) * u_zp1) & 0x0F0F0F0F | s;
frag_b[0] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out1);
frag_b[1] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out2);
} }
#endif #endif

View File

@ -4,141 +4,292 @@ import glob
import itertools import itertools
import os import os
import subprocess import subprocess
import sys
import jinja2 import jinja2
FILE_HEAD = """ ARCHS = []
// auto generated by generate.py SUPPORT_FP8 = False
// clang-format off for arch in sys.argv[1].split(","):
arch = arch[: arch.index(".") + 2].replace(".", "")
arch = int(arch)
# only SM89 and SM120 fully support
# mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32.
# SM90 and SM100 can use this PTX, but its simulated
# with FP16 MMA, so it cannot achieve any acceleration.
if arch in [89, 120]:
SUPPORT_FP8 = True
FILE_HEAD_COMMENT = """
// auto generated by generate_kernels.py
// clang-format off
""".lstrip()
FILE_HEAD = (
FILE_HEAD_COMMENT
+ """
#include "kernel.h" #include "kernel.h"
#include "marlin_template.h" #include "marlin_template.h"
namespace MARLIN_NAMESPACE_NAME { namespace MARLIN_NAMESPACE_NAME {
""".strip() """
)
TEMPLATE = ( TEMPLATE = (
"template __global__ void Marlin<" "template __global__ void Marlin<"
"{{scalar_t}}, " "{{a_type_id}}, "
"{{w_type_id}}, " "{{b_type_id}}, "
"{{c_type_id}}, "
"{{s_type_id}}, " "{{s_type_id}}, "
"{{threads}}, " "{{threads}}, "
"{{thread_m_blocks}}, " "{{thread_m_blocks}}, "
"{{thread_n_blocks}}, " "{{thread_n_blocks}}, "
"{{thread_k_blocks}}, " "{{thread_k_blocks}}, "
"{{'true' if m_block_size_8 else 'false'}}, " "{{m_block_size_8}}, "
"{{stages}}, " "{{stages}}, "
"{{group_blocks}}, " "{{group_blocks}}, "
"{{'true' if is_zp_float else 'false'}}>" "{{is_zp_float}}>"
"( MARLIN_KERNEL_PARAMS );" "( MARLIN_KERNEL_PARAMS );"
) )
# int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES = [
"vllm::kU4",
"vllm::kU4B8",
"vllm::kU8B128",
"vllm::kFE4M3fn",
"vllm::kFE2M1f",
]
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), (128, 64, 128)] THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), (128, 64, 128)]
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
# group_blocks:
# = 0 : act order case QUANT_CONFIGS = [
# = -1 : channelwise quantization # AWQ-INT4
# > 0 : group_size=16*group_blocks {
GROUP_BLOCKS = [0, 1, -1, 2, 4, 8] "b_type": "kU4",
DTYPES = ["fp16", "bf16"] "thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 2, 4, 8],
},
# HQQ
{
"a_type": ["kFloat16"],
"b_type": "kU4",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [4],
"is_zp_float": True,
},
# GPTQ-INT4
{
"b_type": "kU4B8",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 0, 2, 4, 8],
},
# GPTQ-INT8
{
"b_type": "kU8B128",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 0, 2, 4, 8],
},
# FP8
{
"b_type": "kFE4M3fn",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 8],
},
# NVFP4
{
"b_type": "kFE2M1f",
"s_type": "kFE4M3fn",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [1],
},
# MXFP4
{
"a_type": ["kBFloat16"],
"b_type": "kFE2M1f",
"s_type": "kFE8M0fnu",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [2],
},
# AWQ-INT4 with INT8 activation
{
"a_type": ["kS8"],
"b_type": "kU4",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4 with INT8 activation
{
"a_type": ["kS8"],
"b_type": "kU4B8",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4 with FP8 activation
{
"a_type": ["kFE4M3fn"],
"b_type": "kU4B8",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# AWQ-INT4 with FP8 activation
{
"a_type": ["kFE4M3fn"],
"b_type": "kU4",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# MXFP4 with FP8 activation
{
"a_type": ["kFE4M3fn"],
"b_type": "kFE2M1f",
"c_type": ["kBFloat16"],
"s_type": "kFE8M0fnu",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [2],
},
]
def remove_old_kernels(): def remove_old_kernels():
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"): for filename in glob.glob(os.path.dirname(__file__) + "/*kernel_*.cu"):
subprocess.call(["rm", "-f", filename]) subprocess.call(["rm", "-f", filename])
filename = os.path.dirname(__file__) + "/kernel_selector.h"
subprocess.call(["rm", "-f", filename])
def generate_new_kernels(): def generate_new_kernels():
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES): result_dict = {}
for quant_config in QUANT_CONFIGS:
c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"])
a_types = quant_config.get("a_type", ["kFloat16", "kBFloat16"])
b_type = quant_config["b_type"]
is_zp_float = quant_config.get("is_zp_float", False)
all_group_blocks = quant_config["group_blocks"]
all_m_blocks = quant_config["thread_m_blocks"]
all_thread_configs = quant_config["thread_configs"]
for a_type, c_type in itertools.product(a_types, c_types):
if not SUPPORT_FP8 and a_type == "kFE4M3fn":
continue
if "16" in a_type and "16" in c_type and a_type != c_type:
continue
s_type = quant_config.get("s_type", c_type)
if (a_type, b_type, c_type) not in result_dict:
result_dict[(a_type, b_type, c_type)] = []
for group_blocks, m_blocks, thread_configs in itertools.product(
all_group_blocks, all_m_blocks, all_thread_configs
):
thread_k, thread_n, threads = thread_configs
if threads == 256:
# for small batch (m_blocks == 1),
# we only need (128, 128, 256)
# for large batch (m_blocks > 1),
# we only need (64, 256, 256)
if m_blocks <= 1 and (thread_k, thread_n) != (128, 128):
continue
if m_blocks > 1 and (thread_k, thread_n) != (64, 256):
continue
config = {
"threads": threads,
"s_type": s_type,
"thread_m_blocks": max(m_blocks, 1),
"thread_k_blocks": thread_k // 16,
"thread_n_blocks": thread_n // 16,
"m_block_size_8": "true" if m_blocks == 0.5 else "false",
"stages": "pipe_stages",
"group_blocks": group_blocks,
"is_zp_float": "true" if is_zp_float else "false",
}
result_dict[(a_type, b_type, c_type)].append(config)
kernel_selector_str = FILE_HEAD_COMMENT
for (a_type, b_type, c_type), config_list in result_dict.items():
all_template_str_list = [] all_template_str_list = []
for config in config_list:
s_type = config["s_type"]
template_str = jinja2.Template(TEMPLATE).render(
a_type_id=f"vllm::{a_type}.id()",
b_type_id=f"vllm::{b_type}.id()",
c_type_id=f"vllm::{c_type}.id()",
s_type_id=f"vllm::{s_type}.id()",
**config,
)
all_template_str_list.append(template_str)
for group_blocks, m_blocks, thread_configs in itertools.product( conditions = [
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS f"a_type == vllm::{a_type}",
): f"b_type == vllm::{b_type}",
# act order case only support gptq-int4 and gptq-int8 f"c_type == vllm::{c_type}",
if group_blocks == 0 and scalar_type not in [ f"s_type == vllm::{s_type}",
"vllm::kU4B8", f"threads == {config['threads']}",
"vllm::kU8B128", f"thread_m_blocks == {config['thread_m_blocks']}",
]: f"thread_n_blocks == {config['thread_n_blocks']}",
continue f"thread_k_blocks == {config['thread_k_blocks']}",
if thread_configs[2] == 256: f"m_block_size_8 == {config['m_block_size_8']}",
# for small batch (m_blocks == 1), we only need (128, 128, 256) f"group_blocks == {config['group_blocks']}",
# for large batch (m_blocks > 1), we only need (64, 256, 256) f"is_zp_float == {config['is_zp_float']}",
if m_blocks <= 1 and thread_configs[0] != 128: ]
continue conditions = " && ".join(conditions)
if m_blocks > 1 and thread_configs[0] != 64:
continue
# we only support channelwise quantization and group_size == 128 if kernel_selector_str == FILE_HEAD_COMMENT:
# for fp8 kernel_selector_str += f"if ({conditions})\n kernel = "
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]: else:
continue kernel_selector_str += f"else if ({conditions})\n kernel = "
# nvfp4 only supports group_size == 16
# mxfp4 only supports group_size == 32
if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]:
continue
# other quantization methods don't support group_size = 16
if scalar_type != "vllm::kFE2M1f" and group_blocks == 1:
continue
k_blocks = thread_configs[0] // 16 kernel_template2 = (
n_blocks = thread_configs[1] // 16 "Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
threads = thread_configs[2] "{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
"{{is_zp_float}}>;"
)
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16" kernel_selector_str += (
jinja2.Template(kernel_template2).render(
is_zp_float_list = [False] a_type_id=f"vllm::{a_type}.id()",
if dtype == "fp16" and scalar_type == "vllm::kU4" and group_blocks == 4: b_type_id=f"vllm::{b_type}.id()",
# HQQ (is_zp_float = true) only supports c_type_id=f"vllm::{c_type}.id()",
# 4bit quantization and fp16 s_type_id=f"vllm::{s_type}.id()",
is_zp_float_list.append(True) **config,
if scalar_type == "vllm::kFE2M1f" and group_blocks == 1:
s_type = "vllm::kFE4M3fn"
elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2:
s_type = "vllm::kFE8M0fnu"
if dtype == "fp16":
# we cannot safely dequantize e8m0 to fp16, so skip this
continue
elif dtype == "fp16":
s_type = "vllm::kFloat16"
elif dtype == "bf16":
s_type = "vllm::kBFloat16"
for is_zp_float in is_zp_float_list:
template_str = jinja2.Template(TEMPLATE).render(
scalar_t=c_dtype,
w_type_id=scalar_type + ".id()",
s_type_id=s_type + ".id()",
threads=threads,
thread_m_blocks=max(m_blocks, 1),
thread_n_blocks=n_blocks,
thread_k_blocks=k_blocks,
m_block_size_8=m_blocks == 0.5,
stages="pipe_stages",
group_blocks=group_blocks,
is_zp_float=is_zp_float,
) )
+ "\n"
all_template_str_list.append(template_str) )
file_content = FILE_HEAD + "\n\n" file_content = FILE_HEAD + "\n\n"
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n" file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu" if a_type == "kFE4M3fn":
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
else:
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
filename = filename.lower()
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f: with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
f.write(file_content) f.write(file_content)
if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT:
kernel_selector_str += (
"else if (a_type == vllm::kFE4M3fn)\n"
" TORCH_CHECK(false, "
'"marlin kernel with fp8 activation is not built.");'
)
with open(os.path.join(os.path.dirname(__file__), "kernel_selector.h"), "w") as f:
f.write(kernel_selector_str)
if __name__ == "__main__": if __name__ == "__main__":
remove_old_kernels() remove_old_kernels()

View File

@ -53,7 +53,7 @@ torch::Tensor gptq_marlin_gemm(
std::optional<torch::Tensor> const& b_zeros_or_none, std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none, std::optional<torch::Tensor> const& g_idx_or_none,
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace, std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n, vllm::ScalarTypeId const& b_type_id, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
bool is_zp_float) { bool is_zp_float) {
TORCH_CHECK_NOT_IMPLEMENTED(false, TORCH_CHECK_NOT_IMPLEMENTED(false,
@ -243,204 +243,29 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
int cache_size = get_kernel_cache_size( int cache_size = get_kernel_cache_size(
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, has_zp, is_zp_float); has_act_order, is_k_full, has_zp, is_zp_float);
return cache_size + 512 <= max_shared_mem; return cache_size <= max_shared_mem;
} }
#define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ MarlinFuncPtr get_marlin_kernel(
M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ const vllm::ScalarType a_type, const vllm::ScalarType b_type,
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ const vllm::ScalarType c_type, const vllm::ScalarType s_type,
thread_n_blocks == THREAD_N_BLOCKS && \ int thread_m_blocks, int thread_n_blocks, int thread_k_blocks,
thread_k_blocks == THREAD_K_BLOCKS && \ bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks,
m_block_size_8 == M_BLOCK_SIZE_8 && \ int threads, bool is_zp_float) {
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ int num_bits = b_type.size_bits();
is_zp_float == IS_ZP_FLOAT) { \
constexpr auto S_TYPE = \
W_TYPE == vllm::kFE2M1f \
? (GROUP_BLOCKS == 1 ? vllm::kFE4M3fn : vllm::kFE8M0fnu) \
: (std::is_same<scalar_t, half>::value ? vllm::kFloat16 \
: vllm::kBFloat16); \
kernel = Marlin<scalar_t, W_TYPE.id(), S_TYPE.id(), NUM_THREADS, \
THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
M_BLOCK_SIZE_8, pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
}
// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false)
// this is the most common cases
// BIGGROUP: cases for big group size (group_blocks in [-1, 8])
// FZP: cases for float-zero-point (is_zp_float = true)
// ACT: cases for act order case (group_blocks == 0)
// FP4: cases for nvfp4(e2m1) (group_blocks == 1)
#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
\
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
\
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define COMMON_GET_IF(W_TYPE) \
COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \
COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \
COMMON_GET_IF_M1(W_TYPE, 4, 8, 128) \
COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \
COMMON_GET_IF_M234(W_TYPE, 8, 4, 128) \
COMMON_GET_IF_M234(W_TYPE, 4, 8, 128)
#define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define BIGGROUP_GET_IF(W_TYPE) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \
BIGGROUP_GET_IF_M1(W_TYPE, 4, 8, 128) \
BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \
BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \
BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128)
#define NVFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define NVFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define NVFP4_GET_IF(W_TYPE) \
NVFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
NVFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
NVFP4_GET_IF_M1(W_TYPE, 4, 8, 128) \
NVFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
NVFP4_GET_IF_M234(W_TYPE, 8, 4, 128) \
NVFP4_GET_IF_M234(W_TYPE, 4, 8, 128)
#define MXFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
#define MXFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
#define MXFP4_GET_IF(W_TYPE) \
MXFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
MXFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
MXFP4_GET_IF_M1(W_TYPE, 4, 8, 128) \
MXFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
MXFP4_GET_IF_M234(W_TYPE, 8, 4, 128) \
MXFP4_GET_IF_M234(W_TYPE, 4, 8, 128)
// We currently have 4-bit models only with group_blocks == 4
#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
#define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
#define FZP_GET_IF(W_TYPE) \
FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \
FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \
FZP_GET_IF_M1(W_TYPE, 4, 8, 128) \
FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \
FZP_GET_IF_M234(W_TYPE, 8, 4, 128) \
FZP_GET_IF_M234(W_TYPE, 4, 8, 128)
// We currently have 4-bit models only with group_blocks == 4
#define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
#define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
#define ACT_GET_IF(W_TYPE) \
ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \
ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \
ACT_GET_IF_M1(W_TYPE, 4, 8, 128) \
ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \
ACT_GET_IF_M234(W_TYPE, 8, 4, 128) \
ACT_GET_IF_M234(W_TYPE, 4, 8, 128)
template <typename scalar_t>
MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
int thread_m_blocks, int thread_n_blocks,
int thread_k_blocks, bool m_block_size_8,
bool has_act_order, bool has_zp,
int group_blocks, int num_threads,
bool is_zp_float) {
int num_bits = q_type.size_bits();
auto kernel = MarlinDefault; auto kernel = MarlinDefault;
if (false) {
}
COMMON_GET_IF(vllm::kU4) #include "kernel_selector.h"
COMMON_GET_IF(vllm::kU4B8)
COMMON_GET_IF(vllm::kU8B128)
NVFP4_GET_IF(vllm::kFE2M1f)
BIGGROUP_GET_IF(vllm::kFE4M3fn)
ACT_GET_IF(vllm::kU4B8)
ACT_GET_IF(vllm::kU8B128)
if (std::is_same<scalar_t, half>::value) {
if (false) {
}
FZP_GET_IF(vllm::kU4)
}
if (std::is_same<scalar_t, nv_bfloat16>::value) {
if (false) {
}
MXFP4_GET_IF(vllm::kFE2M1f)
}
return kernel; return kernel;
} }
template <typename scalar_t> exec_config_t determine_exec_config(
exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, const vllm::ScalarType& a_type, const vllm::ScalarType& b_type,
int prob_n, int prob_k, int thread_m_blocks, const vllm::ScalarType& c_type, const vllm::ScalarType& s_type, int prob_m,
bool m_block_size_8, int num_bits, int prob_n, int prob_k, int thread_m_blocks, bool m_block_size_8,
int group_size, bool has_act_order, int num_bits, int group_size, bool has_act_order, bool is_k_full,
bool is_k_full, bool has_zp, bool has_zp, bool is_zp_float, int max_shared_mem, int sms) {
bool is_zp_float, int max_shared_mem,
int sms) {
exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}}; exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}};
thread_config_t* thread_configs = thread_m_blocks > 1 thread_config_t* thread_configs = thread_m_blocks > 1
? large_batch_thread_configs ? large_batch_thread_configs
@ -455,7 +280,7 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k, if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full, has_zp, num_bits, group_size, has_act_order, is_k_full, has_zp,
is_zp_float, max_shared_mem)) { is_zp_float, max_shared_mem - 512)) {
continue; continue;
} }
@ -468,10 +293,11 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
group_blocks = group_size == -1 ? -1 : group_size / 16; group_blocks = group_size == -1 ? -1 : group_size / 16;
} }
auto kernel = get_marlin_kernel<scalar_t>( auto kernel =
q_type, thread_m_blocks, th_config.thread_n / 16, get_marlin_kernel(a_type, b_type, c_type, s_type, thread_m_blocks,
th_config.thread_k / 16, m_block_size_8, has_act_order, has_zp, th_config.thread_n / 16, th_config.thread_k / 16,
group_blocks, th_config.num_threads, is_zp_float); m_block_size_8, has_act_order, has_zp, group_blocks,
th_config.num_threads, is_zp_float);
if (kernel == MarlinDefault) continue; if (kernel == MarlinDefault) continue;
@ -485,28 +311,16 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
return exec_cfg; return exec_cfg;
} }
template <typename scalar_t>
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
void* s, void* s2, void* zp, void* g_idx, void* perm, void* a_s, void* b_s, void* g_s, void* zp, void* g_idx,
void* a_tmp, int prob_m, int prob_n, int prob_k, int lda, void* perm, void* a_tmp, int prob_m, int prob_n, int prob_k,
void* workspace, vllm::ScalarType const& q_type, bool has_bias, int lda, void* workspace, vllm::ScalarType const& a_type,
vllm::ScalarType const& b_type, vllm::ScalarType const& c_type,
vllm::ScalarType const& s_type, bool has_bias,
bool has_act_order, bool is_k_full, bool has_zp, int num_groups, bool has_act_order, bool is_k_full, bool has_zp, int num_groups,
int group_size, int dev, cudaStream_t stream, int thread_k_init, int group_size, int dev, cudaStream_t stream, int thread_k_init,
int thread_n_init, int sms, bool use_atomic_add, int thread_n_init, int sms, bool use_atomic_add,
bool use_fp32_reduce, bool is_zp_float) { bool use_fp32_reduce, bool is_zp_float) {
if (has_zp) {
TORCH_CHECK(
q_type == vllm::kU4 || q_type == vllm::kU8,
"q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
} else {
TORCH_CHECK(
q_type == vllm::kU4B8 || q_type == vllm::kU8B128 ||
q_type == vllm::kFE4M3fn || q_type == vllm::kFE2M1f,
"q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when "
"has_zp = False. Got = ",
q_type.str());
}
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
", ", prob_n, ", ", prob_k, "]"); ", ", prob_n, ", ", prob_k, "]");
@ -531,19 +345,21 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
} }
} }
int num_bits = q_type.size_bits(); int num_bits = b_type.size_bits();
const int4* A_ptr = (const int4*)A; const int4* A_ptr = (const int4*)A;
const int4* B_ptr = (const int4*)B; const int4* B_ptr = (const int4*)B;
int4* C_ptr = (int4*)C; int4* C_ptr = (int4*)C;
int4* C_tmp_ptr = (int4*)C_tmp; int4* C_tmp_ptr = (int4*)C_tmp;
const int4* bias_ptr = (const int4*)b_bias; const int4* bias_ptr = (const int4*)b_bias;
const int4* s_ptr = (const int4*)s; const float* a_s_ptr = (const float*)a_s;
const uint16_t* s2_ptr = (const uint16_t*)s2; const int4* b_s_ptr = (const int4*)b_s;
const uint16_t* g_s_ptr = (const uint16_t*)g_s;
const int4* zp_ptr = (const int4*)zp; const int4* zp_ptr = (const int4*)zp;
const int* g_idx_ptr = (const int*)g_idx; const int* g_idx_ptr = (const int*)g_idx;
const int* perm_ptr = (const int*)perm; const int* perm_ptr = (const int*)perm;
int4* a_tmp_ptr = (int4*)a_tmp; int4* a_tmp_ptr = (int4*)a_tmp;
int* locks = (int*)workspace; int* locks = (int*)workspace;
if (has_act_order) { if (has_act_order) {
@ -568,6 +384,21 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
TORCH_CHECK(max_shared_mem > 0); TORCH_CHECK(max_shared_mem > 0);
int major_capability, minor_capability;
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
dev);
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
dev);
TORCH_CHECK(major_capability * 10 + minor_capability >= 80,
"marlin kernel only support Ampere or newer GPUs.");
if (a_type == vllm::kFE4M3fn) {
TORCH_CHECK(
major_capability * 10 + minor_capability == 89 ||
major_capability * 10 + minor_capability == 120,
"Marlin W4A8-FP8 only support SM89 or SM120 device (It is slower than "
"Marlin W4A16 on other devices).");
}
int max_par = 16; int max_par = 16;
if (prob_n <= 4096) max_par = 16 * 8; if (prob_n <= 4096) max_par = 16 * 8;
int max_shared_mem_new = max_shared_mem; int max_shared_mem_new = max_shared_mem;
@ -583,7 +414,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
int thread_n = thread_n_init; int thread_n = thread_n_init;
int thread_m_blocks = min(div_ceil(prob_m_split, 16), max_thread_m_blocks); int thread_m_blocks = min(div_ceil(prob_m_split, 16), max_thread_m_blocks);
int m_block_size_8 = prob_m_split <= 8; int m_block_size_8 = prob_m_split <= 8 && a_type.size_bits() == 16;
// Set thread config // Set thread config
exec_config_t exec_cfg; exec_config_t exec_cfg;
@ -597,11 +428,25 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
" is not divisible by thread_k = ", thread_k); " is not divisible by thread_k = ", thread_k);
} else { } else {
// Auto config // Auto config
exec_cfg = determine_exec_config<scalar_t>( exec_cfg = determine_exec_config(
q_type, prob_m_split, prob_n, prob_k, thread_m_blocks, m_block_size_8, a_type, b_type, c_type, s_type, prob_m_split, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float, thread_m_blocks, m_block_size_8, num_bits, group_size, has_act_order,
max_shared_mem, sms); is_k_full, has_zp, is_zp_float, max_shared_mem, sms);
thread_tfg = exec_cfg.tb_cfg; thread_tfg = exec_cfg.tb_cfg;
if (thread_tfg.thread_n != -1) {
if (prob_n / thread_tfg.thread_n *
div_ceil(prob_m_split, thread_m_blocks * 16) * 4 <=
sms) {
if (is_valid_config({128, 64, 128}, thread_m_blocks, prob_m_split,
prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, has_zp, is_zp_float,
max_shared_mem_new)) {
thread_tfg = {128, 64, 128};
exec_cfg = {1, thread_tfg};
}
}
}
if (thread_tfg.thread_k == -1 && max_thread_m_blocks > 1) { if (thread_tfg.thread_k == -1 && max_thread_m_blocks > 1) {
max_thread_m_blocks--; max_thread_m_blocks--;
continue; continue;
@ -632,10 +477,10 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float, ", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float,
", max_shared_mem_new = ", max_shared_mem_new); ", max_shared_mem_new = ", max_shared_mem_new);
auto kernel = get_marlin_kernel<scalar_t>( auto kernel = get_marlin_kernel(
q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, a_type, b_type, c_type, s_type, thread_m_blocks, thread_n_blocks,
m_block_size_8, has_act_order, has_zp, group_blocks, num_threads, thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks,
is_zp_float); num_threads, is_zp_float);
if (kernel == MarlinDefault) { if (kernel == MarlinDefault) {
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n, TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
@ -657,13 +502,15 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
// avoid ">>>" being formatted to "> > >" // avoid ">>>" being formatted to "> > >"
// clang-format off // clang-format off
kernel<<<blocks, num_threads, max_shared_mem_new, stream>>>( kernel<<<blocks, num_threads, max_shared_mem_new, stream>>>(
A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, s_ptr, s2_ptr, zp_ptr, A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, a_s_ptr, b_s_ptr, g_s_ptr, zp_ptr,
g_idx_ptr, num_groups, g_idx_ptr, num_groups,
prob_m_split, prob_n, prob_k, lda, locks, has_bias, part_use_atomic_add, prob_m_split, prob_n, prob_k, lda, locks, has_bias, part_use_atomic_add,
use_fp32_reduce, max_shared_mem_new); use_fp32_reduce, max_shared_mem_new);
// clang-format on // clang-format on
A_ptr += prob_m_split * (lda / 8); bool is_a_8bit = a_type.size_bits() == 8;
A_ptr += prob_m_split * (lda / (is_a_8bit ? 16 : 8));
a_s_ptr += prob_m_split;
C_ptr += prob_m_split * (prob_n / 8); C_ptr += prob_m_split * (prob_n / 8);
rest_m -= prob_m_split; rest_m -= prob_m_split;
} }
@ -675,15 +522,73 @@ torch::Tensor gptq_marlin_gemm(
torch::Tensor& a, std::optional<torch::Tensor> c_or_none, torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
torch::Tensor& b_q_weight, torch::Tensor& b_q_weight,
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales, std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
std::optional<torch::Tensor> const& a_scales_or_none,
std::optional<torch::Tensor> const& global_scale_or_none, std::optional<torch::Tensor> const& global_scale_or_none,
std::optional<torch::Tensor> const& b_zeros_or_none, std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none, std::optional<torch::Tensor> const& g_idx_or_none,
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace, std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n, vllm::ScalarTypeId const& b_type_id, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
bool is_zp_float) { bool is_zp_float) {
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id); vllm::ScalarTypeId a_type_id, c_type_id, s_type_id;
int pack_factor = 32 / b_q_type.size_bits();
auto c_dtype = a.dtype();
if (a.scalar_type() == at::ScalarType::Half) {
a_type_id = vllm::kFloat16.id();
c_type_id = vllm::kFloat16.id();
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
a_type_id = vllm::kBFloat16.id();
c_type_id = vllm::kBFloat16.id();
} else {
c_dtype = b_scales.dtype();
if (b_scales.scalar_type() == at::ScalarType::Half) {
c_type_id = vllm::kFloat16.id();
} else if (b_scales.scalar_type() == at::ScalarType::BFloat16) {
c_type_id = vllm::kBFloat16.id();
} else {
c_type_id = vllm::kBFloat16.id();
TORCH_CHECK(c_or_none.has_value(), "c must be passed for W4A8-FP4");
torch::Tensor c = c_or_none.value();
c_dtype = c.dtype();
if (c.scalar_type() == at::ScalarType::Half) {
c_type_id = vllm::kFloat16.id();
} else if (c.scalar_type() == at::ScalarType::BFloat16) {
c_type_id = vllm::kBFloat16.id();
} else {
TORCH_CHECK(false, "unsupported c dtype");
}
}
if (a.scalar_type() == at::ScalarType::Float8_e4m3fn) {
a_type_id = vllm::kFE4M3fn.id();
} else if (a.scalar_type() == at::ScalarType::Char) {
a_type_id = vllm::kS8.id();
} else {
TORCH_CHECK(false, "unsupported `a` scalar_type");
}
}
s_type_id = c_type_id;
if (b_type_id == vllm::kFE2M1f.id()) {
if (b_scales.scalar_type() == at::ScalarType::Float8_e4m3fn) {
s_type_id = vllm::kFE4M3fn.id();
} else if (b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu) {
s_type_id = vllm::kFE8M0fnu.id();
} else {
TORCH_CHECK(false,
"When b_type = float4_e2m1f, b_scale scalar type must be",
"float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4).");
}
}
vllm::ScalarType a_type = vllm::ScalarType::from_id(a_type_id);
vllm::ScalarType b_type = vllm::ScalarType::from_id(b_type_id);
vllm::ScalarType c_type = vllm::ScalarType::from_id(c_type_id);
vllm::ScalarType s_type = vllm::ScalarType::from_id(s_type_id);
int pack_factor = 32 / b_type.size_bits();
// Verify A // Verify A
TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0), TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
@ -721,6 +626,21 @@ torch::Tensor gptq_marlin_gemm(
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
torch::Tensor a_scales;
auto options = torch::TensorOptions().dtype(c_dtype).device(a.device());
auto options_fp32 =
torch::TensorOptions().dtype(at::kFloat).device(a.device());
if (a_scales_or_none.has_value()) {
a_scales = a_scales_or_none.value();
TORCH_CHECK(a_type.size_bits() == 8,
"a_scales can only be used for 8bit activation.");
} else {
a_scales = torch::empty({0}, options_fp32);
TORCH_CHECK(a_type.size_bits() != 8,
"the a_scales parameter must be passed for 8bit activation.");
}
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as // thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// auto -1) // auto -1)
int thread_k = -1; int thread_k = -1;
@ -733,7 +653,6 @@ torch::Tensor gptq_marlin_gemm(
// Alloc buffers // Alloc buffers
const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
torch::Tensor c; torch::Tensor c;
if (c_or_none.has_value()) { if (c_or_none.has_value()) {
c = c_or_none.value(); c = c_or_none.value();
@ -750,8 +669,6 @@ torch::Tensor gptq_marlin_gemm(
// Alloc C tmp buffer that is going to be used for the global reduce // Alloc C tmp buffer that is going to be used for the global reduce
torch::Tensor c_tmp; torch::Tensor c_tmp;
auto options_fp32 =
torch::TensorOptions().dtype(at::kFloat).device(a.device());
if (use_fp32_reduce) { if (use_fp32_reduce) {
int max_m_block_size = (size_m + 16 - 1) / 16 * 16; int max_m_block_size = (size_m + 16 - 1) / 16 * 16;
max_m_block_size = min(max_m_block_size, 64); max_m_block_size = min(max_m_block_size, 64);
@ -821,11 +738,11 @@ torch::Tensor gptq_marlin_gemm(
torch::Tensor global_scale; torch::Tensor global_scale;
if (global_scale_or_none.has_value()) { if (global_scale_or_none.has_value()) {
global_scale = global_scale_or_none.value(); global_scale = global_scale_or_none.value();
TORCH_CHECK(b_q_type == vllm::kFE2M1f && group_size == 16, TORCH_CHECK(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn,
"global_scale can only be used for nvfp4 format."); "global_scale can only be used for nvfp4 format.");
} else { } else {
global_scale = torch::empty({0}, options); global_scale = torch::empty({0}, options);
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f && group_size == 16), TORCH_CHECK(!(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn),
"the global_scale parameter must be passed for nvfp4 format."); "the global_scale parameter must be passed for nvfp4 format.");
} }
@ -852,15 +769,15 @@ torch::Tensor gptq_marlin_gemm(
bool has_zp = b_zeros.size(-1) > 0; bool has_zp = b_zeros.size(-1) > 0;
if (has_zp) { if (has_zp) {
TORCH_CHECK( TORCH_CHECK(
b_q_type == vllm::kU4 || b_q_type == vllm::kU8, b_type == vllm::kU4 || b_type == vllm::kU8,
"b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str()); "b_type must be u4 or u8 when has_zp = True. Got = ", b_type.str());
} else { } else {
TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 || TORCH_CHECK(b_type == vllm::kU4B8 || b_type == vllm::kU8B128 ||
b_q_type == vllm::kFE4M3fn || b_q_type == vllm::kFE2M1f, b_type == vllm::kS4 || b_type == vllm::kS8 ||
"b_q_type must be uint4b8, uint8b128, float8_e4m3fn or " b_type == vllm::kFE4M3fn || b_type == vllm::kFE2M1f,
"float4_e2m1f when " "b_type must be uint4b8, uint8b128, int4, int8, "
"has_zp = False. Got = ", "float8_e4m3fn or float4_e2m1f when has_zp = False. Got = ",
b_q_type.str()); b_type.str());
} }
if (has_zp && is_zp_float) { if (has_zp && is_zp_float) {
@ -902,59 +819,27 @@ torch::Tensor gptq_marlin_gemm(
" is below min_workspace_size = ", min_workspace_size); " is below min_workspace_size = ", min_workspace_size);
int dev = a.get_device(); int dev = a.get_device();
if (a.scalar_type() == at::ScalarType::Half) {
void* scales_ptr;
if (b_q_type == vllm::kFE2M1f) {
if (group_size == 16)
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
else if (group_size == 32)
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
else
TORCH_CHECK(false,
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
"and group_size == 32 (MXFP4)");
} else {
scales_ptr = b_scales.data_ptr<at::Half>();
}
marlin::marlin_mm<half>( TORCH_CHECK(a_scales.scalar_type() == at::ScalarType::Float,
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(), "scalar type of a_scales must be float");
c_tmp.data_ptr<float>(), b_bias.data_ptr<at::Half>(), scales_ptr, TORCH_CHECK(global_scale.scalar_type() == c.scalar_type(),
global_scale.data_ptr<at::Half>(), b_zeros.data_ptr(), g_idx.data_ptr(), "scalar type of global_scale must be the same with c");
perm.data_ptr(), a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k, if (a_type.size_bits() == 16) {
a.stride(0), workspace.data_ptr(), b_q_type, has_bias, has_act_order, TORCH_CHECK(
is_k_full, has_zp, num_groups, group_size, dev, a.scalar_type() == c.scalar_type(),
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, "scalar type of a must be the same with c for 16 bit activation");
use_atomic_add, use_fp32_reduce, is_zp_float);
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
void* scales_ptr;
if (b_q_type == vllm::kFE2M1f) {
if (group_size == 16)
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
else if (group_size == 32)
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
else
TORCH_CHECK(false,
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
"and group_size == 32 (MXFP4)");
} else {
scales_ptr = b_scales.data_ptr<at::BFloat16>();
}
marlin::marlin_mm<nv_bfloat16>(
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
b_bias.data_ptr<at::BFloat16>(), scales_ptr,
global_scale.data_ptr<at::BFloat16>(), b_zeros.data_ptr(),
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
size_m, size_n, size_k, a.stride(0), workspace.data_ptr(), b_q_type,
has_bias, has_act_order, is_k_full, has_zp, num_groups, group_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
use_atomic_add, use_fp32_reduce, is_zp_float);
} else {
TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
} }
marlin::marlin_mm(
a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), c_tmp.data_ptr(),
b_bias.data_ptr(), a_scales.data_ptr(), b_scales.data_ptr(),
global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(),
perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, a.stride(0),
workspace.data_ptr(), a_type, b_type, c_type, s_type, has_bias,
has_act_order, is_k_full, has_zp, num_groups, group_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
use_atomic_add, use_fp32_reduce, is_zp_float);
return c; return c;
} }

View File

@ -4,15 +4,18 @@
namespace marlin { namespace marlin {
template <int const num_threads, int const num_bits, bool const has_perm> template <int const num_threads, int const num_bits, bool const has_perm,
bool is_a_8bit>
__global__ void gptq_marlin_repack_kernel( __global__ void gptq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t const* __restrict__ b_q_weight_ptr,
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) { int size_k, int size_n) {
constexpr int pack_factor = 32 / num_bits; constexpr int pack_factor = 32 / num_bits;
int k_tiles = size_k / tile_k_size; constexpr int target_tile_n_size = tile_n_size / (is_a_8bit ? 2 : 1);
int n_tiles = size_n / tile_n_size; constexpr int target_tile_k_size = tile_k_size * (is_a_8bit ? 2 : 1);
int k_tiles = size_k / target_tile_k_size;
int n_tiles = size_n / target_tile_n_size;
int block_k_tiles = div_ceil(k_tiles, gridDim.x); int block_k_tiles = div_ceil(k_tiles, gridDim.x);
auto start_k_tile = blockIdx.x * block_k_tiles; auto start_k_tile = blockIdx.x * block_k_tiles;
@ -34,7 +37,7 @@ __global__ void gptq_marlin_repack_kernel(
extern __shared__ int4 sh[]; extern __shared__ int4 sh[];
constexpr int perm_size = tile_k_size / 4; constexpr int perm_size = target_tile_k_size / 4;
int4* sh_perm_ptr = sh; int4* sh_perm_ptr = sh;
int4* sh_pipe_ptr = sh_perm_ptr; int4* sh_pipe_ptr = sh_perm_ptr;
@ -42,14 +45,14 @@ __global__ void gptq_marlin_repack_kernel(
sh_pipe_ptr += perm_size; sh_pipe_ptr += perm_size;
} }
constexpr int tile_ints = tile_k_size / pack_factor; constexpr int tile_ints = target_tile_k_size / pack_factor;
constexpr int stage_n_threads = tile_n_size / 4; constexpr int stage_n_threads = target_tile_n_size / 4;
constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints; constexpr int stage_k_threads = has_perm ? target_tile_k_size : tile_ints;
constexpr int stage_size = stage_k_threads * stage_n_threads; constexpr int stage_size = stage_k_threads * stage_n_threads;
auto load_perm_to_shared = [&](int k_tile_id) { auto load_perm_to_shared = [&](int k_tile_id) {
int first_k_int4 = (k_tile_id * tile_k_size) / 4; int first_k_int4 = (k_tile_id * target_tile_k_size) / 4;
int4 const* perm_int4_ptr = reinterpret_cast<int4 const*>(perm_ptr); int4 const* perm_int4_ptr = reinterpret_cast<int4 const*>(perm_ptr);
@ -65,7 +68,7 @@ __global__ void gptq_marlin_repack_kernel(
return; return;
} }
int first_n = n_tile_id * tile_n_size; int first_n = n_tile_id * target_tile_n_size;
int4* sh_ptr = sh_pipe_ptr + stage_size * pipe; int4* sh_ptr = sh_pipe_ptr + stage_size * pipe;
@ -91,7 +94,7 @@ __global__ void gptq_marlin_repack_kernel(
auto k_id = threadIdx.x / stage_n_threads; auto k_id = threadIdx.x / stage_n_threads;
auto n_id = threadIdx.x % stage_n_threads; auto n_id = threadIdx.x % stage_n_threads;
int first_k = k_tile_id * tile_k_size; int first_k = k_tile_id * target_tile_k_size;
int first_k_packed = first_k / pack_factor; int first_k_packed = first_k / pack_factor;
cp_async4(&sh_ptr[k_id * stage_n_threads + n_id], cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
@ -117,13 +120,13 @@ __global__ void gptq_marlin_repack_kernel(
} }
int tc_col = th_id / 4; int tc_col = th_id / 4;
int tc_row = (th_id % 4) * 2; int tc_row = (th_id % 4) * (is_a_8bit ? 4 : 2);
constexpr int tc_offsets[4] = {0, 1, 8, 9}; constexpr int tc_offsets[4] = {0, 1, 8, 9};
int cur_n = warp_id * 16 + tc_col; int cur_n = (warp_id / (is_a_8bit ? 2 : 1)) * 16 + tc_col;
constexpr int sh_stride = 64; constexpr int sh_stride = target_tile_n_size;
constexpr uint32_t mask = (1 << num_bits) - 1; constexpr uint32_t mask = (1 << num_bits) - 1;
int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe; int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe;
@ -134,6 +137,7 @@ __global__ void gptq_marlin_repack_kernel(
uint32_t vals[8]; uint32_t vals[8];
if constexpr (has_perm) { if constexpr (has_perm) {
static_assert(!is_a_8bit);
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
int k_idx = tc_row + tc_offsets[i]; int k_idx = tc_row + tc_offsets[i];
@ -156,28 +160,49 @@ __global__ void gptq_marlin_repack_kernel(
#pragma unroll #pragma unroll
for (int i = 0; i < tile_ints; i++) { for (int i = 0; i < tile_ints; i++) {
b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i]; if constexpr (is_a_8bit) {
b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i]; b1_vals[i] =
sh_stage_int_ptr[cur_n + sh_stride * i + (warp_id % 2) * 8];
} else {
b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];
b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];
}
} }
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
int cur_elem = tc_row + tc_offsets[i]; int cur_elem = tc_row + (is_a_8bit ? i : tc_offsets[i]);
int cur_int = cur_elem / pack_factor; int cur_int = cur_elem / pack_factor;
int cur_pos = cur_elem % pack_factor; int cur_pos = cur_elem % pack_factor;
vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask; vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask;
vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask; if constexpr (is_a_8bit)
vals[4 + i] =
(b1_vals[cur_int + tile_ints / 2] >> (cur_pos * num_bits)) & mask;
else
vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask;
} }
} }
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor; constexpr int tile_size =
target_tile_k_size * target_tile_n_size / pack_factor;
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
// Result of: // Result of:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
if constexpr (num_bits == 4) { if constexpr (!is_a_8bit && num_bits == 4) {
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
uint32_t res = 0;
#pragma unroll
for (int i = 0; i < 8; i++) {
res |= vals[pack_idx[i]] << (i * 4);
}
out_ptr[out_offset + th_id * 4 + warp_id] = res;
} else if constexpr (is_a_8bit && num_bits == 4) {
int pack_idx[8] = {0, 4, 1, 5, 2, 6, 3, 7};
uint32_t res = 0; uint32_t res = 0;
#pragma unroll #pragma unroll
@ -194,8 +219,9 @@ __global__ void gptq_marlin_repack_kernel(
uint32_t res2 = 0; uint32_t res2 = 0;
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
res1 |= vals[pack_idx[i]] << (i * 8); const int ii = is_a_8bit ? i : pack_idx[i];
res2 |= vals[4 + pack_idx[i]] << (i * 8); res1 |= vals[ii] << (i * 8);
res2 |= vals[4 + ii] << (i * 8);
} }
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
@ -236,21 +262,22 @@ __global__ void gptq_marlin_repack_kernel(
} // namespace marlin } // namespace marlin
#define CALL_IF(NUM_BITS, HAS_PERM) \ #define CALL_IF(NUM_BITS, HAS_PERM, IS_A_8BIT) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ else if (num_bits == NUM_BITS && has_perm == HAS_PERM && \
is_a_8bit == IS_A_8BIT) { \
cudaFuncSetAttribute( \ cudaFuncSetAttribute( \
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \ marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
HAS_PERM>, \ HAS_PERM, IS_A_8BIT>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \ marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
HAS_PERM> \ HAS_PERM, IS_A_8BIT> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \ <<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
} }
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n, int64_t size_k, int64_t size_n,
int64_t num_bits) { int64_t num_bits, bool is_a_8bit) {
// Verify compatibility with marlin tile of 16x64 // Verify compatibility with marlin tile of 16x64
TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k, TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
" is not divisible by tile_k_size = ", marlin::tile_k_size); " is not divisible by tile_k_size = ", marlin::tile_k_size);
@ -309,13 +336,17 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
if (false) { if (false) {
} }
CALL_IF(4, false) CALL_IF(4, false, false)
CALL_IF(4, true) CALL_IF(4, true, false)
CALL_IF(8, false) CALL_IF(8, false, false)
CALL_IF(8, true) CALL_IF(8, true, false)
CALL_IF(4, false, true)
CALL_IF(8, false, true)
else { else {
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits, TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits,
", has_perm = ", has_perm); ", has_perm = ", has_perm, ", is_a_8bit = ", is_a_8bit);
} }
return out; return out;

View File

@ -11,17 +11,19 @@
const int4 *__restrict__ A, const int4 *__restrict__ B, \ const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ b_bias_ptr, \ const int4 *__restrict__ b_bias_ptr, \
const float *__restrict__ a_scales_ptr, \
const int4 *__restrict__ scales_ptr, \ const int4 *__restrict__ scales_ptr, \
const uint16_t *__restrict__ scale2_ptr, \ const uint16_t *__restrict__ global_scale_ptr, \
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \ int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \
bool has_bias, bool use_atomic_add, bool use_fp32_reduce, \ bool has_bias, bool use_atomic_add, bool use_fp32_reduce, \
int max_shared_mem int max_shared_mem
namespace MARLIN_NAMESPACE_NAME { namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t, // compute dtype, half or nv_float16 template <const vllm::ScalarTypeId a_type_id, // A ScalarType id
const vllm::ScalarTypeId w_type_id, // weight ScalarType id const vllm::ScalarTypeId b_type_id, // B ScalarType id
const vllm::ScalarTypeId s_type_id, // weight ScalarType id const vllm::ScalarTypeId c_type_id, // C ScalarType id
const vllm::ScalarTypeId s_type_id, // B_SCALE ScalarType id
const int threads, // number of threads in a threadblock const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the // dimension (batchsize) of the

View File

@ -55,6 +55,45 @@ constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
// No support for async // No support for async
#else #else
__device__ inline void cp_async1_ca_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
const int BYTES = 4;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.ca.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES));
}
__device__ inline void cp_async2_ca_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
const int BYTES = 8;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.ca.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES));
}
__device__ inline void cp_async4_ca_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.ca.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES));
}
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) { bool pred = true) {
const int BYTES = 16; const int BYTES = 16;

View File

@ -2,8 +2,10 @@
#ifndef _data_types_cuh #ifndef _data_types_cuh
#define _data_types_cuh #define _data_types_cuh
#include "marlin.cuh" #include "marlin.cuh"
#include "core/scalar_type.hpp"
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <cuda_fp8.h>
#ifndef MARLIN_NAMESPACE_NAME #ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin #define MARLIN_NAMESPACE_NAME marlin
@ -11,14 +13,16 @@
namespace MARLIN_NAMESPACE_NAME { namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t> template <long scalar_type_id>
class ScalarType {}; class MarlinScalarType {};
template <> template <>
class ScalarType<half> { class MarlinScalarType<vllm::kFloat16.id()> {
public: public:
using scalar_t = half; using scalar_t = half;
using scalar_t2 = half2; using scalar_t2 = half2;
using scalar_t4 = half2;
using scalar_32bit_t = half2;
// Matrix fragments for tensor core instructions; their precise layout is // Matrix fragments for tensor core instructions; their precise layout is
// documented here: // documented here:
@ -27,6 +31,7 @@ class ScalarType<half> {
using FragB = Vec<half2, 2>; using FragB = Vec<half2, 2>;
using FragC = Vec<float, 4>; using FragC = Vec<float, 4>;
using FragS = Vec<half2, 1>; using FragS = Vec<half2, 1>;
using FragS0 = Vec<__nv_fp8x2_e4m3, 1>;
using FragZP = Vec<half2, 4>; using FragZP = Vec<half2, 4>;
static __device__ float inline num2float(const half x) { static __device__ float inline num2float(const half x) {
@ -44,18 +49,25 @@ class ScalarType<half> {
static __host__ __device__ half inline float2num(const float x) { static __host__ __device__ half inline float2num(const float x) {
return __float2half(x); return __float2half(x);
} }
static __host__ __device__ float2 inline num22float2(const half2 x) {
return __half22float2(x);
}
}; };
template <> template <>
class ScalarType<nv_bfloat16> { class MarlinScalarType<vllm::kBFloat16.id()> {
public: public:
using scalar_t = nv_bfloat16; using scalar_t = nv_bfloat16;
using scalar_t2 = nv_bfloat162; using scalar_t2 = nv_bfloat162;
using scalar_t4 = nv_bfloat162;
using scalar_32bit_t = nv_bfloat162;
using FragA = Vec<nv_bfloat162, 4>; using FragA = Vec<nv_bfloat162, 4>;
using FragB = Vec<nv_bfloat162, 2>; using FragB = Vec<nv_bfloat162, 2>;
using FragC = Vec<float, 4>; using FragC = Vec<float, 4>;
using FragS = Vec<nv_bfloat162, 1>; using FragS = Vec<nv_bfloat162, 1>;
using FragS0 = Vec<__nv_fp8x2_e4m3, 1>;
using FragZP = Vec<nv_bfloat162, 4>; using FragZP = Vec<nv_bfloat162, 4>;
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
@ -75,9 +87,63 @@ class ScalarType<nv_bfloat16> {
static __host__ __device__ nv_bfloat16 inline float2num(const float x) { static __host__ __device__ nv_bfloat16 inline float2num(const float x) {
return __float2bfloat16(x); return __float2bfloat16(x);
} }
static __host__ __device__ float2 inline num22float2(const nv_bfloat162 x) {
return __bfloat1622float2(x);
}
#endif #endif
}; };
template <>
class MarlinScalarType<vllm::kFE4M3fn.id()> {
public:
using scalar_t = __nv_fp8_e4m3;
using scalar_t2 = __nv_fp8x2_e4m3;
using scalar_t4 = __nv_fp8x4_e4m3;
using scalar_32bit_t = __nv_fp8x4_e4m3;
using FragA = Vec<__nv_fp8x4_e4m3, 4>;
using FragB = Vec<__nv_fp8x4_e4m3, 2>;
using FragC = Vec<float, 4>;
using FragZP = Vec<__nv_fp8x2_e4m3, 4>;
static __host__ __device__
float2 inline num22float2(const __nv_fp8x2_e4m3 x) {
return (float2)x;
}
};
template <>
class MarlinScalarType<vllm::kS8.id()> {
public:
using scalar_t = int8_t;
using scalar_t2 = int16_t;
using scalar_t4 = int32_t;
using scalar_32bit_t = int32_t;
using FragA = Vec<int32_t, 4>;
using FragB = Vec<int32_t, 2>;
using FragC = Vec<float, 4>;
using FragZP = Vec<int16_t, 4>;
};
template <typename scalar_t>
class MarlinScalarType2 {};
template <>
class MarlinScalarType2<half> : public MarlinScalarType<vllm::kFloat16.id()> {};
template <>
class MarlinScalarType2<nv_bfloat16>
: public MarlinScalarType<vllm::kBFloat16.id()> {};
template <>
class MarlinScalarType2<__nv_fp8_e4m3>
: public MarlinScalarType<vllm::kFE4M3fn.id()> {};
template <>
class MarlinScalarType2<int8_t> : public MarlinScalarType<vllm::kS8.id()> {};
} // namespace MARLIN_NAMESPACE_NAME } // namespace MARLIN_NAMESPACE_NAME
#endif #endif

View File

@ -0,0 +1,106 @@
#include "marlin.cuh"
#include "core/registration.h"
// for only non-zp format (like gptq)
__global__ void marlin_int4_fp8_preprocess_kernel_without_zp(
// qweight: (size_k * size_n // 8,)
const int32_t* __restrict__ qweight,
// output: same shape with qweight
int32_t* __restrict__ output) {
int32_t val = qweight[blockIdx.x * 32 + threadIdx.x];
int32_t new_val = 0;
#pragma unroll
for (int32_t i = 0; i < 8; i++) {
int32_t single_val = val & 0xF;
single_val = single_val >= 8 ? single_val - 8 : 15 - single_val;
new_val |= single_val << (i * 4);
val >>= 4;
}
output[blockIdx.x * 32 + threadIdx.x] = new_val;
}
// for awq format only (with zp and with awq weight layout)
__global__ void marlin_int4_fp8_preprocess_kernel_awq(
// AWQ qweight: (size_k, size_n // 8)
const int32_t* __restrict__ qweight,
// output: same shape with qweight
int32_t* __restrict__ output,
// AWQ zeros: (size_k // group_size, size_n // 8)
const int32_t* __restrict__ qzeros, int32_t size_n, int32_t size_k,
int32_t group_size) {
int32_t val =
qweight[(blockIdx.x * 32 + threadIdx.x) * size_n / 8 + blockIdx.y];
int32_t zero =
qzeros[(blockIdx.x * 32 + threadIdx.x) / group_size * size_n / 8 +
blockIdx.y];
int32_t new_val = 0;
#pragma unroll
for (int32_t i = 0; i < 8; i++) {
int32_t single_val = val & 0xF;
int32_t single_zero = zero & 0xF;
single_val =
single_val >= single_zero ? single_val - single_zero : 15 - single_val;
new_val |= single_val << (i * 4);
val >>= 4;
zero >>= 4;
}
output[(blockIdx.x * 32 + threadIdx.x) * size_n / 8 + blockIdx.y] = new_val;
}
torch::Tensor marlin_int4_fp8_preprocess(
torch::Tensor& qweight, std::optional<torch::Tensor> qzeros_or_none,
bool inplace) {
TORCH_CHECK(qweight.device().is_cuda(), "qweight is not on GPU");
TORCH_CHECK(qweight.scalar_type() == at::ScalarType::Int,
"qweight.dtype != torch.int32");
const at::cuda::OptionalCUDAGuard device_guard(device_of(qweight));
torch::Tensor output = inplace ? qweight : torch::empty_like(qweight);
if (!qzeros_or_none.has_value()) {
TORCH_CHECK(qweight.numel() * 8 % 256 == 0,
"qweight.numel() * 8 % 256 != 0");
int blocks = qweight.numel() * 8 / 256;
marlin_int4_fp8_preprocess_kernel_without_zp<<<blocks, 32>>>(
(const int32_t*)qweight.data_ptr(), (int32_t*)output.data_ptr());
} else {
int32_t size_k = qweight.size(0);
int32_t size_n = qweight.size(1) * 8;
torch::Tensor qzeros = qzeros_or_none.value();
TORCH_CHECK(size_k % 32 == 0, "size_k % 32 != 0");
TORCH_CHECK(qzeros.device().is_cuda(), "qzeros is not on GPU");
TORCH_CHECK(qzeros.scalar_type() == at::ScalarType::Int,
"qweight.dtype != torch.int32");
TORCH_CHECK(device_of(qweight) == device_of(qzeros),
"qzeros is not on the same device with qweight");
int32_t group_size = qweight.size(0) / qzeros.size(0);
TORCH_CHECK(qweight.size(1) == qzeros.size(1),
"qweight.size(1) != qzeros.size(1)");
TORCH_CHECK(qweight.size(0) % qzeros.size(0) == 0,
"qweight.size(0) % qzeros.size(0) != 0");
TORCH_CHECK(group_size % 8 == 0, "group_size % 8 != 0");
dim3 blocks(size_k / 32, size_n / 8);
marlin_int4_fp8_preprocess_kernel_awq<<<blocks, 32>>>(
(const int32_t*)qweight.data_ptr(), (int32_t*)output.data_ptr(),
(const int32_t*)qzeros.data_ptr(), size_n, size_k, group_size);
}
return output;
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("marlin_int4_fp8_preprocess", &marlin_int4_fp8_preprocess);
}

File diff suppressed because it is too large Load Diff

View File

@ -298,9 +298,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// gptq_marlin Optimized Quantized GEMM for GPTQ. // gptq_marlin Optimized Quantized GEMM for GPTQ.
ops.def( ops.def(
"gptq_marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, " "gptq_marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, "
"Tensor? b_bias_or_none," "Tensor? b_bias_or_none,Tensor b_scales, "
"Tensor b_scales, Tensor? global_scale, Tensor? b_zeros_or_none, Tensor? " "Tensor? a_scales, Tensor? global_scale, Tensor? b_zeros_or_none, "
"g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_q_type, " "Tensor? "
"g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_type_id, "
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, " "SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
"bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor"); "bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor");
// conditionally compiled so impl registration is in source file // conditionally compiled so impl registration is in source file
@ -308,13 +309,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// gptq_marlin repack from GPTQ. // gptq_marlin repack from GPTQ.
ops.def( ops.def(
"gptq_marlin_repack(Tensor b_q_weight, Tensor perm, " "gptq_marlin_repack(Tensor b_q_weight, Tensor perm, "
"SymInt size_k, SymInt size_n, int num_bits) -> Tensor"); "SymInt size_k, SymInt size_n, int num_bits, bool is_a_8bit) -> Tensor");
// conditionally compiled so impl registrations are in source file // conditionally compiled so impl registrations are in source file
// awq_marlin repack from AWQ. // awq_marlin repack from AWQ.
ops.def( ops.def(
"awq_marlin_repack(Tensor b_q_weight, SymInt size_k, " "awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
"SymInt size_n, int num_bits) -> Tensor"); "SymInt size_n, int num_bits, bool is_a_8bit) -> Tensor");
// conditionally compiled so impl registrations are in source file
// preprocess W-int4A-fp8 weight for marlin kernel
ops.def(
"marlin_int4_fp8_preprocess(Tensor qweight, "
"Tensor? qzeros_or_none, bool inplace) -> Tensor");
// conditionally compiled so impl registrations are in source file // conditionally compiled so impl registrations are in source file
// CUTLASS w4a8 GEMM // CUTLASS w4a8 GEMM

View File

@ -119,6 +119,7 @@ FROM base AS vllm-test-deps
WORKDIR /workspace/vllm WORKDIR /workspace/vllm
# TODO: Update to 2.9.0 when there is a new build for intel_extension_for_pytorch for that version
RUN --mount=type=bind,src=requirements/test.in,target=requirements/test.in \ RUN --mount=type=bind,src=requirements/test.in,target=requirements/test.in \
cp requirements/test.in requirements/cpu-test.in && \ cp requirements/test.in requirements/cpu-test.in && \
sed -i '/mamba_ssm/d' requirements/cpu-test.in && \ sed -i '/mamba_ssm/d' requirements/cpu-test.in && \
@ -131,6 +132,9 @@ RUN --mount=type=bind,src=requirements/test.in,target=requirements/test.in \
esac; \ esac; \
}; \ }; \
remove_packages_not_supported_on_aarch64 && \ remove_packages_not_supported_on_aarch64 && \
sed -i 's/^torch==.*/torch==2.9.1/g' requirements/cpu-test.in && \
sed -i 's/torchaudio.*/torchaudio/g' requirements/cpu-test.in && \
sed -i 's/torchvision.*/torchvision/g' requirements/cpu-test.in && \
uv pip compile requirements/cpu-test.in -o requirements/cpu-test.txt --index-strategy unsafe-best-match --torch-backend cpu uv pip compile requirements/cpu-test.in -o requirements/cpu-test.txt --index-strategy unsafe-best-match --torch-backend cpu
RUN --mount=type=cache,target=/root/.cache/uv \ RUN --mount=type=cache,target=/root/.cache/uv \

View File

@ -65,6 +65,8 @@ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/tests /tests
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/examples /examples COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/examples /examples
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/docker/Dockerfile.rocm /docker/ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/docker/Dockerfile.rocm /docker/
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/.buildkite /.buildkite COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/.buildkite /.buildkite
# Centralized v1 package - copied to both test and final stages
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/vllm/v1 /vllm_v1
# ----------------------- # -----------------------
# Test vLLM image # Test vLLM image
@ -88,10 +90,22 @@ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm /vllm-workspace
# install development dependencies (for testing) # install development dependencies (for testing)
RUN cd /vllm-workspace \ RUN cd /vllm-workspace \
&& rm -rf vllm \
&& python3 -m pip install -e tests/vllm_test_utils \ && python3 -m pip install -e tests/vllm_test_utils \
&& python3 -m pip install pytest-shard && python3 -m pip install pytest-shard
# enable fast downloads from hf (for testing)
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system hf_transfer
ENV HF_HUB_ENABLE_HF_TRANSFER=1
# Copy in the v1 package
COPY --from=export_vllm /vllm_v1 /usr/local/lib/python${PYTHON_VERSION}/dist-packages/vllm/v1
# Source code is used in the `python_only_compile.sh` test
# We hide it inside `src/` so that this source code
# will not be imported by other tests
RUN mkdir src && mv vllm src/vllm
# ----------------------- # -----------------------
# Final vLLM image # Final vLLM image
FROM base AS final FROM base AS final
@ -116,6 +130,9 @@ RUN --mount=type=bind,from=export_vllm,src=/,target=/install \
&& pip uninstall -y vllm \ && pip uninstall -y vllm \
&& uv pip install --system *.whl && uv pip install --system *.whl
# Copy in the v1 package
COPY --from=export_vllm /vllm_v1 /usr/local/lib/python${PYTHON_VERSION}/dist-packages/vllm/v1
ARG COMMON_WORKDIR ARG COMMON_WORKDIR
# Copy over the benchmark scripts as well # Copy over the benchmark scripts as well

View File

@ -5,6 +5,8 @@ ARG PYTORCH_BRANCH="1c57644d"
ARG PYTORCH_VISION_BRANCH="v0.23.0" ARG PYTORCH_VISION_BRANCH="v0.23.0"
ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git" ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git"
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
ARG PYTORCH_AUDIO_BRANCH="v2.9.0"
ARG PYTORCH_AUDIO_REPO="https://github.com/pytorch/audio.git"
ARG FA_BRANCH="0e60e394" ARG FA_BRANCH="0e60e394"
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git" ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
ARG AITER_BRANCH="59bd8ff2" ARG AITER_BRANCH="59bd8ff2"
@ -23,6 +25,7 @@ ENV AITER_ROCM_ARCH=gfx942;gfx950
ENV HSA_NO_SCRATCH_RECLAIM=1 ENV HSA_NO_SCRATCH_RECLAIM=1
ARG PYTHON_VERSION=3.12 ARG PYTHON_VERSION=3.12
ENV PYTHON_VERSION=${PYTHON_VERSION}
RUN mkdir -p /app RUN mkdir -p /app
WORKDIR /app WORKDIR /app
@ -45,6 +48,7 @@ RUN apt-get update -y \
&& python3 --version && python3 -m pip --version && python3 --version && python3 -m pip --version
RUN pip install -U packaging 'cmake<4' ninja wheel 'setuptools<80' pybind11 Cython RUN pip install -U packaging 'cmake<4' ninja wheel 'setuptools<80' pybind11 Cython
RUN apt-get update && apt-get install -y libjpeg-dev libsox-dev libsox-fmt-all sox && rm -rf /var/lib/apt/lists/*
FROM base AS build_triton FROM base AS build_triton
ARG TRITON_BRANCH ARG TRITON_BRANCH
@ -66,11 +70,14 @@ RUN mkdir -p /app/install && cp /opt/rocm/share/amd_smi/dist/*.whl /app/install
FROM base AS build_pytorch FROM base AS build_pytorch
ARG PYTORCH_BRANCH ARG PYTORCH_BRANCH
ARG PYTORCH_VISION_BRANCH ARG PYTORCH_VISION_BRANCH
ARG PYTORCH_AUDIO_BRANCH
ARG PYTORCH_REPO ARG PYTORCH_REPO
ARG PYTORCH_VISION_REPO ARG PYTORCH_VISION_REPO
ARG PYTORCH_AUDIO_REPO
RUN git clone ${PYTORCH_REPO} pytorch RUN git clone ${PYTORCH_REPO} pytorch
RUN cd pytorch && git checkout ${PYTORCH_BRANCH} && \ RUN cd pytorch && git checkout ${PYTORCH_BRANCH} \
pip install -r requirements.txt && git submodule update --init --recursive \ && pip install -r requirements.txt && git submodule update --init --recursive \
&& python3 tools/amd_build/build_amd.py \ && python3 tools/amd_build/build_amd.py \
&& CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist \ && CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist \
&& pip install dist/*.whl && pip install dist/*.whl
@ -78,8 +85,15 @@ RUN git clone ${PYTORCH_VISION_REPO} vision
RUN cd vision && git checkout ${PYTORCH_VISION_BRANCH} \ RUN cd vision && git checkout ${PYTORCH_VISION_BRANCH} \
&& python3 setup.py bdist_wheel --dist-dir=dist \ && python3 setup.py bdist_wheel --dist-dir=dist \
&& pip install dist/*.whl && pip install dist/*.whl
RUN git clone ${PYTORCH_AUDIO_REPO} audio
RUN cd audio && git checkout ${PYTORCH_AUDIO_BRANCH} \
&& git submodule update --init --recursive \
&& pip install -r requirements.txt \
&& python3 setup.py bdist_wheel --dist-dir=dist \
&& pip install dist/*.whl
RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \ RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \
&& cp /app/vision/dist/*.whl /app/install && cp /app/vision/dist/*.whl /app/install \
&& cp /app/audio/dist/*.whl /app/install
FROM base AS build_fa FROM base AS build_fa
ARG FA_BRANCH ARG FA_BRANCH
@ -130,6 +144,8 @@ ARG PYTORCH_BRANCH
ARG PYTORCH_VISION_BRANCH ARG PYTORCH_VISION_BRANCH
ARG PYTORCH_REPO ARG PYTORCH_REPO
ARG PYTORCH_VISION_REPO ARG PYTORCH_VISION_REPO
ARG PYTORCH_AUDIO_BRANCH
ARG PYTORCH_AUDIO_REPO
ARG FA_BRANCH ARG FA_BRANCH
ARG FA_REPO ARG FA_REPO
ARG AITER_BRANCH ARG AITER_BRANCH
@ -141,7 +157,9 @@ RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \
&& echo "PYTORCH_VISION_BRANCH: ${PYTORCH_VISION_BRANCH}" >> /app/versions.txt \ && echo "PYTORCH_VISION_BRANCH: ${PYTORCH_VISION_BRANCH}" >> /app/versions.txt \
&& echo "PYTORCH_REPO: ${PYTORCH_REPO}" >> /app/versions.txt \ && echo "PYTORCH_REPO: ${PYTORCH_REPO}" >> /app/versions.txt \
&& echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \ && echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \
&& echo "PYTORCH_AUDIO_BRANCH: ${PYTORCH_AUDIO_BRANCH}" >> /app/versions.txt \
&& echo "PYTORCH_AUDIO_REPO: ${PYTORCH_AUDIO_REPO}" >> /app/versions.txt \
&& echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \ && echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \
&& echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt \ && echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt \
&& echo "AITER_BRANCH: ${AITER_BRANCH}" >> /app/versions.txt \ && echo "AITER_BRANCH: ${AITER_BRANCH}" >> /app/versions.txt \
&& echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt && echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt

View File

@ -52,6 +52,11 @@ nav:
- Plugins: - Plugins:
- design/*plugin*.md - design/*plugin*.md
- design/* - design/*
- Benchmarking:
- benchmarking/README.md
- benchmarking/cli.md
- benchmarking/sweeps.md
- benchmarking/dashboard.md
- API Reference: - API Reference:
- api/README.md - api/README.md
- api/vllm - api/vllm

View File

@ -0,0 +1,7 @@
# Benchmark Suites
vLLM provides comprehensive benchmarking tools for performance testing and evaluation:
- **[Benchmark CLI](./cli.md)**: `vllm bench` CLI tools and specialized benchmark scripts for interactive performance testing.
- **[Parameter Sweeps](./sweeps.md)**: Automate `vllm bench` runs for multiple configurations, useful for [optimization and tuning](../configuration/optimization.md).
- **[Performance Dashboard](./dashboard.md)**: Automated CI that publishes benchmarks on each commit.

View File

@ -1,22 +1,10 @@
--- # Benchmark CLI
toc_depth: 4
---
# Benchmark Suites This section guides you through running benchmark tests with the extensive datasets supported on vLLM.
vLLM provides comprehensive benchmarking tools for performance testing and evaluation: It's a living document, updated as new features and datasets become available.
- **[Benchmark CLI](#benchmark-cli)**: `vllm bench` CLI tools and specialized benchmark scripts for interactive performance testing ## Dataset Overview
- **[Parameter sweeps](#parameter-sweeps)**: Automate `vllm bench` runs for multiple configurations
- **[Performance benchmarks](#performance-benchmarks)**: Automated CI benchmarks for development
## Benchmark CLI
This section guides you through running benchmark tests with the extensive
datasets supported on vLLM. It's a living document, updated as new features and datasets
become available.
### Dataset Overview
<style> <style>
th { th {
@ -59,9 +47,9 @@ Legend:
--dataset-path /datasets/VisionArena-Chat/ --hf-name lmarena-ai/VisionArena-Chat --dataset-path /datasets/VisionArena-Chat/ --hf-name lmarena-ai/VisionArena-Chat
``` ```
### Examples ## Examples
#### 🚀 Online Benchmark ### 🚀 Online Benchmark
<details class="admonition abstract" markdown="1"> <details class="admonition abstract" markdown="1">
<summary>Show more</summary> <summary>Show more</summary>
@ -112,7 +100,7 @@ P99 ITL (ms): 8.39
================================================== ==================================================
``` ```
##### Custom Dataset #### Custom Dataset
If the dataset you want to benchmark is not supported yet in vLLM, even then you can benchmark on it using `CustomDataset`. Your data needs to be in `.jsonl` format and needs to have "prompt" field per entry, e.g., data.jsonl If the dataset you want to benchmark is not supported yet in vLLM, even then you can benchmark on it using `CustomDataset`. Your data needs to be in `.jsonl` format and needs to have "prompt" field per entry, e.g., data.jsonl
@ -145,7 +133,7 @@ vllm bench serve --port 9001 --save-result --save-detailed \
You can skip applying chat template if your data already has it by using `--custom-skip-chat-template`. You can skip applying chat template if your data already has it by using `--custom-skip-chat-template`.
##### VisionArena Benchmark for Vision Language Models #### VisionArena Benchmark for Vision Language Models
```bash ```bash
# need a model with vision capability here # need a model with vision capability here
@ -163,7 +151,7 @@ vllm bench serve \
--num-prompts 1000 --num-prompts 1000
``` ```
##### InstructCoder Benchmark with Speculative Decoding #### InstructCoder Benchmark with Speculative Decoding
``` bash ``` bash
vllm serve meta-llama/Meta-Llama-3-8B-Instruct \ vllm serve meta-llama/Meta-Llama-3-8B-Instruct \
@ -180,7 +168,7 @@ vllm bench serve \
--num-prompts 2048 --num-prompts 2048
``` ```
##### Spec Bench Benchmark with Speculative Decoding #### Spec Bench Benchmark with Speculative Decoding
``` bash ``` bash
vllm serve meta-llama/Meta-Llama-3-8B-Instruct \ vllm serve meta-llama/Meta-Llama-3-8B-Instruct \
@ -217,7 +205,7 @@ vllm bench serve \
--spec-bench-category "summarization" --spec-bench-category "summarization"
``` ```
##### Other HuggingFaceDataset Examples #### Other HuggingFaceDataset Examples
```bash ```bash
vllm serve Qwen/Qwen2-VL-7B-Instruct vllm serve Qwen/Qwen2-VL-7B-Instruct
@ -283,7 +271,7 @@ vllm bench serve \
--blazedit-max-distance 0.99 --blazedit-max-distance 0.99
``` ```
##### Running With Sampling Parameters #### Running With Sampling Parameters
When using OpenAI-compatible backends such as `vllm`, optional sampling When using OpenAI-compatible backends such as `vllm`, optional sampling
parameters can be specified. Example client command: parameters can be specified. Example client command:
@ -301,7 +289,7 @@ vllm bench serve \
--num-prompts 10 --num-prompts 10
``` ```
##### Running With Ramp-Up Request Rate #### Running With Ramp-Up Request Rate
The benchmark tool also supports ramping up the request rate over the The benchmark tool also supports ramping up the request rate over the
duration of the benchmark run. This can be useful for stress testing the duration of the benchmark run. This can be useful for stress testing the
@ -318,11 +306,11 @@ The following arguments can be used to control the ramp-up:
- `--ramp-up-start-rps`: The request rate at the beginning of the benchmark. - `--ramp-up-start-rps`: The request rate at the beginning of the benchmark.
- `--ramp-up-end-rps`: The request rate at the end of the benchmark. - `--ramp-up-end-rps`: The request rate at the end of the benchmark.
##### Load Pattern Configuration #### Load Pattern Configuration
vLLM's benchmark serving script provides sophisticated load pattern simulation capabilities through three key parameters that control request generation and concurrency behavior: vLLM's benchmark serving script provides sophisticated load pattern simulation capabilities through three key parameters that control request generation and concurrency behavior:
###### Load Pattern Control Parameters ##### Load Pattern Control Parameters
- `--request-rate`: Controls the target request generation rate (requests per second). Set to `inf` for maximum throughput testing or finite values for controlled load simulation. - `--request-rate`: Controls the target request generation rate (requests per second). Set to `inf` for maximum throughput testing or finite values for controlled load simulation.
- `--burstiness`: Controls traffic variability using a Gamma distribution (range: > 0). Lower values create bursty traffic, higher values create uniform traffic. - `--burstiness`: Controls traffic variability using a Gamma distribution (range: > 0). Lower values create bursty traffic, higher values create uniform traffic.
@ -387,7 +375,7 @@ Using KV cache metrics for load pattern configuration:
</details> </details>
#### 📈 Offline Throughput Benchmark ### 📈 Offline Throughput Benchmark
<details class="admonition abstract" markdown="1"> <details class="admonition abstract" markdown="1">
<summary>Show more</summary> <summary>Show more</summary>
@ -408,7 +396,7 @@ Total num prompt tokens: 5014
Total num output tokens: 1500 Total num output tokens: 1500
``` ```
##### VisionArena Benchmark for Vision Language Models #### VisionArena Benchmark for Vision Language Models
```bash ```bash
vllm bench throughput \ vllm bench throughput \
@ -428,7 +416,7 @@ Total num prompt tokens: 14527
Total num output tokens: 1280 Total num output tokens: 1280
``` ```
##### InstructCoder Benchmark with Speculative Decoding #### InstructCoder Benchmark with Speculative Decoding
``` bash ``` bash
VLLM_WORKER_MULTIPROC_METHOD=spawn \ VLLM_WORKER_MULTIPROC_METHOD=spawn \
@ -451,7 +439,7 @@ Total num prompt tokens: 261136
Total num output tokens: 204800 Total num output tokens: 204800
``` ```
##### Other HuggingFaceDataset Examples #### Other HuggingFaceDataset Examples
`lmms-lab/LLaVA-OneVision-Data`: `lmms-lab/LLaVA-OneVision-Data`:
@ -509,20 +497,20 @@ vllm bench throughput \
</details> </details>
#### 🛠️ Structured Output Benchmark ### 🛠️ Structured Output Benchmark
<details class="admonition abstract" markdown="1"> <details class="admonition abstract" markdown="1">
<summary>Show more</summary> <summary>Show more</summary>
Benchmark the performance of structured output generation (JSON, grammar, regex). Benchmark the performance of structured output generation (JSON, grammar, regex).
##### Server Setup #### Server Setup
```bash ```bash
vllm serve NousResearch/Hermes-3-Llama-3.1-8B vllm serve NousResearch/Hermes-3-Llama-3.1-8B
``` ```
##### JSON Schema Benchmark #### JSON Schema Benchmark
```bash ```bash
python3 benchmarks/benchmark_serving_structured_output.py \ python3 benchmarks/benchmark_serving_structured_output.py \
@ -534,7 +522,7 @@ python3 benchmarks/benchmark_serving_structured_output.py \
--num-prompts 1000 --num-prompts 1000
``` ```
##### Grammar-based Generation Benchmark #### Grammar-based Generation Benchmark
```bash ```bash
python3 benchmarks/benchmark_serving_structured_output.py \ python3 benchmarks/benchmark_serving_structured_output.py \
@ -546,7 +534,7 @@ python3 benchmarks/benchmark_serving_structured_output.py \
--num-prompts 1000 --num-prompts 1000
``` ```
##### Regex-based Generation Benchmark #### Regex-based Generation Benchmark
```bash ```bash
python3 benchmarks/benchmark_serving_structured_output.py \ python3 benchmarks/benchmark_serving_structured_output.py \
@ -557,7 +545,7 @@ python3 benchmarks/benchmark_serving_structured_output.py \
--num-prompts 1000 --num-prompts 1000
``` ```
##### Choice-based Generation Benchmark #### Choice-based Generation Benchmark
```bash ```bash
python3 benchmarks/benchmark_serving_structured_output.py \ python3 benchmarks/benchmark_serving_structured_output.py \
@ -568,7 +556,7 @@ python3 benchmarks/benchmark_serving_structured_output.py \
--num-prompts 1000 --num-prompts 1000
``` ```
##### XGrammar Benchmark Dataset #### XGrammar Benchmark Dataset
```bash ```bash
python3 benchmarks/benchmark_serving_structured_output.py \ python3 benchmarks/benchmark_serving_structured_output.py \
@ -581,14 +569,14 @@ python3 benchmarks/benchmark_serving_structured_output.py \
</details> </details>
#### 📚 Long Document QA Benchmark ### 📚 Long Document QA Benchmark
<details class="admonition abstract" markdown="1"> <details class="admonition abstract" markdown="1">
<summary>Show more</summary> <summary>Show more</summary>
Benchmark the performance of long document question-answering with prefix caching. Benchmark the performance of long document question-answering with prefix caching.
##### Basic Long Document QA Test #### Basic Long Document QA Test
```bash ```bash
python3 benchmarks/benchmark_long_document_qa_throughput.py \ python3 benchmarks/benchmark_long_document_qa_throughput.py \
@ -600,7 +588,7 @@ python3 benchmarks/benchmark_long_document_qa_throughput.py \
--repeat-count 5 --repeat-count 5
``` ```
##### Different Repeat Modes #### Different Repeat Modes
```bash ```bash
# Random mode (default) - shuffle prompts randomly # Random mode (default) - shuffle prompts randomly
@ -633,14 +621,14 @@ python3 benchmarks/benchmark_long_document_qa_throughput.py \
</details> </details>
#### 🗂️ Prefix Caching Benchmark ### 🗂️ Prefix Caching Benchmark
<details class="admonition abstract" markdown="1"> <details class="admonition abstract" markdown="1">
<summary>Show more</summary> <summary>Show more</summary>
Benchmark the efficiency of automatic prefix caching. Benchmark the efficiency of automatic prefix caching.
##### Fixed Prompt with Prefix Caching #### Fixed Prompt with Prefix Caching
```bash ```bash
python3 benchmarks/benchmark_prefix_caching.py \ python3 benchmarks/benchmark_prefix_caching.py \
@ -651,7 +639,7 @@ python3 benchmarks/benchmark_prefix_caching.py \
--input-length-range 128:256 --input-length-range 128:256
``` ```
##### ShareGPT Dataset with Prefix Caching #### ShareGPT Dataset with Prefix Caching
```bash ```bash
# download dataset # download dataset
@ -682,14 +670,14 @@ vllm bench serve \
</details> </details>
#### ⚡ Request Prioritization Benchmark ### ⚡ Request Prioritization Benchmark
<details class="admonition abstract" markdown="1"> <details class="admonition abstract" markdown="1">
<summary>Show more</summary> <summary>Show more</summary>
Benchmark the performance of request prioritization in vLLM. Benchmark the performance of request prioritization in vLLM.
##### Basic Prioritization Test #### Basic Prioritization Test
```bash ```bash
python3 benchmarks/benchmark_prioritization.py \ python3 benchmarks/benchmark_prioritization.py \
@ -700,7 +688,7 @@ python3 benchmarks/benchmark_prioritization.py \
--scheduling-policy priority --scheduling-policy priority
``` ```
##### Multiple Sequences per Prompt #### Multiple Sequences per Prompt
```bash ```bash
python3 benchmarks/benchmark_prioritization.py \ python3 benchmarks/benchmark_prioritization.py \
@ -714,14 +702,14 @@ python3 benchmarks/benchmark_prioritization.py \
</details> </details>
#### 👁️ Multi-Modal Benchmark ### 👁️ Multi-Modal Benchmark
<details class="admonition abstract" markdown="1"> <details class="admonition abstract" markdown="1">
<summary>Show more</summary> <summary>Show more</summary>
Benchmark the performance of multi-modal requests in vLLM. Benchmark the performance of multi-modal requests in vLLM.
##### Images (ShareGPT4V) #### Images (ShareGPT4V)
Start vLLM: Start vLLM:
@ -747,7 +735,7 @@ vllm bench serve \
--endpoint /v1/chat/completions --endpoint /v1/chat/completions
``` ```
##### Videos (ShareGPT4Video) #### Videos (ShareGPT4Video)
Start vLLM: Start vLLM:
@ -773,7 +761,7 @@ vllm bench serve \
--endpoint /v1/chat/completions --endpoint /v1/chat/completions
``` ```
##### Synthetic Random Images (random-mm) #### Synthetic Random Images (random-mm)
Generate synthetic image inputs alongside random text prompts to stress-test vision models without external datasets. Generate synthetic image inputs alongside random text prompts to stress-test vision models without external datasets.
@ -846,14 +834,14 @@ This should be seen as an edge case, and if this behavior can be avoided by sett
</details> </details>
#### Embedding Benchmark ### Embedding Benchmark
Benchmark the performance of embedding requests in vLLM. Benchmark the performance of embedding requests in vLLM.
<details class="admonition abstract" markdown="1"> <details class="admonition abstract" markdown="1">
<summary>Show more</summary> <summary>Show more</summary>
##### Text Embeddings #### Text Embeddings
Unlike generative models which use Completions API or Chat Completions API, Unlike generative models which use Completions API or Chat Completions API,
you should set `--backend openai-embeddings` and `--endpoint /v1/embeddings` to use the Embeddings API. you should set `--backend openai-embeddings` and `--endpoint /v1/embeddings` to use the Embeddings API.
@ -879,7 +867,7 @@ vllm bench serve \
--dataset-path <your data path>/ShareGPT_V3_unfiltered_cleaned_split.json --dataset-path <your data path>/ShareGPT_V3_unfiltered_cleaned_split.json
``` ```
##### Multi-modal Embeddings #### Multi-modal Embeddings
Unlike generative models which use Completions API or Chat Completions API, Unlike generative models which use Completions API or Chat Completions API,
you should set `--endpoint /v1/embeddings` to use the Embeddings API. The backend to use depends on the model: you should set `--endpoint /v1/embeddings` to use the Embeddings API. The backend to use depends on the model:
@ -944,7 +932,7 @@ vllm bench serve \
</details> </details>
#### Reranker Benchmark ### Reranker Benchmark
Benchmark the performance of rerank requests in vLLM. Benchmark the performance of rerank requests in vLLM.
@ -988,222 +976,3 @@ to account for the extra prompt which is the query. The token accounting to repo
throughput numbers correctly is also adjusted. throughput numbers correctly is also adjusted.
</details> </details>
## Parameter Sweeps
### Online Benchmark
[`vllm/benchmarks/sweep/serve.py`](../../vllm/benchmarks/sweep/serve.py) automatically starts `vllm serve` and runs `vllm bench serve` to evaluate vLLM over multiple configurations.
Follow these steps to run the script:
1. Construct the base command to `vllm serve`, and pass it to the `--serve-cmd` option.
2. Construct the base command to `vllm bench serve`, and pass it to the `--bench-cmd` option.
3. (Optional) If you would like to vary the settings of `vllm serve`, create a new JSON file and populate it with the parameter combinations you want to test. Pass the file path to `--serve-params`.
- Example: Tuning `--max-num-seqs` and `--max-num-batched-tokens`:
```json
[
{
"max_num_seqs": 32,
"max_num_batched_tokens": 1024
},
{
"max_num_seqs": 64,
"max_num_batched_tokens": 1024
},
{
"max_num_seqs": 64,
"max_num_batched_tokens": 2048
},
{
"max_num_seqs": 128,
"max_num_batched_tokens": 2048
},
{
"max_num_seqs": 128,
"max_num_batched_tokens": 4096
},
{
"max_num_seqs": 256,
"max_num_batched_tokens": 4096
}
]
```
4. (Optional) If you would like to vary the settings of `vllm bench serve`, create a new JSON file and populate it with the parameter combinations you want to test. Pass the file path to `--bench-params`.
- Example: Using different input/output lengths for random dataset:
```json
[
{
"random_input_len": 128,
"random_output_len": 32
},
{
"random_input_len": 256,
"random_output_len": 64
},
{
"random_input_len": 512,
"random_output_len": 128
}
]
```
5. Determine where you want to save the results, and pass that to `--output-dir`.
Example command:
```bash
vllm bench sweep serve \
--serve-cmd 'vllm serve meta-llama/Llama-2-7b-chat-hf' \
--bench-cmd 'vllm bench serve --model meta-llama/Llama-2-7b-chat-hf --backend vllm --endpoint /v1/completions --dataset-name sharegpt --dataset-path benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json' \
--serve-params benchmarks/serve_hparams.json \
--bench-params benchmarks/bench_hparams.json \
-o benchmarks/results
```
!!! important
If both `--serve-params` and `--bench-params` are passed, the script will iterate over the Cartesian product between them.
You can use `--dry-run` to preview the commands to be run.
We only start the server once for each `--serve-params`, and keep it running for multiple `--bench-params`.
Between each benchmark run, we call the `/reset_prefix_cache` and `/reset_mm_cache` endpoints to get a clean slate for the next run.
In case you are using a custom `--serve-cmd`, you can override the commands used for resetting the state by setting `--after-bench-cmd`.
!!! note
By default, each parameter combination is run 3 times to make the results more reliable. You can adjust the number of runs by setting `--num-runs`.
!!! tip
You can use the `--resume` option to continue the parameter sweep if one of the runs failed.
### SLA Auto-Tuner
[`vllm/benchmarks/sweep/serve_sla.py`](../../vllm/benchmarks/sweep/serve_sla.py) is a wrapper over [`vllm/benchmarks/sweep/serve.py`](../../vllm/benchmarks/sweep/serve.py) that tunes either the request rate or concurrency (choose using `--sla-variable`) in order to satisfy the SLA constraints given by `--sla-params`.
For example, to ensure E2E latency within different target values for 99% of requests:
```json
[
{
"p99_e2el_ms": "<=200"
},
{
"p99_e2el_ms": "<=500"
},
{
"p99_e2el_ms": "<=1000"
},
{
"p99_e2el_ms": "<=2000"
}
]
```
Example command:
```bash
vllm bench sweep serve_sla \
--serve-cmd 'vllm serve meta-llama/Llama-2-7b-chat-hf' \
--bench-cmd 'vllm bench serve --model meta-llama/Llama-2-7b-chat-hf --backend vllm --endpoint /v1/completions --dataset-name sharegpt --dataset-path benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json' \
--serve-params benchmarks/serve_hparams.json \
--bench-params benchmarks/bench_hparams.json \
--sla-params benchmarks/sla_hparams.json \
--sla-variable max_concurrency \
-o benchmarks/results
```
The algorithm for adjusting the SLA variable is as follows:
1. Run the benchmark with infinite QPS, and use the corresponding metrics to determine the initial value of the variable.
- For example, the initial request rate is set to the concurrency under infinite QPS.
2. If the SLA is still satisfied, keep doubling the value until the SLA is no longer satisfied. This gives a relatively narrow window that contains the point where the SLA is barely satisfied.
3. Apply binary search over the window to find the maximum value that still satisfies the SLA.
!!! important
SLA tuning is applied over each combination of `--serve-params`, `--bench-params`, and `--sla-params`.
For a given combination of `--serve-params` and `--bench-params`, we share the benchmark results across `--sla-params` to avoid rerunning benchmarks with the same SLA variable value.
### Visualizer
[`vllm/benchmarks/sweep/plot.py`](../../vllm/benchmarks/sweep/plot.py) can be used to plot performance curves from parameter sweep results.
Example command:
```bash
vllm bench sweep plot benchmarks/results/<timestamp> \
--var-x max_concurrency \
--row-by random_input_len \
--col-by random_output_len \
--curve-by api_server_count,max_num_batched_tokens \
--filter-by 'max_concurrency<=1024'
```
!!! tip
You can use `--dry-run` to preview the figures to be plotted.
## Performance Benchmarks
The performance benchmarks are used for development to confirm whether new changes improve performance under various workloads. They are triggered on every commit with both the `perf-benchmarks` and `ready` labels, and when a PR is merged into vLLM.
### Manually Trigger the benchmark
Use [vllm-ci-test-repo images](https://gallery.ecr.aws/q9t5s3a7/vllm-ci-test-repo) with vLLM benchmark suite.
For CPU environment, please use the image with "-cpu" postfix.
Here is an example for docker run command for CPU.
```bash
docker run -it --entrypoint /bin/bash -v /data/huggingface:/root/.cache/huggingface -e HF_TOKEN='' --shm-size=16g --name vllm-cpu-ci public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:1da94e673c257373280026f75ceb4effac80e892-cpu
```
Then, run below command inside the docker instance.
```bash
bash .buildkite/performance-benchmarks/scripts/run-performance-benchmarks.sh
```
When run, benchmark script generates results under **benchmark/results** folder, along with the benchmark_results.md and benchmark_results.json.
#### Runtime environment variables
- `ON_CPU`: set the value to '1' on Intel® Xeon® Processors. Default value is 0.
- `SERVING_JSON`: JSON file to use for the serving tests. Default value is empty string (use default file).
- `LATENCY_JSON`: JSON file to use for the latency tests. Default value is empty string (use default file).
- `THROUGHPUT_JSON`: JSON file to use for the throughout tests. Default value is empty string (use default file).
- `REMOTE_HOST`: IP for the remote vLLM service to benchmark. Default value is empty string.
- `REMOTE_PORT`: Port for the remote vLLM service to benchmark. Default value is empty string.
For more results visualization, check the [visualizing the results](https://github.com/intel-ai-tce/vllm/blob/more_cpu_models/.buildkite/nightly-benchmarks/README.md#visualizing-the-results).
The latest performance results are hosted on the public [vLLM Performance Dashboard](https://hud.pytorch.org/benchmark/llms?repoName=vllm-project%2Fvllm).
More information on the performance benchmarks and their parameters can be found in [Benchmark README](https://github.com/intel-ai-tce/vllm/blob/more_cpu_models/.buildkite/nightly-benchmarks/README.md) and [performance benchmark description](../../.buildkite/performance-benchmarks/performance-benchmarks-descriptions.md).
### Continuous Benchmarking
The continuous benchmarking provides automated performance monitoring for vLLM across different models and GPU devices. This helps track vLLM's performance characteristics over time and identify any performance regressions or improvements.
#### How It Works
The continuous benchmarking is triggered via a [GitHub workflow CI](https://github.com/pytorch/pytorch-integration-testing/actions/workflows/vllm-benchmark.yml) in the PyTorch infrastructure repository, which runs automatically every 4 hours. The workflow executes three types of performance tests:
- **Serving tests**: Measure request handling and API performance
- **Throughput tests**: Evaluate token generation rates
- **Latency tests**: Assess response time characteristics
#### Benchmark Configuration
The benchmarking currently runs on a predefined set of models configured in the [vllm-benchmarks directory](https://github.com/pytorch/pytorch-integration-testing/tree/main/vllm-benchmarks/benchmarks). To add new models for benchmarking:
1. Navigate to the appropriate GPU directory in the benchmarks configuration
2. Add your model specifications to the corresponding configuration files
3. The new models will be included in the next scheduled benchmark run
#### Viewing Results
All continuous benchmarking results are automatically published to the public [vLLM Performance Dashboard](https://hud.pytorch.org/benchmark/llms?repoName=vllm-project%2Fvllm).

View File

@ -0,0 +1,58 @@
# Performance Dashboard
The performance dashboard is used to confirm whether new changes improve/degrade performance under various workloads.
It is updated by triggering benchmark runs on every commit with both the `perf-benchmarks` and `ready` labels, and when a PR is merged into vLLM.
The results are automatically published to the public [vLLM Performance Dashboard](https://hud.pytorch.org/benchmark/llms?repoName=vllm-project%2Fvllm).
## Manually Trigger the benchmark
Use [vllm-ci-test-repo images](https://gallery.ecr.aws/q9t5s3a7/vllm-ci-test-repo) with vLLM benchmark suite.
For CPU environment, please use the image with "-cpu" postfix.
Here is an example for docker run command for CPU.
```bash
docker run -it --entrypoint /bin/bash -v /data/huggingface:/root/.cache/huggingface -e HF_TOKEN='' --shm-size=16g --name vllm-cpu-ci public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:1da94e673c257373280026f75ceb4effac80e892-cpu
```
Then, run below command inside the docker instance.
```bash
bash .buildkite/performance-benchmarks/scripts/run-performance-benchmarks.sh
```
When run, benchmark script generates results under **benchmark/results** folder, along with the benchmark_results.md and benchmark_results.json.
### Runtime environment variables
- `ON_CPU`: set the value to '1' on Intel® Xeon® Processors. Default value is 0.
- `SERVING_JSON`: JSON file to use for the serving tests. Default value is empty string (use default file).
- `LATENCY_JSON`: JSON file to use for the latency tests. Default value is empty string (use default file).
- `THROUGHPUT_JSON`: JSON file to use for the throughout tests. Default value is empty string (use default file).
- `REMOTE_HOST`: IP for the remote vLLM service to benchmark. Default value is empty string.
- `REMOTE_PORT`: Port for the remote vLLM service to benchmark. Default value is empty string.
For more results visualization, check the [visualizing the results](https://github.com/intel-ai-tce/vllm/blob/more_cpu_models/.buildkite/nightly-benchmarks/README.md#visualizing-the-results).
More information on the performance benchmarks and their parameters can be found in [Benchmark README](https://github.com/intel-ai-tce/vllm/blob/more_cpu_models/.buildkite/nightly-benchmarks/README.md) and [performance benchmark description](../../.buildkite/performance-benchmarks/performance-benchmarks-descriptions.md).
## Continuous Benchmarking
The continuous benchmarking provides automated performance monitoring for vLLM across different models and GPU devices. This helps track vLLM's performance characteristics over time and identify any performance regressions or improvements.
### How It Works
The continuous benchmarking is triggered via a [GitHub workflow CI](https://github.com/pytorch/pytorch-integration-testing/actions/workflows/vllm-benchmark.yml) in the PyTorch infrastructure repository, which runs automatically every 4 hours. The workflow executes three types of performance tests:
- **Serving tests**: Measure request handling and API performance
- **Throughput tests**: Evaluate token generation rates
- **Latency tests**: Assess response time characteristics
### Benchmark Configuration
The benchmarking currently runs on a predefined set of models configured in the [vllm-benchmarks directory](https://github.com/pytorch/pytorch-integration-testing/tree/main/vllm-benchmarks/benchmarks). To add new models for benchmarking:
1. Navigate to the appropriate GPU directory in the benchmarks configuration
2. Add your model specifications to the corresponding configuration files
3. The new models will be included in the next scheduled benchmark run

178
docs/benchmarking/sweeps.md Normal file
View File

@ -0,0 +1,178 @@
# Parameter Sweeps
## Online Benchmark
### Basic
`vllm bench sweep serve` automatically starts `vllm serve` and runs `vllm bench serve` to evaluate vLLM over multiple configurations.
Follow these steps to run the script:
1. Construct the base command to `vllm serve`, and pass it to the `--serve-cmd` option.
2. Construct the base command to `vllm bench serve`, and pass it to the `--bench-cmd` option.
3. (Optional) If you would like to vary the settings of `vllm serve`, create a new JSON file and populate it with the parameter combinations you want to test. Pass the file path to `--serve-params`.
- Example: Tuning `--max-num-seqs` and `--max-num-batched-tokens`:
```json
[
{
"max_num_seqs": 32,
"max_num_batched_tokens": 1024
},
{
"max_num_seqs": 64,
"max_num_batched_tokens": 1024
},
{
"max_num_seqs": 64,
"max_num_batched_tokens": 2048
},
{
"max_num_seqs": 128,
"max_num_batched_tokens": 2048
},
{
"max_num_seqs": 128,
"max_num_batched_tokens": 4096
},
{
"max_num_seqs": 256,
"max_num_batched_tokens": 4096
}
]
```
4. (Optional) If you would like to vary the settings of `vllm bench serve`, create a new JSON file and populate it with the parameter combinations you want to test. Pass the file path to `--bench-params`.
- Example: Using different input/output lengths for random dataset:
```json
[
{
"random_input_len": 128,
"random_output_len": 32
},
{
"random_input_len": 256,
"random_output_len": 64
},
{
"random_input_len": 512,
"random_output_len": 128
}
]
```
5. Determine where you want to save the results, and pass that to `--output-dir`.
Example command:
```bash
vllm bench sweep serve \
--serve-cmd 'vllm serve meta-llama/Llama-2-7b-chat-hf' \
--bench-cmd 'vllm bench serve --model meta-llama/Llama-2-7b-chat-hf --backend vllm --endpoint /v1/completions --dataset-name sharegpt --dataset-path benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json' \
--serve-params benchmarks/serve_hparams.json \
--bench-params benchmarks/bench_hparams.json \
-o benchmarks/results
```
!!! important
If both `--serve-params` and `--bench-params` are passed, the script will iterate over the Cartesian product between them.
You can use `--dry-run` to preview the commands to be run.
We only start the server once for each `--serve-params`, and keep it running for multiple `--bench-params`.
Between each benchmark run, we call the `/reset_prefix_cache` and `/reset_mm_cache` endpoints to get a clean slate for the next run.
In case you are using a custom `--serve-cmd`, you can override the commands used for resetting the state by setting `--after-bench-cmd`.
!!! note
By default, each parameter combination is run 3 times to make the results more reliable. You can adjust the number of runs by setting `--num-runs`.
!!! tip
You can use the `--resume` option to continue the parameter sweep if one of the runs failed.
### SLA auto-tuner
`vllm bench sweep serve_sla` is a wrapper over `vllm bench sweep serve` that tunes either the request rate or concurrency (choose using `--sla-variable`) in order to satisfy the SLA constraints given by `--sla-params`.
For example, to ensure E2E latency within different target values for 99% of requests:
```json
[
{
"p99_e2el_ms": "<=200"
},
{
"p99_e2el_ms": "<=500"
},
{
"p99_e2el_ms": "<=1000"
},
{
"p99_e2el_ms": "<=2000"
}
]
```
Example command:
```bash
vllm bench sweep serve_sla \
--serve-cmd 'vllm serve meta-llama/Llama-2-7b-chat-hf' \
--bench-cmd 'vllm bench serve --model meta-llama/Llama-2-7b-chat-hf --backend vllm --endpoint /v1/completions --dataset-name sharegpt --dataset-path benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json' \
--serve-params benchmarks/serve_hparams.json \
--bench-params benchmarks/bench_hparams.json \
--sla-params benchmarks/sla_hparams.json \
--sla-variable max_concurrency \
-o benchmarks/results
```
The algorithm for adjusting the SLA variable is as follows:
1. Run the benchmark with infinite QPS, and use the corresponding metrics to determine the initial value of the variable.
- For example, the initial request rate is set to the concurrency under infinite QPS.
2. If the SLA is still satisfied, keep doubling the value until the SLA is no longer satisfied. This gives a relatively narrow window that contains the point where the SLA is barely satisfied.
3. Apply binary search over the window to find the maximum value that still satisfies the SLA.
!!! important
SLA tuning is applied over each combination of `--serve-params`, `--bench-params`, and `--sla-params`.
For a given combination of `--serve-params` and `--bench-params`, we share the benchmark results across `--sla-params` to avoid rerunning benchmarks with the same SLA variable value.
## Visualization
### Basic
`vllm bench sweep plot` can be used to plot performance curves from parameter sweep results.
Example command:
```bash
vllm bench sweep plot benchmarks/results/<timestamp> \
--var-x max_concurrency \
--row-by random_input_len \
--col-by random_output_len \
--curve-by api_server_count,max_num_batched_tokens \
--filter-by 'max_concurrency<=1024'
```
!!! tip
You can use `--dry-run` to preview the figures to be plotted.
### Pareto chart
`vllm bench sweep plot_pareto` helps pick configurations that balance per-user and per-GPU throughput.
Higher concurrency or batch size can raise GPU efficiency (per-GPU), but can add per user latency; lower concurrency improves per-user rate but underutilizes GPUs; The Pareto frontier shows the best achievable pairs across your runs.
- x-axis: tokens/s/user = `output_throughput` ÷ concurrency (`--user-count-var`, default `max_concurrency`, fallback `max_concurrent_requests`).
- y-axis: tokens/s/GPU = `output_throughput` ÷ GPU count (`--gpu-count-var` if set; else gpu_count is TP×PP*DP).
- Output: a single figure at `OUTPUT_DIR/pareto/PARETO.png`.
- Show the configuration used in each data point `--label-by` (default: `max_concurrency,gpu_count`).
Example:
```bash
vllm bench sweep plot_pareto benchmarks/results/<timestamp> \
--label-by max_concurrency,tensor_parallel_size,pipeline_parallel_size
```

View File

@ -0,0 +1,9 @@
# vllm bench sweep plot_pareto
## JSON CLI Arguments
--8<-- "docs/cli/json_tip.inc.md"
## Arguments
--8<-- "docs/argparse/bench_sweep_plot_pareto.inc.md"

View File

@ -113,8 +113,6 @@ See [this page](registration.md) for instructions on how to register your new mo
### How to support models with interleaving sliding windows? ### How to support models with interleaving sliding windows?
For models with interleaving sliding windows (e.g. `google/gemma-2-2b-it` and `mistralai/Ministral-8B-Instruct-2410`), the scheduler will treat the model as a full-attention model, i.e., kv-cache of all tokens will not be dropped. This is to make sure prefix caching works with these models. Sliding window only appears as a parameter to the attention kernel computation.
To support a model with interleaving sliding windows, we need to take care of the following details: To support a model with interleaving sliding windows, we need to take care of the following details:
- Make sure the model's `config.json` contains `layer_types`. - Make sure the model's `config.json` contains `layer_types`.

View File

@ -11,6 +11,8 @@ We support tracing vLLM workers using the `torch.profiler` module. You can enabl
- `VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY=1` to record memory, off by default - `VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY=1` to record memory, off by default
- `VLLM_TORCH_PROFILER_WITH_STACK=1` to enable recording stack information, on by default - `VLLM_TORCH_PROFILER_WITH_STACK=1` to enable recording stack information, on by default
- `VLLM_TORCH_PROFILER_WITH_FLOPS=1` to enable recording FLOPs, off by default - `VLLM_TORCH_PROFILER_WITH_FLOPS=1` to enable recording FLOPs, off by default
- `VLLM_TORCH_PROFILER_USE_GZIP=0` to disable gzip-compressing profiling files, on by default
- `VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL=0` to disable dumping and printing the aggregated CUDA self time table, on by default
The OpenAI server also needs to be started with the `VLLM_TORCH_PROFILER_DIR` environment variable set. The OpenAI server also needs to be started with the `VLLM_TORCH_PROFILER_DIR` environment variable set.

View File

@ -8,9 +8,9 @@ TL;DR:
| Online Flag | Offline Flag | Result | | Online Flag | Offline Flag | Result |
|----------|----------|-------------| |----------|----------|-------------|
| --enforce-eager | enforce_eager=True | Turn off torch.compile and CUDAGraphs | | --enforce-eager | enforce_eager=True | Turn off torch.compile and CUDAGraphs |
| -O.mode=0 | mode=CompilationMode.NONE | Turn off torch.compile only | | -cc.mode=0 | mode=CompilationMode.NONE | Turn off torch.compile only |
| -O.cudagraph_mode=NONE | compilation_config=CompilationConfig(cudagraph_mode=CUDAGraphMode.NONE) | Turn off CUDAGraphs only | | -cc.cudagraph_mode=NONE | compilation_config=CompilationConfig(cudagraph_mode=CUDAGraphMode.NONE) | Turn off CUDAGraphs only |
| -O.backend=eager | compilation_config=CompilationConfig(backend='eager') | Turn off TorchInductor | | -cc.backend=eager | compilation_config=CompilationConfig(backend='eager') | Turn off TorchInductor |
## vLLM-torch.compile overview ## vLLM-torch.compile overview
@ -86,11 +86,11 @@ LLM(model, enforce_eager=True)
``` ```
To turn off just torch.compile, pass `mode = NONE` to the compilation config. To turn off just torch.compile, pass `mode = NONE` to the compilation config.
(`-O` is short for `--compilation_config`): (`-cc` is short for `--compilation_config`; `-O.*` dotted syntax is deprecated):
```sh ```sh
# Online # Online
vllm serve -O.mode=0 vllm serve -cc.mode=0
``` ```
```py ```py
@ -103,7 +103,7 @@ To turn off just CUDAGraphs, pass `cudagraph_mode = NONE`:
```sh ```sh
# Online # Online
vllm serve -O.cudagraph_mode=NONE vllm serve -cc.cudagraph_mode=NONE
``` ```
```py ```py
@ -183,10 +183,10 @@ help debug the issue:
```sh ```sh
# Online - using unbacked mode # Online - using unbacked mode
vllm serve meta-llama/Llama-3.2-1B -O.dynamic_shapes_config.type=unbacked vllm serve meta-llama/Llama-3.2-1B -cc.dynamic_shapes_config.type=unbacked
# Online - using backed_size_oblivious mode # Online - using backed_size_oblivious mode
vllm serve meta-llama/Llama-3.2-1B -O.dynamic_shapes_config.type=backed_size_oblivious vllm serve meta-llama/Llama-3.2-1B -cc.dynamic_shapes_config.type=backed_size_oblivious
``` ```
```py ```py
@ -233,7 +233,7 @@ to the compilation config:
```sh ```sh
# online # online
vllm serve -O.backend=eager vllm serve -cc.backend=eager
``` ```
```py ```py
@ -252,7 +252,7 @@ You can also use `TORCH_LOGS=output_code <command>` to print the Inductor output
### Editable TorchInductor code ### Editable TorchInductor code
You can edit the TorchInductor code that gets run by setting `VLLM_COMPILE_CACHE_SAVE_FORMAT=unpacked` You can edit the TorchInductor code that gets run by setting `VLLM_COMPILE_CACHE_SAVE_FORMAT=unpacked`
or passing `-O.compile_cache_save_format=unpacked`. The default is `binary`, which means it is not editable. or passing `-cc.compile_cache_save_format=unpacked`. The default is `binary`, which means it is not editable.
This is a useful technique: you can put breakpoints (e.g. `torch.distributed.breakpoint()`) This is a useful technique: you can put breakpoints (e.g. `torch.distributed.breakpoint()`)
and print statements in the output code. and print statements in the output code.
@ -299,7 +299,7 @@ To turn off just CUDAGraphs, pass `cudagraph_mode = NONE`:
```sh ```sh
# Online # Online
vllm serve -O.cudagraph_mode=NONE vllm serve -cc.cudagraph_mode=NONE
``` ```
```py ```py

View File

@ -21,7 +21,7 @@ Let's say we want to serve the popular Qwen model by running `vllm serve Qwen/Qw
Beyond that, there are two more things vLLM depends on Hugging Face for. Beyond that, there are two more things vLLM depends on Hugging Face for.
1. **Tokenizer**: vLLM uses the tokenizer from Hugging Face to tokenize the input text. The tokenizer is loaded using [AutoTokenizer.from_pretrained](https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoTokenizer.from_pretrained) with the `model` argument as the model name and the `--revision` argument as the revision. It is also possible to use a tokenizer from another model by specifying the `--tokenizer` argument in the `vllm serve` command. Other relevant arguments are `--tokenizer-revision` and `--tokenizer-mode`. Please check Hugging Face's documentation for the meaning of these arguments. This part of the logic can be found in the [get_tokenizer](https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/transformers_utils/tokenizer.py#L87) function. After obtaining the tokenizer, notably, vLLM will cache some expensive attributes of the tokenizer in [get_cached_tokenizer](https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/transformers_utils/tokenizer.py#L24). 1. **Tokenizer**: vLLM uses the tokenizer from Hugging Face to tokenize the input text. The tokenizer is loaded using [AutoTokenizer.from_pretrained](https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoTokenizer.from_pretrained) with the `model` argument as the model name and the `--revision` argument as the revision. It is also possible to use a tokenizer from another model by specifying the `--tokenizer` argument in the `vllm serve` command. Other relevant arguments are `--tokenizer-revision` and `--tokenizer-mode`. Please check Hugging Face's documentation for the meaning of these arguments. This part of the logic can be found in the [get_tokenizer](https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/transformers_utils/tokenizer.py#L87) function. After obtaining the tokenizer, notably, vLLM will cache some expensive attributes of the tokenizer in [vllm.tokenizers.hf.get_cached_tokenizer][].
2. **Model weight**: vLLM downloads the model weight from the Hugging Face model hub using the `model` argument as the model name and the `--revision` argument as the revision. vLLM provides the argument `--load-format` to control what files to download from the model hub. By default, it will try to load the weights in the safetensors format and fall back to the PyTorch bin format if the safetensors format is not available. We can also pass `--load-format dummy` to skip downloading the weights. 2. **Model weight**: vLLM downloads the model weight from the Hugging Face model hub using the `model` argument as the model name and the `--revision` argument as the revision. vLLM provides the argument `--load-format` to control what files to download from the model hub. By default, it will try to load the weights in the safetensors format and fall back to the PyTorch bin format if the safetensors format is not available. We can also pass `--load-format dummy` to skip downloading the weights.
- It is recommended to use the safetensors format, as it is efficient for loading in distributed inference and also safe from arbitrary code execution. See the [documentation](https://huggingface.co/docs/safetensors/en/index) for more information on the safetensors format. This part of the logic can be found [here](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/model_executor/model_loader/loader.py#L385). Please note that: - It is recommended to use the safetensors format, as it is efficient for loading in distributed inference and also safe from arbitrary code execution. See the [documentation](https://huggingface.co/docs/safetensors/en/index) for more information on the safetensors format. This part of the logic can be found [here](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/model_executor/model_loader/loader.py#L385). Please note that:

View File

@ -77,7 +77,7 @@ The `parse_request` method is used for validating the user prompt and converting
The `pre_process*` methods take the validated plugin input to generate vLLM's model prompts for regular inference. The `pre_process*` methods take the validated plugin input to generate vLLM's model prompts for regular inference.
The `post_process*` methods take `PoolingRequestOutput` objects as input and generate a custom plugin output. The `post_process*` methods take `PoolingRequestOutput` objects as input and generate a custom plugin output.
The `validate_or_generate_params` method is used for validating with the plugin any `SamplingParameters`/`PoolingParameters` received with the user request, or to generate new ones if none are specified. The function always returns the validated/generated parameters. The `validate_or_generate_params` method is used for validating with the plugin any `SamplingParameters`/`PoolingParameters` received with the user request, or to generate new ones if none are specified. The function always returns the validated/generated parameters.
The `output_to_response` method is used only for online serving and converts the plugin output to the `IOProcessorResponse` type that is then returned by the API Server. The implementation of the `/pooling` serving endpoint is available here [vllm/entrypoints/openai/serving_pooling.py](../../vllm/entrypoints/openai/serving_pooling.py). The `output_to_response` method is used only for online serving and converts the plugin output to the `IOProcessorResponse` type that is then returned by the API Server. The implementation of the `/pooling` serving endpoint is available here [vllm/entrypoints/openai/serving_pooling.py](../../vllm/entrypoints/pooling/pooling/serving.py).
An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/IBM/terratorch/tree/main/terratorch/vllm/plugins/segmentation). Please, also refer to our online ([examples/online_serving/pooling/prithvi_geospatial_mae.py](../../examples/online_serving/pooling/prithvi_geospatial_mae.py)) and offline ([examples/offline_inference/pooling/prithvi_geospatial_mae_io_processor.py](../../examples/offline_inference/pooling/prithvi_geospatial_mae_io_processor.py)) inference examples. An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/IBM/terratorch/tree/main/terratorch/vllm/plugins/segmentation). Please, also refer to our online ([examples/online_serving/pooling/prithvi_geospatial_mae.py](../../examples/online_serving/pooling/prithvi_geospatial_mae.py)) and offline ([examples/offline_inference/pooling/prithvi_geospatial_mae_io_processor.py](../../examples/offline_inference/pooling/prithvi_geospatial_mae_io_processor.py)) inference examples.

View File

@ -263,6 +263,29 @@ record:
- End-to-end latency - the interval between frontend `arrival_time` - End-to-end latency - the interval between frontend `arrival_time`
and the frontend receiving the final token. and the frontend receiving the final token.
### KV Cache Residency Metrics
We also emit a set of histograms that describe how long sampled KV cache
blocks stay resident and how often they are reused. Sampling
(`--kv-cache-metrics-sample`) keeps the overhead tiny; when a block is
chosen we record:
- `lifetime` allocation ⟶ eviction
- `idle before eviction` last touch ⟶ eviction
- `reuse gaps` the pauses between touches when the block gets reused
Those map directly to the Prometheus metrics:
- `vllm:kv_block_lifetime_seconds` how long each sampled block exists.
- `vllm:kv_block_idle_before_evict_seconds` idle tail after the final access.
- `vllm:kv_block_reuse_gap_seconds` time between consecutive touches.
The engine core only ships raw eviction events via `SchedulerStats`; the
frontend drains them, turns them into Prometheus observations, and also
exposes the same data through `LLM.get_metrics()` when logging is on.
Looking at lifetime and idle time on one chart makes it easy to spot
stranded cache or workloads that pin prompts for a long decode.
### Metrics Publishing - Logging ### Metrics Publishing - Logging
The `LoggingStatLogger` metrics publisher outputs a log `INFO` message The `LoggingStatLogger` metrics publisher outputs a log `INFO` message

View File

@ -60,7 +60,7 @@ Modular kernels are supported by the following `FusedMoEMethodBase` classes.
- [`ModelOptFp8MoEMethod`][vllm.model_executor.layers.quantization.modelopt.ModelOptFp8MoEMethod] - [`ModelOptFp8MoEMethod`][vllm.model_executor.layers.quantization.modelopt.ModelOptFp8MoEMethod]
- [`Fp8MoEMethod`][vllm.model_executor.layers.quantization.fp8.Fp8MoEMethod] - [`Fp8MoEMethod`][vllm.model_executor.layers.quantization.fp8.Fp8MoEMethod]
- [`CompressedTensorsW4A4Nvfp4MoeMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW4A4Nvfp4MoeMethod] - [`CompressedTensorsW4A4Nvfp4MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW4A4Nvfp4MoEMethod]
- [`CompressedTensorsW8A8Fp8MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW8A8Fp8MoEMethod] - [`CompressedTensorsW8A8Fp8MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW8A8Fp8MoEMethod]
- [`Mxfp4MoEMethod`][vllm.model_executor.layers.quantization.mxfp4.Mxfp4MoEMethod] - [`Mxfp4MoEMethod`][vllm.model_executor.layers.quantization.mxfp4.Mxfp4MoEMethod]
- [`UnquantizedFusedMoEMethod`][vllm.model_executor.layers.fused_moe.layer.UnquantizedFusedMoEMethod] - [`UnquantizedFusedMoEMethod`][vllm.model_executor.layers.fused_moe.layer.UnquantizedFusedMoEMethod]

View File

@ -86,7 +86,7 @@ Every plugin has three parts:
}, },
... ...
) )
``` ```
Please make sure `vllm_add_dummy_platform:register` is a callable function and returns the platform class's fully qualified name. for example: Please make sure `vllm_add_dummy_platform:register` is a callable function and returns the platform class's fully qualified name. for example:

View File

@ -117,7 +117,7 @@ vllm serve meta-llama/Llama-3.2-1B \
# Alternative: Using dot notation (simpler for single values) # Alternative: Using dot notation (simpler for single values)
vllm serve meta-llama/Llama-3.2-1B -O.dynamic_shapes_config.type=unbacked vllm serve meta-llama/Llama-3.2-1B -cc.dynamic_shapes_config.type=unbacked
``` ```
#### Choosing the Right Mode #### Choosing the Right Mode

View File

@ -216,14 +216,13 @@ You can add a new `ReasoningParser` similar to [vllm/reasoning/deepseek_r1_reaso
# import the required packages # import the required packages
from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
DeltaMessage)
# define a reasoning parser and register it to vllm # define a reasoning parser and register it to vllm
# the name list in register_module can be used # the name list in register_module can be used
# in --reasoning-parser. # in --reasoning-parser.
class ExampleParser(ReasoningParser): class ExampleParser(ReasoningParser):
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
def extract_reasoning_streaming( def extract_reasoning_streaming(

View File

@ -371,7 +371,8 @@ Olmo 3 models output tool calls in a format that is very similar to the one expe
Supported models: Supported models:
* TODO (will be updated after Olmo 3 release) * `allenai/Olmo-3-7B-Instruct`
* `allenai/Olmo-3-32B-Think`
Flags: `--tool-call-parser olmo3` Flags: `--tool-call-parser olmo3`
@ -421,7 +422,7 @@ Here is a summary of a plugin file:
# in --tool-call-parser. you can define as many # in --tool-call-parser. you can define as many
# tool parsers as you want here. # tool parsers as you want here.
class ExampleToolParser(ToolParser): class ExampleToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
# adjust request. e.g.: set skip special tokens # adjust request. e.g.: set skip special tokens

View File

@ -94,6 +94,9 @@ def auto_mock(module_name: str, attr: str, max_mocks: int = 100):
bench_latency = auto_mock("vllm.benchmarks", "latency") bench_latency = auto_mock("vllm.benchmarks", "latency")
bench_serve = auto_mock("vllm.benchmarks", "serve") bench_serve = auto_mock("vllm.benchmarks", "serve")
bench_sweep_plot = auto_mock("vllm.benchmarks.sweep.plot", "SweepPlotArgs") bench_sweep_plot = auto_mock("vllm.benchmarks.sweep.plot", "SweepPlotArgs")
bench_sweep_plot_pareto = auto_mock(
"vllm.benchmarks.sweep.plot_pareto", "SweepPlotParetoArgs"
)
bench_sweep_serve = auto_mock("vllm.benchmarks.sweep.serve", "SweepServeArgs") bench_sweep_serve = auto_mock("vllm.benchmarks.sweep.serve", "SweepServeArgs")
bench_sweep_serve_sla = auto_mock( bench_sweep_serve_sla = auto_mock(
"vllm.benchmarks.sweep.serve_sla", "SweepServeSLAArgs" "vllm.benchmarks.sweep.serve_sla", "SweepServeSLAArgs"
@ -221,6 +224,7 @@ def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool):
"bench_latency": create_parser(bench_latency.add_cli_args), "bench_latency": create_parser(bench_latency.add_cli_args),
"bench_serve": create_parser(bench_serve.add_cli_args), "bench_serve": create_parser(bench_serve.add_cli_args),
"bench_sweep_plot": create_parser(bench_sweep_plot.add_cli_args), "bench_sweep_plot": create_parser(bench_sweep_plot.add_cli_args),
"bench_sweep_plot_pareto": create_parser(bench_sweep_plot_pareto.add_cli_args),
"bench_sweep_serve": create_parser(bench_sweep_serve.add_cli_args), "bench_sweep_serve": create_parser(bench_sweep_serve.add_cli_args),
"bench_sweep_serve_sla": create_parser(bench_sweep_serve_sla.add_cli_args), "bench_sweep_serve_sla": create_parser(bench_sweep_serve_sla.add_cli_args),
"bench_throughput": create_parser(bench_throughput.add_cli_args), "bench_throughput": create_parser(bench_throughput.add_cli_args),

View File

@ -479,6 +479,7 @@ These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) A
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | | Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
|--------------|--------|-------------------|----------------------|---------------------------| |--------------|--------|-------------------|----------------------|---------------------------|
| `BertModel`<sup>C</sup> | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | | | `BertModel`<sup>C</sup> | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | |
| `BertSpladeSparseEmbeddingModel` | SPLADE | `naver/splade-v3` | | |
| `Gemma2Model`<sup>C</sup> | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | ✅︎ | | `Gemma2Model`<sup>C</sup> | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | ✅︎ |
| `Gemma3TextModel`<sup>C</sup> | Gemma 3-based | `google/embeddinggemma-300m`, etc. | ✅︎ | ✅︎ | | `Gemma3TextModel`<sup>C</sup> | Gemma 3-based | `google/embeddinggemma-300m`, etc. | ✅︎ | ✅︎ |
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | | `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ |
@ -725,6 +726,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `Step3VLForConditionalGeneration` | Step3-VL | T + I<sup>+</sup> | `stepfun-ai/step3` | | ✅︎ | | `Step3VLForConditionalGeneration` | Step3-VL | T + I<sup>+</sup> | `stepfun-ai/step3` | | ✅︎ |
| `TarsierForConditionalGeneration` | Tarsier | T + I<sup>E+</sup> | `omni-search/Tarsier-7b`, `omni-search/Tarsier-34b` | | ✅︎ | | `TarsierForConditionalGeneration` | Tarsier | T + I<sup>E+</sup> | `omni-search/Tarsier-7b`, `omni-search/Tarsier-34b` | | ✅︎ |
| `Tarsier2ForConditionalGeneration`<sup>^</sup> | Tarsier2 | T + I<sup>E+</sup> + V<sup>E+</sup> | `omni-research/Tarsier2-Recap-7b`, `omni-research/Tarsier2-7b-0115` | | ✅︎ | | `Tarsier2ForConditionalGeneration`<sup>^</sup> | Tarsier2 | T + I<sup>E+</sup> + V<sup>E+</sup> | `omni-research/Tarsier2-Recap-7b`, `omni-research/Tarsier2-7b-0115` | | ✅︎ |
| `UltravoxModel` | Ultravox | T + A<sup>E+</sup> | `fixie-ai/ultravox-v0_5-llama-3_2-1b` | ✅︎ | ✅︎ |
Some models are supported only via the [Transformers modeling backend](#transformers). The purpose of the table below is to acknowledge models which we officially support in this way. The logs will say that the Transformers modeling backend is being used, and you will see no warning that this is fallback behaviour. This means that, if you have issues with any of the models listed below, please [make an issue](https://github.com/vllm-project/vllm/issues/new/choose) and we'll do our best to fix it! Some models are supported only via the [Transformers modeling backend](#transformers). The purpose of the table below is to acknowledge models which we officially support in this way. The logs will say that the Transformers modeling backend is being used, and you will see no warning that this is fallback behaviour. This means that, if you have issues with any of the models listed below, please [make an issue](https://github.com/vllm-project/vllm/issues/new/choose) and we'll do our best to fix it!

View File

@ -351,7 +351,7 @@ The following extra parameters are supported by default:
??? code ??? code
```python ```python
--8<-- "vllm/entrypoints/openai/protocol.py:embedding-extra-params" --8<-- "vllm/entrypoints/pooling/embed/protocol.py:embedding-extra-params"
``` ```
For chat-like input (i.e. if `messages` is passed), these extra parameters are supported instead: For chat-like input (i.e. if `messages` is passed), these extra parameters are supported instead:
@ -359,7 +359,7 @@ For chat-like input (i.e. if `messages` is passed), these extra parameters are s
??? code ??? code
```python ```python
--8<-- "vllm/entrypoints/openai/protocol.py:chat-embedding-extra-params" --8<-- "vllm/entrypoints/pooling/embed/protocol.py:chat-embedding-extra-params"
``` ```
### Transcriptions API ### Transcriptions API
@ -456,6 +456,7 @@ For `verbose_json` response format:
] ]
} }
``` ```
Currently “verbose_json” response format doesnt support avg_logprob, compression_ratio, no_speech_prob.
#### Extra Parameters #### Extra Parameters
@ -629,7 +630,7 @@ The following [pooling parameters][vllm.PoolingParams] are supported.
The following extra parameters are supported: The following extra parameters are supported:
```python ```python
--8<-- "vllm/entrypoints/openai/protocol.py:classification-extra-params" --8<-- "vllm/entrypoints/pooling/classify/protocol.py:classification-extra-params"
``` ```
### Score API ### Score API
@ -834,7 +835,7 @@ The following [pooling parameters][vllm.PoolingParams] are supported.
The following extra parameters are supported: The following extra parameters are supported:
```python ```python
--8<-- "vllm/entrypoints/openai/protocol.py:score-extra-params" --8<-- "vllm/entrypoints/pooling/score/protocol.py:score-extra-params"
``` ```
### Re-rank API ### Re-rank API
@ -915,7 +916,7 @@ The following [pooling parameters][vllm.PoolingParams] are supported.
The following extra parameters are supported: The following extra parameters are supported:
```python ```python
--8<-- "vllm/entrypoints/openai/protocol.py:rerank-extra-params" --8<-- "vllm/entrypoints/pooling/score/protocol.py:rerank-extra-params"
``` ```
## Ray Serve LLM ## Ray Serve LLM

View File

@ -46,7 +46,6 @@ def create_test_prompts(
logprobs=1, logprobs=1,
prompt_logprobs=1, prompt_logprobs=1,
max_tokens=128, max_tokens=128,
stop_token_ids=[32003],
), ),
LoRARequest("sql-lora", 1, lora_path), LoRARequest("sql-lora", 1, lora_path),
), ),
@ -57,7 +56,6 @@ def create_test_prompts(
logprobs=1, logprobs=1,
prompt_logprobs=1, prompt_logprobs=1,
max_tokens=128, max_tokens=128,
stop_token_ids=[32003],
), ),
LoRARequest("sql-lora2", 2, lora_path), LoRARequest("sql-lora2", 2, lora_path),
), ),
@ -98,7 +96,7 @@ def initialize_engine() -> LLMEngine:
# use the same rank, it is recommended to set this as low as possible. # use the same rank, it is recommended to set this as low as possible.
# max_cpu_loras: controls the size of the CPU LoRA cache. # max_cpu_loras: controls the size of the CPU LoRA cache.
engine_args = EngineArgs( engine_args = EngineArgs(
model="meta-llama/Llama-2-7b-hf", model="meta-llama/Llama-3.2-3B-Instruct",
enable_lora=True, enable_lora=True,
max_loras=1, max_loras=1,
max_lora_rank=8, max_lora_rank=8,
@ -111,7 +109,7 @@ def initialize_engine() -> LLMEngine:
def main(): def main():
"""Main function that sets up and runs the prompt processing.""" """Main function that sets up and runs the prompt processing."""
engine = initialize_engine() engine = initialize_engine()
lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test") lora_path = snapshot_download(repo_id="jeeejeee/llama32-3b-text2sql-spider")
test_prompts = create_test_prompts(lora_path) test_prompts = create_test_prompts(lora_path)
process_requests(engine, test_prompts) process_requests(engine, test_prompts)

View File

@ -0,0 +1,184 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Example demonstrating MCP (Model Context Protocol) tools with the Responses API.
This example shows how to use MCP tools with different allowed_tools configurations:
1. No filter (allows all tools from the MCP server)
2. Wildcard "*" (explicitly allows all tools)
3. Specific tool names (filters to only those tools)
Set up this example by starting a vLLM OpenAI-compatible server with MCP tools enabled.
For example:
vllm serve openai/gpt-oss-20b --enforce-eager --tool-server demo
Environment variables:
- VLLM_ENABLE_RESPONSES_API_STORE=1
- VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS=code_interpreter,container
- VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS=1
"""
from openai import OpenAI
from utils import get_first_model
def example_no_filter():
"""Example with no allowed_tools filter - allows all tools."""
print("=" * 60)
print("Example 1: No allowed_tools filter (allows all tools)")
print("=" * 60)
base_url = "http://0.0.0.0:8000/v1"
client = OpenAI(base_url=base_url, api_key="empty")
model = get_first_model(client)
response = client.responses.create(
model=model,
input="Execute this code: print('Hello from Python!')",
instructions="Use the Python tool to execute code.",
tools=[
{
"type": "mcp",
"server_label": "code_interpreter",
"server_url": "http://localhost:8888",
# No allowed_tools specified - all tools are available
}
],
)
print(f"Status: {response.status}")
print(f"Output: {response.output_text}")
print()
def example_wildcard():
"""Example with allowed_tools=['*'] - explicitly allows all tools."""
print("=" * 60)
print("Example 2: allowed_tools=['*'] (select all tools)")
print("=" * 60)
base_url = "http://0.0.0.0:8000/v1"
client = OpenAI(base_url=base_url, api_key="empty")
model = get_first_model(client)
response = client.responses.create(
model=model,
input="Execute this code: print('Hello from Python with wildcard!')",
instructions="Use the Python tool to execute code.",
tools=[
{
"type": "mcp",
"server_label": "code_interpreter",
"server_url": "http://localhost:8888",
# Using "*" to explicitly allow all tools from this MCP server
# This is equivalent to not specifying allowed_tools
"allowed_tools": ["*"],
}
],
)
print(f"Status: {response.status}")
print(f"Output: {response.output_text}")
print()
def example_specific_tools():
"""Example with specific allowed_tools list - filters available tools.
Note: This example uses 'web_search_preview' (browser) which has multiple
sub-tools: 'search', 'open', 'find'. The code_interpreter (python) doesn't
have sub-tools, so filtering doesn't apply there.
"""
print("=" * 60)
print("Example 3: allowed_tools=['search'] (filter browser to specific tools)")
print("=" * 60)
base_url = "http://0.0.0.0:8000/v1"
client = OpenAI(base_url=base_url, api_key="empty")
model = get_first_model(client)
response = client.responses.create(
model=model,
input="Search for 'Python programming tutorials'",
instructions="Use the browser tool to search.",
tools=[
{
"type": "mcp",
"server_label": "web_search_preview",
"server_url": "http://localhost:8888",
# Browser has tools: 'search', 'open', 'find'
# Only allow 'search' - blocks 'open' and 'find'
"allowed_tools": ["search"],
}
],
)
print(f"Status: {response.status}")
print(f"Output: {response.output_text}")
print()
def example_object_format():
"""Example using object format for allowed_tools with browser tools."""
print("=" * 60)
print("Example 4: allowed_tools with object format")
print("=" * 60)
base_url = "http://0.0.0.0:8000/v1"
client = OpenAI(base_url=base_url, api_key="empty")
model = get_first_model(client)
response = client.responses.create(
model=model,
input="Search for 'machine learning' and open the first result",
instructions="Use the browser tool.",
tools=[
{
"type": "mcp",
"server_label": "web_search_preview",
"server_url": "http://localhost:8888",
# Object format with tool_names field
# Can also include read_only and other fields
# Browser has tools: 'search', 'open', 'find'
"allowed_tools": {
"tool_names": [
"search",
"open",
], # Allow search and open, block find
"read_only": False,
},
}
],
)
print(f"Status: {response.status}")
print(f"Output: {response.output_text}")
print()
def main():
"""Run all examples."""
print("\n" + "=" * 60)
print("MCP Tools with allowed_tools Examples")
print("=" * 60 + "\n")
# Run all examples
example_no_filter()
example_wildcard()
example_specific_tools()
example_object_format()
print("=" * 60)
print("Summary:")
print(" - No filter or '*' → All tools available from server")
print(" - Specific list → Only those sub-tools available")
print(" - Object format → More control with tool_names field")
print("")
print("Note: allowed_tools filters SUB-TOOLS within an MCP server:")
print(" - code_interpreter (python): No sub-tools to filter")
print(" - web_search_preview (browser): Has 'search', 'open', 'find'")
print("=" * 60)
if __name__ == "__main__":
main()

View File

@ -49,4 +49,4 @@ cbor2 # Required for cross-language serialization of hashable objects
setproctitle # Used to set process names for better debugging and monitoring setproctitle # Used to set process names for better debugging and monitoring
openai-harmony >= 0.0.3 # Required for gpt-oss openai-harmony >= 0.0.3 # Required for gpt-oss
anthropic == 0.71.0 anthropic == 0.71.0
model-hosting-container-standards < 1.0.0 model-hosting-container-standards >= 0.1.9, < 1.0.0

View File

@ -4,9 +4,8 @@ packaging>=24.2
setuptools>=77.0.3,<81.0.0 setuptools>=77.0.3,<81.0.0
setuptools-scm>=8 setuptools-scm>=8
--extra-index-url https://download.pytorch.org/whl/cpu --extra-index-url https://download.pytorch.org/whl/cpu
torch==2.9.0+cpu; platform_machine == "x86_64" or platform_machine == "s390x" torch==2.9.1+cpu; platform_machine == "x86_64" or platform_machine == "s390x"
torch==2.9.0; platform_system == "Darwin" torch==2.9.1; platform_system == "Darwin" or platform_machine == "ppc64le" or platform_machine == "aarch64"
torch==2.9.0; platform_machine == "ppc64le" or platform_machine == "aarch64"
scons; platform_machine == "aarch64" # needed to build Arm Compute Library (ACL) scons; platform_machine == "aarch64" # needed to build Arm Compute Library (ACL)
wheel wheel
jinja2>=3.1.6 jinja2>=3.1.6

View File

@ -4,25 +4,18 @@
numba == 0.61.2; platform_machine != "s390x" # Required for N-gram speculative decoding numba == 0.61.2; platform_machine != "s390x" # Required for N-gram speculative decoding
# Dependencies for CPUs # Dependencies for CPUs
packaging>=24.2
setuptools>=77.0.3,<81.0.0
--extra-index-url https://download.pytorch.org/whl/cpu --extra-index-url https://download.pytorch.org/whl/cpu
torch==2.9.0+cpu; platform_machine == "x86_64" or platform_machine == "s390x" torch==2.9.1+cpu; platform_machine == "x86_64" or platform_machine == "s390x"
torch==2.9.0; platform_system == "Darwin" torch==2.9.1; platform_system == "Darwin" or platform_machine == "ppc64le" or platform_machine == "aarch64"
torch==2.9.0; platform_machine == "ppc64le" or platform_machine == "aarch64"
# required for the image processor of minicpm-o-2_6, this must be updated alongside torch # required for the image processor of minicpm-o-2_6, this must be updated alongside torch
torchaudio; platform_machine != "ppc64le" and platform_machine != "s390x" torchaudio; platform_machine != "s390x"
torchaudio==2.9.0; platform_machine == "ppc64le"
# required for the image processor of phi3v, this must be updated alongside torch # required for the image processor of phi3v, this must be updated alongside torch
torchvision; platform_machine != "ppc64le" and platform_machine != "s390x" torchvision; platform_machine != "s390x"
torchvision==0.24.0; platform_machine == "ppc64le"
datasets # for benchmark scripts
# Intel Extension for PyTorch, only for x86_64 CPUs # Intel Extension for PyTorch, only for x86_64 CPUs
intel-openmp==2024.2.1; platform_machine == "x86_64" intel-openmp==2024.2.1; platform_machine == "x86_64"
triton==3.2.0; platform_machine == "x86_64" # Triton is required for torch 2.6+cpu, as it is imported in torch.compile.
# Use this to gather CPU info and optimize based on ARM Neoverse cores # Use this to gather CPU info and optimize based on ARM Neoverse cores
py-cpuinfo; platform_machine == "aarch64" py-cpuinfo; platform_machine == "aarch64"

View File

@ -1,51 +1,85 @@
# Common dependencies # Common dependencies
-r common.txt -r common.txt
tblib==3.1.0
bm25s==0.2.13
pystemmer==3.0.0
# Entrypoints test # Test infrastructure
# librosa==0.10.2.post1 # required by audio tests in entrypoints/openai tblib==3.1.0
pytest==8.3.5
pytest-asyncio==0.24.0
pytest-timeout==2.3.1
pytest-cov==6.3.0
pytest-forked==1.6.0
pytest-rerunfailures==14.0
pytest-shard==0.1.2
# Async/HTTP dependencies
anyio==4.6.2.post1
# via httpx, starlette
aiohttp==3.13.0
# via gpt-oss
httpx==0.27.2
# HTTP testing
# Audio processing dependencies
audioread==3.0.1 audioread==3.0.1
# via librosa
cffi==1.17.1 cffi==1.17.1
# via soundfile
decorator==5.2.1 decorator==5.2.1
# via librosa
lazy-loader==0.4 lazy-loader==0.4
# via librosa
platformdirs==4.3.6 platformdirs==4.3.6
# via pooch
pooch==1.8.2 pooch==1.8.2
#pycparse==2.22 # via librosa
soundfile==0.13.1 soundfile==0.13.1
# via librosa
soxr==0.5.0.post1 soxr==0.5.0.post1
# via librosa
librosa==0.10.2.post1 librosa==0.10.2.post1
# Entrypoints test # Retrieval and search
#vllm[video] # required by entrypoints/openai/test_video.py bm25s==0.2.13
decord==0.6.0 # via mteb
pystemmer==3.0.0
# via mteb
# Entrypoints test # Multi-modal processing
#sentence-transformers # required by entrypoints/openai/test_score.py
sentence-transformers==3.4.1
# Basic Models Test
matplotlib==3.10.3
# Multi-Modal Models Test (Extended) 3
blobfile==3.0.0 blobfile==3.0.0
# Multi-Modal Models Test
decord==0.6.0
# video processing, required by entrypoints/openai/test_video.py
# Required for openai schema test. # OpenAI compatibility and testing
gpt-oss==0.0.8
# OpenAI compatibility tests
schemathesis==3.39.15 schemathesis==3.39.15
# OpenAI schema test
# Required for mteb test # Evaluation and benchmarking
mteb[bm25s]>=1.38.11, <2
# Required for eval tests
lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d
# Required for multiprocessed tests that use spawn method # Required for multiprocessed tests that use spawn method, Datasets and Evaluate Test
multiprocess==0.70.16 multiprocess==0.70.16
# Plugins test # Plugins test
terratorch @ git+https://github.com/IBM/terratorch.git@07184fcf91a1324f831ff521dd238d97fe350e3e terratorch @ git+https://github.com/IBM/terratorch.git@07184fcf91a1324f831ff521dd238d97fe350e3e
torchgeo==0.7.0 torchgeo==0.7.0
# via terratorch
# MTEB Benchmark Test
mteb==2.1.2
# Data processing
xgrammar @ git+https://github.com/mlc-ai/xgrammar.git@eafd4db51b78acc64b3f0764ef27dfd206c28628
# Test async scheduling
# Utilities
num2words==0.5.14
# via lm-eval
pqdm==0.2.0
# via lm-eval
# Required for suffix decoding test # Required for suffix decoding test
arctic-inference == 0.1.1 arctic-inference == 0.1.1
# Required for Nemotron test
open-clip-torch==2.32.0

View File

@ -10,6 +10,7 @@ import re
import shutil import shutil
import subprocess import subprocess
import sys import sys
import sysconfig
from pathlib import Path from pathlib import Path
from shutil import which from shutil import which
@ -74,9 +75,13 @@ def is_ninja_available() -> bool:
return which("ninja") is not None return which("ninja") is not None
def is_freethreaded():
return bool(sysconfig.get_config_var("Py_GIL_DISABLED"))
class CMakeExtension(Extension): class CMakeExtension(Extension):
def __init__(self, name: str, cmake_lists_dir: str = ".", **kwa) -> None: def __init__(self, name: str, cmake_lists_dir: str = ".", **kwa) -> None:
super().__init__(name, sources=[], py_limited_api=True, **kwa) super().__init__(name, sources=[], py_limited_api=not is_freethreaded(), **kwa)
self.cmake_lists_dir = os.path.abspath(cmake_lists_dir) self.cmake_lists_dir = os.path.abspath(cmake_lists_dir)

View File

@ -11,7 +11,7 @@ from vllm.device_allocator.cumem import CuMemAllocator
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.mem_constants import GiB_bytes from vllm.utils.mem_constants import GiB_bytes
from ..utils import create_new_process_for_each_test from ..utils import create_new_process_for_each_test, requires_fp8
@create_new_process_for_each_test("fork" if not current_platform.is_rocm() else "spawn") @create_new_process_for_each_test("fork" if not current_platform.is_rocm() else "spawn")
@ -243,3 +243,34 @@ def test_deep_sleep_async():
assert output.outputs[0].text == output2.outputs[0].text assert output.outputs[0].text == output2.outputs[0].text
asyncio.run(test()) asyncio.run(test())
@requires_fp8
def test_deep_sleep_fp8_kvcache():
GiB_bytes = 1 << 30
model = "Qwen/Qwen2-0.5B"
used_bytes_baseline = current_platform.get_current_memory_usage()
llm = LLM(model, enable_sleep_mode=True, kv_cache_dtype="fp8")
prompt = "How are you?"
sampling_params = SamplingParams(temperature=0, max_tokens=10)
output = llm.generate(prompt, sampling_params)
# Put the engine to deep sleep
llm.sleep(level=2)
used_bytes = current_platform.get_current_memory_usage() - used_bytes_baseline
assert used_bytes < 3 * GiB_bytes
llm.wake_up(tags=["weights"])
llm.collective_rpc("reload_weights")
used_bytes = current_platform.get_current_memory_usage() - used_bytes_baseline
assert used_bytes < 4 * GiB_bytes
# now allocate kv cache and cuda graph memory
llm.wake_up(tags=["kv_cache"])
output2 = llm.generate(prompt, sampling_params)
# cmp output
assert output[0].outputs[0].text == output2[0].outputs[0].text

View File

@ -115,7 +115,7 @@ def test_compile_correctness(
str(pp_size), str(pp_size),
"-tp", "-tp",
str(tp_size), str(tp_size),
"-O.cudagraph_mode=none", "-cc.cudagraph_mode=none",
] ]
all_args: list[list[str]] = [] all_args: list[list[str]] = []
@ -128,7 +128,7 @@ def test_compile_correctness(
]: ]:
for mode in [CompilationMode.NONE, comp_mode]: for mode in [CompilationMode.NONE, comp_mode]:
all_args.append( all_args.append(
final_args + [f"-O.mode={mode.name}", "-O.backend=inductor"] final_args + [f"-cc.mode={mode.name}", "-cc.backend=inductor"]
) )
# inductor will change the output, so we only compare if the output # inductor will change the output, so we only compare if the output
@ -148,7 +148,7 @@ def test_compile_correctness(
CompilationMode.DYNAMO_TRACE_ONCE, CompilationMode.DYNAMO_TRACE_ONCE,
CompilationMode.VLLM_COMPILE, CompilationMode.VLLM_COMPILE,
]: ]:
all_args.append(final_args + [f"-O.mode={mode.name}", "-O.backend=eager"]) all_args.append(final_args + [f"-cc.mode={mode.name}", "-cc.backend=eager"])
all_envs.append({}) all_envs.append({})
all_envs.append({}) all_envs.append({})

View File

@ -459,14 +459,17 @@ class HfRunner:
embeddings.append(embedding) embeddings.append(embedding)
return embeddings return embeddings
def classify(self, prompts: list[str]) -> list[str]: def classify(self, prompts: list[str]) -> list[list[float]]:
# output is final logits # output is final logits
all_inputs = self.get_inputs(prompts) all_inputs = self.get_inputs(prompts)
outputs = [] outputs: list[list[float]] = []
problem_type = getattr(self.config, "problem_type", "") problem_type = getattr(self.config, "problem_type", "")
for inputs in all_inputs: for inputs in all_inputs:
output = self.model(**self.wrap_device(inputs)) output = self.model(**self.wrap_device(inputs))
assert isinstance(output.logits, torch.Tensor)
if problem_type == "regression": if problem_type == "regression":
logits = output.logits[0].tolist() logits = output.logits[0].tolist()
elif problem_type == "multi_label_classification": elif problem_type == "multi_label_classification":
@ -1424,3 +1427,32 @@ def disable_deepgemm_ue8m0(monkeypatch):
# Clear cache so the next time it is used it is processed with the # Clear cache so the next time it is used it is processed with the
# default VLLM_USE_DEEP_GEMM_E8M0 setting. # default VLLM_USE_DEEP_GEMM_E8M0 setting.
is_deep_gemm_e8m0_used.cache_clear() is_deep_gemm_e8m0_used.cache_clear()
@pytest.fixture(autouse=True)
def clean_gpu_memory_between_tests():
if os.getenv("VLLM_TEST_CLEAN_GPU_MEMORY", "0") != "1":
yield
return
# Wait for GPU memory to be cleared before starting the test
import gc
from tests.utils import wait_for_gpu_memory_to_clear
num_gpus = torch.cuda.device_count()
if num_gpus > 0:
try:
wait_for_gpu_memory_to_clear(
devices=list(range(num_gpus)),
threshold_ratio=0.1,
)
except ValueError as e:
logger.info("Failed to clean GPU memory: %s", e)
yield
# Clean up GPU memory after the test
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()

View File

@ -248,15 +248,15 @@ def test_optimization_level(args, expected):
@pytest.mark.parametrize( @pytest.mark.parametrize(
("args", "expected"), ("args", "expected"),
[ [
(["-O.mode=0"], 0), (["-cc.mode=0"], 0),
(["-O.mode=1"], 1), (["-cc.mode=1"], 1),
(["-O.mode=2"], 2), (["-cc.mode=2"], 2),
(["-O.mode=3"], 3), (["-cc.mode=3"], 3),
], ],
) )
def test_mode_parser(args, expected): def test_mode_parser(args, expected):
""" """
Test compilation config modes (-O.mode=int) map to compilation_config. Test compilation config modes (-cc.mode=int) map to compilation_config.
""" """
parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
parsed_args = parser.parse_args(args) parsed_args = parser.parse_args(args)
@ -273,7 +273,7 @@ def test_compilation_config():
# set to string form of a dict # set to string form of a dict
args = parser.parse_args( args = parser.parse_args(
[ [
"-O", "-cc",
'{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], "backend": "eager"}', '{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], "backend": "eager"}',
] ]
) )

View File

@ -188,11 +188,11 @@ number: "1" | "2"
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def zephyr_lora_files(): def qwen3_lora_files():
"""Download zephyr LoRA files once per test session.""" """Download Qwen3 LoRA files once per test session."""
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
return snapshot_download(repo_id="typeof/zephyr-7b-beta-lora") return snapshot_download(repo_id="charent/self_cognition_Alice")
@pytest.fixture(scope="session") @pytest.fixture(scope="session")

View File

@ -23,6 +23,16 @@ MODEL_CONFIGS = [
"max_num_seqs": 64, "max_num_seqs": 64,
"tensor_parallel_size": 1, "tensor_parallel_size": 1,
}, },
{
"model": "Qwen/Qwen3-0.6B",
"enforce_eager": True,
"gpu_memory_utilization": 0.50,
"max_model_len": 64,
"max_num_batched_tokens": 64,
"max_num_seqs": 64,
"tensor_parallel_size": 1,
"tokenizer": "Qwen/Qwen3-4B",
},
{ {
"model": "mistralai/Mistral-7B-Instruct-v0.1", "model": "mistralai/Mistral-7B-Instruct-v0.1",
"enforce_eager": True, "enforce_eager": True,

View File

@ -16,7 +16,7 @@ from vllm.version import __version__ as VLLM_VERSION
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" MODEL_NAME = "Qwen/Qwen3-0.6B"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")

View File

@ -19,6 +19,14 @@ from ...utils import RemoteOpenAIServer
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
@pytest.fixture(scope="module")
def zephyr_lora_files():
"""Download zephyr LoRA files once per test session."""
from huggingface_hub import snapshot_download
return snapshot_download(repo_id="typeof/zephyr-7b-beta-lora")
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(zephyr_lora_files): # noqa: F811 def server(zephyr_lora_files): # noqa: F811
args = [ args = [

View File

@ -8,7 +8,7 @@ import pytest_asyncio
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
# any model with a chat template should work here # any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" MODEL_NAME = "Qwen/Qwen3-0.6B"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
@ -20,7 +20,6 @@ def server():
"--max-model-len", "--max-model-len",
"8192", "8192",
"--enforce-eager", "--enforce-eager",
# lora config below
"--max-num-seqs", "--max-num-seqs",
"128", "128",
"--enable-chunked-prefill", "--enable-chunked-prefill",

View File

@ -13,9 +13,8 @@ import pytest_asyncio
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
# any model with a chat template should work here # any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" MODEL_NAME = "Qwen/Qwen3-0.6B"
# technically this needs Mistral-7B-v0.1 as base, but we're not testing
# generation quality here
BADREQUEST_CASES = [ BADREQUEST_CASES = [
( (
@ -33,11 +32,11 @@ BADREQUEST_CASES = [
@pytest.fixture(scope="module", params=[True]) @pytest.fixture(scope="module", params=[True])
def server_with_lora_modules_json(request, zephyr_lora_files): def server_with_lora_modules_json(request, qwen3_lora_files):
# Define the json format LoRA module configurations # Define the json format LoRA module configurations
lora_module_1 = { lora_module_1 = {
"name": "zephyr-lora", "name": "qwen3-lora",
"path": zephyr_lora_files, "path": qwen3_lora_files,
"base_model_name": MODEL_NAME, "base_model_name": MODEL_NAME,
} }
@ -74,7 +73,7 @@ async def client(server_with_lora_modules_json):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_static_lora_lineage(client: openai.AsyncOpenAI, zephyr_lora_files): async def test_static_lora_lineage(client: openai.AsyncOpenAI, qwen3_lora_files):
models = await client.models.list() models = await client.models.list()
models = models.data models = models.data
served_model = models[0] served_model = models[0]
@ -82,17 +81,17 @@ async def test_static_lora_lineage(client: openai.AsyncOpenAI, zephyr_lora_files
assert served_model.id == MODEL_NAME assert served_model.id == MODEL_NAME
assert served_model.root == MODEL_NAME assert served_model.root == MODEL_NAME
assert served_model.parent is None assert served_model.parent is None
assert all(lora_model.root == zephyr_lora_files for lora_model in lora_models) assert all(lora_model.root == qwen3_lora_files for lora_model in lora_models)
assert all(lora_model.parent == MODEL_NAME for lora_model in lora_models) assert all(lora_model.parent == MODEL_NAME for lora_model in lora_models)
assert lora_models[0].id == "zephyr-lora" assert lora_models[0].id == "qwen3-lora"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_dynamic_lora_lineage(client: openai.AsyncOpenAI, zephyr_lora_files): async def test_dynamic_lora_lineage(client: openai.AsyncOpenAI, qwen3_lora_files):
response = await client.post( response = await client.post(
"load_lora_adapter", "load_lora_adapter",
cast_to=str, cast_to=str,
body={"lora_name": "zephyr-lora-3", "lora_path": zephyr_lora_files}, body={"lora_name": "qwen3-lora-3", "lora_path": qwen3_lora_files},
) )
# Ensure adapter loads before querying /models # Ensure adapter loads before querying /models
assert "success" in response assert "success" in response
@ -100,9 +99,9 @@ async def test_dynamic_lora_lineage(client: openai.AsyncOpenAI, zephyr_lora_file
models = await client.models.list() models = await client.models.list()
models = models.data models = models.data
dynamic_lora_model = models[-1] dynamic_lora_model = models[-1]
assert dynamic_lora_model.root == zephyr_lora_files assert dynamic_lora_model.root == qwen3_lora_files
assert dynamic_lora_model.parent == MODEL_NAME assert dynamic_lora_model.parent == MODEL_NAME
assert dynamic_lora_model.id == "zephyr-lora-3" assert dynamic_lora_model.id == "qwen3-lora-3"
@pytest.mark.asyncio @pytest.mark.asyncio
@ -134,7 +133,7 @@ async def test_dynamic_lora_invalid_files(client: openai.AsyncOpenAI, tmp_path):
async def test_dynamic_lora_badrequests( async def test_dynamic_lora_badrequests(
client: openai.AsyncOpenAI, client: openai.AsyncOpenAI,
tmp_path, tmp_path,
zephyr_lora_files, qwen3_lora_files,
test_name: str, test_name: str,
config_change: dict, config_change: dict,
expected_error: str, expected_error: str,
@ -143,7 +142,7 @@ async def test_dynamic_lora_badrequests(
test_dir = tmp_path / test_name test_dir = tmp_path / test_name
# Copy adapter files # Copy adapter files
shutil.copytree(zephyr_lora_files, test_dir) shutil.copytree(qwen3_lora_files, test_dir)
# Load and modify configuration # Load and modify configuration
config_path = test_dir / "adapter_config.json" config_path = test_dir / "adapter_config.json"
@ -167,7 +166,7 @@ async def test_dynamic_lora_badrequests(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_multiple_lora_adapters( async def test_multiple_lora_adapters(
client: openai.AsyncOpenAI, tmp_path, zephyr_lora_files client: openai.AsyncOpenAI, tmp_path, qwen3_lora_files
): ):
"""Validate that many loras can be dynamically registered and inferenced """Validate that many loras can be dynamically registered and inferenced
with concurrently""" with concurrently"""
@ -178,7 +177,7 @@ async def test_multiple_lora_adapters(
await client.post( await client.post(
"load_lora_adapter", "load_lora_adapter",
cast_to=str, cast_to=str,
body={"lora_name": adapter_name, "lora_path": str(zephyr_lora_files)}, body={"lora_name": adapter_name, "lora_path": str(qwen3_lora_files)},
) )
for _ in range(3): for _ in range(3):
await client.completions.create( await client.completions.create(
@ -199,7 +198,7 @@ async def test_multiple_lora_adapters(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_loading_invalid_adapters_does_not_break_others( async def test_loading_invalid_adapters_does_not_break_others(
client: openai.AsyncOpenAI, tmp_path, zephyr_lora_files client: openai.AsyncOpenAI, tmp_path, qwen3_lora_files
): ):
invalid_files = tmp_path / "invalid_files" invalid_files = tmp_path / "invalid_files"
invalid_files.mkdir() invalid_files.mkdir()
@ -215,7 +214,7 @@ async def test_loading_invalid_adapters_does_not_break_others(
while not stop_good_requests_event.is_set(): while not stop_good_requests_event.is_set():
try: try:
batch = await client.completions.create( batch = await client.completions.create(
model="zephyr-lora", model="qwen3-lora",
prompt=["Hello there", "Foo bar bazz buzz"], prompt=["Hello there", "Foo bar bazz buzz"],
max_tokens=5, max_tokens=5,
) )
@ -254,7 +253,7 @@ async def test_loading_invalid_adapters_does_not_break_others(
await client.post( await client.post(
"load_lora_adapter", "load_lora_adapter",
cast_to=str, cast_to=str,
body={"lora_name": "valid", "lora_path": zephyr_lora_files}, body={"lora_name": "valid", "lora_path": qwen3_lora_files},
) )
await client.completions.create( await client.completions.create(
model="valid", model="valid",
@ -267,7 +266,7 @@ async def test_loading_invalid_adapters_does_not_break_others(
async def test_beam_search_with_lora_adapters( async def test_beam_search_with_lora_adapters(
client: openai.AsyncOpenAI, client: openai.AsyncOpenAI,
tmp_path, tmp_path,
zephyr_lora_files, qwen3_lora_files,
): ):
"""Validate that async beam search can be used with lora.""" """Validate that async beam search can be used with lora."""
@ -275,7 +274,7 @@ async def test_beam_search_with_lora_adapters(
await client.post( await client.post(
"load_lora_adapter", "load_lora_adapter",
cast_to=str, cast_to=str,
body={"lora_name": adapter_name, "lora_path": str(zephyr_lora_files)}, body={"lora_name": adapter_name, "lora_path": str(qwen3_lora_files)},
) )
for _ in range(3): for _ in range(3):
await client.completions.create( await client.completions.create(

View File

@ -114,7 +114,7 @@ def mock_serving_setup():
mock_engine.add_lora.reset_mock() mock_engine.add_lora.reset_mock()
mock_engine.model_config = MockModelConfig() mock_engine.model_config = MockModelConfig()
mock_engine.processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock() mock_engine.io_processor = MagicMock()
models = OpenAIServingModels( models = OpenAIServingModels(

View File

@ -8,13 +8,13 @@ import pytest_asyncio
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
# any model with a chat template should work here # any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" MODEL_NAME = "Qwen/Qwen3-0.6B"
# technically this needs Mistral-7B-v0.1 as base, but we're not testing # technically this needs Mistral-7B-v0.1 as base, but we're not testing
# generation quality here # generation quality here
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(zephyr_lora_files): def server(qwen3_lora_files):
args = [ args = [
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
"--dtype", "--dtype",
@ -25,7 +25,7 @@ def server(zephyr_lora_files):
# lora config below # lora config below
"--enable-lora", "--enable-lora",
"--lora-modules", "--lora-modules",
f"zephyr-lora={zephyr_lora_files}", f"qwen3-lora={qwen3_lora_files}",
"--max-lora-rank", "--max-lora-rank",
"64", "64",
"--max-cpu-loras", "--max-cpu-loras",
@ -45,12 +45,12 @@ async def client(server):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_check_models(client: openai.AsyncOpenAI, zephyr_lora_files): async def test_check_models(client: openai.AsyncOpenAI, qwen3_lora_files):
models = await client.models.list() models = await client.models.list()
models = models.data models = models.data
served_model = models[0] served_model = models[0]
lora_models = models[1:] lora_models = models[1:]
assert served_model.id == MODEL_NAME assert served_model.id == MODEL_NAME
assert served_model.root == MODEL_NAME assert served_model.root == MODEL_NAME
assert all(lora_model.root == zephyr_lora_files for lora_model in lora_models) assert all(lora_model.root == qwen3_lora_files for lora_model in lora_models)
assert lora_models[0].id == "zephyr-lora" assert lora_models[0].id == "qwen3-lora"

View File

@ -8,7 +8,7 @@ import pytest_asyncio
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
# any model with a chat template should work here # any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" MODEL_NAME = "Qwen/Qwen3-0.6B"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
@ -110,8 +110,9 @@ async def test_single_completion(client: openai.AsyncOpenAI):
choice = completion.choices[0] choice = completion.choices[0]
assert len(choice.text) >= 5 assert len(choice.text) >= 5
assert choice.finish_reason == "length" assert choice.finish_reason == "length"
# When using Qwen3-0.6B, prompt tokens=[9707, 11, 847, 829, 374]
assert completion.usage == openai.types.CompletionUsage( assert completion.usage == openai.types.CompletionUsage(
completion_tokens=5, prompt_tokens=6, total_tokens=11 completion_tokens=5, prompt_tokens=5, total_tokens=10
) )
# test using token IDs # test using token IDs

View File

@ -4,6 +4,9 @@
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from openai import OpenAI from openai import OpenAI
from openai_harmony import ToolDescription, ToolNamespaceConfig
from vllm.entrypoints.tool_server import MCPToolServer
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
@ -111,6 +114,48 @@ async def test_mcp_tool_env_flag_enabled(mcp_enabled_client: OpenAI, model_name:
) )
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_mcp_tool_with_allowed_tools_star(
mcp_enabled_client: OpenAI, model_name: str
):
"""Test MCP tool with allowed_tools=['*'] to select all available tools.
This E2E test verifies that the "*" wildcard works end-to-end.
See test_serving_responses.py for detailed unit tests of "*" normalization.
"""
response = await mcp_enabled_client.responses.create(
model=model_name,
input=(
"Execute the following code: "
"import random; print(random.randint(1, 1000000))"
),
instructions=(
"You must use the Python tool to execute code. Never simulate execution."
),
tools=[
{
"type": "mcp",
"server_label": "code_interpreter",
"server_url": "http://localhost:8888",
# Using "*" to allow all tools from this MCP server
"allowed_tools": ["*"],
}
],
extra_body={"enable_response_messages": True},
)
assert response is not None
assert response.status == "completed"
# Verify tool calls work with allowed_tools=["*"]
tool_call_found = False
for message in response.output_messages:
recipient = message.get("recipient")
if recipient and recipient.startswith("python"):
tool_call_found = True
break
assert tool_call_found, "Should have found at least one Python tool call with '*'"
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_mcp_tool_env_flag_disabled(mcp_disabled_client: OpenAI, model_name: str): async def test_mcp_tool_env_flag_disabled(mcp_disabled_client: OpenAI, model_name: str):
@ -159,3 +204,58 @@ async def test_mcp_tool_env_flag_disabled(mcp_disabled_client: OpenAI, model_nam
assert message.get("author").get("role") != "developer", ( assert message.get("author").get("role") != "developer", (
"No developer messages should be present without a valid tool" "No developer messages should be present without a valid tool"
) )
def test_get_tool_description():
"""Test MCPToolServer.get_tool_description filtering logic.
Note: The wildcard "*" is normalized to None by
_extract_allowed_tools_from_mcp_requests before reaching this layer,
so we only test None and specific tool filtering here.
See test_serving_responses.py for "*" normalization tests.
"""
pytest.importorskip("mcp")
server = MCPToolServer()
tool1 = ToolDescription.new(
name="tool1", description="First", parameters={"type": "object"}
)
tool2 = ToolDescription.new(
name="tool2", description="Second", parameters={"type": "object"}
)
tool3 = ToolDescription.new(
name="tool3", description="Third", parameters={"type": "object"}
)
server.harmony_tool_descriptions = {
"test_server": ToolNamespaceConfig(
name="test_server", description="test", tools=[tool1, tool2, tool3]
)
}
# Nonexistent server
assert server.get_tool_description("nonexistent") is None
# None (no filter) - returns all tools
result = server.get_tool_description("test_server", allowed_tools=None)
assert len(result.tools) == 3
# Filter to specific tools
result = server.get_tool_description(
"test_server", allowed_tools=["tool1", "tool3"]
)
assert len(result.tools) == 2
assert result.tools[0].name == "tool1"
assert result.tools[1].name == "tool3"
# Single tool
result = server.get_tool_description("test_server", allowed_tools=["tool2"])
assert len(result.tools) == 1
assert result.tools[0].name == "tool2"
# No matching tools - returns None
result = server.get_tool_description("test_server", allowed_tools=["nonexistent"])
assert result is None
# Empty list - returns None
assert server.get_tool_description("test_server", allowed_tools=[]) is None

View File

@ -11,11 +11,11 @@ from vllm.transformers_utils.tokenizer import get_tokenizer
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" MODEL_NAME = "Qwen/Qwen3-0.6B"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def default_server_args(zephyr_lora_files): def default_server_args(qwen3_lora_files):
return [ return [
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
"--dtype", "--dtype",
@ -28,7 +28,7 @@ def default_server_args(zephyr_lora_files):
# lora config # lora config
"--enable-lora", "--enable-lora",
"--lora-modules", "--lora-modules",
f"zephyr-lora={zephyr_lora_files}", f"qwen3-lora={qwen3_lora_files}",
"--max-lora-rank", "--max-lora-rank",
"64", "64",
"--max-cpu-loras", "--max-cpu-loras",

View File

@ -7,7 +7,7 @@ import tempfile
import pytest import pytest
from vllm.entrypoints.openai.protocol import BatchRequestOutput from vllm.entrypoints.openai.run_batch import BatchRequestOutput
MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM" MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"

View File

@ -399,7 +399,7 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
@dataclass @dataclass
class MockEngine: class MockEngine:
model_config: MockModelConfig = field(default_factory=MockModelConfig) model_config: MockModelConfig = field(default_factory=MockModelConfig)
processor: MagicMock = field(default_factory=MagicMock) input_processor: MagicMock = field(default_factory=MagicMock)
io_processor: MagicMock = field(default_factory=MagicMock) io_processor: MagicMock = field(default_factory=MagicMock)
@ -429,7 +429,7 @@ async def test_serving_chat_returns_correct_model_name():
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = MockModelConfig() mock_engine.model_config = MockModelConfig()
mock_engine.processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock() mock_engine.io_processor = MagicMock()
serving_chat = _build_serving_chat(mock_engine) serving_chat = _build_serving_chat(mock_engine)
@ -459,7 +459,7 @@ async def test_serving_chat_should_set_correct_max_tokens():
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = MockModelConfig() mock_engine.model_config = MockModelConfig()
mock_engine.processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock() mock_engine.io_processor = MagicMock()
serving_chat = _build_serving_chat(mock_engine) serving_chat = _build_serving_chat(mock_engine)
@ -492,7 +492,7 @@ async def test_serving_chat_should_set_correct_max_tokens():
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = mock_model_config mock_engine.model_config = mock_model_config
mock_engine.processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock() mock_engine.io_processor = MagicMock()
# Initialize the serving chat # Initialize the serving chat
@ -537,7 +537,7 @@ async def test_serving_chat_should_set_correct_max_tokens():
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = mock_model_config mock_engine.model_config = mock_model_config
mock_engine.processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock() mock_engine.io_processor = MagicMock()
# Initialize the serving chat # Initialize the serving chat
@ -583,7 +583,7 @@ async def test_serving_chat_could_load_correct_generation_config():
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = mock_model_config mock_engine.model_config = mock_model_config
mock_engine.processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock() mock_engine.io_processor = MagicMock()
# Initialize the serving chat # Initialize the serving chat
@ -629,7 +629,7 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = mock_model_config mock_engine.model_config = mock_model_config
mock_engine.processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock() mock_engine.io_processor = MagicMock()
serving_chat = _build_serving_chat(mock_engine) serving_chat = _build_serving_chat(mock_engine)
@ -662,7 +662,7 @@ async def test_serving_chat_data_parallel_rank_extraction():
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = MockModelConfig() mock_engine.model_config = MockModelConfig()
mock_engine.processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock() mock_engine.io_processor = MagicMock()
# Mock the generate method to return an async generator # Mock the generate method to return an async generator

View File

@ -10,7 +10,7 @@ import pytest
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.tokenizers import MistralTokenizer
@pytest.fixture() @pytest.fixture()
@ -23,7 +23,7 @@ def serving() -> OpenAIServing:
model_config.max_model_len = 32768 model_config.max_model_len = 32768
models = Mock(spec=OpenAIServingModels) models = Mock(spec=OpenAIServingModels)
models.model_config = model_config models.model_config = model_config
models.processor = Mock() models.input_processor = Mock()
models.io_processor = Mock() models.io_processor = Mock()
serving = OpenAIServing( serving = OpenAIServing(

View File

@ -30,7 +30,7 @@ async def _async_serving_models_init() -> OpenAIServingModels:
mock_model_config = MagicMock(spec=ModelConfig) mock_model_config = MagicMock(spec=ModelConfig)
mock_model_config.max_model_len = 2048 mock_model_config.max_model_len = 2048
mock_engine_client.model_config = mock_model_config mock_engine_client.model_config = mock_model_config
mock_engine_client.processor = MagicMock() mock_engine_client.input_processor = MagicMock()
mock_engine_client.io_processor = MagicMock() mock_engine_client.io_processor = MagicMock()
serving_models = OpenAIServingModels( serving_models = OpenAIServingModels(

View File

@ -17,6 +17,7 @@ from vllm.entrypoints.context import ConversationContext
from vllm.entrypoints.openai.protocol import ErrorResponse, ResponsesRequest from vllm.entrypoints.openai.protocol import ErrorResponse, ResponsesRequest
from vllm.entrypoints.openai.serving_responses import ( from vllm.entrypoints.openai.serving_responses import (
OpenAIServingResponses, OpenAIServingResponses,
_extract_allowed_tools_from_mcp_requests,
extract_tool_types, extract_tool_types,
) )
from vllm.entrypoints.tool_server import ToolServer from vllm.entrypoints.tool_server import ToolServer
@ -127,7 +128,7 @@ class TestInitializeToolSessions:
model_config.get_diff_sampling_param.return_value = {} model_config.get_diff_sampling_param.return_value = {}
engine_client.model_config = model_config engine_client.model_config = model_config
engine_client.processor = MagicMock() engine_client.input_processor = MagicMock()
engine_client.io_processor = MagicMock() engine_client.io_processor = MagicMock()
models = MagicMock() models = MagicMock()
@ -213,7 +214,7 @@ class TestValidateGeneratorInput:
model_config.get_diff_sampling_param.return_value = {} model_config.get_diff_sampling_param.return_value = {}
engine_client.model_config = model_config engine_client.model_config = model_config
engine_client.processor = MagicMock() engine_client.input_processor = MagicMock()
engine_client.io_processor = MagicMock() engine_client.io_processor = MagicMock()
models = MagicMock() models = MagicMock()
@ -254,3 +255,98 @@ class TestValidateGeneratorInput:
# Should return an ErrorResponse # Should return an ErrorResponse
assert result is not None assert result is not None
assert isinstance(result, ErrorResponse) assert isinstance(result, ErrorResponse)
class TestExtractAllowedToolsFromMcpRequests:
"""Test class for _extract_allowed_tools_from_mcp_requests function"""
def test_extract_allowed_tools_basic_formats(self):
"""Test extraction with list format, object format, and None."""
from openai.types.responses.tool import McpAllowedToolsMcpToolFilter
tools = [
# List format
Mcp(
type="mcp",
server_label="server1",
allowed_tools=["tool1", "tool2"],
),
# Object format
Mcp(
type="mcp",
server_label="server2",
allowed_tools=McpAllowedToolsMcpToolFilter(
tool_names=["tool3", "tool4"]
),
),
# None (no filter)
Mcp(
type="mcp",
server_label="server3",
allowed_tools=None,
),
]
result = _extract_allowed_tools_from_mcp_requests(tools)
assert result == {
"server1": ["tool1", "tool2"],
"server2": ["tool3", "tool4"],
"server3": None,
}
def test_extract_allowed_tools_star_normalization(self):
"""Test that '*' wildcard is normalized to None (select all tools).
This is the key test requested by reviewers to explicitly demonstrate
that the "*" select-all scenario is handled correctly.
"""
from openai.types.responses.tool import McpAllowedToolsMcpToolFilter
tools = [
# Star in list format
Mcp(
type="mcp",
server_label="server1",
allowed_tools=["*"],
),
# Star mixed with other tools in list
Mcp(
type="mcp",
server_label="server2",
allowed_tools=["tool1", "*"],
),
# Star in object format
Mcp(
type="mcp",
server_label="server3",
allowed_tools=McpAllowedToolsMcpToolFilter(tool_names=["*"]),
),
]
result = _extract_allowed_tools_from_mcp_requests(tools)
# All should be normalized to None (allows all tools)
assert result == {
"server1": None,
"server2": None,
"server3": None,
}
def test_extract_allowed_tools_filters_non_mcp(self):
"""Test that non-MCP tools are ignored during extraction."""
tools = [
Mcp(
type="mcp",
server_label="server1",
allowed_tools=["tool1"],
),
LocalShell(type="local_shell"), # Non-MCP tool should be ignored
Mcp(
type="mcp",
server_label="server2",
allowed_tools=["tool2"],
),
]
result = _extract_allowed_tools_from_mcp_requests(tools)
# Non-MCP tools should be ignored
assert result == {
"server1": ["tool1"],
"server2": ["tool2"],
}

View File

@ -53,7 +53,7 @@ async def test_tokenize_completions(
model_name: str, model_name: str,
tokenizer_name: str, tokenizer_name: str,
): ):
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, tokenizer_mode="fast") tokenizer = get_tokenizer(tokenizer_name=tokenizer_name)
for add_special in [False, True]: for add_special in [False, True]:
prompt = "vllm1 This is a test prompt." prompt = "vllm1 This is a test prompt."
@ -87,7 +87,7 @@ async def test_tokenize_chat(
model_name: str, model_name: str,
tokenizer_name: str, tokenizer_name: str,
): ):
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, tokenizer_mode="fast") tokenizer = get_tokenizer(tokenizer_name=tokenizer_name)
for add_generation in [False, True]: for add_generation in [False, True]:
for add_special in [False, True]: for add_special in [False, True]:
@ -140,7 +140,7 @@ async def test_tokenize_chat_with_tools(
model_name: str, model_name: str,
tokenizer_name: str, tokenizer_name: str,
): ):
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, tokenizer_mode="fast") tokenizer = get_tokenizer(tokenizer_name=tokenizer_name)
for add_generation in [False, True]: for add_generation in [False, True]:
for add_special in [False, True]: for add_special in [False, True]:
@ -210,7 +210,7 @@ async def test_tokenize_with_return_token_strs(
model_name: str, model_name: str,
tokenizer_name: str, tokenizer_name: str,
): ):
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, tokenizer_mode="fast") tokenizer = get_tokenizer(tokenizer_name=tokenizer_name)
prompt = "This is a token_strs test prompt! vllm1" prompt = "This is a token_strs test prompt! vllm1"
response = requests.post( response = requests.post(
@ -240,7 +240,7 @@ async def test_detokenize(
model_name: str, model_name: str,
tokenizer_name: str, tokenizer_name: str,
): ):
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, tokenizer_mode="fast") tokenizer = get_tokenizer(tokenizer_name=tokenizer_name)
prompt = "This is a test prompt. vllm1" prompt = "This is a test prompt. vllm1"
tokens = tokenizer.encode(prompt, add_special_tokens=False) tokens = tokenizer.encode(prompt, add_special_tokens=False)

View File

@ -235,3 +235,16 @@ async def test_audio_prompt(mary_had_lamb, whisper_client):
) )
out_prompt = json.loads(transcription_wprompt)["text"] out_prompt = json.loads(transcription_wprompt)["text"]
assert prefix in out_prompt assert prefix in out_prompt
@pytest.mark.asyncio
async def test_audio_with_timestamp(mary_had_lamb, whisper_client):
transcription = await whisper_client.audio.transcriptions.create(
model=MODEL_NAME,
file=mary_had_lamb,
language="en",
response_format="verbose_json",
temperature=0.0,
)
assert transcription.segments is not None
assert len(transcription.segments) > 0

View File

@ -10,7 +10,7 @@ from vllm.version import __version__ as VLLM_VERSION
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" MODEL_NAME = "Qwen/Qwen3-0.6B"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")

View File

@ -4,9 +4,9 @@
import pytest import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def default_tokenizer() -> AnyTokenizer: def default_tokenizer() -> TokenizerLike:
return AutoTokenizer.from_pretrained("gpt2") return AutoTokenizer.from_pretrained("gpt2")

View File

@ -7,7 +7,7 @@ import pytest
from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import Hermes2ProToolParser from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import Hermes2ProToolParser
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from ....utils import RemoteOpenAIServer from ....utils import RemoteOpenAIServer
@ -270,14 +270,14 @@ async def test_streaming_product_tool_call():
@pytest.fixture @pytest.fixture
def qwen_tokenizer() -> AnyTokenizer: def qwen_tokenizer() -> TokenizerLike:
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
return get_tokenizer("Qwen/Qwen3-32B") return get_tokenizer("Qwen/Qwen3-32B")
@pytest.fixture @pytest.fixture
def hermes_parser(qwen_tokenizer: AnyTokenizer) -> Hermes2ProToolParser: def hermes_parser(qwen_tokenizer: TokenizerLike) -> Hermes2ProToolParser:
return Hermes2ProToolParser(qwen_tokenizer) return Hermes2ProToolParser(qwen_tokenizer)
@ -291,7 +291,7 @@ def any_chat_request() -> ChatCompletionRequest:
def test_hermes_parser_streaming_just_forward_text( def test_hermes_parser_streaming_just_forward_text(
qwen_tokenizer: AnyTokenizer, qwen_tokenizer: TokenizerLike,
hermes_parser: Hermes2ProToolParser, hermes_parser: Hermes2ProToolParser,
any_chat_request: ChatCompletionRequest, any_chat_request: ChatCompletionRequest,
) -> None: ) -> None:
@ -323,7 +323,7 @@ def test_hermes_parser_streaming_just_forward_text(
def test_hermes_parser_streaming_failure_case_bug_19056( def test_hermes_parser_streaming_failure_case_bug_19056(
qwen_tokenizer: AnyTokenizer, qwen_tokenizer: TokenizerLike,
hermes_parser: Hermes2ProToolParser, hermes_parser: Hermes2ProToolParser,
any_chat_request: ChatCompletionRequest, any_chat_request: ChatCompletionRequest,
) -> None: ) -> None:
@ -357,7 +357,7 @@ def test_hermes_parser_streaming_failure_case_bug_19056(
def test_hermes_parser_streaming( def test_hermes_parser_streaming(
qwen_tokenizer: AnyTokenizer, qwen_tokenizer: TokenizerLike,
hermes_parser: Hermes2ProToolParser, hermes_parser: Hermes2ProToolParser,
any_chat_request: ChatCompletionRequest, any_chat_request: ChatCompletionRequest,
) -> None: ) -> None:

View File

@ -7,11 +7,11 @@ import pytest
from vllm.entrypoints.openai.protocol import ExtractedToolCallInformation from vllm.entrypoints.openai.protocol import ExtractedToolCallInformation
from vllm.entrypoints.openai.tool_parsers.llama_tool_parser import Llama3JsonToolParser from vllm.entrypoints.openai.tool_parsers.llama_tool_parser import Llama3JsonToolParser
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
@pytest.fixture @pytest.fixture
def parser(default_tokenizer: AnyTokenizer): def parser(default_tokenizer: TokenizerLike):
return Llama3JsonToolParser(default_tokenizer) return Llama3JsonToolParser(default_tokenizer)

View File

@ -11,7 +11,7 @@ from tests.entrypoints.openai.tool_parsers.utils import (
) )
from vllm.entrypoints.openai.protocol import FunctionCall from vllm.entrypoints.openai.protocol import FunctionCall
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
# Test cases similar to pythonic parser but with Llama4 specific format # Test cases similar to pythonic parser but with Llama4 specific format
SIMPLE_FUNCTION_OUTPUT = "[get_weather(city='LA', metric='C')]" SIMPLE_FUNCTION_OUTPUT = "[get_weather(city='LA', metric='C')]"
@ -64,7 +64,7 @@ PYTHON_TAG_FUNCTION_OUTPUT = (
@pytest.mark.parametrize("streaming", [True, False]) @pytest.mark.parametrize("streaming", [True, False])
def test_no_tool_call(streaming: bool, default_tokenizer: AnyTokenizer): def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")( tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
default_tokenizer default_tokenizer
) )
@ -208,7 +208,7 @@ def test_tool_call(
streaming: bool, streaming: bool,
model_output: str, model_output: str,
expected_tool_calls: list[FunctionCall], expected_tool_calls: list[FunctionCall],
default_tokenizer: AnyTokenizer, default_tokenizer: TokenizerLike,
): ):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")( tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
default_tokenizer default_tokenizer
@ -224,7 +224,7 @@ def test_tool_call(
assert actual.function == expected assert actual.function == expected
def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer): def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")( tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
default_tokenizer default_tokenizer
) )
@ -246,7 +246,7 @@ def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer):
@pytest.mark.parametrize("streaming", [False]) @pytest.mark.parametrize("streaming", [False])
def test_regex_timeout_handling(streaming: bool, default_tokenizer: AnyTokenizer): def test_regex_timeout_handling(streaming: bool, default_tokenizer: TokenizerLike):
"""test regex timeout is handled gracefully""" """test regex timeout is handled gracefully"""
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")( tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
default_tokenizer default_tokenizer

View File

@ -11,7 +11,7 @@ from tests.entrypoints.openai.tool_parsers.utils import (
) )
from vllm.entrypoints.openai.protocol import FunctionCall from vllm.entrypoints.openai.protocol import FunctionCall
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1 # https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')" SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')"
@ -69,7 +69,7 @@ ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
@pytest.mark.parametrize("streaming", [True, False]) @pytest.mark.parametrize("streaming", [True, False])
def test_no_tool_call(streaming: bool, default_tokenizer: AnyTokenizer): def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")( tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
default_tokenizer default_tokenizer
) )
@ -188,7 +188,7 @@ def test_tool_call(
streaming: bool, streaming: bool,
model_output: str, model_output: str,
expected_tool_calls: list[FunctionCall], expected_tool_calls: list[FunctionCall],
default_tokenizer: AnyTokenizer, default_tokenizer: TokenizerLike,
): ):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")( tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
default_tokenizer default_tokenizer
@ -205,7 +205,7 @@ def test_tool_call(
assert actual.function == expected assert actual.function == expected
def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer): def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")( tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
default_tokenizer default_tokenizer
) )
@ -228,7 +228,7 @@ def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer):
@pytest.mark.parametrize("streaming", [False]) @pytest.mark.parametrize("streaming", [False])
def test_regex_timeout_handling(streaming: bool, default_tokenizer: AnyTokenizer): def test_regex_timeout_handling(streaming: bool, default_tokenizer: TokenizerLike):
"""test regex timeout is handled gracefully""" """test regex timeout is handled gracefully"""
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")( tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
default_tokenizer default_tokenizer

View File

@ -11,7 +11,7 @@ from tests.entrypoints.openai.tool_parsers.utils import (
) )
from vllm.entrypoints.openai.protocol import FunctionCall from vllm.entrypoints.openai.protocol import FunctionCall
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1 # https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')" SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')"
@ -61,7 +61,7 @@ ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
@pytest.mark.parametrize("streaming", [True, False]) @pytest.mark.parametrize("streaming", [True, False])
def test_no_tool_call(streaming: bool, default_tokenizer: AnyTokenizer): def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")( tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
default_tokenizer default_tokenizer
) )
@ -168,7 +168,7 @@ def test_tool_call(
streaming: bool, streaming: bool,
model_output: str, model_output: str,
expected_tool_calls: list[FunctionCall], expected_tool_calls: list[FunctionCall],
default_tokenizer: AnyTokenizer, default_tokenizer: TokenizerLike,
): ):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")( tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
default_tokenizer default_tokenizer
@ -185,7 +185,7 @@ def test_tool_call(
assert actual.function == expected assert actual.function == expected
def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer): def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")( tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
default_tokenizer default_tokenizer
) )
@ -208,7 +208,7 @@ def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer):
@pytest.mark.parametrize("streaming", [False]) @pytest.mark.parametrize("streaming", [False])
def test_regex_timeout_handling(streaming: bool, default_tokenizer: AnyTokenizer): def test_regex_timeout_handling(streaming: bool, default_tokenizer: TokenizerLike):
"""test regex timeout is handled gracefully""" """test regex timeout is handled gracefully"""
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")( tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
default_tokenizer default_tokenizer

View File

@ -11,7 +11,7 @@ from vllm.entrypoints.openai.protocol import (
ToolCall, ToolCall,
) )
from vllm.entrypoints.openai.tool_parsers import ToolParser from vllm.entrypoints.openai.tool_parsers import ToolParser
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
class StreamingToolReconstructor: class StreamingToolReconstructor:
@ -111,7 +111,7 @@ def run_tool_extraction_nonstreaming(
return tool_parser.extract_tool_calls(model_output, request) return tool_parser.extract_tool_calls(model_output, request)
def split_string_into_token_deltas(tokenizer: AnyTokenizer, text: str) -> list[str]: def split_string_into_token_deltas(tokenizer: TokenizerLike, text: str) -> list[str]:
# Split a string into a series of deltas using the provided tokenizer. Each # Split a string into a series of deltas using the provided tokenizer. Each
# delta will be the string equivalent of a single token. # delta will be the string equivalent of a single token.
token_ids = tokenizer.encode(text, add_special_tokens=False) token_ids = tokenizer.encode(text, add_special_tokens=False)

View File

@ -7,7 +7,8 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from tests.utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.openai.protocol import ClassificationResponse, PoolingResponse from vllm.entrypoints.pooling.classify.protocol import ClassificationResponse
from vllm.entrypoints.pooling.pooling.protocol import PoolingResponse
MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach" MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach"
DTYPE = "float32" # Use float32 to avoid NaN issue DTYPE = "float32" # Use float32 to avoid NaN issue

View File

@ -7,7 +7,7 @@ import pytest
import requests import requests
from tests.utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.openai.protocol import ClassificationResponse from vllm.entrypoints.pooling.classify.protocol import ClassificationResponse
VLM_MODEL_NAME = "muziyongshixin/Qwen2.5-VL-7B-for-VideoCls" VLM_MODEL_NAME = "muziyongshixin/Qwen2.5-VL-7B-for-VideoCls"
MAXIMUM_VIDEOS = 1 MAXIMUM_VIDEOS = 1

View File

@ -15,10 +15,8 @@ import torch.nn.functional as F
from tests.models.language.pooling.embed_utils import run_embedding_correctness_test from tests.models.language.pooling.embed_utils import run_embedding_correctness_test
from tests.models.utils import check_embeddings_close from tests.models.utils import check_embeddings_close
from tests.utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.pooling.embed.protocol import EmbeddingResponse
EmbeddingResponse, from vllm.entrypoints.pooling.pooling.protocol import PoolingResponse
PoolingResponse,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils.serial_utils import ( from vllm.utils.serial_utils import (
@ -199,7 +197,7 @@ async def test_conversation_embedding(
chat_response.raise_for_status() chat_response.raise_for_status()
chat_embeddings = EmbeddingResponse.model_validate(chat_response.json()) chat_embeddings = EmbeddingResponse.model_validate(chat_response.json())
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast") tokenizer = get_tokenizer(tokenizer_name=model_name)
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
messages, messages,
chat_template=DUMMY_CHAT_TEMPLATE, chat_template=DUMMY_CHAT_TEMPLATE,

View File

@ -11,7 +11,7 @@ from tests.conftest import HfRunner
from tests.models.language.pooling.embed_utils import run_embedding_correctness_test from tests.models.language.pooling.embed_utils import run_embedding_correctness_test
from tests.models.utils import EmbedModelInfo from tests.models.utils import EmbedModelInfo
from tests.utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.openai.protocol import EmbeddingResponse from vllm.entrypoints.pooling.embed.protocol import EmbeddingResponse
from vllm.platforms import current_platform from vllm.platforms import current_platform
if current_platform.is_rocm(): if current_platform.is_rocm():

View File

@ -15,7 +15,7 @@ import pytest
import pytest_asyncio import pytest_asyncio
from tests.utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.openai.protocol import EmbeddingResponse from vllm.entrypoints.pooling.embed.protocol import EmbeddingResponse
from vllm.platforms import current_platform from vllm.platforms import current_platform
if current_platform.is_rocm(): if current_platform.is_rocm():

View File

@ -8,7 +8,7 @@ import requests
from transformers import AutoProcessor from transformers import AutoProcessor
from tests.utils import VLLM_PATH, RemoteOpenAIServer from tests.utils import VLLM_PATH, RemoteOpenAIServer
from vllm.entrypoints.openai.protocol import EmbeddingResponse from vllm.entrypoints.pooling.embed.protocol import EmbeddingResponse
from vllm.multimodal.utils import encode_image_base64, fetch_image from vllm.multimodal.utils import encode_image_base64, fetch_image
MODEL_NAME = "TIGER-Lab/VLM2Vec-Full" MODEL_NAME = "TIGER-Lab/VLM2Vec-Full"

View File

@ -11,7 +11,7 @@ import torch
from tests.models.utils import check_embeddings_close from tests.models.utils import check_embeddings_close
from tests.utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.openai.protocol import PoolingResponse from vllm.entrypoints.pooling.pooling.protocol import PoolingResponse
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils.serial_utils import ( from vllm.utils.serial_utils import (
EMBED_DTYPE_TO_TORCH_DTYPE, EMBED_DTYPE_TO_TORCH_DTYPE,
@ -158,11 +158,7 @@ async def test_conversation_pooling(server: RemoteOpenAIServer, model_name: str)
chat_response.raise_for_status() chat_response.raise_for_status()
chat_poolings = PoolingResponse.model_validate(chat_response.json()) chat_poolings = PoolingResponse.model_validate(chat_response.json())
tokenizer = get_tokenizer( tokenizer = get_tokenizer(tokenizer_name=model_name, trust_remote_code=True)
tokenizer_name=model_name,
tokenizer_mode="fast",
trust_remote_code=True,
)
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
messages, messages,
chat_template=DUMMY_CHAT_TEMPLATE, chat_template=DUMMY_CHAT_TEMPLATE,

Some files were not shown because too many files have changed in this diff Show More