mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-25 14:06:54 +08:00
Merge branch 'jeejeelee:mlm-full-lora-support' into mlm-full-lora-support
This commit is contained in:
commit
cd32aeadfa
@ -7,7 +7,7 @@ vLLM also maintains a continuous performance benchmark under [perf.vllm.ai](http
|
||||
|
||||
## Performance benchmark quick overview
|
||||
|
||||
**Benchmarking Coverage**: latency, throughput and fix-qps serving on B200, A100, H100, Intel® Xeon® Processors and Intel® Gaudi® 3 Accelerators with different models.
|
||||
**Benchmarking Coverage**: latency, throughput and fix-qps serving on B200, A100, H100, Intel® Xeon® Processors, Intel® Gaudi® 3 Accelerators and Arm® Neoverse™ with different models.
|
||||
|
||||
**Benchmarking Duration**: about 1hr.
|
||||
|
||||
@ -23,7 +23,7 @@ bash .buildkite/performance-benchmarks/scripts/run-performance-benchmarks.sh
|
||||
|
||||
Runtime environment variables:
|
||||
|
||||
- `ON_CPU`: set the value to '1' on Intel® Xeon® Processors. Default value is 0.
|
||||
- `ON_CPU`: set the value to '1' on Intel® Xeon® and Arm® Neoverse™ Processors. Default value is 0.
|
||||
- `SERVING_JSON`: JSON file to use for the serving tests. Default value is empty string (use default file).
|
||||
- `LATENCY_JSON`: JSON file to use for the latency tests. Default value is empty string (use default file).
|
||||
- `THROUGHPUT_JSON`: JSON file to use for the throughout tests. Default value is empty string (use default file).
|
||||
@ -34,8 +34,9 @@ Runtime environment variables:
|
||||
|
||||
See [performance-benchmarks-descriptions.md](performance-benchmarks-descriptions.md) for detailed descriptions, and use `tests/latency-tests.json`, `tests/throughput-tests.json`, `tests/serving-tests.json` to configure the test cases.
|
||||
> NOTE: For Intel® Xeon® Processors, use `tests/latency-tests-cpu.json`, `tests/throughput-tests-cpu.json`, `tests/serving-tests-cpu.json` instead.
|
||||
For Intel® Gaudi® 3 Accelerators, use `tests/latency-tests-hpu.json`, `tests/throughput-tests-hpu.json`, `tests/serving-tests-hpu.json` instead.
|
||||
>
|
||||
> For Intel® Gaudi® 3 Accelerators, use `tests/latency-tests-hpu.json`, `tests/throughput-tests-hpu.json`, `tests/serving-tests-hpu.json` instead.
|
||||
> For Arm® Neoverse™, use `tests/latency-tests-arm64-cpu.json`, `tests/throughput-tests-arm64-cpu.json`, `tests/serving-tests-arm64-cpu.json` instead.
|
||||
|
||||
### Latency test
|
||||
|
||||
Here is an example of one test inside `latency-tests.json`:
|
||||
|
||||
24
.buildkite/performance-benchmarks/scripts/run-performance-benchmarks.sh
Normal file → Executable file
24
.buildkite/performance-benchmarks/scripts/run-performance-benchmarks.sh
Normal file → Executable file
@ -49,7 +49,11 @@ check_cpus() {
|
||||
echo "Need at least 1 NUMA to run benchmarking."
|
||||
exit 1
|
||||
fi
|
||||
declare -g gpu_type="cpu"
|
||||
if [[ "$(uname -m)" == "aarch64" ]] || [[ "$(uname -m)" == "arm64" ]]; then
|
||||
declare -g gpu_type="arm64-cpu"
|
||||
else
|
||||
declare -g gpu_type="cpu"
|
||||
fi
|
||||
echo "GPU type is $gpu_type"
|
||||
}
|
||||
|
||||
@ -207,8 +211,8 @@ run_latency_tests() {
|
||||
|
||||
# check if there is enough GPU to run the test
|
||||
tp=$(echo "$latency_params" | jq -r '.tensor_parallel_size')
|
||||
if [ "$ON_CPU" == "1" ]; then
|
||||
pp=$(echo "$latency_params" | jq -r '.pipeline_parallel_size')
|
||||
if [[ "$ON_CPU" == "1" ]]; then
|
||||
pp=$(echo "$latency_params" | jq -r '.pipeline_parallel_size // 1')
|
||||
world_size=$(($tp*$pp))
|
||||
if [[ $numa_count -lt $world_size && -z "${REMOTE_HOST}" ]]; then
|
||||
echo "Required world-size $world_size but only $numa_count NUMA nodes found. Skip testcase $test_name."
|
||||
@ -276,8 +280,8 @@ run_throughput_tests() {
|
||||
|
||||
# check if there is enough GPU to run the test
|
||||
tp=$(echo "$throughput_params" | jq -r '.tensor_parallel_size')
|
||||
if [ "$ON_CPU" == "1" ]; then
|
||||
pp=$(echo "$throughput_params" | jq -r '.pipeline_parallel_size')
|
||||
if [[ "$ON_CPU" == "1" ]]; then
|
||||
pp=$(echo "$throughput_params" | jq -r '.pipeline_parallel_size // 1')
|
||||
world_size=$(($tp*$pp))
|
||||
if [[ $numa_count -lt $world_size && -z "${REMOTE_HOST}" ]]; then
|
||||
echo "Required world-size $world_size but only $numa_count NUMA nodes found. Skip testcase $test_name."
|
||||
@ -393,8 +397,8 @@ run_serving_tests() {
|
||||
|
||||
# check if there is enough resources to run the test
|
||||
tp=$(echo "$server_params" | jq -r '.tensor_parallel_size')
|
||||
if [ "$ON_CPU" == "1" ]; then
|
||||
pp=$(echo "$server_params" | jq -r '.pipeline_parallel_size')
|
||||
if [[ "$ON_CPU" == "1" ]]; then
|
||||
pp=$(echo "$server_params" | jq -r '.pipeline_parallel_size // 1')
|
||||
world_size=$(($tp*$pp))
|
||||
if [[ $numa_count -lt $world_size && -z "${REMOTE_HOST}" ]]; then
|
||||
echo "Required world-size $world_size but only $numa_count NUMA nodes found. Skip testcase $test_name."
|
||||
@ -496,9 +500,9 @@ run_serving_tests() {
|
||||
main() {
|
||||
local ARCH
|
||||
ARCH=''
|
||||
if [ "$ON_CPU" == "1" ];then
|
||||
check_cpus
|
||||
ARCH='-cpu'
|
||||
if [[ "$ON_CPU" == "1" ]]; then
|
||||
check_cpus
|
||||
ARCH="-$gpu_type"
|
||||
else
|
||||
check_gpus
|
||||
ARCH="$arch_suffix"
|
||||
|
||||
@ -0,0 +1,26 @@
|
||||
[
|
||||
{
|
||||
"test_name": "latency_llama8B_tp1",
|
||||
"environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_ENGINE_ITERATION_TIMEOUT_S": 120,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"parameters": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 1,
|
||||
"load_format": "dummy",
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
"block_size": 128,
|
||||
"trust_remote_code": "",
|
||||
"disable_log_stats": "",
|
||||
"enforce_eager": "",
|
||||
"max_num_batched_tokens": 2048,
|
||||
"max_num_seqs": 256,
|
||||
"num_iters_warmup": 5,
|
||||
"num_iters": 15
|
||||
}
|
||||
}
|
||||
]
|
||||
@ -0,0 +1,130 @@
|
||||
{
|
||||
"defaults": {
|
||||
"qps_list": [
|
||||
"inf"
|
||||
],
|
||||
"max_concurrency_list": [
|
||||
12,
|
||||
16,
|
||||
24,
|
||||
32,
|
||||
64,
|
||||
128,
|
||||
200
|
||||
],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_ENGINE_ITERATION_TIMEOUT_S": 120,
|
||||
"VLLM_CPU_SGL_KERNEL": 1,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 1,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
"block_size": 128,
|
||||
"trust_remote_code": "",
|
||||
"disable_log_stats": "",
|
||||
"enforce_eager": "",
|
||||
"max_num_batched_tokens": 2048,
|
||||
"max_num_seqs": 256,
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"ignore-eos": "",
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
"tests": [
|
||||
{
|
||||
"test_name": "serving_llama8B_tp1_sharegpt",
|
||||
"server_parameters": {
|
||||
"tensor_parallel_size": 1
|
||||
},
|
||||
"client_parameters": {
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json"
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_tp2_sharegpt",
|
||||
"server_parameters": {
|
||||
"tensor_parallel_size": 2
|
||||
},
|
||||
"client_parameters": {
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json"
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_tp1_random_128_128",
|
||||
"server_parameters": {
|
||||
"tensor_parallel_size": 1
|
||||
},
|
||||
"client_parameters": {
|
||||
"dataset_name": "random",
|
||||
"random-input-len": 128,
|
||||
"random-output-len": 128
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_tp2_random_128_128",
|
||||
"server_parameters": {
|
||||
"tensor_parallel_size": 2
|
||||
},
|
||||
"client_parameters": {
|
||||
"dataset_name": "random",
|
||||
"random-input-len": 128,
|
||||
"random-output-len": 128
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_tp1_random_128_2048",
|
||||
"server_parameters": {
|
||||
"tensor_parallel_size": 1
|
||||
},
|
||||
"client_parameters": {
|
||||
"dataset_name": "random",
|
||||
"random-input-len": 128,
|
||||
"random-output-len": 2048
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_tp2_random_128_2048",
|
||||
"server_parameters": {
|
||||
"tensor_parallel_size": 2
|
||||
},
|
||||
"client_parameters": {
|
||||
"dataset_name": "random",
|
||||
"random-input-len": 128,
|
||||
"random-output-len": 2048
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_tp1_random_2048_128",
|
||||
"server_parameters": {
|
||||
"tensor_parallel_size": 1
|
||||
},
|
||||
"client_parameters": {
|
||||
"dataset_name": "random",
|
||||
"random-input-len": 2048,
|
||||
"random-output-len": 128
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_tp2_random_2048_128",
|
||||
"server_parameters": {
|
||||
"tensor_parallel_size": 2
|
||||
},
|
||||
"client_parameters": {
|
||||
"dataset_name": "random",
|
||||
"random-input-len": 2048,
|
||||
"random-output-len": 128
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
@ -0,0 +1,27 @@
|
||||
[
|
||||
{
|
||||
"test_name": "throughput_llama8B_tp1",
|
||||
"environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_ENGINE_ITERATION_TIMEOUT_S": 120,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"parameters": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 1,
|
||||
"load_format": "dummy",
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
"block_size": 128,
|
||||
"trust_remote_code": "",
|
||||
"disable_log_stats": "",
|
||||
"enforce_eager": "",
|
||||
"max_num_batched_tokens": 2048,
|
||||
"max_num_seqs": 256,
|
||||
"dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"num_prompts": 200,
|
||||
"backend": "vllm"
|
||||
}
|
||||
}
|
||||
]
|
||||
@ -141,7 +141,6 @@ if [[ $commands == *" entrypoints/openai "* ]]; then
|
||||
--ignore=entrypoints/openai/test_audio.py \
|
||||
--ignore=entrypoints/openai/test_shutdown.py \
|
||||
--ignore=entrypoints/openai/test_completion.py \
|
||||
--ignore=entrypoints/openai/test_sleep.py \
|
||||
--ignore=entrypoints/openai/test_models.py \
|
||||
--ignore=entrypoints/openai/test_lora_adapters.py \
|
||||
--ignore=entrypoints/openai/test_return_tokens_as_ids.py \
|
||||
|
||||
@ -50,6 +50,7 @@ function cpu_tests() {
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
pytest -x -v -s tests/kernels/attention/test_cpu_attn.py
|
||||
pytest -x -v -s tests/kernels/moe/test_cpu_fused_moe.py
|
||||
pytest -x -v -s tests/kernels/test_onednn.py"
|
||||
|
||||
# Run basic model test
|
||||
|
||||
@ -39,7 +39,7 @@ docker run \
|
||||
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
|
||||
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
|
||||
python3 examples/offline_inference/basic/generate.py --model Intel/Qwen2.5-0.5B-W4A16-G128-AutoRound-LLMC-TEST-ONLY --enforce-eager
|
||||
VLLM_ATTENTION_BACKEND=TRITON_ATTN python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
|
||||
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager --attention-backend=TRITON_ATTN
|
||||
cd tests
|
||||
pytest -v -s v1/core
|
||||
pytest -v -s v1/engine
|
||||
|
||||
@ -44,10 +44,10 @@ trap cleanup EXIT
|
||||
|
||||
for BACK in "${BACKENDS[@]}"; do
|
||||
VLLM_DEEP_GEMM_WARMUP=skip \
|
||||
VLLM_ALL2ALL_BACKEND=$BACK \
|
||||
vllm serve "$MODEL" \
|
||||
--enforce-eager \
|
||||
--enable-eplb \
|
||||
--all2all-backend $BACK \
|
||||
--eplb-config '{"window_size":10, "step_interval":100, "num_redundant_experts":0, "log_balancedness":true}' \
|
||||
--tensor-parallel-size ${TENSOR_PARALLEL_SIZE} \
|
||||
--data-parallel-size ${DATA_PARALLEL_SIZE} \
|
||||
|
||||
@ -43,12 +43,12 @@ trap cleanup EXIT
|
||||
|
||||
for BACK in "${BACKENDS[@]}"; do
|
||||
VLLM_DEEP_GEMM_WARMUP=skip \
|
||||
VLLM_ALL2ALL_BACKEND=$BACK \
|
||||
vllm serve "$MODEL" \
|
||||
--enforce-eager \
|
||||
--tensor-parallel-size 4 \
|
||||
--enable-expert-parallel \
|
||||
--enable-eplb \
|
||||
--all2all-backend $BACK \
|
||||
--eplb-config '{"window_size":200,"step_interval":600,"use_async":true}' \
|
||||
--speculative-config '{"method":"qwen3_next_mtp","num_speculative_tokens":1}' \
|
||||
--trust-remote-code \
|
||||
|
||||
@ -128,7 +128,7 @@ steps:
|
||||
- tests/entrypoints/
|
||||
commands:
|
||||
- pytest -v -s entrypoints/openai/tool_parsers
|
||||
- pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/openai --ignore=entrypoints/offline_mode --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling
|
||||
- pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/openai --ignore=entrypoints/rpc --ignore=entrypoints/sleep --ignore=entrypoints/instrumentator --ignore=entrypoints/offline_mode --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling
|
||||
|
||||
- label: Entrypoints Integration Test (LLM) # 30min
|
||||
timeout_in_minutes: 40
|
||||
@ -148,7 +148,7 @@ steps:
|
||||
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
|
||||
|
||||
- label: Entrypoints Integration Test (API Server) # 100min
|
||||
- label: Entrypoints Integration Test (API Server 1) # 100min
|
||||
timeout_in_minutes: 130
|
||||
mirror_hardwares: [amdexperimental]
|
||||
agent_pool: mi325_1
|
||||
@ -162,10 +162,28 @@ steps:
|
||||
- tests/entrypoints/test_chat_utils
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/openai/test_collective_rpc.py # PYTHONPATH is needed to import custom Worker extension
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_collective_rpc.py --ignore=entrypoints/openai/tool_parsers/
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/tool_parsers/
|
||||
- pytest -v -s entrypoints/test_chat_utils.py
|
||||
|
||||
- label: Entrypoints Integration Test (API Server 2)
|
||||
timeout_in_minutes: 50
|
||||
mirror_hardwares: [amdexperimental]
|
||||
agent_pool: mi325_1
|
||||
# grade: Blocking
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
fast_check: true
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/entrypoints/sleep
|
||||
- tests/entrypoints/rpc
|
||||
- tests/tool_use
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s entrypoints/sleep
|
||||
- pytest -v -s tool_use
|
||||
- PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/rpc
|
||||
|
||||
- label: Entrypoints Integration Test (Pooling)
|
||||
timeout_in_minutes: 50
|
||||
mirror_hardwares: [amdexperimental]
|
||||
@ -722,7 +740,7 @@ steps:
|
||||
# https://github.com/pytorch/ao/issues/2919, we'll have to skip new torchao tests for now
|
||||
# we can only upgrade after this is resolved
|
||||
# TODO(jerryzh168): resolve the above comment
|
||||
- uv pip install --system torchao==0.13.0
|
||||
- uv pip install --system torchao==0.14.1
|
||||
- uv pip install --system conch-triton-kernels
|
||||
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/ --ignore quantization/test_blackwell_moe.py
|
||||
|
||||
@ -736,7 +754,7 @@ steps:
|
||||
- vllm/model_executor/layers/quantization
|
||||
autorun_on_main: true
|
||||
commands:
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt
|
||||
|
||||
- label: OpenAI API correctness # 10min
|
||||
timeout_in_minutes: 15
|
||||
@ -751,17 +769,6 @@ steps:
|
||||
# Transcription WER check is skipped because encoder-decoder models are not supported on ROCm, see https://github.com/vllm-project/vllm/issues/27442
|
||||
- pytest -s entrypoints/openai/correctness/
|
||||
|
||||
- label: OpenAI-Compatible Tool Use # 23 min
|
||||
timeout_in_minutes: 35
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
agent_pool: mi325_1
|
||||
# grade: Blocking
|
||||
fast_check: false
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/tool_use
|
||||
commands:
|
||||
- pytest -v -s tool_use
|
||||
|
||||
##### models test #####
|
||||
|
||||
@ -957,7 +964,7 @@ steps:
|
||||
- pytest -v -s models/multimodal/processing
|
||||
|
||||
- label: Multi-Modal Models Test (Standard) # 60min
|
||||
timeout_in_minutes: 80
|
||||
timeout_in_minutes: 100
|
||||
mirror_hardwares: [amdexperimental]
|
||||
agent_pool: mi325_1
|
||||
# grade: Blocking
|
||||
@ -966,13 +973,15 @@ steps:
|
||||
- vllm/
|
||||
- tests/models/multimodal
|
||||
commands:
|
||||
- export MIOPEN_DEBUG_CONV_DIRECT=0
|
||||
- export MIOPEN_DEBUG_CONV_GEMM=0
|
||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
- pip freeze | grep -E 'torch'
|
||||
- pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing
|
||||
- cd .. && VLLM_WORKER_MULTIPROC_METHOD=spawn pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work
|
||||
|
||||
- label: Multi-Modal Accuracy Eval (Small Models) # 150min - 180min
|
||||
timeout_in_minutes: 180
|
||||
- label: Multi-Modal Accuracy Eval (Small Models) # 5min
|
||||
timeout_in_minutes: 10
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
agent_pool: mi325_1
|
||||
# grade: Blocking
|
||||
@ -982,7 +991,9 @@ steps:
|
||||
- vllm/inputs/
|
||||
- vllm/v1/core/
|
||||
commands:
|
||||
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-mm-small.txt --tp-size=1
|
||||
- export MIOPEN_DEBUG_CONV_DIRECT=0
|
||||
- export MIOPEN_DEBUG_CONV_GEMM=0
|
||||
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-mm-small.txt
|
||||
|
||||
- label: Multi-Modal Models Test (Extended) 1 # 60min
|
||||
timeout_in_minutes: 120
|
||||
@ -994,10 +1005,13 @@ steps:
|
||||
- vllm/
|
||||
- tests/models/multimodal
|
||||
commands:
|
||||
- export MIOPEN_DEBUG_CONV_DIRECT=0
|
||||
- export MIOPEN_DEBUG_CONV_GEMM=0
|
||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
- pytest -v -s models/multimodal -m 'not core_model' --ignore models/multimodal/generation/test_common.py --ignore models/multimodal/processing
|
||||
|
||||
- label: Multi-Modal Models Test (Extended) 2
|
||||
- label: Multi-Modal Models Test (Extended) 2 #60min
|
||||
timeout_in_minutes: 120
|
||||
mirror_hardwares: [amdexperimental]
|
||||
agent_pool: mi325_1
|
||||
# grade: Blocking
|
||||
@ -1006,6 +1020,8 @@ steps:
|
||||
- vllm/
|
||||
- tests/models/multimodal
|
||||
commands:
|
||||
- export MIOPEN_DEBUG_CONV_DIRECT=0
|
||||
- export MIOPEN_DEBUG_CONV_GEMM=0
|
||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
- pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=0) and not core_model'
|
||||
|
||||
@ -1019,6 +1035,8 @@ steps:
|
||||
- vllm/
|
||||
- tests/models/multimodal
|
||||
commands:
|
||||
- export MIOPEN_DEBUG_CONV_DIRECT=0
|
||||
- export MIOPEN_DEBUG_CONV_GEMM=0
|
||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
- pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=1) and not core_model'
|
||||
|
||||
@ -1196,7 +1214,7 @@ steps:
|
||||
- csrc/
|
||||
- vllm/model_executor/layers/quantization
|
||||
commands:
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-blackwell.txt --tp-size=1
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-blackwell.txt
|
||||
|
||||
##### 1 GPU test #####
|
||||
##### multi gpus test #####
|
||||
@ -1490,7 +1508,7 @@ steps:
|
||||
- "VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_fusions_e2e.py -k 'not Llama-4'"
|
||||
- VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/distributed/test_sequence_parallel.py
|
||||
- pytest -v -s tests/distributed/test_context_parallel.py
|
||||
- HIP_VISIBLE_DEVICES=0,1 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048
|
||||
- HIP_VISIBLE_DEVICES=0,1 VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 --all2all-backend deepep_high_throughput
|
||||
- pytest -v -s tests/v1/distributed/test_dbo.py
|
||||
|
||||
##### B200 test #####
|
||||
@ -1514,7 +1532,7 @@ steps:
|
||||
- csrc/
|
||||
- vllm/model_executor/layers/quantization
|
||||
commands:
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt
|
||||
|
||||
- label: LM Eval Large Models (4 Card)
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
|
||||
@ -114,7 +114,7 @@ steps:
|
||||
- tests/entrypoints/
|
||||
commands:
|
||||
- pytest -v -s entrypoints/openai/tool_parsers
|
||||
- pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/openai --ignore=entrypoints/offline_mode --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling
|
||||
- pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/rpc --ignore=entrypoints/sleep --ignore=entrypoints/instrumentator --ignore=entrypoints/openai --ignore=entrypoints/offline_mode --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling
|
||||
|
||||
- label: Entrypoints Integration Test (LLM) # 30min
|
||||
timeout_in_minutes: 40
|
||||
@ -132,7 +132,7 @@ steps:
|
||||
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
|
||||
|
||||
- label: Entrypoints Integration Test (API Server) # 100min
|
||||
- label: Entrypoints Integration Test (API Server 1) # 100min
|
||||
timeout_in_minutes: 130
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
@ -144,10 +144,26 @@ steps:
|
||||
- tests/entrypoints/test_chat_utils
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/openai/test_collective_rpc.py # PYTHONPATH is needed to import custom Worker extension
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_collective_rpc.py --ignore=entrypoints/openai/tool_parsers/
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/tool_parsers/
|
||||
- pytest -v -s entrypoints/test_chat_utils.py
|
||||
|
||||
- label: Entrypoints Integration Test (API Server 2)
|
||||
timeout_in_minutes: 50
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
fast_check: true
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/entrypoints/sleep
|
||||
- tests/entrypoints/rpc
|
||||
- tests/tool_use
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s entrypoints/sleep
|
||||
- PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/rpc
|
||||
- pytest -v -s tool_use
|
||||
|
||||
- label: Entrypoints Integration Test (Pooling)
|
||||
timeout_in_minutes: 50
|
||||
mirror_hardwares: [amdexperimental]
|
||||
@ -303,7 +319,10 @@ steps:
|
||||
# TODO: accuracy does not match, whether setting
|
||||
# VLLM_USE_FLASHINFER_SAMPLER or not on H100.
|
||||
- pytest -v -s v1/e2e
|
||||
- pytest -v -s v1/engine
|
||||
# Run this test standalone for now;
|
||||
# need to untangle use (implicit) use of spawn/fork across the tests.
|
||||
- pytest -v -s v1/engine/test_preprocess_error_handling.py
|
||||
- pytest -v -s v1/engine --ignore v1/engine/test_preprocess_error_handling.py
|
||||
|
||||
- label: V1 Test entrypoints # 35min
|
||||
timeout_in_minutes: 50
|
||||
@ -642,7 +661,7 @@ steps:
|
||||
# https://github.com/pytorch/ao/issues/2919, we'll have to skip new torchao tests for now
|
||||
# we can only upgrade after this is resolved
|
||||
# TODO(jerryzh168): resolve the above comment
|
||||
- uv pip install --system torchao==0.13.0 --index-url https://download.pytorch.org/whl/cu129
|
||||
- uv pip install --system torchao==0.14.1 --index-url https://download.pytorch.org/whl/cu129
|
||||
- uv pip install --system conch-triton-kernels
|
||||
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/ --ignore quantization/test_blackwell_moe.py
|
||||
|
||||
@ -654,7 +673,7 @@ steps:
|
||||
- vllm/model_executor/layers/quantization
|
||||
autorun_on_main: true
|
||||
commands:
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt
|
||||
|
||||
- label: OpenAI API correctness # 22min
|
||||
timeout_in_minutes: 30
|
||||
@ -666,16 +685,6 @@ steps:
|
||||
commands: # LMEval+Transcription WER check
|
||||
- pytest -s entrypoints/openai/correctness/
|
||||
|
||||
- label: OpenAI-Compatible Tool Use # 23 min
|
||||
timeout_in_minutes: 35
|
||||
mirror_hardwares: [amdexperimental]
|
||||
fast_check: false
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/tool_use
|
||||
commands:
|
||||
- pytest -v -s tool_use
|
||||
|
||||
##### models test #####
|
||||
|
||||
- label: Basic Models Tests (Initialization)
|
||||
@ -1064,7 +1073,7 @@ steps:
|
||||
- csrc/
|
||||
- vllm/model_executor/layers/quantization
|
||||
commands:
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-blackwell.txt --tp-size=1
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-blackwell.txt
|
||||
|
||||
##### 1 GPU test #####
|
||||
##### multi gpus test #####
|
||||
@ -1325,7 +1334,7 @@ steps:
|
||||
- "VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_fusions_e2e.py -k 'not Llama-4'"
|
||||
- VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/distributed/test_sequence_parallel.py
|
||||
- pytest -v -s tests/distributed/test_context_parallel.py
|
||||
- CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048
|
||||
- CUDA_VISIBLE_DEVICES=1,2 VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 --all2all-backend deepep_high_throughput
|
||||
- pytest -v -s tests/v1/distributed/test_dbo.py
|
||||
|
||||
##### B200 test #####
|
||||
|
||||
@ -145,7 +145,7 @@ steps:
|
||||
- VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_fusions_e2e.py -k 'not Llama-4'
|
||||
- VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/distributed/test_sequence_parallel.py
|
||||
- pytest -v -s tests/distributed/test_context_parallel.py
|
||||
- CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048
|
||||
- CUDA_VISIBLE_DEVICES=1,2 VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 --all2all-backend deepep_high_throughput
|
||||
- pytest -v -s tests/v1/distributed/test_dbo.py
|
||||
|
||||
- label: Distributed Tests (2 GPUs)(B200)
|
||||
|
||||
@ -32,6 +32,7 @@ steps:
|
||||
- label: Prime-RL Integration (2 GPUs)
|
||||
timeout_in_minutes: 30
|
||||
optional: true
|
||||
soft_fail: true
|
||||
num_gpus: 2
|
||||
working_dir: "/vllm-workspace"
|
||||
source_file_dependencies:
|
||||
@ -39,21 +40,3 @@ steps:
|
||||
- .buildkite/scripts/run-prime-rl-test.sh
|
||||
commands:
|
||||
- bash .buildkite/scripts/run-prime-rl-test.sh
|
||||
|
||||
- label: DeepSeek V2-Lite Async EPLB Accuracy
|
||||
timeout_in_minutes: 60
|
||||
gpu: h100
|
||||
optional: true
|
||||
num_gpus: 4
|
||||
working_dir: "/vllm-workspace"
|
||||
commands:
|
||||
- bash .buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_async_eplb.sh 0.25 1319 8030
|
||||
|
||||
- label: Qwen3-Next-80B-A3B-Instruct MTP Async EPLB Accuracy
|
||||
timeout_in_minutes: 60
|
||||
gpu: h100
|
||||
optional: true
|
||||
num_gpus: 4
|
||||
working_dir: "/vllm-workspace"
|
||||
commands:
|
||||
- bash .buildkite/scripts/scheduled_integration_test/qwen3_next_mtp_async_eplb.sh 0.8 1319 8040
|
||||
|
||||
@ -10,7 +10,7 @@ steps:
|
||||
- tests/entrypoints/
|
||||
commands:
|
||||
- pytest -v -s entrypoints/openai/tool_parsers
|
||||
- pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/openai --ignore=entrypoints/offline_mode --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling
|
||||
- pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/rpc --ignore=entrypoints/sleep --ignore=entrypoints/instrumentator --ignore=entrypoints/openai --ignore=entrypoints/offline_mode --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling
|
||||
|
||||
- label: Entrypoints Integration (LLM)
|
||||
timeout_in_minutes: 40
|
||||
@ -25,7 +25,7 @@ steps:
|
||||
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
|
||||
|
||||
- label: Entrypoints Integration (API Server)
|
||||
- label: Entrypoints Integration (API Server 1)
|
||||
timeout_in_minutes: 130
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
source_file_dependencies:
|
||||
@ -34,11 +34,26 @@ steps:
|
||||
- tests/entrypoints/test_chat_utils
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/openai/test_collective_rpc.py # PYTHONPATH is needed to import custom Worker extension
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_collective_rpc.py --ignore=entrypoints/openai/tool_parsers/
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/tool_parsers/
|
||||
- pytest -v -s entrypoints/test_chat_utils.py
|
||||
|
||||
|
||||
- label: Entrypoints Integration (API Server 2)
|
||||
timeout_in_minutes: 130
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/tool_use
|
||||
- tests/entrypoints/sleep
|
||||
- tests/entrypoints/instrumentator
|
||||
- tests/entrypoints/rpc
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/rpc
|
||||
- pytest -v -s entrypoints/instrumentator
|
||||
- pytest -v -s entrypoints/sleep
|
||||
- pytest -v -s tool_use
|
||||
|
||||
- label: Entrypoints Integration (Pooling)
|
||||
timeout_in_minutes: 50
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
|
||||
@ -9,7 +9,7 @@ steps:
|
||||
- vllm/model_executor/layers/quantization
|
||||
autorun_on_main: true
|
||||
commands:
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt
|
||||
|
||||
- label: LM Eval Large Models (4 GPUs)(A100)
|
||||
gpu: a100
|
||||
@ -43,4 +43,4 @@ steps:
|
||||
- csrc/
|
||||
- vllm/model_executor/layers/quantization
|
||||
commands:
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-blackwell.txt --tp-size=1
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-blackwell.txt
|
||||
|
||||
@ -22,6 +22,8 @@ steps:
|
||||
# FIXIT: find out which code initialize cuda before running the test
|
||||
# before the fix, we need to use spawn to test it
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
# Alot of these tests are on the edge of OOMing
|
||||
- export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
||||
# There is some Tensor Parallelism related processing logic in LoRA that
|
||||
# requires multi-GPU testing for validation.
|
||||
- pytest -v -s -x lora/test_chatglm3_tp.py
|
||||
|
||||
@ -9,6 +9,7 @@ steps:
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/models/test_initialization.py
|
||||
- tests/models/registry.py
|
||||
commands:
|
||||
# Run a subset of model initialization tests
|
||||
- pytest -v -s models/test_initialization.py::test_can_initialize_small_subset
|
||||
@ -20,6 +21,7 @@ steps:
|
||||
source_file_dependencies:
|
||||
- vllm/model_executor/models/
|
||||
- tests/models/test_initialization.py
|
||||
- tests/models/registry.py
|
||||
commands:
|
||||
# Only when vLLM model source is modified - test initialization of a large
|
||||
# subset of supported models (the complement of the small subset in the above
|
||||
|
||||
@ -13,7 +13,9 @@ steps:
|
||||
# tests covered elsewhere.
|
||||
# Use `find` to launch multiple instances of pytest so that
|
||||
# they do not suffer from https://github.com/vllm-project/vllm/issues/28965
|
||||
- "find compile/ -maxdepth 1 -name 'test_*.py' -exec pytest -s -v {} \\;"
|
||||
# However, find does not normally propagate error codes, so we combine it with xargs
|
||||
# (using -0 for proper path handling)
|
||||
- "find compile/ -maxdepth 1 -name 'test_*.py' -print0 | xargs -0 -n1 -I{} pytest -s -v '{}'"
|
||||
|
||||
- label: PyTorch Fullgraph Smoke Test
|
||||
timeout_in_minutes: 30
|
||||
|
||||
@ -1,13 +0,0 @@
|
||||
group: Tool use
|
||||
depends_on:
|
||||
- image-build
|
||||
steps:
|
||||
- label: OpenAI-Compatible Tool Use
|
||||
timeout_in_minutes: 35
|
||||
mirror_hardwares: [amdexperimental]
|
||||
fast_check: false
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/tool_use
|
||||
commands:
|
||||
- pytest -v -s tool_use
|
||||
26
.github/mergify.yml
vendored
26
.github/mergify.yml
vendored
@ -235,6 +235,20 @@ pull_request_rules:
|
||||
add:
|
||||
- rocm
|
||||
|
||||
- name: label-cpu
|
||||
description: Automatically apply cpu label
|
||||
conditions:
|
||||
- label != stale
|
||||
- files~=^(?!.*kv_offload)(?!.*cpu_offload).*\bcpu.*
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
- cpu
|
||||
assign:
|
||||
users:
|
||||
- "fadara01"
|
||||
- "aditew01"
|
||||
|
||||
- name: label-structured-output
|
||||
description: Automatically apply structured-output label
|
||||
conditions:
|
||||
@ -335,6 +349,18 @@ pull_request_rules:
|
||||
add:
|
||||
- tool-calling
|
||||
|
||||
- name: auto-rebase if approved, ready, and 40 commits behind main
|
||||
conditions:
|
||||
- base = main
|
||||
- label=ready
|
||||
- "#approved-reviews-by >= 1"
|
||||
- "#commits-behind >= 40"
|
||||
- -closed
|
||||
- -draft
|
||||
- -conflict
|
||||
actions:
|
||||
rebase: {}
|
||||
|
||||
- name: ping author on conflicts and add 'needs-rebase' label
|
||||
conditions:
|
||||
- label != stale
|
||||
|
||||
131
CMakeLists.txt
131
CMakeLists.txt
@ -56,8 +56,8 @@ endif()
|
||||
# requirements.txt files and should be kept consistent. The ROCm torch
|
||||
# versions are derived from docker/Dockerfile.rocm
|
||||
#
|
||||
set(TORCH_SUPPORTED_VERSION_CUDA "2.9.0")
|
||||
set(TORCH_SUPPORTED_VERSION_ROCM "2.9.0")
|
||||
set(TORCH_SUPPORTED_VERSION_CUDA "2.9.1")
|
||||
set(TORCH_SUPPORTED_VERSION_ROCM "2.9.1")
|
||||
|
||||
#
|
||||
# Try to find python package with an executable that exactly matches
|
||||
@ -357,6 +357,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
# marlin arches for fp16 output
|
||||
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0+PTX" "${CUDA_ARCHS}")
|
||||
# marlin has limited support for turing
|
||||
cuda_archs_loose_intersection(MARLIN_SM75_ARCHS "7.5" "${CUDA_ARCHS}")
|
||||
# marlin arches for bf16 output (we need 9.0 for bf16 atomicAdd PTX)
|
||||
cuda_archs_loose_intersection(MARLIN_BF16_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")
|
||||
# marlin arches for fp8 input
|
||||
@ -364,8 +366,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction
|
||||
# so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0)
|
||||
cuda_archs_loose_intersection(MARLIN_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}")
|
||||
# marlin arches for other files
|
||||
cuda_archs_loose_intersection(MARLIN_OTHER_ARCHS "7.5;8.0+PTX" "${CUDA_ARCHS}")
|
||||
|
||||
if (MARLIN_ARCHS)
|
||||
if (MARLIN_OTHER_ARCHS)
|
||||
|
||||
#
|
||||
# For the Marlin kernels we automatically generate sources for various
|
||||
@ -406,25 +410,39 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
message(STATUS "Marlin generation script has not changed, skipping generation.")
|
||||
endif()
|
||||
|
||||
file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_float16.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
|
||||
if (MARLIN_ARCHS)
|
||||
file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_float16.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
|
||||
|
||||
file(GLOB MARLIN_TEMPLATE_BF16_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_bfloat16.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_TEMPLATE_BF16_KERNEL_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_BF16_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_TEMPLATE_BF16_KERNEL_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
file(GLOB MARLIN_TEMPLATE_BF16_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_bfloat16.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_TEMPLATE_BF16_KERNEL_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_BF16_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_TEMPLATE_BF16_KERNEL_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_BF16_KERNEL_SRC})
|
||||
endif()
|
||||
|
||||
if (MARLIN_SM75_ARCHS)
|
||||
file(GLOB MARLIN_TEMPLATE_SM75_KERNEL_SRC "csrc/quantization/gptq_marlin/sm75_kernel_*.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_TEMPLATE_SM75_KERNEL_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_SM75_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_TEMPLATE_SM75_KERNEL_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_SM75_KERNEL_SRC})
|
||||
endif()
|
||||
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_BF16_KERNEL_SRC})
|
||||
|
||||
if (MARLIN_FP8_ARCHS)
|
||||
file(GLOB MARLIN_TEMPLATE_FP8_KERNEL_SRC "csrc/quantization/gptq_marlin/sm89_kernel_*.cu")
|
||||
@ -446,14 +464,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_SRCS}"
|
||||
CUDA_ARCHS "${MARLIN_ARCHS}")
|
||||
CUDA_ARCHS "${MARLIN_OTHER_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties("csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
||||
set_source_files_properties(${MARLIN_SRCS}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_EXT_SRC "${MARLIN_SRCS}")
|
||||
|
||||
message(STATUS "Building Marlin kernels for archs: ${MARLIN_ARCHS}")
|
||||
message(STATUS "Building Marlin kernels for archs: ${MARLIN_OTHER_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building Marlin kernels as no compatible archs found"
|
||||
" in CUDA target architectures")
|
||||
@ -781,24 +799,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
else()
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS "csrc/quantization/w8a8/cutlass/moe/blockwise_scaled_group_mm_sm100.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
|
||||
message(STATUS "Building blockwise_scaled_group_mm_sm100 for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building blockwise_scaled_group_mm_sm100 kernels as CUDA Compiler version is "
|
||||
"not >= 12.8, we recommend upgrading to CUDA 12.8 or later "
|
||||
"if you intend on running FP8 quantized MoE models on Blackwell.")
|
||||
else()
|
||||
message(STATUS "Not building blockwise_scaled_group_mm_sm100 as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
#
|
||||
# Machete kernels
|
||||
@ -980,12 +980,16 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# note that we always set `use_atomic_add=False` for moe marlin now,
|
||||
# so we don't need 9.0 for bf16 atomicAdd PTX
|
||||
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0+PTX" "${CUDA_ARCHS}")
|
||||
# moe marlin has limited support for turing
|
||||
cuda_archs_loose_intersection(MARLIN_MOE_SM75_ARCHS "7.5" "${CUDA_ARCHS}")
|
||||
# moe marlin arches for fp8 input
|
||||
# - sm80 doesn't support fp8 computation
|
||||
# - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction
|
||||
# so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0)
|
||||
cuda_archs_loose_intersection(MARLIN_MOE_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}")
|
||||
if (MARLIN_MOE_ARCHS)
|
||||
# moe marlin arches for other files
|
||||
cuda_archs_loose_intersection(MARLIN_MOE_OTHER_ARCHS "7.5;8.0+PTX" "${CUDA_ARCHS}")
|
||||
if (MARLIN_MOE_OTHER_ARCHS)
|
||||
|
||||
#
|
||||
# For the Marlin MOE kernels we automatically generate sources for various
|
||||
@ -1026,16 +1030,29 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
message(STATUS "Marlin MOE generation script has not changed, skipping generation.")
|
||||
endif()
|
||||
|
||||
file(GLOB MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/sm80_kernel_*.cu")
|
||||
list(APPEND MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/ops.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_MOE_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_MOE_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_MOE_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
if (MARLIN_MOE_ARCHS)
|
||||
file(GLOB MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/sm80_kernel_*.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_MOE_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_MOE_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_MOE_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_SRC})
|
||||
endif()
|
||||
|
||||
if (MARLIN_MOE_SM75_ARCHS)
|
||||
file(GLOB MARLIN_MOE_SM75_SRC "csrc/moe/marlin_moe_wna16/sm75_kernel_*.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_MOE_SM75_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_MOE_SM75_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_MOE_SM75_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_SM75_SRC})
|
||||
endif()
|
||||
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_SRC})
|
||||
|
||||
if (MARLIN_MOE_FP8_ARCHS)
|
||||
file(GLOB MARLIN_MOE_FP8_SRC "csrc/moe/marlin_moe_wna16/sm89_kernel_*.cu")
|
||||
@ -1049,7 +1066,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_FP8_SRC})
|
||||
endif()
|
||||
|
||||
message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}")
|
||||
set(MARLIN_MOE_OTHER_SRC "csrc/moe/marlin_moe_wna16/ops.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_MOE_OTHER_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_MOE_OTHER_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_MOE_OTHER_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_MOE_EXT_SRC "${MARLIN_MOE_OTHER_SRC}")
|
||||
|
||||
message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_OTHER_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building Marlin MOE kernels as no compatible archs found"
|
||||
" in CUDA target architectures")
|
||||
|
||||
@ -13,8 +13,8 @@ from vllm.triton_utils import triton
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
batch_size_range = [1, 16, 32, 64, 128]
|
||||
seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]
|
||||
batch_size_range = [1, 16, 128]
|
||||
seq_len_range = [1, 16, 64, 1024, 4096]
|
||||
intermediate_size = [3072, 9728, 12288]
|
||||
configs = list(itertools.product(batch_size_range, seq_len_range, intermediate_size))
|
||||
|
||||
|
||||
@ -330,7 +330,7 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON
|
||||
PUBLIC ${oneDNN_BINARY_DIR}/include
|
||||
PRIVATE ${oneDNN_SOURCE_DIR}/src
|
||||
)
|
||||
target_link_libraries(dnnl_ext dnnl)
|
||||
target_link_libraries(dnnl_ext dnnl torch)
|
||||
target_compile_options(dnnl_ext PRIVATE ${CXX_COMPILE_FLAGS} -fPIC)
|
||||
list(APPEND LIBS dnnl_ext)
|
||||
set(USE_ONEDNN ON)
|
||||
@ -358,13 +358,13 @@ set(VLLM_EXT_SRC
|
||||
"csrc/cpu/pos_encoding.cpp"
|
||||
"csrc/moe/dynamic_4bit_int_moe_cpu.cpp"
|
||||
"csrc/cpu/cpu_attn.cpp"
|
||||
"csrc/cpu/scratchpad_manager.cpp"
|
||||
"csrc/cpu/torch_bindings.cpp")
|
||||
|
||||
if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cpu/shm.cpp"
|
||||
"csrc/cpu/cpu_wna16.cpp"
|
||||
"csrc/cpu/cpu_fused_moe.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI)
|
||||
set(VLLM_EXT_SRC
|
||||
|
||||
@ -15,19 +15,61 @@ __device__ __forceinline__ scalar_t compute(const scalar_t& x,
|
||||
const scalar_t& y) {
|
||||
return act_first ? ACT_FN(x) * y : x * ACT_FN(y);
|
||||
}
|
||||
// Activation and gating kernel template.
|
||||
|
||||
// Check if all pointers are 16-byte aligned for int4 vectorized access
|
||||
__device__ __forceinline__ bool is_16byte_aligned(const void* ptr) {
|
||||
return (reinterpret_cast<uintptr_t>(ptr) & 15) == 0;
|
||||
}
|
||||
|
||||
// Activation and gating kernel template.
|
||||
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
|
||||
bool act_first>
|
||||
__global__ void act_and_mul_kernel(
|
||||
scalar_t* __restrict__ out, // [..., d]
|
||||
const scalar_t* __restrict__ input, // [..., 2, d]
|
||||
const int d) {
|
||||
constexpr int VEC_SIZE = 16 / sizeof(scalar_t);
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
|
||||
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
|
||||
out[token_idx * d + idx] = compute<scalar_t, ACT_FN, act_first>(x, y);
|
||||
const scalar_t* x_ptr = input + token_idx * 2 * d;
|
||||
const scalar_t* y_ptr = x_ptr + d;
|
||||
scalar_t* out_ptr = out + token_idx * d;
|
||||
|
||||
// Check alignment for 128-bit vectorized access.
|
||||
// All three pointers must be 16-byte aligned for safe int4 operations.
|
||||
const bool aligned = is_16byte_aligned(x_ptr) && is_16byte_aligned(y_ptr) &&
|
||||
is_16byte_aligned(out_ptr);
|
||||
|
||||
if (aligned && d >= VEC_SIZE) {
|
||||
// Fast path: 128-bit vectorized loop
|
||||
const int4* x_vec = reinterpret_cast<const int4*>(x_ptr);
|
||||
const int4* y_vec = reinterpret_cast<const int4*>(y_ptr);
|
||||
int4* out_vec = reinterpret_cast<int4*>(out_ptr);
|
||||
const int num_vecs = d / VEC_SIZE;
|
||||
const int vec_end = num_vecs * VEC_SIZE;
|
||||
|
||||
for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
|
||||
int4 x = VLLM_LDG(&x_vec[i]), y = VLLM_LDG(&y_vec[i]), r;
|
||||
auto* xp = reinterpret_cast<scalar_t*>(&x);
|
||||
auto* yp = reinterpret_cast<scalar_t*>(&y);
|
||||
auto* rp = reinterpret_cast<scalar_t*>(&r);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; j++) {
|
||||
rp[j] = compute<scalar_t, ACT_FN, act_first>(xp[j], yp[j]);
|
||||
}
|
||||
out_vec[i] = r;
|
||||
}
|
||||
// Scalar cleanup for remaining elements
|
||||
for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) {
|
||||
out_ptr[i] = compute<scalar_t, ACT_FN, act_first>(VLLM_LDG(&x_ptr[i]),
|
||||
VLLM_LDG(&y_ptr[i]));
|
||||
}
|
||||
} else {
|
||||
// Scalar fallback for unaligned data or small d
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = VLLM_LDG(&x_ptr[idx]);
|
||||
const scalar_t y = VLLM_LDG(&y_ptr[idx]);
|
||||
out_ptr[idx] = compute<scalar_t, ACT_FN, act_first>(x, y);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -120,50 +162,115 @@ template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&, const float)>
|
||||
__global__ void act_and_mul_kernel_with_param(
|
||||
scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const int d,
|
||||
const float param) {
|
||||
constexpr int VEC_SIZE = 16 / sizeof(scalar_t);
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
|
||||
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
|
||||
out[token_idx * d + idx] = ACT_FN(x, param) * y;
|
||||
const scalar_t* x_ptr = input + token_idx * 2 * d;
|
||||
const scalar_t* y_ptr = x_ptr + d;
|
||||
scalar_t* out_ptr = out + token_idx * d;
|
||||
|
||||
// Check alignment for 128-bit vectorized access
|
||||
const bool aligned = is_16byte_aligned(x_ptr) && is_16byte_aligned(y_ptr) &&
|
||||
is_16byte_aligned(out_ptr);
|
||||
|
||||
if (aligned && d >= VEC_SIZE) {
|
||||
// Fast path: 128-bit vectorized loop
|
||||
const int4* x_vec = reinterpret_cast<const int4*>(x_ptr);
|
||||
const int4* y_vec = reinterpret_cast<const int4*>(y_ptr);
|
||||
int4* out_vec = reinterpret_cast<int4*>(out_ptr);
|
||||
const int num_vecs = d / VEC_SIZE;
|
||||
const int vec_end = num_vecs * VEC_SIZE;
|
||||
|
||||
for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
|
||||
int4 x = VLLM_LDG(&x_vec[i]), y = VLLM_LDG(&y_vec[i]), r;
|
||||
auto* xp = reinterpret_cast<scalar_t*>(&x);
|
||||
auto* yp = reinterpret_cast<scalar_t*>(&y);
|
||||
auto* rp = reinterpret_cast<scalar_t*>(&r);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; j++) {
|
||||
rp[j] = ACT_FN(xp[j], param) * yp[j];
|
||||
}
|
||||
out_vec[i] = r;
|
||||
}
|
||||
// Scalar cleanup for remaining elements
|
||||
for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) {
|
||||
out_ptr[i] = ACT_FN(VLLM_LDG(&x_ptr[i]), param) * VLLM_LDG(&y_ptr[i]);
|
||||
}
|
||||
} else {
|
||||
// Scalar fallback for unaligned data or small d
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = VLLM_LDG(&x_ptr[idx]);
|
||||
const scalar_t y = VLLM_LDG(&y_ptr[idx]);
|
||||
out_ptr[idx] = ACT_FN(x, param) * y;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T swigluoai_and_mul(const T& gate, const T& up,
|
||||
float alpha, float limit) {
|
||||
// clamp gate: min=None, max=limit
|
||||
const float gate_f = (float)gate;
|
||||
const float clamped_gate = gate_f > limit ? limit : gate_f;
|
||||
|
||||
// clamp up: min=-limit, max=limit
|
||||
const float up_f = (float)up;
|
||||
const float clamped_up =
|
||||
up_f > limit ? limit : (up_f < -limit ? -limit : up_f);
|
||||
|
||||
// glu = gate * sigmoid(gate * alpha)
|
||||
const float sigmoid_val = 1.0f / (1.0f + expf(-clamped_gate * alpha));
|
||||
const float glu = clamped_gate * sigmoid_val;
|
||||
|
||||
// (up + 1) * glu
|
||||
return (T)((clamped_up + 1.0f) * glu);
|
||||
// Clamp gate to (-inf, limit] and up to [-limit, limit]
|
||||
const float g = fminf((float)gate, limit);
|
||||
const float u = fmaxf(fminf((float)up, limit), -limit);
|
||||
// glu = gate * sigmoid(gate * alpha), then return (up + 1) * glu
|
||||
return (T)((u + 1.0f) * g / (1.0f + expf(-g * alpha)));
|
||||
}
|
||||
|
||||
// Interleaved gate/up: input has [gate0, up0, gate1, up1, ...].
|
||||
template <typename scalar_t,
|
||||
scalar_t (*ACT_FN)(const scalar_t&, const scalar_t&, const float,
|
||||
const float)>
|
||||
__global__ void swigluoai_and_mul_kernel(
|
||||
scalar_t* __restrict__ out, // [..., d]
|
||||
const scalar_t* __restrict__ input, // [..., 2, d]
|
||||
const scalar_t* __restrict__ input, // [..., 2 * d] (interleaved)
|
||||
const int d, const float alpha, const float limit) {
|
||||
// For interleaved data: input has 2*d elements per token (gate/up pairs)
|
||||
// output has d elements per token
|
||||
constexpr int VEC_SIZE = 16 / sizeof(scalar_t);
|
||||
constexpr int PAIRS = VEC_SIZE / 2; // Number of gate/up pairs per int4 load
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
// TODO: Vectorize loads and stores.
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
// gate = x[..., ::2] (even indices)
|
||||
const scalar_t gate = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx]);
|
||||
// up = x[..., 1::2] (odd indices)
|
||||
const scalar_t up = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx + 1]);
|
||||
const scalar_t* in_ptr = input + token_idx * 2 * d;
|
||||
scalar_t* out_ptr = out + token_idx * d;
|
||||
|
||||
out[token_idx * d + idx] = ACT_FN(gate, up, alpha, limit);
|
||||
// Check alignment for 128-bit vectorized access on input.
|
||||
// For output we use int2 (64-bit) which has 8-byte alignment requirement.
|
||||
const bool in_aligned = is_16byte_aligned(in_ptr);
|
||||
const bool out_aligned =
|
||||
(reinterpret_cast<uintptr_t>(out_ptr) & 7) == 0; // 8-byte for int2
|
||||
|
||||
if (in_aligned && out_aligned && d >= PAIRS) {
|
||||
// Fast path: vectorized loop
|
||||
// Each int4 load gives VEC_SIZE elements = PAIRS gate/up pairs
|
||||
// Each int2 store writes PAIRS output elements
|
||||
const int4* in_vec = reinterpret_cast<const int4*>(in_ptr);
|
||||
int2* out_vec = reinterpret_cast<int2*>(out_ptr);
|
||||
const int num_vecs = d / PAIRS;
|
||||
const int vec_end = num_vecs * PAIRS;
|
||||
|
||||
for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
|
||||
int4 v = VLLM_LDG(&in_vec[i]);
|
||||
int2 r;
|
||||
auto* vp = reinterpret_cast<scalar_t*>(&v);
|
||||
auto* rp = reinterpret_cast<scalar_t*>(&r);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < PAIRS; j++) {
|
||||
rp[j] = ACT_FN(vp[2 * j], vp[2 * j + 1], alpha, limit);
|
||||
}
|
||||
out_vec[i] = r;
|
||||
}
|
||||
// Scalar cleanup for remaining elements
|
||||
for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) {
|
||||
out_ptr[i] = ACT_FN(VLLM_LDG(&in_ptr[2 * i]),
|
||||
VLLM_LDG(&in_ptr[2 * i + 1]), alpha, limit);
|
||||
}
|
||||
} else {
|
||||
// Scalar fallback for unaligned data or small d
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
// gate = x[..., ::2] (even indices)
|
||||
const scalar_t gate = VLLM_LDG(&in_ptr[2 * idx]);
|
||||
// up = x[..., 1::2] (odd indices)
|
||||
const scalar_t up = VLLM_LDG(&in_ptr[2 * idx + 1]);
|
||||
out_ptr[idx] = ACT_FN(gate, up, alpha, limit);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -217,10 +324,41 @@ __global__ void activation_kernel(
|
||||
scalar_t* __restrict__ out, // [..., d]
|
||||
const scalar_t* __restrict__ input, // [..., d]
|
||||
const int d) {
|
||||
constexpr int VEC_SIZE = 16 / sizeof(scalar_t);
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
|
||||
out[token_idx * d + idx] = ACT_FN(x);
|
||||
const scalar_t* in_ptr = input + token_idx * d;
|
||||
scalar_t* out_ptr = out + token_idx * d;
|
||||
|
||||
// Check alignment for 128-bit vectorized access
|
||||
const bool aligned = is_16byte_aligned(in_ptr) && is_16byte_aligned(out_ptr);
|
||||
|
||||
if (aligned && d >= VEC_SIZE) {
|
||||
// Fast path: 128-bit vectorized loop
|
||||
const int4* in_vec = reinterpret_cast<const int4*>(in_ptr);
|
||||
int4* out_vec = reinterpret_cast<int4*>(out_ptr);
|
||||
const int num_vecs = d / VEC_SIZE;
|
||||
const int vec_end = num_vecs * VEC_SIZE;
|
||||
|
||||
for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
|
||||
int4 v = VLLM_LDG(&in_vec[i]), r;
|
||||
auto* vp = reinterpret_cast<scalar_t*>(&v);
|
||||
auto* rp = reinterpret_cast<scalar_t*>(&r);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; j++) {
|
||||
rp[j] = ACT_FN(vp[j]);
|
||||
}
|
||||
out_vec[i] = r;
|
||||
}
|
||||
// Scalar cleanup for remaining elements
|
||||
for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) {
|
||||
out_ptr[i] = ACT_FN(VLLM_LDG(&in_ptr[i]));
|
||||
}
|
||||
} else {
|
||||
// Scalar fallback for unaligned data or small d
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = VLLM_LDG(&in_ptr[idx]);
|
||||
out_ptr[idx] = ACT_FN(x);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
#ifndef CPU_ATTN_MACROS_H
|
||||
#define CPU_ATTN_MACROS_H
|
||||
#ifndef CPU_ARCH_MACROS_H
|
||||
#define CPU_ARCH_MACROS_H
|
||||
|
||||
// x86_64
|
||||
#ifdef __x86_64__
|
||||
@ -26,7 +26,7 @@
|
||||
_mm512_castsi512_ps(_mm512_set1_epi32(0x42b17218)); \
|
||||
const __m512i vec_127 = _mm512_set1_epi32(0x0000007f); \
|
||||
const int n_mantissa_bits = 23; \
|
||||
auto fast_exp = [&](vec_op::FP32Vec16& vec) __attribute__(( \
|
||||
auto fast_exp = [&](const vec_op::FP32Vec16& vec) __attribute__(( \
|
||||
always_inline)) { \
|
||||
__m512 values = vec.reg; \
|
||||
auto less_ln_flt_min_mask = \
|
||||
@ -98,7 +98,7 @@
|
||||
poly = vbslq_f32(hi_mask, inf, poly); \
|
||||
return vbslq_f32(lo_mask, zero, poly); \
|
||||
}; \
|
||||
auto fast_exp = [&](vec_op::FP32Vec16& vec) \
|
||||
auto fast_exp = [&](const vec_op::FP32Vec16& vec) \
|
||||
__attribute__((always_inline)) { \
|
||||
float32x4x4_t result; \
|
||||
result.val[0] = neon_expf(vec.reg.val[0]); \
|
||||
@ -110,4 +110,4 @@
|
||||
|
||||
#endif // __aarch64__
|
||||
|
||||
#endif
|
||||
#endif
|
||||
@ -8,10 +8,8 @@
|
||||
#include <sys/sysctl.h>
|
||||
#endif
|
||||
|
||||
#include "cpu_types.hpp"
|
||||
#include "scratchpad_manager.h"
|
||||
#include "cpu_attn_macros.h"
|
||||
#include "utils.hpp"
|
||||
#include "cpu/cpu_arch_macros.h"
|
||||
#include "cpu/utils.hpp"
|
||||
|
||||
namespace cpu_attention {
|
||||
enum class ISA { AMX, VEC, VEC16, NEON };
|
||||
@ -378,12 +376,13 @@ class AttentionScheduler {
|
||||
|
||||
static constexpr int32_t MaxQTileIterNum = 128;
|
||||
|
||||
AttentionScheduler() : available_cache_size_(get_available_l2_size()) {}
|
||||
AttentionScheduler()
|
||||
: available_cache_size_(cpu_utils::get_available_l2_size()) {}
|
||||
|
||||
torch::Tensor schedule(const ScheduleInput& input) const {
|
||||
const bool casual = input.casual;
|
||||
const int32_t thread_num = omp_get_max_threads();
|
||||
const int64_t cache_size = get_available_l2_size();
|
||||
const int64_t cache_size = cpu_utils::get_available_l2_size();
|
||||
const int32_t max_num_q_per_iter = input.max_num_q_per_iter;
|
||||
const int32_t kv_len_alignment = input.kv_block_alignment;
|
||||
int32_t q_head_per_kv = input.num_heads_q / input.num_heads_kv;
|
||||
@ -659,7 +658,7 @@ class AttentionScheduler {
|
||||
metadata_ptr->thread_num +
|
||||
metadata_ptr->reduction_scratchpad_size_per_kv_head *
|
||||
(use_gqa ? input.num_heads_kv : input.num_heads_q);
|
||||
DNNLScratchPadManager::get_dnnl_scratchpad_manager()->realloc(
|
||||
cpu_utils::ScratchPadManager::get_scratchpad_manager()->realloc(
|
||||
scratchpad_size);
|
||||
|
||||
// metadata_ptr->print();
|
||||
@ -667,7 +666,7 @@ class AttentionScheduler {
|
||||
// test out of boundary access
|
||||
// {
|
||||
// float* cache_ptr =
|
||||
// DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data<float>();
|
||||
// cpu_utils::ScratchPadManager::getl_scratchpad_manager()->get_data<float>();
|
||||
// for (int64_t i = 0; i < scratchpad_size / sizeof(float); ++i) {
|
||||
// cache_ptr[i] = std::numeric_limits<float>::quiet_NaN();
|
||||
// }
|
||||
@ -749,27 +748,6 @@ class AttentionScheduler {
|
||||
return std::max(rounded_tile_size, round_size);
|
||||
}
|
||||
|
||||
static int64_t get_available_l2_size() {
|
||||
static int64_t size = []() {
|
||||
#if defined(__APPLE__)
|
||||
// macOS doesn't have _SC_LEVEL2_CACHE_SIZE. Use sysctlbyname.
|
||||
int64_t l2_cache_size = 0;
|
||||
size_t len = sizeof(l2_cache_size);
|
||||
if (sysctlbyname("hw.l2cachesize", &l2_cache_size, &len, NULL, 0) == 0 &&
|
||||
l2_cache_size > 0) {
|
||||
return l2_cache_size >> 1; // use 50% of L2 cache
|
||||
}
|
||||
// Fallback if sysctlbyname fails
|
||||
return 128LL * 1024 >> 1; // use 50% of 128KB
|
||||
#else
|
||||
long l2_cache_size = sysconf(_SC_LEVEL2_CACHE_SIZE);
|
||||
TORCH_CHECK_NE(l2_cache_size, -1);
|
||||
return l2_cache_size >> 1; // use 50% of L2 cache
|
||||
#endif
|
||||
}();
|
||||
return size;
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t available_cache_size_;
|
||||
};
|
||||
@ -1402,7 +1380,7 @@ class AttentionMainLoop {
|
||||
|
||||
// init buffers
|
||||
void* scratchpad_ptr =
|
||||
DNNLScratchPadManager::get_dnnl_scratchpad_manager()
|
||||
cpu_utils::ScratchPadManager::get_scratchpad_manager()
|
||||
->get_data<void>();
|
||||
AttentionScratchPad buffer_manager(thread_id, metadata, scratchpad_ptr);
|
||||
|
||||
@ -1422,8 +1400,7 @@ class AttentionMainLoop {
|
||||
}
|
||||
}
|
||||
|
||||
const int64_t available_cache_size =
|
||||
AttentionScheduler::get_available_l2_size();
|
||||
const int64_t available_cache_size = cpu_utils::get_available_l2_size();
|
||||
const int32_t default_tile_size =
|
||||
AttentionScheduler::calcu_default_tile_size(
|
||||
available_cache_size, head_dim, sizeof(kv_cache_t),
|
||||
|
||||
727
csrc/cpu/cpu_fused_moe.cpp
Normal file
727
csrc/cpu/cpu_fused_moe.cpp
Normal file
@ -0,0 +1,727 @@
|
||||
#include "cpu/cpu_types.hpp"
|
||||
#include "cpu/utils.hpp"
|
||||
#include "cpu/micro_gemm/cpu_micro_gemm_vec.hpp"
|
||||
#include "cpu/cpu_arch_macros.h"
|
||||
|
||||
#ifdef CPU_CAPABILITY_AMXBF16
|
||||
#include "cpu/micro_gemm/cpu_micro_gemm_amx.hpp"
|
||||
#define AMX_DISPATCH(...) \
|
||||
case cpu_utils::ISA::AMX: { \
|
||||
using gemm_t = cpu_micro_gemm::MicroGemm<cpu_utils::ISA::AMX, scalar_t>; \
|
||||
return __VA_ARGS__(); \
|
||||
}
|
||||
#else
|
||||
#define AMX_DISPATCH(...) case cpu_utils::ISA::AMX:
|
||||
#endif
|
||||
|
||||
#define CPU_ISA_DISPATCH_IMPL(ISA_TYPE, ...) \
|
||||
[&] { \
|
||||
switch (ISA_TYPE) { \
|
||||
AMX_DISPATCH(__VA_ARGS__) \
|
||||
case cpu_utils::ISA::VEC: { \
|
||||
using gemm_t = \
|
||||
cpu_micro_gemm::MicroGemm<cpu_utils::ISA::VEC, scalar_t>; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
default: { \
|
||||
TORCH_CHECK(false, "Invalid CPU ISA type."); \
|
||||
} \
|
||||
} \
|
||||
}()
|
||||
|
||||
namespace {
|
||||
enum class FusedMOEAct { SiluAndMul, SwigluOAIAndMul };
|
||||
|
||||
FusedMOEAct get_act_type(const std::string& act) {
|
||||
if (act == "silu") {
|
||||
return FusedMOEAct::SiluAndMul;
|
||||
} else if (act == "swigluoai") {
|
||||
return FusedMOEAct::SwigluOAIAndMul;
|
||||
} else {
|
||||
TORCH_CHECK(false, "Invalid act type: " + act);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void swigluoai_and_mul(float* __restrict__ input, scalar_t* __restrict__ output,
|
||||
const int32_t m_size, const int32_t n_size,
|
||||
const int32_t input_stride,
|
||||
const int32_t output_stride) {
|
||||
using scalar_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
|
||||
// For GPT-OSS interleaved gate-up weights
|
||||
alignas(64) static int32_t index[16] = {0, 2, 4, 6, 8, 10, 12, 14,
|
||||
16, 18, 20, 22, 24, 26, 28, 30};
|
||||
vec_op::INT32Vec16 index_vec(index);
|
||||
vec_op::FP32Vec16 gate_up_max_vec(7.0);
|
||||
vec_op::FP32Vec16 up_min_vec(-7.0);
|
||||
vec_op::FP32Vec16 alpha_vec(1.702);
|
||||
vec_op::FP32Vec16 one_vec(1.0);
|
||||
|
||||
DEFINE_FAST_EXP
|
||||
|
||||
for (int32_t m = 0; m < m_size; ++m) {
|
||||
for (int32_t n = 0; n < n_size; n += 32) {
|
||||
vec_op::FP32Vec16 gate_vec(input + n, index_vec);
|
||||
vec_op::FP32Vec16 up_vec(input + n + 1, index_vec);
|
||||
gate_vec = gate_vec.min(gate_up_max_vec);
|
||||
up_vec = up_vec.clamp(up_min_vec, gate_up_max_vec);
|
||||
auto sigmoid_vec = one_vec / (one_vec + fast_exp(-gate_vec * alpha_vec));
|
||||
auto glu = gate_vec * sigmoid_vec;
|
||||
auto gated_output_fp32 = (one_vec + up_vec) * glu;
|
||||
scalar_vec_t gated_output = scalar_vec_t(gated_output_fp32);
|
||||
gated_output.save(output + n / 2);
|
||||
}
|
||||
input += input_stride;
|
||||
output += output_stride;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void silu_and_mul(float* __restrict__ input, scalar_t* __restrict__ output,
|
||||
const int32_t m_size, const int32_t n_size,
|
||||
const int32_t input_stride, const int32_t output_stride) {
|
||||
using scalar_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
|
||||
const int32_t dim = n_size / 2;
|
||||
float* __restrict__ gate = input;
|
||||
float* __restrict__ up = input + dim;
|
||||
vec_op::FP32Vec16 one_vec(1.0);
|
||||
|
||||
DEFINE_FAST_EXP
|
||||
|
||||
for (int32_t m = 0; m < m_size; ++m) {
|
||||
for (int32_t n = 0; n < dim; n += 16) {
|
||||
vec_op::FP32Vec16 gate_vec(gate + n);
|
||||
vec_op::FP32Vec16 up_vec(up + n);
|
||||
auto sigmoid_vec = one_vec / (one_vec + fast_exp(-gate_vec));
|
||||
auto silu = gate_vec * sigmoid_vec;
|
||||
auto gated_output_fp32 = up_vec * silu;
|
||||
scalar_vec_t gated_output = scalar_vec_t(gated_output_fp32);
|
||||
gated_output.save(output + n);
|
||||
}
|
||||
gate += input_stride;
|
||||
up += input_stride;
|
||||
output += output_stride;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
FORCE_INLINE void apply_gated_act(const FusedMOEAct act,
|
||||
float* __restrict__ input,
|
||||
scalar_t* __restrict__ output,
|
||||
const int32_t m, const int32_t n,
|
||||
const int32_t input_stride,
|
||||
const int32_t output_stride) {
|
||||
switch (act) {
|
||||
case FusedMOEAct::SwigluOAIAndMul:
|
||||
swigluoai_and_mul(input, output, m, n, input_stride, output_stride);
|
||||
return;
|
||||
case FusedMOEAct::SiluAndMul:
|
||||
silu_and_mul(input, output, m, n, input_stride, output_stride);
|
||||
return;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unsupported act type.");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename gemm_t>
|
||||
void prepack_moe_weight_impl(scalar_t* __restrict__ weight_ptr,
|
||||
scalar_t* __restrict__ packed_weight_ptr,
|
||||
const int32_t expert_num,
|
||||
const int32_t output_size,
|
||||
const int32_t input_size,
|
||||
const int64_t expert_stride) {
|
||||
#pragma omp parallel for
|
||||
for (int32_t e_idx = 0; e_idx < expert_num; ++e_idx) {
|
||||
gemm_t::pack_weight(weight_ptr + expert_stride * e_idx,
|
||||
packed_weight_ptr + expert_stride * e_idx, output_size,
|
||||
input_size);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename w_t, typename gemm_t>
|
||||
void fused_moe_impl(scalar_t* __restrict__ output, scalar_t* __restrict__ input,
|
||||
w_t* __restrict__ w13, w_t* __restrict__ w2,
|
||||
w_t* __restrict__ w13_bias, w_t* __restrict__ w2_bias,
|
||||
float* __restrict__ topk_weights,
|
||||
int32_t* __restrict__ topk_id, FusedMOEAct act_type,
|
||||
const int32_t token_num, const int32_t expert_num,
|
||||
const int32_t topk_num, const int32_t input_size_13,
|
||||
const int32_t output_size_13, const int32_t input_size_2,
|
||||
const int32_t output_size_2) {
|
||||
using scalar_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
|
||||
constexpr int32_t gemm_n_tile_size = gemm_t::NSize;
|
||||
constexpr int32_t gemm_m_tile_size = gemm_t::MaxMSize;
|
||||
constexpr int32_t min_w13_n_tile_size = 2 * gemm_n_tile_size;
|
||||
static_assert(gemm_n_tile_size % 16 == 0);
|
||||
|
||||
TORCH_CHECK_EQ(output_size_13 % min_w13_n_tile_size, 0);
|
||||
TORCH_CHECK_EQ(output_size_2 % gemm_n_tile_size, 0);
|
||||
TORCH_CHECK_EQ(output_size_13 / 2, input_size_2);
|
||||
|
||||
const int32_t thread_num = omp_get_max_threads();
|
||||
|
||||
const int32_t w13_input_buffer_size = cpu_utils::round_up<64>(
|
||||
gemm_m_tile_size * input_size_13 * sizeof(scalar_t));
|
||||
|
||||
const int32_t w13_n_tile_size = [&]() {
|
||||
const int64_t cache_size = cpu_utils::get_available_l2_size();
|
||||
// input buffer + output buffer + weight
|
||||
const int32_t n_size_cache_limit =
|
||||
(cache_size - w13_input_buffer_size) /
|
||||
(gemm_m_tile_size * sizeof(float) + input_size_13 * sizeof(scalar_t));
|
||||
const int32_t n_size_thread_limit =
|
||||
output_size_13 / std::max(1, thread_num / topk_num);
|
||||
const int32_t n_size = cpu_utils::round_down<min_w13_n_tile_size>(
|
||||
std::min(n_size_cache_limit, n_size_thread_limit));
|
||||
return std::max(n_size, min_w13_n_tile_size);
|
||||
}();
|
||||
|
||||
const int32_t w2_input_tile_size = cpu_utils::round_up<64>(
|
||||
gemm_m_tile_size * input_size_2 * sizeof(scalar_t));
|
||||
|
||||
const int32_t w2_n_tile_size = [&]() {
|
||||
const int64_t cache_size = cpu_utils::get_available_l2_size();
|
||||
// input tile + weight
|
||||
const int32_t n_size_cache_limit =
|
||||
(cache_size - w2_input_tile_size) / (input_size_2 * sizeof(scalar_t));
|
||||
const int32_t n_size_thread_limit =
|
||||
output_size_2 / std::max(1, thread_num / topk_num);
|
||||
const int32_t n_size = cpu_utils::round_down<gemm_n_tile_size>(
|
||||
std::min(n_size_cache_limit, n_size_thread_limit));
|
||||
return std::max(n_size, gemm_n_tile_size);
|
||||
}();
|
||||
|
||||
// allocate buffers
|
||||
int32_t common_buffer_offset = 0;
|
||||
int32_t w13_thread_buffer_offset = 0;
|
||||
int32_t ws_thread_buffer_offset = 0;
|
||||
|
||||
// common buffers
|
||||
const int32_t token_num_per_group_buffer_size =
|
||||
cpu_utils::round_up<64>(expert_num * sizeof(int32_t));
|
||||
const int32_t token_num_per_group_buffer_offset = common_buffer_offset;
|
||||
common_buffer_offset += token_num_per_group_buffer_size;
|
||||
|
||||
const int32_t cu_token_num_per_group_buffer_size =
|
||||
cpu_utils::round_up<64>((expert_num + 1) * sizeof(int32_t));
|
||||
const int32_t cu_token_num_per_group_buffer_offset = common_buffer_offset;
|
||||
common_buffer_offset += cu_token_num_per_group_buffer_size;
|
||||
|
||||
const int32_t expand_token_id_buffer_size =
|
||||
cpu_utils::round_up<64>(token_num * topk_num * sizeof(int32_t));
|
||||
const int32_t expand_token_id_buffer_offset = common_buffer_offset;
|
||||
common_buffer_offset += expand_token_id_buffer_size;
|
||||
|
||||
const int32_t expand_token_id_index_buffer_size =
|
||||
cpu_utils::round_up<64>(token_num * topk_num * sizeof(int32_t));
|
||||
const int32_t expand_token_id_index_buffer_offset = common_buffer_offset;
|
||||
common_buffer_offset += expand_token_id_index_buffer_size;
|
||||
|
||||
const int32_t w13_gemm_output_buffer_size = cpu_utils::round_up<64>(
|
||||
token_num * topk_num * (output_size_13 / 2) * sizeof(scalar_t));
|
||||
const int32_t w13_gemm_output_buffer_offset = common_buffer_offset;
|
||||
common_buffer_offset += w13_gemm_output_buffer_size;
|
||||
|
||||
const int32_t w2_gemm_output_buffer_size = cpu_utils::round_up<64>(
|
||||
token_num * topk_num * output_size_2 * sizeof(float));
|
||||
const int32_t w2_gemm_output_buffer_offset = common_buffer_offset;
|
||||
common_buffer_offset += w2_gemm_output_buffer_size;
|
||||
|
||||
// w13 GEMM thread buffers
|
||||
const int32_t w13_input_buffer_offset = w13_thread_buffer_offset;
|
||||
w13_thread_buffer_offset += w13_input_buffer_size;
|
||||
|
||||
const int32_t w13_output_buffer_size = cpu_utils::round_up<64>(
|
||||
gemm_m_tile_size * w13_n_tile_size * sizeof(float));
|
||||
const int32_t w13_output_buffer_offset = w13_thread_buffer_offset;
|
||||
w13_thread_buffer_offset += w13_output_buffer_size;
|
||||
|
||||
// Weighted sum thread buffer
|
||||
const int32_t ws_output_buffer_size =
|
||||
cpu_utils::round_up<64>(output_size_2 * sizeof(float));
|
||||
const int32_t ws_output_buffer_offset = ws_thread_buffer_offset;
|
||||
ws_thread_buffer_offset += ws_output_buffer_size;
|
||||
|
||||
const int32_t buffer_size =
|
||||
common_buffer_offset +
|
||||
std::max(w13_thread_buffer_offset, ws_thread_buffer_offset) * thread_num;
|
||||
cpu_utils::ScratchPadManager::get_scratchpad_manager()->realloc(buffer_size);
|
||||
uint8_t* common_buffer_start =
|
||||
cpu_utils::ScratchPadManager::get_scratchpad_manager()
|
||||
->get_data<uint8_t>();
|
||||
uint8_t* thread_buffer_start = common_buffer_start + common_buffer_offset;
|
||||
|
||||
int32_t* __restrict__ token_num_per_group_buffer = reinterpret_cast<int32_t*>(
|
||||
common_buffer_start + token_num_per_group_buffer_offset);
|
||||
int32_t* __restrict__ cu_token_num_per_group_buffer =
|
||||
reinterpret_cast<int32_t*>(common_buffer_start +
|
||||
cu_token_num_per_group_buffer_offset);
|
||||
int32_t* __restrict__ expand_token_id_buffer = reinterpret_cast<int32_t*>(
|
||||
common_buffer_start + expand_token_id_buffer_offset);
|
||||
int32_t* __restrict__ expand_token_id_index_buffer =
|
||||
reinterpret_cast<int32_t*>(common_buffer_start +
|
||||
expand_token_id_index_buffer_offset);
|
||||
|
||||
// prepare token-expert mappings
|
||||
{
|
||||
std::memset(token_num_per_group_buffer, 0, expert_num * sizeof(int32_t));
|
||||
for (int32_t i = 0; i < token_num * topk_num; ++i) {
|
||||
int32_t curr_expert_id = topk_id[i];
|
||||
++token_num_per_group_buffer[curr_expert_id];
|
||||
}
|
||||
|
||||
int32_t token_num_sum = 0;
|
||||
cu_token_num_per_group_buffer[0] = 0;
|
||||
int32_t* token_index_buffer = cu_token_num_per_group_buffer + 1;
|
||||
for (int32_t i = 0; i < expert_num; ++i) {
|
||||
token_index_buffer[i] = token_num_sum;
|
||||
token_num_sum += token_num_per_group_buffer[i];
|
||||
}
|
||||
|
||||
for (int32_t i = 0; i < token_num; ++i) {
|
||||
int32_t* curr_topk_id = topk_id + i * topk_num;
|
||||
int32_t* curr_index_buffer = expand_token_id_index_buffer + i * topk_num;
|
||||
for (int32_t j = 0; j < topk_num; ++j) {
|
||||
int32_t curr_expert_id = curr_topk_id[j];
|
||||
int32_t curr_index = token_index_buffer[curr_expert_id];
|
||||
++token_index_buffer[curr_expert_id];
|
||||
expand_token_id_buffer[curr_index] = i;
|
||||
curr_index_buffer[j] = curr_index;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// w13 GEMM + act
|
||||
{
|
||||
alignas(64) cpu_utils::Counter counter;
|
||||
cpu_utils::Counter* counter_ptr = &counter;
|
||||
|
||||
#pragma omp parallel for schedule(static, 1)
|
||||
for (int32_t thread_id = 0; thread_id < thread_num; ++thread_id) {
|
||||
const int32_t task_num_per_expert =
|
||||
(output_size_13 + w13_n_tile_size - 1) / w13_n_tile_size;
|
||||
const int32_t task_num = task_num_per_expert * expert_num;
|
||||
|
||||
uint8_t* __restrict__ thread_buffer =
|
||||
thread_buffer_start + thread_id * w13_thread_buffer_offset;
|
||||
scalar_t* __restrict__ w13_input_buffer =
|
||||
reinterpret_cast<scalar_t*>(thread_buffer + w13_input_buffer_offset);
|
||||
float* __restrict__ w13_output_buffer =
|
||||
reinterpret_cast<float*>(thread_buffer + w13_output_buffer_offset);
|
||||
scalar_t* __restrict__ w13_gemm_output_buffer =
|
||||
reinterpret_cast<scalar_t*>(common_buffer_start +
|
||||
w13_gemm_output_buffer_offset);
|
||||
|
||||
gemm_t gemm;
|
||||
|
||||
const int32_t input_size_13_bytes = input_size_13 * sizeof(scalar_t);
|
||||
const int32_t w13_n_group_stride = 16 * input_size_13;
|
||||
const int32_t w13_n_tile_stride = gemm_n_tile_size * input_size_13;
|
||||
|
||||
for (;;) {
|
||||
int32_t task_id = counter_ptr->acquire_counter();
|
||||
if (task_id >= task_num) {
|
||||
break;
|
||||
}
|
||||
|
||||
const int32_t curr_expert_id = task_id / task_num_per_expert;
|
||||
const int32_t curr_output_group_id = task_id % task_num_per_expert;
|
||||
const int32_t curr_token_num =
|
||||
token_num_per_group_buffer[curr_expert_id];
|
||||
if (curr_token_num == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int32_t actual_n_tile_size =
|
||||
std::min(w13_n_tile_size,
|
||||
output_size_13 - curr_output_group_id * w13_n_tile_size);
|
||||
const int32_t* __restrict__ curr_expand_token_id_buffer =
|
||||
expand_token_id_buffer +
|
||||
cu_token_num_per_group_buffer[curr_expert_id];
|
||||
scalar_t* __restrict__ curr_w13_gemm_output_buffer =
|
||||
w13_gemm_output_buffer +
|
||||
cu_token_num_per_group_buffer[curr_expert_id] *
|
||||
(output_size_13 / 2) +
|
||||
curr_output_group_id * w13_n_tile_size / 2;
|
||||
|
||||
w_t* __restrict__ w13_weight_ptr_0 = nullptr;
|
||||
w_t* __restrict__ w13_weight_ptr_1 = nullptr;
|
||||
w_t* __restrict__ w13_bias_ptr_0 = nullptr;
|
||||
w_t* __restrict__ w13_bias_ptr_1 = nullptr;
|
||||
if (act_type == FusedMOEAct::SwigluOAIAndMul) {
|
||||
// For SwigluOAIAndMul, up and down weights are interleaved
|
||||
w13_weight_ptr_0 =
|
||||
w13 + curr_expert_id * input_size_13 * output_size_13 +
|
||||
curr_output_group_id * w13_n_tile_size * input_size_13;
|
||||
w13_weight_ptr_1 =
|
||||
w13_weight_ptr_0 + actual_n_tile_size / 2 * input_size_13;
|
||||
if (w13_bias != nullptr) {
|
||||
w13_bias_ptr_0 = w13_bias + curr_expert_id * output_size_13 +
|
||||
curr_output_group_id * w13_n_tile_size;
|
||||
w13_bias_ptr_1 = w13_bias_ptr_0 + actual_n_tile_size / 2;
|
||||
}
|
||||
} else {
|
||||
w13_weight_ptr_0 =
|
||||
w13 + curr_expert_id * input_size_13 * output_size_13 +
|
||||
curr_output_group_id * (w13_n_tile_size / 2) * input_size_13;
|
||||
w13_weight_ptr_1 =
|
||||
w13_weight_ptr_0 + output_size_13 / 2 * input_size_13;
|
||||
if (w13_bias != nullptr) {
|
||||
w13_bias_ptr_0 = w13_bias + curr_expert_id * output_size_13 +
|
||||
curr_output_group_id * (w13_n_tile_size / 2);
|
||||
w13_bias_ptr_1 = w13_bias_ptr_0 + output_size_13 / 2;
|
||||
}
|
||||
}
|
||||
|
||||
scalar_t* __restrict__ curr_w13_input_buffer = w13_input_buffer;
|
||||
for (int32_t token_idx = 0; token_idx < curr_token_num;
|
||||
token_idx += gemm_m_tile_size) {
|
||||
const int32_t actual_token_num =
|
||||
std::min(gemm_m_tile_size, curr_token_num - token_idx);
|
||||
// copy inputs
|
||||
{
|
||||
scalar_t* __restrict__ curr_w13_input_buffer_iter =
|
||||
curr_w13_input_buffer;
|
||||
for (int32_t i = 0; i < actual_token_num; ++i) {
|
||||
const int32_t curr_token_id = curr_expand_token_id_buffer[i];
|
||||
int8_t* __restrict__ curr_input_iter = reinterpret_cast<int8_t*>(
|
||||
input + curr_token_id * input_size_13);
|
||||
int8_t* __restrict__ curr_output_iter =
|
||||
reinterpret_cast<int8_t*>(curr_w13_input_buffer_iter);
|
||||
int32_t j = 0;
|
||||
for (; j < input_size_13_bytes - 64; j += 64) {
|
||||
vec_op::INT8Vec64 vec(curr_input_iter);
|
||||
vec.save(curr_output_iter);
|
||||
curr_input_iter += 64;
|
||||
curr_output_iter += 64;
|
||||
}
|
||||
vec_op::INT8Vec64 vec(curr_input_iter);
|
||||
vec.save(curr_output_iter, input_size_13_bytes - j);
|
||||
|
||||
// update
|
||||
curr_w13_input_buffer_iter += input_size_13;
|
||||
}
|
||||
// update
|
||||
curr_expand_token_id_buffer += actual_token_num;
|
||||
}
|
||||
|
||||
// gemm + act
|
||||
{
|
||||
scalar_t* __restrict__ w13_weight_ptr_0_iter = w13_weight_ptr_0;
|
||||
scalar_t* __restrict__ w13_weight_ptr_1_iter = w13_weight_ptr_1;
|
||||
scalar_t* __restrict__ w13_bias_ptr_0_iter = w13_bias_ptr_0;
|
||||
scalar_t* __restrict__ w13_bias_ptr_1_iter = w13_bias_ptr_1;
|
||||
scalar_t* __restrict__ curr_w13_input_buffer_iter =
|
||||
curr_w13_input_buffer;
|
||||
float* __restrict__ w13_output_buffer_0_iter = w13_output_buffer;
|
||||
float* __restrict__ w13_output_buffer_1_iter =
|
||||
w13_output_buffer + actual_n_tile_size / 2;
|
||||
for (int32_t i = 0; i < actual_n_tile_size;
|
||||
i += min_w13_n_tile_size) {
|
||||
gemm.gemm(curr_w13_input_buffer_iter, w13_weight_ptr_0_iter,
|
||||
w13_output_buffer_0_iter, actual_token_num,
|
||||
input_size_13, input_size_13, w13_n_group_stride,
|
||||
actual_n_tile_size, false);
|
||||
|
||||
if (w13_bias != nullptr) {
|
||||
cpu_micro_gemm::add_bias_epilogue<gemm_n_tile_size>(
|
||||
w13_output_buffer_0_iter, w13_output_buffer_0_iter,
|
||||
w13_bias_ptr_0_iter, actual_token_num, actual_n_tile_size,
|
||||
actual_n_tile_size);
|
||||
w13_bias_ptr_0_iter += gemm_n_tile_size;
|
||||
}
|
||||
|
||||
gemm.gemm(curr_w13_input_buffer_iter, w13_weight_ptr_1_iter,
|
||||
w13_output_buffer_1_iter, actual_token_num,
|
||||
input_size_13, input_size_13, w13_n_group_stride,
|
||||
actual_n_tile_size, false);
|
||||
|
||||
if (w13_bias != nullptr) {
|
||||
cpu_micro_gemm::add_bias_epilogue<gemm_n_tile_size>(
|
||||
w13_output_buffer_1_iter, w13_output_buffer_1_iter,
|
||||
w13_bias_ptr_1_iter, actual_token_num, actual_n_tile_size,
|
||||
actual_n_tile_size);
|
||||
w13_bias_ptr_1_iter += gemm_n_tile_size;
|
||||
}
|
||||
|
||||
// update
|
||||
w13_weight_ptr_0_iter += w13_n_tile_stride;
|
||||
w13_weight_ptr_1_iter += w13_n_tile_stride;
|
||||
w13_output_buffer_0_iter += gemm_n_tile_size;
|
||||
w13_output_buffer_1_iter += gemm_n_tile_size;
|
||||
}
|
||||
|
||||
apply_gated_act(act_type, w13_output_buffer,
|
||||
curr_w13_gemm_output_buffer, actual_token_num,
|
||||
actual_n_tile_size, actual_n_tile_size,
|
||||
output_size_13 / 2);
|
||||
|
||||
// update
|
||||
curr_w13_gemm_output_buffer +=
|
||||
gemm_m_tile_size * (output_size_13 / 2);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// w2 GEMM
|
||||
{
|
||||
alignas(64) cpu_utils::Counter counter;
|
||||
cpu_utils::Counter* counter_ptr = &counter;
|
||||
|
||||
#pragma omp parallel for schedule(static, 1)
|
||||
for (int32_t thread_id = 0; thread_id < thread_num; ++thread_id) {
|
||||
const int32_t task_num_per_expert =
|
||||
(output_size_2 + w2_n_tile_size - 1) / w2_n_tile_size;
|
||||
const int32_t task_num = task_num_per_expert * expert_num;
|
||||
scalar_t* __restrict__ w13_gemm_output_buffer =
|
||||
reinterpret_cast<scalar_t*>(common_buffer_start +
|
||||
w13_gemm_output_buffer_offset);
|
||||
float* __restrict__ w2_gemm_output_buffer = reinterpret_cast<float*>(
|
||||
common_buffer_start + w2_gemm_output_buffer_offset);
|
||||
|
||||
gemm_t gemm;
|
||||
|
||||
const int32_t w2_n_tile_stride = gemm_n_tile_size * input_size_2;
|
||||
const int32_t w2_n_group_stride = 16 * input_size_2;
|
||||
|
||||
for (;;) {
|
||||
int32_t task_id = counter_ptr->acquire_counter();
|
||||
if (task_id >= task_num) {
|
||||
break;
|
||||
}
|
||||
|
||||
const int32_t curr_expert_id = task_id / task_num_per_expert;
|
||||
const int32_t curr_output_group_id = task_id % task_num_per_expert;
|
||||
const int32_t curr_token_num =
|
||||
token_num_per_group_buffer[curr_expert_id];
|
||||
if (curr_token_num == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int32_t actual_n_tile_size =
|
||||
std::min(w2_n_tile_size,
|
||||
output_size_2 - curr_output_group_id * w2_n_tile_size);
|
||||
scalar_t* __restrict__ curr_w13_gemm_output_buffer =
|
||||
w13_gemm_output_buffer +
|
||||
cu_token_num_per_group_buffer[curr_expert_id] * input_size_2;
|
||||
float* __restrict__ curr_w2_gemm_output_buffer =
|
||||
w2_gemm_output_buffer +
|
||||
cu_token_num_per_group_buffer[curr_expert_id] * output_size_2 +
|
||||
curr_output_group_id * w2_n_tile_size;
|
||||
scalar_t* __restrict__ w2_weight_ptr =
|
||||
w2 + curr_expert_id * output_size_2 * input_size_2 +
|
||||
curr_output_group_id * w2_n_tile_size * input_size_2;
|
||||
scalar_t* __restrict__ w2_bias_ptr = nullptr;
|
||||
if (w2_bias != nullptr) {
|
||||
w2_bias_ptr = w2_bias + curr_expert_id * output_size_2 +
|
||||
curr_output_group_id * w2_n_tile_size;
|
||||
}
|
||||
|
||||
for (int32_t token_idx = 0; token_idx < curr_token_num;
|
||||
token_idx += gemm_m_tile_size) {
|
||||
const int32_t actual_token_num =
|
||||
std::min(gemm_m_tile_size, curr_token_num - token_idx);
|
||||
|
||||
scalar_t* __restrict__ w2_weight_ptr_iter = w2_weight_ptr;
|
||||
scalar_t* __restrict__ w2_bias_ptr_iter = w2_bias_ptr;
|
||||
float* __restrict__ curr_w2_gemm_output_buffer_iter =
|
||||
curr_w2_gemm_output_buffer;
|
||||
for (int32_t i = 0; i < actual_n_tile_size; i += gemm_n_tile_size) {
|
||||
gemm.gemm(curr_w13_gemm_output_buffer, w2_weight_ptr_iter,
|
||||
curr_w2_gemm_output_buffer_iter, actual_token_num,
|
||||
input_size_2, input_size_2, w2_n_group_stride,
|
||||
output_size_2, false);
|
||||
|
||||
if (w2_bias != nullptr) {
|
||||
cpu_micro_gemm::add_bias_epilogue<gemm_n_tile_size>(
|
||||
curr_w2_gemm_output_buffer_iter,
|
||||
curr_w2_gemm_output_buffer_iter, w2_bias_ptr_iter,
|
||||
actual_token_num, output_size_2, output_size_2);
|
||||
w2_bias_ptr_iter += gemm_n_tile_size;
|
||||
}
|
||||
|
||||
w2_weight_ptr_iter += w2_n_tile_stride;
|
||||
curr_w2_gemm_output_buffer_iter += gemm_n_tile_size;
|
||||
}
|
||||
|
||||
// update
|
||||
curr_w13_gemm_output_buffer += gemm_m_tile_size * input_size_2;
|
||||
curr_w2_gemm_output_buffer += gemm_m_tile_size * output_size_2;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// weighted sum
|
||||
{
|
||||
alignas(64) cpu_utils::Counter counter;
|
||||
cpu_utils::Counter* counter_ptr = &counter;
|
||||
|
||||
#pragma omp parallel for schedule(static, 1)
|
||||
for (int32_t thread_id = 0; thread_id < thread_num; ++thread_id) {
|
||||
const int32_t task_num = token_num;
|
||||
uint8_t* __restrict__ thread_buffer =
|
||||
thread_buffer_start + thread_id * ws_thread_buffer_offset;
|
||||
float* __restrict__ ws_output_buffer =
|
||||
reinterpret_cast<float*>(thread_buffer + ws_output_buffer_offset);
|
||||
float* __restrict__ w2_gemm_output_buffer = reinterpret_cast<float*>(
|
||||
common_buffer_start + w2_gemm_output_buffer_offset);
|
||||
|
||||
for (;;) {
|
||||
int32_t task_id = counter_ptr->acquire_counter();
|
||||
if (task_id >= task_num) {
|
||||
break;
|
||||
}
|
||||
|
||||
int32_t token_id = task_id;
|
||||
int32_t* __restrict__ curr_expand_token_id_index_buffer =
|
||||
expand_token_id_index_buffer + token_id * topk_num;
|
||||
float* __restrict__ curr_weight = topk_weights + token_id * topk_num;
|
||||
scalar_t* __restrict__ curr_output_buffer =
|
||||
output + token_id * output_size_2;
|
||||
|
||||
if (topk_num > 1) {
|
||||
{
|
||||
int32_t w2_output_idx = curr_expand_token_id_index_buffer[0];
|
||||
float* __restrict__ w2_output_iter =
|
||||
w2_gemm_output_buffer + w2_output_idx * output_size_2;
|
||||
float* __restrict__ ws_output_buffer_iter = ws_output_buffer;
|
||||
vec_op::FP32Vec16 weight_vec(curr_weight[0]);
|
||||
for (int32_t i = 0; i < output_size_2; i += 16) {
|
||||
vec_op::FP32Vec16 vec(w2_output_iter);
|
||||
vec = vec * weight_vec;
|
||||
vec.save(ws_output_buffer_iter);
|
||||
|
||||
// update
|
||||
w2_output_iter += 16;
|
||||
ws_output_buffer_iter += 16;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
for (int32_t idx = 1; idx < topk_num - 1; ++idx) {
|
||||
int32_t w2_output_idx = curr_expand_token_id_index_buffer[idx];
|
||||
float* __restrict__ w2_output_iter =
|
||||
w2_gemm_output_buffer + w2_output_idx * output_size_2;
|
||||
float* __restrict__ ws_output_buffer_iter = ws_output_buffer;
|
||||
vec_op::FP32Vec16 weight_vec(curr_weight[idx]);
|
||||
for (int32_t i = 0; i < output_size_2; i += 16) {
|
||||
vec_op::FP32Vec16 vec(w2_output_iter);
|
||||
vec_op::FP32Vec16 sum(ws_output_buffer_iter);
|
||||
sum = sum + vec * weight_vec;
|
||||
sum.save(ws_output_buffer_iter);
|
||||
|
||||
// update
|
||||
w2_output_iter += 16;
|
||||
ws_output_buffer_iter += 16;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
int32_t idx = topk_num - 1;
|
||||
int32_t w2_output_idx = curr_expand_token_id_index_buffer[idx];
|
||||
float* __restrict__ w2_output_iter =
|
||||
w2_gemm_output_buffer + w2_output_idx * output_size_2;
|
||||
float* __restrict__ ws_output_buffer_iter = ws_output_buffer;
|
||||
scalar_t* __restrict__ curr_output_buffer_iter = curr_output_buffer;
|
||||
vec_op::FP32Vec16 weight_vec(curr_weight[idx]);
|
||||
for (int32_t i = 0; i < output_size_2; i += 16) {
|
||||
vec_op::FP32Vec16 vec(w2_output_iter);
|
||||
vec_op::FP32Vec16 sum(ws_output_buffer_iter);
|
||||
sum = sum + vec * weight_vec;
|
||||
scalar_vec_t out_vec(sum);
|
||||
out_vec.save(curr_output_buffer_iter);
|
||||
|
||||
// update
|
||||
w2_output_iter += 16;
|
||||
ws_output_buffer_iter += 16;
|
||||
curr_output_buffer_iter += 16;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
int32_t w2_output_idx = curr_expand_token_id_index_buffer[0];
|
||||
float* __restrict__ w2_output_iter =
|
||||
w2_gemm_output_buffer + w2_output_idx * output_size_2;
|
||||
scalar_t* __restrict__ curr_output_buffer_iter = curr_output_buffer;
|
||||
vec_op::FP32Vec16 weight_vec(curr_weight[0]);
|
||||
for (int32_t i = 0; i < output_size_2; i += 16) {
|
||||
vec_op::FP32Vec16 vec(w2_output_iter);
|
||||
vec = vec * weight_vec;
|
||||
scalar_vec_t out_vec(vec);
|
||||
out_vec.save(curr_output_buffer_iter);
|
||||
|
||||
// update
|
||||
w2_output_iter += 16;
|
||||
curr_output_buffer_iter += 16;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void prepack_moe_weight(
|
||||
const torch::Tensor& weight, // [expert_num, output_size, input_size]
|
||||
torch::Tensor& packed_weight, const std::string& isa) {
|
||||
TORCH_CHECK(weight.is_contiguous());
|
||||
const int32_t expert_num = weight.size(0);
|
||||
const int32_t output_size = weight.size(1);
|
||||
const int32_t input_size = weight.size(2);
|
||||
TORCH_CHECK_EQ(output_size % 32, 0);
|
||||
const int64_t expert_stride = weight.stride(0);
|
||||
cpu_utils::ISA isa_type = cpu_utils::get_isa(isa);
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
weight.scalar_type(), "prepack_moe_weight", [&]() {
|
||||
CPU_ISA_DISPATCH_IMPL(isa_type, [&]() {
|
||||
scalar_t* weight_ptr = weight.data_ptr<scalar_t>();
|
||||
scalar_t* packed_weight_ptr = packed_weight.data_ptr<scalar_t>();
|
||||
prepack_moe_weight_impl<scalar_t, gemm_t>(
|
||||
weight_ptr, packed_weight_ptr, expert_num, output_size,
|
||||
input_size, expert_stride);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void cpu_fused_moe(
|
||||
torch::Tensor& output, // [token_num, output_size_2]
|
||||
const torch::Tensor& input, // [token_num, input_size_13]
|
||||
const torch::Tensor&
|
||||
w13, // [expert_num, output_size_13, input_size_13], packed
|
||||
const torch::Tensor&
|
||||
w2, // [expert_num, output_size_2, input_size_2], packed
|
||||
const std::optional<torch::Tensor>&
|
||||
w13_bias, // [expert_num, output_size_13]
|
||||
const std::optional<torch::Tensor>& w2_bias, // [expert_num, output_size_2]
|
||||
const torch::Tensor& topk_weights, // [token_num, k], float32
|
||||
const torch::Tensor& topk_id, // [token_num, k], int32
|
||||
const std::string& act, const std::string& isa) {
|
||||
const int32_t token_num = input.size(0);
|
||||
const int32_t input_size_13 = input.size(1);
|
||||
const int64_t input_stride = input.stride(0);
|
||||
TORCH_CHECK_EQ(input_stride, input_size_13);
|
||||
const int32_t expert_num = w13.size(0);
|
||||
const int32_t output_size_13 = w13.size(1);
|
||||
const int32_t input_size_2 = w2.size(2);
|
||||
const int32_t output_size_2 = w2.size(1);
|
||||
const int32_t topk_num = topk_id.size(1);
|
||||
const FusedMOEAct act_type = get_act_type(act);
|
||||
cpu_utils::ISA isa_type = cpu_utils::get_isa(isa);
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(w13.scalar_type(), "cpu_fused_moe", [&]() {
|
||||
CPU_ISA_DISPATCH_IMPL(isa_type, [&]() {
|
||||
fused_moe_impl<scalar_t, scalar_t, gemm_t>(
|
||||
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
|
||||
w13.data_ptr<scalar_t>(), w2.data_ptr<scalar_t>(),
|
||||
w13_bias.has_value() ? w13_bias->data_ptr<scalar_t>() : nullptr,
|
||||
w2_bias.has_value() ? w2_bias->data_ptr<scalar_t>() : nullptr,
|
||||
topk_weights.data_ptr<float>(), topk_id.data_ptr<int32_t>(), act_type,
|
||||
token_num, expert_num, topk_num, input_size_13, output_size_13,
|
||||
input_size_2, output_size_2);
|
||||
});
|
||||
});
|
||||
}
|
||||
@ -352,6 +352,10 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
explicit FP32Vec16(bool, void* ptr)
|
||||
: reg((__m512)_mm512_stream_load_si512(ptr)) {}
|
||||
|
||||
// strided load
|
||||
explicit FP32Vec16(const float* ptr, INT32Vec16 idx)
|
||||
: reg(_mm512_i32gather_ps(idx.reg, ptr, 4)) {}
|
||||
|
||||
explicit FP32Vec16(__m512 data) : reg(data) {}
|
||||
|
||||
// de-pack 4 bit values
|
||||
@ -408,6 +412,10 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
return FP32Vec16(_mm512_sub_ps(reg, b.reg));
|
||||
}
|
||||
|
||||
FP32Vec16 operator-() const {
|
||||
return FP32Vec16(_mm512_xor_ps(reg, _mm512_set1_ps(-0.0f)));
|
||||
}
|
||||
|
||||
FP32Vec16 operator/(const FP32Vec16& b) const {
|
||||
return FP32Vec16(_mm512_div_ps(reg, b.reg));
|
||||
}
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
#include "cpu_types.hpp"
|
||||
#include "scratchpad_manager.h"
|
||||
#include "utils.hpp"
|
||||
#include "cpu/cpu_types.hpp"
|
||||
#include "cpu/utils.hpp"
|
||||
|
||||
#ifdef CPU_CAPABILITY_AMXBF16
|
||||
#include "cpu/micro_gemm/cpu_micro_gemm_amx.hpp"
|
||||
@ -158,7 +157,7 @@ void cpu_gemm_wna16_impl(
|
||||
// a simple schedule policy, just to hold more B tiles in L2 and make sure
|
||||
// each thread has tasks
|
||||
const int32_t n_partition_size = [&]() {
|
||||
const int64_t cache_size = cpu_utils::get_l2_size();
|
||||
const int64_t cache_size = cpu_utils::get_available_l2_size();
|
||||
int64_t ps_cache_limit = cache_size / (k_size * sizeof(scalar_t));
|
||||
int64_t ps_thread_limit = n_size / thread_num;
|
||||
ps_cache_limit =
|
||||
@ -179,8 +178,8 @@ void cpu_gemm_wna16_impl(
|
||||
const int64_t b_buffer_offset = 0;
|
||||
const int64_t c_buffer_offset = b_buffer_size;
|
||||
const int64_t buffer_size = b_buffer_size + c_buffer_size;
|
||||
DNNLScratchPadManager::get_dnnl_scratchpad_manager()->realloc(buffer_size *
|
||||
thread_num);
|
||||
cpu_utils::ScratchPadManager::get_scratchpad_manager()->realloc(buffer_size *
|
||||
thread_num);
|
||||
|
||||
alignas(64) cpu_utils::Counter counter;
|
||||
cpu_utils::Counter* counter_ptr = &counter;
|
||||
@ -190,9 +189,10 @@ void cpu_gemm_wna16_impl(
|
||||
scalar_t* __restrict__ b_buffer = nullptr;
|
||||
float* __restrict__ c_buffer = nullptr;
|
||||
{
|
||||
uint8_t* buffer_ptr = DNNLScratchPadManager::get_dnnl_scratchpad_manager()
|
||||
->get_data<uint8_t>() +
|
||||
thread_id * buffer_size;
|
||||
uint8_t* buffer_ptr =
|
||||
cpu_utils::ScratchPadManager::get_scratchpad_manager()
|
||||
->get_data<uint8_t>() +
|
||||
thread_id * buffer_size;
|
||||
b_buffer = reinterpret_cast<scalar_t*>(buffer_ptr + b_buffer_offset);
|
||||
c_buffer = reinterpret_cast<float*>(buffer_ptr + c_buffer_offset);
|
||||
}
|
||||
|
||||
@ -4,8 +4,8 @@
|
||||
#include "common/memory_desc.hpp"
|
||||
#include "common/memory.hpp"
|
||||
|
||||
#include "dnnl_helper.h"
|
||||
#include "scratchpad_manager.h"
|
||||
#include "cpu/utils.hpp"
|
||||
#include "cpu/dnnl_helper.h"
|
||||
|
||||
static dnnl::engine& default_engine() {
|
||||
static dnnl::engine engine(dnnl::engine::kind::cpu, 0);
|
||||
@ -274,7 +274,7 @@ void W8A8MatMulPrimitiveHandler::execute(ExecArgs& args) {
|
||||
|
||||
auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(5);
|
||||
scratchpad_storage->set_data_handle(
|
||||
DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data<void>());
|
||||
cpu_utils::ScratchPadManager::get_scratchpad_manager()->get_data<void>());
|
||||
|
||||
matmul.execute(default_stream(), memory_cache_);
|
||||
default_stream().wait();
|
||||
@ -294,7 +294,7 @@ dnnl::matmul W8A8MatMulPrimitiveHandler::get_matmul_cache(
|
||||
|
||||
return m_size_cache_->get_or_create(key, [&]() {
|
||||
dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false);
|
||||
auto manager = DNNLScratchPadManager::get_dnnl_scratchpad_manager();
|
||||
auto manager = cpu_utils::ScratchPadManager::get_scratchpad_manager();
|
||||
manager->realloc(desc.scratchpad_desc().get_size());
|
||||
return dnnl::matmul(desc);
|
||||
});
|
||||
@ -470,7 +470,7 @@ void MatMulPrimitiveHandler::execute(ExecArgs& args) {
|
||||
|
||||
auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(3);
|
||||
scratchpad_storage->set_data_handle(
|
||||
DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data<void>());
|
||||
cpu_utils::ScratchPadManager::get_scratchpad_manager()->get_data<void>());
|
||||
|
||||
matmul.execute(default_stream(), memory_cache_);
|
||||
default_stream().wait();
|
||||
@ -486,7 +486,7 @@ dnnl::matmul MatMulPrimitiveHandler::get_matmul_cache(
|
||||
}
|
||||
return m_size_cache_->get_or_create(key, [&]() {
|
||||
dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false);
|
||||
auto manager = DNNLScratchPadManager::get_dnnl_scratchpad_manager();
|
||||
auto manager = cpu_utils::ScratchPadManager::get_scratchpad_manager();
|
||||
manager->realloc(desc.scratchpad_desc().get_size());
|
||||
return dnnl::matmul(desc);
|
||||
});
|
||||
|
||||
@ -235,6 +235,39 @@ class MicroGemm<cpu_utils::ISA::AMX, scalar_t> {
|
||||
}
|
||||
}
|
||||
|
||||
static void pack_weight(const scalar_t* __restrict__ weight,
|
||||
scalar_t* __restrict__ packed_weight,
|
||||
const int32_t output_size, const int32_t input_size) {
|
||||
constexpr int32_t elem_num_per_group = 4 / sizeof(scalar_t);
|
||||
TORCH_CHECK_EQ(output_size % 16, 0);
|
||||
TORCH_CHECK_EQ(input_size % (16 * elem_num_per_group), 0);
|
||||
|
||||
const int32_t output_group_num = output_size / 16;
|
||||
const int32_t input_32b_num = input_size / elem_num_per_group;
|
||||
for (int32_t output_group_idx = 0; output_group_idx < output_group_num;
|
||||
++output_group_idx) {
|
||||
const int32_t* __restrict__ weight_32b =
|
||||
reinterpret_cast<const int32_t*>(weight);
|
||||
int32_t* __restrict__ packed_weight_32b =
|
||||
reinterpret_cast<int32_t*>(packed_weight);
|
||||
for (int32_t output_idx = 0; output_idx < 16; ++output_idx) {
|
||||
for (int32_t weight_offset = 0, packed_offset = 0;
|
||||
weight_offset < input_32b_num;
|
||||
++weight_offset, packed_offset += 16) {
|
||||
packed_weight_32b[packed_offset] = weight_32b[weight_offset];
|
||||
}
|
||||
|
||||
// update
|
||||
weight_32b += input_32b_num;
|
||||
packed_weight_32b += 1;
|
||||
}
|
||||
|
||||
// update
|
||||
weight += 16 * input_size;
|
||||
packed_weight += 16 * input_size;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
alignas(64) __tilecfg amx_tile_config_;
|
||||
int32_t curr_m_;
|
||||
|
||||
@ -13,6 +13,9 @@ namespace cpu_micro_gemm {
|
||||
#define CPU_MICRO_GEMM_PARAMS \
|
||||
a_ptr, b_ptr, c_ptr, m, k, lda, b_n_group_stride, ldc, accum_c
|
||||
|
||||
// Note: weights for MicroGemm should be packed as (output_size / 16) contiguous
|
||||
// blocks, means the logical shape of blocks is [16, input_size]. And the actual
|
||||
// layout of blocks can be ISA-specific.
|
||||
template <cpu_utils::ISA isa, typename scalar_t>
|
||||
class MicroGemm {
|
||||
public:
|
||||
@ -86,6 +89,41 @@ FORCE_INLINE void bias_epilogue(float* __restrict__ c_ptr,
|
||||
curr_d += ldd;
|
||||
}
|
||||
}
|
||||
|
||||
template <int32_t n_size, typename scalar_t>
|
||||
FORCE_INLINE void add_bias_epilogue(float* c_ptr, float* d_ptr,
|
||||
scalar_t* __restrict__ bias_ptr,
|
||||
const int32_t m, const int64_t ldc,
|
||||
const int64_t ldd) {
|
||||
using scalar_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
|
||||
static_assert(n_size % 16 == 0);
|
||||
constexpr int32_t n_group_num = n_size / 16;
|
||||
static_assert(n_group_num <= 16);
|
||||
|
||||
vec_op::FP32Vec16 bias_vecs[n_group_num];
|
||||
scalar_t* __restrict__ curr_bias = bias_ptr;
|
||||
vec_op::unroll_loop<int32_t, n_group_num>([&](int32_t i) {
|
||||
scalar_vec_t vec(curr_bias);
|
||||
bias_vecs[i] = vec_op::FP32Vec16(vec);
|
||||
curr_bias += 16;
|
||||
});
|
||||
|
||||
float* curr_c = c_ptr;
|
||||
float* curr_d = d_ptr;
|
||||
for (int32_t i = 0; i < m; ++i) {
|
||||
float* curr_c_iter = curr_c;
|
||||
float* curr_d_iter = curr_d;
|
||||
vec_op::unroll_loop<int32_t, n_group_num>([&](int32_t n_g_idx) {
|
||||
vec_op::FP32Vec16 c_vec_fp32(curr_c_iter);
|
||||
c_vec_fp32 = c_vec_fp32 + bias_vecs[n_g_idx];
|
||||
c_vec_fp32.save(curr_d_iter);
|
||||
curr_c_iter += 16;
|
||||
curr_d_iter += 16;
|
||||
});
|
||||
curr_c += ldc;
|
||||
curr_d += ldd;
|
||||
}
|
||||
}
|
||||
} // namespace cpu_micro_gemm
|
||||
|
||||
#endif
|
||||
|
||||
@ -109,6 +109,25 @@ class MicroGemm<cpu_utils::ISA::VEC, scalar_t> {
|
||||
void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
|
||||
TileGemm82<scalar_t>::gemm(CPU_MICRO_GEMM_PARAMS);
|
||||
}
|
||||
|
||||
// Note: pack contiguous weight [output_size, input_size] as contiguous
|
||||
// packed weight [output_size / 16, input_size, 16]
|
||||
static void pack_weight(const scalar_t* __restrict__ weight,
|
||||
scalar_t* __restrict__ packed_weight,
|
||||
const int32_t output_size, const int32_t input_size) {
|
||||
TORCH_CHECK_EQ(output_size % 16, 0);
|
||||
for (int32_t o_idx = 0; o_idx < output_size; ++o_idx) {
|
||||
const scalar_t* __restrict__ curr_weight = weight + o_idx * input_size;
|
||||
scalar_t* __restrict__ curr_packed_weight =
|
||||
packed_weight + (o_idx / 16) * (16 * input_size) + o_idx % 16;
|
||||
for (int32_t i_idx = 0; i_idx < input_size; ++i_idx) {
|
||||
*curr_packed_weight = *curr_weight;
|
||||
|
||||
curr_packed_weight += 16;
|
||||
++curr_weight;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace cpu_micro_gemm
|
||||
|
||||
|
||||
@ -1,23 +0,0 @@
|
||||
#include <cstdlib>
|
||||
|
||||
#include "scratchpad_manager.h"
|
||||
|
||||
DNNLScratchPadManager::DNNLScratchPadManager() : size_(0), ptr_(nullptr) {
|
||||
this->realloc(allocation_unit * 128);
|
||||
}
|
||||
|
||||
void DNNLScratchPadManager::realloc(size_t new_size) {
|
||||
new_size = round(new_size);
|
||||
if (new_size > size_) {
|
||||
if (ptr_ != nullptr) {
|
||||
std::free(ptr_);
|
||||
}
|
||||
ptr_ = std::aligned_alloc(64, new_size);
|
||||
size_ = new_size;
|
||||
}
|
||||
}
|
||||
|
||||
DNNLScratchPadManager* DNNLScratchPadManager::get_dnnl_scratchpad_manager() {
|
||||
static DNNLScratchPadManager manager;
|
||||
return &manager;
|
||||
}
|
||||
@ -1,31 +0,0 @@
|
||||
#ifndef SCRATCHPAD_MANAGER_H
|
||||
#define SCRATCHPAD_MANAGER_H
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdio>
|
||||
|
||||
class DNNLScratchPadManager {
|
||||
public:
|
||||
static constexpr size_t allocation_unit = 4 * 1024; // 4KB
|
||||
|
||||
static DNNLScratchPadManager* get_dnnl_scratchpad_manager();
|
||||
|
||||
DNNLScratchPadManager();
|
||||
|
||||
template <typename T>
|
||||
T* get_data() {
|
||||
return reinterpret_cast<T*>(ptr_);
|
||||
}
|
||||
|
||||
static size_t round(size_t size) {
|
||||
return ((size + allocation_unit - 1) / allocation_unit) * allocation_unit;
|
||||
}
|
||||
|
||||
void realloc(size_t new_size);
|
||||
|
||||
private:
|
||||
size_t size_;
|
||||
void* ptr_;
|
||||
};
|
||||
|
||||
#endif
|
||||
@ -110,6 +110,17 @@ void cpu_gemm_wna16(const torch::Tensor& input, const torch::Tensor& q_weight,
|
||||
const std::optional<torch::Tensor>& bias,
|
||||
const int64_t pack_factor, const std::string& isa_hint);
|
||||
|
||||
void prepack_moe_weight(const torch::Tensor& weight,
|
||||
torch::Tensor& packed_weight, const std::string& isa);
|
||||
|
||||
void cpu_fused_moe(torch::Tensor& output, const torch::Tensor& input,
|
||||
const torch::Tensor& w13, const torch::Tensor& w2,
|
||||
const std::optional<torch::Tensor>& w13_bias,
|
||||
const std::optional<torch::Tensor>& w2_bias,
|
||||
const torch::Tensor& topk_weights,
|
||||
const torch::Tensor& topk_id, const std::string& act,
|
||||
const std::string& isa);
|
||||
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// vLLM custom ops
|
||||
|
||||
@ -296,6 +307,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"pack_factor, str isa_hint) -> ()");
|
||||
ops.impl("cpu_gemm_wna16", torch::kCPU, &cpu_gemm_wna16);
|
||||
#endif
|
||||
|
||||
// fused moe
|
||||
#if defined(__AVX512F__)
|
||||
ops.def(
|
||||
"prepack_moe_weight(Tensor weight, Tensor(a1!) packed_weight, str isa) "
|
||||
"-> ()");
|
||||
ops.impl("prepack_moe_weight", torch::kCPU, &prepack_moe_weight);
|
||||
ops.def(
|
||||
"cpu_fused_moe(Tensor(a0!) output, Tensor input, Tensor w13, Tensor w2, "
|
||||
"Tensor? w13_bias, Tensor? w2_bias, Tensor topk_weights, Tensor topk_id, "
|
||||
"str act, str isa) -> ()");
|
||||
ops.impl("cpu_fused_moe", torch::kCPU, &cpu_fused_moe);
|
||||
#endif
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
|
||||
|
||||
@ -10,7 +10,7 @@
|
||||
#define gettid() syscall(SYS_gettid)
|
||||
#endif
|
||||
|
||||
#include "cpu_types.hpp"
|
||||
#include "cpu/utils.hpp"
|
||||
|
||||
#ifdef VLLM_NUMA_DISABLED
|
||||
std::string init_cpu_threads_env(const std::string& cpu_ids) {
|
||||
@ -138,4 +138,26 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) {
|
||||
|
||||
return ss.str();
|
||||
}
|
||||
#endif
|
||||
#endif // VLLM_NUMA_DISABLED
|
||||
|
||||
namespace cpu_utils {
|
||||
ScratchPadManager::ScratchPadManager() : size_(0), ptr_(nullptr) {
|
||||
this->realloc(allocation_unit * 128);
|
||||
}
|
||||
|
||||
void ScratchPadManager::realloc(size_t new_size) {
|
||||
new_size = round(new_size);
|
||||
if (new_size > size_) {
|
||||
if (ptr_ != nullptr) {
|
||||
std::free(ptr_);
|
||||
}
|
||||
ptr_ = std::aligned_alloc(64, new_size);
|
||||
size_ = new_size;
|
||||
}
|
||||
}
|
||||
|
||||
ScratchPadManager* ScratchPadManager::get_scratchpad_manager() {
|
||||
static ScratchPadManager manager;
|
||||
return &manager;
|
||||
}
|
||||
} // namespace cpu_utils
|
||||
|
||||
@ -2,19 +2,24 @@
|
||||
#define UTILS_HPP
|
||||
|
||||
#include <atomic>
|
||||
#include <cassert>
|
||||
#include <cstdint>
|
||||
#include <unistd.h>
|
||||
#include <ATen/cpu/Utils.h>
|
||||
|
||||
#if defined(__APPLE__)
|
||||
#include <sys/sysctl.h>
|
||||
#endif
|
||||
|
||||
#include "cpu_types.hpp"
|
||||
#include "cpu/cpu_types.hpp"
|
||||
|
||||
namespace cpu_utils {
|
||||
enum class ISA { AMX, VEC };
|
||||
|
||||
inline ISA get_isa(const std::string& isa) {
|
||||
if (isa == "amx") {
|
||||
return ISA::AMX;
|
||||
} else if (isa == "vec") {
|
||||
return ISA::VEC;
|
||||
} else {
|
||||
TORCH_CHECK(false, "Invalid isa type: " + isa);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct VecTypeTrait {
|
||||
using vec_t = void;
|
||||
@ -32,10 +37,12 @@ struct VecTypeTrait<c10::BFloat16> {
|
||||
};
|
||||
#endif
|
||||
|
||||
#if !defined(__powerpc__)
|
||||
template <>
|
||||
struct VecTypeTrait<c10::Half> {
|
||||
using vec_t = vec_op::FP16Vec16;
|
||||
};
|
||||
#endif
|
||||
|
||||
struct Counter {
|
||||
std::atomic<int64_t> counter;
|
||||
@ -48,26 +55,66 @@ struct Counter {
|
||||
int64_t acquire_counter() { return counter++; }
|
||||
};
|
||||
|
||||
inline int64_t get_l2_size() {
|
||||
inline int64_t get_available_l2_size() {
|
||||
static int64_t size = []() {
|
||||
#if defined(__APPLE__)
|
||||
// macOS doesn't have _SC_LEVEL2_CACHE_SIZE. Use sysctlbyname.
|
||||
int64_t l2_cache_size = 0;
|
||||
size_t len = sizeof(l2_cache_size);
|
||||
if (sysctlbyname("hw.l2cachesize", &l2_cache_size, &len, NULL, 0) == 0 &&
|
||||
l2_cache_size > 0) {
|
||||
return l2_cache_size >> 1; // use 50% of L2 cache
|
||||
}
|
||||
// Fallback if sysctlbyname fails
|
||||
return 128LL * 1024 >> 1; // use 50% of 128KB
|
||||
#else
|
||||
long l2_cache_size = sysconf(_SC_LEVEL2_CACHE_SIZE);
|
||||
assert(l2_cache_size != -1);
|
||||
const uint32_t l2_cache_size = at::cpu::L2_cache_size();
|
||||
return l2_cache_size >> 1; // use 50% of L2 cache
|
||||
#endif
|
||||
}();
|
||||
return size;
|
||||
}
|
||||
|
||||
template <int32_t alignment_v, typename T>
|
||||
inline T round_up(T size) {
|
||||
T alignment = alignment_v;
|
||||
return (((size + alignment - 1) / alignment) * alignment);
|
||||
}
|
||||
|
||||
template <int32_t alignment_v, typename T>
|
||||
inline T round_down(T size) {
|
||||
T alignment = alignment_v;
|
||||
return (size / alignment) * alignment;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void print_logits(const char* name, T* ptr, int32_t row, int32_t col,
|
||||
int32_t stride) {
|
||||
std::stringstream ss;
|
||||
ss << std::fixed << std::setprecision(5) << name << ": [\n";
|
||||
auto* curr_logits_buffer = ptr;
|
||||
for (int32_t m = 0; m < row; ++m) {
|
||||
for (int32_t n = 0; n < col; ++n) {
|
||||
ss << curr_logits_buffer[n] << ", ";
|
||||
}
|
||||
ss << "\n";
|
||||
curr_logits_buffer += stride;
|
||||
}
|
||||
ss << "]\n";
|
||||
std::printf("%s", ss.str().c_str());
|
||||
}
|
||||
|
||||
class ScratchPadManager {
|
||||
public:
|
||||
static constexpr size_t allocation_unit = 4 * 1024; // 4KB
|
||||
|
||||
static ScratchPadManager* get_scratchpad_manager();
|
||||
|
||||
ScratchPadManager();
|
||||
|
||||
template <typename T>
|
||||
T* get_data() {
|
||||
return reinterpret_cast<T*>(ptr_);
|
||||
}
|
||||
|
||||
static size_t round(size_t size) {
|
||||
return ((size + allocation_unit - 1) / allocation_unit) * allocation_unit;
|
||||
}
|
||||
|
||||
void realloc(size_t new_size);
|
||||
|
||||
private:
|
||||
size_t size_;
|
||||
void* ptr_;
|
||||
};
|
||||
} // namespace cpu_utils
|
||||
|
||||
#endif
|
||||
|
||||
@ -107,6 +107,16 @@ void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem,
|
||||
prop.location.id = device;
|
||||
prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_NONE;
|
||||
|
||||
#ifndef USE_ROCM
|
||||
int flag = 0;
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
&flag, CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WITH_CUDA_VMM_SUPPORTED,
|
||||
device));
|
||||
if (flag) { // support GPUDirect RDMA if possible
|
||||
prop.allocFlags.gpuDirectRDMACapable = 1;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// Allocate memory using cuMemCreate
|
||||
CUDA_CHECK(cuMemCreate(p_memHandle, size, &prop, 0));
|
||||
|
||||
@ -446,9 +446,13 @@ __device__ inline T apply_sigmoid(T val) {
|
||||
|
||||
template <ScoringFunc SF, typename T>
|
||||
__device__ inline T apply_scoring(T val) {
|
||||
if constexpr (SF == SCORING_SIGMOID) {
|
||||
if constexpr (SF == SCORING_NONE) {
|
||||
return val;
|
||||
} else if constexpr (SF == SCORING_SIGMOID) {
|
||||
return apply_sigmoid(val);
|
||||
} else {
|
||||
static_assert(SF == SCORING_NONE || SF == SCORING_SIGMOID,
|
||||
"Unsupported ScoringFunc in apply_scoring");
|
||||
return val;
|
||||
}
|
||||
}
|
||||
@ -670,10 +674,13 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
|
||||
if (case_id < num_tokens) {
|
||||
if (if_proceed_next_topk) {
|
||||
float scale = routed_scaling_factor;
|
||||
if (renormalize) {
|
||||
scale /= topk_sum;
|
||||
}
|
||||
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
||||
float base = cuda_cast<float, T>(s_topk_value[i]);
|
||||
float value = renormalize ? (base / topk_sum * routed_scaling_factor)
|
||||
: (base * routed_scaling_factor);
|
||||
float value = base * scale;
|
||||
topk_indices[i] = s_topk_idx[i];
|
||||
topk_values[i] = value;
|
||||
}
|
||||
|
||||
1
csrc/moe/marlin_moe_wna16/.gitignore
vendored
1
csrc/moe/marlin_moe_wna16/.gitignore
vendored
@ -1,2 +1,3 @@
|
||||
sm*_kernel_*.cu
|
||||
kernel_selector.h
|
||||
kernel_*.cu
|
||||
|
||||
@ -10,6 +10,8 @@ import jinja2
|
||||
|
||||
ARCHS = []
|
||||
SUPPORT_FP8 = False
|
||||
SUPPORT_SM75 = False
|
||||
SUPPORT_SM80 = False
|
||||
for arch in sys.argv[1].split(","):
|
||||
arch = arch[: arch.index(".") + 2].replace(".", "")
|
||||
arch = int(arch)
|
||||
@ -19,6 +21,10 @@ for arch in sys.argv[1].split(","):
|
||||
# with FP16 MMA, so it cannot achieve any acceleration.
|
||||
if arch in [89, 120]:
|
||||
SUPPORT_FP8 = True
|
||||
if arch >= 80:
|
||||
SUPPORT_SM80 = True
|
||||
if arch == 75:
|
||||
SUPPORT_SM75 = True
|
||||
|
||||
FILE_HEAD_COMMENT = """
|
||||
// auto generated by generate_kernels.py
|
||||
@ -157,6 +163,7 @@ def remove_old_kernels():
|
||||
|
||||
def generate_new_kernels():
|
||||
result_dict = {}
|
||||
sm_75_result_dict = {}
|
||||
|
||||
for quant_config in QUANT_CONFIGS:
|
||||
c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"])
|
||||
@ -174,6 +181,8 @@ def generate_new_kernels():
|
||||
s_type = quant_config.get("s_type", c_type)
|
||||
if (a_type, b_type, c_type) not in result_dict:
|
||||
result_dict[(a_type, b_type, c_type)] = []
|
||||
if a_type in ["kFloat16", "kS8"] and c_type == "kFloat16":
|
||||
sm_75_result_dict[(a_type, b_type, c_type)] = []
|
||||
|
||||
for group_blocks, m_blocks, thread_configs in itertools.product(
|
||||
all_group_blocks, all_m_blocks, all_thread_configs
|
||||
@ -197,78 +206,89 @@ def generate_new_kernels():
|
||||
"thread_k_blocks": thread_k // 16,
|
||||
"thread_n_blocks": thread_n // 16,
|
||||
"m_block_size_8": "true" if m_blocks == 0.5 else "false",
|
||||
"stages": "pipe_stages",
|
||||
"stages": 4,
|
||||
"group_blocks": group_blocks,
|
||||
"is_zp_float": "false",
|
||||
}
|
||||
|
||||
result_dict[(a_type, b_type, c_type)].append(config)
|
||||
if SUPPORT_SM80:
|
||||
result_dict[(a_type, b_type, c_type)].append(config)
|
||||
if (a_type, b_type, c_type) in sm_75_result_dict and SUPPORT_SM75:
|
||||
config_sm75 = config.copy()
|
||||
config_sm75["stages"] = 2
|
||||
sm_75_result_dict[(a_type, b_type, c_type)].append(config_sm75)
|
||||
|
||||
kernel_selector_str = FILE_HEAD_COMMENT
|
||||
|
||||
for (a_type, b_type, c_type), config_list in result_dict.items():
|
||||
all_template_str_list = []
|
||||
for config in config_list:
|
||||
s_type = config["s_type"]
|
||||
template_str = jinja2.Template(TEMPLATE).render(
|
||||
a_type_id=f"vllm::{a_type}.id()",
|
||||
b_type_id=f"vllm::{b_type}.id()",
|
||||
c_type_id=f"vllm::{c_type}.id()",
|
||||
s_type_id=f"vllm::{s_type}.id()",
|
||||
**config,
|
||||
)
|
||||
all_template_str_list.append(template_str)
|
||||
|
||||
conditions = [
|
||||
f"a_type == vllm::{a_type}",
|
||||
f"b_type == vllm::{b_type}",
|
||||
f"c_type == vllm::{c_type}",
|
||||
f"s_type == vllm::{s_type}",
|
||||
f"threads == {config['threads']}",
|
||||
f"thread_m_blocks == {config['thread_m_blocks']}",
|
||||
f"thread_n_blocks == {config['thread_n_blocks']}",
|
||||
f"thread_k_blocks == {config['thread_k_blocks']}",
|
||||
f"m_block_size_8 == {config['m_block_size_8']}",
|
||||
f"group_blocks == {config['group_blocks']}",
|
||||
f"is_zp_float == {config['is_zp_float']}",
|
||||
]
|
||||
conditions = " && ".join(conditions)
|
||||
|
||||
if kernel_selector_str == FILE_HEAD_COMMENT:
|
||||
kernel_selector_str += f"if ({conditions})\n kernel = "
|
||||
else:
|
||||
kernel_selector_str += f"else if ({conditions})\n kernel = "
|
||||
|
||||
kernel_template2 = (
|
||||
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
|
||||
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
|
||||
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
|
||||
"{{is_zp_float}}>;"
|
||||
)
|
||||
|
||||
kernel_selector_str += (
|
||||
jinja2.Template(kernel_template2).render(
|
||||
for result_dict_tmp in [result_dict, sm_75_result_dict]:
|
||||
for (a_type, b_type, c_type), config_list in result_dict_tmp.items():
|
||||
all_template_str_list = []
|
||||
if not config_list:
|
||||
continue
|
||||
for config in config_list:
|
||||
s_type = config["s_type"]
|
||||
template_str = jinja2.Template(TEMPLATE).render(
|
||||
a_type_id=f"vllm::{a_type}.id()",
|
||||
b_type_id=f"vllm::{b_type}.id()",
|
||||
c_type_id=f"vllm::{c_type}.id()",
|
||||
s_type_id=f"vllm::{s_type}.id()",
|
||||
**config,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
all_template_str_list.append(template_str)
|
||||
|
||||
file_content = FILE_HEAD + "\n\n"
|
||||
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
|
||||
if a_type == "kFE4M3fn":
|
||||
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
else:
|
||||
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
conditions = [
|
||||
f"a_type == vllm::{a_type}",
|
||||
f"b_type == vllm::{b_type}",
|
||||
f"c_type == vllm::{c_type}",
|
||||
f"s_type == vllm::{s_type}",
|
||||
f"threads == {config['threads']}",
|
||||
f"thread_m_blocks == {config['thread_m_blocks']}",
|
||||
f"thread_n_blocks == {config['thread_n_blocks']}",
|
||||
f"thread_k_blocks == {config['thread_k_blocks']}",
|
||||
f"m_block_size_8 == {config['m_block_size_8']}",
|
||||
f"stages == {config['stages']}",
|
||||
f"group_blocks == {config['group_blocks']}",
|
||||
f"is_zp_float == {config['is_zp_float']}",
|
||||
]
|
||||
conditions = " && ".join(conditions)
|
||||
|
||||
filename = filename.lower()
|
||||
if kernel_selector_str == FILE_HEAD_COMMENT:
|
||||
kernel_selector_str += f"if ({conditions})\n kernel = "
|
||||
else:
|
||||
kernel_selector_str += f"else if ({conditions})\n kernel = "
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
|
||||
f.write(file_content)
|
||||
kernel_template2 = (
|
||||
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
|
||||
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
|
||||
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
|
||||
"{{is_zp_float}}>;"
|
||||
)
|
||||
|
||||
kernel_selector_str += (
|
||||
jinja2.Template(kernel_template2).render(
|
||||
a_type_id=f"vllm::{a_type}.id()",
|
||||
b_type_id=f"vllm::{b_type}.id()",
|
||||
c_type_id=f"vllm::{c_type}.id()",
|
||||
s_type_id=f"vllm::{s_type}.id()",
|
||||
**config,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
file_content = FILE_HEAD + "\n\n"
|
||||
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
|
||||
if a_type == "kFE4M3fn":
|
||||
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
elif result_dict_tmp is sm_75_result_dict:
|
||||
filename = f"sm75_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
else:
|
||||
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
|
||||
filename = filename.lower()
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
|
||||
f.write(file_content)
|
||||
|
||||
if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT:
|
||||
kernel_selector_str += (
|
||||
|
||||
@ -26,6 +26,7 @@
|
||||
#include "quantization/gptq_marlin/marlin.cuh"
|
||||
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
|
||||
#include "quantization/gptq_marlin/dequant.h"
|
||||
#include "quantization/gptq_marlin/marlin_mma.h"
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||
@ -35,7 +36,7 @@
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
||||
|
||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const vllm::ScalarTypeId b_type_id, // weight MarlinScalarType id
|
||||
@ -84,146 +85,6 @@ __global__ void Marlin(
|
||||
|
||||
#else
|
||||
|
||||
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
|
||||
// output/accumulation.
|
||||
template <vllm::ScalarTypeId type_id, int k_size = 16>
|
||||
__device__ inline void mma(
|
||||
const typename MarlinScalarType<type_id>::FragA& a_frag,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b,
|
||||
typename MarlinScalarType<type_id>::FragC& frag_c, int idx = 0) {
|
||||
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
|
||||
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
||||
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
|
||||
if constexpr (k_size == 16) {
|
||||
if constexpr (std::is_same<scalar_t, half>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "f"(c[0]),
|
||||
"f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "r"(c[0]),
|
||||
"r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
}
|
||||
} else if (k_size == 32) {
|
||||
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <vllm::ScalarTypeId type_id, int k_size = 16>
|
||||
__device__ inline void mma_trans(
|
||||
const typename MarlinScalarType<type_id>::FragA& a_frag,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b2,
|
||||
typename MarlinScalarType<type_id>::FragC& frag_c) {
|
||||
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
|
||||
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
||||
const uint32_t* b2 = reinterpret_cast<const uint32_t*>(&frag_b2);
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
|
||||
if constexpr (k_size == 16) {
|
||||
if constexpr (std::is_same<scalar_t, half>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
|
||||
"f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]),
|
||||
"r"(c[3]));
|
||||
}
|
||||
} else {
|
||||
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 1200
|
||||
asm volatile(
|
||||
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
#else
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
#endif
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
||||
// memory, directly in tensor core layout.
|
||||
template <int count, vllm::ScalarTypeId type_id>
|
||||
@ -439,9 +300,20 @@ __global__ void Marlin(
|
||||
if constexpr (a_type_id == vllm::kFE4M3fn.id()) return;
|
||||
#endif
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
// Turing TensorCore only supports fp16 and int8
|
||||
if constexpr (a_type_id != vllm::kFloat16.id() && a_type_id != vllm::kS8.id())
|
||||
return;
|
||||
#endif
|
||||
|
||||
int num_tokens_past_padded = num_tokens_past_padded_ptr[0];
|
||||
constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks);
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
constexpr bool use_fp16_accum = a_type_id == vllm::kFloat16.id();
|
||||
#else
|
||||
constexpr bool use_fp16_accum = false;
|
||||
#endif
|
||||
using Adtype = MarlinScalarType<a_type_id>;
|
||||
using Cdtype = MarlinScalarType<c_type_id>;
|
||||
|
||||
@ -618,7 +490,22 @@ __global__ void Marlin(
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
|
||||
if constexpr (moe_block_size >= 16)
|
||||
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 16);
|
||||
if constexpr (moe_block_size >= 8)
|
||||
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 8);
|
||||
if constexpr (moe_block_size >= 4)
|
||||
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 4);
|
||||
if constexpr (moe_block_size >= 2)
|
||||
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 2);
|
||||
|
||||
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 1);
|
||||
block_num_valid_tokens = local_count;
|
||||
#else
|
||||
block_num_valid_tokens = __reduce_add_sync(0xffffffff, local_count);
|
||||
#endif
|
||||
|
||||
if (lane_id == 0)
|
||||
reinterpret_cast<int*>(sh_new)[0] = block_num_valid_tokens;
|
||||
@ -1018,10 +905,6 @@ __global__ void Marlin(
|
||||
constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride)
|
||||
: (stages * s_sh_stage);
|
||||
int4* sh_s = sh_zp + (stages * zp_sh_stage);
|
||||
// shared memory reused by reduction should be smaller than
|
||||
// shared memory used by weight.
|
||||
static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <=
|
||||
stages * b_sh_stage);
|
||||
int4* sh_a = sh_s + sh_s_size;
|
||||
|
||||
// Register storage for double buffer of shared memory reads.
|
||||
@ -1545,11 +1428,13 @@ __global__ void Marlin(
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks; i++) {
|
||||
if constexpr (m_block_size_8) {
|
||||
mma_trans<a_type_id>(frag_a[k2][i], frag_b0, frag_b1,
|
||||
frag_c[i][j][0]);
|
||||
mma_trans<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b0, frag_b1,
|
||||
frag_c[i][j][0]);
|
||||
} else {
|
||||
mma<a_type_id>(frag_a[k2][i], frag_b0, frag_c[i][j][0]);
|
||||
mma<a_type_id>(frag_a[k2][i], frag_b1, frag_c[i][j][1]);
|
||||
mma<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b0,
|
||||
frag_c[i][j][0]);
|
||||
mma<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b1,
|
||||
frag_c[i][j][1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1583,10 +1468,12 @@ __global__ void Marlin(
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks; i++) {
|
||||
mma<a_type_id, 32>(frag_a[k2][i], frag_b[0],
|
||||
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]);
|
||||
mma<a_type_id, 32>(frag_a[k2][i], frag_b[1],
|
||||
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]);
|
||||
mma<a_type_id, false, 32>(
|
||||
frag_a[k2][i], frag_b[0],
|
||||
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]);
|
||||
mma<a_type_id, false, 32>(
|
||||
frag_a[k2][i], frag_b[1],
|
||||
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]);
|
||||
}
|
||||
|
||||
if constexpr (group_blocks != -1) {
|
||||
@ -2132,6 +2019,21 @@ __global__ void Marlin(
|
||||
// While this pattern may not be the most readable, other ways of writing
|
||||
// the loop seemed to noticeably worse performance after compilation.
|
||||
if (slice_iters == 0) {
|
||||
// convert fp16 accum to fp32 for reduction
|
||||
if constexpr (use_fp16_accum) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < (thread_m_blocks * (is_a_8bit ? 2 : 4) * 2); i++) {
|
||||
float* frag_c_part_float = reinterpret_cast<float*>(frag_c) + i * 4;
|
||||
scalar_t* frag_c_part_half =
|
||||
reinterpret_cast<scalar_t*>(frag_c_part_float);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 3; i >= 0; i--) {
|
||||
frag_c_part_float[i] = Cdtype::num2float(frag_c_part_half[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (is_a_8bit) {
|
||||
float frag_a_s[2 * thread_m_blocks];
|
||||
|
||||
|
||||
@ -142,7 +142,7 @@ typedef struct {
|
||||
|
||||
int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
|
||||
int prob_n, int prob_k, int num_bits, int group_size,
|
||||
bool has_act_order, bool is_k_full) {
|
||||
bool has_act_order, bool is_k_full, int stages) {
|
||||
bool cache_scales_chunk = has_act_order && !is_k_full;
|
||||
|
||||
int tb_n = th_config.thread_n;
|
||||
@ -160,13 +160,13 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
|
||||
|
||||
if (cache_scales_chunk) {
|
||||
int load_groups =
|
||||
tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
|
||||
tb_groups * stages * 2; // Chunk size is 2x pipeline over dim K
|
||||
load_groups = max(load_groups, 32); // We load at least 32 scale groups
|
||||
return load_groups * tb_n * 2;
|
||||
} else {
|
||||
int tb_scales = tb_groups * tb_n * 2;
|
||||
|
||||
return tb_scales * pipe_stages;
|
||||
return tb_scales * stages;
|
||||
}
|
||||
}
|
||||
|
||||
@ -174,7 +174,7 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
|
||||
int thread_m_blocks, int prob_m, int prob_n,
|
||||
int prob_k, int num_bits, int group_size,
|
||||
bool has_act_order, bool is_k_full, int has_zp,
|
||||
int is_zp_float, bool is_a_8bit) {
|
||||
int is_zp_float, bool is_a_8bit, int stages) {
|
||||
int pack_factor = 32 / num_bits;
|
||||
|
||||
// Get B size
|
||||
@ -185,8 +185,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
|
||||
// shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights
|
||||
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
|
||||
int sh_block_meta_size = tb_m * 16;
|
||||
int sh_a_size = pipe_stages * (tb_m * tb_k) * (is_a_8bit ? 1 : 2);
|
||||
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
|
||||
int sh_a_size = stages * (tb_m * tb_k) * (is_a_8bit ? 1 : 2);
|
||||
int sh_b_size = stages * (tb_k * tb_n / pack_factor) * 4;
|
||||
int sh_red_size = tb_m * (tb_n + 8) * 2;
|
||||
int sh_bias_size = tb_n * 2;
|
||||
int tmp_size =
|
||||
@ -195,8 +195,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
|
||||
|
||||
int sh_s_size =
|
||||
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
|
||||
group_size, has_act_order, is_k_full);
|
||||
int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0;
|
||||
group_size, has_act_order, is_k_full, stages);
|
||||
int sh_g_idx_size = has_act_order && !is_k_full ? stages * tb_k / 4 : 0;
|
||||
int sh_zp_size = 0;
|
||||
if (has_zp) {
|
||||
if (is_zp_float)
|
||||
@ -217,7 +217,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
|
||||
int thread_m_blocks, int prob_m, int prob_n, int prob_k,
|
||||
int num_bits, int group_size, bool has_act_order,
|
||||
bool is_k_full, int has_zp, int is_zp_float,
|
||||
int max_shared_mem, bool is_a_8bit) {
|
||||
bool is_a_8bit, int stages, int max_shared_mem) {
|
||||
// Sanity
|
||||
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
|
||||
th_config.num_threads == -1) {
|
||||
@ -243,7 +243,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
|
||||
int cache_size =
|
||||
get_kernel_cache_size(th_config, m_block_size_8, thread_m_blocks, prob_m,
|
||||
prob_n, prob_k, num_bits, group_size, has_act_order,
|
||||
is_k_full, has_zp, is_zp_float, is_a_8bit);
|
||||
is_k_full, has_zp, is_zp_float, is_a_8bit, stages);
|
||||
return cache_size <= max_shared_mem;
|
||||
}
|
||||
|
||||
@ -252,7 +252,7 @@ MarlinFuncPtr get_marlin_kernel(
|
||||
const vllm::ScalarType c_type, const vllm::ScalarType s_type,
|
||||
int thread_m_blocks, int thread_n_blocks, int thread_k_blocks,
|
||||
bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks,
|
||||
int threads, bool is_zp_float) {
|
||||
int threads, bool is_zp_float, int stages) {
|
||||
int num_bits = b_type.size_bits();
|
||||
auto kernel = MarlinDefault;
|
||||
|
||||
@ -266,8 +266,8 @@ exec_config_t determine_exec_config(
|
||||
const vllm::ScalarType& c_type, const vllm::ScalarType& s_type, int prob_m,
|
||||
int prob_n, int prob_k, int num_experts, int top_k, int thread_m_blocks,
|
||||
bool m_block_size_8, int num_bits, int group_size, bool has_act_order,
|
||||
bool is_k_full, bool has_zp, bool is_zp_float, int max_shared_mem, int sms,
|
||||
bool is_a_8bit) {
|
||||
bool is_k_full, bool has_zp, bool is_zp_float, bool is_a_8bit, int stages,
|
||||
int max_shared_mem, int sms) {
|
||||
exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}};
|
||||
thread_config_t* thread_configs = thread_m_blocks > 1
|
||||
? large_batch_thread_configs
|
||||
@ -284,15 +284,15 @@ exec_config_t determine_exec_config(
|
||||
|
||||
if (!is_valid_config(th_config, m_block_size_8, thread_m_blocks, prob_m,
|
||||
prob_n, prob_k, num_bits, group_size, has_act_order,
|
||||
is_k_full, has_zp, is_zp_float, max_shared_mem - 512,
|
||||
is_a_8bit)) {
|
||||
is_k_full, has_zp, is_zp_float, is_a_8bit, stages,
|
||||
max_shared_mem - 512)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int cache_size = get_kernel_cache_size(
|
||||
th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k,
|
||||
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float,
|
||||
is_a_8bit);
|
||||
is_a_8bit, stages);
|
||||
|
||||
int group_blocks = 0;
|
||||
if (!has_act_order) {
|
||||
@ -303,7 +303,7 @@ exec_config_t determine_exec_config(
|
||||
get_marlin_kernel(a_type, b_type, c_type, s_type, thread_m_blocks,
|
||||
th_config.thread_n / 16, th_config.thread_k / 16,
|
||||
m_block_size_8, has_act_order, has_zp, group_blocks,
|
||||
th_config.num_threads, is_zp_float);
|
||||
th_config.num_threads, is_zp_float, stages);
|
||||
|
||||
if (kernel == MarlinDefault) continue;
|
||||
|
||||
@ -433,8 +433,14 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
dev);
|
||||
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
|
||||
dev);
|
||||
TORCH_CHECK(major_capability * 10 + minor_capability >= 80,
|
||||
"marlin kernel only support Ampere or newer GPUs.");
|
||||
TORCH_CHECK(major_capability * 10 + minor_capability >= 75,
|
||||
"marlin kernel only support Turing or newer GPUs.");
|
||||
int stages = 4;
|
||||
if (major_capability == 7 && minor_capability == 5) {
|
||||
stages = 2;
|
||||
TORCH_CHECK(a_type == vllm::kFloat16 || a_type == vllm::kS8,
|
||||
"Turing only support FP16 or INT8 activation.");
|
||||
}
|
||||
if (a_type == vllm::kFE4M3fn) {
|
||||
TORCH_CHECK(major_capability * 10 + minor_capability >= 89,
|
||||
"FP8 only support Ada Lovelace or newer GPUs.");
|
||||
@ -461,8 +467,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
exec_cfg = determine_exec_config(
|
||||
a_type, b_type, c_type, s_type, prob_m, prob_n, prob_k, num_experts,
|
||||
top_k, thread_m_blocks, m_block_size_8, num_bits, group_size,
|
||||
has_act_order, is_k_full, has_zp, is_zp_float, max_shared_mem, sms,
|
||||
is_a_8bit);
|
||||
has_act_order, is_k_full, has_zp, is_zp_float, is_a_8bit, stages,
|
||||
max_shared_mem, sms);
|
||||
thread_tfg = exec_cfg.tb_cfg;
|
||||
}
|
||||
|
||||
@ -479,7 +485,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
TORCH_CHECK(is_valid_config(thread_tfg, m_block_size_8, thread_m_blocks,
|
||||
prob_m, prob_n, prob_k, num_bits, group_size,
|
||||
has_act_order, is_k_full, has_zp, is_zp_float,
|
||||
max_shared_mem, is_a_8bit),
|
||||
is_a_8bit, stages, max_shared_mem),
|
||||
"Invalid thread config: thread_m_blocks = ", thread_m_blocks,
|
||||
", thread_k = ", thread_tfg.thread_k,
|
||||
", thread_n = ", thread_tfg.thread_n,
|
||||
@ -493,12 +499,12 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
int sh_cache_size =
|
||||
get_kernel_cache_size(thread_tfg, m_block_size_8, thread_m_blocks, prob_m,
|
||||
prob_n, prob_k, num_bits, group_size, has_act_order,
|
||||
is_k_full, has_zp, is_zp_float, is_a_8bit);
|
||||
is_k_full, has_zp, is_zp_float, is_a_8bit, stages);
|
||||
|
||||
auto kernel = get_marlin_kernel(
|
||||
a_type, b_type, c_type, s_type, thread_m_blocks, thread_n_blocks,
|
||||
thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks,
|
||||
num_threads, is_zp_float);
|
||||
num_threads, is_zp_float, stages);
|
||||
|
||||
if (kernel == MarlinDefault) {
|
||||
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
|
||||
|
||||
1
csrc/quantization/gptq_marlin/.gitignore
vendored
1
csrc/quantization/gptq_marlin/.gitignore
vendored
@ -1,2 +1,3 @@
|
||||
sm*_kernel_*.cu
|
||||
kernel_selector.h
|
||||
kernel_*.cu
|
||||
|
||||
@ -67,7 +67,7 @@ where `scale_factor * multiplier` can be computed at weight loading.
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 750
|
||||
// Lookup-table based 3-input logical operation; explicitly used for
|
||||
// dequantization as the compiler does not seem to automatically recognize it in
|
||||
// all cases.
|
||||
|
||||
@ -10,6 +10,8 @@ import jinja2
|
||||
|
||||
ARCHS = []
|
||||
SUPPORT_FP8 = False
|
||||
SUPPORT_SM75 = False
|
||||
SUPPORT_SM80 = False
|
||||
for arch in sys.argv[1].split(","):
|
||||
arch = arch[: arch.index(".") + 2].replace(".", "")
|
||||
arch = int(arch)
|
||||
@ -19,6 +21,10 @@ for arch in sys.argv[1].split(","):
|
||||
# with FP16 MMA, so it cannot achieve any acceleration.
|
||||
if arch in [89, 120]:
|
||||
SUPPORT_FP8 = True
|
||||
if arch >= 80:
|
||||
SUPPORT_SM80 = True
|
||||
if arch == 75:
|
||||
SUPPORT_SM75 = True
|
||||
|
||||
FILE_HEAD_COMMENT = """
|
||||
// auto generated by generate_kernels.py
|
||||
@ -166,6 +172,7 @@ def remove_old_kernels():
|
||||
|
||||
def generate_new_kernels():
|
||||
result_dict = {}
|
||||
sm_75_result_dict = {}
|
||||
|
||||
for quant_config in QUANT_CONFIGS:
|
||||
c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"])
|
||||
@ -184,6 +191,8 @@ def generate_new_kernels():
|
||||
s_type = quant_config.get("s_type", c_type)
|
||||
if (a_type, b_type, c_type) not in result_dict:
|
||||
result_dict[(a_type, b_type, c_type)] = []
|
||||
if a_type in ["kFloat16", "kS8"] and c_type == "kFloat16":
|
||||
sm_75_result_dict[(a_type, b_type, c_type)] = []
|
||||
|
||||
for group_blocks, m_blocks, thread_configs in itertools.product(
|
||||
all_group_blocks, all_m_blocks, all_thread_configs
|
||||
@ -207,78 +216,89 @@ def generate_new_kernels():
|
||||
"thread_k_blocks": thread_k // 16,
|
||||
"thread_n_blocks": thread_n // 16,
|
||||
"m_block_size_8": "true" if m_blocks == 0.5 else "false",
|
||||
"stages": "pipe_stages",
|
||||
"stages": 4,
|
||||
"group_blocks": group_blocks,
|
||||
"is_zp_float": "true" if is_zp_float else "false",
|
||||
}
|
||||
|
||||
result_dict[(a_type, b_type, c_type)].append(config)
|
||||
if SUPPORT_SM80:
|
||||
result_dict[(a_type, b_type, c_type)].append(config)
|
||||
if (a_type, b_type, c_type) in sm_75_result_dict and SUPPORT_SM75:
|
||||
config_sm75 = config.copy()
|
||||
config_sm75["stages"] = 2
|
||||
sm_75_result_dict[(a_type, b_type, c_type)].append(config_sm75)
|
||||
|
||||
kernel_selector_str = FILE_HEAD_COMMENT
|
||||
|
||||
for (a_type, b_type, c_type), config_list in result_dict.items():
|
||||
all_template_str_list = []
|
||||
for config in config_list:
|
||||
s_type = config["s_type"]
|
||||
template_str = jinja2.Template(TEMPLATE).render(
|
||||
a_type_id=f"vllm::{a_type}.id()",
|
||||
b_type_id=f"vllm::{b_type}.id()",
|
||||
c_type_id=f"vllm::{c_type}.id()",
|
||||
s_type_id=f"vllm::{s_type}.id()",
|
||||
**config,
|
||||
)
|
||||
all_template_str_list.append(template_str)
|
||||
|
||||
conditions = [
|
||||
f"a_type == vllm::{a_type}",
|
||||
f"b_type == vllm::{b_type}",
|
||||
f"c_type == vllm::{c_type}",
|
||||
f"s_type == vllm::{s_type}",
|
||||
f"threads == {config['threads']}",
|
||||
f"thread_m_blocks == {config['thread_m_blocks']}",
|
||||
f"thread_n_blocks == {config['thread_n_blocks']}",
|
||||
f"thread_k_blocks == {config['thread_k_blocks']}",
|
||||
f"m_block_size_8 == {config['m_block_size_8']}",
|
||||
f"group_blocks == {config['group_blocks']}",
|
||||
f"is_zp_float == {config['is_zp_float']}",
|
||||
]
|
||||
conditions = " && ".join(conditions)
|
||||
|
||||
if kernel_selector_str == FILE_HEAD_COMMENT:
|
||||
kernel_selector_str += f"if ({conditions})\n kernel = "
|
||||
else:
|
||||
kernel_selector_str += f"else if ({conditions})\n kernel = "
|
||||
|
||||
kernel_template2 = (
|
||||
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
|
||||
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
|
||||
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
|
||||
"{{is_zp_float}}>;"
|
||||
)
|
||||
|
||||
kernel_selector_str += (
|
||||
jinja2.Template(kernel_template2).render(
|
||||
for result_dict_tmp in [result_dict, sm_75_result_dict]:
|
||||
for (a_type, b_type, c_type), config_list in result_dict_tmp.items():
|
||||
all_template_str_list = []
|
||||
if not config_list:
|
||||
continue
|
||||
for config in config_list:
|
||||
s_type = config["s_type"]
|
||||
template_str = jinja2.Template(TEMPLATE).render(
|
||||
a_type_id=f"vllm::{a_type}.id()",
|
||||
b_type_id=f"vllm::{b_type}.id()",
|
||||
c_type_id=f"vllm::{c_type}.id()",
|
||||
s_type_id=f"vllm::{s_type}.id()",
|
||||
**config,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
all_template_str_list.append(template_str)
|
||||
|
||||
file_content = FILE_HEAD + "\n\n"
|
||||
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
|
||||
if a_type == "kFE4M3fn":
|
||||
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
else:
|
||||
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
conditions = [
|
||||
f"a_type == vllm::{a_type}",
|
||||
f"b_type == vllm::{b_type}",
|
||||
f"c_type == vllm::{c_type}",
|
||||
f"s_type == vllm::{s_type}",
|
||||
f"threads == {config['threads']}",
|
||||
f"thread_m_blocks == {config['thread_m_blocks']}",
|
||||
f"thread_n_blocks == {config['thread_n_blocks']}",
|
||||
f"thread_k_blocks == {config['thread_k_blocks']}",
|
||||
f"m_block_size_8 == {config['m_block_size_8']}",
|
||||
f"stages == {config['stages']}",
|
||||
f"group_blocks == {config['group_blocks']}",
|
||||
f"is_zp_float == {config['is_zp_float']}",
|
||||
]
|
||||
conditions = " && ".join(conditions)
|
||||
|
||||
filename = filename.lower()
|
||||
if kernel_selector_str == FILE_HEAD_COMMENT:
|
||||
kernel_selector_str += f"if ({conditions})\n kernel = "
|
||||
else:
|
||||
kernel_selector_str += f"else if ({conditions})\n kernel = "
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
|
||||
f.write(file_content)
|
||||
kernel_template2 = (
|
||||
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
|
||||
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
|
||||
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
|
||||
"{{is_zp_float}}>;"
|
||||
)
|
||||
|
||||
kernel_selector_str += (
|
||||
jinja2.Template(kernel_template2).render(
|
||||
a_type_id=f"vllm::{a_type}.id()",
|
||||
b_type_id=f"vllm::{b_type}.id()",
|
||||
c_type_id=f"vllm::{c_type}.id()",
|
||||
s_type_id=f"vllm::{s_type}.id()",
|
||||
**config,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
file_content = FILE_HEAD + "\n\n"
|
||||
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
|
||||
if a_type == "kFE4M3fn":
|
||||
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
elif result_dict_tmp is sm_75_result_dict:
|
||||
filename = f"sm75_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
else:
|
||||
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
|
||||
filename = filename.lower()
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
|
||||
f.write(file_content)
|
||||
|
||||
if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT:
|
||||
kernel_selector_str += (
|
||||
|
||||
@ -37,7 +37,7 @@ __global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){};
|
||||
|
||||
using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS);
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
||||
|
||||
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
|
||||
int const* __restrict__ perm_int_ptr,
|
||||
@ -148,7 +148,7 @@ typedef struct {
|
||||
|
||||
int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
|
||||
int prob_n, int prob_k, int num_bits, int group_size,
|
||||
bool has_act_order, bool is_k_full) {
|
||||
bool has_act_order, bool is_k_full, int stages) {
|
||||
bool cache_scales_chunk = has_act_order && !is_k_full;
|
||||
|
||||
int tb_n = th_config.thread_n;
|
||||
@ -166,28 +166,29 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
|
||||
|
||||
if (cache_scales_chunk) {
|
||||
int load_groups =
|
||||
tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
|
||||
tb_groups * stages * 2; // Chunk size is 2x pipeline over dim K
|
||||
load_groups = max(load_groups, 32); // We load at least 32 scale groups
|
||||
return load_groups * tb_n * 2;
|
||||
} else {
|
||||
int tb_scales = tb_groups * tb_n * 2;
|
||||
|
||||
return tb_scales * pipe_stages;
|
||||
return tb_scales * stages;
|
||||
}
|
||||
}
|
||||
|
||||
int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
|
||||
int prob_m, int prob_n, int prob_k, int num_bits,
|
||||
int group_size, bool has_act_order, bool is_k_full,
|
||||
int has_zp, int is_zp_float) {
|
||||
int has_zp, bool is_zp_float, bool is_a_8bit,
|
||||
int stages) {
|
||||
int pack_factor = 32 / num_bits;
|
||||
|
||||
// Get B size
|
||||
int tb_k = th_config.thread_k;
|
||||
int tb_n = th_config.thread_n;
|
||||
int tb_m = thread_m_blocks * 16;
|
||||
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2;
|
||||
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
|
||||
int sh_a_size = stages * (tb_m * tb_k) * (is_a_8bit ? 1 : 2);
|
||||
int sh_b_size = stages * (tb_k * tb_n / pack_factor) * 4;
|
||||
int sh_red_size = tb_m * (tb_n + 8) * 2;
|
||||
int sh_bias_size = tb_n * 2;
|
||||
int tmp_size =
|
||||
@ -196,8 +197,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
|
||||
|
||||
int sh_s_size =
|
||||
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
|
||||
group_size, has_act_order, is_k_full);
|
||||
int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0;
|
||||
group_size, has_act_order, is_k_full, stages);
|
||||
int sh_g_idx_size = has_act_order && !is_k_full ? stages * tb_k / 4 : 0;
|
||||
int sh_zp_size = 0;
|
||||
if (has_zp) {
|
||||
if (is_zp_float)
|
||||
@ -217,7 +218,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
|
||||
bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
|
||||
int prob_m, int prob_n, int prob_k, int num_bits,
|
||||
int group_size, bool has_act_order, bool is_k_full,
|
||||
int has_zp, int is_zp_float, int max_shared_mem) {
|
||||
int has_zp, bool is_zp_float, bool is_a_8bit, int stages,
|
||||
int max_shared_mem) {
|
||||
// Sanity
|
||||
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
|
||||
th_config.num_threads == -1) {
|
||||
@ -242,7 +244,7 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
|
||||
// Check that pipeline fits into cache
|
||||
int cache_size = get_kernel_cache_size(
|
||||
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size,
|
||||
has_act_order, is_k_full, has_zp, is_zp_float);
|
||||
has_act_order, is_k_full, has_zp, is_zp_float, is_a_8bit, stages);
|
||||
return cache_size <= max_shared_mem;
|
||||
}
|
||||
|
||||
@ -251,7 +253,7 @@ MarlinFuncPtr get_marlin_kernel(
|
||||
const vllm::ScalarType c_type, const vllm::ScalarType s_type,
|
||||
int thread_m_blocks, int thread_n_blocks, int thread_k_blocks,
|
||||
bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks,
|
||||
int threads, bool is_zp_float) {
|
||||
int threads, bool is_zp_float, int stages) {
|
||||
int num_bits = b_type.size_bits();
|
||||
auto kernel = MarlinDefault;
|
||||
|
||||
@ -265,7 +267,8 @@ exec_config_t determine_exec_config(
|
||||
const vllm::ScalarType& c_type, const vllm::ScalarType& s_type, int prob_m,
|
||||
int prob_n, int prob_k, int thread_m_blocks, bool m_block_size_8,
|
||||
int num_bits, int group_size, bool has_act_order, bool is_k_full,
|
||||
bool has_zp, bool is_zp_float, int max_shared_mem, int sms) {
|
||||
bool has_zp, bool is_zp_float, int is_a_8bit, int stages,
|
||||
int max_shared_mem, int sms) {
|
||||
exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}};
|
||||
thread_config_t* thread_configs = thread_m_blocks > 1
|
||||
? large_batch_thread_configs
|
||||
@ -280,13 +283,15 @@ exec_config_t determine_exec_config(
|
||||
|
||||
if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k,
|
||||
num_bits, group_size, has_act_order, is_k_full, has_zp,
|
||||
is_zp_float, max_shared_mem - 512)) {
|
||||
is_zp_float, is_a_8bit, stages,
|
||||
max_shared_mem - 512)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int cache_size = get_kernel_cache_size(
|
||||
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits,
|
||||
group_size, has_act_order, is_k_full, has_zp, is_zp_float);
|
||||
int cache_size = get_kernel_cache_size(th_config, thread_m_blocks, prob_m,
|
||||
prob_n, prob_k, num_bits, group_size,
|
||||
has_act_order, is_k_full, has_zp,
|
||||
is_zp_float, is_a_8bit, stages);
|
||||
|
||||
int group_blocks = 0;
|
||||
if (!has_act_order) {
|
||||
@ -297,14 +302,10 @@ exec_config_t determine_exec_config(
|
||||
get_marlin_kernel(a_type, b_type, c_type, s_type, thread_m_blocks,
|
||||
th_config.thread_n / 16, th_config.thread_k / 16,
|
||||
m_block_size_8, has_act_order, has_zp, group_blocks,
|
||||
th_config.num_threads, is_zp_float);
|
||||
th_config.num_threads, is_zp_float, stages);
|
||||
|
||||
if (kernel == MarlinDefault) continue;
|
||||
|
||||
// int m_tiles = div_ceil(prob_m, thread_m_blocks * 16);
|
||||
// int n_tiles = prob_n / th_config.thread_n;
|
||||
// int k_tiles = prob_k / th_config.thread_k;
|
||||
|
||||
return {1, th_config};
|
||||
}
|
||||
|
||||
@ -321,6 +322,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
int group_size, int dev, cudaStream_t stream, int thread_k_init,
|
||||
int thread_n_init, int sms, bool use_atomic_add,
|
||||
bool use_fp32_reduce, bool is_zp_float) {
|
||||
bool is_a_8bit = a_type.size_bits() == 8;
|
||||
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
||||
", ", prob_n, ", ", prob_k, "]");
|
||||
|
||||
@ -389,8 +391,14 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
dev);
|
||||
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
|
||||
dev);
|
||||
TORCH_CHECK(major_capability * 10 + minor_capability >= 80,
|
||||
"marlin kernel only support Ampere or newer GPUs.");
|
||||
TORCH_CHECK(major_capability * 10 + minor_capability >= 75,
|
||||
"marlin kernel only support Turing or newer GPUs.");
|
||||
int stages = 4;
|
||||
if (major_capability == 7 && minor_capability == 5) {
|
||||
stages = 2;
|
||||
TORCH_CHECK(a_type == vllm::kFloat16 || a_type == vllm::kS8,
|
||||
"Turing only support FP16 or INT8 activation.");
|
||||
}
|
||||
if (a_type == vllm::kFE4M3fn) {
|
||||
TORCH_CHECK(
|
||||
major_capability * 10 + minor_capability == 89 ||
|
||||
@ -431,7 +439,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
exec_cfg = determine_exec_config(
|
||||
a_type, b_type, c_type, s_type, prob_m_split, prob_n, prob_k,
|
||||
thread_m_blocks, m_block_size_8, num_bits, group_size, has_act_order,
|
||||
is_k_full, has_zp, is_zp_float, max_shared_mem, sms);
|
||||
is_k_full, has_zp, is_zp_float, is_a_8bit, stages, max_shared_mem,
|
||||
sms);
|
||||
thread_tfg = exec_cfg.tb_cfg;
|
||||
if (thread_tfg.thread_n != -1) {
|
||||
if (prob_n / thread_tfg.thread_n *
|
||||
@ -440,7 +449,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
if (is_valid_config({128, 64, 128}, thread_m_blocks, prob_m_split,
|
||||
prob_n, prob_k, num_bits, group_size,
|
||||
has_act_order, is_k_full, has_zp, is_zp_float,
|
||||
max_shared_mem_new)) {
|
||||
is_a_8bit, stages, max_shared_mem_new)) {
|
||||
thread_tfg = {128, 64, 128};
|
||||
exec_cfg = {1, thread_tfg};
|
||||
}
|
||||
@ -466,7 +475,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
TORCH_CHECK(
|
||||
is_valid_config(thread_tfg, thread_m_blocks, prob_m_split, prob_n,
|
||||
prob_k, num_bits, group_size, has_act_order, is_k_full,
|
||||
has_zp, is_zp_float, max_shared_mem_new),
|
||||
has_zp, is_zp_float, is_a_8bit, stages,
|
||||
max_shared_mem_new),
|
||||
"Invalid thread config: thread_m_blocks = ", thread_m_blocks,
|
||||
", thread_k = ", thread_tfg.thread_k,
|
||||
", thread_n = ", thread_tfg.thread_n,
|
||||
@ -475,12 +485,12 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
", prob_m_split = ", prob_m_split, ", group_size = ", group_size,
|
||||
", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
|
||||
", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float,
|
||||
", max_shared_mem_new = ", max_shared_mem_new);
|
||||
", stages = ", stages, ", max_shared_mem_new = ", max_shared_mem_new);
|
||||
|
||||
auto kernel = get_marlin_kernel(
|
||||
a_type, b_type, c_type, s_type, thread_m_blocks, thread_n_blocks,
|
||||
thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks,
|
||||
num_threads, is_zp_float);
|
||||
num_threads, is_zp_float, stages);
|
||||
|
||||
if (kernel == MarlinDefault) {
|
||||
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
|
||||
|
||||
@ -1,17 +1,19 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/all.h>
|
||||
#ifndef _marlin_cuh
|
||||
#define _marlin_cuh
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <iostream>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <iostream>
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin
|
||||
#endif
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin
|
||||
#endif
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
@ -51,9 +53,51 @@ using I4 = Vec<int, 4>;
|
||||
|
||||
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
// No support for async
|
||||
#else
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
|
||||
__device__ inline void cp_async1_ca_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
if (pred) {
|
||||
reinterpret_cast<int32_t*>(smem_ptr)[0] =
|
||||
reinterpret_cast<const int32_t*>(glob_ptr)[0];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void cp_async2_ca_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
if (pred) {
|
||||
reinterpret_cast<int64_t*>(smem_ptr)[0] =
|
||||
reinterpret_cast<const int64_t*>(glob_ptr)[0];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void cp_async4_ca_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
if (pred) {
|
||||
reinterpret_cast<int4*>(smem_ptr)[0] =
|
||||
reinterpret_cast<const int4*>(glob_ptr)[0];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
if (pred) {
|
||||
reinterpret_cast<int4*>(smem_ptr)[0] =
|
||||
reinterpret_cast<const int4*>(glob_ptr)[0];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
||||
reinterpret_cast<int4*>(smem_ptr)[0] =
|
||||
reinterpret_cast<const int4*>(glob_ptr)[0];
|
||||
}
|
||||
|
||||
__device__ inline void cp_async_fence() {}
|
||||
|
||||
template <int n>
|
||||
__device__ inline void cp_async_wait() {}
|
||||
|
||||
#else
|
||||
|
||||
__device__ inline void cp_async1_ca_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
@ -126,6 +170,8 @@ __device__ inline void cp_async_wait() {
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
||||
}
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
|
||||
#endif
|
||||
269
csrc/quantization/gptq_marlin/marlin_mma.h
Normal file
269
csrc/quantization/gptq_marlin/marlin_mma.h
Normal file
@ -0,0 +1,269 @@
|
||||
|
||||
#include "marlin_dtypes.cuh"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
|
||||
// output/accumulation.
|
||||
template <vllm::ScalarTypeId type_id, bool use_fp16_accum, int k_size = 16>
|
||||
__device__ inline void mma(
|
||||
const typename MarlinScalarType<type_id>::FragA& a_frag,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b,
|
||||
typename MarlinScalarType<type_id>::FragC& frag_c, int idx = 0) {
|
||||
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
|
||||
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
||||
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
|
||||
if constexpr (!std::is_same<scalar_t, half>::value || k_size != 16) {
|
||||
static_assert(!use_fp16_accum);
|
||||
}
|
||||
|
||||
if constexpr (k_size == 16) {
|
||||
if constexpr (std::is_same<scalar_t, half>::value && !use_fp16_accum) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(b[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
|
||||
"f"(c[3]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[2]), "r"(a[3]), "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
|
||||
"f"(c[3]));
|
||||
#else
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
#endif
|
||||
} else if constexpr (std::is_same<scalar_t, half>::value &&
|
||||
use_fp16_accum) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
uint32_t* c = reinterpret_cast<uint32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
|
||||
"{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(b[0]), "r"(c[0]), "r"(c[1]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
|
||||
"{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(a[2]), "r"(a[3]), "r"(b[1]), "r"(c[0]), "r"(c[1]));
|
||||
#else
|
||||
uint32_t* c = reinterpret_cast<uint32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
|
||||
"{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"r"(c[0]), "r"(c[1]));
|
||||
#endif
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "f"(c[0]),
|
||||
"f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "r"(c[0]),
|
||||
"r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
}
|
||||
} else if (k_size == 32) {
|
||||
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(a[0]), "r"(b[0]), "r"(c[0]), "r"(c[1]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
||||
: "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(a[1]), "r"(b[0]), "r"(c[2]), "r"(c[3]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(a[2]), "r"(b[1]), "r"(c[0]), "r"(c[1]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
||||
: "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(a[3]), "r"(b[1]), "r"(c[2]), "r"(c[3]));
|
||||
#else
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <vllm::ScalarTypeId type_id, bool use_fp16_accum, int k_size = 16>
|
||||
__device__ inline void mma_trans(
|
||||
const typename MarlinScalarType<type_id>::FragA& a_frag,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b2,
|
||||
typename MarlinScalarType<type_id>::FragC& frag_c) {
|
||||
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
|
||||
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
||||
const uint32_t* b2 = reinterpret_cast<const uint32_t*>(&frag_b2);
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
|
||||
if constexpr (!std::is_same<scalar_t, half>::value || k_size != 16) {
|
||||
static_assert(!use_fp16_accum);
|
||||
}
|
||||
|
||||
if constexpr (k_size == 16) {
|
||||
if constexpr (std::is_same<scalar_t, half>::value && !use_fp16_accum) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
|
||||
"f"(c[3]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[1]), "r"(b2[1]), "r"(a[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
|
||||
"f"(c[3]));
|
||||
#else
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
#endif
|
||||
} else if constexpr (std::is_same<scalar_t, half>::value &&
|
||||
use_fp16_accum) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
uint32_t* c = reinterpret_cast<uint32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
|
||||
"{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
|
||||
"{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(b[1]), "r"(b2[1]), "r"(a[1]), "r"(c[0]), "r"(c[1]));
|
||||
#else
|
||||
uint32_t* c = reinterpret_cast<uint32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
|
||||
"{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"r"(c[0]), "r"(c[1]));
|
||||
#endif
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
|
||||
"f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]),
|
||||
"r"(c[3]));
|
||||
}
|
||||
} else {
|
||||
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(b[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
||||
: "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(b2[1]), "r"(a[0]), "r"(c[2]), "r"(c[3]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(b[0]), "r"(a[1]), "r"(c[0]), "r"(c[1]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
||||
: "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(b2[1]), "r"(a[1]), "r"(c[2]), "r"(c[3]));
|
||||
#else
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
@ -26,6 +26,7 @@
|
||||
#include "marlin.cuh"
|
||||
#include "marlin_dtypes.cuh"
|
||||
#include "dequant.h"
|
||||
#include "marlin_mma.h"
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||
@ -35,7 +36,7 @@
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
||||
|
||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const vllm::ScalarTypeId b_type_id, // weight MarlinScalarType id
|
||||
@ -75,137 +76,6 @@ __global__ void Marlin(
|
||||
|
||||
#else
|
||||
|
||||
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
|
||||
// output/accumulation.
|
||||
template <vllm::ScalarTypeId type_id, int k_size = 16>
|
||||
__device__ inline void mma(
|
||||
const typename MarlinScalarType<type_id>::FragA& a_frag,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b,
|
||||
typename MarlinScalarType<type_id>::FragC& frag_c, int idx = 0) {
|
||||
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
|
||||
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
||||
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
|
||||
if constexpr (k_size == 16) {
|
||||
if constexpr (std::is_same<scalar_t, half>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "f"(c[0]),
|
||||
"f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "r"(c[0]),
|
||||
"r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
}
|
||||
} else if (k_size == 32) {
|
||||
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <vllm::ScalarTypeId type_id, int k_size = 16>
|
||||
__device__ inline void mma_trans(
|
||||
const typename MarlinScalarType<type_id>::FragA& a_frag,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b2,
|
||||
typename MarlinScalarType<type_id>::FragC& frag_c) {
|
||||
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
|
||||
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
||||
const uint32_t* b2 = reinterpret_cast<const uint32_t*>(&frag_b2);
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
|
||||
if constexpr (k_size == 16) {
|
||||
if constexpr (std::is_same<scalar_t, half>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
|
||||
"f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]),
|
||||
"r"(c[3]));
|
||||
}
|
||||
} else {
|
||||
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
||||
// memory, directly in tensor core layout.
|
||||
template <int count, vllm::ScalarTypeId type_id>
|
||||
@ -415,6 +285,17 @@ __global__ void Marlin(
|
||||
if constexpr (a_type_id == vllm::kFE4M3fn.id()) return;
|
||||
#endif
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
// Turing TensorCore only supports fp16 and int8
|
||||
if constexpr (a_type_id != vllm::kFloat16.id() && a_type_id != vllm::kS8.id())
|
||||
return;
|
||||
#endif
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
constexpr bool use_fp16_accum = a_type_id == vllm::kFloat16.id();
|
||||
#else
|
||||
constexpr bool use_fp16_accum = false;
|
||||
#endif
|
||||
using Adtype = MarlinScalarType<a_type_id>;
|
||||
using Cdtype = MarlinScalarType<c_type_id>;
|
||||
const int4* A = A0;
|
||||
@ -873,10 +754,6 @@ __global__ void Marlin(
|
||||
constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride)
|
||||
: (stages * s_sh_stage);
|
||||
int4* sh_s = sh_zp + (stages * zp_sh_stage);
|
||||
// shared memory reused by reduction should be smaller than
|
||||
// shared memory used by weight.
|
||||
static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <=
|
||||
stages * b_sh_stage);
|
||||
int4* sh_a = sh_s + sh_s_size;
|
||||
|
||||
// Register storage for double buffer of shared memory reads.
|
||||
@ -1395,11 +1272,13 @@ __global__ void Marlin(
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks; i++) {
|
||||
if constexpr (m_block_size_8) {
|
||||
mma_trans<a_type_id>(frag_a[k2][i], frag_b0, frag_b1,
|
||||
frag_c[i][j][0]);
|
||||
mma_trans<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b0, frag_b1,
|
||||
frag_c[i][j][0]);
|
||||
} else {
|
||||
mma<a_type_id>(frag_a[k2][i], frag_b0, frag_c[i][j][0]);
|
||||
mma<a_type_id>(frag_a[k2][i], frag_b1, frag_c[i][j][1]);
|
||||
mma<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b0,
|
||||
frag_c[i][j][0]);
|
||||
mma<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b1,
|
||||
frag_c[i][j][1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1433,10 +1312,12 @@ __global__ void Marlin(
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks; i++) {
|
||||
mma<a_type_id, 32>(frag_a[k2][i], frag_b[0],
|
||||
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]);
|
||||
mma<a_type_id, 32>(frag_a[k2][i], frag_b[1],
|
||||
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]);
|
||||
mma<a_type_id, false, 32>(
|
||||
frag_a[k2][i], frag_b[0],
|
||||
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]);
|
||||
mma<a_type_id, false, 32>(
|
||||
frag_a[k2][i], frag_b[1],
|
||||
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]);
|
||||
}
|
||||
|
||||
if constexpr (group_blocks != -1) {
|
||||
@ -1956,6 +1837,21 @@ __global__ void Marlin(
|
||||
// While this pattern may not be the most readable, other ways of writing
|
||||
// the loop seemed to noticeably worse performance after compilation.
|
||||
if (slice_iters == 0) {
|
||||
// convert fp16 accum to fp32 for reduction
|
||||
if constexpr (use_fp16_accum) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < (thread_m_blocks * (is_a_8bit ? 2 : 4) * 2); i++) {
|
||||
float* frag_c_part_float = reinterpret_cast<float*>(frag_c) + i * 4;
|
||||
scalar_t* frag_c_part_half =
|
||||
reinterpret_cast<scalar_t*>(frag_c_part_float);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 3; i >= 0; i--) {
|
||||
frag_c_part_float[i] = Cdtype::num2float(frag_c_part_half[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (is_a_8bit) {
|
||||
float frag_a_s[2 * thread_m_blocks];
|
||||
|
||||
|
||||
@ -1,373 +0,0 @@
|
||||
#include "core/registration.h"
|
||||
|
||||
#include <torch/all.h>
|
||||
#include <cutlass/arch/arch.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include <cassert>
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename ElementAB, typename ElementC, typename ElementAccumulator,
|
||||
typename LayoutSFA, typename LayoutSFB, typename ScaleConfig>
|
||||
__global__ void get_ggemm_starts(
|
||||
int32_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets,
|
||||
ElementC** out_offsets, ElementAccumulator** a_scale_offsets,
|
||||
ElementAccumulator** b_scale_offsets, ElementAB* a_base_as_int,
|
||||
ElementAB* b_base_as_int, ElementC* out_base_as_int,
|
||||
ElementAccumulator* a_scale_base_as_int,
|
||||
ElementAccumulator* b_scale_base_as_int, LayoutSFA* layout_sfa_base_as_int,
|
||||
LayoutSFB* layout_sfb_base_as_int, int* problem_sizes) {
|
||||
int expert_id = threadIdx.x;
|
||||
|
||||
if (expert_id >= gridDim.x * blockDim.x) {
|
||||
return;
|
||||
}
|
||||
|
||||
int m = problem_sizes[expert_id * 3];
|
||||
int n = problem_sizes[expert_id * 3 + 1];
|
||||
int k = problem_sizes[expert_id * 3 + 2];
|
||||
|
||||
int32_t expert_offset = expert_offsets[expert_id];
|
||||
int a_stride = expert_offset * k;
|
||||
int b_stride = expert_id * k * n;
|
||||
int a_scale_stride = expert_offset * k / 128;
|
||||
int b_scale_stride = expert_id * k * n / 128 / 128;
|
||||
|
||||
a_offsets[expert_id] = a_base_as_int + a_stride;
|
||||
b_offsets[expert_id] = b_base_as_int + b_stride;
|
||||
out_offsets[expert_id] = out_base_as_int + expert_offset * n;
|
||||
a_scale_offsets[expert_id] = a_scale_base_as_int + a_scale_stride;
|
||||
b_scale_offsets[expert_id] = b_scale_base_as_int + b_scale_stride;
|
||||
|
||||
LayoutSFA* layout_sfa_ptr = layout_sfa_base_as_int + expert_id;
|
||||
LayoutSFB* layout_sfb_ptr = layout_sfb_base_as_int + expert_id;
|
||||
|
||||
*layout_sfa_ptr =
|
||||
ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1));
|
||||
*layout_sfb_ptr =
|
||||
ScaleConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1));
|
||||
}
|
||||
|
||||
#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB, \
|
||||
ScaleConfig) \
|
||||
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
|
||||
get_ggemm_starts<cutlass::float_e4m3_t, C_TYPE, float, LayoutSFA, \
|
||||
LayoutSFB, ScaleConfig><<<1, num_experts, 0, stream>>>( \
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()), \
|
||||
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
|
||||
static_cast<float**>(a_scales_ptrs.data_ptr()), \
|
||||
static_cast<float**>(b_scales_ptrs.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t*>(b_tensors.data_ptr()), \
|
||||
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
|
||||
static_cast<float*>(a_scales.data_ptr()), \
|
||||
static_cast<float*>(b_scales.data_ptr()), \
|
||||
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()), \
|
||||
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr()), \
|
||||
static_cast<int*>(problem_sizes.data_ptr())); \
|
||||
}
|
||||
|
||||
template <typename LayoutSFA, typename LayoutSFB, typename ScaleConfig>
|
||||
void run_get_ggemm_starts(
|
||||
torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs,
|
||||
torch::Tensor& b_ptrs, torch::Tensor& out_ptrs,
|
||||
torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs,
|
||||
torch::Tensor const& a_tensors, torch::Tensor const& b_tensors,
|
||||
torch::Tensor out_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& layout_sfa,
|
||||
torch::Tensor const& layout_sfb, torch::Tensor const& problem_sizes) {
|
||||
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(out_tensors.size(1) % 128 == 0 or out_tensors.size(0) % 128 == 0);
|
||||
TORCH_CHECK(a_tensors.size(1) % 128 == 0 or a_tensors.size(0) % 128 == 0);
|
||||
|
||||
int num_experts = (int)expert_offsets.size(0);
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
||||
|
||||
if (false) {
|
||||
}
|
||||
__CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t, LayoutSFA,
|
||||
LayoutSFB, ScaleConfig)
|
||||
__CALL_GET_STARTS_KERNEL(torch::kFloat16, cutlass::half_t, LayoutSFA,
|
||||
LayoutSFB, ScaleConfig)
|
||||
else {
|
||||
TORCH_CHECK(false, "Unsupported output tensor type");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename OutType, typename ScheduleConfig, typename LayoutD>
|
||||
void run_blockwise_scaled_group_mm(
|
||||
torch::Tensor& out_ptrs, const torch::Tensor& a_ptrs,
|
||||
const torch::Tensor& b_ptrs, const torch::Tensor& a_scales_ptrs,
|
||||
const torch::Tensor& b_scales_ptrs, const torch::Tensor& stride_a,
|
||||
const torch::Tensor& stride_b, const torch::Tensor& stride_c,
|
||||
const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb,
|
||||
const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets) {
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>;
|
||||
|
||||
// Types
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = OutType;
|
||||
using ElementD = ElementC;
|
||||
using ElementAccumulator = float;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = LayoutD;
|
||||
|
||||
// Alignments
|
||||
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
||||
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, typename ScheduleConfig::MmaTileShape,
|
||||
typename ScheduleConfig::ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
|
||||
ElementAccumulator, void, LayoutC*, AlignmentC, ElementD, LayoutC*,
|
||||
AlignmentC, typename ScheduleConfig::EpilogueSchedule>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, ElementA,
|
||||
cute::tuple<LayoutA*, typename ScheduleConfig::LayoutSFA*>,
|
||||
AlignmentA, ElementB,
|
||||
cute::tuple<LayoutB*, typename ScheduleConfig::LayoutSFB*>,
|
||||
AlignmentB, ElementAccumulator, typename ScheduleConfig::MmaTileShape,
|
||||
typename ScheduleConfig::ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
typename ScheduleConfig::KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel =
|
||||
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop,
|
||||
CollectiveEpilogue, void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
|
||||
|
||||
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
|
||||
int num_experts = (int)expert_offsets.size(0);
|
||||
|
||||
Gemm gemm_op;
|
||||
|
||||
// Mainloop Arguments
|
||||
typename GemmKernel::MainloopArguments mainloop_args{
|
||||
static_cast<const ElementA**>(a_ptrs.data_ptr()),
|
||||
static_cast<StrideA*>(stride_a.data_ptr()),
|
||||
static_cast<const ElementB**>(b_ptrs.data_ptr()),
|
||||
static_cast<StrideB*>(stride_b.data_ptr()),
|
||||
static_cast<const ElementAccumulator**>(a_scales_ptrs.data_ptr()),
|
||||
reinterpret_cast<typename ScheduleConfig::LayoutSFA*>(
|
||||
layout_sfa.data_ptr()),
|
||||
static_cast<const ElementAccumulator**>(b_scales_ptrs.data_ptr()),
|
||||
reinterpret_cast<typename ScheduleConfig::LayoutSFB*>(
|
||||
layout_sfb.data_ptr())};
|
||||
|
||||
int device_id = a_ptrs.device().index();
|
||||
static const cutlass::KernelHardwareInfo hw_info{
|
||||
device_id, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
|
||||
device_id)};
|
||||
|
||||
// Epilogue Arguments
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
{}, // epilogue.thread
|
||||
nullptr,
|
||||
static_cast<StrideC*>(stride_c.data_ptr()),
|
||||
static_cast<ElementD**>(out_ptrs.data_ptr()),
|
||||
static_cast<StrideC*>(stride_c.data_ptr())};
|
||||
|
||||
UnderlyingProblemShape* problem_sizes_as_shapes =
|
||||
static_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr());
|
||||
|
||||
// Gemm Arguments
|
||||
typename GemmKernel::Arguments args{
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{num_experts, problem_sizes_as_shapes, nullptr},
|
||||
mainloop_args,
|
||||
epilogue_args,
|
||||
hw_info};
|
||||
|
||||
at::cuda::CUDAGuard device_guard{(char)a_ptrs.device().index()};
|
||||
const cudaStream_t stream =
|
||||
at::cuda::getCurrentCUDAStream(a_ptrs.get_device());
|
||||
|
||||
auto can_implement_status = gemm_op.can_implement(args);
|
||||
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
|
||||
"Failed to implement GEMM");
|
||||
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(a_ptrs.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
auto status = gemm_op.initialize(args, workspace.data_ptr(), stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
|
||||
|
||||
status = gemm_op.run(stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void blockwise_scaled_group_mm_dispatch_shape(
|
||||
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
|
||||
const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets) {
|
||||
struct MmaConfig {
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
|
||||
using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig<
|
||||
1, 128, 128, cute::UMMA::Major::K, cute::UMMA::Major::K>;
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using MmaTileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
};
|
||||
|
||||
int num_experts = (int)expert_offsets.size(0);
|
||||
|
||||
auto a_ptrs = torch::empty(
|
||||
{num_experts},
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
|
||||
auto b_ptrs = torch::empty(
|
||||
{num_experts},
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
|
||||
auto out_ptrs = torch::empty(
|
||||
{num_experts},
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
|
||||
auto a_scales_ptrs = torch::empty(
|
||||
{num_experts},
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
|
||||
auto b_scales_ptrs = torch::empty(
|
||||
{num_experts},
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
|
||||
|
||||
auto layout_sfa = torch::empty(
|
||||
{num_experts, 5},
|
||||
torch::TensorOptions().dtype(torch::kInt32).device(a.device()));
|
||||
auto layout_sfb = torch::empty(
|
||||
{num_experts, 5},
|
||||
torch::TensorOptions().dtype(torch::kInt32).device(a.device()));
|
||||
|
||||
auto stride_a = torch::full(
|
||||
{num_experts}, a.size(1),
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
|
||||
auto stride_b = torch::full(
|
||||
{num_experts}, a.size(1),
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
|
||||
auto stride_c = torch::full(
|
||||
{num_experts}, output.size(1),
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
|
||||
|
||||
torch::TensorOptions options_int =
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device());
|
||||
|
||||
run_get_ggemm_starts<typename MmaConfig::LayoutSFA,
|
||||
typename MmaConfig::LayoutSFB,
|
||||
typename MmaConfig::ScaleConfig>(
|
||||
expert_offsets, a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, a,
|
||||
b, output, scales_a, scales_b, layout_sfa, layout_sfb, problem_sizes);
|
||||
|
||||
run_blockwise_scaled_group_mm<OutType, MmaConfig,
|
||||
typename MmaConfig::LayoutC>(
|
||||
out_ptrs, a_ptrs, b_ptrs, a_scales_ptrs, b_scales_ptrs, stride_a,
|
||||
stride_b, stride_c, layout_sfa, layout_sfb, problem_sizes,
|
||||
expert_offsets);
|
||||
}
|
||||
|
||||
void cutlass_blockwise_scaled_grouped_mm(
|
||||
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
|
||||
const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets) {
|
||||
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
|
||||
TORCH_CHECK(problem_sizes.size(1) == 3,
|
||||
"problem_sizes must have shape (num_experts, 3)");
|
||||
TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0),
|
||||
"Number of experts in problem_sizes must match expert_offsets");
|
||||
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
|
||||
"problem_sizes must be int32");
|
||||
TORCH_CHECK(a.scalar_type() == torch::kFloat8_e4m3fn,
|
||||
"a must be kFloat8_e4m3fn");
|
||||
TORCH_CHECK(b.scalar_type() == torch::kFloat8_e4m3fn,
|
||||
"b must be kFloat8_e4m3fn");
|
||||
TORCH_CHECK(output.scalar_type() == torch::kBFloat16 ||
|
||||
output.scalar_type() == torch::kHalf,
|
||||
"output must be bfloat16 or half");
|
||||
TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32,
|
||||
"scales_a must be float32");
|
||||
TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32,
|
||||
"scales_b must be float32");
|
||||
TORCH_CHECK(expert_offsets.scalar_type() == torch::kInt32,
|
||||
"expert_offsets must be int32");
|
||||
|
||||
TORCH_CHECK(output.dim() == 2, "output must be 2D tensor");
|
||||
TORCH_CHECK(a.dim() == 2, "a must be 2D tensor");
|
||||
TORCH_CHECK(b.dim() == 3, "b must be 3D tensor");
|
||||
TORCH_CHECK(scales_a.dim() == 2, "scales_a must be 2D tensor");
|
||||
TORCH_CHECK(scales_b.dim() == 3, "scales_b must be 3D tensor");
|
||||
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
|
||||
TORCH_CHECK(problem_sizes.size(1) == 3,
|
||||
"problem_sizes must have shape (num_experts, 3)");
|
||||
TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0),
|
||||
"Number of experts in problem_sizes must match expert_offsets");
|
||||
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
|
||||
"problem_sizes must be int32");
|
||||
TORCH_CHECK(expert_offsets.dim() == 1, "expert_offsets must be 1D tensor");
|
||||
|
||||
#if defined(ENABLE_CUTLASS_MOE_SM100) && ENABLE_CUTLASS_MOE_SM100
|
||||
if (output.scalar_type() == torch::kBFloat16) {
|
||||
blockwise_scaled_group_mm_dispatch_shape<cutlass::bfloat16_t>(
|
||||
output, a, b, scales_a, scales_b, problem_sizes, expert_offsets);
|
||||
} else if (output.scalar_type() == torch::kFloat16) {
|
||||
blockwise_scaled_group_mm_dispatch_shape<cutlass::half_t>(
|
||||
output, a, b, scales_a, scales_b, problem_sizes, expert_offsets);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported output tensor type");
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("cutlass_blockwise_scaled_grouped_mm",
|
||||
&cutlass_blockwise_scaled_grouped_mm);
|
||||
}
|
||||
@ -550,8 +550,8 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowPrefill(
|
||||
int rowEnd = rowEnds[rowIdx];
|
||||
|
||||
// Local pointers to this block
|
||||
outIndices += rowIdx * topK;
|
||||
logits += rowIdx * stride0;
|
||||
outIndices += static_cast<int64_t>(rowIdx) * topK;
|
||||
logits += static_cast<int64_t>(rowIdx) * stride0;
|
||||
|
||||
topKPerRowJob<kNumThreadsPerBlock, kNumBins, useRadixSort>(
|
||||
nullptr, logits, rowStart, rowEnd, outIndices, nullptr, stride1, topK);
|
||||
@ -576,19 +576,21 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode(
|
||||
|
||||
// Local pointers to this block
|
||||
if constexpr (!multipleBlocksPerRow && !mergeBlocks) {
|
||||
outIndices += rowIdx * topK;
|
||||
outIndices += static_cast<int64_t>(rowIdx) * topK;
|
||||
} else if constexpr (multipleBlocksPerRow) {
|
||||
const auto blockSize = rowEnd / gridDim.y; // 16384 / 2 = 8192
|
||||
rowStart = blockSize * blockIdx.y; // 8192 * 1 = 8192
|
||||
rowEnd = gridDim.y == blockIdx.y + 1 ? rowEnd : rowStart + blockSize;
|
||||
outIndices += rowIdx * gridDim.y * topK + blockIdx.y * topK;
|
||||
outLogits += rowIdx * gridDim.y * topK + blockIdx.y * topK;
|
||||
outIndices +=
|
||||
static_cast<int64_t>(rowIdx) * gridDim.y * topK + blockIdx.y * topK;
|
||||
outLogits +=
|
||||
static_cast<int64_t>(rowIdx) * gridDim.y * topK + blockIdx.y * topK;
|
||||
} else if constexpr (mergeBlocks) {
|
||||
rowEnd = numBlocksToMerge * topK;
|
||||
indices += rowIdx * numBlocksToMerge * topK;
|
||||
outIndices += rowIdx * topK;
|
||||
indices += static_cast<int64_t>(rowIdx) * numBlocksToMerge * topK;
|
||||
outIndices += static_cast<int64_t>(rowIdx) * topK;
|
||||
}
|
||||
logits += rowIdx * stride0;
|
||||
logits += static_cast<int64_t>(rowIdx) * stride0;
|
||||
|
||||
topKPerRowJob<kNumThreadsPerBlock, kNumBins, useRadixSort,
|
||||
multipleBlocksPerRow, mergeBlocks>(
|
||||
|
||||
@ -416,13 +416,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor alpha) -> ()");
|
||||
ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm);
|
||||
|
||||
// cutlass blockwise scaledgroup GEMM
|
||||
ops.def(
|
||||
"cutlass_blockwise_scaled_grouped_mm(Tensor! output, Tensor a, Tensor b, "
|
||||
"Tensor scales_a, Tensor scales_b, "
|
||||
"Tensor problem_sizes, Tensor expert_offsets) -> ()");
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
// cutlass nvfp4 block scaled group GEMM
|
||||
ops.def(
|
||||
"cutlass_fp4_group_mm(Tensor! out, Tensor a, Tensor b,"
|
||||
|
||||
@ -621,7 +621,7 @@ ENV UV_HTTP_TIMEOUT=500
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=requirements/kv_connectors.txt,target=/tmp/kv_connectors.txt,ro \
|
||||
if [ "$INSTALL_KV_CONNECTORS" = "true" ]; then \
|
||||
uv pip install --system -r /tmp/kv_connectors.txt; \
|
||||
uv pip install --system -r /tmp/kv_connectors.txt || true; \
|
||||
fi
|
||||
|
||||
ENV VLLM_USAGE_SOURCE production-docker-image
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
# VLLM_CPU_DISABLE_AVX512=false (default)|true
|
||||
# VLLM_CPU_AVX512BF16=false (default)|true
|
||||
# VLLM_CPU_AVX512VNNI=false (default)|true
|
||||
# VLLM_CPU_AMXBF16=false (default)|true
|
||||
# VLLM_CPU_AMXBF16=false |true (default)
|
||||
#
|
||||
|
||||
######################### COMMON BASE IMAGE #########################
|
||||
@ -95,7 +95,7 @@ ENV VLLM_CPU_AVX512BF16=${VLLM_CPU_AVX512BF16}
|
||||
ARG VLLM_CPU_AVX512VNNI=0
|
||||
ENV VLLM_CPU_AVX512VNNI=${VLLM_CPU_AVX512VNNI}
|
||||
# Support for building with AMXBF16 ISA: docker build --build-arg VLLM_CPU_AMXBF16="true" ...
|
||||
ARG VLLM_CPU_AMXBF16=0
|
||||
ARG VLLM_CPU_AMXBF16=1
|
||||
ENV VLLM_CPU_AMXBF16=${VLLM_CPU_AMXBF16}
|
||||
|
||||
WORKDIR /workspace/vllm
|
||||
@ -147,7 +147,9 @@ WORKDIR /workspace/vllm
|
||||
|
||||
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
||||
apt-get install -y --no-install-recommends vim numactl xz-utils
|
||||
apt-get install -y --no-install-recommends vim numactl xz-utils make clangd-14
|
||||
|
||||
RUN ln -s /usr/bin/clangd-14 /usr/bin/clangd
|
||||
|
||||
# install development dependencies (for testing)
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
|
||||
@ -130,6 +130,7 @@ RUN --mount=type=bind,from=export_vllm,src=/,target=/install \
|
||||
&& uv pip install --system *.whl
|
||||
|
||||
ARG COMMON_WORKDIR
|
||||
ARG BASE_IMAGE
|
||||
|
||||
# Copy over the benchmark scripts as well
|
||||
COPY --from=export_vllm /benchmarks ${COMMON_WORKDIR}/vllm/benchmarks
|
||||
@ -144,4 +145,9 @@ ENV SAFETENSORS_FAST_GPU=1
|
||||
# Performance environment variable.
|
||||
ENV HIP_FORCE_DEV_KERNARG=1
|
||||
|
||||
# Workaround for ROCm profiler limits
|
||||
RUN echo "ROCTRACER_MAX_EVENTS=10000000" > ${COMMON_WORKDIR}/libkineto.conf
|
||||
ENV KINETO_CONFIG="${COMMON_WORKDIR}/libkineto.conf"
|
||||
RUN echo "VLLM_BASE_IMAGE=${BASE_IMAGE}" >> ${COMMON_WORKDIR}/versions.txt
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
@ -1,15 +1,15 @@
|
||||
ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:7.1-complete
|
||||
ARG TRITON_BRANCH="57c693b6"
|
||||
ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:7.0-complete
|
||||
ARG TRITON_BRANCH="a272dfa8"
|
||||
ARG TRITON_REPO="https://github.com/ROCm/triton.git"
|
||||
ARG PYTORCH_BRANCH="1c57644d"
|
||||
ARG PYTORCH_VISION_BRANCH="v0.23.0"
|
||||
ARG PYTORCH_BRANCH="89075173"
|
||||
ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git"
|
||||
ARG PYTORCH_VISION_BRANCH="v0.24.1"
|
||||
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
|
||||
ARG PYTORCH_AUDIO_BRANCH="v2.9.0"
|
||||
ARG PYTORCH_AUDIO_REPO="https://github.com/pytorch/audio.git"
|
||||
ARG FA_BRANCH="0e60e394"
|
||||
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
|
||||
ARG AITER_BRANCH="59bd8ff2"
|
||||
ARG AITER_BRANCH="6af8b687"
|
||||
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
|
||||
|
||||
FROM ${BASE_IMAGE} AS base
|
||||
|
||||
@ -8,12 +8,19 @@ The results are automatically published to the public [vLLM Performance Dashboar
|
||||
## Manually Trigger the benchmark
|
||||
|
||||
Use [vllm-ci-test-repo images](https://gallery.ecr.aws/q9t5s3a7/vllm-ci-test-repo) with vLLM benchmark suite.
|
||||
For CPU environment, please use the image with "-cpu" postfix.
|
||||
For x86 CPU environment, please use the image with "-cpu" postfix. For AArch64 CPU environment, please use the image with "-arm64-cpu" postfix.
|
||||
|
||||
Here is an example for docker run command for CPU.
|
||||
Here is an example for docker run command for CPU. For GPUs skip setting the `ON_CPU` env var.
|
||||
|
||||
```bash
|
||||
docker run -it --entrypoint /bin/bash -v /data/huggingface:/root/.cache/huggingface -e HF_TOKEN='' --shm-size=16g --name vllm-cpu-ci public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:1da94e673c257373280026f75ceb4effac80e892-cpu
|
||||
export VLLM_COMMIT=1da94e673c257373280026f75ceb4effac80e892 # use full commit hash from the main branch
|
||||
export HF_TOKEN=<valid Hugging Face token>
|
||||
if [[ "$(uname -m)" == aarch64 || "$(uname -m)" == arm64 ]]; then
|
||||
IMG_SUFFIX="arm64-cpu"
|
||||
else
|
||||
IMG_SUFFIX="cpu"
|
||||
fi
|
||||
docker run -it --entrypoint /bin/bash -v /data/huggingface:/root/.cache/huggingface -e HF_TOKEN=$HF_TOKEN -e ON_ARM64_CPU=1 --shm-size=16g --name vllm-cpu-ci public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:${VLLM_COMMIT}-${IMG_SUFFIX}
|
||||
```
|
||||
|
||||
Then, run below command inside the docker instance.
|
||||
@ -26,7 +33,7 @@ When run, benchmark script generates results under **benchmark/results** folder,
|
||||
|
||||
### Runtime environment variables
|
||||
|
||||
- `ON_CPU`: set the value to '1' on Intel® Xeon® Processors. Default value is 0.
|
||||
- `ON_CPU`: set the value to '1' on Intel® Xeon® and Arm® Neoverse™ Processors. Default value is 0.
|
||||
- `SERVING_JSON`: JSON file to use for the serving tests. Default value is empty string (use default file).
|
||||
- `LATENCY_JSON`: JSON file to use for the latency tests. Default value is empty string (use default file).
|
||||
- `THROUGHPUT_JSON`: JSON file to use for the throughout tests. Default value is empty string (use default file).
|
||||
|
||||
@ -77,25 +77,20 @@ This complicates the process as we cannot use the out-of-the-box
|
||||
- `.buildkite/release-pipeline.yaml`
|
||||
- `.buildkite/scripts/upload-wheels.sh`
|
||||
|
||||
## Address long vLLM build time
|
||||
## Manually running vLLM builds on BuildKiteCI
|
||||
|
||||
When building vLLM with a new PyTorch/CUDA version, no cache will exist
|
||||
in the vLLM sccache S3 bucket, causing the build job on CI to potentially take more than 5 hours
|
||||
and timeout. Additionally, since vLLM's fastcheck pipeline runs in read-only mode,
|
||||
it doesn't populate the cache, so re-running it to warm up the cache
|
||||
is ineffective.
|
||||
When building vLLM with a new PyTorch/CUDA version, the vLLM sccache S3 bucket
|
||||
will not have any cached artifacts, which can cause CI build jobs to exceed 5 hours.
|
||||
Furthermore, vLLM's fastcheck pipeline operates in read-only mode and does not
|
||||
populate the cache, making it ineffective for cache warm-up purposes.
|
||||
|
||||
While ongoing efforts like <https://github.com/vllm-project/vllm/issues/17419>
|
||||
address the long build time at its source, the current workaround is to set `VLLM_CI_BRANCH`
|
||||
to a custom branch provided by @khluu (`VLLM_CI_BRANCH=khluu/long_build`)
|
||||
when manually triggering a build on Buildkite. This branch accomplishes two things:
|
||||
To address this, manually trigger a build on Buildkite to accomplish two objectives:
|
||||
|
||||
1. Increase the timeout limit to 10 hours so that the build doesn't time out.
|
||||
2. Allow the compiled artifacts to be written to the vLLM sccache S3 bucket
|
||||
to warm it up so that future builds are faster.
|
||||
1. Run the complete test suite against the PyTorch RC build by setting the environment variables: `RUN_ALL=1` and `NIGHTLY=1`
|
||||
2. Populate the vLLM sccache S3 bucket with compiled artifacts, enabling faster subsequent builds
|
||||
|
||||
<p align="center" width="100%">
|
||||
<img width="60%" alt="Buildkite new build popup" src="https://github.com/user-attachments/assets/a8ff0fcd-76e0-4e91-b72f-014e3fdb6b94">
|
||||
<img width="60%" alt="Buildkite new build popup" src="https://github.com/user-attachments/assets/3b07f71b-bb18-4ca3-aeaf-da0fe79d315f" />
|
||||
</p>
|
||||
|
||||
## Update all the different vLLM platforms
|
||||
|
||||
@ -16,7 +16,7 @@ Async backends support the use of DBO (Dual Batch Overlap) and shared expert ove
|
||||
|
||||
Certain models require the topk weights to be applied to the input activations rather than the output activations when topk==1, e.g. Llama. For modular kernels, this feature is supported by the `FusedMoEPrepareAndFinalize` subclass. For non-modular kernels, it is up to the experts function to deal with this flag.
|
||||
|
||||
Unless otherwise specified, backends are controlled via `VLLM_ALL2ALL_BACKEND`. All backends except `flashinfer` only work with EP+DP or EP+TP. `Flashinfer` can work with EP or DP without EP.
|
||||
Unless otherwise specified, backends are controlled via the `--all2all-backend` command-line argument (or the `all2all_backend` parameter in `ParallelConfig`). All backends except `flashinfer` only work with EP+DP or EP+TP. `Flashinfer` can work with EP or DP without EP.
|
||||
|
||||
<style>
|
||||
td {
|
||||
|
||||
@ -109,7 +109,7 @@ Every plugin has three parts:
|
||||
- `init_device`: This function is called to set up the device for the worker.
|
||||
- `initialize_cache`: This function is called to set cache config for the worker.
|
||||
- `load_model`: This function is called to load the model weights to device.
|
||||
- `get_kv_cache_spaces`: This function is called to generate the kv cache spaces for the model.
|
||||
- `get_kv_cache_spec`: This function is called to generate the kv cache spec for the model.
|
||||
- `determine_available_memory`: This function is called to profiles the peak memory usage of the model to determine how much memory can be used for KV cache without OOMs.
|
||||
- `initialize_from_config`: This function is called to allocate device KV cache with the specified kv_cache_config
|
||||
- `execute_model`: This function is called every step to inference the model.
|
||||
|
||||
@ -352,10 +352,17 @@ Supported models:
|
||||
* `zai-org/GLM-4.5`
|
||||
* `zai-org/GLM-4.5-Air`
|
||||
* `zai-org/GLM-4.6`
|
||||
* `zai-org/GLM-4.6-Air`
|
||||
|
||||
Flags: `--tool-call-parser glm45`
|
||||
|
||||
### GLM-4.7 Models (`glm47`)
|
||||
|
||||
Supported models:
|
||||
|
||||
* `zai-org/GLM-4.7`
|
||||
|
||||
Flags: `--tool-call-parser glm47`
|
||||
|
||||
### Qwen3-Coder Models (`qwen3_xml`)
|
||||
|
||||
Supported models:
|
||||
|
||||
@ -27,3 +27,4 @@ The backends below live **outside** the main `vllm` repository and follow the
|
||||
| IBM Spyre AIU | `vllm-spyre` | <https://github.com/vllm-project/vllm-spyre> |
|
||||
| Cambricon MLU | `vllm-mlu` | <https://github.com/Cambricon/vllm-mlu> |
|
||||
| Baidu Kunlun XPU | N/A, install from source | <https://github.com/baidu/vLLM-Kunlun> |
|
||||
| Sophgo TPU | N/A, install from source | <https://github.com/sophgo/vllm-tpu> |
|
||||
|
||||
@ -19,12 +19,26 @@ Pre-built vLLM wheels for Arm are available since version 0.11.2. These wheels c
|
||||
|
||||
```bash
|
||||
export VLLM_VERSION=$(curl -s https://api.github.com/repos/vllm-project/vllm/releases/latest | jq -r .tag_name | sed 's/^v//')
|
||||
uv pip install vllm --extra-index-url https://wheels.vllm.ai/${VLLM_VERSION}/cpu
|
||||
uv pip install https://github.com/vllm-project/vllm/releases/download/v${VLLM_VERSION}/vllm-${VLLM_VERSION}+cpu-cp38-abi3-manylinux_2_35_aarch64.whl
|
||||
```
|
||||
|
||||
??? console "pip"
|
||||
```bash
|
||||
pip install vllm==${VLLM_VERSION}+cpu --extra-index-url https://wheels.vllm.ai/${VLLM_VERSION}/cpu
|
||||
pip install https://github.com/vllm-project/vllm/releases/download/v${VLLM_VERSION}/vllm-${VLLM_VERSION}+cpu-cp38-abi3-manylinux_2_35_aarch64.whl
|
||||
```
|
||||
|
||||
!!! warning "set `LD_PRELOAD`"
|
||||
Before use vLLM CPU installed via wheels, make sure TCMalloc is installed and added to `LD_PRELOAD`:
|
||||
```bash
|
||||
# install TCMalloc
|
||||
sudo apt-get install -y --no-install-recommends libtcmalloc-minimal4
|
||||
|
||||
# manually find the path
|
||||
sudo find / -iname *libtcmalloc_minimal.so.4
|
||||
TC_PATH=...
|
||||
|
||||
# add them to LD_PRELOAD
|
||||
export LD_PRELOAD="$TC_PATH:$LD_PRELOAD"
|
||||
```
|
||||
|
||||
The `uv` approach works for vLLM `v0.6.6` and later. A unique feature of `uv` is that packages in `--extra-index-url` have [higher priority than the default index](https://docs.astral.sh/uv/pip/compatibility/#packages-that-exist-on-multiple-indexes). If the latest public release is `v0.6.6.post1`, `uv`'s behavior allows installing a commit before `v0.6.6.post1` by specifying the `--extra-index-url`. In contrast, `pip` combines packages from `--extra-index-url` and the default index, choosing only the latest version, which makes it difficult to install a development version prior to the released version.
|
||||
@ -37,7 +51,7 @@ LLM inference is a fast-evolving field, and the latest code may contain bug fixe
|
||||
|
||||
To install from nightly index, run:
|
||||
```bash
|
||||
uv pip install vllm --extra-index-url https://wheels.vllm.ai/nightly/cpu
|
||||
uv pip install vllm --extra-index-url https://wheels.vllm.ai/nightly/cpu --index-strategy first-index
|
||||
```
|
||||
|
||||
??? console "pip (there's a caveat)"
|
||||
@ -56,7 +70,7 @@ If you want to access the wheels for previous commits (e.g. to bisect the behavi
|
||||
|
||||
```bash
|
||||
export VLLM_COMMIT=730bd35378bf2a5b56b6d3a45be28b3092d26519 # use full commit hash from the main branch
|
||||
uv pip install vllm --extra-index-url https://wheels.vllm.ai/${VLLM_COMMIT}/cpu
|
||||
uv pip install vllm --extra-index-url https://wheels.vllm.ai/${VLLM_COMMIT}/cpu --index-strategy first-index
|
||||
```
|
||||
|
||||
# --8<-- [end:pre-built-wheels]
|
||||
@ -105,6 +119,20 @@ VLLM_TARGET_DEVICE=cpu uv pip install -e . --no-build-isolation
|
||||
|
||||
Testing has been conducted on AWS Graviton3 instances for compatibility.
|
||||
|
||||
!!! warning "set `LD_PRELOAD`"
|
||||
Before use vLLM CPU installed via wheels, make sure TCMalloc is installed and added to `LD_PRELOAD`:
|
||||
```bash
|
||||
# install TCMalloc
|
||||
sudo apt-get install -y --no-install-recommends libtcmalloc-minimal4
|
||||
|
||||
# manually find the path
|
||||
sudo find / -iname *libtcmalloc_minimal.so.4
|
||||
TC_PATH=...
|
||||
|
||||
# add them to LD_PRELOAD
|
||||
export LD_PRELOAD="$TC_PATH:$LD_PRELOAD"
|
||||
```
|
||||
|
||||
# --8<-- [end:build-wheel-from-source]
|
||||
# --8<-- [start:pre-built-images]
|
||||
|
||||
|
||||
@ -18,6 +18,12 @@ vLLM is a Python library that supports the following CPU variants. Select your C
|
||||
|
||||
--8<-- "docs/getting_started/installation/cpu.s390x.inc.md:installation"
|
||||
|
||||
## Technical Discussions
|
||||
|
||||
The main discussions happen in the `#sig-cpu` channel of [vLLM Slack](https://slack.vllm.ai/).
|
||||
|
||||
When open a Github issue about the CPU backend, please add `[CPU Backend]` in the title and it will be labeled with `cpu` for better awareness.
|
||||
|
||||
## Requirements
|
||||
|
||||
- Python: 3.10 -- 3.13
|
||||
@ -258,11 +264,6 @@ vLLM CPU supports data parallel (DP), tensor parallel (TP) and pipeline parallel
|
||||
- GPTQ (x86 only)
|
||||
- compressed-tensor INT8 W8A8 (x86, s390x)
|
||||
|
||||
### (x86 only) What is the purpose of `VLLM_CPU_SGL_KERNEL`?
|
||||
|
||||
- Both of them require `amx` CPU flag.
|
||||
- `VLLM_CPU_SGL_KERNEL` can provide better performance for MoE models and small-batch scenarios.
|
||||
|
||||
### Why do I see `get_mempolicy: Operation not permitted` when running in Docker?
|
||||
|
||||
In some container environments (like Docker), NUMA-related syscalls used by vLLM (e.g., `get_mempolicy`, `migrate_pages`) are blocked/denied in the runtime's default seccomp/capabilities settings. This may lead to warnings like `get_mempolicy: Operation not permitted`. Functionality is not affected, but NUMA memory binding/migration optimizations may not take effect and performance can be suboptimal.
|
||||
|
||||
@ -17,7 +17,51 @@ vLLM supports basic model inferencing and serving on x86 CPU platform, with data
|
||||
# --8<-- [end:set-up-using-python]
|
||||
# --8<-- [start:pre-built-wheels]
|
||||
|
||||
Currently, there are no pre-built x86 CPU wheels.
|
||||
Pre-built vLLM wheels for x86 with AVX512 are available since version 0.13.0. To install release wheels:
|
||||
|
||||
```bash
|
||||
export VLLM_VERSION=$(curl -s https://api.github.com/repos/vllm-project/vllm/releases/latest | jq -r .tag_name | sed 's/^v//')
|
||||
|
||||
# use uv
|
||||
uv pip install https://github.com/vllm-project/vllm/releases/download/v${VLLM_VERSION}/vllm-${VLLM_VERSION}+cpu-cp38-abi3-manylinux_2_35_x86_64.whl --torch-backend cpu
|
||||
```
|
||||
??? console "pip"
|
||||
```bash
|
||||
# use pip
|
||||
pip install https://github.com/vllm-project/vllm/releases/download/v${VLLM_VERSION}/vllm-${VLLM_VERSION}+cpu-cp38-abi3-manylinux_2_35_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
```
|
||||
!!! warning "set `LD_PRELOAD`"
|
||||
Before use vLLM CPU installed via wheels, make sure TCMalloc and Intel OpenMP are installed and added to `LD_PRELOAD`:
|
||||
```bash
|
||||
# install TCMalloc, Intel OpenMP is installed with vLLM CPU
|
||||
sudo apt-get install -y --no-install-recommends libtcmalloc-minimal4
|
||||
|
||||
# manually find the path
|
||||
sudo find / -iname *libtcmalloc_minimal.so.4
|
||||
sudo find / -iname *libiomp5.so
|
||||
TC_PATH=...
|
||||
IOMP_PATH=...
|
||||
|
||||
# add them to LD_PRELOAD
|
||||
export LD_PRELOAD="$TC_PATH:$IOMP_PATH:$LD_PRELOAD"
|
||||
```
|
||||
|
||||
**Install the latest code**
|
||||
|
||||
To install the wheel built from the latest main branch:
|
||||
|
||||
```bash
|
||||
uv pip install vllm --extra-index-url https://wheels.vllm.ai/nightly/cpu --index-strategy first-index --torch-backend cpu
|
||||
```
|
||||
|
||||
**Install specific revisions**
|
||||
|
||||
If you want to access the wheels for previous commits (e.g. to bisect the behavior change, performance regression), you can specify the commit hash in the URL:
|
||||
|
||||
```bash
|
||||
export VLLM_COMMIT=730bd35378bf2a5b56b6d3a45be28b3092d26519 # use full commit hash from the main branch
|
||||
uv pip install vllm --extra-index-url https://wheels.vllm.ai/${VLLM_COMMIT}/cpu --index-strategy first-index --torch-backend cpu
|
||||
```
|
||||
|
||||
# --8<-- [end:pre-built-wheels]
|
||||
# --8<-- [start:build-wheel-from-source]
|
||||
@ -26,10 +70,12 @@ Install recommended compiler. We recommend to use `gcc/g++ >= 12.3.0` as the def
|
||||
|
||||
```bash
|
||||
sudo apt-get update -y
|
||||
sudo apt-get install -y gcc-12 g++-12 libnuma-dev python3-dev
|
||||
sudo apt-get install -y gcc-12 g++-12 libnuma-dev
|
||||
sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
|
||||
```
|
||||
|
||||
--8<-- "docs/getting_started/installation/python_env_setup.inc.md"
|
||||
|
||||
Clone the vLLM project:
|
||||
|
||||
```bash
|
||||
@ -82,6 +128,22 @@ uv pip install dist/*.whl
|
||||
pip install dist/*.whl
|
||||
```
|
||||
|
||||
!!! warning "set `LD_PRELOAD`"
|
||||
Before use vLLM CPU installed via wheels, make sure TCMalloc and Intel OpenMP are installed and added to `LD_PRELOAD`:
|
||||
```bash
|
||||
# install TCMalloc, Intel OpenMP is installed with vLLM CPU
|
||||
sudo apt-get install -y --no-install-recommends libtcmalloc-minimal4
|
||||
|
||||
# manually find the path
|
||||
sudo find / -iname *libtcmalloc_minimal.so.4
|
||||
sudo find / -iname *libiomp5.so
|
||||
TC_PATH=...
|
||||
IOMP_PATH=...
|
||||
|
||||
# add them to LD_PRELOAD
|
||||
export LD_PRELOAD="$TC_PATH:$IOMP_PATH:$LD_PRELOAD"
|
||||
```
|
||||
|
||||
!!! example "Troubleshooting"
|
||||
- **NumPy ≥2.0 error**: Downgrade using `pip install "numpy<2.0"`.
|
||||
- **CMake picks up CUDA**: Add `CMAKE_DISABLE_FIND_PACKAGE_CUDA=ON` to prevent CUDA detection during CPU builds, even if CUDA is installed.
|
||||
@ -95,7 +157,6 @@ uv pip install dist/*.whl
|
||||
"torch==X.Y.Z+cpu" # <-------
|
||||
]
|
||||
```
|
||||
- If you are building vLLM from source and not using the pre-built images, remember to set `LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:$LD_PRELOAD"` on x86 machines before running vLLM.
|
||||
|
||||
# --8<-- [end:build-wheel-from-source]
|
||||
# --8<-- [start:pre-built-images]
|
||||
@ -112,6 +173,7 @@ uv pip install dist/*.whl
|
||||
docker build -f docker/Dockerfile.cpu \
|
||||
--build-arg VLLM_CPU_AVX512BF16=false (default)|true \
|
||||
--build-arg VLLM_CPU_AVX512VNNI=false (default)|true \
|
||||
--build-arg VLLM_CPU_AMXBF16=false|true (default) \
|
||||
--build-arg VLLM_CPU_DISABLE_AVX512=false (default)|true \
|
||||
--tag vllm-cpu-env \
|
||||
--target vllm-openai .
|
||||
@ -123,9 +185,8 @@ docker run --rm \
|
||||
--shm-size=4g \
|
||||
-p 8000:8000 \
|
||||
-e VLLM_CPU_KVCACHE_SPACE=<KV cache space> \
|
||||
-e VLLM_CPU_OMP_THREADS_BIND=<CPU cores for inference> \
|
||||
vllm-cpu-env \
|
||||
--model=meta-llama/Llama-3.2-1B-Instruct \
|
||||
meta-llama/Llama-3.2-1B-Instruct \
|
||||
--dtype=bfloat16 \
|
||||
other vLLM OpenAI server arguments
|
||||
```
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
On NVIDIA CUDA only, it's recommended to use [uv](https://docs.astral.sh/uv/), a very fast Python environment manager, to create and manage Python environments. Please follow the [documentation](https://docs.astral.sh/uv/#getting-started) to install `uv`. After installing `uv`, you can create a new Python environment using the following commands:
|
||||
It's recommended to use [uv](https://docs.astral.sh/uv/), a very fast Python environment manager, to create and manage Python environments. Please follow the [documentation](https://docs.astral.sh/uv/#getting-started) to install `uv`. After installing `uv`, you can create a new Python environment using the following commands:
|
||||
|
||||
```bash
|
||||
uv venv --python 3.12 --seed
|
||||
|
||||
@ -181,3 +181,4 @@ If you have PRs touching the area, please feel free to ping the area owner for r
|
||||
|
||||
- Ascend NPU: [@wangxiyuan](https://github.com/wangxiyuan) and [see more details](https://vllm-ascend.readthedocs.io/en/latest/community/contributors.html#maintainers)
|
||||
- Intel Gaudi HPU [@xuechendi](https://github.com/xuechendi) and [@kzawora-intel](https://github.com/kzawora-intel)
|
||||
- Semantic Router: [@xunzhuo](https://github.com/xunzhuo), [@rootfs](https://github.com/rootfs) and [see more details](https://vllm-semantic-router.com/community/team)
|
||||
|
||||
@ -387,7 +387,7 @@ th {
|
||||
| `Gemma3nForCausalLM` | Gemma 3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | |
|
||||
| `GlmForCausalLM` | GLM-4 | `zai-org/glm-4-9b-chat-hf`, etc. | ✅︎ | ✅︎ |
|
||||
| `Glm4ForCausalLM` | GLM-4-0414 | `zai-org/GLM-4-32B-0414`, etc. | ✅︎ | ✅︎ |
|
||||
| `Glm4MoeForCausalLM` | GLM-4.5, GLM-4.6 | `zai-org/GLM-4.5`, etc. | ✅︎ | ✅︎ |
|
||||
| `Glm4MoeForCausalLM` | GLM-4.5, GLM-4.6, GLM-4.7 | `zai-org/GLM-4.5`, etc. | ✅︎ | ✅︎ |
|
||||
| `GPT2LMHeadModel` | GPT-2 | `gpt2`, `gpt2-xl`, etc. | | ✅︎ |
|
||||
| `GPTBigCodeForCausalLM` | StarCoder, SantaCoder, WizardCoder | `bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, `WizardLM/WizardCoder-15B-V1.0`, etc. | ✅︎ | ✅︎ |
|
||||
| `GPTJForCausalLM` | GPT-J | `EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc. | | ✅︎ |
|
||||
@ -406,6 +406,7 @@ th {
|
||||
| `InternLM2ForCausalLM` | InternLM2 | `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. | ✅︎ | ✅︎ |
|
||||
| `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ |
|
||||
| `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ |
|
||||
| `Jais2ForCausalLM` | Jais2 | `inceptionai/Jais-2-8B-Chat`, `inceptionai/Jais-2-70B-Chat`, etc. | | ✅︎ |
|
||||
| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ |
|
||||
| `KimiLinearForCausalLM` | Kimi-Linear-48B-A3B-Base, Kimi-Linear-48B-A3B-Instruct | `moonshotai/Kimi-Linear-48B-A3B-Base`, `moonshotai/Kimi-Linear-48B-A3B-Instruct` | | ✅︎ |
|
||||
| `Lfm2ForCausalLM` | LFM2 | `LiquidAI/LFM2-1.2B`, `LiquidAI/LFM2-700M`, `LiquidAI/LFM2-350M`, etc. | ✅︎ | ✅︎ |
|
||||
@ -414,6 +415,7 @@ th {
|
||||
| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ |
|
||||
| `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ |
|
||||
| `MiMoForCausalLM` | MiMo | `XiaomiMiMo/MiMo-7B-RL`, etc. | ✅︎ | ✅︎ |
|
||||
| `MiMoV2FlashForCausalLM` | MiMoV2Flash | `XiaomiMiMo/MiMo-V2-Flash`, etc. | ︎| ✅︎ |
|
||||
| `MiniCPMForCausalLM` | MiniCPM | `openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, `openbmb/MiniCPM-S-1B-sft`, etc. | ✅︎ | ✅︎ |
|
||||
| `MiniCPM3ForCausalLM` | MiniCPM3 | `openbmb/MiniCPM3-4B`, etc. | ✅︎ | ✅︎ |
|
||||
| `MiniMaxM2ForCausalLM` | MiniMax-M2 |`MiniMaxAI/MiniMax-M2`, etc. | | ✅︎ |
|
||||
|
||||
@ -47,6 +47,8 @@ We currently support the following OpenAI APIs:
|
||||
- [Completions API](#completions-api) (`/v1/completions`)
|
||||
- Only applicable to [text generation models](../models/generative_models.md).
|
||||
- *Note: `suffix` parameter is not supported.*
|
||||
- [Responses API](#responses-api) (`/v1/responses`)
|
||||
- Only applicable to [text generation models](../models/generative_models.md).
|
||||
- [Chat Completions API](#chat-api) (`/v1/chat/completions`)
|
||||
- Only applicable to [text generation models](../models/generative_models.md) with a [chat template](../serving/openai_compatible_server.md#chat-template).
|
||||
- *Note: `user` parameter is ignored.*
|
||||
@ -229,6 +231,31 @@ The following extra parameters are supported:
|
||||
--8<-- "vllm/entrypoints/openai/protocol.py:chat-completion-extra-params"
|
||||
```
|
||||
|
||||
### Responses API
|
||||
|
||||
Our Responses API is compatible with [OpenAI's Responses API](https://platform.openai.com/docs/api-reference/responses);
|
||||
you can use the [official OpenAI Python client](https://github.com/openai/openai-python) to interact with it.
|
||||
|
||||
Code example: [examples/online_serving/openai_responses_client_with_tools.py](../../examples/online_serving/openai_responses_client_with_tools.py)
|
||||
|
||||
#### Extra parameters
|
||||
|
||||
The following extra parameters in the request object are supported:
|
||||
|
||||
??? code
|
||||
|
||||
```python
|
||||
--8<-- "vllm/entrypoints/openai/protocol.py:responses-extra-params"
|
||||
```
|
||||
|
||||
The following extra parameters in the response object are supported:
|
||||
|
||||
??? code
|
||||
|
||||
```python
|
||||
--8<-- "vllm/entrypoints/openai/protocol.py:responses-response-extra-params"
|
||||
```
|
||||
|
||||
### Embeddings API
|
||||
|
||||
Our Embeddings API is compatible with [OpenAI's Embeddings API](https://platform.openai.com/docs/api-reference/embeddings);
|
||||
|
||||
@ -55,7 +55,6 @@ done
|
||||
echo "Starting vLLM server for $MODEL_NAME with data parallel size: $DATA_PARALLEL_SIZE and redundant experts: $REDUNDANT_EXPERTS"
|
||||
|
||||
export RAY_DEDUP_LOGS=0
|
||||
export VLLM_ALL2ALL_BACKEND="pplx"
|
||||
export VLLM_USE_DEEP_GEMM=1
|
||||
|
||||
vllm serve $MODEL_NAME \
|
||||
@ -65,6 +64,7 @@ vllm serve $MODEL_NAME \
|
||||
--enforce-eager \
|
||||
--enable-expert-parallel \
|
||||
--enable-eplb \
|
||||
--all2all-backend pplx \
|
||||
--num-redundant-experts $REDUNDANT_EXPERTS \
|
||||
--trust-remote-code \
|
||||
--host $HOST \
|
||||
|
||||
@ -6,7 +6,7 @@ requires = [
|
||||
"packaging>=24.2",
|
||||
"setuptools>=77.0.3,<81.0.0",
|
||||
"setuptools-scm>=8.0",
|
||||
"torch == 2.9.0",
|
||||
"torch == 2.9.1",
|
||||
"wheel",
|
||||
"jinja2",
|
||||
]
|
||||
|
||||
@ -4,7 +4,7 @@ ninja
|
||||
packaging>=24.2
|
||||
setuptools>=77.0.3,<81.0.0
|
||||
setuptools-scm>=8
|
||||
torch==2.9.0
|
||||
torch==2.9.1
|
||||
wheel
|
||||
jinja2>=3.1.6
|
||||
regex
|
||||
|
||||
@ -37,7 +37,7 @@ pyyaml
|
||||
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
|
||||
setuptools>=77.0.3,<81.0.0; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12
|
||||
einops # Required for Qwen2-VL.
|
||||
compressed-tensors == 0.12.2 # required for compressed-tensors
|
||||
compressed-tensors == 0.13.0 # required for compressed-tensors
|
||||
depyf==0.20.0 # required for profiling and debugging with compilation config
|
||||
cloudpickle # allows pickling lambda functions in model_executor/models/registry.py
|
||||
watchfiles # required for http server to monitor the updates of TLS files
|
||||
@ -50,5 +50,5 @@ ijson # Required for mistral streaming tool parser
|
||||
setproctitle # Used to set process names for better debugging and monitoring
|
||||
openai-harmony >= 0.0.3 # Required for gpt-oss
|
||||
anthropic == 0.71.0
|
||||
model-hosting-container-standards >= 0.1.9, < 1.0.0
|
||||
mcp
|
||||
model-hosting-container-standards >= 0.1.10, < 1.0.0
|
||||
mcp
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
cmake>=3.26.1
|
||||
ninja
|
||||
packaging>=24.2
|
||||
setuptools>=77.0.3,<81.0.0
|
||||
setuptools==77.0.3 # this version can reuse CMake build dir
|
||||
setuptools-scm>=8
|
||||
torch==2.9.1+cpu; platform_machine == "x86_64" or platform_machine == "s390x"
|
||||
torch==2.9.1; platform_system == "Darwin" or platform_machine == "ppc64le" or platform_machine == "aarch64"
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
# Common dependencies
|
||||
-r common.txt
|
||||
|
||||
setuptools==77.0.3 # this version can reuse CMake build dir
|
||||
|
||||
numba == 0.61.2; platform_machine != "s390x" # Required for N-gram speculative decoding
|
||||
|
||||
# Dependencies for CPUs
|
||||
|
||||
@ -5,9 +5,9 @@ numba == 0.61.2 # Required for N-gram speculative decoding
|
||||
|
||||
# Dependencies for NVIDIA GPUs
|
||||
ray[cgraph]>=2.48.0 # Ray Compiled Graph, required for pipeline parallelism in V1.
|
||||
torch==2.9.0
|
||||
torchaudio==2.9.0
|
||||
torch==2.9.1
|
||||
torchaudio==2.9.1
|
||||
# These must be updated alongside torch
|
||||
torchvision==0.24.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
|
||||
torchvision==0.24.1 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
|
||||
# FlashInfer should be updated together with the Dockerfile
|
||||
flashinfer-python==0.5.3
|
||||
|
||||
@ -2,11 +2,11 @@
|
||||
-r common.txt
|
||||
|
||||
--extra-index-url https://download.pytorch.org/whl/rocm6.4
|
||||
torch==2.9.0
|
||||
torchvision==0.24.0
|
||||
torchaudio==2.9.0
|
||||
torch==2.9.1
|
||||
torchvision==0.24.1
|
||||
torchaudio==2.9.1
|
||||
|
||||
triton==3.5.0
|
||||
triton==3.5.1
|
||||
cmake>=3.26.1,<4
|
||||
packaging>=24.2
|
||||
setuptools>=77.0.3,<80.0.0
|
||||
|
||||
@ -24,9 +24,9 @@ soundfile # required for audio tests
|
||||
jiwer # required for audio tests
|
||||
tblib # for pickling test exceptions
|
||||
timm >=1.0.17 # required for internvl and gemma3n-mm test
|
||||
torch==2.9.0
|
||||
torchaudio==2.9.0
|
||||
torchvision==0.24.0
|
||||
torch==2.9.1
|
||||
torchaudio==2.9.1
|
||||
torchvision==0.24.1
|
||||
transformers_stream_generator # required for qwen-vl test
|
||||
matplotlib # required for qwen-vl test
|
||||
mistral_common[image,audio] >= 1.8.5 # required for voxtral test
|
||||
|
||||
@ -1123,7 +1123,7 @@ tomli==2.2.1
|
||||
# via schemathesis
|
||||
tomli-w==1.2.0
|
||||
# via schemathesis
|
||||
torch==2.9.0+cu129
|
||||
torch==2.9.1+cu129
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# accelerate
|
||||
@ -1152,7 +1152,7 @@ torch==2.9.0+cu129
|
||||
# torchvision
|
||||
# vector-quantize-pytorch
|
||||
# vocos
|
||||
torchaudio==2.9.0+cu129
|
||||
torchaudio==2.9.1+cu129
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# encodec
|
||||
@ -1165,7 +1165,7 @@ torchmetrics==1.7.4
|
||||
# pytorch-lightning
|
||||
# terratorch
|
||||
# torchgeo
|
||||
torchvision==0.24.0+cu129
|
||||
torchvision==0.24.1+cu129
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# lightly
|
||||
@ -1206,7 +1206,7 @@ transformers==4.57.3
|
||||
# transformers-stream-generator
|
||||
transformers-stream-generator==0.0.5
|
||||
# via -r requirements/test.in
|
||||
triton==3.5.0
|
||||
triton==3.5.1
|
||||
# via torch
|
||||
tritonclient==2.51.0
|
||||
# via
|
||||
|
||||
@ -67,7 +67,6 @@ def _fix_prompt_embed_outputs(
|
||||
@pytest.mark.parametrize("model_executor", ["uni", "mp"])
|
||||
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
|
||||
def test_models(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
hf_runner,
|
||||
model: str,
|
||||
backend: str,
|
||||
@ -77,48 +76,46 @@ def test_models(
|
||||
model_executor: str,
|
||||
enable_prompt_embeds: bool,
|
||||
) -> None:
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", backend)
|
||||
# 5042 tokens for gemma2
|
||||
# gemma2 has alternating sliding window size of 4096
|
||||
# we need a prompt with more than 4096 tokens to test the sliding window
|
||||
prompt = (
|
||||
"The following numbers of the sequence "
|
||||
+ ", ".join(str(i) for i in range(1024))
|
||||
+ " are:"
|
||||
)
|
||||
example_prompts = [prompt]
|
||||
|
||||
# 5042 tokens for gemma2
|
||||
# gemma2 has alternating sliding window size of 4096
|
||||
# we need a prompt with more than 4096 tokens to test the sliding window
|
||||
prompt = (
|
||||
"The following numbers of the sequence "
|
||||
+ ", ".join(str(i) for i in range(1024))
|
||||
+ " are:"
|
||||
)
|
||||
example_prompts = [prompt]
|
||||
with hf_runner(model) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
if enable_prompt_embeds:
|
||||
with torch.no_grad():
|
||||
prompt_embeds = hf_model.get_prompt_embeddings(example_prompts)
|
||||
|
||||
with hf_runner(model) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
if enable_prompt_embeds:
|
||||
with torch.no_grad():
|
||||
prompt_embeds = hf_model.get_prompt_embeddings(example_prompts)
|
||||
with VllmRunner(
|
||||
model,
|
||||
max_model_len=8192,
|
||||
enforce_eager=enforce_eager,
|
||||
enable_prompt_embeds=enable_prompt_embeds,
|
||||
gpu_memory_utilization=0.7,
|
||||
async_scheduling=async_scheduling,
|
||||
distributed_executor_backend=model_executor,
|
||||
attention_config={"backend": backend},
|
||||
) as vllm_model:
|
||||
if enable_prompt_embeds:
|
||||
vllm_outputs = vllm_model.generate_greedy(prompt_embeds, max_tokens)
|
||||
vllm_outputs = _fix_prompt_embed_outputs(
|
||||
vllm_outputs, hf_model, example_prompts
|
||||
)
|
||||
else:
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
with VllmRunner(
|
||||
model,
|
||||
max_model_len=8192,
|
||||
enforce_eager=enforce_eager,
|
||||
enable_prompt_embeds=enable_prompt_embeds,
|
||||
gpu_memory_utilization=0.7,
|
||||
async_scheduling=async_scheduling,
|
||||
distributed_executor_backend=model_executor,
|
||||
) as vllm_model:
|
||||
if enable_prompt_embeds:
|
||||
vllm_outputs = vllm_model.generate_greedy(prompt_embeds, max_tokens)
|
||||
vllm_outputs = _fix_prompt_embed_outputs(
|
||||
vllm_outputs, hf_model, example_prompts
|
||||
)
|
||||
else:
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@ -161,12 +158,6 @@ def test_models_distributed(
|
||||
): # noqa
|
||||
pytest.skip("enable_prompt_embeds does not work with ray compiled dag.")
|
||||
|
||||
if attention_backend:
|
||||
monkeypatch_context.setenv(
|
||||
"VLLM_ATTENTION_BACKEND",
|
||||
attention_backend,
|
||||
)
|
||||
|
||||
for k, v in extra_env.items():
|
||||
monkeypatch_context.setenv(k, v)
|
||||
|
||||
@ -178,6 +169,7 @@ def test_models_distributed(
|
||||
# if we run HF first, the cuda initialization will be done and it
|
||||
# will hurt multiprocessing backend with fork method
|
||||
# (the default method).
|
||||
attention_config = {"backend": attention_backend} if attention_backend else None
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
@ -185,6 +177,7 @@ def test_models_distributed(
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enable_prompt_embeds=enable_prompt_embeds,
|
||||
gpu_memory_utilization=0.7,
|
||||
attention_config=attention_config,
|
||||
) as vllm_model:
|
||||
if enable_prompt_embeds:
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
|
||||
@ -19,21 +19,18 @@ def server():
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_bench_serve(server):
|
||||
# Test default model detection and input/output len
|
||||
command = [
|
||||
"vllm",
|
||||
"bench",
|
||||
"serve",
|
||||
"--model",
|
||||
MODEL_NAME,
|
||||
"--host",
|
||||
server.host,
|
||||
"--port",
|
||||
str(server.port),
|
||||
"--dataset-name",
|
||||
"random",
|
||||
"--random-input-len",
|
||||
"--input-len",
|
||||
"32",
|
||||
"--random-output-len",
|
||||
"--output-len",
|
||||
"4",
|
||||
"--num-prompts",
|
||||
"5",
|
||||
|
||||
@ -208,7 +208,8 @@ def test_attn_quant(
|
||||
# To capture subprocess logs, we need to know whether spawn or fork is used.
|
||||
# Force spawn as it is more general.
|
||||
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
|
||||
|
||||
model_kwargs["attention_config"] = {"backend": backend.name}
|
||||
|
||||
compilation_config = CompilationConfig(
|
||||
# Testing properties
|
||||
@ -297,7 +298,8 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
|
||||
# To capture subprocess logs, we need to know whether spawn or fork is used.
|
||||
# Force spawn as it is more general.
|
||||
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
|
||||
|
||||
model_kwargs["attention_config"] = {"backend": backend.name}
|
||||
|
||||
compilation_config = CompilationConfig(
|
||||
# Testing properties
|
||||
@ -409,7 +411,8 @@ def test_tp2_attn_quant_async_tp(
|
||||
# To capture subprocess logs, we need to know whether spawn or fork is used.
|
||||
# Force spawn as it is more general.
|
||||
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
|
||||
|
||||
model_kwargs["attention_config"] = {"backend": backend.name}
|
||||
|
||||
compilation_config = CompilationConfig(
|
||||
# Testing properties
|
||||
@ -523,6 +526,8 @@ CUSTOM_OPS_QUANT_RMS_NORM = ["+quant_fp8,+rms_norm"]
|
||||
list[tuple[Any, ...]](flat_product(MODELS_GROUP_FP8, CUSTOM_OPS_QUANT_RMS_NORM)),
|
||||
)
|
||||
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
|
||||
# TODO: remove skip after we fix the fusion thoroughly
|
||||
@pytest.mark.skipif(is_blackwell(), reason="Temporarily disabled on Blackwell")
|
||||
def test_rms_group_quant(
|
||||
model_name: str,
|
||||
model_kwargs: dict[str, Any],
|
||||
@ -562,7 +567,9 @@ def test_rms_group_quant(
|
||||
splitting_ops=splitting_ops,
|
||||
# Common
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
pass_config=PassConfig(eliminate_noops=True, enable_fusion=True),
|
||||
pass_config=PassConfig(
|
||||
fuse_norm_quant=True, fuse_act_quant=True, eliminate_noops=True
|
||||
),
|
||||
# Inductor caches custom passes by default as well via uuid
|
||||
inductor_compile_config={"force_disable_caches": True},
|
||||
)
|
||||
|
||||
@ -89,7 +89,6 @@ class TestSetting:
|
||||
],
|
||||
)
|
||||
def test_compile_correctness(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
test_setting: TestSetting,
|
||||
):
|
||||
# this test is run under multiple suits, with different GPUs.
|
||||
@ -107,49 +106,48 @@ def test_compile_correctness(
|
||||
f"{cuda_device_count_stateless()}"
|
||||
)
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||
final_args = [
|
||||
*model_args,
|
||||
"-pp",
|
||||
str(pp_size),
|
||||
"-tp",
|
||||
str(tp_size),
|
||||
"-cc.cudagraph_mode=none",
|
||||
]
|
||||
final_args = [
|
||||
*model_args,
|
||||
"-pp",
|
||||
str(pp_size),
|
||||
"-tp",
|
||||
str(tp_size),
|
||||
"-cc.cudagraph_mode=none",
|
||||
f"--attention-backend={attn_backend}",
|
||||
]
|
||||
|
||||
all_args: list[list[str]] = []
|
||||
all_envs: list[dict[str, str] | None] = []
|
||||
all_args: list[list[str]] = []
|
||||
all_envs: list[dict[str, str] | None] = []
|
||||
|
||||
for comp_mode in [
|
||||
CompilationMode.STOCK_TORCH_COMPILE,
|
||||
CompilationMode.DYNAMO_TRACE_ONCE,
|
||||
CompilationMode.VLLM_COMPILE,
|
||||
]:
|
||||
for mode in [CompilationMode.NONE, comp_mode]:
|
||||
all_args.append(
|
||||
final_args + [f"-cc.mode={mode.name}", "-cc.backend=inductor"]
|
||||
)
|
||||
|
||||
# inductor will change the output, so we only compare if the output
|
||||
# is close, not exactly the same.
|
||||
compare_all_settings(
|
||||
model,
|
||||
all_args,
|
||||
all_envs,
|
||||
method=method if method != "generate" else "generate_close",
|
||||
for comp_mode in [
|
||||
CompilationMode.STOCK_TORCH_COMPILE,
|
||||
CompilationMode.DYNAMO_TRACE_ONCE,
|
||||
CompilationMode.VLLM_COMPILE,
|
||||
]:
|
||||
for mode in [CompilationMode.NONE, comp_mode]:
|
||||
all_args.append(
|
||||
final_args + [f"-cc.mode={mode.name}", "-cc.backend=inductor"]
|
||||
)
|
||||
all_envs.clear()
|
||||
all_args.clear()
|
||||
|
||||
for mode in [
|
||||
CompilationMode.NONE,
|
||||
CompilationMode.STOCK_TORCH_COMPILE,
|
||||
CompilationMode.DYNAMO_TRACE_ONCE,
|
||||
CompilationMode.VLLM_COMPILE,
|
||||
]:
|
||||
all_args.append(final_args + [f"-cc.mode={mode.name}", "-cc.backend=eager"])
|
||||
all_envs.append({})
|
||||
all_envs.append({})
|
||||
# inductor will change the output, so we only compare if the output
|
||||
# is close, not exactly the same.
|
||||
compare_all_settings(
|
||||
model,
|
||||
all_args,
|
||||
all_envs,
|
||||
method=method if method != "generate" else "generate_close",
|
||||
)
|
||||
all_envs.clear()
|
||||
all_args.clear()
|
||||
|
||||
compare_all_settings(model, all_args * 3, all_envs, method=method)
|
||||
for mode in [
|
||||
CompilationMode.NONE,
|
||||
CompilationMode.STOCK_TORCH_COMPILE,
|
||||
CompilationMode.DYNAMO_TRACE_ONCE,
|
||||
CompilationMode.VLLM_COMPILE,
|
||||
]:
|
||||
all_args.append(final_args + [f"-cc.mode={mode.name}", "-cc.backend=eager"])
|
||||
all_envs.append({})
|
||||
all_envs.append({})
|
||||
|
||||
compare_all_settings(model, all_args * 3, all_envs, method=method)
|
||||
|
||||
@ -74,7 +74,6 @@ def llm_pair(request):
|
||||
# Force native sampler to avoid potential nondeterminism in FlashInfer
|
||||
# when per-request generators are not used in V1.
|
||||
"VLLM_USE_FLASHINFER_SAMPLER": "0",
|
||||
**backend_config.env_vars,
|
||||
}
|
||||
with temporary_environ(env_vars):
|
||||
full = LLM(
|
||||
@ -170,16 +169,10 @@ class TestFullCUDAGraph:
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
|
||||
def test_full_cudagraph_with_invalid_backend():
|
||||
with (
|
||||
temporary_environ(
|
||||
{
|
||||
"VLLM_ATTENTION_BACKEND": "FLEX_ATTENTION",
|
||||
# Flex_Attention is not supported with full cuda graph
|
||||
}
|
||||
),
|
||||
pytest.raises(RuntimeError),
|
||||
):
|
||||
# Flex_Attention is not supported with full cuda graph
|
||||
with pytest.raises(RuntimeError):
|
||||
LLM(
|
||||
model="Qwen/Qwen2-1.5B-Instruct",
|
||||
compilation_config=CompilationConfig(cudagraph_mode="FULL"),
|
||||
attention_config={"backend": "FLEX_ATTENTION"},
|
||||
)
|
||||
|
||||
@ -197,20 +197,19 @@ def test_custom_compile_config(
|
||||
],
|
||||
)
|
||||
def test_fp8_kv_scale_compile(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
compilation_mode: int,
|
||||
model: str,
|
||||
backend: AttentionBackendEnum | None,
|
||||
):
|
||||
if backend:
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
|
||||
|
||||
model_kwargs = {
|
||||
"quantization": "fp8",
|
||||
"kv_cache_dtype": "fp8_e4m3",
|
||||
"calculate_kv_scales": True,
|
||||
"max_model_len": 512,
|
||||
}
|
||||
if backend:
|
||||
model_kwargs["attention_config"] = {"backend": backend.name}
|
||||
|
||||
run_model(compilation_mode, model, **model_kwargs)
|
||||
|
||||
|
||||
|
||||
@ -9,6 +9,7 @@ from contextlib import contextmanager
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.activation
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
@ -16,9 +17,12 @@ from vllm.config import (
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.envs import disable_envs_cache
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
from ..utils import create_new_process_for_each_test
|
||||
|
||||
|
||||
def reference_fn(x: torch.Tensor):
|
||||
assert x.shape[0] <= 42
|
||||
@ -66,6 +70,7 @@ def test_no_dynamo_cache_entry(monkeypatch: pytest.MonkeyPatch):
|
||||
torch.compiler.set_stance("fail_on_recompile"),
|
||||
):
|
||||
CompiledMod(vllm_config=vllm_config)(*args)
|
||||
disable_envs_cache()
|
||||
|
||||
m.setenv("VLLM_USE_AOT_COMPILE", "1")
|
||||
torch._dynamo.reset()
|
||||
@ -101,6 +106,7 @@ def test_save_and_load(monkeypatch: pytest.MonkeyPatch):
|
||||
vllm_config = make_vllm_config()
|
||||
with use_vllm_config(vllm_config):
|
||||
expected = CompiledMod(vllm_config=vllm_config)(*args)
|
||||
disable_envs_cache()
|
||||
|
||||
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
|
||||
vllm_config = make_vllm_config()
|
||||
@ -130,6 +136,7 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch):
|
||||
artifacts = compiled_mod.aot_compiled_fn._artifacts
|
||||
guards_string = artifacts.compiled_fn.shape_env.format_guards()
|
||||
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"
|
||||
disable_envs_cache()
|
||||
|
||||
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
|
||||
vllm_config = make_vllm_config()
|
||||
@ -144,7 +151,7 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch):
|
||||
@pytest.mark.skipif(
|
||||
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
|
||||
)
|
||||
@use_vllm_config(make_vllm_config())
|
||||
@create_new_process_for_each_test("spawn")
|
||||
def test_gpt2_cache_hit(monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
Test that compiling gpt2 twice results in a cache hit and
|
||||
@ -186,6 +193,8 @@ def test_gpt2_cache_hit(monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
# Clean up first model
|
||||
del llm_model
|
||||
disable_envs_cache()
|
||||
vllm.model_executor.layers.activation._ACTIVATION_REGISTRY._dict.clear()
|
||||
|
||||
# Second compilation - should hit cache
|
||||
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
|
||||
|
||||
@ -233,24 +233,6 @@ def test_splitting_ops_dynamic():
|
||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
|
||||
|
||||
|
||||
def test_moe_splitting_ops_deepep_ht_piecewise():
|
||||
# Non-inductor, non-attn-fusion case: DeepEP HT with dp>1
|
||||
# should add MoE ops to splitting_ops on top of attention ops.
|
||||
config = VllmConfig(
|
||||
parallel_config=ParallelConfig(
|
||||
all2all_backend="deepep_high_throughput",
|
||||
data_parallel_size=8,
|
||||
),
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
),
|
||||
)
|
||||
splitting_ops = config.compilation_config.splitting_ops
|
||||
assert splitting_ops is not None
|
||||
assert "vllm::moe_forward" in splitting_ops
|
||||
assert "vllm::moe_forward_shared" in splitting_ops
|
||||
|
||||
|
||||
def test_moe_splitting_ops_deepep_ht_inductor_partition():
|
||||
# Inductor partition case: user-provided splitting_ops should be
|
||||
# preserved and MoE ops should be appended for DeepEP HT with dp>1.
|
||||
@ -277,26 +259,6 @@ def test_moe_splitting_ops_deepep_ht_inductor_partition():
|
||||
]
|
||||
|
||||
|
||||
def test_moe_splitting_ops_deepep_ht_attn_fusion_no_inductor():
|
||||
# Pure attn-fusion case without inductor partition: even with
|
||||
# DeepEP HT and dp>1, we should not re-enable piecewise compilation
|
||||
# or add MoE ops into splitting_ops.
|
||||
config = VllmConfig(
|
||||
parallel_config=ParallelConfig(
|
||||
all2all_backend="deepep_high_throughput",
|
||||
data_parallel_size=8,
|
||||
),
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
pass_config={"fuse_attn_quant": True, "eliminate_noops": True},
|
||||
custom_ops=["+quant_fp8"],
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
),
|
||||
)
|
||||
assert config.compilation_config.splitting_ops == []
|
||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL
|
||||
|
||||
|
||||
def test_should_split():
|
||||
import torch
|
||||
|
||||
|
||||
@ -219,14 +219,12 @@ def _test_cp_gsm8k(
|
||||
]
|
||||
)
|
||||
|
||||
server_env = {}
|
||||
if attn_backend:
|
||||
server_env["VLLM_ATTENTION_BACKEND"] = attn_backend
|
||||
server_args.append(f"--attention-backend={attn_backend}")
|
||||
|
||||
with RemoteOpenAIServer(
|
||||
model_id,
|
||||
server_args,
|
||||
env_dict=server_env,
|
||||
max_wait_seconds=720,
|
||||
) as remote_server:
|
||||
host = f"http://{remote_server.host}"
|
||||
|
||||
@ -20,23 +20,21 @@ from ..utils import compare_two_settings, create_new_process_for_each_test
|
||||
)
|
||||
@create_new_process_for_each_test()
|
||||
def test_pp_cudagraph(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
PP_SIZE: int,
|
||||
MODEL_NAME: str,
|
||||
ATTN_BACKEND: LiteralString,
|
||||
):
|
||||
with monkeypatch.context() as m:
|
||||
cudagraph_args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"float16",
|
||||
"--pipeline-parallel-size",
|
||||
str(PP_SIZE),
|
||||
"--distributed-executor-backend",
|
||||
"mp",
|
||||
]
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", ATTN_BACKEND)
|
||||
cudagraph_args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"float16",
|
||||
"--pipeline-parallel-size",
|
||||
str(PP_SIZE),
|
||||
"--distributed-executor-backend",
|
||||
"mp",
|
||||
f"--attention-backend={ATTN_BACKEND}",
|
||||
]
|
||||
|
||||
eager_args = cudagraph_args + ["--enforce-eager"]
|
||||
eager_args = cudagraph_args + ["--enforce-eager"]
|
||||
|
||||
compare_two_settings(MODEL_NAME, eager_args, cudagraph_args)
|
||||
compare_two_settings(MODEL_NAME, eager_args, cudagraph_args)
|
||||
|
||||
@ -9,7 +9,7 @@ from typing import Annotated, Literal
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import CompilationConfig, config
|
||||
from vllm.config import AttentionConfig, CompilationConfig, config
|
||||
from vllm.engine.arg_utils import (
|
||||
EngineArgs,
|
||||
contains_type,
|
||||
@ -298,6 +298,139 @@ def test_compilation_config():
|
||||
)
|
||||
|
||||
|
||||
def test_attention_config():
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||
|
||||
# default value
|
||||
args = parser.parse_args([])
|
||||
assert args is not None
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
assert engine_args.attention_config == AttentionConfig()
|
||||
|
||||
# set backend via dot notation
|
||||
args = parser.parse_args(["--attention-config.backend", "FLASH_ATTN"])
|
||||
assert args is not None
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
assert engine_args.attention_config.backend is not None
|
||||
assert engine_args.attention_config.backend.name == "FLASH_ATTN"
|
||||
|
||||
# set backend via --attention-backend shorthand
|
||||
args = parser.parse_args(["--attention-backend", "FLASHINFER"])
|
||||
assert args is not None
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
assert engine_args.attention_backend is not None
|
||||
assert engine_args.attention_backend == "FLASHINFER"
|
||||
|
||||
# set all fields via dot notation
|
||||
args = parser.parse_args(
|
||||
[
|
||||
"--attention-config.backend",
|
||||
"FLASH_ATTN",
|
||||
"--attention-config.flash_attn_version",
|
||||
"3",
|
||||
"--attention-config.use_prefill_decode_attention",
|
||||
"true",
|
||||
"--attention-config.flash_attn_max_num_splits_for_cuda_graph",
|
||||
"16",
|
||||
"--attention-config.use_cudnn_prefill",
|
||||
"true",
|
||||
"--attention-config.use_trtllm_ragged_deepseek_prefill",
|
||||
"true",
|
||||
"--attention-config.use_trtllm_attention",
|
||||
"true",
|
||||
"--attention-config.disable_flashinfer_prefill",
|
||||
"true",
|
||||
"--attention-config.disable_flashinfer_q_quantization",
|
||||
"true",
|
||||
]
|
||||
)
|
||||
assert args is not None
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
assert engine_args.attention_config.backend is not None
|
||||
assert engine_args.attention_config.backend.name == "FLASH_ATTN"
|
||||
assert engine_args.attention_config.flash_attn_version == 3
|
||||
assert engine_args.attention_config.use_prefill_decode_attention is True
|
||||
assert engine_args.attention_config.flash_attn_max_num_splits_for_cuda_graph == 16
|
||||
assert engine_args.attention_config.use_cudnn_prefill is True
|
||||
assert engine_args.attention_config.use_trtllm_ragged_deepseek_prefill is True
|
||||
assert engine_args.attention_config.use_trtllm_attention is True
|
||||
assert engine_args.attention_config.disable_flashinfer_prefill is True
|
||||
assert engine_args.attention_config.disable_flashinfer_q_quantization is True
|
||||
|
||||
# set to string form of a dict with all fields
|
||||
args = parser.parse_args(
|
||||
[
|
||||
"--attention-config="
|
||||
'{"backend": "FLASHINFER", "flash_attn_version": 2, '
|
||||
'"use_prefill_decode_attention": false, '
|
||||
'"flash_attn_max_num_splits_for_cuda_graph": 8, '
|
||||
'"use_cudnn_prefill": false, '
|
||||
'"use_trtllm_ragged_deepseek_prefill": false, '
|
||||
'"use_trtllm_attention": false, '
|
||||
'"disable_flashinfer_prefill": false, '
|
||||
'"disable_flashinfer_q_quantization": false}',
|
||||
]
|
||||
)
|
||||
assert args is not None
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
assert engine_args.attention_config.backend is not None
|
||||
assert engine_args.attention_config.backend.name == "FLASHINFER"
|
||||
assert engine_args.attention_config.flash_attn_version == 2
|
||||
assert engine_args.attention_config.use_prefill_decode_attention is False
|
||||
assert engine_args.attention_config.flash_attn_max_num_splits_for_cuda_graph == 8
|
||||
assert engine_args.attention_config.use_cudnn_prefill is False
|
||||
assert engine_args.attention_config.use_trtllm_ragged_deepseek_prefill is False
|
||||
assert engine_args.attention_config.use_trtllm_attention is False
|
||||
assert engine_args.attention_config.disable_flashinfer_prefill is False
|
||||
assert engine_args.attention_config.disable_flashinfer_q_quantization is False
|
||||
|
||||
# test --attention-backend flows into VllmConfig.attention_config
|
||||
args = parser.parse_args(
|
||||
[
|
||||
"--model",
|
||||
"facebook/opt-125m",
|
||||
"--attention-backend",
|
||||
"FLASH_ATTN",
|
||||
]
|
||||
)
|
||||
assert args is not None
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
assert vllm_config.attention_config.backend == AttentionBackendEnum.FLASH_ATTN
|
||||
|
||||
# test --attention-config.backend flows into VllmConfig.attention_config
|
||||
args = parser.parse_args(
|
||||
[
|
||||
"--model",
|
||||
"facebook/opt-125m",
|
||||
"--attention-config.backend",
|
||||
"FLASHINFER",
|
||||
]
|
||||
)
|
||||
assert args is not None
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
assert vllm_config.attention_config.backend == AttentionBackendEnum.FLASHINFER
|
||||
|
||||
# test --attention-backend and --attention-config.backend are mutually exclusive
|
||||
args = parser.parse_args(
|
||||
[
|
||||
"--model",
|
||||
"facebook/opt-125m",
|
||||
"--attention-backend",
|
||||
"FLASH_ATTN",
|
||||
"--attention-config.backend",
|
||||
"FLASHINFER",
|
||||
]
|
||||
)
|
||||
assert args is not None
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
with pytest.raises(ValueError, match="mutually exclusive"):
|
||||
engine_args.create_engine_config()
|
||||
|
||||
|
||||
def test_prefix_cache_default():
|
||||
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||
args = parser.parse_args([])
|
||||
|
||||
0
tests/entrypoints/instrumentator/__init__.py
Normal file
0
tests/entrypoints/instrumentator/__init__.py
Normal file
@ -14,11 +14,10 @@ import requests
|
||||
from prometheus_client.parser import text_string_to_metric_families
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from tests.conftest import LocalAssetServer
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm import version
|
||||
|
||||
from ...conftest import LocalAssetServer
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODELS = {
|
||||
"text": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"multimodal": "HuggingFaceTB/SmolVLM-256M-Instruct",
|
||||
@ -254,7 +254,9 @@ async def test_single_chat_session_input_audio(
|
||||
async def test_chat_streaming_audio(
|
||||
client: openai.AsyncOpenAI, model_name: str, audio_url: str
|
||||
):
|
||||
messages = dummy_messages_from_audio_url(audio_url)
|
||||
messages = dummy_messages_from_audio_url(
|
||||
audio_url, "What's a short title for this audio?"
|
||||
)
|
||||
|
||||
# test single completion
|
||||
chat_completion = await client.chat.completions.create(
|
||||
|
||||
@ -76,6 +76,7 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
|
||||
lora_request,
|
||||
trace_headers,
|
||||
priority,
|
||||
data_parallel_rank,
|
||||
):
|
||||
return dict(engine_prompt), {}
|
||||
|
||||
|
||||
@ -73,6 +73,7 @@ def _build_serving_completion(engine: AsyncLLM) -> OpenAIServingCompletion:
|
||||
lora_request,
|
||||
trace_headers,
|
||||
priority,
|
||||
data_parallel_rank,
|
||||
):
|
||||
return dict(engine_prompt), {}
|
||||
|
||||
|
||||
223
tests/entrypoints/openai/test_embedding_shape_validation.py
Normal file
223
tests/entrypoints/openai/test_embedding_shape_validation.py
Normal file
@ -0,0 +1,223 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Embedding shape validation in multimodal APIs.
|
||||
|
||||
Tests verify that embeddings with correct ndim but incorrect hidden_size
|
||||
are rejected before they can cause crashes during model inference.
|
||||
|
||||
Validation is performed by the parser (MultiModalDataParser) and EmbeddingItems
|
||||
classes, not by CompletionRenderer or MediaIO classes.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.multimodal.parse import (
|
||||
AudioEmbeddingItems,
|
||||
ImageEmbeddingItems,
|
||||
MultiModalDataParser,
|
||||
VideoEmbeddingItems,
|
||||
)
|
||||
|
||||
|
||||
class TestMultiModalParserShapeValidation:
|
||||
"""Test hidden_size validation in MultiModalDataParser."""
|
||||
|
||||
def test_image_embeddings_correct_hidden_size_accepted(self):
|
||||
"""Baseline: Image embeddings with correct hidden_size should work."""
|
||||
expected_hidden_size = 768
|
||||
parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size)
|
||||
|
||||
valid_embeds = torch.randn(2, 100, expected_hidden_size)
|
||||
|
||||
result = parser.parse_mm_data({"image": valid_embeds})
|
||||
|
||||
assert "image" in result
|
||||
assert isinstance(result["image"], ImageEmbeddingItems)
|
||||
assert result["image"].get_count() == 2
|
||||
|
||||
def test_image_embeddings_wrong_hidden_size_rejected(self):
|
||||
"""Security: Image embeddings with wrong hidden_size should be rejected."""
|
||||
expected_hidden_size = 768
|
||||
wrong_hidden_size = 4096
|
||||
parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size)
|
||||
|
||||
invalid_embeds = torch.randn(2, 100, wrong_hidden_size)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
parser.parse_mm_data({"image": invalid_embeds})
|
||||
|
||||
error_msg = str(exc_info.value).lower()
|
||||
assert "image" in error_msg
|
||||
assert "hidden dimension mismatch" in error_msg
|
||||
|
||||
def test_audio_embeddings_wrong_hidden_size_rejected(self):
|
||||
"""Security: Audio embeddings with wrong hidden_size should be rejected."""
|
||||
expected_hidden_size = 768
|
||||
wrong_hidden_size = 2048
|
||||
parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size)
|
||||
|
||||
invalid_embeds = torch.randn(2, 100, wrong_hidden_size)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
parser.parse_mm_data({"audio": invalid_embeds})
|
||||
|
||||
error_msg = str(exc_info.value).lower()
|
||||
assert "audio" in error_msg
|
||||
assert "hidden dimension mismatch" in error_msg
|
||||
|
||||
def test_video_embeddings_wrong_hidden_size_rejected(self):
|
||||
"""Security: Video embeddings with wrong hidden_size should be rejected."""
|
||||
expected_hidden_size = 768
|
||||
wrong_hidden_size = 512
|
||||
parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size)
|
||||
|
||||
invalid_embeds = torch.randn(2, 100, wrong_hidden_size)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
parser.parse_mm_data({"video": invalid_embeds})
|
||||
|
||||
error_msg = str(exc_info.value).lower()
|
||||
assert "video" in error_msg
|
||||
assert "hidden dimension mismatch" in error_msg
|
||||
|
||||
def test_list_of_embeddings_validates_each(self):
|
||||
"""Security: Each embedding in list should be validated."""
|
||||
expected_hidden_size = 768
|
||||
wrong_hidden_size = 1024
|
||||
parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size)
|
||||
|
||||
# List with second tensor having wrong hidden_size
|
||||
invalid_embeds = [
|
||||
torch.randn(100, expected_hidden_size),
|
||||
torch.randn(100, wrong_hidden_size),
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
parser.parse_mm_data({"image": invalid_embeds})
|
||||
|
||||
# Should identify which embedding failed
|
||||
assert "[1]" in str(exc_info.value)
|
||||
|
||||
def test_validation_disabled_allows_any_size(self):
|
||||
"""When validation disabled (legacy), any hidden_size allowed."""
|
||||
parser = MultiModalDataParser(expected_hidden_size=None)
|
||||
|
||||
any_hidden_size = 12345
|
||||
embeds = torch.randn(2, 100, any_hidden_size)
|
||||
|
||||
# Should not raise
|
||||
result = parser.parse_mm_data({"image": embeds})
|
||||
assert "image" in result
|
||||
assert isinstance(result["image"], ImageEmbeddingItems)
|
||||
|
||||
|
||||
class TestEmbeddingItemsDirectValidation:
|
||||
"""Direct tests for EmbeddingItems hidden_size validation."""
|
||||
|
||||
def test_image_embedding_items_validates_batched_tensor(self):
|
||||
"""Test validation for batched (3D) image embeddings."""
|
||||
expected = 768
|
||||
wrong = 1024
|
||||
|
||||
# Valid
|
||||
valid = torch.randn(2, 100, expected)
|
||||
items = ImageEmbeddingItems(valid, expected_hidden_size=expected)
|
||||
assert items.get_count() == 2
|
||||
|
||||
# Invalid
|
||||
invalid = torch.randn(2, 100, wrong)
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ImageEmbeddingItems(invalid, expected_hidden_size=expected)
|
||||
|
||||
assert str(wrong) in str(exc_info.value)
|
||||
assert str(expected) in str(exc_info.value)
|
||||
|
||||
def test_image_embedding_items_validates_list_of_tensors(self):
|
||||
"""Test validation for list of 2D image embeddings."""
|
||||
expected = 768
|
||||
wrong = 512
|
||||
|
||||
# Valid list
|
||||
valid_list = [torch.randn(100, expected), torch.randn(50, expected)]
|
||||
items = ImageEmbeddingItems(valid_list, expected_hidden_size=expected)
|
||||
assert items.get_count() == 2
|
||||
|
||||
# Invalid list
|
||||
invalid_list = [torch.randn(100, expected), torch.randn(50, wrong)]
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ImageEmbeddingItems(invalid_list, expected_hidden_size=expected)
|
||||
|
||||
assert "[1]" in str(exc_info.value)
|
||||
|
||||
def test_audio_embedding_items_validates(self):
|
||||
"""Test validation for audio embeddings."""
|
||||
expected = 768
|
||||
wrong = 256
|
||||
|
||||
invalid = torch.randn(2, 100, wrong)
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
AudioEmbeddingItems(invalid, expected_hidden_size=expected)
|
||||
|
||||
assert "audio" in str(exc_info.value).lower()
|
||||
|
||||
def test_video_embedding_items_validates(self):
|
||||
"""Test validation for video embeddings."""
|
||||
expected = 768
|
||||
wrong = 384
|
||||
|
||||
invalid = torch.randn(2, 100, wrong)
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
VideoEmbeddingItems(invalid, expected_hidden_size=expected)
|
||||
|
||||
assert "video" in str(exc_info.value).lower()
|
||||
|
||||
|
||||
class TestShapeValidationIntegration:
|
||||
"""Integration tests verifying attack scenarios are blocked."""
|
||||
|
||||
def test_attack_scenario_multimodal_image(self):
|
||||
"""
|
||||
Simulate attack through Chat API with image embeddings.
|
||||
|
||||
Verifies validation occurs in multimodal parser path.
|
||||
"""
|
||||
expected_hidden_size = 768
|
||||
wrong_hidden_size = 4096
|
||||
parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size)
|
||||
|
||||
attack_tensor = torch.randn(1, 100, wrong_hidden_size)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
parser.parse_mm_data({"image": attack_tensor})
|
||||
|
||||
def test_attack_scenario_multimodal_audio(self):
|
||||
"""
|
||||
Simulate attack through Chat API with audio embeddings.
|
||||
|
||||
Verifies validation occurs in multimodal parser path.
|
||||
"""
|
||||
expected_hidden_size = 768
|
||||
wrong_hidden_size = 2048
|
||||
parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size)
|
||||
|
||||
attack_tensor = torch.randn(1, 100, wrong_hidden_size)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
parser.parse_mm_data({"audio": attack_tensor})
|
||||
|
||||
def test_attack_scenario_multimodal_video(self):
|
||||
"""
|
||||
Simulate attack through Chat API with video embeddings.
|
||||
|
||||
Verifies validation occurs in multimodal parser path.
|
||||
"""
|
||||
expected_hidden_size = 768
|
||||
wrong_hidden_size = 1024
|
||||
parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size)
|
||||
|
||||
attack_tensor = torch.randn(1, 100, wrong_hidden_size)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
parser.parse_mm_data({"video": attack_tensor})
|
||||
@ -15,6 +15,7 @@ from vllm.entrypoints.openai.parser.harmony_utils import get_encoding
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ErrorResponse,
|
||||
RequestResponseMetadata,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
@ -52,8 +53,19 @@ def with_tool_parser(request) -> bool:
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
scope="module",
|
||||
params=[True],
|
||||
ids=["exclude_tools_when_tool_choice_none"],
|
||||
)
|
||||
def exclude_tools_when_tool_choice_none(request) -> bool:
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def default_server_args(with_tool_parser: bool):
|
||||
def default_server_args(
|
||||
with_tool_parser: bool, exclude_tools_when_tool_choice_none: bool
|
||||
):
|
||||
args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--enforce-eager",
|
||||
@ -72,19 +84,16 @@ def default_server_args(with_tool_parser: bool):
|
||||
"--enable-auto-tool-choice",
|
||||
]
|
||||
)
|
||||
if exclude_tools_when_tool_choice_none:
|
||||
args.append("--exclude-tools-when-tool-choice-none")
|
||||
return args
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def gptoss_server(
|
||||
monkeypatch_module: pytest.MonkeyPatch, default_server_args: list[str]
|
||||
):
|
||||
with monkeypatch_module.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
|
||||
with RemoteOpenAIServer(
|
||||
GPT_OSS_MODEL_NAME, default_server_args
|
||||
) as remote_server:
|
||||
yield remote_server
|
||||
def gptoss_server(default_server_args: list[str]):
|
||||
server_args = default_server_args + ["--attention-backend=TRITON_ATTN"]
|
||||
with RemoteOpenAIServer(GPT_OSS_MODEL_NAME, server_args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
@ -340,6 +349,69 @@ async def test_gpt_oss_tool_message_array_content(
|
||||
assert response_multi_array.choices[0].message is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gpt_oss_tool_choice_none(
|
||||
gptoss_client: OpenAI,
|
||||
with_tool_parser: bool,
|
||||
exclude_tools_when_tool_choice_none: bool,
|
||||
):
|
||||
if not (with_tool_parser and exclude_tools_when_tool_choice_none):
|
||||
pytest.skip(
|
||||
"skip tool_choice tests when non-tool or "
|
||||
"--exclude-tools-when-tool-choice-none not set"
|
||||
)
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string"},
|
||||
"state": {"type": "string"},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["city", "state", "unit"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the temperature(in degrees Celsius) in Dallas?",
|
||||
},
|
||||
]
|
||||
|
||||
tool_choice_auto = await gptoss_client.chat.completions.create(
|
||||
model=GPT_OSS_MODEL_NAME,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
temperature=0.0,
|
||||
)
|
||||
msg = tool_choice_auto.choices[0].message
|
||||
assert len(msg.tool_calls) == 1
|
||||
|
||||
tool_choice_none = await gptoss_client.chat.completions.create(
|
||||
model=GPT_OSS_MODEL_NAME,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="none",
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
msg = tool_choice_none.choices[0].message
|
||||
assert len(msg.tool_calls) == 0
|
||||
|
||||
|
||||
MODEL_NAME = "openai-community/gpt2"
|
||||
MODEL_NAME_SHORT = "gpt2"
|
||||
CHAT_TEMPLATE = "Dummy chat template for testing {}"
|
||||
@ -401,6 +473,7 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
|
||||
lora_request,
|
||||
trace_headers,
|
||||
priority,
|
||||
data_parallel_rank,
|
||||
):
|
||||
return dict(engine_prompt), {}
|
||||
|
||||
@ -1372,3 +1445,69 @@ class TestServingChatWithHarmony:
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_choice_validation_without_parser():
|
||||
"""Test that tool_choice='required' or named tool without tool_parser
|
||||
returns an appropriate error message."""
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
|
||||
models = OpenAIServingModels(
|
||||
engine_client=mock_engine,
|
||||
base_model_paths=BASE_MODEL_PATHS,
|
||||
)
|
||||
# Create serving_chat without tool_parser (enable_auto_tools=False)
|
||||
serving_chat = OpenAIServingChat(
|
||||
mock_engine,
|
||||
models,
|
||||
response_role="assistant",
|
||||
chat_template=CHAT_TEMPLATE,
|
||||
chat_template_content_format="auto",
|
||||
request_logger=None,
|
||||
enable_auto_tools=False, # No tool parser
|
||||
)
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"location": {"type": "string"}},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
# Test tool_choice="required" without tool_parser
|
||||
req_required = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[{"role": "user", "content": "What's the weather?"}],
|
||||
tools=tools,
|
||||
tool_choice="required",
|
||||
)
|
||||
response_required = await serving_chat.create_chat_completion(req_required)
|
||||
assert isinstance(response_required, ErrorResponse)
|
||||
assert "tool_choice" in response_required.error.message
|
||||
assert "--tool-call-parser" in response_required.error.message
|
||||
|
||||
# Test named tool_choice without tool_parser
|
||||
req_named = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[{"role": "user", "content": "What's the weather?"}],
|
||||
tools=tools,
|
||||
tool_choice={"type": "function", "function": {"name": "get_weather"}},
|
||||
)
|
||||
response_named = await serving_chat.create_chat_completion(req_named)
|
||||
assert isinstance(response_named, ErrorResponse)
|
||||
assert "tool_choice" in response_named.error.message
|
||||
assert "--tool-call-parser" in response_named.error.message
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user