mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-01 21:27:15 +08:00
Merge branch 'main' into rename_file_info_to_pkg/file
This commit is contained in:
commit
ff6ec3f0e4
@ -15,6 +15,21 @@ steps:
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
- label: "Build arm64 wheel - CUDA 13.0"
|
||||
depends_on: ~
|
||||
id: build-wheel-arm64-cuda-13-0
|
||||
agents:
|
||||
queue: arm64_cpu_queue_postmerge
|
||||
commands:
|
||||
# #NOTE: torch_cuda_arch_list is derived from upstream PyTorch build files here:
|
||||
# https://github.com/pytorch/pytorch/blob/main/.ci/aarch64_linux/aarch64_ci_build.sh#L7
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=13.0.1 --build-arg torch_cuda_arch_list='8.7 8.9 9.0 10.0+PTX 12.0' --build-arg BUILD_BASE_IMAGE=nvidia/cuda:13.0.1-devel-ubuntu22.04 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
|
||||
- "mkdir artifacts"
|
||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||
- "bash .buildkite/scripts/upload-wheels.sh manylinux_2_35"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
# aarch64 build
|
||||
- label: "Build arm64 CPU wheel"
|
||||
depends_on: ~
|
||||
@ -25,7 +40,7 @@ steps:
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --build-arg VLLM_BUILD_ACL=ON --tag vllm-ci:build-image --target vllm-build --progress plain -f docker/Dockerfile.cpu ."
|
||||
- "mkdir artifacts"
|
||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||
- "bash .buildkite/scripts/upload-wheels.sh"
|
||||
- "bash .buildkite/scripts/upload-wheels.sh manylinux_2_35"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
@ -39,7 +54,7 @@ steps:
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
|
||||
- "mkdir artifacts"
|
||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||
- "bash .buildkite/scripts/upload-wheels.sh"
|
||||
- "bash .buildkite/scripts/upload-wheels.sh manylinux_2_31"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
@ -52,7 +67,21 @@ steps:
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=13.0.1 --build-arg BUILD_BASE_IMAGE=nvidia/cuda:13.0.1-devel-ubuntu22.04 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
|
||||
- "mkdir artifacts"
|
||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||
- "bash .buildkite/scripts/upload-wheels.sh"
|
||||
- "bash .buildkite/scripts/upload-wheels.sh manylinux_2_35"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
# x86 CPU wheel build
|
||||
- label: "Build x86 CPU wheel"
|
||||
depends_on: ~
|
||||
id: build-wheel-x86-cpu
|
||||
agents:
|
||||
queue: cpu_queue_postmerge
|
||||
commands:
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --build-arg VLLM_CPU_AVX512BF16=true --build-arg VLLM_CPU_AVX512VNNI=true --build-arg VLLM_CPU_AMXBF16=true --tag vllm-ci:build-image --target vllm-build --progress plain -f docker/Dockerfile.cpu ."
|
||||
- "mkdir artifacts"
|
||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||
- "bash .buildkite/scripts/upload-wheels.sh manylinux_2_35"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
|
||||
@ -372,6 +372,17 @@ if __name__ == "__main__":
|
||||
|
||||
print(f"Found {len(wheel_files)} wheel files for version {version}: {wheel_files}")
|
||||
|
||||
# keep only "official" files for a non-nightly version (specifed by cli args)
|
||||
PY_VERSION_RE = re.compile(r"^\d+\.\d+\.\d+([a-zA-Z0-9.+-]*)?$")
|
||||
if PY_VERSION_RE.match(version):
|
||||
# upload-wheels.sh ensures no "dev" is in args.version
|
||||
wheel_files = list(
|
||||
filter(lambda x: version in x and "dev" not in x, wheel_files)
|
||||
)
|
||||
print(f"Non-nightly version detected, wheel files used: {wheel_files}")
|
||||
else:
|
||||
print("Nightly version detected, keeping all wheel files.")
|
||||
|
||||
# Generate index and metadata, assuming wheels and indices are stored as:
|
||||
# s3://vllm-wheels/{version}/<wheel files>
|
||||
# s3://vllm-wheels/<anything>/<index files>
|
||||
|
||||
@ -36,6 +36,11 @@ function cpu_tests() {
|
||||
set -e
|
||||
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m"
|
||||
|
||||
# Run model tests
|
||||
docker exec cpu-test bash -c "
|
||||
set -e
|
||||
pytest -x -v -s tests/models/multimodal/generation/test_whisper.py -m cpu_model"
|
||||
|
||||
# Run kernel tests
|
||||
docker exec cpu-test bash -c "
|
||||
set -e
|
||||
|
||||
@ -1,73 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euxo pipefail
|
||||
|
||||
# args: [THRESHOLD] [NUM_QUESTIONS] [START_PORT]
|
||||
THRESHOLD=${1:-0.25}
|
||||
NUM_Q=${2:-1319}
|
||||
PORT=${3:-8030}
|
||||
OUT_DIR=${OUT_DIR:-/tmp/vllm-scheduled}
|
||||
mkdir -p "${OUT_DIR}"
|
||||
|
||||
wait_for_server() {
|
||||
local port=$1
|
||||
timeout 600 bash -c '
|
||||
until curl -sf "http://127.0.0.1:'"$port"'/health" > /dev/null; do
|
||||
sleep 1
|
||||
done'
|
||||
}
|
||||
|
||||
MODEL="deepseek-ai/DeepSeek-V2-lite"
|
||||
|
||||
# Set BACKENDS based on platform
|
||||
if command -v rocm-smi &> /dev/null || [[ -d /opt/rocm ]] || [[ -n "${ROCM_PATH:-}" ]]; then
|
||||
# ROCm platform
|
||||
BACKENDS=("allgather_reducescatter")
|
||||
# Disable MOE padding for ROCm since it is causing eplb to fail
|
||||
export VLLM_ROCM_MOE_PADDING=0
|
||||
else
|
||||
# Non-ROCm platform (CUDA/other)
|
||||
BACKENDS=("deepep_high_throughput" "deepep_low_latency")
|
||||
fi
|
||||
|
||||
cleanup() {
|
||||
if [[ -n "${SERVER_PID:-}" ]] && kill -0 "${SERVER_PID}" 2>/dev/null; then
|
||||
kill "${SERVER_PID}" 2>/dev/null || true
|
||||
for _ in {1..20}; do
|
||||
kill -0 "${SERVER_PID}" 2>/dev/null || break
|
||||
sleep 0.5
|
||||
done
|
||||
kill -9 "${SERVER_PID}" 2>/dev/null || true
|
||||
fi
|
||||
}
|
||||
trap cleanup EXIT
|
||||
|
||||
for BACK in "${BACKENDS[@]}"; do
|
||||
VLLM_DEEP_GEMM_WARMUP=skip \
|
||||
VLLM_ALL2ALL_BACKEND=$BACK \
|
||||
vllm serve "$MODEL" \
|
||||
--enforce-eager \
|
||||
--tensor-parallel-size 2 \
|
||||
--data-parallel-size 2 \
|
||||
--enable-expert-parallel \
|
||||
--enable-eplb \
|
||||
--eplb-config '{"window_size":200,"step_interval":600,"use_async":true}' \
|
||||
--trust-remote-code \
|
||||
--max-model-len 2048 \
|
||||
--port $PORT &
|
||||
SERVER_PID=$!
|
||||
wait_for_server $PORT
|
||||
|
||||
TAG=$(echo "$MODEL" | tr '/: \\n' '_____')
|
||||
OUT="${OUT_DIR}/${TAG}_${BACK}_async_eplb.json"
|
||||
python3 tests/evals/gsm8k/gsm8k_eval.py --host http://127.0.0.1 --port $PORT --num-questions ${NUM_Q} --save-results ${OUT}
|
||||
python3 - <<PY
|
||||
import json; acc=json.load(open('${OUT}'))['accuracy']
|
||||
print(f"${MODEL} ${BACK}: accuracy {acc:.3f}")
|
||||
assert acc >= ${THRESHOLD}, f"${MODEL} ${BACK} accuracy {acc}"
|
||||
PY
|
||||
|
||||
cleanup
|
||||
SERVER_PID=
|
||||
sleep 1
|
||||
PORT=$((PORT+1))
|
||||
done
|
||||
@ -50,7 +50,6 @@ for BACK in "${BACKENDS[@]}"; do
|
||||
--data-parallel-size 2 \
|
||||
--enable-expert-parallel \
|
||||
--enable-eplb \
|
||||
--eplb-config '{"window_size":200,"step_interval":600}' \
|
||||
--trust-remote-code \
|
||||
--max-model-len 2048 \
|
||||
--port $PORT &
|
||||
|
||||
@ -34,9 +34,10 @@ if [[ ${#wheel_files[@]} -ne 1 ]]; then
|
||||
fi
|
||||
wheel="${wheel_files[0]}"
|
||||
|
||||
# current build image uses ubuntu 20.04, which corresponds to manylinux_2_31
|
||||
# default build image uses ubuntu 20.04, which corresponds to manylinux_2_31
|
||||
# we also accept params as manylinux tag
|
||||
# refer to https://github.com/mayeut/pep600_compliance?tab=readme-ov-file#acceptable-distros-to-build-wheels
|
||||
manylinux_version="manylinux_2_31"
|
||||
manylinux_version="${1:-manylinux_2_31}"
|
||||
|
||||
# Rename 'linux' to the appropriate manylinux version in the wheel filename
|
||||
if [[ "$wheel" != *"linux"* ]]; then
|
||||
@ -96,8 +97,11 @@ if [[ "$BUILDKITE_BRANCH" == "main" && "$BUILDKITE_PULL_REQUEST" == "false" ]];
|
||||
aws s3 cp --recursive "$INDICES_OUTPUT_DIR/" "s3://$BUCKET/nightly/"
|
||||
fi
|
||||
|
||||
# copy to /<pure_version>/ only if it does not have "dev" in the version
|
||||
# re-generate and copy to /<pure_version>/ only if it does not have "dev" in the version
|
||||
if [[ "$version" != *"dev"* ]]; then
|
||||
echo "Uploading indices to overwrite /$pure_version/"
|
||||
echo "Re-generating indices for /$pure_version/"
|
||||
rm -rf "$INDICES_OUTPUT_DIR/*"
|
||||
mkdir -p "$INDICES_OUTPUT_DIR"
|
||||
$PYTHON .buildkite/scripts/generate-nightly-index.py --version "$pure_version" --current-objects "$obj_json" --output-dir "$INDICES_OUTPUT_DIR" --comment "version $pure_version" $alias_arg
|
||||
aws s3 cp --recursive "$INDICES_OUTPUT_DIR/" "s3://$BUCKET/$pure_version/"
|
||||
fi
|
||||
|
||||
@ -61,8 +61,8 @@ steps:
|
||||
- pytest -v -s -m 'not cpu_test' multimodal
|
||||
- pytest -v -s utils_
|
||||
|
||||
- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 15min
|
||||
timeout_in_minutes: 20
|
||||
- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 20min
|
||||
timeout_in_minutes: 30
|
||||
mirror_hardwares: [amdexperimental, amdproduction, amdtentative]
|
||||
agent_pool: mi325_1
|
||||
grade: Blocking
|
||||
@ -73,6 +73,7 @@ steps:
|
||||
- tests/multimodal
|
||||
- tests/standalone_tests/lazy_imports.py
|
||||
- tests/tokenizers_
|
||||
- tests/tool_parsers
|
||||
- tests/transformers_utils
|
||||
- tests/config
|
||||
no_gpu: true
|
||||
@ -82,6 +83,7 @@ steps:
|
||||
- pytest -v -s test_outputs.py
|
||||
- pytest -v -s -m 'cpu_test' multimodal
|
||||
- pytest -v -s tokenizers_
|
||||
- pytest -v -s tool_parsers
|
||||
- pytest -v -s transformers_utils
|
||||
- pytest -v -s config
|
||||
|
||||
@ -326,10 +328,10 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py
|
||||
|
||||
- label: V1 Test e2e + engine # 30min
|
||||
timeout_in_minutes: 45
|
||||
- label: V1 Test e2e + engine # 65min
|
||||
timeout_in_minutes: 90
|
||||
mirror_hardwares: [amdexperimental]
|
||||
agent_pool: mi325_1
|
||||
agent_pool: mi325_4
|
||||
# grade: Blocking
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -435,7 +437,7 @@ steps:
|
||||
|
||||
- label: Examples Test # 30min
|
||||
timeout_in_minutes: 45
|
||||
mirror_hardwares: [amdexperimental]
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
agent_pool: mi325_1
|
||||
# grade: Blocking
|
||||
working_dir: "/vllm-workspace/examples"
|
||||
@ -455,7 +457,6 @@ steps:
|
||||
# for multi-modal models
|
||||
- python3 offline_inference/audio_language.py --seed 0
|
||||
- python3 offline_inference/vision_language.py --seed 0
|
||||
- python3 offline_inference/vision_language_pooling.py --seed 0
|
||||
- python3 offline_inference/vision_language_multi_image.py --seed 0
|
||||
- python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0
|
||||
# for pooling models
|
||||
@ -760,19 +761,7 @@ steps:
|
||||
- vllm/
|
||||
- tests/tool_use
|
||||
commands:
|
||||
- pytest -v -s -m 'not cpu_test' tool_use
|
||||
|
||||
- label: OpenAI-Compatible Tool Use (CPU) # 5 mins
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
agent_pool: mi325_1
|
||||
# grade: Blocking
|
||||
timeout_in_minutes: 10
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/tool_use
|
||||
no_gpu: true
|
||||
commands:
|
||||
- pytest -v -s -m 'cpu_test' tool_use
|
||||
- pytest -v -s tool_use
|
||||
|
||||
##### models test #####
|
||||
|
||||
@ -1630,7 +1619,6 @@ steps:
|
||||
mirror_hardwares: [amdexperimental]
|
||||
agent_pool: mi325_4
|
||||
# grade: Blocking
|
||||
gpu: h100
|
||||
optional: true
|
||||
num_gpus: 4
|
||||
working_dir: "/vllm-workspace"
|
||||
|
||||
@ -57,8 +57,8 @@ steps:
|
||||
- pytest -v -s -m 'not cpu_test' multimodal
|
||||
- pytest -v -s utils_
|
||||
|
||||
- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 15min
|
||||
timeout_in_minutes: 20
|
||||
- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 20min
|
||||
timeout_in_minutes: 30
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/test_inputs.py
|
||||
@ -66,6 +66,7 @@ steps:
|
||||
- tests/multimodal
|
||||
- tests/standalone_tests/lazy_imports.py
|
||||
- tests/tokenizers_
|
||||
- tests/tool_parsers
|
||||
- tests/transformers_utils
|
||||
- tests/config
|
||||
no_gpu: true
|
||||
@ -75,6 +76,7 @@ steps:
|
||||
- pytest -v -s test_outputs.py
|
||||
- pytest -v -s -m 'cpu_test' multimodal
|
||||
- pytest -v -s tokenizers_
|
||||
- pytest -v -s tool_parsers
|
||||
- pytest -v -s transformers_utils
|
||||
- pytest -v -s config
|
||||
|
||||
@ -652,7 +654,7 @@ steps:
|
||||
- vllm/model_executor/layers/quantization
|
||||
autorun_on_main: true
|
||||
commands:
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt
|
||||
|
||||
- label: OpenAI API correctness # 22min
|
||||
timeout_in_minutes: 30
|
||||
@ -672,16 +674,7 @@ steps:
|
||||
- vllm/
|
||||
- tests/tool_use
|
||||
commands:
|
||||
- pytest -v -s -m 'not cpu_test' tool_use
|
||||
|
||||
- label: OpenAI-Compatible Tool Use (CPU) # 5 mins
|
||||
timeout_in_minutes: 10
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/tool_use
|
||||
no_gpu: true
|
||||
commands:
|
||||
- pytest -v -s -m 'cpu_test' tool_use
|
||||
- pytest -v -s tool_use
|
||||
|
||||
##### models test #####
|
||||
|
||||
@ -692,6 +685,7 @@ steps:
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/models/test_initialization.py
|
||||
- tests/models/registry.py
|
||||
commands:
|
||||
# Run a subset of model initialization tests
|
||||
- pytest -v -s models/test_initialization.py::test_can_initialize_small_subset
|
||||
@ -704,6 +698,7 @@ steps:
|
||||
- vllm/model_executor/models/
|
||||
- vllm/transformers_utils/
|
||||
- tests/models/test_initialization.py
|
||||
- tests/models/registry.py
|
||||
commands:
|
||||
# Only when vLLM model source is modified - test initialization of a large
|
||||
# subset of supported models (the complement of the small subset in the above
|
||||
@ -836,7 +831,7 @@ steps:
|
||||
- tests/models/multimodal
|
||||
no_gpu: true
|
||||
commands:
|
||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
- "pip install git+https://github.com/TIGER-AI-Lab/Mantis.git || echo 'Mantis installation skipped (decord not available on CPU-only environment)'"
|
||||
- pytest -v -s models/multimodal/processing --ignore models/multimodal/processing/test_tensor_schema.py
|
||||
|
||||
- label: Multi-Modal Processor Test
|
||||
@ -1069,7 +1064,7 @@ steps:
|
||||
- csrc/
|
||||
- vllm/model_executor/layers/quantization
|
||||
commands:
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-blackwell.txt --tp-size=1
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-blackwell.txt
|
||||
|
||||
##### 1 GPU test #####
|
||||
##### multi gpus test #####
|
||||
@ -1228,6 +1223,8 @@ steps:
|
||||
# FIXIT: find out which code initialize cuda before running the test
|
||||
# before the fix, we need to use spawn to test it
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
# Alot of these tests are on the edge of OOMing
|
||||
- export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
||||
# There is some Tensor Parallelism related processing logic in LoRA that
|
||||
# requires multi-GPU testing for validation.
|
||||
- pytest -v -s -x lora/test_chatglm3_tp.py
|
||||
@ -1346,6 +1343,7 @@ steps:
|
||||
- label: Prime-RL Integration Test # 15min
|
||||
timeout_in_minutes: 30
|
||||
optional: true
|
||||
soft_fail: true
|
||||
num_gpus: 2
|
||||
working_dir: "/vllm-workspace"
|
||||
source_file_dependencies:
|
||||
@ -1380,21 +1378,3 @@ steps:
|
||||
working_dir: "/vllm-workspace"
|
||||
commands:
|
||||
- bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1
|
||||
|
||||
- label: DeepSeek V2-Lite Async EPLB Accuracy
|
||||
timeout_in_minutes: 60
|
||||
gpu: h100
|
||||
optional: true
|
||||
num_gpus: 4
|
||||
working_dir: "/vllm-workspace"
|
||||
commands:
|
||||
- bash .buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_async_eplb.sh 0.25 1319 8030
|
||||
|
||||
- label: Qwen3-Next-80B-A3B-Instruct MTP Async EPLB Accuracy
|
||||
timeout_in_minutes: 60
|
||||
gpu: h100
|
||||
optional: true
|
||||
num_gpus: 4
|
||||
working_dir: "/vllm-workspace"
|
||||
commands:
|
||||
- bash .buildkite/scripts/scheduled_integration_test/qwen3_next_mtp_async_eplb.sh 0.8 1319 8040
|
||||
|
||||
@ -115,7 +115,7 @@ steps:
|
||||
|
||||
- label: Async Engine, Inputs, Utils, Worker, Config (CPU)
|
||||
depends_on: ~
|
||||
timeout_in_minutes: 20
|
||||
timeout_in_minutes: 30
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/test_inputs.py
|
||||
@ -123,6 +123,7 @@ steps:
|
||||
- tests/multimodal
|
||||
- tests/standalone_tests/lazy_imports.py
|
||||
- tests/tokenizers_
|
||||
- tests/tool_parsers
|
||||
- tests/transformers_utils
|
||||
- tests/config
|
||||
no_gpu: true
|
||||
@ -132,6 +133,7 @@ steps:
|
||||
- pytest -v -s test_outputs.py
|
||||
- pytest -v -s -m 'cpu_test' multimodal
|
||||
- pytest -v -s tokenizers_
|
||||
- pytest -v -s tool_parsers
|
||||
- pytest -v -s transformers_utils
|
||||
- pytest -v -s config
|
||||
|
||||
|
||||
@ -10,14 +10,4 @@ steps:
|
||||
- vllm/
|
||||
- tests/tool_use
|
||||
commands:
|
||||
- pytest -v -s -m 'not cpu_test' tool_use
|
||||
|
||||
- label: OpenAI-Compatible Tool Use (CPU)
|
||||
depends_on: ~
|
||||
timeout_in_minutes: 10
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/tool_use
|
||||
no_gpu: true
|
||||
commands:
|
||||
- pytest -v -s -m 'cpu_test' tool_use
|
||||
- pytest -v -s tool_use
|
||||
|
||||
115
CMakeLists.txt
115
CMakeLists.txt
@ -357,6 +357,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
# marlin arches for fp16 output
|
||||
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0+PTX" "${CUDA_ARCHS}")
|
||||
# marlin has limited support for turing
|
||||
cuda_archs_loose_intersection(MARLIN_SM75_ARCHS "7.5" "${CUDA_ARCHS}")
|
||||
# marlin arches for bf16 output (we need 9.0 for bf16 atomicAdd PTX)
|
||||
cuda_archs_loose_intersection(MARLIN_BF16_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")
|
||||
# marlin arches for fp8 input
|
||||
@ -364,8 +366,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction
|
||||
# so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0)
|
||||
cuda_archs_loose_intersection(MARLIN_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}")
|
||||
# marlin arches for other files
|
||||
cuda_archs_loose_intersection(MARLIN_OTHER_ARCHS "7.5;8.0+PTX" "${CUDA_ARCHS}")
|
||||
|
||||
if (MARLIN_ARCHS)
|
||||
if (MARLIN_OTHER_ARCHS)
|
||||
|
||||
#
|
||||
# For the Marlin kernels we automatically generate sources for various
|
||||
@ -384,7 +388,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
OR NOT $CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH} STREQUAL ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH})
|
||||
execute_process(
|
||||
COMMAND ${CMAKE_COMMAND} -E env
|
||||
PYTHONPATH=$PYTHONPATH
|
||||
PYTHONPATH=$ENV{PYTHONPATH}
|
||||
${Python_EXECUTABLE} ${MARLIN_GEN_SCRIPT} ${CUDA_ARCHS_STR}
|
||||
RESULT_VARIABLE marlin_generation_result
|
||||
OUTPUT_VARIABLE marlin_generation_result
|
||||
@ -406,25 +410,39 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
message(STATUS "Marlin generation script has not changed, skipping generation.")
|
||||
endif()
|
||||
|
||||
file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_float16.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
|
||||
if (MARLIN_ARCHS)
|
||||
file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_float16.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
|
||||
|
||||
file(GLOB MARLIN_TEMPLATE_BF16_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_bfloat16.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_TEMPLATE_BF16_KERNEL_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_BF16_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_TEMPLATE_BF16_KERNEL_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
file(GLOB MARLIN_TEMPLATE_BF16_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_bfloat16.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_TEMPLATE_BF16_KERNEL_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_BF16_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_TEMPLATE_BF16_KERNEL_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_BF16_KERNEL_SRC})
|
||||
endif()
|
||||
|
||||
if (MARLIN_SM75_ARCHS)
|
||||
file(GLOB MARLIN_TEMPLATE_SM75_KERNEL_SRC "csrc/quantization/gptq_marlin/sm75_kernel_*.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_TEMPLATE_SM75_KERNEL_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_SM75_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_TEMPLATE_SM75_KERNEL_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_SM75_KERNEL_SRC})
|
||||
endif()
|
||||
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_BF16_KERNEL_SRC})
|
||||
|
||||
if (MARLIN_FP8_ARCHS)
|
||||
file(GLOB MARLIN_TEMPLATE_FP8_KERNEL_SRC "csrc/quantization/gptq_marlin/sm89_kernel_*.cu")
|
||||
@ -446,14 +464,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_SRCS}"
|
||||
CUDA_ARCHS "${MARLIN_ARCHS}")
|
||||
CUDA_ARCHS "${MARLIN_OTHER_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties("csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
||||
set_source_files_properties(${MARLIN_SRCS}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_EXT_SRC "${MARLIN_SRCS}")
|
||||
|
||||
message(STATUS "Building Marlin kernels for archs: ${MARLIN_ARCHS}")
|
||||
message(STATUS "Building Marlin kernels for archs: ${MARLIN_OTHER_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building Marlin kernels as no compatible archs found"
|
||||
" in CUDA target architectures")
|
||||
@ -822,7 +840,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
OR NOT $CACHE{MACHETE_GEN_SCRIPT_HASH} STREQUAL ${MACHETE_GEN_SCRIPT_HASH})
|
||||
execute_process(
|
||||
COMMAND ${CMAKE_COMMAND} -E env
|
||||
PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH
|
||||
PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$ENV{PYTHONPATH}
|
||||
${Python_EXECUTABLE} ${MACHETE_GEN_SCRIPT}
|
||||
RESULT_VARIABLE machete_generation_result
|
||||
OUTPUT_VARIABLE machete_generation_output
|
||||
@ -980,12 +998,16 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# note that we always set `use_atomic_add=False` for moe marlin now,
|
||||
# so we don't need 9.0 for bf16 atomicAdd PTX
|
||||
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0+PTX" "${CUDA_ARCHS}")
|
||||
# moe marlin has limited support for turing
|
||||
cuda_archs_loose_intersection(MARLIN_MOE_SM75_ARCHS "7.5" "${CUDA_ARCHS}")
|
||||
# moe marlin arches for fp8 input
|
||||
# - sm80 doesn't support fp8 computation
|
||||
# - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction
|
||||
# so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0)
|
||||
cuda_archs_loose_intersection(MARLIN_MOE_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}")
|
||||
if (MARLIN_MOE_ARCHS)
|
||||
# moe marlin arches for other files
|
||||
cuda_archs_loose_intersection(MARLIN_MOE_OTHER_ARCHS "7.5;8.0+PTX" "${CUDA_ARCHS}")
|
||||
if (MARLIN_MOE_OTHER_ARCHS)
|
||||
|
||||
#
|
||||
# For the Marlin MOE kernels we automatically generate sources for various
|
||||
@ -1004,7 +1026,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH})
|
||||
execute_process(
|
||||
COMMAND ${CMAKE_COMMAND} -E env
|
||||
PYTHONPATH=$PYTHONPATH
|
||||
PYTHONPATH=$ENV{PYTHONPATH}
|
||||
${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT} ${CUDA_ARCHS_STR}
|
||||
RESULT_VARIABLE moe_marlin_generation_result
|
||||
OUTPUT_VARIABLE moe_marlin_generation_output
|
||||
@ -1026,16 +1048,29 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
message(STATUS "Marlin MOE generation script has not changed, skipping generation.")
|
||||
endif()
|
||||
|
||||
file(GLOB MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/sm80_kernel_*.cu")
|
||||
list(APPEND MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/ops.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_MOE_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_MOE_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_MOE_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
if (MARLIN_MOE_ARCHS)
|
||||
file(GLOB MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/sm80_kernel_*.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_MOE_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_MOE_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_MOE_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_SRC})
|
||||
endif()
|
||||
|
||||
if (MARLIN_MOE_SM75_ARCHS)
|
||||
file(GLOB MARLIN_MOE_SM75_SRC "csrc/moe/marlin_moe_wna16/sm75_kernel_*.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_MOE_SM75_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_MOE_SM75_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_MOE_SM75_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_SM75_SRC})
|
||||
endif()
|
||||
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_SRC})
|
||||
|
||||
if (MARLIN_MOE_FP8_ARCHS)
|
||||
file(GLOB MARLIN_MOE_FP8_SRC "csrc/moe/marlin_moe_wna16/sm89_kernel_*.cu")
|
||||
@ -1049,7 +1084,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_FP8_SRC})
|
||||
endif()
|
||||
|
||||
message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}")
|
||||
set(MARLIN_MOE_OTHER_SRC "csrc/moe/marlin_moe_wna16/ops.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_MOE_OTHER_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_MOE_OTHER_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_MOE_OTHER_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_MOE_EXT_SRC "${MARLIN_MOE_OTHER_SRC}")
|
||||
|
||||
message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_OTHER_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building Marlin MOE kernels as no compatible archs found"
|
||||
" in CUDA target architectures")
|
||||
|
||||
@ -143,11 +143,13 @@ Compute Resources:
|
||||
- Databricks
|
||||
- DeepInfra
|
||||
- Google Cloud
|
||||
- IBM
|
||||
- Intel
|
||||
- Lambda Lab
|
||||
- Nebius
|
||||
- Novita AI
|
||||
- NVIDIA
|
||||
- Red Hat
|
||||
- Replicate
|
||||
- Roblox
|
||||
- RunPod
|
||||
|
||||
@ -18,6 +18,11 @@ MIN_CACHE_HIT_PCT=${MIN_CACHE_HIT_PCT:-0}
|
||||
MAX_LATENCY_ALLOWED_MS=${MAX_LATENCY_ALLOWED_MS:-100000000000}
|
||||
NUM_SEQS_LIST=${NUM_SEQS_LIST:-"128 256"}
|
||||
NUM_BATCHED_TOKENS_LIST=${NUM_BATCHED_TOKENS_LIST:-"512 1024 2048 4096"}
|
||||
HOSTNAME=$(hostname)
|
||||
if [[ -z "$HOSTNAME" ]]; then
|
||||
echo "Error: Failed to determine hostname." >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
LOG_FOLDER="$BASE/auto-benchmark/$TAG"
|
||||
RESULT="$LOG_FOLDER/result.txt"
|
||||
@ -82,6 +87,7 @@ start_server() {
|
||||
"$MODEL"
|
||||
"--disable-log-requests"
|
||||
"--port" "8004"
|
||||
"--host" "$HOSTNAME"
|
||||
"--gpu-memory-utilization" "$gpu_memory_utilization"
|
||||
"--max-num-seqs" "$max_num_seqs"
|
||||
"--max-num-batched-tokens" "$max_num_batched_tokens"
|
||||
@ -113,7 +119,7 @@ start_server() {
|
||||
# since that we should always have permission to send signal to the server process.
|
||||
kill -0 $server_pid 2> /dev/null || break
|
||||
|
||||
RESPONSE=$(curl -s -X GET "http://0.0.0.0:8004/health" -w "%{http_code}" -o /dev/stdout)
|
||||
RESPONSE=$(curl -s -X GET "http://${HOSTNAME}:8004/health" -w "%{http_code}" -o /dev/stdout)
|
||||
STATUS_CODE=$(echo "$RESPONSE" | tail -n 1)
|
||||
if [[ "$STATUS_CODE" -eq 200 ]]; then
|
||||
server_started=1
|
||||
@ -173,6 +179,7 @@ run_benchmark() {
|
||||
--goodput e2el:$MAX_LATENCY_ALLOWED_MS \
|
||||
--num-prompts 1000 \
|
||||
--random-prefix-len $prefix_len \
|
||||
--host "$HOSTNAME" \
|
||||
--port 8004 &> "$bm_log"
|
||||
throughput=$(grep "Request throughput (req/s):" "$bm_log" | sed 's/[^0-9.]//g')
|
||||
e2el=$(grep "P99 E2EL (ms):" "$bm_log" | awk '{print $NF}')
|
||||
@ -188,7 +195,7 @@ run_benchmark() {
|
||||
request_rate=$((${throughput%.*} + 1))
|
||||
while ((request_rate > 0)); do
|
||||
# clear prefix cache
|
||||
curl -X POST http://0.0.0.0:8004/reset_prefix_cache
|
||||
curl -X POST http://${HOSTNAME}:8004/reset_prefix_cache
|
||||
sleep 5
|
||||
bm_log="$LOG_FOLDER/bm_log_${max_num_seqs}_${max_num_batched_tokens}_requestrate_${request_rate}.txt"
|
||||
vllm bench serve \
|
||||
@ -204,6 +211,7 @@ run_benchmark() {
|
||||
--goodput e2el:$MAX_LATENCY_ALLOWED_MS \
|
||||
--num-prompts 100 \
|
||||
--random-prefix-len $prefix_len \
|
||||
--host "$HOSTNAME" \
|
||||
--port 8004 &> "$bm_log"
|
||||
throughput=$(grep "Request throughput (req/s):" "$bm_log" | sed 's/[^0-9.]//g')
|
||||
e2el=$(grep "P99 E2EL (ms):" "$bm_log" | awk '{print $NF}')
|
||||
@ -304,6 +312,7 @@ if (( $(echo "$best_throughput > 0" | bc -l) )); then
|
||||
--goodput e2el:$MAX_LATENCY_ALLOWED_MS \
|
||||
--num-prompts 100 \
|
||||
--random-prefix-len $prefix_len \
|
||||
--host "$HOSTNAME" \
|
||||
--port 8004 \
|
||||
--profile &> "$bm_log"
|
||||
else
|
||||
|
||||
@ -620,7 +620,7 @@ def get_tokenizer(
|
||||
kwargs["use_fast"] = False
|
||||
if tokenizer_mode == "mistral":
|
||||
try:
|
||||
from vllm.tokenizers import MistralTokenizer
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"MistralTokenizer requires vllm package.\n"
|
||||
|
||||
@ -32,12 +32,11 @@ def benchmark_propose(args):
|
||||
|
||||
model_config = ModelConfig(
|
||||
model="facebook/opt-125m",
|
||||
task="generate",
|
||||
max_model_len=args.num_token + args.num_spec_token,
|
||||
tokenizer="facebook/opt-125m",
|
||||
tokenizer_mode="auto",
|
||||
dtype="auto",
|
||||
seed=None,
|
||||
seed=0,
|
||||
trust_remote_code=False,
|
||||
)
|
||||
proposer = NgramProposer(
|
||||
|
||||
@ -574,7 +574,7 @@ async def benchmark(
|
||||
)
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
"Total Token throughput (tok/s):", metrics.total_token_throughput
|
||||
"Total token throughput (tok/s):", metrics.total_token_throughput
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -13,8 +13,8 @@ from vllm.triton_utils import triton
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
batch_size_range = [1, 16, 32, 64, 128]
|
||||
seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]
|
||||
batch_size_range = [1, 16, 128]
|
||||
seq_len_range = [1, 16, 64, 1024, 4096]
|
||||
intermediate_size = [3072, 9728, 12288]
|
||||
configs = list(itertools.product(batch_size_range, seq_len_range, intermediate_size))
|
||||
|
||||
|
||||
150
benchmarks/kernels/benchmark_mla_k_concat.py
Normal file
150
benchmarks/kernels/benchmark_mla_k_concat.py
Normal file
@ -0,0 +1,150 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Benchmark script comparing torch.cat vs direct copy for k_nope/k_pe concatenation
|
||||
in MLA (Multi-head Latent Attention) prefill.
|
||||
|
||||
This validates that the optimization from commit 8d4142bd is beneficial across
|
||||
various batch sizes, not just the originally tested batch size of 32768.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
# DeepSeek-V3 MLA dimensions
|
||||
NUM_HEADS = 128
|
||||
QK_NOPE_HEAD_DIM = 128
|
||||
PE_DIM = 64
|
||||
|
||||
|
||||
def cat_method(k_nope: torch.Tensor, k_pe: torch.Tensor) -> torch.Tensor:
|
||||
"""Original torch.cat approach with expand."""
|
||||
return torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
|
||||
|
||||
|
||||
def direct_copy_method(k_nope: torch.Tensor, k_pe: torch.Tensor) -> torch.Tensor:
|
||||
"""Optimized direct copy approach (avoids expand + cat overhead)."""
|
||||
k = torch.empty(
|
||||
(*k_nope.shape[:-1], k_nope.shape[-1] + k_pe.shape[-1]),
|
||||
dtype=k_nope.dtype,
|
||||
device=k_nope.device,
|
||||
)
|
||||
k[..., : k_nope.shape[-1]] = k_nope
|
||||
k[..., k_nope.shape[-1] :] = k_pe
|
||||
return k
|
||||
|
||||
|
||||
def benchmark_method(
|
||||
method: Callable,
|
||||
k_nope: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
num_warmup: int = 10,
|
||||
num_iters: int = 100,
|
||||
) -> float:
|
||||
"""Benchmark a concatenation method and return mean latency in ms."""
|
||||
# Warmup
|
||||
for _ in range(num_warmup):
|
||||
_ = method(k_nope, k_pe)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Benchmark
|
||||
start = time.perf_counter()
|
||||
for _ in range(num_iters):
|
||||
_ = method(k_nope, k_pe)
|
||||
torch.cuda.synchronize()
|
||||
end = time.perf_counter()
|
||||
|
||||
return (end - start) / num_iters * 1000 # Convert to ms
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_benchmark(dtype: torch.dtype, dtype_name: str):
|
||||
"""Run benchmark for a specific dtype."""
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
# Batch sizes to test (powers of 2 from 32 to 65536)
|
||||
batch_sizes = [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536]
|
||||
|
||||
print("=" * 80)
|
||||
print("Benchmark: torch.cat vs direct copy for MLA k_nope/k_pe concatenation")
|
||||
print("=" * 80)
|
||||
print(
|
||||
f"Tensor shapes: k_nope=[B, {NUM_HEADS}, {QK_NOPE_HEAD_DIM}], "
|
||||
f"k_pe=[B, 1, {PE_DIM}]"
|
||||
)
|
||||
print(f"dtype: {dtype_name}")
|
||||
print()
|
||||
print(
|
||||
f"{'Batch Size':>12} | {'cat (ms)':>10} | {'direct (ms)':>12} | "
|
||||
f"{'Speedup':>8} | {'Reduction':>10}"
|
||||
)
|
||||
print("-" * 70)
|
||||
|
||||
results = []
|
||||
for batch_size in batch_sizes:
|
||||
# Create input tensors (generate in float32 then convert for FP8 compatibility)
|
||||
k_nope = torch.randn(
|
||||
batch_size, NUM_HEADS, QK_NOPE_HEAD_DIM, dtype=torch.float32, device="cuda"
|
||||
).to(dtype)
|
||||
k_pe = torch.randn(
|
||||
batch_size, 1, PE_DIM, dtype=torch.float32, device="cuda"
|
||||
).to(dtype)
|
||||
|
||||
# Benchmark both methods
|
||||
cat_time = benchmark_method(cat_method, k_nope, k_pe)
|
||||
direct_time = benchmark_method(direct_copy_method, k_nope, k_pe)
|
||||
|
||||
speedup = cat_time / direct_time
|
||||
reduction = (1 - direct_time / cat_time) * 100
|
||||
|
||||
results.append((batch_size, cat_time, direct_time, speedup, reduction))
|
||||
|
||||
print(
|
||||
f"{batch_size:>12} | {cat_time:>10.3f} | {direct_time:>12.3f} | "
|
||||
f"{speedup:>7.2f}x | {reduction:>9.1f}%"
|
||||
)
|
||||
|
||||
print("=" * 80)
|
||||
|
||||
# Summary statistics
|
||||
speedups = [r[3] for r in results]
|
||||
print("\nSpeedup summary:")
|
||||
print(f" Min: {min(speedups):.2f}x")
|
||||
print(f" Max: {max(speedups):.2f}x")
|
||||
print(f" Mean: {sum(speedups) / len(speedups):.2f}x")
|
||||
|
||||
# Find crossover point
|
||||
crossover_batch = None
|
||||
for batch_size, _, _, speedup, _ in results:
|
||||
if speedup >= 1.0:
|
||||
crossover_batch = batch_size
|
||||
break
|
||||
|
||||
print("\nConclusion:")
|
||||
if crossover_batch:
|
||||
print(f" - Direct copy becomes beneficial at batch size >= {crossover_batch}")
|
||||
# Filter for large batches (>= 512 which is typical for prefill)
|
||||
large_batch_speedups = [r[3] for r in results if r[0] >= 512]
|
||||
if large_batch_speedups:
|
||||
avg_large = sum(large_batch_speedups) / len(large_batch_speedups)
|
||||
print(f" - For batch sizes >= 512: avg speedup = {avg_large:.2f}x")
|
||||
print(" - MLA prefill typically uses large batches, so optimization is effective")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def main():
|
||||
# Test bfloat16
|
||||
print("\n")
|
||||
run_benchmark(torch.bfloat16, "bfloat16")
|
||||
|
||||
# Test float8_e4m3fn
|
||||
print("\n")
|
||||
run_benchmark(torch.float8_e4m3fn, "float8_e4m3fn")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -99,7 +99,6 @@ def benchmark_mrope(
|
||||
# the parameters to compute the q k v size based on tp_size
|
||||
mrope_helper_class = get_rope(
|
||||
head_size=head_dim,
|
||||
rotary_dim=head_dim,
|
||||
max_position=max_position,
|
||||
is_neox_style=is_neox_style,
|
||||
rope_parameters=rope_parameters,
|
||||
|
||||
@ -32,8 +32,8 @@ def get_benchmark(head_size, rotary_dim, is_neox_style, device):
|
||||
def benchmark(batch_size, seq_len, num_heads, provider):
|
||||
dtype = torch.bfloat16
|
||||
max_position = 8192
|
||||
base = 10000
|
||||
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style)
|
||||
rope_parameters = {"partial_rotary_factor": rotary_dim / head_size}
|
||||
rope = get_rope(head_size, max_position, is_neox_style, rope_parameters)
|
||||
rope = rope.to(dtype=dtype, device=device)
|
||||
cos_sin_cache = rope.cos_sin_cache.to(dtype=torch.float, device=device)
|
||||
|
||||
|
||||
@ -251,17 +251,6 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON
|
||||
endif()
|
||||
|
||||
# Build ACL with CMake
|
||||
set(ARM_COMPUTE_BUILD_SHARED_LIB "OFF")
|
||||
set(CMAKE_BUILD_TYPE "Release")
|
||||
set(ARM_COMPUTE_ARCH "armv8.2-a")
|
||||
set(ARM_COMPUTE_ENABLE_ASSERTS "OFF")
|
||||
set(ARM_COMPUTE_ENABLE_CPPTHREADS "OFF")
|
||||
set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER")
|
||||
set(ARM_COMPUTE_ENABLE_OPENMP "ON")
|
||||
set(ARM_COMPUTE_ENABLE_WERROR "OFF")
|
||||
set(ARM_COMPUTE_BUILD_EXAMPLES "OFF")
|
||||
set(ARM_COMPUTE_BUILD_TESTING "OFF")
|
||||
|
||||
set(_cmake_config_cmd
|
||||
${CMAKE_COMMAND} -G Ninja -B build
|
||||
-DARM_COMPUTE_BUILD_SHARED_LIB=OFF
|
||||
|
||||
@ -35,16 +35,21 @@ message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}")
|
||||
# sm90a
|
||||
|
||||
set(SUPPORT_ARCHS)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3)
|
||||
list(APPEND SUPPORT_ARCHS 9.0a)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3)
|
||||
list(APPEND SUPPORT_ARCHS "9.0a")
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8)
|
||||
list(APPEND SUPPORT_ARCHS 10.0a)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.9)
|
||||
# CUDA 12.9 has introduced "Family-Specific Architecture Features"
|
||||
# this supports all compute_10x family
|
||||
list(APPEND SUPPORT_ARCHS "10.0f")
|
||||
elseif(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
list(APPEND SUPPORT_ARCHS "10.0a")
|
||||
endif()
|
||||
|
||||
|
||||
cuda_archs_loose_intersection(FLASH_MLA_ARCHS "${SUPPORT_ARCHS}" "${CUDA_ARCHS}")
|
||||
if(FLASH_MLA_ARCHS)
|
||||
message(STATUS "FlashMLA CUDA architectures: ${FLASH_MLA_ARCHS}")
|
||||
set(VLLM_FLASHMLA_GPU_FLAGS ${VLLM_GPU_FLAGS})
|
||||
list(APPEND VLLM_FLASHMLA_GPU_FLAGS "--expt-relaxed-constexpr" "--expt-extended-lambda" "--use_fast_math")
|
||||
|
||||
@ -126,7 +131,8 @@ if(FLASH_MLA_ARCHS)
|
||||
$<$<COMPILE_LANGUAGE:CUDA>:-UPy_LIMITED_API>
|
||||
$<$<COMPILE_LANGUAGE:CXX>:-UPy_LIMITED_API>)
|
||||
else()
|
||||
# Create empty targets for setup.py when not targeting sm90a systems
|
||||
message(STATUS "FlashMLA will not compile: unsupported CUDA architecture ${CUDA_ARCHS}")
|
||||
# Create empty targets for setup.py on unsupported systems
|
||||
add_custom_target(_flashmla_C)
|
||||
add_custom_target(_flashmla_extension_C)
|
||||
endif()
|
||||
|
||||
@ -140,16 +140,21 @@ function(vllm_prepare_torch_gomp_shim TORCH_GOMP_SHIM_DIR)
|
||||
run_python(_VLLM_TORCH_GOMP_PATH
|
||||
"
|
||||
import os, glob
|
||||
try:
|
||||
import torch
|
||||
torch_pkg = os.path.dirname(torch.__file__)
|
||||
site_root = os.path.dirname(torch_pkg)
|
||||
torch_libs = os.path.join(site_root, 'torch.libs')
|
||||
print(glob.glob(os.path.join(torch_libs, 'libgomp-*.so*'))[0])
|
||||
except:
|
||||
print('')
|
||||
import torch
|
||||
torch_pkg = os.path.dirname(torch.__file__)
|
||||
site_root = os.path.dirname(torch_pkg)
|
||||
|
||||
# Search both torch.libs and torch/lib
|
||||
roots = [os.path.join(site_root, 'torch.libs'), os.path.join(torch_pkg, 'lib')]
|
||||
candidates = []
|
||||
for root in roots:
|
||||
if not os.path.isdir(root):
|
||||
continue
|
||||
candidates.extend(glob.glob(os.path.join(root, 'libgomp*.so*')))
|
||||
|
||||
print(candidates[0] if candidates else '')
|
||||
"
|
||||
"failed to probe torch.libs for libgomp")
|
||||
"failed to probe for libgomp")
|
||||
|
||||
if(_VLLM_TORCH_GOMP_PATH STREQUAL "" OR NOT EXISTS "${_VLLM_TORCH_GOMP_PATH}")
|
||||
return()
|
||||
|
||||
@ -15,19 +15,61 @@ __device__ __forceinline__ scalar_t compute(const scalar_t& x,
|
||||
const scalar_t& y) {
|
||||
return act_first ? ACT_FN(x) * y : x * ACT_FN(y);
|
||||
}
|
||||
// Activation and gating kernel template.
|
||||
|
||||
// Check if all pointers are 16-byte aligned for int4 vectorized access
|
||||
__device__ __forceinline__ bool is_16byte_aligned(const void* ptr) {
|
||||
return (reinterpret_cast<uintptr_t>(ptr) & 15) == 0;
|
||||
}
|
||||
|
||||
// Activation and gating kernel template.
|
||||
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
|
||||
bool act_first>
|
||||
__global__ void act_and_mul_kernel(
|
||||
scalar_t* __restrict__ out, // [..., d]
|
||||
const scalar_t* __restrict__ input, // [..., 2, d]
|
||||
const int d) {
|
||||
constexpr int VEC_SIZE = 16 / sizeof(scalar_t);
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
|
||||
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
|
||||
out[token_idx * d + idx] = compute<scalar_t, ACT_FN, act_first>(x, y);
|
||||
const scalar_t* x_ptr = input + token_idx * 2 * d;
|
||||
const scalar_t* y_ptr = x_ptr + d;
|
||||
scalar_t* out_ptr = out + token_idx * d;
|
||||
|
||||
// Check alignment for 128-bit vectorized access.
|
||||
// All three pointers must be 16-byte aligned for safe int4 operations.
|
||||
const bool aligned = is_16byte_aligned(x_ptr) && is_16byte_aligned(y_ptr) &&
|
||||
is_16byte_aligned(out_ptr);
|
||||
|
||||
if (aligned && d >= VEC_SIZE) {
|
||||
// Fast path: 128-bit vectorized loop
|
||||
const int4* x_vec = reinterpret_cast<const int4*>(x_ptr);
|
||||
const int4* y_vec = reinterpret_cast<const int4*>(y_ptr);
|
||||
int4* out_vec = reinterpret_cast<int4*>(out_ptr);
|
||||
const int num_vecs = d / VEC_SIZE;
|
||||
const int vec_end = num_vecs * VEC_SIZE;
|
||||
|
||||
for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
|
||||
int4 x = VLLM_LDG(&x_vec[i]), y = VLLM_LDG(&y_vec[i]), r;
|
||||
auto* xp = reinterpret_cast<scalar_t*>(&x);
|
||||
auto* yp = reinterpret_cast<scalar_t*>(&y);
|
||||
auto* rp = reinterpret_cast<scalar_t*>(&r);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; j++) {
|
||||
rp[j] = compute<scalar_t, ACT_FN, act_first>(xp[j], yp[j]);
|
||||
}
|
||||
out_vec[i] = r;
|
||||
}
|
||||
// Scalar cleanup for remaining elements
|
||||
for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) {
|
||||
out_ptr[i] = compute<scalar_t, ACT_FN, act_first>(VLLM_LDG(&x_ptr[i]),
|
||||
VLLM_LDG(&y_ptr[i]));
|
||||
}
|
||||
} else {
|
||||
// Scalar fallback for unaligned data or small d
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = VLLM_LDG(&x_ptr[idx]);
|
||||
const scalar_t y = VLLM_LDG(&y_ptr[idx]);
|
||||
out_ptr[idx] = compute<scalar_t, ACT_FN, act_first>(x, y);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -120,50 +162,115 @@ template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&, const float)>
|
||||
__global__ void act_and_mul_kernel_with_param(
|
||||
scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const int d,
|
||||
const float param) {
|
||||
constexpr int VEC_SIZE = 16 / sizeof(scalar_t);
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
|
||||
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
|
||||
out[token_idx * d + idx] = ACT_FN(x, param) * y;
|
||||
const scalar_t* x_ptr = input + token_idx * 2 * d;
|
||||
const scalar_t* y_ptr = x_ptr + d;
|
||||
scalar_t* out_ptr = out + token_idx * d;
|
||||
|
||||
// Check alignment for 128-bit vectorized access
|
||||
const bool aligned = is_16byte_aligned(x_ptr) && is_16byte_aligned(y_ptr) &&
|
||||
is_16byte_aligned(out_ptr);
|
||||
|
||||
if (aligned && d >= VEC_SIZE) {
|
||||
// Fast path: 128-bit vectorized loop
|
||||
const int4* x_vec = reinterpret_cast<const int4*>(x_ptr);
|
||||
const int4* y_vec = reinterpret_cast<const int4*>(y_ptr);
|
||||
int4* out_vec = reinterpret_cast<int4*>(out_ptr);
|
||||
const int num_vecs = d / VEC_SIZE;
|
||||
const int vec_end = num_vecs * VEC_SIZE;
|
||||
|
||||
for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
|
||||
int4 x = VLLM_LDG(&x_vec[i]), y = VLLM_LDG(&y_vec[i]), r;
|
||||
auto* xp = reinterpret_cast<scalar_t*>(&x);
|
||||
auto* yp = reinterpret_cast<scalar_t*>(&y);
|
||||
auto* rp = reinterpret_cast<scalar_t*>(&r);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; j++) {
|
||||
rp[j] = ACT_FN(xp[j], param) * yp[j];
|
||||
}
|
||||
out_vec[i] = r;
|
||||
}
|
||||
// Scalar cleanup for remaining elements
|
||||
for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) {
|
||||
out_ptr[i] = ACT_FN(VLLM_LDG(&x_ptr[i]), param) * VLLM_LDG(&y_ptr[i]);
|
||||
}
|
||||
} else {
|
||||
// Scalar fallback for unaligned data or small d
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = VLLM_LDG(&x_ptr[idx]);
|
||||
const scalar_t y = VLLM_LDG(&y_ptr[idx]);
|
||||
out_ptr[idx] = ACT_FN(x, param) * y;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T swigluoai_and_mul(const T& gate, const T& up,
|
||||
float alpha, float limit) {
|
||||
// clamp gate: min=None, max=limit
|
||||
const float gate_f = (float)gate;
|
||||
const float clamped_gate = gate_f > limit ? limit : gate_f;
|
||||
|
||||
// clamp up: min=-limit, max=limit
|
||||
const float up_f = (float)up;
|
||||
const float clamped_up =
|
||||
up_f > limit ? limit : (up_f < -limit ? -limit : up_f);
|
||||
|
||||
// glu = gate * sigmoid(gate * alpha)
|
||||
const float sigmoid_val = 1.0f / (1.0f + expf(-clamped_gate * alpha));
|
||||
const float glu = clamped_gate * sigmoid_val;
|
||||
|
||||
// (up + 1) * glu
|
||||
return (T)((clamped_up + 1.0f) * glu);
|
||||
// Clamp gate to (-inf, limit] and up to [-limit, limit]
|
||||
const float g = fminf((float)gate, limit);
|
||||
const float u = fmaxf(fminf((float)up, limit), -limit);
|
||||
// glu = gate * sigmoid(gate * alpha), then return (up + 1) * glu
|
||||
return (T)((u + 1.0f) * g / (1.0f + expf(-g * alpha)));
|
||||
}
|
||||
|
||||
// Interleaved gate/up: input has [gate0, up0, gate1, up1, ...].
|
||||
template <typename scalar_t,
|
||||
scalar_t (*ACT_FN)(const scalar_t&, const scalar_t&, const float,
|
||||
const float)>
|
||||
__global__ void swigluoai_and_mul_kernel(
|
||||
scalar_t* __restrict__ out, // [..., d]
|
||||
const scalar_t* __restrict__ input, // [..., 2, d]
|
||||
const scalar_t* __restrict__ input, // [..., 2 * d] (interleaved)
|
||||
const int d, const float alpha, const float limit) {
|
||||
// For interleaved data: input has 2*d elements per token (gate/up pairs)
|
||||
// output has d elements per token
|
||||
constexpr int VEC_SIZE = 16 / sizeof(scalar_t);
|
||||
constexpr int PAIRS = VEC_SIZE / 2; // Number of gate/up pairs per int4 load
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
// TODO: Vectorize loads and stores.
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
// gate = x[..., ::2] (even indices)
|
||||
const scalar_t gate = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx]);
|
||||
// up = x[..., 1::2] (odd indices)
|
||||
const scalar_t up = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx + 1]);
|
||||
const scalar_t* in_ptr = input + token_idx * 2 * d;
|
||||
scalar_t* out_ptr = out + token_idx * d;
|
||||
|
||||
out[token_idx * d + idx] = ACT_FN(gate, up, alpha, limit);
|
||||
// Check alignment for 128-bit vectorized access on input.
|
||||
// For output we use int2 (64-bit) which has 8-byte alignment requirement.
|
||||
const bool in_aligned = is_16byte_aligned(in_ptr);
|
||||
const bool out_aligned =
|
||||
(reinterpret_cast<uintptr_t>(out_ptr) & 7) == 0; // 8-byte for int2
|
||||
|
||||
if (in_aligned && out_aligned && d >= PAIRS) {
|
||||
// Fast path: vectorized loop
|
||||
// Each int4 load gives VEC_SIZE elements = PAIRS gate/up pairs
|
||||
// Each int2 store writes PAIRS output elements
|
||||
const int4* in_vec = reinterpret_cast<const int4*>(in_ptr);
|
||||
int2* out_vec = reinterpret_cast<int2*>(out_ptr);
|
||||
const int num_vecs = d / PAIRS;
|
||||
const int vec_end = num_vecs * PAIRS;
|
||||
|
||||
for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
|
||||
int4 v = VLLM_LDG(&in_vec[i]);
|
||||
int2 r;
|
||||
auto* vp = reinterpret_cast<scalar_t*>(&v);
|
||||
auto* rp = reinterpret_cast<scalar_t*>(&r);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < PAIRS; j++) {
|
||||
rp[j] = ACT_FN(vp[2 * j], vp[2 * j + 1], alpha, limit);
|
||||
}
|
||||
out_vec[i] = r;
|
||||
}
|
||||
// Scalar cleanup for remaining elements
|
||||
for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) {
|
||||
out_ptr[i] = ACT_FN(VLLM_LDG(&in_ptr[2 * i]),
|
||||
VLLM_LDG(&in_ptr[2 * i + 1]), alpha, limit);
|
||||
}
|
||||
} else {
|
||||
// Scalar fallback for unaligned data or small d
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
// gate = x[..., ::2] (even indices)
|
||||
const scalar_t gate = VLLM_LDG(&in_ptr[2 * idx]);
|
||||
// up = x[..., 1::2] (odd indices)
|
||||
const scalar_t up = VLLM_LDG(&in_ptr[2 * idx + 1]);
|
||||
out_ptr[idx] = ACT_FN(gate, up, alpha, limit);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -217,10 +324,41 @@ __global__ void activation_kernel(
|
||||
scalar_t* __restrict__ out, // [..., d]
|
||||
const scalar_t* __restrict__ input, // [..., d]
|
||||
const int d) {
|
||||
constexpr int VEC_SIZE = 16 / sizeof(scalar_t);
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
|
||||
out[token_idx * d + idx] = ACT_FN(x);
|
||||
const scalar_t* in_ptr = input + token_idx * d;
|
||||
scalar_t* out_ptr = out + token_idx * d;
|
||||
|
||||
// Check alignment for 128-bit vectorized access
|
||||
const bool aligned = is_16byte_aligned(in_ptr) && is_16byte_aligned(out_ptr);
|
||||
|
||||
if (aligned && d >= VEC_SIZE) {
|
||||
// Fast path: 128-bit vectorized loop
|
||||
const int4* in_vec = reinterpret_cast<const int4*>(in_ptr);
|
||||
int4* out_vec = reinterpret_cast<int4*>(out_ptr);
|
||||
const int num_vecs = d / VEC_SIZE;
|
||||
const int vec_end = num_vecs * VEC_SIZE;
|
||||
|
||||
for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
|
||||
int4 v = VLLM_LDG(&in_vec[i]), r;
|
||||
auto* vp = reinterpret_cast<scalar_t*>(&v);
|
||||
auto* rp = reinterpret_cast<scalar_t*>(&r);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; j++) {
|
||||
rp[j] = ACT_FN(vp[j]);
|
||||
}
|
||||
out_vec[i] = r;
|
||||
}
|
||||
// Scalar cleanup for remaining elements
|
||||
for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) {
|
||||
out_ptr[i] = ACT_FN(VLLM_LDG(&in_ptr[i]));
|
||||
}
|
||||
} else {
|
||||
// Scalar fallback for unaligned data or small d
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = VLLM_LDG(&in_ptr[idx]);
|
||||
out_ptr[idx] = ACT_FN(x);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
12
csrc/cache.h
12
csrc/cache.h
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/all.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
@ -58,6 +59,15 @@ void cp_gather_cache(
|
||||
torch::Tensor const& cu_seq_lens, // [BATCH+1]
|
||||
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);
|
||||
|
||||
// Gather and upconvert FP8 KV cache to BF16 workspace
|
||||
void cp_gather_and_upconvert_fp8_kv_cache(
|
||||
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
|
||||
torch::Tensor const& dst, // [TOT_TOKENS, 576]
|
||||
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
|
||||
torch::Tensor const& seq_lens, // [BATCH]
|
||||
torch::Tensor const& workspace_starts, // [BATCH]
|
||||
int64_t batch_size);
|
||||
|
||||
// Indexer K quantization and cache function
|
||||
void indexer_k_quant_and_cache(
|
||||
torch::Tensor& k, // [num_tokens, head_dim]
|
||||
@ -72,4 +82,4 @@ void cp_gather_indexer_k_quant_cache(
|
||||
torch::Tensor& dst_k, // [num_tokens, head_dim]
|
||||
torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4]
|
||||
const torch::Tensor& block_table, // [batch_size, num_blocks]
|
||||
const torch::Tensor& cu_seq_lens); // [batch_size + 1]
|
||||
const torch::Tensor& cu_seq_lens); // [batch_size + 1]
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
#include "cuda_utils.h"
|
||||
#include "cuda_compat.h"
|
||||
@ -514,7 +515,8 @@ __global__ void indexer_k_quant_and_cache_kernel(
|
||||
const int quant_block_size, // quantization block size
|
||||
const int cache_block_size, // cache block size
|
||||
const int cache_stride, // stride for each token in kv_cache
|
||||
const bool use_ue8m0 // use ue8m0 scale format
|
||||
|
||||
const bool use_ue8m0 // use ue8m0 scale format
|
||||
) {
|
||||
constexpr int VEC_SIZE = 4;
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
@ -1061,6 +1063,82 @@ void gather_and_maybe_dequant_cache(
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// Gather and upconvert FP8 KV cache tokens to BF16 workspace
|
||||
// Similar to cp_gather_cache but specifically for FP8->BF16 conversion
|
||||
__global__ void cp_gather_and_upconvert_fp8_kv_cache(
|
||||
const uint8_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
|
||||
__nv_bfloat16* __restrict__ dst, // [TOT_TOKENS, 576]
|
||||
const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
|
||||
const int32_t* __restrict__ seq_lens, // [BATCH]
|
||||
const int32_t* __restrict__ workspace_starts, // [BATCH]
|
||||
const int32_t block_size, const int32_t head_dim,
|
||||
const int64_t block_table_stride, const int64_t cache_block_stride,
|
||||
const int64_t cache_entry_stride, const int64_t dst_entry_stride) {
|
||||
const int64_t bid = blockIdx.x; // Batch ID
|
||||
const int32_t num_splits = gridDim.y;
|
||||
const int32_t split = blockIdx.y;
|
||||
const int32_t seq_start = workspace_starts[bid];
|
||||
const int32_t seq_len = seq_lens[bid];
|
||||
const int32_t tot_slots = seq_len;
|
||||
const int32_t split_slots = cuda_utils::ceil_div(tot_slots, num_splits);
|
||||
|
||||
const int32_t split_start = split * split_slots;
|
||||
const int32_t split_end = min((split + 1) * split_slots, tot_slots);
|
||||
|
||||
const bool is_active_split = (split_start < tot_slots);
|
||||
|
||||
if (!is_active_split) return;
|
||||
|
||||
// Adjust the pointer for the block_table for this batch
|
||||
const int32_t batch_offset = bid * block_table_stride;
|
||||
int32_t offset = split_start;
|
||||
int32_t offset_div = offset / block_size;
|
||||
offset = offset % block_size;
|
||||
const int32_t* batch_block_table = block_table + batch_offset;
|
||||
|
||||
// Adjust dst pointer based on the cumulative sequence lengths
|
||||
dst += seq_start * dst_entry_stride;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
// Process each token in this split
|
||||
for (int pid = split_start; pid < split_end; ++pid) {
|
||||
auto block_id = batch_block_table[offset_div];
|
||||
const uint8_t* token_ptr =
|
||||
src_cache + block_id * cache_block_stride + offset * cache_entry_stride;
|
||||
__nv_bfloat16* dst_ptr = dst + pid * dst_entry_stride;
|
||||
|
||||
// FP8 format: 512 bytes fp8 + 16 bytes scales + 128 bytes rope (64 bf16)
|
||||
const uint8_t* no_pe_ptr = token_ptr;
|
||||
const float* scales_ptr = reinterpret_cast<const float*>(token_ptr + 512);
|
||||
const __nv_bfloat16* rope_ptr =
|
||||
reinterpret_cast<const __nv_bfloat16*>(token_ptr + 512 + 16);
|
||||
|
||||
// Parallelize fp8 dequant (512 elements) and rope copy (64 elements)
|
||||
if (tid < 512) {
|
||||
// FP8 dequantization
|
||||
const int tile = tid >> 7; // each tile is 128 elements
|
||||
const float scale = scales_ptr[tile];
|
||||
const uint8_t val = no_pe_ptr[tid];
|
||||
dst_ptr[tid] =
|
||||
fp8::scaled_convert<__nv_bfloat16, uint8_t,
|
||||
vllm::Fp8KVCacheDataType::kFp8E4M3>(val, scale);
|
||||
} else if (tid < 576) {
|
||||
// Rope copy (64 bf16 elements)
|
||||
const int rope_idx = tid - 512;
|
||||
dst_ptr[512 + rope_idx] = rope_ptr[rope_idx];
|
||||
}
|
||||
|
||||
// Move to next token
|
||||
offset += 1;
|
||||
if (offset == block_size) {
|
||||
offset_div += 1;
|
||||
offset = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
// Note(hc): The cp_gather_cache allows seq_starts to no longer be divisible by
|
||||
// block_size.
|
||||
@ -1202,6 +1280,57 @@ void cp_gather_cache(
|
||||
}
|
||||
}
|
||||
|
||||
void cp_gather_and_upconvert_fp8_kv_cache(
|
||||
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
|
||||
torch::Tensor const& dst, // [TOT_TOKENS, 576]
|
||||
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
|
||||
torch::Tensor const& seq_lens, // [BATCH]
|
||||
torch::Tensor const& workspace_starts, // [BATCH]
|
||||
int64_t batch_size) {
|
||||
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
int32_t block_size = src_cache.size(1);
|
||||
int32_t head_dim = dst.size(1);
|
||||
|
||||
TORCH_CHECK(block_table.dtype() == torch::kInt32,
|
||||
"block_table must be int32");
|
||||
TORCH_CHECK(seq_lens.dtype() == torch::kInt32, "seq_lens must be int32");
|
||||
TORCH_CHECK(workspace_starts.dtype() == torch::kInt32,
|
||||
"workspace_starts must be int32");
|
||||
|
||||
TORCH_CHECK(src_cache.device() == dst.device(),
|
||||
"src_cache and dst must be on the same device");
|
||||
TORCH_CHECK(src_cache.device() == block_table.device(),
|
||||
"src_cache and block_table must be on the same device");
|
||||
TORCH_CHECK(src_cache.device() == seq_lens.device(),
|
||||
"src_cache and seq_lens must be on the same device");
|
||||
TORCH_CHECK(src_cache.device() == workspace_starts.device(),
|
||||
"src_cache and workspace_starts must be on the same device");
|
||||
|
||||
TORCH_CHECK(src_cache.dtype() == torch::kUInt8, "src_cache must be uint8");
|
||||
TORCH_CHECK(dst.dtype() == torch::kBFloat16, "dst must be bfloat16");
|
||||
TORCH_CHECK(head_dim == 576, "head_dim must be 576 for MLA");
|
||||
|
||||
int64_t block_table_stride = block_table.stride(0);
|
||||
int64_t cache_block_stride = src_cache.stride(0);
|
||||
int64_t cache_entry_stride = src_cache.stride(1);
|
||||
int64_t dst_entry_stride = dst.stride(0);
|
||||
|
||||
// Decide on the number of splits based on the batch size
|
||||
int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16;
|
||||
dim3 grid(batch_size, num_splits);
|
||||
dim3 block(576);
|
||||
|
||||
vllm::cp_gather_and_upconvert_fp8_kv_cache<<<grid, block, 0, stream>>>(
|
||||
src_cache.data_ptr<uint8_t>(),
|
||||
reinterpret_cast<__nv_bfloat16*>(dst.data_ptr()),
|
||||
block_table.data_ptr<int32_t>(), seq_lens.data_ptr<int32_t>(),
|
||||
workspace_starts.data_ptr<int32_t>(), block_size, head_dim,
|
||||
block_table_stride, cache_block_stride, cache_entry_stride,
|
||||
dst_entry_stride);
|
||||
}
|
||||
|
||||
// Macro to dispatch the kernel based on the data type.
|
||||
#define CALL_INDEXER_K_QUANT_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::indexer_k_quant_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||
|
||||
@ -117,7 +117,6 @@ torch::Tensor get_scheduler_metadata(
|
||||
input.casual = casual;
|
||||
input.isa = isa;
|
||||
input.enable_kv_split = enable_kv_split;
|
||||
TORCH_CHECK(casual, "Only supports casual mask for now.");
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(dtype, "get_scheduler_metadata", [&]() {
|
||||
CPU_ATTN_DISPATCH_CASE_HEADDIM(head_dim, [&] {
|
||||
|
||||
@ -186,7 +186,7 @@ struct AttentionMetadata {
|
||||
// - Intermediate outputs: q_tile_size * head_dim * output_buffer_elem_size + 2
|
||||
// * q_tile_size * 4, partial output, max + sum (float)
|
||||
// Reduction scratchpad contains:
|
||||
// - flags: bool array to indicate wether the split is finished
|
||||
// - flags: bool array to indicate whether the split is finished
|
||||
// - outputs: split_num * q_tile_size * head_dim * output_buffer_elem_size
|
||||
// - max, sum: 2 * split_num * q_tile_size * 4
|
||||
class AttentionScratchPad {
|
||||
|
||||
@ -446,9 +446,13 @@ __device__ inline T apply_sigmoid(T val) {
|
||||
|
||||
template <ScoringFunc SF, typename T>
|
||||
__device__ inline T apply_scoring(T val) {
|
||||
if constexpr (SF == SCORING_SIGMOID) {
|
||||
if constexpr (SF == SCORING_NONE) {
|
||||
return val;
|
||||
} else if constexpr (SF == SCORING_SIGMOID) {
|
||||
return apply_sigmoid(val);
|
||||
} else {
|
||||
static_assert(SF == SCORING_NONE || SF == SCORING_SIGMOID,
|
||||
"Unsupported ScoringFunc in apply_scoring");
|
||||
return val;
|
||||
}
|
||||
}
|
||||
@ -481,8 +485,6 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias,
|
||||
largest = value;
|
||||
}
|
||||
}
|
||||
|
||||
__syncwarp(); // Ensure all threads have valid data before reduction
|
||||
// Get the top2 warpwise
|
||||
T max1 = cg::reduce(tile, largest, cg::greater<T>());
|
||||
|
||||
@ -589,7 +591,6 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
int pre_count_equal_to_top_value = 0;
|
||||
// Use loop to find the largset top_group
|
||||
while (count_equal_to_top_value < target_num_min) {
|
||||
__syncwarp(); // Ensure all threads have valid data before reduction
|
||||
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
|
||||
if (value == topk_group_value) {
|
||||
value = neg_inf<T>();
|
||||
@ -644,10 +645,8 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
}
|
||||
}
|
||||
queue.done();
|
||||
__syncwarp();
|
||||
// Get the topk_idx
|
||||
queue.dumpIdx(s_topk_idx);
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
// Load the valid score value
|
||||
@ -675,10 +674,13 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
|
||||
if (case_id < num_tokens) {
|
||||
if (if_proceed_next_topk) {
|
||||
float scale = routed_scaling_factor;
|
||||
if (renormalize) {
|
||||
scale /= topk_sum;
|
||||
}
|
||||
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
||||
float base = cuda_cast<float, T>(s_topk_value[i]);
|
||||
float value = renormalize ? (base / topk_sum * routed_scaling_factor)
|
||||
: (base * routed_scaling_factor);
|
||||
float value = base * scale;
|
||||
topk_indices[i] = s_topk_idx[i];
|
||||
topk_values[i] = value;
|
||||
}
|
||||
|
||||
1
csrc/moe/marlin_moe_wna16/.gitignore
vendored
1
csrc/moe/marlin_moe_wna16/.gitignore
vendored
@ -1,2 +1,3 @@
|
||||
sm*_kernel_*.cu
|
||||
kernel_selector.h
|
||||
kernel_*.cu
|
||||
|
||||
@ -10,6 +10,8 @@ import jinja2
|
||||
|
||||
ARCHS = []
|
||||
SUPPORT_FP8 = False
|
||||
SUPPORT_SM75 = False
|
||||
SUPPORT_SM80 = False
|
||||
for arch in sys.argv[1].split(","):
|
||||
arch = arch[: arch.index(".") + 2].replace(".", "")
|
||||
arch = int(arch)
|
||||
@ -19,6 +21,10 @@ for arch in sys.argv[1].split(","):
|
||||
# with FP16 MMA, so it cannot achieve any acceleration.
|
||||
if arch in [89, 120]:
|
||||
SUPPORT_FP8 = True
|
||||
if arch >= 80:
|
||||
SUPPORT_SM80 = True
|
||||
if arch == 75:
|
||||
SUPPORT_SM75 = True
|
||||
|
||||
FILE_HEAD_COMMENT = """
|
||||
// auto generated by generate_kernels.py
|
||||
@ -157,6 +163,7 @@ def remove_old_kernels():
|
||||
|
||||
def generate_new_kernels():
|
||||
result_dict = {}
|
||||
sm_75_result_dict = {}
|
||||
|
||||
for quant_config in QUANT_CONFIGS:
|
||||
c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"])
|
||||
@ -174,6 +181,8 @@ def generate_new_kernels():
|
||||
s_type = quant_config.get("s_type", c_type)
|
||||
if (a_type, b_type, c_type) not in result_dict:
|
||||
result_dict[(a_type, b_type, c_type)] = []
|
||||
if a_type in ["kFloat16", "kS8"] and c_type == "kFloat16":
|
||||
sm_75_result_dict[(a_type, b_type, c_type)] = []
|
||||
|
||||
for group_blocks, m_blocks, thread_configs in itertools.product(
|
||||
all_group_blocks, all_m_blocks, all_thread_configs
|
||||
@ -197,78 +206,89 @@ def generate_new_kernels():
|
||||
"thread_k_blocks": thread_k // 16,
|
||||
"thread_n_blocks": thread_n // 16,
|
||||
"m_block_size_8": "true" if m_blocks == 0.5 else "false",
|
||||
"stages": "pipe_stages",
|
||||
"stages": 4,
|
||||
"group_blocks": group_blocks,
|
||||
"is_zp_float": "false",
|
||||
}
|
||||
|
||||
result_dict[(a_type, b_type, c_type)].append(config)
|
||||
if SUPPORT_SM80:
|
||||
result_dict[(a_type, b_type, c_type)].append(config)
|
||||
if (a_type, b_type, c_type) in sm_75_result_dict and SUPPORT_SM75:
|
||||
config_sm75 = config.copy()
|
||||
config_sm75["stages"] = 2
|
||||
sm_75_result_dict[(a_type, b_type, c_type)].append(config_sm75)
|
||||
|
||||
kernel_selector_str = FILE_HEAD_COMMENT
|
||||
|
||||
for (a_type, b_type, c_type), config_list in result_dict.items():
|
||||
all_template_str_list = []
|
||||
for config in config_list:
|
||||
s_type = config["s_type"]
|
||||
template_str = jinja2.Template(TEMPLATE).render(
|
||||
a_type_id=f"vllm::{a_type}.id()",
|
||||
b_type_id=f"vllm::{b_type}.id()",
|
||||
c_type_id=f"vllm::{c_type}.id()",
|
||||
s_type_id=f"vllm::{s_type}.id()",
|
||||
**config,
|
||||
)
|
||||
all_template_str_list.append(template_str)
|
||||
|
||||
conditions = [
|
||||
f"a_type == vllm::{a_type}",
|
||||
f"b_type == vllm::{b_type}",
|
||||
f"c_type == vllm::{c_type}",
|
||||
f"s_type == vllm::{s_type}",
|
||||
f"threads == {config['threads']}",
|
||||
f"thread_m_blocks == {config['thread_m_blocks']}",
|
||||
f"thread_n_blocks == {config['thread_n_blocks']}",
|
||||
f"thread_k_blocks == {config['thread_k_blocks']}",
|
||||
f"m_block_size_8 == {config['m_block_size_8']}",
|
||||
f"group_blocks == {config['group_blocks']}",
|
||||
f"is_zp_float == {config['is_zp_float']}",
|
||||
]
|
||||
conditions = " && ".join(conditions)
|
||||
|
||||
if kernel_selector_str == FILE_HEAD_COMMENT:
|
||||
kernel_selector_str += f"if ({conditions})\n kernel = "
|
||||
else:
|
||||
kernel_selector_str += f"else if ({conditions})\n kernel = "
|
||||
|
||||
kernel_template2 = (
|
||||
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
|
||||
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
|
||||
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
|
||||
"{{is_zp_float}}>;"
|
||||
)
|
||||
|
||||
kernel_selector_str += (
|
||||
jinja2.Template(kernel_template2).render(
|
||||
for result_dict_tmp in [result_dict, sm_75_result_dict]:
|
||||
for (a_type, b_type, c_type), config_list in result_dict_tmp.items():
|
||||
all_template_str_list = []
|
||||
if not config_list:
|
||||
continue
|
||||
for config in config_list:
|
||||
s_type = config["s_type"]
|
||||
template_str = jinja2.Template(TEMPLATE).render(
|
||||
a_type_id=f"vllm::{a_type}.id()",
|
||||
b_type_id=f"vllm::{b_type}.id()",
|
||||
c_type_id=f"vllm::{c_type}.id()",
|
||||
s_type_id=f"vllm::{s_type}.id()",
|
||||
**config,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
all_template_str_list.append(template_str)
|
||||
|
||||
file_content = FILE_HEAD + "\n\n"
|
||||
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
|
||||
if a_type == "kFE4M3fn":
|
||||
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
else:
|
||||
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
conditions = [
|
||||
f"a_type == vllm::{a_type}",
|
||||
f"b_type == vllm::{b_type}",
|
||||
f"c_type == vllm::{c_type}",
|
||||
f"s_type == vllm::{s_type}",
|
||||
f"threads == {config['threads']}",
|
||||
f"thread_m_blocks == {config['thread_m_blocks']}",
|
||||
f"thread_n_blocks == {config['thread_n_blocks']}",
|
||||
f"thread_k_blocks == {config['thread_k_blocks']}",
|
||||
f"m_block_size_8 == {config['m_block_size_8']}",
|
||||
f"stages == {config['stages']}",
|
||||
f"group_blocks == {config['group_blocks']}",
|
||||
f"is_zp_float == {config['is_zp_float']}",
|
||||
]
|
||||
conditions = " && ".join(conditions)
|
||||
|
||||
filename = filename.lower()
|
||||
if kernel_selector_str == FILE_HEAD_COMMENT:
|
||||
kernel_selector_str += f"if ({conditions})\n kernel = "
|
||||
else:
|
||||
kernel_selector_str += f"else if ({conditions})\n kernel = "
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
|
||||
f.write(file_content)
|
||||
kernel_template2 = (
|
||||
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
|
||||
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
|
||||
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
|
||||
"{{is_zp_float}}>;"
|
||||
)
|
||||
|
||||
kernel_selector_str += (
|
||||
jinja2.Template(kernel_template2).render(
|
||||
a_type_id=f"vllm::{a_type}.id()",
|
||||
b_type_id=f"vllm::{b_type}.id()",
|
||||
c_type_id=f"vllm::{c_type}.id()",
|
||||
s_type_id=f"vllm::{s_type}.id()",
|
||||
**config,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
file_content = FILE_HEAD + "\n\n"
|
||||
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
|
||||
if a_type == "kFE4M3fn":
|
||||
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
elif result_dict_tmp is sm_75_result_dict:
|
||||
filename = f"sm75_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
else:
|
||||
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
|
||||
filename = filename.lower()
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
|
||||
f.write(file_content)
|
||||
|
||||
if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT:
|
||||
kernel_selector_str += (
|
||||
|
||||
@ -26,6 +26,7 @@
|
||||
#include "quantization/gptq_marlin/marlin.cuh"
|
||||
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
|
||||
#include "quantization/gptq_marlin/dequant.h"
|
||||
#include "quantization/gptq_marlin/marlin_mma.h"
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||
@ -35,7 +36,7 @@
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
||||
|
||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const vllm::ScalarTypeId b_type_id, // weight MarlinScalarType id
|
||||
@ -84,146 +85,6 @@ __global__ void Marlin(
|
||||
|
||||
#else
|
||||
|
||||
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
|
||||
// output/accumulation.
|
||||
template <vllm::ScalarTypeId type_id, int k_size = 16>
|
||||
__device__ inline void mma(
|
||||
const typename MarlinScalarType<type_id>::FragA& a_frag,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b,
|
||||
typename MarlinScalarType<type_id>::FragC& frag_c, int idx = 0) {
|
||||
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
|
||||
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
||||
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
|
||||
if constexpr (k_size == 16) {
|
||||
if constexpr (std::is_same<scalar_t, half>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "f"(c[0]),
|
||||
"f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "r"(c[0]),
|
||||
"r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
}
|
||||
} else if (k_size == 32) {
|
||||
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <vllm::ScalarTypeId type_id, int k_size = 16>
|
||||
__device__ inline void mma_trans(
|
||||
const typename MarlinScalarType<type_id>::FragA& a_frag,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b2,
|
||||
typename MarlinScalarType<type_id>::FragC& frag_c) {
|
||||
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
|
||||
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
||||
const uint32_t* b2 = reinterpret_cast<const uint32_t*>(&frag_b2);
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
|
||||
if constexpr (k_size == 16) {
|
||||
if constexpr (std::is_same<scalar_t, half>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
|
||||
"f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]),
|
||||
"r"(c[3]));
|
||||
}
|
||||
} else {
|
||||
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 1200
|
||||
asm volatile(
|
||||
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
#else
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
#endif
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
||||
// memory, directly in tensor core layout.
|
||||
template <int count, vllm::ScalarTypeId type_id>
|
||||
@ -439,9 +300,20 @@ __global__ void Marlin(
|
||||
if constexpr (a_type_id == vllm::kFE4M3fn.id()) return;
|
||||
#endif
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
// Turing TensorCore only supports fp16 and int8
|
||||
if constexpr (a_type_id != vllm::kFloat16.id() && a_type_id != vllm::kS8.id())
|
||||
return;
|
||||
#endif
|
||||
|
||||
int num_tokens_past_padded = num_tokens_past_padded_ptr[0];
|
||||
constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks);
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
constexpr bool use_fp16_accum = a_type_id == vllm::kFloat16.id();
|
||||
#else
|
||||
constexpr bool use_fp16_accum = false;
|
||||
#endif
|
||||
using Adtype = MarlinScalarType<a_type_id>;
|
||||
using Cdtype = MarlinScalarType<c_type_id>;
|
||||
|
||||
@ -618,7 +490,22 @@ __global__ void Marlin(
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
|
||||
if constexpr (moe_block_size >= 16)
|
||||
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 16);
|
||||
if constexpr (moe_block_size >= 8)
|
||||
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 8);
|
||||
if constexpr (moe_block_size >= 4)
|
||||
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 4);
|
||||
if constexpr (moe_block_size >= 2)
|
||||
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 2);
|
||||
|
||||
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 1);
|
||||
block_num_valid_tokens = local_count;
|
||||
#else
|
||||
block_num_valid_tokens = __reduce_add_sync(0xffffffff, local_count);
|
||||
#endif
|
||||
|
||||
if (lane_id == 0)
|
||||
reinterpret_cast<int*>(sh_new)[0] = block_num_valid_tokens;
|
||||
@ -1018,10 +905,6 @@ __global__ void Marlin(
|
||||
constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride)
|
||||
: (stages * s_sh_stage);
|
||||
int4* sh_s = sh_zp + (stages * zp_sh_stage);
|
||||
// shared memory reused by reduction should be smaller than
|
||||
// shared memory used by weight.
|
||||
static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <=
|
||||
stages * b_sh_stage);
|
||||
int4* sh_a = sh_s + sh_s_size;
|
||||
|
||||
// Register storage for double buffer of shared memory reads.
|
||||
@ -1545,11 +1428,13 @@ __global__ void Marlin(
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks; i++) {
|
||||
if constexpr (m_block_size_8) {
|
||||
mma_trans<a_type_id>(frag_a[k2][i], frag_b0, frag_b1,
|
||||
frag_c[i][j][0]);
|
||||
mma_trans<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b0, frag_b1,
|
||||
frag_c[i][j][0]);
|
||||
} else {
|
||||
mma<a_type_id>(frag_a[k2][i], frag_b0, frag_c[i][j][0]);
|
||||
mma<a_type_id>(frag_a[k2][i], frag_b1, frag_c[i][j][1]);
|
||||
mma<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b0,
|
||||
frag_c[i][j][0]);
|
||||
mma<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b1,
|
||||
frag_c[i][j][1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1583,10 +1468,12 @@ __global__ void Marlin(
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks; i++) {
|
||||
mma<a_type_id, 32>(frag_a[k2][i], frag_b[0],
|
||||
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]);
|
||||
mma<a_type_id, 32>(frag_a[k2][i], frag_b[1],
|
||||
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]);
|
||||
mma<a_type_id, false, 32>(
|
||||
frag_a[k2][i], frag_b[0],
|
||||
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]);
|
||||
mma<a_type_id, false, 32>(
|
||||
frag_a[k2][i], frag_b[1],
|
||||
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]);
|
||||
}
|
||||
|
||||
if constexpr (group_blocks != -1) {
|
||||
@ -2132,6 +2019,21 @@ __global__ void Marlin(
|
||||
// While this pattern may not be the most readable, other ways of writing
|
||||
// the loop seemed to noticeably worse performance after compilation.
|
||||
if (slice_iters == 0) {
|
||||
// convert fp16 accum to fp32 for reduction
|
||||
if constexpr (use_fp16_accum) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < (thread_m_blocks * (is_a_8bit ? 2 : 4) * 2); i++) {
|
||||
float* frag_c_part_float = reinterpret_cast<float*>(frag_c) + i * 4;
|
||||
scalar_t* frag_c_part_half =
|
||||
reinterpret_cast<scalar_t*>(frag_c_part_float);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 3; i >= 0; i--) {
|
||||
frag_c_part_float[i] = Cdtype::num2float(frag_c_part_half[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (is_a_8bit) {
|
||||
float frag_a_s[2 * thread_m_blocks];
|
||||
|
||||
|
||||
@ -142,7 +142,7 @@ typedef struct {
|
||||
|
||||
int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
|
||||
int prob_n, int prob_k, int num_bits, int group_size,
|
||||
bool has_act_order, bool is_k_full) {
|
||||
bool has_act_order, bool is_k_full, int stages) {
|
||||
bool cache_scales_chunk = has_act_order && !is_k_full;
|
||||
|
||||
int tb_n = th_config.thread_n;
|
||||
@ -160,13 +160,13 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
|
||||
|
||||
if (cache_scales_chunk) {
|
||||
int load_groups =
|
||||
tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
|
||||
tb_groups * stages * 2; // Chunk size is 2x pipeline over dim K
|
||||
load_groups = max(load_groups, 32); // We load at least 32 scale groups
|
||||
return load_groups * tb_n * 2;
|
||||
} else {
|
||||
int tb_scales = tb_groups * tb_n * 2;
|
||||
|
||||
return tb_scales * pipe_stages;
|
||||
return tb_scales * stages;
|
||||
}
|
||||
}
|
||||
|
||||
@ -174,7 +174,7 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
|
||||
int thread_m_blocks, int prob_m, int prob_n,
|
||||
int prob_k, int num_bits, int group_size,
|
||||
bool has_act_order, bool is_k_full, int has_zp,
|
||||
int is_zp_float, bool is_a_8bit) {
|
||||
int is_zp_float, bool is_a_8bit, int stages) {
|
||||
int pack_factor = 32 / num_bits;
|
||||
|
||||
// Get B size
|
||||
@ -185,8 +185,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
|
||||
// shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights
|
||||
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
|
||||
int sh_block_meta_size = tb_m * 16;
|
||||
int sh_a_size = pipe_stages * (tb_m * tb_k) * (is_a_8bit ? 1 : 2);
|
||||
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
|
||||
int sh_a_size = stages * (tb_m * tb_k) * (is_a_8bit ? 1 : 2);
|
||||
int sh_b_size = stages * (tb_k * tb_n / pack_factor) * 4;
|
||||
int sh_red_size = tb_m * (tb_n + 8) * 2;
|
||||
int sh_bias_size = tb_n * 2;
|
||||
int tmp_size =
|
||||
@ -195,8 +195,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
|
||||
|
||||
int sh_s_size =
|
||||
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
|
||||
group_size, has_act_order, is_k_full);
|
||||
int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0;
|
||||
group_size, has_act_order, is_k_full, stages);
|
||||
int sh_g_idx_size = has_act_order && !is_k_full ? stages * tb_k / 4 : 0;
|
||||
int sh_zp_size = 0;
|
||||
if (has_zp) {
|
||||
if (is_zp_float)
|
||||
@ -217,7 +217,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
|
||||
int thread_m_blocks, int prob_m, int prob_n, int prob_k,
|
||||
int num_bits, int group_size, bool has_act_order,
|
||||
bool is_k_full, int has_zp, int is_zp_float,
|
||||
int max_shared_mem, bool is_a_8bit) {
|
||||
bool is_a_8bit, int stages, int max_shared_mem) {
|
||||
// Sanity
|
||||
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
|
||||
th_config.num_threads == -1) {
|
||||
@ -243,7 +243,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
|
||||
int cache_size =
|
||||
get_kernel_cache_size(th_config, m_block_size_8, thread_m_blocks, prob_m,
|
||||
prob_n, prob_k, num_bits, group_size, has_act_order,
|
||||
is_k_full, has_zp, is_zp_float, is_a_8bit);
|
||||
is_k_full, has_zp, is_zp_float, is_a_8bit, stages);
|
||||
return cache_size <= max_shared_mem;
|
||||
}
|
||||
|
||||
@ -252,7 +252,7 @@ MarlinFuncPtr get_marlin_kernel(
|
||||
const vllm::ScalarType c_type, const vllm::ScalarType s_type,
|
||||
int thread_m_blocks, int thread_n_blocks, int thread_k_blocks,
|
||||
bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks,
|
||||
int threads, bool is_zp_float) {
|
||||
int threads, bool is_zp_float, int stages) {
|
||||
int num_bits = b_type.size_bits();
|
||||
auto kernel = MarlinDefault;
|
||||
|
||||
@ -266,8 +266,8 @@ exec_config_t determine_exec_config(
|
||||
const vllm::ScalarType& c_type, const vllm::ScalarType& s_type, int prob_m,
|
||||
int prob_n, int prob_k, int num_experts, int top_k, int thread_m_blocks,
|
||||
bool m_block_size_8, int num_bits, int group_size, bool has_act_order,
|
||||
bool is_k_full, bool has_zp, bool is_zp_float, int max_shared_mem, int sms,
|
||||
bool is_a_8bit) {
|
||||
bool is_k_full, bool has_zp, bool is_zp_float, bool is_a_8bit, int stages,
|
||||
int max_shared_mem, int sms) {
|
||||
exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}};
|
||||
thread_config_t* thread_configs = thread_m_blocks > 1
|
||||
? large_batch_thread_configs
|
||||
@ -284,15 +284,15 @@ exec_config_t determine_exec_config(
|
||||
|
||||
if (!is_valid_config(th_config, m_block_size_8, thread_m_blocks, prob_m,
|
||||
prob_n, prob_k, num_bits, group_size, has_act_order,
|
||||
is_k_full, has_zp, is_zp_float, max_shared_mem - 512,
|
||||
is_a_8bit)) {
|
||||
is_k_full, has_zp, is_zp_float, is_a_8bit, stages,
|
||||
max_shared_mem - 512)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int cache_size = get_kernel_cache_size(
|
||||
th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k,
|
||||
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float,
|
||||
is_a_8bit);
|
||||
is_a_8bit, stages);
|
||||
|
||||
int group_blocks = 0;
|
||||
if (!has_act_order) {
|
||||
@ -303,7 +303,7 @@ exec_config_t determine_exec_config(
|
||||
get_marlin_kernel(a_type, b_type, c_type, s_type, thread_m_blocks,
|
||||
th_config.thread_n / 16, th_config.thread_k / 16,
|
||||
m_block_size_8, has_act_order, has_zp, group_blocks,
|
||||
th_config.num_threads, is_zp_float);
|
||||
th_config.num_threads, is_zp_float, stages);
|
||||
|
||||
if (kernel == MarlinDefault) continue;
|
||||
|
||||
@ -433,8 +433,14 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
dev);
|
||||
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
|
||||
dev);
|
||||
TORCH_CHECK(major_capability * 10 + minor_capability >= 80,
|
||||
"marlin kernel only support Ampere or newer GPUs.");
|
||||
TORCH_CHECK(major_capability * 10 + minor_capability >= 75,
|
||||
"marlin kernel only support Turing or newer GPUs.");
|
||||
int stages = 4;
|
||||
if (major_capability == 7 && minor_capability == 5) {
|
||||
stages = 2;
|
||||
TORCH_CHECK(a_type == vllm::kFloat16 || a_type == vllm::kS8,
|
||||
"Turing only support FP16 or INT8 activation.");
|
||||
}
|
||||
if (a_type == vllm::kFE4M3fn) {
|
||||
TORCH_CHECK(major_capability * 10 + minor_capability >= 89,
|
||||
"FP8 only support Ada Lovelace or newer GPUs.");
|
||||
@ -461,8 +467,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
exec_cfg = determine_exec_config(
|
||||
a_type, b_type, c_type, s_type, prob_m, prob_n, prob_k, num_experts,
|
||||
top_k, thread_m_blocks, m_block_size_8, num_bits, group_size,
|
||||
has_act_order, is_k_full, has_zp, is_zp_float, max_shared_mem, sms,
|
||||
is_a_8bit);
|
||||
has_act_order, is_k_full, has_zp, is_zp_float, is_a_8bit, stages,
|
||||
max_shared_mem, sms);
|
||||
thread_tfg = exec_cfg.tb_cfg;
|
||||
}
|
||||
|
||||
@ -479,7 +485,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
TORCH_CHECK(is_valid_config(thread_tfg, m_block_size_8, thread_m_blocks,
|
||||
prob_m, prob_n, prob_k, num_bits, group_size,
|
||||
has_act_order, is_k_full, has_zp, is_zp_float,
|
||||
max_shared_mem, is_a_8bit),
|
||||
is_a_8bit, stages, max_shared_mem),
|
||||
"Invalid thread config: thread_m_blocks = ", thread_m_blocks,
|
||||
", thread_k = ", thread_tfg.thread_k,
|
||||
", thread_n = ", thread_tfg.thread_n,
|
||||
@ -493,12 +499,12 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
int sh_cache_size =
|
||||
get_kernel_cache_size(thread_tfg, m_block_size_8, thread_m_blocks, prob_m,
|
||||
prob_n, prob_k, num_bits, group_size, has_act_order,
|
||||
is_k_full, has_zp, is_zp_float, is_a_8bit);
|
||||
is_k_full, has_zp, is_zp_float, is_a_8bit, stages);
|
||||
|
||||
auto kernel = get_marlin_kernel(
|
||||
a_type, b_type, c_type, s_type, thread_m_blocks, thread_n_blocks,
|
||||
thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks,
|
||||
num_threads, is_zp_float);
|
||||
num_threads, is_zp_float, stages);
|
||||
|
||||
if (kernel == MarlinDefault) {
|
||||
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
|
||||
@ -860,4 +866,4 @@ torch::Tensor moe_wna16_marlin_gemm(
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("moe_wna16_marlin_gemm", &moe_wna16_marlin_gemm);
|
||||
}
|
||||
}
|
||||
1
csrc/quantization/gptq_marlin/.gitignore
vendored
1
csrc/quantization/gptq_marlin/.gitignore
vendored
@ -1,2 +1,3 @@
|
||||
sm*_kernel_*.cu
|
||||
kernel_selector.h
|
||||
kernel_*.cu
|
||||
|
||||
@ -67,7 +67,7 @@ where `scale_factor * multiplier` can be computed at weight loading.
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 750
|
||||
// Lookup-table based 3-input logical operation; explicitly used for
|
||||
// dequantization as the compiler does not seem to automatically recognize it in
|
||||
// all cases.
|
||||
|
||||
@ -10,6 +10,8 @@ import jinja2
|
||||
|
||||
ARCHS = []
|
||||
SUPPORT_FP8 = False
|
||||
SUPPORT_SM75 = False
|
||||
SUPPORT_SM80 = False
|
||||
for arch in sys.argv[1].split(","):
|
||||
arch = arch[: arch.index(".") + 2].replace(".", "")
|
||||
arch = int(arch)
|
||||
@ -19,6 +21,10 @@ for arch in sys.argv[1].split(","):
|
||||
# with FP16 MMA, so it cannot achieve any acceleration.
|
||||
if arch in [89, 120]:
|
||||
SUPPORT_FP8 = True
|
||||
if arch >= 80:
|
||||
SUPPORT_SM80 = True
|
||||
if arch == 75:
|
||||
SUPPORT_SM75 = True
|
||||
|
||||
FILE_HEAD_COMMENT = """
|
||||
// auto generated by generate_kernels.py
|
||||
@ -166,6 +172,7 @@ def remove_old_kernels():
|
||||
|
||||
def generate_new_kernels():
|
||||
result_dict = {}
|
||||
sm_75_result_dict = {}
|
||||
|
||||
for quant_config in QUANT_CONFIGS:
|
||||
c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"])
|
||||
@ -184,6 +191,8 @@ def generate_new_kernels():
|
||||
s_type = quant_config.get("s_type", c_type)
|
||||
if (a_type, b_type, c_type) not in result_dict:
|
||||
result_dict[(a_type, b_type, c_type)] = []
|
||||
if a_type in ["kFloat16", "kS8"] and c_type == "kFloat16":
|
||||
sm_75_result_dict[(a_type, b_type, c_type)] = []
|
||||
|
||||
for group_blocks, m_blocks, thread_configs in itertools.product(
|
||||
all_group_blocks, all_m_blocks, all_thread_configs
|
||||
@ -207,78 +216,89 @@ def generate_new_kernels():
|
||||
"thread_k_blocks": thread_k // 16,
|
||||
"thread_n_blocks": thread_n // 16,
|
||||
"m_block_size_8": "true" if m_blocks == 0.5 else "false",
|
||||
"stages": "pipe_stages",
|
||||
"stages": 4,
|
||||
"group_blocks": group_blocks,
|
||||
"is_zp_float": "true" if is_zp_float else "false",
|
||||
}
|
||||
|
||||
result_dict[(a_type, b_type, c_type)].append(config)
|
||||
if SUPPORT_SM80:
|
||||
result_dict[(a_type, b_type, c_type)].append(config)
|
||||
if (a_type, b_type, c_type) in sm_75_result_dict and SUPPORT_SM75:
|
||||
config_sm75 = config.copy()
|
||||
config_sm75["stages"] = 2
|
||||
sm_75_result_dict[(a_type, b_type, c_type)].append(config_sm75)
|
||||
|
||||
kernel_selector_str = FILE_HEAD_COMMENT
|
||||
|
||||
for (a_type, b_type, c_type), config_list in result_dict.items():
|
||||
all_template_str_list = []
|
||||
for config in config_list:
|
||||
s_type = config["s_type"]
|
||||
template_str = jinja2.Template(TEMPLATE).render(
|
||||
a_type_id=f"vllm::{a_type}.id()",
|
||||
b_type_id=f"vllm::{b_type}.id()",
|
||||
c_type_id=f"vllm::{c_type}.id()",
|
||||
s_type_id=f"vllm::{s_type}.id()",
|
||||
**config,
|
||||
)
|
||||
all_template_str_list.append(template_str)
|
||||
|
||||
conditions = [
|
||||
f"a_type == vllm::{a_type}",
|
||||
f"b_type == vllm::{b_type}",
|
||||
f"c_type == vllm::{c_type}",
|
||||
f"s_type == vllm::{s_type}",
|
||||
f"threads == {config['threads']}",
|
||||
f"thread_m_blocks == {config['thread_m_blocks']}",
|
||||
f"thread_n_blocks == {config['thread_n_blocks']}",
|
||||
f"thread_k_blocks == {config['thread_k_blocks']}",
|
||||
f"m_block_size_8 == {config['m_block_size_8']}",
|
||||
f"group_blocks == {config['group_blocks']}",
|
||||
f"is_zp_float == {config['is_zp_float']}",
|
||||
]
|
||||
conditions = " && ".join(conditions)
|
||||
|
||||
if kernel_selector_str == FILE_HEAD_COMMENT:
|
||||
kernel_selector_str += f"if ({conditions})\n kernel = "
|
||||
else:
|
||||
kernel_selector_str += f"else if ({conditions})\n kernel = "
|
||||
|
||||
kernel_template2 = (
|
||||
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
|
||||
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
|
||||
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
|
||||
"{{is_zp_float}}>;"
|
||||
)
|
||||
|
||||
kernel_selector_str += (
|
||||
jinja2.Template(kernel_template2).render(
|
||||
for result_dict_tmp in [result_dict, sm_75_result_dict]:
|
||||
for (a_type, b_type, c_type), config_list in result_dict_tmp.items():
|
||||
all_template_str_list = []
|
||||
if not config_list:
|
||||
continue
|
||||
for config in config_list:
|
||||
s_type = config["s_type"]
|
||||
template_str = jinja2.Template(TEMPLATE).render(
|
||||
a_type_id=f"vllm::{a_type}.id()",
|
||||
b_type_id=f"vllm::{b_type}.id()",
|
||||
c_type_id=f"vllm::{c_type}.id()",
|
||||
s_type_id=f"vllm::{s_type}.id()",
|
||||
**config,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
all_template_str_list.append(template_str)
|
||||
|
||||
file_content = FILE_HEAD + "\n\n"
|
||||
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
|
||||
if a_type == "kFE4M3fn":
|
||||
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
else:
|
||||
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
conditions = [
|
||||
f"a_type == vllm::{a_type}",
|
||||
f"b_type == vllm::{b_type}",
|
||||
f"c_type == vllm::{c_type}",
|
||||
f"s_type == vllm::{s_type}",
|
||||
f"threads == {config['threads']}",
|
||||
f"thread_m_blocks == {config['thread_m_blocks']}",
|
||||
f"thread_n_blocks == {config['thread_n_blocks']}",
|
||||
f"thread_k_blocks == {config['thread_k_blocks']}",
|
||||
f"m_block_size_8 == {config['m_block_size_8']}",
|
||||
f"stages == {config['stages']}",
|
||||
f"group_blocks == {config['group_blocks']}",
|
||||
f"is_zp_float == {config['is_zp_float']}",
|
||||
]
|
||||
conditions = " && ".join(conditions)
|
||||
|
||||
filename = filename.lower()
|
||||
if kernel_selector_str == FILE_HEAD_COMMENT:
|
||||
kernel_selector_str += f"if ({conditions})\n kernel = "
|
||||
else:
|
||||
kernel_selector_str += f"else if ({conditions})\n kernel = "
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
|
||||
f.write(file_content)
|
||||
kernel_template2 = (
|
||||
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
|
||||
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
|
||||
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
|
||||
"{{is_zp_float}}>;"
|
||||
)
|
||||
|
||||
kernel_selector_str += (
|
||||
jinja2.Template(kernel_template2).render(
|
||||
a_type_id=f"vllm::{a_type}.id()",
|
||||
b_type_id=f"vllm::{b_type}.id()",
|
||||
c_type_id=f"vllm::{c_type}.id()",
|
||||
s_type_id=f"vllm::{s_type}.id()",
|
||||
**config,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
file_content = FILE_HEAD + "\n\n"
|
||||
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
|
||||
if a_type == "kFE4M3fn":
|
||||
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
elif result_dict_tmp is sm_75_result_dict:
|
||||
filename = f"sm75_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
else:
|
||||
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
|
||||
filename = filename.lower()
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
|
||||
f.write(file_content)
|
||||
|
||||
if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT:
|
||||
kernel_selector_str += (
|
||||
|
||||
@ -37,7 +37,7 @@ __global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){};
|
||||
|
||||
using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS);
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
||||
|
||||
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
|
||||
int const* __restrict__ perm_int_ptr,
|
||||
@ -148,7 +148,7 @@ typedef struct {
|
||||
|
||||
int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
|
||||
int prob_n, int prob_k, int num_bits, int group_size,
|
||||
bool has_act_order, bool is_k_full) {
|
||||
bool has_act_order, bool is_k_full, int stages) {
|
||||
bool cache_scales_chunk = has_act_order && !is_k_full;
|
||||
|
||||
int tb_n = th_config.thread_n;
|
||||
@ -166,28 +166,29 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
|
||||
|
||||
if (cache_scales_chunk) {
|
||||
int load_groups =
|
||||
tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
|
||||
tb_groups * stages * 2; // Chunk size is 2x pipeline over dim K
|
||||
load_groups = max(load_groups, 32); // We load at least 32 scale groups
|
||||
return load_groups * tb_n * 2;
|
||||
} else {
|
||||
int tb_scales = tb_groups * tb_n * 2;
|
||||
|
||||
return tb_scales * pipe_stages;
|
||||
return tb_scales * stages;
|
||||
}
|
||||
}
|
||||
|
||||
int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
|
||||
int prob_m, int prob_n, int prob_k, int num_bits,
|
||||
int group_size, bool has_act_order, bool is_k_full,
|
||||
int has_zp, int is_zp_float) {
|
||||
int has_zp, bool is_zp_float, bool is_a_8bit,
|
||||
int stages) {
|
||||
int pack_factor = 32 / num_bits;
|
||||
|
||||
// Get B size
|
||||
int tb_k = th_config.thread_k;
|
||||
int tb_n = th_config.thread_n;
|
||||
int tb_m = thread_m_blocks * 16;
|
||||
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2;
|
||||
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
|
||||
int sh_a_size = stages * (tb_m * tb_k) * (is_a_8bit ? 1 : 2);
|
||||
int sh_b_size = stages * (tb_k * tb_n / pack_factor) * 4;
|
||||
int sh_red_size = tb_m * (tb_n + 8) * 2;
|
||||
int sh_bias_size = tb_n * 2;
|
||||
int tmp_size =
|
||||
@ -196,8 +197,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
|
||||
|
||||
int sh_s_size =
|
||||
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
|
||||
group_size, has_act_order, is_k_full);
|
||||
int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0;
|
||||
group_size, has_act_order, is_k_full, stages);
|
||||
int sh_g_idx_size = has_act_order && !is_k_full ? stages * tb_k / 4 : 0;
|
||||
int sh_zp_size = 0;
|
||||
if (has_zp) {
|
||||
if (is_zp_float)
|
||||
@ -217,7 +218,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
|
||||
bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
|
||||
int prob_m, int prob_n, int prob_k, int num_bits,
|
||||
int group_size, bool has_act_order, bool is_k_full,
|
||||
int has_zp, int is_zp_float, int max_shared_mem) {
|
||||
int has_zp, bool is_zp_float, bool is_a_8bit, int stages,
|
||||
int max_shared_mem) {
|
||||
// Sanity
|
||||
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
|
||||
th_config.num_threads == -1) {
|
||||
@ -242,7 +244,7 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
|
||||
// Check that pipeline fits into cache
|
||||
int cache_size = get_kernel_cache_size(
|
||||
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size,
|
||||
has_act_order, is_k_full, has_zp, is_zp_float);
|
||||
has_act_order, is_k_full, has_zp, is_zp_float, is_a_8bit, stages);
|
||||
return cache_size <= max_shared_mem;
|
||||
}
|
||||
|
||||
@ -251,7 +253,7 @@ MarlinFuncPtr get_marlin_kernel(
|
||||
const vllm::ScalarType c_type, const vllm::ScalarType s_type,
|
||||
int thread_m_blocks, int thread_n_blocks, int thread_k_blocks,
|
||||
bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks,
|
||||
int threads, bool is_zp_float) {
|
||||
int threads, bool is_zp_float, int stages) {
|
||||
int num_bits = b_type.size_bits();
|
||||
auto kernel = MarlinDefault;
|
||||
|
||||
@ -265,7 +267,8 @@ exec_config_t determine_exec_config(
|
||||
const vllm::ScalarType& c_type, const vllm::ScalarType& s_type, int prob_m,
|
||||
int prob_n, int prob_k, int thread_m_blocks, bool m_block_size_8,
|
||||
int num_bits, int group_size, bool has_act_order, bool is_k_full,
|
||||
bool has_zp, bool is_zp_float, int max_shared_mem, int sms) {
|
||||
bool has_zp, bool is_zp_float, int is_a_8bit, int stages,
|
||||
int max_shared_mem, int sms) {
|
||||
exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}};
|
||||
thread_config_t* thread_configs = thread_m_blocks > 1
|
||||
? large_batch_thread_configs
|
||||
@ -280,13 +283,15 @@ exec_config_t determine_exec_config(
|
||||
|
||||
if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k,
|
||||
num_bits, group_size, has_act_order, is_k_full, has_zp,
|
||||
is_zp_float, max_shared_mem - 512)) {
|
||||
is_zp_float, is_a_8bit, stages,
|
||||
max_shared_mem - 512)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int cache_size = get_kernel_cache_size(
|
||||
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits,
|
||||
group_size, has_act_order, is_k_full, has_zp, is_zp_float);
|
||||
int cache_size = get_kernel_cache_size(th_config, thread_m_blocks, prob_m,
|
||||
prob_n, prob_k, num_bits, group_size,
|
||||
has_act_order, is_k_full, has_zp,
|
||||
is_zp_float, is_a_8bit, stages);
|
||||
|
||||
int group_blocks = 0;
|
||||
if (!has_act_order) {
|
||||
@ -297,14 +302,10 @@ exec_config_t determine_exec_config(
|
||||
get_marlin_kernel(a_type, b_type, c_type, s_type, thread_m_blocks,
|
||||
th_config.thread_n / 16, th_config.thread_k / 16,
|
||||
m_block_size_8, has_act_order, has_zp, group_blocks,
|
||||
th_config.num_threads, is_zp_float);
|
||||
th_config.num_threads, is_zp_float, stages);
|
||||
|
||||
if (kernel == MarlinDefault) continue;
|
||||
|
||||
// int m_tiles = div_ceil(prob_m, thread_m_blocks * 16);
|
||||
// int n_tiles = prob_n / th_config.thread_n;
|
||||
// int k_tiles = prob_k / th_config.thread_k;
|
||||
|
||||
return {1, th_config};
|
||||
}
|
||||
|
||||
@ -321,6 +322,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
int group_size, int dev, cudaStream_t stream, int thread_k_init,
|
||||
int thread_n_init, int sms, bool use_atomic_add,
|
||||
bool use_fp32_reduce, bool is_zp_float) {
|
||||
bool is_a_8bit = a_type.size_bits() == 8;
|
||||
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
||||
", ", prob_n, ", ", prob_k, "]");
|
||||
|
||||
@ -389,8 +391,14 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
dev);
|
||||
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
|
||||
dev);
|
||||
TORCH_CHECK(major_capability * 10 + minor_capability >= 80,
|
||||
"marlin kernel only support Ampere or newer GPUs.");
|
||||
TORCH_CHECK(major_capability * 10 + minor_capability >= 75,
|
||||
"marlin kernel only support Turing or newer GPUs.");
|
||||
int stages = 4;
|
||||
if (major_capability == 7 && minor_capability == 5) {
|
||||
stages = 2;
|
||||
TORCH_CHECK(a_type == vllm::kFloat16 || a_type == vllm::kS8,
|
||||
"Turing only support FP16 or INT8 activation.");
|
||||
}
|
||||
if (a_type == vllm::kFE4M3fn) {
|
||||
TORCH_CHECK(
|
||||
major_capability * 10 + minor_capability == 89 ||
|
||||
@ -431,7 +439,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
exec_cfg = determine_exec_config(
|
||||
a_type, b_type, c_type, s_type, prob_m_split, prob_n, prob_k,
|
||||
thread_m_blocks, m_block_size_8, num_bits, group_size, has_act_order,
|
||||
is_k_full, has_zp, is_zp_float, max_shared_mem, sms);
|
||||
is_k_full, has_zp, is_zp_float, is_a_8bit, stages, max_shared_mem,
|
||||
sms);
|
||||
thread_tfg = exec_cfg.tb_cfg;
|
||||
if (thread_tfg.thread_n != -1) {
|
||||
if (prob_n / thread_tfg.thread_n *
|
||||
@ -440,7 +449,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
if (is_valid_config({128, 64, 128}, thread_m_blocks, prob_m_split,
|
||||
prob_n, prob_k, num_bits, group_size,
|
||||
has_act_order, is_k_full, has_zp, is_zp_float,
|
||||
max_shared_mem_new)) {
|
||||
is_a_8bit, stages, max_shared_mem_new)) {
|
||||
thread_tfg = {128, 64, 128};
|
||||
exec_cfg = {1, thread_tfg};
|
||||
}
|
||||
@ -466,7 +475,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
TORCH_CHECK(
|
||||
is_valid_config(thread_tfg, thread_m_blocks, prob_m_split, prob_n,
|
||||
prob_k, num_bits, group_size, has_act_order, is_k_full,
|
||||
has_zp, is_zp_float, max_shared_mem_new),
|
||||
has_zp, is_zp_float, is_a_8bit, stages,
|
||||
max_shared_mem_new),
|
||||
"Invalid thread config: thread_m_blocks = ", thread_m_blocks,
|
||||
", thread_k = ", thread_tfg.thread_k,
|
||||
", thread_n = ", thread_tfg.thread_n,
|
||||
@ -475,12 +485,12 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
", prob_m_split = ", prob_m_split, ", group_size = ", group_size,
|
||||
", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
|
||||
", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float,
|
||||
", max_shared_mem_new = ", max_shared_mem_new);
|
||||
", stages = ", stages, ", max_shared_mem_new = ", max_shared_mem_new);
|
||||
|
||||
auto kernel = get_marlin_kernel(
|
||||
a_type, b_type, c_type, s_type, thread_m_blocks, thread_n_blocks,
|
||||
thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks,
|
||||
num_threads, is_zp_float);
|
||||
num_threads, is_zp_float, stages);
|
||||
|
||||
if (kernel == MarlinDefault) {
|
||||
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
|
||||
|
||||
@ -1,17 +1,19 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/all.h>
|
||||
#ifndef _marlin_cuh
|
||||
#define _marlin_cuh
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <iostream>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <iostream>
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin
|
||||
#endif
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin
|
||||
#endif
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
@ -51,9 +53,51 @@ using I4 = Vec<int, 4>;
|
||||
|
||||
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
// No support for async
|
||||
#else
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
|
||||
__device__ inline void cp_async1_ca_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
if (pred) {
|
||||
reinterpret_cast<int32_t*>(smem_ptr)[0] =
|
||||
reinterpret_cast<const int32_t*>(glob_ptr)[0];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void cp_async2_ca_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
if (pred) {
|
||||
reinterpret_cast<int64_t*>(smem_ptr)[0] =
|
||||
reinterpret_cast<const int64_t*>(glob_ptr)[0];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void cp_async4_ca_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
if (pred) {
|
||||
reinterpret_cast<int4*>(smem_ptr)[0] =
|
||||
reinterpret_cast<const int4*>(glob_ptr)[0];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
if (pred) {
|
||||
reinterpret_cast<int4*>(smem_ptr)[0] =
|
||||
reinterpret_cast<const int4*>(glob_ptr)[0];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
||||
reinterpret_cast<int4*>(smem_ptr)[0] =
|
||||
reinterpret_cast<const int4*>(glob_ptr)[0];
|
||||
}
|
||||
|
||||
__device__ inline void cp_async_fence() {}
|
||||
|
||||
template <int n>
|
||||
__device__ inline void cp_async_wait() {}
|
||||
|
||||
#else
|
||||
|
||||
__device__ inline void cp_async1_ca_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
@ -126,6 +170,8 @@ __device__ inline void cp_async_wait() {
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
||||
}
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
|
||||
#endif
|
||||
269
csrc/quantization/gptq_marlin/marlin_mma.h
Normal file
269
csrc/quantization/gptq_marlin/marlin_mma.h
Normal file
@ -0,0 +1,269 @@
|
||||
|
||||
#include "marlin_dtypes.cuh"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
|
||||
// output/accumulation.
|
||||
template <vllm::ScalarTypeId type_id, bool use_fp16_accum, int k_size = 16>
|
||||
__device__ inline void mma(
|
||||
const typename MarlinScalarType<type_id>::FragA& a_frag,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b,
|
||||
typename MarlinScalarType<type_id>::FragC& frag_c, int idx = 0) {
|
||||
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
|
||||
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
||||
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
|
||||
if constexpr (!std::is_same<scalar_t, half>::value || k_size != 16) {
|
||||
static_assert(!use_fp16_accum);
|
||||
}
|
||||
|
||||
if constexpr (k_size == 16) {
|
||||
if constexpr (std::is_same<scalar_t, half>::value && !use_fp16_accum) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(b[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
|
||||
"f"(c[3]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[2]), "r"(a[3]), "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
|
||||
"f"(c[3]));
|
||||
#else
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
#endif
|
||||
} else if constexpr (std::is_same<scalar_t, half>::value &&
|
||||
use_fp16_accum) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
uint32_t* c = reinterpret_cast<uint32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
|
||||
"{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(b[0]), "r"(c[0]), "r"(c[1]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
|
||||
"{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(a[2]), "r"(a[3]), "r"(b[1]), "r"(c[0]), "r"(c[1]));
|
||||
#else
|
||||
uint32_t* c = reinterpret_cast<uint32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
|
||||
"{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"r"(c[0]), "r"(c[1]));
|
||||
#endif
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "f"(c[0]),
|
||||
"f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "r"(c[0]),
|
||||
"r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
}
|
||||
} else if (k_size == 32) {
|
||||
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(a[0]), "r"(b[0]), "r"(c[0]), "r"(c[1]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
||||
: "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(a[1]), "r"(b[0]), "r"(c[2]), "r"(c[3]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(a[2]), "r"(b[1]), "r"(c[0]), "r"(c[1]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
||||
: "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(a[3]), "r"(b[1]), "r"(c[2]), "r"(c[3]));
|
||||
#else
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <vllm::ScalarTypeId type_id, bool use_fp16_accum, int k_size = 16>
|
||||
__device__ inline void mma_trans(
|
||||
const typename MarlinScalarType<type_id>::FragA& a_frag,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b2,
|
||||
typename MarlinScalarType<type_id>::FragC& frag_c) {
|
||||
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
|
||||
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
||||
const uint32_t* b2 = reinterpret_cast<const uint32_t*>(&frag_b2);
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
|
||||
if constexpr (!std::is_same<scalar_t, half>::value || k_size != 16) {
|
||||
static_assert(!use_fp16_accum);
|
||||
}
|
||||
|
||||
if constexpr (k_size == 16) {
|
||||
if constexpr (std::is_same<scalar_t, half>::value && !use_fp16_accum) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
|
||||
"f"(c[3]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[1]), "r"(b2[1]), "r"(a[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
|
||||
"f"(c[3]));
|
||||
#else
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
#endif
|
||||
} else if constexpr (std::is_same<scalar_t, half>::value &&
|
||||
use_fp16_accum) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
uint32_t* c = reinterpret_cast<uint32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
|
||||
"{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
|
||||
"{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(b[1]), "r"(b2[1]), "r"(a[1]), "r"(c[0]), "r"(c[1]));
|
||||
#else
|
||||
uint32_t* c = reinterpret_cast<uint32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
|
||||
"{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"r"(c[0]), "r"(c[1]));
|
||||
#endif
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
|
||||
"f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]),
|
||||
"r"(c[3]));
|
||||
}
|
||||
} else {
|
||||
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(b[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
||||
: "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(b2[1]), "r"(a[0]), "r"(c[2]), "r"(c[3]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(b[0]), "r"(a[1]), "r"(c[0]), "r"(c[1]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
||||
: "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(b2[1]), "r"(a[1]), "r"(c[2]), "r"(c[3]));
|
||||
#else
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
@ -26,6 +26,7 @@
|
||||
#include "marlin.cuh"
|
||||
#include "marlin_dtypes.cuh"
|
||||
#include "dequant.h"
|
||||
#include "marlin_mma.h"
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||
@ -35,7 +36,7 @@
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
||||
|
||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const vllm::ScalarTypeId b_type_id, // weight MarlinScalarType id
|
||||
@ -75,137 +76,6 @@ __global__ void Marlin(
|
||||
|
||||
#else
|
||||
|
||||
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
|
||||
// output/accumulation.
|
||||
template <vllm::ScalarTypeId type_id, int k_size = 16>
|
||||
__device__ inline void mma(
|
||||
const typename MarlinScalarType<type_id>::FragA& a_frag,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b,
|
||||
typename MarlinScalarType<type_id>::FragC& frag_c, int idx = 0) {
|
||||
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
|
||||
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
||||
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
|
||||
if constexpr (k_size == 16) {
|
||||
if constexpr (std::is_same<scalar_t, half>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "f"(c[0]),
|
||||
"f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "r"(c[0]),
|
||||
"r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
}
|
||||
} else if (k_size == 32) {
|
||||
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <vllm::ScalarTypeId type_id, int k_size = 16>
|
||||
__device__ inline void mma_trans(
|
||||
const typename MarlinScalarType<type_id>::FragA& a_frag,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b2,
|
||||
typename MarlinScalarType<type_id>::FragC& frag_c) {
|
||||
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
|
||||
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
||||
const uint32_t* b2 = reinterpret_cast<const uint32_t*>(&frag_b2);
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
|
||||
if constexpr (k_size == 16) {
|
||||
if constexpr (std::is_same<scalar_t, half>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
|
||||
"f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]),
|
||||
"r"(c[3]));
|
||||
}
|
||||
} else {
|
||||
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
||||
// memory, directly in tensor core layout.
|
||||
template <int count, vllm::ScalarTypeId type_id>
|
||||
@ -415,6 +285,17 @@ __global__ void Marlin(
|
||||
if constexpr (a_type_id == vllm::kFE4M3fn.id()) return;
|
||||
#endif
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
// Turing TensorCore only supports fp16 and int8
|
||||
if constexpr (a_type_id != vllm::kFloat16.id() && a_type_id != vllm::kS8.id())
|
||||
return;
|
||||
#endif
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
constexpr bool use_fp16_accum = a_type_id == vllm::kFloat16.id();
|
||||
#else
|
||||
constexpr bool use_fp16_accum = false;
|
||||
#endif
|
||||
using Adtype = MarlinScalarType<a_type_id>;
|
||||
using Cdtype = MarlinScalarType<c_type_id>;
|
||||
const int4* A = A0;
|
||||
@ -873,10 +754,6 @@ __global__ void Marlin(
|
||||
constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride)
|
||||
: (stages * s_sh_stage);
|
||||
int4* sh_s = sh_zp + (stages * zp_sh_stage);
|
||||
// shared memory reused by reduction should be smaller than
|
||||
// shared memory used by weight.
|
||||
static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <=
|
||||
stages * b_sh_stage);
|
||||
int4* sh_a = sh_s + sh_s_size;
|
||||
|
||||
// Register storage for double buffer of shared memory reads.
|
||||
@ -1395,11 +1272,13 @@ __global__ void Marlin(
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks; i++) {
|
||||
if constexpr (m_block_size_8) {
|
||||
mma_trans<a_type_id>(frag_a[k2][i], frag_b0, frag_b1,
|
||||
frag_c[i][j][0]);
|
||||
mma_trans<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b0, frag_b1,
|
||||
frag_c[i][j][0]);
|
||||
} else {
|
||||
mma<a_type_id>(frag_a[k2][i], frag_b0, frag_c[i][j][0]);
|
||||
mma<a_type_id>(frag_a[k2][i], frag_b1, frag_c[i][j][1]);
|
||||
mma<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b0,
|
||||
frag_c[i][j][0]);
|
||||
mma<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b1,
|
||||
frag_c[i][j][1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1433,10 +1312,12 @@ __global__ void Marlin(
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks; i++) {
|
||||
mma<a_type_id, 32>(frag_a[k2][i], frag_b[0],
|
||||
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]);
|
||||
mma<a_type_id, 32>(frag_a[k2][i], frag_b[1],
|
||||
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]);
|
||||
mma<a_type_id, false, 32>(
|
||||
frag_a[k2][i], frag_b[0],
|
||||
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]);
|
||||
mma<a_type_id, false, 32>(
|
||||
frag_a[k2][i], frag_b[1],
|
||||
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]);
|
||||
}
|
||||
|
||||
if constexpr (group_blocks != -1) {
|
||||
@ -1956,6 +1837,21 @@ __global__ void Marlin(
|
||||
// While this pattern may not be the most readable, other ways of writing
|
||||
// the loop seemed to noticeably worse performance after compilation.
|
||||
if (slice_iters == 0) {
|
||||
// convert fp16 accum to fp32 for reduction
|
||||
if constexpr (use_fp16_accum) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < (thread_m_blocks * (is_a_8bit ? 2 : 4) * 2); i++) {
|
||||
float* frag_c_part_float = reinterpret_cast<float*>(frag_c) + i * 4;
|
||||
scalar_t* frag_c_part_half =
|
||||
reinterpret_cast<scalar_t*>(frag_c_part_float);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 3; i >= 0; i--) {
|
||||
frag_c_part_float[i] = Cdtype::num2float(frag_c_part_half[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (is_a_8bit) {
|
||||
float frag_a_s[2 * thread_m_blocks];
|
||||
|
||||
|
||||
@ -617,7 +617,7 @@ struct MacheteCollectiveMma {
|
||||
|
||||
// Same as upstream, should be kept the same when possible, not formatted for
|
||||
// easier comparison
|
||||
// with `SwapAB ? N : M -> M` since we dont support SwapAB
|
||||
// with `SwapAB ? N : M -> M` since we don't support SwapAB
|
||||
// clang-format off
|
||||
template<class ProblemShape>
|
||||
static bool
|
||||
|
||||
@ -22,6 +22,62 @@ __device__ __forceinline__ float GroupReduceMax(float val) {
|
||||
return val;
|
||||
}
|
||||
|
||||
template <typename T, bool SCALE_UE8M0>
|
||||
__device__ __forceinline__ float ComputeGroupScale(
|
||||
const T* __restrict__ group_input, T* __restrict__ smem_group,
|
||||
const int group_size, const int lane_id, const int threads_per_group,
|
||||
const float eps, const float max_8bit) {
|
||||
float local_absmax = eps;
|
||||
|
||||
constexpr int vec_size = 16 / sizeof(T);
|
||||
|
||||
// copy global -> shared & compute absmax
|
||||
auto scalar_op_cache = [&] __device__(T & dst, const T& src) {
|
||||
float abs_v = fabsf(static_cast<float>(src));
|
||||
local_absmax = fmaxf(local_absmax, abs_v);
|
||||
dst = src;
|
||||
};
|
||||
|
||||
vllm::vectorize_with_alignment<vec_size>(
|
||||
group_input, // in
|
||||
smem_group, // out (shared)
|
||||
group_size, // elements per group
|
||||
lane_id, // thread id
|
||||
threads_per_group, // stride in group
|
||||
scalar_op_cache); // scalar handler
|
||||
|
||||
local_absmax = GroupReduceMax(local_absmax);
|
||||
|
||||
float y_s = local_absmax / max_8bit;
|
||||
if constexpr (SCALE_UE8M0) {
|
||||
y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f))));
|
||||
}
|
||||
|
||||
return y_s;
|
||||
}
|
||||
|
||||
template <typename T, typename DST_DTYPE>
|
||||
__device__ __forceinline__ void QuantizeGroup(
|
||||
const T* __restrict__ smem_group, DST_DTYPE* __restrict__ group_output,
|
||||
const int group_size, const int lane_id, const int threads_per_group,
|
||||
const float y_s, const float min_8bit, const float max_8bit) {
|
||||
constexpr int vec_size = 16 / sizeof(T);
|
||||
|
||||
// quantize shared -> global 8-bit
|
||||
auto scalar_op_quant = [&] __device__(DST_DTYPE & dst, const T& src) {
|
||||
float q = fminf(fmaxf(static_cast<float>(src) / y_s, min_8bit), max_8bit);
|
||||
dst = DST_DTYPE(q);
|
||||
};
|
||||
|
||||
vllm::vectorize_with_alignment<vec_size>(
|
||||
smem_group, // in (shared)
|
||||
group_output, // out (global quant tensor)
|
||||
group_size, // elements
|
||||
lane_id, // tid
|
||||
threads_per_group, // stride
|
||||
scalar_op_quant); // scalar handler
|
||||
}
|
||||
|
||||
template <typename T, typename DST_DTYPE, bool IS_COLUMN_MAJOR = false,
|
||||
bool SCALE_UE8M0 = false, typename scale_packed_t = float>
|
||||
__global__ void per_token_group_quant_8bit_kernel(
|
||||
@ -38,8 +94,6 @@ __global__ void per_token_group_quant_8bit_kernel(
|
||||
const int64_t global_group_id = block_group_id + local_group_id;
|
||||
const int64_t block_group_offset = global_group_id * group_size;
|
||||
|
||||
float local_absmax = eps;
|
||||
|
||||
using scale_element_t = float;
|
||||
static_assert(sizeof(scale_packed_t) % sizeof(scale_element_t) == 0);
|
||||
|
||||
@ -68,30 +122,9 @@ __global__ void per_token_group_quant_8bit_kernel(
|
||||
T* smem = reinterpret_cast<T*>(smem_raw);
|
||||
T* smem_group = smem + local_group_id * group_size;
|
||||
|
||||
constexpr int vec_size = 16 / sizeof(T);
|
||||
using vec_t = vllm::vec_n_t<T, vec_size>;
|
||||
|
||||
// copy global -> shared & compute absmax
|
||||
auto scalar_op_cache = [&] __device__(T & dst, const T& src) {
|
||||
float abs_v = fabsf(static_cast<float>(src));
|
||||
local_absmax = fmaxf(local_absmax, abs_v);
|
||||
dst = src;
|
||||
};
|
||||
|
||||
vllm::vectorize_with_alignment<vec_size>(
|
||||
group_input, // in
|
||||
smem_group, // out (shared)
|
||||
group_size, // elements per group
|
||||
lane_id, // thread id
|
||||
threads_per_group, // stride in group
|
||||
scalar_op_cache); // scalar handler
|
||||
|
||||
local_absmax = GroupReduceMax(local_absmax);
|
||||
|
||||
float y_s = local_absmax / max_8bit;
|
||||
if constexpr (SCALE_UE8M0) {
|
||||
y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f))));
|
||||
}
|
||||
const float y_s = ComputeGroupScale<T, SCALE_UE8M0>(
|
||||
group_input, smem_group, group_size, lane_id, threads_per_group, eps,
|
||||
max_8bit);
|
||||
|
||||
scale_element_t y_s_quant = y_s;
|
||||
|
||||
@ -101,19 +134,24 @@ __global__ void per_token_group_quant_8bit_kernel(
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// quantize shared -> global 8-bit
|
||||
auto scalar_op_quant = [&] __device__(DST_DTYPE & dst, const T& src) {
|
||||
float q = fminf(fmaxf(static_cast<float>(src) / y_s, min_8bit), max_8bit);
|
||||
dst = DST_DTYPE(q);
|
||||
};
|
||||
QuantizeGroup<T, DST_DTYPE>(smem_group, group_output, group_size, lane_id,
|
||||
threads_per_group, y_s, min_8bit, max_8bit);
|
||||
}
|
||||
|
||||
vllm::vectorize_with_alignment<vec_size>(
|
||||
smem_group, // in (shared)
|
||||
group_output, // out (global quant tensor)
|
||||
group_size, // elements
|
||||
lane_id, // tid
|
||||
threads_per_group, // stride
|
||||
scalar_op_quant); // scalar handler
|
||||
inline int GetGroupsPerBlock(int64_t num_groups) {
|
||||
if (num_groups % 16 == 0) {
|
||||
return 16;
|
||||
}
|
||||
if (num_groups % 8 == 0) {
|
||||
return 8;
|
||||
}
|
||||
if (num_groups % 4 == 0) {
|
||||
return 4;
|
||||
}
|
||||
if (num_groups % 2 == 0) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
void per_token_group_quant_8bit(const torch::Tensor& input,
|
||||
@ -133,17 +171,7 @@ void per_token_group_quant_8bit(const torch::Tensor& input,
|
||||
|
||||
constexpr int THREADS_PER_GROUP = 16;
|
||||
|
||||
int groups_per_block = 1;
|
||||
|
||||
if (num_groups % 16 == 0) {
|
||||
groups_per_block = 16;
|
||||
} else if (num_groups % 8 == 0) {
|
||||
groups_per_block = 8;
|
||||
} else if (num_groups % 4 == 0) {
|
||||
groups_per_block = 4;
|
||||
} else if (num_groups % 2 == 0) {
|
||||
groups_per_block = 2;
|
||||
}
|
||||
const int groups_per_block = GetGroupsPerBlock(num_groups);
|
||||
|
||||
auto dst_type = output_q.scalar_type();
|
||||
const int num_blocks = num_groups / groups_per_block;
|
||||
@ -225,8 +253,6 @@ __global__ void per_token_group_quant_8bit_packed_kernel(
|
||||
|
||||
const int64_t block_group_offset = global_group_id * group_size;
|
||||
|
||||
float local_absmax = eps;
|
||||
|
||||
const T* group_input = input + block_group_offset;
|
||||
DST_DTYPE* group_output =
|
||||
static_cast<DST_DTYPE*>(output_q) + block_group_offset;
|
||||
@ -235,29 +261,9 @@ __global__ void per_token_group_quant_8bit_packed_kernel(
|
||||
extern __shared__ __align__(16) char smem_raw[];
|
||||
T* smem = reinterpret_cast<T*>(smem_raw);
|
||||
T* smem_group = smem + local_group_id * group_size;
|
||||
|
||||
constexpr int vec_size = 16 / sizeof(T);
|
||||
using vec_t = vllm::vec_n_t<T, vec_size>;
|
||||
|
||||
// copy global -> shared & compute absmax
|
||||
auto scalar_op_cache = [&] __device__(T & dst, const T& src) {
|
||||
float abs_v = fabsf(static_cast<float>(src));
|
||||
local_absmax = fmaxf(local_absmax, abs_v);
|
||||
dst = src;
|
||||
};
|
||||
|
||||
vllm::vectorize_with_alignment<vec_size>(
|
||||
group_input, // in
|
||||
smem_group, // out (shared)
|
||||
group_size, // elements per group
|
||||
lane_id, // thread id
|
||||
threads_per_group, // stride in group
|
||||
scalar_op_cache); // scalar handler
|
||||
|
||||
local_absmax = GroupReduceMax(local_absmax);
|
||||
|
||||
float y_s = local_absmax / max_8bit;
|
||||
y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f))));
|
||||
const float y_s =
|
||||
ComputeGroupScale<T, true>(group_input, smem_group, group_size, lane_id,
|
||||
threads_per_group, eps, max_8bit);
|
||||
|
||||
// pack 4 scales into a uint32
|
||||
if (lane_id == 0) {
|
||||
@ -284,19 +290,8 @@ __global__ void per_token_group_quant_8bit_packed_kernel(
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// quantize shared -> global 8-bit
|
||||
auto scalar_op_quant = [&] __device__(DST_DTYPE & dst, const T& src) {
|
||||
float q = fminf(fmaxf(static_cast<float>(src) / y_s, min_8bit), max_8bit);
|
||||
dst = DST_DTYPE(q);
|
||||
};
|
||||
|
||||
vllm::vectorize_with_alignment<vec_size>(
|
||||
smem_group, // in (shared)
|
||||
group_output, // out (global quant tensor)
|
||||
group_size, // elements
|
||||
lane_id, // tid
|
||||
threads_per_group, // stride
|
||||
scalar_op_quant); // scalar handler
|
||||
QuantizeGroup<T, DST_DTYPE>(smem_group, group_output, group_size, lane_id,
|
||||
threads_per_group, y_s, min_8bit, max_8bit);
|
||||
}
|
||||
|
||||
void per_token_group_quant_8bit_packed(const torch::Tensor& input,
|
||||
@ -337,17 +332,7 @@ void per_token_group_quant_8bit_packed(const torch::Tensor& input,
|
||||
|
||||
constexpr int THREADS_PER_GROUP = 16;
|
||||
|
||||
int groups_per_block = 1;
|
||||
|
||||
if (num_groups % 16 == 0) {
|
||||
groups_per_block = 16;
|
||||
} else if (num_groups % 8 == 0) {
|
||||
groups_per_block = 8;
|
||||
} else if (num_groups % 4 == 0) {
|
||||
groups_per_block = 4;
|
||||
} else if (num_groups % 2 == 0) {
|
||||
groups_per_block = 2;
|
||||
}
|
||||
const int groups_per_block = GetGroupsPerBlock(num_groups);
|
||||
|
||||
auto dst_type = output_q.scalar_type();
|
||||
const int num_blocks = num_groups / groups_per_block;
|
||||
|
||||
@ -550,8 +550,8 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowPrefill(
|
||||
int rowEnd = rowEnds[rowIdx];
|
||||
|
||||
// Local pointers to this block
|
||||
outIndices += rowIdx * topK;
|
||||
logits += rowIdx * stride0;
|
||||
outIndices += static_cast<int64_t>(rowIdx) * topK;
|
||||
logits += static_cast<int64_t>(rowIdx) * stride0;
|
||||
|
||||
topKPerRowJob<kNumThreadsPerBlock, kNumBins, useRadixSort>(
|
||||
nullptr, logits, rowStart, rowEnd, outIndices, nullptr, stride1, topK);
|
||||
@ -576,19 +576,21 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode(
|
||||
|
||||
// Local pointers to this block
|
||||
if constexpr (!multipleBlocksPerRow && !mergeBlocks) {
|
||||
outIndices += rowIdx * topK;
|
||||
outIndices += static_cast<int64_t>(rowIdx) * topK;
|
||||
} else if constexpr (multipleBlocksPerRow) {
|
||||
const auto blockSize = rowEnd / gridDim.y; // 16384 / 2 = 8192
|
||||
rowStart = blockSize * blockIdx.y; // 8192 * 1 = 8192
|
||||
rowEnd = gridDim.y == blockIdx.y + 1 ? rowEnd : rowStart + blockSize;
|
||||
outIndices += rowIdx * gridDim.y * topK + blockIdx.y * topK;
|
||||
outLogits += rowIdx * gridDim.y * topK + blockIdx.y * topK;
|
||||
outIndices +=
|
||||
static_cast<int64_t>(rowIdx) * gridDim.y * topK + blockIdx.y * topK;
|
||||
outLogits +=
|
||||
static_cast<int64_t>(rowIdx) * gridDim.y * topK + blockIdx.y * topK;
|
||||
} else if constexpr (mergeBlocks) {
|
||||
rowEnd = numBlocksToMerge * topK;
|
||||
indices += rowIdx * numBlocksToMerge * topK;
|
||||
outIndices += rowIdx * topK;
|
||||
indices += static_cast<int64_t>(rowIdx) * numBlocksToMerge * topK;
|
||||
outIndices += static_cast<int64_t>(rowIdx) * topK;
|
||||
}
|
||||
logits += rowIdx * stride0;
|
||||
logits += static_cast<int64_t>(rowIdx) * stride0;
|
||||
|
||||
topKPerRowJob<kNumThreadsPerBlock, kNumBins, useRadixSort,
|
||||
multipleBlocksPerRow, mergeBlocks>(
|
||||
|
||||
@ -754,6 +754,13 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
"Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()");
|
||||
cache_ops.impl("cp_gather_cache", torch::kCUDA, &cp_gather_cache);
|
||||
|
||||
cache_ops.def(
|
||||
"cp_gather_and_upconvert_fp8_kv_cache(Tensor src_cache, Tensor! dst, "
|
||||
"Tensor block_table, Tensor seq_lens, Tensor workspace_starts, int "
|
||||
"batch_size) -> ()");
|
||||
cache_ops.impl("cp_gather_and_upconvert_fp8_kv_cache", torch::kCUDA,
|
||||
&cp_gather_and_upconvert_fp8_kv_cache);
|
||||
|
||||
cache_ops.def(
|
||||
"indexer_k_quant_and_cache(Tensor k, Tensor! kv_cache, Tensor "
|
||||
"slot_mapping, "
|
||||
|
||||
@ -32,7 +32,7 @@ ARG DEADSNAKES_GPGKEY_URL
|
||||
|
||||
# The PyPA get-pip.py script is a self contained script+zip file, that provides
|
||||
# both the installer script and the pip base85-encoded zip archive. This allows
|
||||
# bootstrapping pip in environment where a dsitribution package does not exist.
|
||||
# bootstrapping pip in environment where a distribution package does not exist.
|
||||
#
|
||||
# By parameterizing the URL for get-pip.py installation script, we allow
|
||||
# third-party to use their own copy of the script stored in a private mirror.
|
||||
@ -73,15 +73,13 @@ ARG INSTALL_KV_CONNECTORS=false
|
||||
#################### BASE BUILD IMAGE ####################
|
||||
# prepare basic build environment
|
||||
FROM ${BUILD_BASE_IMAGE} AS base
|
||||
|
||||
ARG CUDA_VERSION
|
||||
ARG PYTHON_VERSION
|
||||
ARG TARGETPLATFORM
|
||||
ARG INSTALL_KV_CONNECTORS=false
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
ARG GET_PIP_URL
|
||||
|
||||
# Install system dependencies and uv, then create Python virtual environment
|
||||
# Install system dependencies including build tools
|
||||
RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
|
||||
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
|
||||
&& apt-get update -y \
|
||||
@ -107,32 +105,30 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
|
||||
&& ln -s /opt/venv/bin/pip /usr/bin/pip \
|
||||
&& python3 --version && python3 -m pip --version
|
||||
|
||||
ARG PIP_INDEX_URL UV_INDEX_URL
|
||||
ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL
|
||||
ARG PYTORCH_CUDA_INDEX_BASE_URL
|
||||
ARG PIP_KEYRING_PROVIDER UV_KEYRING_PROVIDER
|
||||
|
||||
# Activate virtual environment and add uv to PATH
|
||||
ENV PATH="/opt/venv/bin:/root/.local/bin:$PATH"
|
||||
ENV VIRTUAL_ENV="/opt/venv"
|
||||
|
||||
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
# Environment for uv
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
ENV UV_INDEX_STRATEGY="unsafe-best-match"
|
||||
# Use copy mode to avoid hardlink failures with Docker cache mounts
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
RUN <<EOF
|
||||
gcc --version
|
||||
EOF
|
||||
# Verify GCC version
|
||||
RUN gcc --version
|
||||
|
||||
# Workaround for https://github.com/openai/triton/issues/2507 and
|
||||
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
|
||||
# this won't be needed for future versions of this docker image
|
||||
# or future versions of triton.
|
||||
# Workaround for triton/pytorch issues
|
||||
RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/
|
||||
|
||||
# ============================================================
|
||||
# SLOW-CHANGING DEPENDENCIES BELOW
|
||||
# These are the expensive layers that we want to cache
|
||||
# ============================================================
|
||||
|
||||
# Install PyTorch and core CUDA dependencies
|
||||
# This is ~2GB and rarely changes
|
||||
ARG PYTORCH_CUDA_INDEX_BASE_URL
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
# install build and runtime dependencies
|
||||
@ -142,13 +138,12 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --python /opt/venv/bin/python3 -r requirements/cuda.txt \
|
||||
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
|
||||
# cuda arch list used by torch
|
||||
# can be useful for both `dev` and `test`
|
||||
# explicitly set the list to avoid issues with torch 2.2
|
||||
# see https://github.com/pytorch/pytorch/pull/123243
|
||||
# CUDA arch list used by torch
|
||||
# Explicitly set the list to avoid issues with torch 2.2
|
||||
# See https://github.com/pytorch/pytorch/pull/123243
|
||||
ARG torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0 10.0 12.0'
|
||||
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
|
||||
#################### BASE BUILD IMAGE ####################
|
||||
#################### BUILD BASE IMAGE ####################
|
||||
|
||||
#################### CSRC BUILD IMAGE ####################
|
||||
FROM base AS csrc-build
|
||||
@ -241,6 +236,48 @@ RUN --mount=type=cache,target=/root/.cache/ccache \
|
||||
fi
|
||||
#################### CSRC BUILD IMAGE ####################
|
||||
|
||||
#################### EXTENSIONS BUILD IMAGE ####################
|
||||
# Build DeepGEMM, pplx-kernels, DeepEP - runs in PARALLEL with csrc-build
|
||||
# This stage is independent and doesn't affect csrc cache
|
||||
FROM base AS extensions-build
|
||||
ARG CUDA_VERSION
|
||||
|
||||
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
ENV UV_INDEX_STRATEGY="unsafe-best-match"
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
# Build DeepGEMM wheel
|
||||
ARG DEEPGEMM_GIT_REF
|
||||
COPY tools/install_deepgemm.sh /tmp/install_deepgemm.sh
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
mkdir -p /tmp/deepgemm/dist && \
|
||||
VLLM_DOCKER_BUILD_CONTEXT=1 TORCH_CUDA_ARCH_LIST="9.0a 10.0a" /tmp/install_deepgemm.sh \
|
||||
--cuda-version "${CUDA_VERSION}" \
|
||||
${DEEPGEMM_GIT_REF:+--ref "$DEEPGEMM_GIT_REF"} \
|
||||
--wheel-dir /tmp/deepgemm/dist || \
|
||||
echo "DeepGEMM build skipped (CUDA version requirement not met)"
|
||||
|
||||
# Ensure the wheel dir exists so COPY won't fail when DeepGEMM is skipped
|
||||
RUN mkdir -p /tmp/deepgemm/dist && touch /tmp/deepgemm/dist/.deepgemm_skipped
|
||||
|
||||
# Build pplx-kernels and DeepEP wheels
|
||||
COPY tools/ep_kernels/install_python_libraries.sh /tmp/install_python_libraries.sh
|
||||
ARG PPLX_COMMIT_HASH
|
||||
ARG DEEPEP_COMMIT_HASH
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
mkdir -p /tmp/ep_kernels_workspace/dist && \
|
||||
export TORCH_CUDA_ARCH_LIST='9.0a 10.0a' && \
|
||||
/tmp/install_python_libraries.sh \
|
||||
--workspace /tmp/ep_kernels_workspace \
|
||||
--mode wheel \
|
||||
${PPLX_COMMIT_HASH:+--pplx-ref "$PPLX_COMMIT_HASH"} \
|
||||
${DEEPEP_COMMIT_HASH:+--deepep-ref "$DEEPEP_COMMIT_HASH"} && \
|
||||
find /tmp/ep_kernels_workspace/nvshmem -name '*.a' -delete
|
||||
#################### EXTENSIONS BUILD IMAGE ####################
|
||||
|
||||
#################### WHEEL BUILD IMAGE ####################
|
||||
FROM base AS build
|
||||
ARG TARGETPLATFORM
|
||||
@ -265,6 +302,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
# Copy pre-built csrc wheel directly
|
||||
COPY --from=csrc-build /workspace/dist /precompiled-wheels
|
||||
|
||||
COPY . .
|
||||
@ -286,27 +324,9 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
fi && \
|
||||
python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38
|
||||
|
||||
# Install DeepGEMM from source
|
||||
ARG DEEPGEMM_GIT_REF
|
||||
COPY tools/install_deepgemm.sh /tmp/install_deepgemm.sh
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
VLLM_DOCKER_BUILD_CONTEXT=1 TORCH_CUDA_ARCH_LIST="9.0a 10.0a" /tmp/install_deepgemm.sh --cuda-version "${CUDA_VERSION}" ${DEEPGEMM_GIT_REF:+--ref "$DEEPGEMM_GIT_REF"} --wheel-dir /tmp/deepgemm/dist
|
||||
|
||||
# Ensure the wheel dir exists so later-stage COPY won't fail when DeepGEMM is skipped
|
||||
RUN mkdir -p /tmp/deepgemm/dist && touch /tmp/deepgemm/dist/.deepgemm_skipped
|
||||
|
||||
COPY tools/ep_kernels/install_python_libraries.sh /tmp/install_python_libraries.sh
|
||||
# Install EP kernels(pplx-kernels and DeepEP)
|
||||
ARG PPLX_COMMIT_HASH
|
||||
ARG DEEPEP_COMMIT_HASH
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
export TORCH_CUDA_ARCH_LIST='9.0a 10.0a' && \
|
||||
/tmp/install_python_libraries.sh \
|
||||
--workspace /tmp/ep_kernels_workspace \
|
||||
--mode wheel \
|
||||
${PPLX_COMMIT_HASH:+--pplx-ref "$PPLX_COMMIT_HASH"} \
|
||||
${DEEPEP_COMMIT_HASH:+--deepep-ref "$DEEPEP_COMMIT_HASH"} && \
|
||||
find /tmp/ep_kernels_workspace/nvshmem -name '*.a' -delete
|
||||
# Copy extension wheels from extensions-build stage for later use
|
||||
COPY --from=extensions-build /tmp/deepgemm/dist /tmp/deepgemm/dist
|
||||
COPY --from=extensions-build /tmp/ep_kernels_workspace/dist /tmp/ep_kernels_workspace/dist
|
||||
|
||||
# Check the size of the wheel if RUN_WHEEL_CHECK is true
|
||||
COPY .buildkite/check-wheel-size.py check-wheel-size.py
|
||||
@ -344,32 +364,25 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --python /opt/venv/bin/python3 -r requirements/dev.txt \
|
||||
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
#################### DEV IMAGE ####################
|
||||
|
||||
#################### vLLM installation IMAGE ####################
|
||||
# image with vLLM installed
|
||||
FROM ${FINAL_BASE_IMAGE} AS vllm-base
|
||||
|
||||
ARG CUDA_VERSION
|
||||
ARG PYTHON_VERSION
|
||||
ARG INSTALL_KV_CONNECTORS=false
|
||||
WORKDIR /vllm-workspace
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ARG TARGETPLATFORM
|
||||
|
||||
# TODO (huydhn): There is no prebuilt gdrcopy package on 12.9 at the moment
|
||||
ARG GDRCOPY_CUDA_VERSION=12.8
|
||||
# Keep in line with FINAL_BASE_IMAGE
|
||||
ARG GDRCOPY_OS_VERSION=Ubuntu22_04
|
||||
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
|
||||
ARG DEADSNAKES_MIRROR_URL
|
||||
ARG DEADSNAKES_GPGKEY_URL
|
||||
ARG GET_PIP_URL
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
WORKDIR /vllm-workspace
|
||||
|
||||
|
||||
# Python version string for paths (e.g., "312" for 3.12)
|
||||
RUN PYTHON_VERSION_STR=$(echo ${PYTHON_VERSION} | sed 's/\.//g') && \
|
||||
echo "export PYTHON_VERSION_STR=${PYTHON_VERSION_STR}" >> /etc/environment
|
||||
|
||||
# Install Python and other dependencies
|
||||
# Install Python and system dependencies
|
||||
RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
|
||||
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
|
||||
&& apt-get update -y \
|
||||
@ -408,63 +421,104 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
|
||||
&& curl -sS ${GET_PIP_URL} | python${PYTHON_VERSION} \
|
||||
&& python3 --version && python3 -m pip --version
|
||||
|
||||
# Install CUDA development tools and build essentials for runtime JIT compilation
|
||||
# Install CUDA development tools for runtime JIT compilation
|
||||
# (FlashInfer, DeepGEMM, EP kernels all require compilation at runtime)
|
||||
RUN CUDA_VERSION_DASH=$(echo $CUDA_VERSION | cut -d. -f1,2 | tr '.' '-') && \
|
||||
apt-get update -y && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
cuda-nvcc-${CUDA_VERSION_DASH} \
|
||||
cuda-cudart-${CUDA_VERSION_DASH} \
|
||||
cuda-nvrtc-${CUDA_VERSION_DASH} \
|
||||
cuda-cuobjdump-${CUDA_VERSION_DASH} \
|
||||
# https://github.com/vllm-project/vllm/issues/29590
|
||||
libcurand-dev-${CUDA_VERSION_DASH} \
|
||||
libcublas-${CUDA_VERSION_DASH} \
|
||||
# Fixes nccl_allocator requiring nccl.h at runtime
|
||||
# https://github.com/vllm-project/vllm/blob/1336a1ea244fa8bfd7e72751cabbdb5b68a0c11a/vllm/distributed/device_communicators/pynccl_allocator.py#L22
|
||||
libnccl-dev && \
|
||||
cuda-nvcc-${CUDA_VERSION_DASH} \
|
||||
cuda-cudart-${CUDA_VERSION_DASH} \
|
||||
cuda-nvrtc-${CUDA_VERSION_DASH} \
|
||||
cuda-cuobjdump-${CUDA_VERSION_DASH} \
|
||||
libcurand-dev-${CUDA_VERSION_DASH} \
|
||||
libcublas-${CUDA_VERSION_DASH} \
|
||||
# Fixes nccl_allocator requiring nccl.h at runtime
|
||||
# https://github.com/vllm-project/vllm/blob/1336a1ea244fa8bfd7e72751cabbdb5b68a0c11a/vllm/distributed/device_communicators/pynccl_allocator.py#L22
|
||||
libnccl-dev && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install uv for faster pip installs
|
||||
RUN python3 -m pip install uv
|
||||
|
||||
# Environment for uv
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
ENV UV_INDEX_STRATEGY="unsafe-best-match"
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
# Workaround for triton/pytorch issues
|
||||
RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/
|
||||
|
||||
# ============================================================
|
||||
# SLOW-CHANGING DEPENDENCIES BELOW
|
||||
# These are the expensive layers that we want to cache
|
||||
# ============================================================
|
||||
|
||||
# Install PyTorch and core CUDA dependencies
|
||||
# This is ~2GB and rarely changes
|
||||
ARG PYTORCH_CUDA_INDEX_BASE_URL
|
||||
COPY requirements/common.txt /tmp/common.txt
|
||||
COPY requirements/cuda.txt /tmp/requirements-cuda.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r /tmp/requirements-cuda.txt \
|
||||
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') && \
|
||||
rm /tmp/requirements-cuda.txt /tmp/common.txt
|
||||
|
||||
# Install FlashInfer pre-compiled kernel cache and binaries
|
||||
# This is ~1.1GB and only changes when FlashInfer version bumps
|
||||
# https://docs.flashinfer.ai/installation.html
|
||||
ARG FLASHINFER_VERSION=0.5.3
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system flashinfer-cubin==${FLASHINFER_VERSION} \
|
||||
&& uv pip install --system flashinfer-jit-cache==${FLASHINFER_VERSION} \
|
||||
--extra-index-url https://flashinfer.ai/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \
|
||||
&& flashinfer show-config
|
||||
|
||||
# ============================================================
|
||||
# OPENAI API SERVER DEPENDENCIES
|
||||
# Pre-install these to avoid reinstalling on every vLLM wheel rebuild
|
||||
# ============================================================
|
||||
|
||||
# Install gdrcopy (saves ~6s per build)
|
||||
# TODO (huydhn): There is no prebuilt gdrcopy package on 12.9 at the moment
|
||||
ARG GDRCOPY_CUDA_VERSION=12.8
|
||||
ARG GDRCOPY_OS_VERSION=Ubuntu22_04
|
||||
ARG TARGETPLATFORM
|
||||
COPY tools/install_gdrcopy.sh /tmp/install_gdrcopy.sh
|
||||
RUN set -eux; \
|
||||
case "${TARGETPLATFORM}" in \
|
||||
linux/arm64) UUARCH="aarch64" ;; \
|
||||
linux/amd64) UUARCH="x64" ;; \
|
||||
*) echo "Unsupported TARGETPLATFORM: ${TARGETPLATFORM}" >&2; exit 1 ;; \
|
||||
esac; \
|
||||
/tmp/install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "${GDRCOPY_CUDA_VERSION}" "${UUARCH}" && \
|
||||
rm /tmp/install_gdrcopy.sh
|
||||
|
||||
# Install vllm-openai dependencies (saves ~2.6s per build)
|
||||
# These are stable packages that don't depend on vLLM itself
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
|
||||
BITSANDBYTES_VERSION="0.42.0"; \
|
||||
else \
|
||||
BITSANDBYTES_VERSION="0.46.1"; \
|
||||
fi; \
|
||||
uv pip install --system accelerate hf_transfer modelscope \
|
||||
"bitsandbytes>=${BITSANDBYTES_VERSION}" 'timm>=1.0.17' 'runai-model-streamer[s3,gcs]>=0.15.3'
|
||||
|
||||
# ============================================================
|
||||
# VLLM INSTALLATION (depends on build stage)
|
||||
# ============================================================
|
||||
|
||||
ARG PIP_INDEX_URL UV_INDEX_URL
|
||||
ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL
|
||||
ARG PYTORCH_CUDA_INDEX_BASE_URL
|
||||
ARG PIP_KEYRING_PROVIDER UV_KEYRING_PROVIDER
|
||||
|
||||
# Install uv for faster pip installs
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
python3 -m pip install uv
|
||||
|
||||
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
ENV UV_INDEX_STRATEGY="unsafe-best-match"
|
||||
# Use copy mode to avoid hardlink failures with Docker cache mounts
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
# Workaround for https://github.com/openai/triton/issues/2507 and
|
||||
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
|
||||
# this won't be needed for future versions of this docker image
|
||||
# or future versions of triton.
|
||||
RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/
|
||||
|
||||
# Install vllm wheel first, so that torch etc will be installed.
|
||||
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
|
||||
--mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system dist/*.whl --verbose \
|
||||
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
|
||||
# Install FlashInfer pre-compiled kernel cache and binaries
|
||||
# https://docs.flashinfer.ai/installation.html
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system flashinfer-cubin==0.5.3 \
|
||||
&& uv pip install --system flashinfer-jit-cache==0.5.3 \
|
||||
--extra-index-url https://flashinfer.ai/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \
|
||||
&& flashinfer show-config
|
||||
|
||||
COPY examples examples
|
||||
COPY benchmarks benchmarks
|
||||
COPY ./vllm/collect_env.py .
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
. /etc/environment && \
|
||||
uv pip list
|
||||
@ -478,7 +532,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
echo "No DeepGEMM wheels to install; skipping."; \
|
||||
fi'
|
||||
|
||||
# Pytorch now installs NVSHMEM, setting LD_LIBRARY_PATH (https://github.com/pytorch/pytorch/blob/d38164a545b4a4e4e0cf73ce67173f70574890b6/.ci/manywheel/build_cuda.sh#L141C14-L141C36)
|
||||
# Pytorch now installs NVSHMEM, setting LD_LIBRARY_PATH
|
||||
ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
|
||||
|
||||
# Install EP kernels wheels (pplx-kernels and DeepEP) that have been built in the `build` stage
|
||||
@ -487,23 +541,17 @@ RUN --mount=type=bind,from=build,src=/tmp/ep_kernels_workspace/dist,target=/vllm
|
||||
uv pip install --system ep_kernels/dist/*.whl --verbose \
|
||||
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
|
||||
RUN --mount=type=bind,source=tools/install_gdrcopy.sh,target=/tmp/install_gdrcopy.sh,ro \
|
||||
set -eux; \
|
||||
case "${TARGETPLATFORM}" in \
|
||||
linux/arm64) UUARCH="aarch64" ;; \
|
||||
linux/amd64) UUARCH="x64" ;; \
|
||||
*) echo "Unsupported TARGETPLATFORM: ${TARGETPLATFORM}" >&2; exit 1 ;; \
|
||||
esac; \
|
||||
/tmp/install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "${GDRCOPY_CUDA_VERSION}" "${UUARCH}"
|
||||
|
||||
# CUDA image changed from /usr/local/nvidia to /usr/local/cuda in 12.8 but will
|
||||
# return to /usr/local/nvidia in 13.0 to allow container providers to mount drivers
|
||||
# consistently from the host (see https://github.com/vllm-project/vllm/issues/18859).
|
||||
# Until then, add /usr/local/nvidia/lib64 before the image cuda path to allow override.
|
||||
ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib64:${LD_LIBRARY_PATH}
|
||||
|
||||
# Copy examples and benchmarks at the end to minimize cache invalidation
|
||||
COPY examples examples
|
||||
COPY benchmarks benchmarks
|
||||
COPY ./vllm/collect_env.py .
|
||||
#################### vLLM installation IMAGE ####################
|
||||
|
||||
#################### TEST IMAGE ####################
|
||||
# image to run unit testing suite
|
||||
# note that this uses vllm installed by `pip`
|
||||
@ -569,18 +617,12 @@ ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
|
||||
# install additional dependencies for openai api server
|
||||
# install kv_connectors if requested
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=requirements/kv_connectors.txt,target=/tmp/kv_connectors.txt,ro \
|
||||
if [ "$INSTALL_KV_CONNECTORS" = "true" ]; then \
|
||||
uv pip install --system -r /tmp/kv_connectors.txt; \
|
||||
fi; \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
|
||||
BITSANDBYTES_VERSION="0.42.0"; \
|
||||
else \
|
||||
BITSANDBYTES_VERSION="0.46.1"; \
|
||||
fi; \
|
||||
uv pip install --system accelerate hf_transfer modelscope "bitsandbytes>=${BITSANDBYTES_VERSION}" 'timm>=1.0.17' 'runai-model-streamer[s3,gcs]>=0.15.3'
|
||||
uv pip install --system -r /tmp/kv_connectors.txt || true; \
|
||||
fi
|
||||
|
||||
ENV VLLM_USAGE_SOURCE production-docker-image
|
||||
|
||||
|
||||
@ -76,6 +76,9 @@ RUN python3 -m pip install -e tests/vllm_test_utils
|
||||
ENV NIXL_VERSION=0.7.0
|
||||
RUN python3 /workspace/vllm/tools/install_nixl_from_source_ubuntu.py
|
||||
|
||||
# PyJWT-2.7.0 will influence some wheel behaviors, remove its dist-info to avoid conflicts
|
||||
RUN rm /usr/lib/python3/dist-packages/PyJWT-2.7.0.dist-info/ -rf
|
||||
|
||||
# remove torch bundled oneccl to avoid conflicts
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip uninstall oneccl oneccl-devel -y
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 174 KiB After Width: | Height: | Size: 205 KiB |
@ -84,7 +84,7 @@ Total input tokens: 1369
|
||||
Total generated tokens: 2212
|
||||
Request throughput (req/s): 1.73
|
||||
Output token throughput (tok/s): 382.89
|
||||
Total Token throughput (tok/s): 619.85
|
||||
Total token throughput (tok/s): 619.85
|
||||
---------------Time to First Token----------------
|
||||
Mean TTFT (ms): 71.54
|
||||
Median TTFT (ms): 73.88
|
||||
|
||||
@ -24,11 +24,13 @@ Compute Resources:
|
||||
- Databricks
|
||||
- DeepInfra
|
||||
- Google Cloud
|
||||
- IBM
|
||||
- Intel
|
||||
- Lambda Lab
|
||||
- Nebius
|
||||
- Novita AI
|
||||
- NVIDIA
|
||||
- Red Hat
|
||||
- Replicate
|
||||
- Roblox
|
||||
- RunPod
|
||||
|
||||
@ -7,7 +7,7 @@ This guide covers optimization strategies and performance tuning for vLLM V1.
|
||||
|
||||
## Preemption
|
||||
|
||||
Due to the auto-regressive nature of transformer architecture, there are times when KV cache space is insufficient to handle all batched requests.
|
||||
Due to the autoregressive nature of transformer architecture, there are times when KV cache space is insufficient to handle all batched requests.
|
||||
In such cases, vLLM can preempt requests to free up KV cache space for other requests. Preempted requests are recomputed when sufficient KV cache space becomes
|
||||
available again. When this occurs, you may see the following warning:
|
||||
|
||||
|
||||
@ -82,7 +82,7 @@ DOCKER_BUILDKIT=1 docker build . \
|
||||
|
||||
## Building for Arm64/aarch64
|
||||
|
||||
A docker container can be built for aarch64 systems such as the Nvidia Grace-Hopper. At time of this writing, this should be considered **experimental**. Using the flag `--platform "linux/arm64"` will attempt to build for arm64.
|
||||
A docker container can be built for aarch64 systems such as the Nvidia Grace-Hopper and Grace-Blackwell. Using the flag `--platform "linux/arm64"` will build for arm64.
|
||||
|
||||
!!! note
|
||||
Multiple modules must be compiled, so this process can take a while. Recommend using `--build-arg max_jobs=` & `--build-arg nvcc_threads=`
|
||||
@ -104,6 +104,25 @@ A docker container can be built for aarch64 systems such as the Nvidia Grace-Hop
|
||||
--build-arg RUN_WHEEL_CHECK=false
|
||||
```
|
||||
|
||||
For (G)B300, we recommend using CUDA 13, as shown in the following command.
|
||||
|
||||
??? console "Command"
|
||||
|
||||
```bash
|
||||
DOCKER_BUILDKIT=1 docker build \
|
||||
--build-arg CUDA_VERSION=13.0.1 \
|
||||
--build-arg BUILD_BASE_IMAGE=nvidia/cuda:13.0.1-devel-ubuntu22.04 \
|
||||
--build-arg max_jobs=256 \
|
||||
--build-arg nvcc_threads=2 \
|
||||
--build-arg RUN_WHEEL_CHECK=false \
|
||||
--build-arg torch_cuda_arch_list='9.0 10.0+PTX' \
|
||||
--platform "linux/arm64" \
|
||||
--tag vllm/vllm-gb300-openai:latest \
|
||||
--target vllm-openai \
|
||||
-f docker/Dockerfile \
|
||||
.
|
||||
```
|
||||
|
||||
!!! note
|
||||
If you are building the `linux/arm64` image on a non-ARM host (e.g., an x86_64 machine), you need to ensure your system is set up for cross-compilation using QEMU. This allows your host machine to emulate ARM64 execution.
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ Deploying vLLM on Kubernetes is a scalable and efficient way to serve machine le
|
||||
|
||||
* **Upstream vLLM compatibility** – It wraps around upstream vLLM without modifying its code.
|
||||
* **Ease of use** – Simplified deployment via Helm charts and observability through Grafana dashboards.
|
||||
* **High performance** – Optimized for LLM workloads with features like multi-model support, model-aware and prefix-aware routing, fast vLLM bootstrapping, and KV cache offloading with [LMCache](https://github.com/LMCache/LMCache), among others.
|
||||
* **High performance** – Optimized for LLM workloads with features like multimodel support, model-aware and prefix-aware routing, fast vLLM bootstrapping, and KV cache offloading with [LMCache](https://github.com/LMCache/LMCache), among others.
|
||||
|
||||
If you are new to Kubernetes, don't worry: in the vLLM production stack [repo](https://github.com/vllm-project/production-stack), we provide a step-by-step [guide](https://github.com/vllm-project/production-stack/blob/main/tutorials/00-install-kubernetes-env.md) and a [short video](https://www.youtube.com/watch?v=EsTJbQtzj0g) to set up everything and get started in **4 minutes**!
|
||||
|
||||
|
||||
@ -41,7 +41,7 @@ These features allow the most flexibility for cudagraph capture and compilation
|
||||
* `NONE` — turn CUDA Graphs off. Good for debugging.
|
||||
* `PIECEWISE` — a single-mode strategy (and past default). It is the most flexible: attention or other CUDA Graphs-incompatible operations stay eager, everything else goes into CUDA Graphs. Requires piecewise compilation.
|
||||
* `FULL` — a single-mode strategy, which only captures full CUDA Graphs for non-uniform batches, then uniform-decode batches reuse the CUDA Graph of non-uniform batch of the same batch_size, since they are compatible; can be good for small models or workloads with small prompts.
|
||||
* `FULL_DECODE_ONLY` — full CUDA Graph for uniform decode, no cudagraph for prefill/mixed etc; suitable for decode instances in a P/D setup where prefill is not as important, this way we can save the memory needed for `PIECEWISE` CUDA Graphs.
|
||||
* `FULL_DECODE_ONLY` — full CUDA Graph for uniform decode, no cudagraph for prefill/mixed etc.; suitable for decode instances in a P/D setup where prefill is not as important, this way we can save the memory needed for `PIECEWISE` CUDA Graphs.
|
||||
* `FULL_AND_PIECEWISE` — (default mode) full CUDA Graph for uniform decode, piecewise CUDA Graphs for others; generally the most performant setting, especially for low latency with small models or MoEs, but also requires the most memory and takes the longest to capture.
|
||||
|
||||
Defaults: If you’re on v1 with piecewise compilation, we default to `FULL_AND_PIECEWISE` for better performance, (for pooling models, it's still `PIECEWISE`). Otherwise, e.g. if piecewise compilation unavailable, we default to `NONE`.
|
||||
@ -49,7 +49,7 @@ Defaults: If you’re on v1 with piecewise compilation, we default to `FULL_AND_
|
||||
While `NONE` , `PIECEWISE`, and `FULL` are single-mode configurations and simply equivalent to past implementations of eager execution, piecewise CUDA Graphs, and full CUDA Graphs respectively, `FULL_DECODE_ONLY` and `FULL_AND_PIECEWISE` are newly appended dual-mode configurations, which require dispatching to switch between concrete runtime modes according to runtime batches dynamically.
|
||||
|
||||
!!! note
|
||||
Here, the single-modes `NONE`, `PIECEWISE`, and `FULL` are treated as the runtime modes for CUDA Graphs dispatching. If using a dual-mode, the dispatcher will always dispatch to one of its member modes (plus a potantial `NONE` if no suitable CUDA Graph available), depending on the batch composition.
|
||||
Here, the single-modes `NONE`, `PIECEWISE`, and `FULL` are treated as the runtime modes for CUDA Graphs dispatching. If using a dual-mode, the dispatcher will always dispatch to one of its member modes (plus a potential `NONE` if no suitable CUDA Graph available), depending on the batch composition.
|
||||
|
||||
While cascade attention is not cudagraph compatible, it is now compatible with all possible cudagraph mode configurations. If a batch uses cascade attention, it always gets dispatched to `PIECEWISE` mode if available (otherwise `NONE`).
|
||||
|
||||
|
||||
@ -21,30 +21,20 @@ The mental model is that server-level metrics help explain the values of request
|
||||
|
||||
### v1 Metrics
|
||||
|
||||
In v1, the following metrics are exposed via a Prometheus-compatible `/metrics` endpoint using the `vllm:` prefix:
|
||||
In v1, an extensive set of metrics are exposed via a Prometheus-compatible `/metrics` endpoint using the `vllm:` prefix, for example:
|
||||
|
||||
- `vllm:num_requests_running` (Gauge) - Number of requests currently running.
|
||||
- `vllm:num_requests_waiting` (Gauge) - Number of requests currently waiting.
|
||||
- `vllm:kv_cache_usage_perc` (Gauge) - Fraction of used KV cache blocks (0–1).
|
||||
- `vllm:prefix_cache_queries` (Counter) - Number of prefix cache queries.
|
||||
- `vllm:prefix_cache_hits` (Counter) - Number of prefix cache hits.
|
||||
- `vllm:mm_cache_queries` (Counter) - (For multimodal models) Number of multimodal cache queries.
|
||||
- `vllm:mm_cache_hits` (Counter) - (For multimodal models) Number of multimodal cache hits.
|
||||
- `vllm:num_preemptions_total` (Counter) - Number of preemptions.
|
||||
- `vllm:prompt_tokens_total` (Counter) - Total number of prompt tokens processed.
|
||||
- `vllm:generation_tokens_total` (Counter) - Total number of generated tokens.
|
||||
- `vllm:iteration_tokens_total` (Histogram) - Histogram of tokens processed in each engine step.
|
||||
- `vllm:cache_config_info` (Gauge) - Information about the cache configuration.
|
||||
- `vllm:request_success_total` (Counter) - Number of finished requests (by finish reason).
|
||||
- `vllm:request_prompt_tokens` (Histogram) - Histogram of input prompt token counts.
|
||||
- `vllm:request_generation_tokens` (Histogram) - Histogram of generation token counts.
|
||||
- `vllm:request_params_n` (Histogram) - Histogram of request parameter n.
|
||||
- `vllm:request_params_max_tokens` - (Histogram) - Histogram of max_tokens parameter in requests.
|
||||
- `vllm:time_to_first_token_seconds` (Histogram) - Time to first token (TTFT).
|
||||
- `vllm:inter_token_latency_seconds` (Histogram) - Inter-token latency.
|
||||
- `vllm:e2e_request_latency_seconds` (Histogram) - End-to-end request latency.
|
||||
- `vllm:request_queue_time_seconds` (Histogram) - Time spent in the queue.
|
||||
- `vllm:request_inference_time_seconds` (Histogram) - Request inference time.
|
||||
- `vllm:request_prefill_time_seconds` (Histogram) - Request prefill time.
|
||||
- `vllm:request_decode_time_seconds` (Histogram) - Request decode time.
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
|
||||
## Overview
|
||||
|
||||
vLLM now supports optimization levels (`-O0`, `-O1`, `-O2`, `-O3`). Optimization levels provide an intuitive mechnaism for users to trade startup time for performance. Higher levels have better performance but worse startup time. These optimization levels have associated defaults to help users get desired out of the box performance. Importantly, defaults set by optimization levels are purely defaults; explicit user settings will not be overwritten.
|
||||
vLLM now supports optimization levels (`-O0`, `-O1`, `-O2`, `-O3`). Optimization levels provide an intuitive mechanism for users to trade startup time for performance. Higher levels have better performance but worse startup time. These optimization levels have associated defaults to help users get desired out-of-the-box performance. Importantly, defaults set by optimization levels are purely defaults; explicit user settings will not be overwritten.
|
||||
|
||||
## Level Summaries and Usage Examples
|
||||
```bash
|
||||
|
||||
@ -36,7 +36,7 @@ the input pointers `q`, `k_cache`, and `v_cache`, which point
|
||||
to query, key, and value data on global memory that need to be read
|
||||
and processed. The output pointer `out` points to global memory
|
||||
where the result should be written. These four pointers actually
|
||||
refer to multi-dimensional arrays, but each thread only accesses the
|
||||
refer to multidimensional arrays, but each thread only accesses the
|
||||
portion of data assigned to it. I have omitted all other runtime
|
||||
parameters here for simplicity.
|
||||
|
||||
@ -229,7 +229,7 @@ manner.
|
||||
|
||||
## QK
|
||||
|
||||
As shown the pseudo code below, before the entire for loop block, we
|
||||
As shown the pseudocode below, before the entire for loop block, we
|
||||
fetch the query data for one token and store it in `q_vecs`. Then,
|
||||
in the outer for loop, we iterate through different `k_ptrs` that
|
||||
point to different tokens and prepare the `k_vecs` in the inner for
|
||||
@ -403,7 +403,7 @@ for ... { // Iteration over different blocks.
|
||||
}
|
||||
```
|
||||
|
||||
As shown in the above pseudo code, in the outer loop, similar to
|
||||
As shown in the above pseudocode, in the outer loop, similar to
|
||||
`k_ptr`, `logits_vec` iterates over different blocks and reads
|
||||
`V_VEC_SIZE` elements from `logits`. In the inner loop, each
|
||||
thread reads `V_VEC_SIZE` elements from the same tokens as a
|
||||
|
||||
@ -152,5 +152,5 @@ The interface for the model/module may change during vLLM's development. If you
|
||||
## Deprecation announcement
|
||||
|
||||
!!! warning "Deprecations"
|
||||
- `use_v1` parameter in `Platform.get_attn_backend_cls` is deprecated. It will be removed in v0.13.0 or v1.0.0.
|
||||
- `_Backend` in `vllm.attention` is deprecated. It will be removed in v0.13.0 or v1.0.0. Please use `vllm.attention.backends.registry.register_backend` to add new attention backend to `AttentionBackendEnum` instead.
|
||||
- `use_v1` parameter in `Platform.get_attn_backend_cls` is deprecated. It has been removed in v0.13.0.
|
||||
- `_Backend` in `vllm.attention` is deprecated. It has been removed in v0.13.0. Please use `vllm.attention.backends.registry.register_backend` to add new attention backend to `AttentionBackendEnum` instead.
|
||||
|
||||
@ -22,7 +22,7 @@ python tools/install_nixl_from_source_ubuntu.py
|
||||
NixlConnector uses NIXL library for underlying communication, which supports multiple transport backends. UCX (Unified Communication X) is the primary default transport library used by NIXL. Configure transport environment variables:
|
||||
|
||||
```bash
|
||||
# Example UCX configuration, adjust according to your enviroment
|
||||
# Example UCX configuration, adjust according to your environment
|
||||
export UCX_TLS=all # or specify specific transports like "rc,ud,sm,^cuda_ipc" ..etc
|
||||
export UCX_NET_DEVICES=all # or specify network devices like "mlx5_0:1,mlx5_1:1"
|
||||
```
|
||||
|
||||
@ -61,7 +61,7 @@ Now let´s see an example for each of the cases, starting with the `choice`, as
|
||||
print(completion.choices[0].message.content)
|
||||
```
|
||||
|
||||
The next example shows how to use the `regex`. The idea is to generate an email address, given a simple regex template:
|
||||
The next example shows how to use the `regex`. The supported regex syntax depends on the structured output backend. For example, `xgrammar`, `guidance`, and `outlines` use Rust-style regex, while `lm-format-enforcer` uses Python's `re` module. The idea is to generate an email address, given a simple regex template:
|
||||
|
||||
??? code
|
||||
|
||||
|
||||
@ -420,7 +420,7 @@ Flags: `--tool-call-parser pythonic --chat-template {see_above}`
|
||||
|
||||
## How to Write a Tool Parser Plugin
|
||||
|
||||
A tool parser plugin is a Python file containing one or more ToolParser implementations. You can write a ToolParser similar to the `Hermes2ProToolParser` in [vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py](../../vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py).
|
||||
A tool parser plugin is a Python file containing one or more ToolParser implementations. You can write a ToolParser similar to the `Hermes2ProToolParser` in [vllm/tool_parsers/hermes_tool_parser.py](../../vllm/tool_parsers/hermes_tool_parser.py).
|
||||
|
||||
Here is a summary of a plugin file:
|
||||
|
||||
@ -468,7 +468,7 @@ Here is a summary of a plugin file:
|
||||
# register the tool parser to ToolParserManager
|
||||
ToolParserManager.register_lazy_module(
|
||||
name="example",
|
||||
module_path="vllm.entrypoints.openai.tool_parsers.example",
|
||||
module_path="vllm.tool_parsers.example",
|
||||
class_name="ExampleToolParser",
|
||||
)
|
||||
|
||||
|
||||
@ -26,3 +26,4 @@ The backends below live **outside** the main `vllm` repository and follow the
|
||||
| Rebellions ATOM / REBEL NPU | `vllm-rbln` | <https://github.com/rebellions-sw/vllm-rbln> |
|
||||
| IBM Spyre AIU | `vllm-spyre` | <https://github.com/vllm-project/vllm-spyre> |
|
||||
| Cambricon MLU | `vllm-mlu` | <https://github.com/Cambricon/vllm-mlu> |
|
||||
| Baidu Kunlun XPU | N/A, install from source | <https://github.com/baidu/vLLM-Kunlun> |
|
||||
|
||||
@ -16,21 +16,48 @@ vLLM offers basic model inferencing and serving on Arm CPU platform, with suppor
|
||||
# --8<-- [start:pre-built-wheels]
|
||||
|
||||
Pre-built vLLM wheels for Arm are available since version 0.11.2. These wheels contain pre-compiled C++ binaries.
|
||||
Please replace `<version>` in the commands below with a specific version string (e.g., `0.11.2`).
|
||||
|
||||
```bash
|
||||
uv pip install --pre vllm==<version>+cpu --extra-index-url https://wheels.vllm.ai/<version>%2Bcpu/
|
||||
export VLLM_VERSION=$(curl -s https://api.github.com/repos/vllm-project/vllm/releases/latest | jq -r .tag_name | sed 's/^v//')
|
||||
uv pip install vllm --extra-index-url https://wheels.vllm.ai/${VLLM_VERSION}/cpu
|
||||
```
|
||||
|
||||
??? console "pip"
|
||||
```bash
|
||||
pip install --pre vllm==<version>+cpu --extra-index-url https://wheels.vllm.ai/<version>%2Bcpu/
|
||||
pip install vllm==${VLLM_VERSION}+cpu --extra-index-url https://wheels.vllm.ai/${VLLM_VERSION}/cpu
|
||||
```
|
||||
|
||||
The `uv` approach works for vLLM `v0.6.6` and later. A unique feature of `uv` is that packages in `--extra-index-url` have [higher priority than the default index](https://docs.astral.sh/uv/pip/compatibility/#packages-that-exist-on-multiple-indexes). If the latest public release is `v0.6.6.post1`, `uv`'s behavior allows installing a commit before `v0.6.6.post1` by specifying the `--extra-index-url`. In contrast, `pip` combines packages from `--extra-index-url` and the default index, choosing only the latest version, which makes it difficult to install a development version prior to the released version.
|
||||
|
||||
!!! note
|
||||
Nightly wheels are currently unsupported for this architecture. (e.g. to bisect the behavior change, performance regression).
|
||||
**Install the latest code**
|
||||
|
||||
LLM inference is a fast-evolving field, and the latest code may contain bug fixes, performance improvements, and new features that are not released yet. To allow users to try the latest code without waiting for the next release, vLLM provides working pre-built Arm CPU wheels for every commit since `v0.11.2` on <https://wheels.vllm.ai/nightly>. For native CPU wheels, this index should be used:
|
||||
|
||||
* `https://wheels.vllm.ai/nightly/cpu/vllm`
|
||||
|
||||
To install from nightly index, run:
|
||||
```bash
|
||||
uv pip install vllm --extra-index-url https://wheels.vllm.ai/nightly/cpu
|
||||
```
|
||||
|
||||
??? console "pip (there's a caveat)"
|
||||
|
||||
Using `pip` to install from nightly indices is _not supported_, because `pip` combines packages from `--extra-index-url` and the default index, choosing only the latest version, which makes it difficult to install a development version prior to the released version. In contrast, `uv` gives the extra index [higher priority than the default index](https://docs.astral.sh/uv/pip/compatibility/#packages-that-exist-on-multiple-indexes).
|
||||
|
||||
If you insist on using `pip`, you have to specify the full URL (link address) of the wheel file (which can be obtained from https://wheels.vllm.ai/nightly/cpu/vllm).
|
||||
|
||||
```bash
|
||||
pip install https://wheels.vllm.ai/4fa7ce46f31cbd97b4651694caf9991cc395a259/vllm-0.13.0rc2.dev104%2Bg4fa7ce46f.cpu-cp38-abi3-manylinux_2_35_aarch64.whl # current nightly build (the filename will change!)
|
||||
```
|
||||
|
||||
**Install specific revisions**
|
||||
|
||||
If you want to access the wheels for previous commits (e.g. to bisect the behavior change, performance regression), you can specify the commit hash in the URL:
|
||||
|
||||
```bash
|
||||
export VLLM_COMMIT=730bd35378bf2a5b56b6d3a45be28b3092d26519 # use full commit hash from the main branch
|
||||
uv pip install vllm --extra-index-url https://wheels.vllm.ai/${VLLM_COMMIT}/cpu
|
||||
```
|
||||
|
||||
# --8<-- [end:pre-built-wheels]
|
||||
# --8<-- [start:build-wheel-from-source]
|
||||
@ -81,7 +108,23 @@ Testing has been conducted on AWS Graviton3 instances for compatibility.
|
||||
# --8<-- [end:build-wheel-from-source]
|
||||
# --8<-- [start:pre-built-images]
|
||||
|
||||
Currently, there are no pre-built Arm CPU images.
|
||||
See [Using Docker](../../deployment/docker.md) for instructions on using the official Docker image.
|
||||
|
||||
Stable vLLM Docker images are being pre-built for Arm from version 0.12.0. Available image tags are here: [https://gallery.ecr.aws/q9t5s3a7/vllm-arm64-cpu-release-repo](https://gallery.ecr.aws/q9t5s3a7/vllm-arm64-cpu-release-repo).
|
||||
|
||||
```bash
|
||||
export VLLM_VERSION=$(curl -s https://api.github.com/repos/vllm-project/vllm/releases/latest | jq -r .tag_name | sed 's/^v//')
|
||||
docker pull public.ecr.aws/q9t5s3a7/vllm-arm64-cpu-release-repo:v${VLLM_VERSION}
|
||||
```
|
||||
|
||||
You can also access the latest code with Docker images. These are not intended for production use and are meant for CI and testing only. They will expire after several days.
|
||||
|
||||
The latest code can contain bugs and may not be stable. Please use it with caution.
|
||||
|
||||
```bash
|
||||
export VLLM_COMMIT=6299628d326f429eba78736acb44e76749b281f5 # use full commit hash from the main branch
|
||||
docker pull public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:${VLLM_COMMIT}-arm64-cpu
|
||||
```
|
||||
|
||||
# --8<-- [end:pre-built-images]
|
||||
# --8<-- [start:build-image-from-source]
|
||||
|
||||
@ -281,17 +281,27 @@ Alternatively, you can use the `openai` Python package:
|
||||
|
||||
Currently, vLLM supports multiple backends for efficient Attention computation across different platforms and accelerator architectures. It automatically selects the most performant backend compatible with your system and model specifications.
|
||||
|
||||
If desired, you can also manually set the backend of your choice by configuring the environment variable `VLLM_ATTENTION_BACKEND` to one of the following options:
|
||||
If desired, you can also manually set the backend of your choice using the `--attention-backend` CLI argument:
|
||||
|
||||
```bash
|
||||
# For online serving
|
||||
vllm serve Qwen/Qwen2.5-1.5B-Instruct --attention-backend FLASH_ATTN
|
||||
|
||||
# For offline inference
|
||||
python script.py --attention-backend FLASHINFER
|
||||
```
|
||||
|
||||
Some of the available backend options include:
|
||||
|
||||
- On NVIDIA CUDA: `FLASH_ATTN` or `FLASHINFER`.
|
||||
- On AMD ROCm: `TRITON_ATTN`, `ROCM_ATTN`, `ROCM_AITER_FA` or `ROCM_AITER_UNIFIED_ATTN`.
|
||||
|
||||
For AMD ROCm, you can further control the specific Attention implementation using the following variables:
|
||||
For AMD ROCm, you can further control the specific Attention implementation using the following options:
|
||||
|
||||
- Triton Unified Attention: `VLLM_ROCM_USE_AITER=0 VLLM_V1_USE_PREFILL_DECODE_ATTENTION=0 VLLM_ROCM_USE_AITER_MHA=0`
|
||||
- AITER Unified Attention: `VLLM_ROCM_USE_AITER=1 VLLM_USE_AITER_UNIFIED_ATTENTION=1 VLLM_V1_USE_PREFILL_DECODE_ATTENTION=0 VLLM_ROCM_USE_AITER_MHA=0`
|
||||
- Triton Prefill-Decode Attention: `VLLM_ROCM_USE_AITER=1 VLLM_V1_USE_PREFILL_DECODE_ATTENTION=1 VLLM_ROCM_USE_AITER_MHA=0`
|
||||
- AITER Multi-head Attention: `VLLM_ROCM_USE_AITER=1 VLLM_V1_USE_PREFILL_DECODE_ATTENTION=0 VLLM_ROCM_USE_AITER_MHA=1`
|
||||
- Triton Unified Attention: Set the environment variables `VLLM_ROCM_USE_AITER=0 VLLM_ROCM_USE_AITER_MHA=0` and pass `--attention-config.use_prefill_decode_attention=false` as a CLI argument.
|
||||
- AITER Unified Attention: Set the environment variables `VLLM_ROCM_USE_AITER=1 VLLM_USE_AITER_UNIFIED_ATTENTION=1 VLLM_ROCM_USE_AITER_MHA=0` and pass `--attention-config.use_prefill_decode_attention=false` as a CLI argument.
|
||||
- Triton Prefill-Decode Attention: Set the environment variables `VLLM_ROCM_USE_AITER=1 VLLM_ROCM_USE_AITER_MHA=0` and pass `--attention-config.use_prefill_decode_attention=true` as a CLI argument.
|
||||
- AITER Multi-head Attention: Set the environment variables `VLLM_ROCM_USE_AITER=1 VLLM_ROCM_USE_AITER_MHA=1` and pass `--attention-config.use_prefill_decode_attention=false` as a CLI argument.
|
||||
|
||||
!!! warning
|
||||
There are no pre-built vllm wheels containing Flash Infer, so you must install it in your environment first. Refer to the [Flash Infer official docs](https://docs.flashinfer.ai/) or see [docker/Dockerfile](../../docker/Dockerfile) for instructions on how to install it.
|
||||
|
||||
149
docs/mkdocs/hooks/generate_metrics.py
Normal file
149
docs/mkdocs/hooks/generate_metrics.py
Normal file
@ -0,0 +1,149 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import ast
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
logger = logging.getLogger("mkdocs")
|
||||
|
||||
ROOT_DIR = Path(__file__).parent.parent.parent.parent
|
||||
DOCS_DIR = ROOT_DIR / "docs"
|
||||
GENERATED_METRICS_DIR = DOCS_DIR / "generated" / "metrics"
|
||||
|
||||
# Files to scan for metric definitions - each will generate a separate table
|
||||
METRIC_SOURCE_FILES = [
|
||||
{"path": "vllm/v1/metrics/loggers.py", "output": "general.md"},
|
||||
{
|
||||
"path": "vllm/v1/spec_decode/metrics.py",
|
||||
"output": "spec_decode.md",
|
||||
},
|
||||
{
|
||||
"path": "vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py",
|
||||
"output": "nixl_connector.md",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class MetricExtractor(ast.NodeVisitor):
|
||||
"""AST visitor to extract metric definitions."""
|
||||
|
||||
def __init__(self):
|
||||
self.metrics: list[dict[str, str]] = []
|
||||
|
||||
def visit_Call(self, node: ast.Call) -> None:
|
||||
"""Visit function calls to find metric class instantiations."""
|
||||
metric_type = self._get_metric_type(node)
|
||||
if metric_type:
|
||||
name = self._extract_kwarg(node, "name")
|
||||
documentation = self._extract_kwarg(node, "documentation")
|
||||
|
||||
if name:
|
||||
self.metrics.append(
|
||||
{
|
||||
"name": name,
|
||||
"type": metric_type,
|
||||
"documentation": documentation or "",
|
||||
}
|
||||
)
|
||||
|
||||
self.generic_visit(node)
|
||||
|
||||
def _get_metric_type(self, node: ast.Call) -> str | None:
|
||||
"""Determine if this call creates a metric and return its type."""
|
||||
metric_type_map = {
|
||||
"_gauge_cls": "gauge",
|
||||
"_counter_cls": "counter",
|
||||
"_histogram_cls": "histogram",
|
||||
}
|
||||
if isinstance(node.func, ast.Attribute):
|
||||
return metric_type_map.get(node.func.attr)
|
||||
return None
|
||||
|
||||
def _extract_kwarg(self, node: ast.Call, key: str) -> str | None:
|
||||
"""Extract a keyword argument value from a function call."""
|
||||
for keyword in node.keywords:
|
||||
if keyword.arg == key:
|
||||
return self._get_string_value(keyword.value)
|
||||
return None
|
||||
|
||||
def _get_string_value(self, node: ast.AST) -> str | None:
|
||||
"""Extract string value from an AST node."""
|
||||
if isinstance(node, ast.Constant):
|
||||
return str(node.value) if node.value is not None else None
|
||||
return None
|
||||
|
||||
|
||||
def extract_metrics_from_file(filepath: Path) -> list[dict[str, str]]:
|
||||
"""Parse a Python file and extract all metric definitions."""
|
||||
try:
|
||||
with open(filepath, encoding="utf-8") as f:
|
||||
source = f.read()
|
||||
|
||||
tree = ast.parse(source, filename=str(filepath))
|
||||
extractor = MetricExtractor()
|
||||
extractor.visit(tree)
|
||||
return extractor.metrics
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to parse {filepath}: {e}") from e
|
||||
|
||||
|
||||
def generate_markdown_table(metrics: list[dict[str, str]]) -> str:
|
||||
"""Generate a markdown table from extracted metrics."""
|
||||
if not metrics:
|
||||
return "No metrics found.\n"
|
||||
|
||||
# Sort by type, then by name
|
||||
metrics_sorted = sorted(metrics, key=lambda m: (m["type"], m["name"]))
|
||||
|
||||
lines = []
|
||||
lines.append("| Metric Name | Type | Description |")
|
||||
lines.append("|-------------|------|-------------|")
|
||||
|
||||
for metric in metrics_sorted:
|
||||
name = metric["name"]
|
||||
metric_type = metric["type"].capitalize()
|
||||
doc = metric["documentation"].replace("\n", " ").strip()
|
||||
lines.append(f"| `{name}` | {metric_type} | {doc} |")
|
||||
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
|
||||
def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool):
|
||||
"""Generate metrics documentation tables from source files."""
|
||||
logger.info("Generating metrics documentation")
|
||||
|
||||
# Create generated directory if it doesn't exist
|
||||
GENERATED_METRICS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
total_metrics = 0
|
||||
for source_config in METRIC_SOURCE_FILES:
|
||||
source_path = source_config["path"]
|
||||
output_file = source_config["output"]
|
||||
|
||||
filepath = ROOT_DIR / source_path
|
||||
if not filepath.exists():
|
||||
raise FileNotFoundError(f"Metrics source file not found: {filepath}")
|
||||
|
||||
logger.debug("Extracting metrics from: %s", source_path)
|
||||
metrics = extract_metrics_from_file(filepath)
|
||||
logger.debug("Found %d metrics in %s", len(metrics), source_path)
|
||||
|
||||
# Generate and write the markdown table for this source
|
||||
table_content = generate_markdown_table(metrics)
|
||||
output_path = GENERATED_METRICS_DIR / output_file
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
f.write(table_content)
|
||||
|
||||
total_metrics += len(metrics)
|
||||
logger.info(
|
||||
"Generated metrics table: %s (%d metrics)",
|
||||
output_path.relative_to(ROOT_DIR),
|
||||
len(metrics),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Total metrics generated: %d across %d files",
|
||||
total_metrics,
|
||||
len(METRIC_SOURCE_FILES),
|
||||
)
|
||||
@ -316,10 +316,13 @@ We have split the `encode` task into two more specific token-wise tasks: `token_
|
||||
|
||||
### Remove softmax from PoolingParams
|
||||
|
||||
We are going to remove `softmax` and `activation` from `PoolingParams`. Instead, use `use_activation`, since we allow `classify` and `token_classify` to use any activation function.
|
||||
We are going to remove `softmax` and `activation` from `PoolingParams` in v0.15. Instead, use `use_activation`, since we allow `classify` and `token_classify` to use any activation function.
|
||||
|
||||
### as_reward_model
|
||||
|
||||
!!! warning
|
||||
We are going to remove `--convert reward` in v0.15, use `--convert embed` instead.
|
||||
|
||||
Pooling models now default support all pooling, you can use it without any settings.
|
||||
|
||||
- Extracting hidden states prefers using `token_embed` task.
|
||||
|
||||
@ -568,7 +568,7 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
|
||||
```
|
||||
|
||||
!!! note
|
||||
Load the official original `Qwen3 Reranker` by using the following command. More information can be found at: [examples/pooling/score/qwen3_reranker.py](../../examples/pooling/score/qwen3_reranker.py).
|
||||
Load the official original `Qwen3 Reranker` by using the following command. More information can be found at: [examples/pooling/score/offline_reranker.py](../../examples/pooling/score/offline_reranker.py).
|
||||
|
||||
```bash
|
||||
vllm serve Qwen/Qwen3-Reranker-0.6B --hf_overrides '{"architectures": ["Qwen3ForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}'
|
||||
@ -659,7 +659,9 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
||||
| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
|
||||
|--------------|--------|--------|-------------------|----------------------|---------------------------|
|
||||
| `AriaForConditionalGeneration` | Aria | T + I<sup>+</sup> | `rhymes-ai/Aria` | | |
|
||||
| `AudioFlamingo3ForConditionalGeneration` | AudioFlamingo3 | T + A<sup>+</sup> | `nvidia/audio-flamingo-3-hf`, `nvidia/music-flamingo-hf` | ✅︎ | ✅︎ |
|
||||
| `AyaVisionForConditionalGeneration` | Aya Vision | T + I<sup>+</sup> | `CohereLabs/aya-vision-8b`, `CohereLabs/aya-vision-32b`, etc. | | ✅︎ |
|
||||
| `BagelForConditionalGeneration` | BAGEL | T + I<sup>+</sup> | `ByteDance-Seed/BAGEL-7B-MoT` | ✅︎ | ✅︎ |
|
||||
| `BeeForConditionalGeneration` | Bee-8B | T + I<sup>E+</sup> | `Open-Bee/Bee-8B-RL`, `Open-Bee/Bee-8B-SFT` | | ✅︎ |
|
||||
| `Blip2ForConditionalGeneration` | BLIP-2 | T + I<sup>E</sup> | `Salesforce/blip2-opt-2.7b`, `Salesforce/blip2-opt-6.7b`, etc. | | ✅︎ |
|
||||
| `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b`, etc. | | ✅︎ |
|
||||
@ -743,7 +745,7 @@ Some models are supported only via the [Transformers modeling backend](#transfor
|
||||
- There's no PLE caching or out-of-memory swapping support, as described in [Google's blog](https://developers.googleblog.com/en/introducing-gemma-3n/). These features might be too model-specific for vLLM, and swapping in particular may be better suited for constrained setups.
|
||||
|
||||
!!! note
|
||||
For `InternVLChatModel`, only InternVL2.5 with Qwen2.5 text backbone (`OpenGVLab/InternVL2.5-1B` etc), InternVL3 and InternVL3.5 have video inputs support currently.
|
||||
For `InternVLChatModel`, only InternVL2.5 with Qwen2.5 text backbone (`OpenGVLab/InternVL2.5-1B` etc.), InternVL3 and InternVL3.5 have video inputs support currently.
|
||||
|
||||
!!! note
|
||||
To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have to pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.
|
||||
|
||||
@ -8,11 +8,11 @@ For MoE models, particularly those like DeepSeek that employ MLA (Multi-head Lat
|
||||
|
||||
In these cases, the data parallel ranks are not completely independent. Forward passes must be aligned, and expert layers across all ranks are required to synchronize during every forward pass, even when there are fewer requests to be processed than DP ranks.
|
||||
|
||||
The expert layers will by default form a (DP x TP) sized tensor parallel group. To enable expert parallelism, include the `--enable-expert-parallel` CLI arg (on all nodes in the multi-node case).
|
||||
By default, expert layers form a tensor parallel group of size `DP × TP`. To use expert parallelism instead, include the `--enable-expert-parallel` CLI arg (on all nodes in the multi-node case). See [Expert Parallel Deployment](expert_parallel_deployment.md) for details on how attention and expert layers behave differently with EP enabled.
|
||||
|
||||
In vLLM, each DP rank is deployed as a separate "core engine" process that communicates with front-end process(es) via ZMQ sockets. Data Parallel attention can be combined with Tensor Parallel attention, in which case each DP engine owns a number of per-GPU worker processes equal to the configured TP size.
|
||||
|
||||
For MoE models, when any requests are in progress in any rank, we must ensure that empty "dummy" forward passes are performed in all ranks that don't currently have any requests scheduled. This is handled via a separate DP Coordinator process that communicates with all ranks, and a collective operation performed every N steps to determine when all ranks become idle and can be paused. When TP is used in conjunction with DP, expert layers form an EP or TP group of size (DP x TP).
|
||||
For MoE models, when any requests are in progress in any rank, we must ensure that empty "dummy" forward passes are performed in all ranks that don't currently have any requests scheduled. This is handled via a separate DP Coordinator process that communicates with all ranks, and a collective operation performed every N steps to determine when all ranks become idle and can be paused. When TP is used in conjunction with DP, expert layers form a group of size `DP × TP` (using either tensor parallelism by default, or expert parallelism if `--enable-expert-parallel` is set).
|
||||
|
||||
In all cases, it is beneficial to load-balance requests between DP ranks. For online deployments, this balancing can be optimized by taking into account the state of each DP engine - in particular its currently scheduled and waiting (queued) requests, and KV cache state. Each DP engine has an independent KV cache, and the benefit of prefix caching can be maximized by directing prompts intelligently.
|
||||
|
||||
@ -24,7 +24,7 @@ There are two distinct modes supported for online deployments - self-contained w
|
||||
|
||||
vLLM supports "self-contained" data parallel deployments that expose a single API endpoint.
|
||||
|
||||
It can be configured by simply including e.g. `--data-parallel-size=4` in the vllm serve command line arguments. This will require 4 GPUs. It can be combined with tensor parallel, for example `--data-parallel-size=4 --tensor-parallel-size=2`, which would require 8 GPUs.
|
||||
It can be configured by simply including e.g. `--data-parallel-size=4` in the vllm serve command line arguments. This will require 4 GPUs. It can be combined with tensor parallel, for example `--data-parallel-size=4 --tensor-parallel-size=2`, which would require 8 GPUs. When sizing DP deployments, remember that `--max-num-seqs` applies per DP rank.
|
||||
|
||||
Running a single data parallel deployment across multiple nodes requires a different `vllm serve` to be run on each node, specifying which DP ranks should run on that node. In this case, there will still be a single HTTP entrypoint - the API server(s) will run only on one node, but it doesn't necessarily need to be co-located with the DP ranks.
|
||||
|
||||
@ -80,6 +80,18 @@ When deploying large DP sizes using this method, the API server process can beco
|
||||

|
||||
</figure>
|
||||
|
||||
## Hybrid Load Balancing
|
||||
|
||||
Hybrid load balancing sits between the internal and external approaches. Each node runs its own API server(s) that only queue requests to the data-parallel engines colocated on that node. An upstream load balancer (for example, an ingress controller or traffic router) spreads user requests across those per-node endpoints.
|
||||
|
||||
Enable this mode with `--data-parallel-hybrid-lb` while still launching every node with the global data-parallel size. The key differences from internal load balancing are:
|
||||
|
||||
- You must provide `--data-parallel-size-local` and `--data-parallel-start-rank` so each node knows which ranks it owns.
|
||||
- Not compatible with `--headless` since every node exposes an API endpoint.
|
||||
- Scale `--api-server-count` per node based on the number of local ranks
|
||||
|
||||
In this configuration, each node keeps scheduling decisions local, which reduces cross-node traffic and avoids single node bottlenecks at larger DP sizes.
|
||||
|
||||
## External Load Balancing
|
||||
|
||||
For larger scale deployments especially, it can make sense to handle the orchestration and load balancing of data parallel ranks externally.
|
||||
|
||||
@ -40,10 +40,32 @@ EP_SIZE = TP_SIZE × DP_SIZE
|
||||
|
||||
Where:
|
||||
|
||||
- `TP_SIZE`: Tensor parallel size (always 1 for now)
|
||||
- `TP_SIZE`: Tensor parallel size
|
||||
- `DP_SIZE`: Data parallel size
|
||||
- `EP_SIZE`: Expert parallel size (computed automatically)
|
||||
|
||||
### Layer Behavior with EP Enabled
|
||||
|
||||
When EP is enabled, different layers in MoE models behave differently:
|
||||
|
||||
| Layer Type | Behavior | Parallelism Used |
|
||||
|------------|----------|------------------|
|
||||
| **Expert (MoE) Layers** | Sharded across all EP ranks | Expert Parallel (EP) of size `TP × DP` |
|
||||
| **Attention Layers** | Behavior depends on TP size | See below |
|
||||
|
||||
**Attention layer parallelism:**
|
||||
|
||||
- **When `TP = 1`**: Attention weights are **replicated** across all DP ranks (data parallelism)
|
||||
- **When `TP > 1`**: Attention weights are **sharded** using tensor parallelism across TP ranks within each DP group
|
||||
|
||||
For example, with `TP=2, DP=4` (8 GPUs total):
|
||||
|
||||
- Expert layers form an EP group of size 8, with experts distributed across all GPUs
|
||||
- Attention layers use TP=2 within each of the 4 DP groups
|
||||
|
||||
!!! note "Key Difference from Data Parallel Deployment"
|
||||
Without `--enable-expert-parallel`, MoE layers would use tensor parallelism (forming a TP group of size `TP × DP`), similar to dense models. With EP enabled, expert layers switch to expert parallelism, which can provide better efficiency and locality for MoE models.
|
||||
|
||||
### Example Command
|
||||
|
||||
The following command serves a `DeepSeek-V3-0324` model with 1-way tensor parallel, 8-way (attention) data parallel, and 8-way expert parallel. The attention weights are replicated across all GPUs, while the expert weights are split across GPUs. It will work on a H200 (or H20) node with 8 GPUs. For H100, you can try to serve a smaller model or refer to the multi-node deployment section.
|
||||
@ -81,7 +103,7 @@ vllm serve deepseek-ai/DeepSeek-V3-0324 \
|
||||
--data-parallel-size-local 8 \ # Local DP size on this node (8 GPUs per node)
|
||||
--data-parallel-address 192.168.1.100 \ # Replace with actual IP of Node 1
|
||||
--data-parallel-rpc-port 13345 \ # RPC communication port, can be any port as long as reachable by all nodes
|
||||
--api-server-count=8 # Number of API servers for load handling (scaling this out to total ranks are recommended)
|
||||
--api-server-count=8 # Number of API servers for load handling (scaling this out to # local ranks is recommended)
|
||||
|
||||
# Node 2 (Secondary - headless mode, no API server)
|
||||
vllm serve deepseek-ai/DeepSeek-V3-0324 \
|
||||
@ -119,9 +141,6 @@ While MoE models are typically trained so that each expert receives a similar nu
|
||||
|
||||
Enable EPLB with the `--enable-eplb` flag.
|
||||
|
||||
!!! note "Model Support"
|
||||
Currently only DeepSeek V3 architecture is supported.
|
||||
|
||||
When enabled, vLLM collects load statistics with every forward pass and periodically rebalances expert distribution.
|
||||
|
||||
### EPLB Parameters
|
||||
@ -134,6 +153,8 @@ Configure EPLB with the `--eplb-config` argument, which accepts a JSON string. T
|
||||
| `step_interval`| Frequency of rebalancing (every N engine steps) | 3000 |
|
||||
| `log_balancedness` | Log balancedness metrics (avg tokens per expert ÷ max tokens per expert) | `false` |
|
||||
| `num_redundant_experts` | Additional global experts per EP rank beyond equal distribution | `0` |
|
||||
| `use_async` | Use non-blocking EPLB for reduced latency overhead | `false` |
|
||||
| `policy` | The policy type for expert parallel load balancing | `"default"` |
|
||||
|
||||
For example:
|
||||
|
||||
@ -183,6 +204,26 @@ vllm serve deepseek-ai/DeepSeek-V3-0324 \
|
||||
|
||||
For multi-node deployment, add these EPLB flags to each node's command. We recommend setting `--eplb-config '{"num_redundant_experts":32}'` to 32 in large scale use cases so the most popular experts are always available.
|
||||
|
||||
## Advanced Configuration
|
||||
|
||||
### Performance Optimization
|
||||
|
||||
- **DeepEP kernels**: The `high_throughput` and `low_latency` kernels are optimized for disaggregated serving and may show poor performance for mixed workloads
|
||||
- **Dual Batch Overlap**: Use `--enable-dbo` to overlap all-to-all communication with compute. See [Dual Batch Overlap](../design/dbo.md) for more details.
|
||||
- **Async scheduling (experimental)**: Try `--async-scheduling` to overlap scheduling with model execution.
|
||||
|
||||
### Troubleshooting
|
||||
|
||||
- **`non-zero status: 7 cannot register cq buf`**: When using Infiniband/RoCE, make sure host VM and pods show `ulimit -l` "unlimited".
|
||||
- **`init failed for transport: IBGDA`**: The InfiniBand GDA kernel modules are missing. Run `tools/ep_kernels/configure_system_drivers.sh` on each GPU node and reboot. Also fixes error `NVSHMEM API called before NVSHMEM initialization has completed`.
|
||||
- **NVSHMEM peer disconnect**: Usually a networking misconfiguration. If deploying via Kubernetes, verify that every pod runs with `hostNetwork: true`, `securityContext.privileged: true` to access Infiniband.
|
||||
|
||||
### Benchmarking
|
||||
|
||||
- Use simulator flags `VLLM_MOE_ROUTING_SIMULATION_STRATEGY=uniform_random` and `VLLM_RANDOMIZE_DP_DUMMY_INPUTS=1` so token routing is balanced across EP ranks.
|
||||
|
||||
- Increasing `VLLM_MOE_DP_CHUNK_SIZE` may increase throughput by increasing the maximum batch size for inter-rank token transfers. This may cause DeepEP to throw `assert self.nvshmem_qp_depth >= (num_max_dispatch_tokens_per_rank + 1) * 2`, which can be fixed by increasing environment variable `NVSHMEM_QP_DEPTH`.
|
||||
|
||||
## Disaggregated Serving (Prefill/Decode Split)
|
||||
|
||||
For production deployments requiring strict SLA guarantees for time-to-first-token and inter-token latency, disaggregated serving allows independent scaling of prefill and decode operations.
|
||||
@ -273,3 +314,9 @@ except Exception as e:
|
||||
print(f"❌ Error during disaggregated serving: {e}")
|
||||
print("Check that both prefill and decode instances are running and accessible")
|
||||
```
|
||||
|
||||
### Benchmarking
|
||||
|
||||
- To simulate the decode deployment of disaggregated serving, pass `--kv-transfer-config '{"kv_connector":"DecodeBenchConnector","kv_role":"kv_both"}'` to the `vllm serve` invocation. The connector populates KV cache with random values so decode can be profiled in isolation.
|
||||
|
||||
- **CUDAGraph capture**: Use `--compilation_config '{"cudagraph_mode": "FULL_DECODE_ONLY"}'` to enable CUDA graph capture for decode only and save KV cache.
|
||||
|
||||
@ -851,7 +851,7 @@ endpoints are compatible with both [Jina AI's re-rank API interface](https://jin
|
||||
[Cohere's re-rank API interface](https://docs.cohere.com/v2/reference/rerank) to ensure compatibility with
|
||||
popular open-source tools.
|
||||
|
||||
Code example: [examples/pooling/score/jinaai_rerank_client.py](../../examples/pooling/score/jinaai_rerank_client.py)
|
||||
Code example: [examples/pooling/score/openai_reranker.py](../../examples/pooling/score/openai_reranker.py)
|
||||
|
||||
#### Example Request
|
||||
|
||||
|
||||
@ -62,7 +62,7 @@ If a single node lacks sufficient GPUs to hold the model, deploy vLLM across mul
|
||||
|
||||
### What is Ray?
|
||||
|
||||
Ray is a distributed computing framework for scaling Python programs. Multi-node vLLM deployments require Ray as the runtime engine.
|
||||
Ray is a distributed computing framework for scaling Python programs. Multi-node vLLM deployments can use Ray as the runtime engine.
|
||||
|
||||
vLLM uses Ray to manage the distributed execution of tasks across multiple nodes and control where execution happens.
|
||||
|
||||
@ -130,9 +130,31 @@ vllm serve /path/to/the/model/in/the/container \
|
||||
--distributed-executor-backend ray
|
||||
```
|
||||
|
||||
### Running vLLM with MultiProcessing
|
||||
|
||||
Besides Ray, Multi-node vLLM deployments can also use `multiprocessing` as the runtime engine. Here's an example to deploy model across 2 nodes (8 GPUs per node) with `tp_size=8` and `pp_size=2`.
|
||||
|
||||
Choose one node as the head node and run:
|
||||
|
||||
```bash
|
||||
vllm serve /path/to/the/model/in/the/container \
|
||||
--tensor-parallel-size 8 --pipeline-parallel-size 2 \
|
||||
--nnodes 2 --node-rank 0 \
|
||||
--master-addr <HEAD_NODE_IP>
|
||||
```
|
||||
|
||||
On the other worker node, run:
|
||||
|
||||
```bash
|
||||
vllm serve /path/to/the/model/in/the/container \
|
||||
--tensor-parallel-size 8 --pipeline-parallel-size 2 \
|
||||
--nnodes 2 --node-rank 1 \
|
||||
--master-addr <HEAD_NODE_IP> --headless
|
||||
```
|
||||
|
||||
## Optimizing network communication for tensor parallelism
|
||||
|
||||
Efficient tensor parallelism requires fast inter-node communication, preferably through high-speed network adapters such as InfiniBand.
|
||||
Efficient tensor parallelism requires fast internode communication, preferably through high-speed network adapters such as InfiniBand.
|
||||
To set up the cluster to use InfiniBand, append additional arguments like `--privileged -e NCCL_IB_HCA=mlx5` to the
|
||||
[examples/online_serving/run_cluster.sh](../../examples/online_serving/run_cluster.sh) helper script.
|
||||
Contact your system administrator for more information about the required flags.
|
||||
|
||||
@ -33,11 +33,19 @@ Then query the endpoint to get the latest metrics from the server:
|
||||
|
||||
The following metrics are exposed:
|
||||
|
||||
??? code
|
||||
## General Metrics
|
||||
|
||||
```python
|
||||
--8<-- "vllm/engine/metrics.py:metrics-definitions"
|
||||
```
|
||||
--8<-- "docs/generated/metrics/general.md"
|
||||
|
||||
## Speculative Decoding Metrics
|
||||
|
||||
--8<-- "docs/generated/metrics/spec_decode.md"
|
||||
|
||||
## NIXL KV Connector Metrics
|
||||
|
||||
--8<-- "docs/generated/metrics/nixl_connector.md"
|
||||
|
||||
## Deprecation Policy
|
||||
|
||||
Note: when metrics are deprecated in version `X.Y`, they are hidden in version `X.Y+1`
|
||||
but can be re-enabled using the `--show-hidden-metrics-for-version=X.Y` escape hatch,
|
||||
|
||||
@ -10,7 +10,7 @@ All communications between nodes in a multi-node vLLM deployment are **insecure
|
||||
|
||||
### Configuration Options for Inter-Node Communications
|
||||
|
||||
The following options control inter-node communications in vLLM:
|
||||
The following options control internode communications in vLLM:
|
||||
|
||||
#### 1. **Environment Variables:**
|
||||
|
||||
@ -28,7 +28,7 @@ The following options control inter-node communications in vLLM:
|
||||
|
||||
### Notes on PyTorch Distributed
|
||||
|
||||
vLLM uses PyTorch's distributed features for some inter-node communication. For
|
||||
vLLM uses PyTorch's distributed features for some internode communication. For
|
||||
detailed information about PyTorch Distributed security considerations, please
|
||||
refer to the [PyTorch Security
|
||||
Guide](https://github.com/pytorch/pytorch/security/policy#using-distributed-features).
|
||||
|
||||
@ -42,60 +42,31 @@ class ModelRequestData(NamedTuple):
|
||||
# Unless specified, these settings have been tested to work on a single L4.
|
||||
|
||||
|
||||
# Voxtral
|
||||
# Make sure to install mistral-common[audio].
|
||||
def run_voxtral(question: str, audio_count: int) -> ModelRequestData:
|
||||
from mistral_common.audio import Audio
|
||||
from mistral_common.protocol.instruct.chunk import (
|
||||
AudioChunk,
|
||||
RawAudio,
|
||||
TextChunk,
|
||||
)
|
||||
from mistral_common.protocol.instruct.messages import (
|
||||
UserMessage,
|
||||
)
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
model_name = "mistralai/Voxtral-Mini-3B-2507"
|
||||
tokenizer = MistralTokenizer.from_hf_hub(model_name)
|
||||
|
||||
# AudioFlamingo3
|
||||
def run_audioflamingo3(question: str, audio_count: int) -> ModelRequestData:
|
||||
model_name = "nvidia/audio-flamingo-3-hf"
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=8192,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
limit_mm_per_prompt={"audio": audio_count},
|
||||
config_format="mistral",
|
||||
load_format="mistral",
|
||||
tokenizer_mode="mistral",
|
||||
enforce_eager=True,
|
||||
enable_chunked_prefill=False,
|
||||
)
|
||||
|
||||
text_chunk = TextChunk(text=question)
|
||||
audios = [
|
||||
Audio.from_file(str(audio_assets[i].get_local_path()), strict=False)
|
||||
for i in range(audio_count)
|
||||
]
|
||||
audio_chunks = [
|
||||
AudioChunk(input_audio=RawAudio.from_audio(audio)) for audio in audios
|
||||
]
|
||||
# AudioFlamingo3 uses <sound> token for audio
|
||||
audio_placeholder = "<sound>" * audio_count
|
||||
|
||||
messages = [UserMessage(content=[*audio_chunks, text_chunk])]
|
||||
|
||||
req = ChatCompletionRequest(messages=messages, model=model_name)
|
||||
|
||||
tokens = tokenizer.encode_chat_completion(req)
|
||||
prompt_ids, audios = tokens.tokens, tokens.audios
|
||||
|
||||
audios_and_sr = [(au.audio_array, au.sampling_rate) for au in audios]
|
||||
|
||||
multi_modal_data = {"audio": audios_and_sr}
|
||||
prompt = (
|
||||
"<|im_start|>system\n"
|
||||
"You are a helpful assistant.<|im_end|>\n"
|
||||
"<|im_start|>user\n"
|
||||
f"{audio_placeholder}{question}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt_token_ids=prompt_ids,
|
||||
multi_modal_data=multi_modal_data,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
|
||||
@ -361,6 +332,63 @@ def run_ultravox(question: str, audio_count: int) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
# Voxtral
|
||||
# Make sure to install mistral-common[audio].
|
||||
def run_voxtral(question: str, audio_count: int) -> ModelRequestData:
|
||||
from mistral_common.audio import Audio
|
||||
from mistral_common.protocol.instruct.chunk import (
|
||||
AudioChunk,
|
||||
RawAudio,
|
||||
TextChunk,
|
||||
)
|
||||
from mistral_common.protocol.instruct.messages import (
|
||||
UserMessage,
|
||||
)
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
model_name = "mistralai/Voxtral-Mini-3B-2507"
|
||||
tokenizer = MistralTokenizer.from_hf_hub(model_name)
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
limit_mm_per_prompt={"audio": audio_count},
|
||||
config_format="mistral",
|
||||
load_format="mistral",
|
||||
tokenizer_mode="mistral",
|
||||
enforce_eager=True,
|
||||
enable_chunked_prefill=False,
|
||||
)
|
||||
|
||||
text_chunk = TextChunk(text=question)
|
||||
audios = [
|
||||
Audio.from_file(str(audio_assets[i].get_local_path()), strict=False)
|
||||
for i in range(audio_count)
|
||||
]
|
||||
audio_chunks = [
|
||||
AudioChunk(input_audio=RawAudio.from_audio(audio)) for audio in audios
|
||||
]
|
||||
|
||||
messages = [UserMessage(content=[*audio_chunks, text_chunk])]
|
||||
|
||||
req = ChatCompletionRequest(messages=messages, model=model_name)
|
||||
|
||||
tokens = tokenizer.encode_chat_completion(req)
|
||||
prompt_ids, audios = tokens.tokens, tokens.audios
|
||||
|
||||
audios_and_sr = [(au.audio_array, au.sampling_rate) for au in audios]
|
||||
|
||||
multi_modal_data = {"audio": audios_and_sr}
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt_token_ids=prompt_ids,
|
||||
multi_modal_data=multi_modal_data,
|
||||
)
|
||||
|
||||
|
||||
# Whisper
|
||||
def run_whisper(question: str, audio_count: int) -> ModelRequestData:
|
||||
assert audio_count == 1, "Whisper only support single audio input per prompt"
|
||||
@ -382,7 +410,7 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData:
|
||||
|
||||
|
||||
model_example_map = {
|
||||
"voxtral": run_voxtral,
|
||||
"audioflamingo3": run_audioflamingo3,
|
||||
"gemma3n": run_gemma3n,
|
||||
"granite_speech": run_granite_speech,
|
||||
"midashenglm": run_midashenglm,
|
||||
@ -392,6 +420,7 @@ model_example_map = {
|
||||
"qwen2_audio": run_qwen2_audio,
|
||||
"qwen2_5_omni": run_qwen2_5_omni,
|
||||
"ultravox": run_ultravox,
|
||||
"voxtral": run_voxtral,
|
||||
"whisper": run_whisper,
|
||||
}
|
||||
|
||||
@ -422,7 +451,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
default=0,
|
||||
help="Set the seed when initializing `vllm.LLM`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
@ -4,6 +4,9 @@
|
||||
from argparse import Namespace
|
||||
|
||||
from vllm import LLM, EngineArgs
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import AttentionConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
@ -20,6 +23,11 @@ def parse_args():
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
if current_platform.is_rocm():
|
||||
args.attention_config = AttentionConfig(
|
||||
backend=AttentionBackendEnum.FLEX_ATTENTION
|
||||
)
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
|
||||
@ -4,6 +4,9 @@
|
||||
from argparse import Namespace
|
||||
|
||||
from vllm import LLM, EngineArgs
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import AttentionConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
@ -20,6 +23,11 @@ def parse_args():
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
if current_platform.is_rocm():
|
||||
args.attention_config = AttentionConfig(
|
||||
backend=AttentionBackendEnum.FLEX_ATTENTION
|
||||
)
|
||||
|
||||
# Sample prompts.
|
||||
text_1 = "What is the capital of France?"
|
||||
texts_2 = [
|
||||
|
||||
@ -33,6 +33,7 @@ import os
|
||||
from time import sleep
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.network_utils import get_open_port
|
||||
|
||||
|
||||
@ -222,6 +223,11 @@ if __name__ == "__main__":
|
||||
|
||||
from multiprocessing import Process
|
||||
|
||||
if current_platform.is_rocm():
|
||||
from multiprocessing import set_start_method
|
||||
|
||||
set_start_method("spawn", force=True)
|
||||
|
||||
procs = []
|
||||
for local_dp_rank, global_dp_rank in enumerate(
|
||||
range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)
|
||||
|
||||
@ -77,7 +77,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
default=0,
|
||||
help="Set the seed when initializing `vllm.LLM`.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
@ -158,7 +158,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
default=0,
|
||||
help="Set the seed when initializing `vllm.LLM`.",
|
||||
)
|
||||
|
||||
|
||||
@ -158,7 +158,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
default=0,
|
||||
help="Set the seed when initializing `vllm.LLM`.",
|
||||
)
|
||||
|
||||
|
||||
@ -118,6 +118,32 @@ def run_bee(questions: list[str], modality: str) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
def run_bagel(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
model_name = "ByteDance-Seed/BAGEL-7B-MoT"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
limit_mm_per_prompt={modality: 1},
|
||||
)
|
||||
|
||||
prompts = [
|
||||
(
|
||||
f"<|im_start|>user\n<|image_pad|>\n{question}<|im_end|>\n"
|
||||
f"<|im_start|>assistant\n"
|
||||
)
|
||||
for question in questions
|
||||
]
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# BLIP-2
|
||||
def run_blip2(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
@ -1832,6 +1858,7 @@ def run_tarsier2(questions: list[str], modality: str) -> ModelRequestData:
|
||||
model_example_map = {
|
||||
"aria": run_aria,
|
||||
"aya_vision": run_aya_vision,
|
||||
"bagel": run_bagel,
|
||||
"bee": run_bee,
|
||||
"blip-2": run_blip2,
|
||||
"chameleon": run_chameleon,
|
||||
@ -2031,7 +2058,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
default=0,
|
||||
help="Set the seed when initializing `vllm.LLM`.",
|
||||
)
|
||||
|
||||
|
||||
@ -1382,7 +1382,7 @@ def run_generate(
|
||||
model,
|
||||
question: str,
|
||||
image_urls: list[str],
|
||||
seed: int | None,
|
||||
seed: int,
|
||||
tensor_parallel_size: int | None,
|
||||
):
|
||||
req_data = model_example_map[model](question, image_urls)
|
||||
@ -1416,7 +1416,7 @@ def run_chat(
|
||||
model: str,
|
||||
question: str,
|
||||
image_urls: list[str],
|
||||
seed: int | None,
|
||||
seed: int,
|
||||
tensor_parallel_size: int | None,
|
||||
):
|
||||
req_data = model_example_map[model](question, image_urls)
|
||||
@ -1494,7 +1494,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
default=0,
|
||||
help="Set the seed when initializing `vllm.LLM`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
@ -21,7 +21,7 @@
|
||||
# --worker \
|
||||
# /abs/path/to/huggingface/cache \
|
||||
# -e VLLM_HOST_IP=<worker_node_ip>
|
||||
#
|
||||
#
|
||||
# Each worker requires a unique VLLM_HOST_IP value.
|
||||
# Keep each terminal session open. Closing a session stops the associated Ray
|
||||
# node and thereby shuts down the entire cluster.
|
||||
@ -59,6 +59,34 @@ if [ "${NODE_TYPE}" != "--head" ] && [ "${NODE_TYPE}" != "--worker" ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Extract VLLM_HOST_IP from ADDITIONAL_ARGS (e.g. "-e VLLM_HOST_IP=...").
|
||||
VLLM_HOST_IP=""
|
||||
for ((i = 0; i < ${#ADDITIONAL_ARGS[@]}; i++)); do
|
||||
arg="${ADDITIONAL_ARGS[$i]}"
|
||||
case "${arg}" in
|
||||
-e)
|
||||
next="${ADDITIONAL_ARGS[$((i + 1))]:-}"
|
||||
if [[ "${next}" == VLLM_HOST_IP=* ]]; then
|
||||
VLLM_HOST_IP="${next#VLLM_HOST_IP=}"
|
||||
break
|
||||
fi
|
||||
;;
|
||||
-eVLLM_HOST_IP=* | VLLM_HOST_IP=*)
|
||||
VLLM_HOST_IP="${arg#*=}"
|
||||
break
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# For the head node, HEAD_NODE_ADDRESS and VLLM_HOST_IP should be consistent.
|
||||
if [[ "${NODE_TYPE}" == "--head" && -n "${VLLM_HOST_IP}" ]]; then
|
||||
if [[ "${VLLM_HOST_IP}" != "${HEAD_NODE_ADDRESS}" ]]; then
|
||||
echo "Warning: VLLM_HOST_IP (${VLLM_HOST_IP}) differs from head_node_ip (${HEAD_NODE_ADDRESS})."
|
||||
echo "Using VLLM_HOST_IP as the head node address."
|
||||
HEAD_NODE_ADDRESS="${VLLM_HOST_IP}"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Generate a unique container name with random suffix.
|
||||
# Docker container names must be unique on each host.
|
||||
# The random suffix allows multiple Ray containers to run simultaneously on the same machine,
|
||||
@ -74,36 +102,17 @@ cleanup() {
|
||||
trap cleanup EXIT
|
||||
|
||||
# Build the Ray start command based on the node role.
|
||||
# The head node manages the cluster and accepts connections on port 6379,
|
||||
# The head node manages the cluster and accepts connections on port 6379,
|
||||
# while workers connect to the head's address.
|
||||
RAY_START_CMD="ray start --block"
|
||||
if [ "${NODE_TYPE}" == "--head" ]; then
|
||||
RAY_START_CMD+=" --head --port=6379"
|
||||
RAY_START_CMD+=" --head --node-ip-address=${HEAD_NODE_ADDRESS} --port=6379"
|
||||
else
|
||||
|
||||
RAY_START_CMD+=" --address=${HEAD_NODE_ADDRESS}:6379"
|
||||
fi
|
||||
|
||||
# Parse VLLM_HOST_IP from additional args if present.
|
||||
# This is needed for multi-NIC configurations where Ray needs explicit IP bindings.
|
||||
VLLM_HOST_IP=""
|
||||
for arg in "${ADDITIONAL_ARGS[@]}"; do
|
||||
if [[ $arg == "-e" ]]; then
|
||||
continue
|
||||
if [ -n "${VLLM_HOST_IP}" ]; then
|
||||
RAY_START_CMD+=" --node-ip-address=${VLLM_HOST_IP}"
|
||||
fi
|
||||
if [[ $arg == VLLM_HOST_IP=* ]]; then
|
||||
VLLM_HOST_IP="${arg#VLLM_HOST_IP=}"
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
||||
# Build Ray IP environment variables if VLLM_HOST_IP is set.
|
||||
# These variables ensure Ray binds to the correct network interface on multi-NIC systems.
|
||||
RAY_IP_VARS=()
|
||||
if [ -n "${VLLM_HOST_IP}" ]; then
|
||||
RAY_IP_VARS=(
|
||||
-e "RAY_NODE_IP_ADDRESS=${VLLM_HOST_IP}"
|
||||
-e "RAY_OVERRIDE_NODE_IP_ADDRESS=${VLLM_HOST_IP}"
|
||||
)
|
||||
fi
|
||||
|
||||
# Launch the container with the assembled parameters.
|
||||
@ -118,6 +127,5 @@ docker run \
|
||||
--shm-size 10.24g \
|
||||
--gpus all \
|
||||
-v "${PATH_TO_HF_HOME}:/root/.cache/huggingface" \
|
||||
"${RAY_IP_VARS[@]}" \
|
||||
"${ADDITIONAL_ARGS[@]}" \
|
||||
"${DOCKER_IMAGE}" -c "${RAY_START_CMD}"
|
||||
|
||||
@ -112,7 +112,7 @@ PARAMS: dict[ConstraintsFormat, dict[str, Any]] = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Generate an SQL query to show the 'username' and 'email'from the 'users' table.",
|
||||
"content": "Generate an SQL query to show the 'username' and 'email' from the 'users' table.",
|
||||
}
|
||||
],
|
||||
"extra_body": {
|
||||
|
||||
@ -16,7 +16,7 @@ import requests
|
||||
# - start vllm in serving mode with the below args
|
||||
# --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM'
|
||||
# --model-impl terratorch
|
||||
# --task embed --trust-remote-code
|
||||
# --trust-remote-code
|
||||
# --skip-tokenizer-init --enforce-eager
|
||||
# --io-processor-plugin terratorch_segmentation
|
||||
# --enable-mm-embeds
|
||||
|
||||
@ -305,7 +305,7 @@ def get_query(modality: QueryModality):
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def run_encode(model: str, modality: QueryModality, seed: int | None):
|
||||
def run_encode(model: str, modality: QueryModality, seed: int):
|
||||
query = get_query(modality)
|
||||
req_data = model_example_map[model](query)
|
||||
|
||||
@ -335,7 +335,7 @@ def run_encode(model: str, modality: QueryModality, seed: int | None):
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
def run_score(model: str, modality: QueryModality, seed: int | None):
|
||||
def run_score(model: str, modality: QueryModality, seed: int):
|
||||
query = get_query(modality)
|
||||
req_data = model_example_map[model](query)
|
||||
|
||||
@ -390,7 +390,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
default=0,
|
||||
help="Set the seed when initializing `vllm.LLM`.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
@ -51,6 +51,7 @@ hooks:
|
||||
- docs/mkdocs/hooks/remove_announcement.py
|
||||
- docs/mkdocs/hooks/generate_examples.py
|
||||
- docs/mkdocs/hooks/generate_argparse.py
|
||||
- docs/mkdocs/hooks/generate_metrics.py
|
||||
- docs/mkdocs/hooks/url_schemes.py
|
||||
|
||||
plugins:
|
||||
|
||||
@ -50,4 +50,5 @@ ijson # Required for mistral streaming tool parser
|
||||
setproctitle # Used to set process names for better debugging and monitoring
|
||||
openai-harmony >= 0.0.3 # Required for gpt-oss
|
||||
anthropic == 0.71.0
|
||||
model-hosting-container-standards >= 0.1.9, < 1.0.0
|
||||
model-hosting-container-standards >= 0.1.10, < 1.0.0
|
||||
mcp
|
||||
|
||||
@ -75,7 +75,7 @@ torchgeo==0.7.0
|
||||
mteb==2.1.2
|
||||
|
||||
# Data processing
|
||||
xgrammar==0.1.27
|
||||
xgrammar @ git+https://github.com/divakar-amd/xgrammar@3272f7c520564858056a60480d5afdf69ae79c84
|
||||
# Test async scheduling
|
||||
|
||||
# Utilities
|
||||
|
||||
@ -23,14 +23,6 @@ class TestParameterSweepItem:
|
||||
{"compilation_config.use_inductor_graph_partition": True},
|
||||
"--compilation-config.use_inductor_graph_partition=true",
|
||||
),
|
||||
(
|
||||
{"compilation_config.use_inductor": False},
|
||||
"--compilation-config.use_inductor=false",
|
||||
),
|
||||
(
|
||||
{"compilation_config.use_inductor": True},
|
||||
"--compilation-config.use_inductor=true",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_nested_boolean_params(self, input_dict, expected):
|
||||
|
||||
@ -20,13 +20,14 @@ from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
from ...utils import flat_product, multi_gpu_test
|
||||
|
||||
is_blackwell = lambda: current_platform.is_device_capability(100)
|
||||
is_blackwell = lambda: current_platform.is_device_capability_family(100)
|
||||
"""Are we running on Blackwell, a lot of tests depend on it"""
|
||||
|
||||
|
||||
class Matches(NamedTuple):
|
||||
attention_fusion: int = 0
|
||||
allreduce_fusion: int = 0
|
||||
rms_quant_norm_fusion: int = 0
|
||||
sequence_parallel: int = 0
|
||||
async_tp: int = 0
|
||||
|
||||
@ -40,6 +41,7 @@ class ModelBackendTestCase(NamedTuple):
|
||||
|
||||
MODELS_FP8: list[ModelBackendTestCase] = []
|
||||
MODELS_FP4: list[ModelBackendTestCase] = []
|
||||
MODELS_GROUP_FP8: list[ModelBackendTestCase] = []
|
||||
MODELS: list[ModelBackendTestCase] = [] # tp-only
|
||||
|
||||
if current_platform.is_cuda():
|
||||
@ -138,6 +140,17 @@ elif current_platform.is_rocm():
|
||||
CUSTOM_OPS_FP8 = ["-quant_fp8", "+quant_fp8"]
|
||||
|
||||
|
||||
def has_cuda_graph_wrapper_metadata() -> bool:
|
||||
from importlib import import_module
|
||||
|
||||
try:
|
||||
module = import_module("torch._inductor.utils")
|
||||
module.CUDAGraphWrapperMetadata # noqa B018
|
||||
except AttributeError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, model_kwargs, backend, matches, custom_ops",
|
||||
# Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
|
||||
@ -145,7 +158,20 @@ CUSTOM_OPS_FP8 = ["-quant_fp8", "+quant_fp8"]
|
||||
# quant_fp4 only has the custom impl
|
||||
+ list(flat_product(MODELS_FP4, [""])),
|
||||
)
|
||||
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
|
||||
@pytest.mark.parametrize(
|
||||
"inductor_graph_partition",
|
||||
[
|
||||
pytest.param(
|
||||
True,
|
||||
marks=pytest.mark.skipif(
|
||||
not has_cuda_graph_wrapper_metadata(),
|
||||
reason="This test requires"
|
||||
"torch._inductor.utils.CUDAGraphWrapperMetadata to run",
|
||||
),
|
||||
),
|
||||
False,
|
||||
],
|
||||
)
|
||||
def test_attn_quant(
|
||||
model_name: str,
|
||||
model_kwargs: dict[str, Any],
|
||||
@ -474,3 +500,81 @@ def run_model(compile_config: int | CompilationConfig, model: str, **model_kwarg
|
||||
compilation_config.compile_ranges_split_points = (
|
||||
llm.llm_engine.vllm_config.compilation_config.compile_ranges_split_points
|
||||
)
|
||||
|
||||
|
||||
if current_platform.is_cuda():
|
||||
MODELS_GROUP_FP8 = [
|
||||
ModelBackendTestCase(
|
||||
model_name="Qwen/Qwen3-30B-A3B-FP8",
|
||||
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
|
||||
backend=AttentionBackendEnum.TRITON_ATTN,
|
||||
matches=Matches(
|
||||
rms_quant_norm_fusion=48,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
CUSTOM_OPS_QUANT_RMS_NORM = ["+quant_fp8,+rms_norm"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, model_kwargs, backend, matches, custom_ops",
|
||||
# Test rms norm+group quant_fp8 fusion
|
||||
list[tuple[Any, ...]](flat_product(MODELS_GROUP_FP8, CUSTOM_OPS_QUANT_RMS_NORM)),
|
||||
)
|
||||
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
|
||||
# TODO: remove skip after we fix the fusion thoroughly
|
||||
@pytest.mark.skipif(is_blackwell(), reason="Temporarily disabled on Blackwell")
|
||||
def test_rms_group_quant(
|
||||
model_name: str,
|
||||
model_kwargs: dict[str, Any],
|
||||
backend: AttentionBackendEnum,
|
||||
matches: Matches,
|
||||
custom_ops: str,
|
||||
inductor_graph_partition: bool,
|
||||
caplog_mp_spawn,
|
||||
monkeypatch,
|
||||
):
|
||||
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
pytest.skip("Inductor graph partition requires torch>=2.9")
|
||||
|
||||
custom_ops_list = custom_ops.split(",") if custom_ops else []
|
||||
|
||||
if inductor_graph_partition:
|
||||
mode = CUDAGraphMode.FULL_AND_PIECEWISE
|
||||
splitting_ops: list[str] | None = None
|
||||
else:
|
||||
mode = CUDAGraphMode.FULL_DECODE_ONLY
|
||||
splitting_ops = []
|
||||
|
||||
# Disable, compile cache to make sure custom passes run.
|
||||
# Otherwise, we can't verify fusion happened through the logs.
|
||||
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
|
||||
|
||||
# To capture subprocess logs, we need to know whether spawn or fork is used.
|
||||
# Force spawn as it is more general.
|
||||
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
|
||||
|
||||
compilation_config = CompilationConfig(
|
||||
# Testing properties
|
||||
custom_ops=custom_ops_list,
|
||||
use_inductor_graph_partition=inductor_graph_partition,
|
||||
cudagraph_mode=mode,
|
||||
splitting_ops=splitting_ops,
|
||||
# Common
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
pass_config=PassConfig(eliminate_noops=True, fuse_norm_quant=True),
|
||||
# Inductor caches custom passes by default as well via uuid
|
||||
inductor_compile_config={"force_disable_caches": True},
|
||||
)
|
||||
|
||||
with caplog_mp_spawn(logging.DEBUG) as log_holder:
|
||||
run_model(compilation_config, model_name, **model_kwargs)
|
||||
|
||||
log_matches = re.findall(
|
||||
r"\[fusion.py:\d+] Replaced (\d+) patterns",
|
||||
log_holder.text,
|
||||
)
|
||||
assert len(log_matches) == 1, log_holder.text
|
||||
assert int(log_matches[0]) == matches.rms_quant_norm_fusion
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
import logging
|
||||
from contextlib import nullcontext
|
||||
from unittest.mock import patch
|
||||
|
||||
@ -13,7 +12,6 @@ from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||
from vllm.config import CompilationConfig, CUDAGraphMode, ParallelConfig, VllmConfig
|
||||
from vllm.config.compilation import CompilationMode, PassConfig
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.logger import _print_warning_once
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import _is_torch_equal_or_newer
|
||||
|
||||
@ -290,7 +288,7 @@ def test_moe_splitting_ops_deepep_ht_attn_fusion_no_inductor():
|
||||
),
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
pass_config={"enable_attn_fusion": True, "enable_noop": True},
|
||||
pass_config={"fuse_attn_quant": True, "eliminate_noops": True},
|
||||
custom_ops=["+quant_fp8"],
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
),
|
||||
@ -442,62 +440,3 @@ def test_cudagraph_sizes_post_init(
|
||||
vllm_config.compilation_config.max_cudagraph_capture_size
|
||||
== expected_max_size
|
||||
)
|
||||
|
||||
|
||||
def test_pass_config_deprecation(caplog_vllm):
|
||||
caplog_vllm.set_level(logging.WARNING)
|
||||
|
||||
# Clear cache to ensure warnings are re-issued
|
||||
_print_warning_once.cache_clear()
|
||||
|
||||
# Test enable_fusion -> fuse_norm_quant, fuse_act_quant
|
||||
caplog_vllm.clear()
|
||||
config = PassConfig(enable_fusion=True)
|
||||
assert "enable_fusion is deprecated" in caplog_vllm.text
|
||||
assert config.fuse_norm_quant is True
|
||||
assert config.fuse_act_quant is True
|
||||
assert config.enable_fusion is True
|
||||
|
||||
# Test enable_attn_fusion -> fuse_attn_quant
|
||||
caplog_vllm.clear()
|
||||
config = PassConfig(enable_attn_fusion=True)
|
||||
assert "enable_attn_fusion is deprecated" in caplog_vllm.text
|
||||
assert config.fuse_attn_quant is True
|
||||
assert config.enable_attn_fusion is True
|
||||
|
||||
# Test enable_noop -> eliminate_noops
|
||||
caplog_vllm.clear()
|
||||
config = PassConfig(enable_noop=True)
|
||||
assert "enable_noop is deprecated" in caplog_vllm.text
|
||||
assert config.eliminate_noops is True
|
||||
assert config.enable_noop is True
|
||||
|
||||
# Test enable_sequence_parallelism -> enable_sp
|
||||
caplog_vllm.clear()
|
||||
config = PassConfig(enable_sequence_parallelism=True)
|
||||
assert "enable_sequence_parallelism is deprecated" in caplog_vllm.text
|
||||
assert config.enable_sp is True
|
||||
assert config.enable_sequence_parallelism is True
|
||||
|
||||
# Test enable_async_tp -> fuse_gemm_comms
|
||||
caplog_vllm.clear()
|
||||
config = PassConfig(enable_async_tp=True)
|
||||
assert "enable_async_tp is deprecated" in caplog_vllm.text
|
||||
assert config.fuse_gemm_comms is True
|
||||
assert config.enable_async_tp is True
|
||||
|
||||
# Test enable_fi_allreduce_fusion -> fuse_allreduce_rms
|
||||
caplog_vllm.clear()
|
||||
config = PassConfig(enable_fi_allreduce_fusion=True)
|
||||
assert "enable_fi_allreduce_fusion is deprecated" in caplog_vllm.text
|
||||
assert config.fuse_allreduce_rms is True
|
||||
assert config.enable_fi_allreduce_fusion is True
|
||||
|
||||
# Test hash consistency
|
||||
config_old = PassConfig(enable_fusion=True)
|
||||
config_new = PassConfig(fuse_norm_quant=True, fuse_act_quant=True)
|
||||
assert config_old.compute_hash() == config_new.compute_hash()
|
||||
|
||||
config_old = PassConfig(enable_async_tp=True)
|
||||
config_new = PassConfig(fuse_gemm_comms=True)
|
||||
assert config_old.compute_hash() == config_new.compute_hash()
|
||||
|
||||
@ -36,7 +36,7 @@ def get_test_models():
|
||||
DynamicShapesType.BACKED_SIZE_OBLIVIOUS,
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("use_aot_compile", ["0"])
|
||||
@pytest.mark.parametrize("use_aot_compile", ["0", "1"])
|
||||
@pytest.mark.parametrize("use_bytecode_hook", [True, False])
|
||||
@pytest.mark.parametrize("evaluate_guards", [False, True])
|
||||
@pytest.mark.skipif(
|
||||
@ -54,6 +54,12 @@ def test_dynamic_shapes_compilation(
|
||||
if use_bytecode_hook and shapes_type == DynamicShapesType.UNBACKED:
|
||||
pytest.skip("UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0")
|
||||
|
||||
if evaluate_guards and shapes_type == DynamicShapesType.UNBACKED:
|
||||
pytest.skip("unbacked dynamic shapes do not add guards")
|
||||
|
||||
if evaluate_guards and use_aot_compile:
|
||||
pytest.skip("evaluate_guards requires use_aot_compile=0")
|
||||
|
||||
monkeypatch.setenv("VLLM_USE_AOT_COMPILE", use_aot_compile)
|
||||
monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0")
|
||||
|
||||
@ -120,7 +126,7 @@ def test_model_specialization_with_evaluate_guards(
|
||||
and dynamic_shapes_type == DynamicShapesType.BACKED
|
||||
and evaluate_guards
|
||||
):
|
||||
pytest.skip("evaluate_guards for backed does not work with aot_compile =1")
|
||||
pytest.skip("evaluate_guards for backed does not work with aot_compile=1")
|
||||
|
||||
@support_torch_compile
|
||||
class ModelWithSizeCheck(torch.nn.Module):
|
||||
|
||||
@ -128,14 +128,12 @@ class TestFusedAddRMSNorm(torch.nn.Module):
|
||||
|
||||
|
||||
class TestRotaryEmbedding(torch.nn.Module):
|
||||
def __init__(self, head_dim=64, rotary_dim=None, max_position=2048, base=10000):
|
||||
def __init__(self, head_dim=64, max_position=2048, base=10000):
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
self.rotary_dim = rotary_dim or head_dim
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.rotary_dim,
|
||||
max_position=max_position,
|
||||
rope_parameters={"rope_type": "default", "rope_theta": base},
|
||||
)
|
||||
@ -170,7 +168,6 @@ class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position,
|
||||
rope_parameters={"rope_type": "default", "rope_theta": base},
|
||||
)
|
||||
|
||||
@ -202,6 +202,27 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool):
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def workspace_init():
|
||||
"""Initialize the workspace manager for tests that need it.
|
||||
|
||||
This fixture initializes the workspace manager with a CUDA device
|
||||
if available, and resets it after the test completes. Tests that
|
||||
create a full vLLM engine should NOT use this fixture as the engine
|
||||
will initialize the workspace manager itself.
|
||||
"""
|
||||
from vllm.v1.worker.workspace import (
|
||||
init_workspace_manager,
|
||||
reset_workspace_manager,
|
||||
)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda:0")
|
||||
init_workspace_manager(device)
|
||||
yield
|
||||
reset_workspace_manager()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def dynamo_reset():
|
||||
yield
|
||||
@ -681,10 +702,16 @@ class HfRunner:
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Encoder-decoder models return decoder_hidden_states instead of
|
||||
# hidden_states
|
||||
hidden_states = (
|
||||
getattr(output, "hidden_states", None) or output.decoder_hidden_states
|
||||
)
|
||||
|
||||
(
|
||||
seq_logprobs_lst,
|
||||
output_len,
|
||||
) = self._hidden_states_to_logprobs(output.hidden_states, num_logprobs)
|
||||
) = self._hidden_states_to_logprobs(hidden_states, num_logprobs)
|
||||
|
||||
all_logprobs.append(seq_logprobs_lst)
|
||||
seq_ids = output.sequences[0]
|
||||
@ -741,7 +768,7 @@ class VllmRunner:
|
||||
tokenizer_name: str | None = None,
|
||||
tokenizer_mode: str = "auto",
|
||||
trust_remote_code: bool = True,
|
||||
seed: int | None = 0,
|
||||
seed: int = 0,
|
||||
max_model_len: int | None = 1024,
|
||||
dtype: str = "auto",
|
||||
disable_log_stats: bool = True,
|
||||
|
||||
276
tests/distributed/test_eplb_fused_moe_layer_dep_nvfp4.py
Normal file
276
tests/distributed/test_eplb_fused_moe_layer_dep_nvfp4.py
Normal file
@ -0,0 +1,276 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Test that the interaction between EPLB and FusedMoE Layer is okay for DP w/ NVFP4
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import make_test_quant_config
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
|
||||
from vllm.distributed.parallel_state import (
|
||||
ensure_model_parallel_initialized,
|
||||
get_dp_group,
|
||||
)
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
from vllm.model_executor.layers.quantization.modelopt import (
|
||||
ModelOptNvFp4Config,
|
||||
ModelOptNvFp4FusedMoE,
|
||||
)
|
||||
|
||||
from .eplb_utils import distributed_run, set_env_vars_and_device
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestConfig:
|
||||
num_layers: int
|
||||
num_experts: int
|
||||
num_local_experts: int
|
||||
num_topk: int
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
num_tokens: int
|
||||
|
||||
|
||||
def make_fused_moe_layer(
|
||||
rank: int,
|
||||
layer_idx: int,
|
||||
test_config: TestConfig,
|
||||
) -> FusedMoE:
|
||||
quant_config = None
|
||||
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
|
||||
quant_config = ModelOptNvFp4Config(
|
||||
is_checkpoint_nvfp4_serialized=True,
|
||||
kv_cache_quant_algo=None,
|
||||
exclude_modules=[],
|
||||
)
|
||||
|
||||
fml = FusedMoE(
|
||||
num_experts=test_config.num_experts,
|
||||
top_k=test_config.num_topk,
|
||||
hidden_size=test_config.hidden_size,
|
||||
intermediate_size=test_config.intermediate_size,
|
||||
prefix=f"dummy_layer_{layer_idx}",
|
||||
activation="silu",
|
||||
is_act_and_mul=True,
|
||||
params_dtype=torch.bfloat16,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
nvfp4_fused_moe = ModelOptNvFp4FusedMoE(quant_config, fml)
|
||||
nvfp4_fused_moe.create_weights(
|
||||
fml,
|
||||
test_config.num_local_experts,
|
||||
test_config.hidden_size,
|
||||
test_config.intermediate_size,
|
||||
params_dtype=torch.uint8,
|
||||
global_num_experts=test_config.num_experts,
|
||||
)
|
||||
|
||||
fml = fml.to(device)
|
||||
w1_q, w2_q, quant_config = make_test_quant_config(
|
||||
test_config.num_local_experts,
|
||||
test_config.intermediate_size,
|
||||
test_config.hidden_size,
|
||||
in_dtype=torch.bfloat16,
|
||||
quant_dtype="nvfp4",
|
||||
block_shape=None,
|
||||
per_act_token_quant=False,
|
||||
)
|
||||
|
||||
fml.w13_weight.data = w1_q
|
||||
fml.w2_weight.data = w2_q
|
||||
|
||||
fml.w2_input_scale.data = torch.randn_like(fml.w2_input_scale.data) / 5
|
||||
fml.w13_input_scale.data = torch.randn_like(fml.w13_input_scale.data) / 5
|
||||
fml.w2_weight_scale_2.data = torch.randn_like(fml.w2_weight_scale_2.data) / 5
|
||||
fml.w13_weight_scale_2.data = torch.randn_like(fml.w13_weight_scale_2.data) / 5
|
||||
fml.w2_weight_scale.data = (
|
||||
torch.randn(fml.w2_weight_scale.data.shape, device=device) / 5
|
||||
).to(fml.w2_weight_scale.data.dtype)
|
||||
fml.w13_weight_scale.data = (
|
||||
torch.randn(fml.w13_weight_scale.data.shape, device=device) / 5
|
||||
).to(fml.w13_weight_scale.data.dtype)
|
||||
|
||||
nvfp4_fused_moe.process_weights_after_loading(fml)
|
||||
|
||||
fml.maybe_init_modular_kernel()
|
||||
|
||||
return fml
|
||||
|
||||
|
||||
def _test_eplb_fml(env, world_size: int, test_config: TestConfig):
|
||||
set_env_vars_and_device(env)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.parallel_config.data_parallel_size = world_size
|
||||
vllm_config.parallel_config.enable_expert_parallel = True
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
ensure_model_parallel_initialized(
|
||||
tensor_model_parallel_size=1, pipeline_model_parallel_size=1
|
||||
)
|
||||
|
||||
ep_group = get_dp_group().cpu_group
|
||||
ep_rank = torch.distributed.get_rank()
|
||||
|
||||
device = torch.device(f"cuda:{ep_rank}")
|
||||
|
||||
fml_layers = [
|
||||
make_fused_moe_layer(ep_rank, layer_idx, test_config).to(device)
|
||||
for layer_idx in range(test_config.num_layers)
|
||||
]
|
||||
rank_expert_weights = [fml.get_expert_weights() for fml in fml_layers]
|
||||
|
||||
hidden_states = []
|
||||
router_logits = []
|
||||
for layer_idx in range(test_config.num_layers):
|
||||
hidden_states.append(
|
||||
torch.randn(
|
||||
(test_config.num_tokens, test_config.hidden_size),
|
||||
dtype=torch.bfloat16,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
router_logits.append(
|
||||
torch.randn(
|
||||
(test_config.num_tokens, test_config.num_experts),
|
||||
dtype=torch.bfloat16,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
out_before_shuffle = []
|
||||
with set_forward_context(
|
||||
{},
|
||||
num_tokens=test_config.num_tokens,
|
||||
num_tokens_across_dp=torch.tensor(
|
||||
[test_config.num_tokens] * world_size, device="cpu", dtype=torch.int
|
||||
),
|
||||
vllm_config=vllm_config,
|
||||
):
|
||||
for lidx, fml in enumerate(fml_layers):
|
||||
out_before_shuffle.append(
|
||||
fml(hidden_states[lidx].clone(), router_logits[lidx].clone())
|
||||
)
|
||||
|
||||
indices = torch.zeros(
|
||||
test_config.num_layers, test_config.num_experts, dtype=torch.long
|
||||
)
|
||||
for lidx in range(test_config.num_layers):
|
||||
indices[lidx] = torch.Tensor(range(test_config.num_experts))
|
||||
|
||||
shuffled_indices = torch.zeros_like(indices)
|
||||
for lidx in range(test_config.num_layers):
|
||||
shuffled_indices[lidx] = torch.randperm(test_config.num_experts)
|
||||
|
||||
rearrange_expert_weights_inplace(
|
||||
indices,
|
||||
shuffled_indices,
|
||||
rank_expert_weights,
|
||||
ep_group,
|
||||
is_profile=False,
|
||||
)
|
||||
|
||||
num_global_experts = test_config.num_experts
|
||||
|
||||
logical_to_physical_map_list = []
|
||||
for lidx, fml in enumerate(fml_layers):
|
||||
physical_to_logical_map = shuffled_indices[lidx].to(device)
|
||||
logical_to_physical_map = torch.empty(
|
||||
(num_global_experts,), dtype=torch.int32, device=device
|
||||
)
|
||||
logical_to_physical_map[physical_to_logical_map] = torch.arange(
|
||||
0, num_global_experts, dtype=torch.int32, device=device
|
||||
)
|
||||
logical_to_physical_map_list.append(
|
||||
logical_to_physical_map.reshape(num_global_experts, 1)
|
||||
)
|
||||
|
||||
logical_to_physical_map = torch.stack(logical_to_physical_map_list)
|
||||
|
||||
for lidx, fml in enumerate(fml_layers):
|
||||
logical_replica_count = torch.ones(
|
||||
(test_config.num_layers, num_global_experts),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
fml.enable_eplb = True
|
||||
fml.set_eplb_state(
|
||||
lidx,
|
||||
torch.zeros(
|
||||
(test_config.num_layers, num_global_experts),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
),
|
||||
logical_to_physical_map,
|
||||
logical_replica_count,
|
||||
)
|
||||
|
||||
out_after_shuffle = []
|
||||
with set_forward_context(
|
||||
{},
|
||||
num_tokens=test_config.num_tokens,
|
||||
num_tokens_across_dp=torch.tensor(
|
||||
[test_config.num_tokens] * world_size, device="cpu", dtype=torch.int
|
||||
),
|
||||
vllm_config=vllm_config,
|
||||
):
|
||||
for lidx, fml in enumerate(fml_layers):
|
||||
out_after_shuffle.append(
|
||||
fml(hidden_states[lidx].clone(), router_logits[lidx].clone())
|
||||
)
|
||||
|
||||
for lidx in range(test_config.num_layers):
|
||||
torch.testing.assert_close(
|
||||
out_before_shuffle[lidx], out_after_shuffle[lidx], atol=1e-1, rtol=1e-1
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("world_size", [2, 4])
|
||||
@pytest.mark.parametrize("num_layers", [8])
|
||||
@pytest.mark.parametrize("num_experts", [32])
|
||||
@pytest.mark.parametrize("hidden_size", [256])
|
||||
@pytest.mark.parametrize("intermediate_size", [256])
|
||||
@pytest.mark.parametrize("num_tokens", [256])
|
||||
@pytest.mark.parametrize("backend", ["latency", "throughput"])
|
||||
def test_eplb_fml(
|
||||
world_size: int,
|
||||
num_layers: int,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
num_tokens: int,
|
||||
backend: str,
|
||||
monkeypatch,
|
||||
):
|
||||
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1")
|
||||
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", backend)
|
||||
|
||||
if torch.cuda.device_count() < world_size:
|
||||
pytest.skip(f"Need at least {world_size} GPUs to run the test")
|
||||
|
||||
num_local_experts = num_experts // world_size
|
||||
num_topk = 4
|
||||
|
||||
test_config = TestConfig(
|
||||
num_layers=num_layers,
|
||||
num_experts=num_experts,
|
||||
num_local_experts=num_local_experts,
|
||||
num_topk=num_topk,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
num_tokens=num_tokens,
|
||||
)
|
||||
|
||||
distributed_run(
|
||||
_test_eplb_fml,
|
||||
world_size,
|
||||
test_config,
|
||||
)
|
||||
@ -350,21 +350,35 @@ def test_human_readable_model_len():
|
||||
assert args.max_model_len == 1_000_000
|
||||
args = parser.parse_args(["--max-model-len", "10k"])
|
||||
assert args.max_model_len == 10_000
|
||||
args = parser.parse_args(["--max-model-len", "2g"])
|
||||
assert args.max_model_len == 2_000_000_000
|
||||
args = parser.parse_args(["--max-model-len", "2t"])
|
||||
assert args.max_model_len == 2_000_000_000_000
|
||||
|
||||
# Capital
|
||||
args = parser.parse_args(["--max-model-len", "3K"])
|
||||
assert args.max_model_len == 1024 * 3
|
||||
assert args.max_model_len == 2**10 * 3
|
||||
args = parser.parse_args(["--max-model-len", "10M"])
|
||||
assert args.max_model_len == 2**20 * 10
|
||||
args = parser.parse_args(["--max-model-len", "4G"])
|
||||
assert args.max_model_len == 2**30 * 4
|
||||
args = parser.parse_args(["--max-model-len", "4T"])
|
||||
assert args.max_model_len == 2**40 * 4
|
||||
|
||||
# Decimal values
|
||||
args = parser.parse_args(["--max-model-len", "10.2k"])
|
||||
assert args.max_model_len == 10200
|
||||
# ..truncated to the nearest int
|
||||
args = parser.parse_args(["--max-model-len", "10.212345k"])
|
||||
args = parser.parse_args(["--max-model-len", "10.2123451234567k"])
|
||||
assert args.max_model_len == 10212
|
||||
args = parser.parse_args(["--max-model-len", "10.2123451234567m"])
|
||||
assert args.max_model_len == 10212345
|
||||
args = parser.parse_args(["--max-model-len", "10.2123451234567g"])
|
||||
assert args.max_model_len == 10212345123
|
||||
args = parser.parse_args(["--max-model-len", "10.2123451234567t"])
|
||||
assert args.max_model_len == 10212345123456
|
||||
|
||||
# Invalid (do not allow decimals with binary multipliers)
|
||||
for invalid in ["1a", "pwd", "10.24", "1.23M"]:
|
||||
for invalid in ["1a", "pwd", "10.24", "1.23M", "1.22T"]:
|
||||
with pytest.raises(ArgumentError):
|
||||
args = parser.parse_args(["--max-model-len", invalid])
|
||||
parser.parse_args(["--max-model-len", invalid])
|
||||
|
||||
@ -1,21 +1,37 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
from openai.types.responses import ResponseFunctionToolCall, ResponseReasoningItem
|
||||
from openai.types.responses.response_output_item import McpCall
|
||||
from openai_harmony import Author, Message, Role, TextContent
|
||||
|
||||
from tests.entrypoints.openai.utils import verify_harmony_messages
|
||||
from vllm.entrypoints.openai.parser.harmony_utils import (
|
||||
auto_drop_analysis_messages,
|
||||
get_encoding,
|
||||
has_custom_tools,
|
||||
parse_chat_input_to_harmony_message,
|
||||
parse_chat_output,
|
||||
parse_input_to_harmony_message,
|
||||
parse_output_message,
|
||||
)
|
||||
|
||||
|
||||
class TestParseInputToHarmonyMessage:
|
||||
"""Tests for parse_input_to_harmony_message function."""
|
||||
class TestCommonParseInputToHarmonyMessage:
|
||||
"""
|
||||
Tests for scenarios that are common to both Chat Completion
|
||||
parse_chat_input_to_harmony_message and Responsees API
|
||||
parse_input_to_harmony_message functions.
|
||||
"""
|
||||
|
||||
def test_assistant_message_with_tool_calls(self):
|
||||
@pytest.fixture(
|
||||
params=[parse_chat_input_to_harmony_message, parse_input_to_harmony_message]
|
||||
)
|
||||
def parse_function(self, request):
|
||||
return request.param
|
||||
|
||||
def test_assistant_message_with_tool_calls(self, parse_function):
|
||||
"""Test parsing assistant message with tool calls."""
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
@ -35,7 +51,7 @@ class TestParseInputToHarmonyMessage:
|
||||
],
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
messages = parse_function(chat_msg)
|
||||
|
||||
assert len(messages) == 2
|
||||
|
||||
@ -53,7 +69,7 @@ class TestParseInputToHarmonyMessage:
|
||||
assert messages[1].recipient == "functions.search_web"
|
||||
assert messages[1].content_type == "json"
|
||||
|
||||
def test_assistant_message_with_empty_tool_call_arguments(self):
|
||||
def test_assistant_message_with_empty_tool_call_arguments(self, parse_function):
|
||||
"""Test parsing assistant message with tool call having None arguments."""
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
@ -67,12 +83,152 @@ class TestParseInputToHarmonyMessage:
|
||||
],
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
messages = parse_function(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].content[0].text == ""
|
||||
assert messages[0].recipient == "functions.get_current_time"
|
||||
|
||||
def test_system_message(self, parse_function):
|
||||
"""Test parsing system message."""
|
||||
chat_msg = {
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant",
|
||||
}
|
||||
|
||||
messages = parse_function(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
# System messages are converted using Message.from_dict
|
||||
# which should preserve the role
|
||||
assert messages[0].author.role == Role.SYSTEM
|
||||
|
||||
def test_developer_message(self, parse_function):
|
||||
"""Test parsing developer message."""
|
||||
chat_msg = {
|
||||
"role": "developer",
|
||||
"content": "Use concise language",
|
||||
}
|
||||
|
||||
messages = parse_function(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.DEVELOPER
|
||||
|
||||
def test_user_message_with_string_content(self, parse_function):
|
||||
"""Test parsing user message with string content."""
|
||||
chat_msg = {
|
||||
"role": "user",
|
||||
"content": "What's the weather in San Francisco?",
|
||||
}
|
||||
|
||||
messages = parse_function(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.USER
|
||||
assert messages[0].content[0].text == "What's the weather in San Francisco?"
|
||||
|
||||
def test_user_message_with_array_content(self, parse_function):
|
||||
"""Test parsing user message with array content."""
|
||||
chat_msg = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"text": "What's in this image? "},
|
||||
{"text": "Please describe it."},
|
||||
],
|
||||
}
|
||||
|
||||
messages = parse_function(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.USER
|
||||
assert len(messages[0].content) == 2
|
||||
assert messages[0].content[0].text == "What's in this image? "
|
||||
assert messages[0].content[1].text == "Please describe it."
|
||||
|
||||
def test_assistant_message_with_string_content(self, parse_function):
|
||||
"""Test parsing assistant message with string content (no tool calls)."""
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I help you today?",
|
||||
}
|
||||
|
||||
messages = parse_function(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.ASSISTANT
|
||||
assert messages[0].content[0].text == "Hello! How can I help you today?"
|
||||
|
||||
def test_pydantic_model_input(self, parse_function):
|
||||
"""Test parsing Pydantic model input (has model_dump method)."""
|
||||
|
||||
class MockPydanticModel:
|
||||
def model_dump(self, exclude_none=True):
|
||||
return {
|
||||
"role": "user",
|
||||
"content": "Test message",
|
||||
}
|
||||
|
||||
chat_msg = MockPydanticModel()
|
||||
messages = parse_function(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.USER
|
||||
assert messages[0].content[0].text == "Test message"
|
||||
|
||||
def test_tool_call_with_missing_function_fields(self, parse_function):
|
||||
"""Test parsing tool call with missing name or arguments."""
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {} # Missing both name and arguments
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
messages = parse_function(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].recipient == "functions."
|
||||
assert messages[0].content[0].text == ""
|
||||
|
||||
def test_array_content_with_missing_text(self, parse_function):
|
||||
"""Test parsing array content where text field is missing."""
|
||||
chat_msg = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{}, # Missing text field
|
||||
{"text": "actual text"},
|
||||
],
|
||||
}
|
||||
|
||||
messages = parse_function(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert len(messages[0].content) == 2
|
||||
assert messages[0].content[0].text == ""
|
||||
assert messages[0].content[1].text == "actual text"
|
||||
|
||||
|
||||
class TestParseInputToHarmonyMessage:
|
||||
"""
|
||||
Tests for scenarios that are specific to the Responses API
|
||||
parse_input_to_harmony_message function.
|
||||
"""
|
||||
|
||||
def test_message_with_empty_content(self):
|
||||
"""Test parsing message with empty string content."""
|
||||
chat_msg = {
|
||||
"role": "user",
|
||||
"content": "",
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].content[0].text == ""
|
||||
|
||||
def test_tool_message_with_string_content(self):
|
||||
"""Test parsing tool message with string content."""
|
||||
chat_msg = {
|
||||
@ -111,6 +267,7 @@ class TestParseInputToHarmonyMessage:
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.TOOL
|
||||
assert messages[0].author.name == "functions.search_results"
|
||||
assert messages[0].content[0].text == "Result 1: Result 2: Result 3"
|
||||
|
||||
def test_tool_message_with_empty_content(self):
|
||||
@ -124,140 +281,564 @@ class TestParseInputToHarmonyMessage:
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.TOOL
|
||||
assert messages[0].author.name == "functions.empty_tool"
|
||||
assert messages[0].content[0].text == ""
|
||||
|
||||
def test_system_message(self):
|
||||
"""Test parsing system message."""
|
||||
chat_msg = {
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant",
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
class TestParseChatInputToHarmonyMessage:
|
||||
"""
|
||||
Tests for scenarios that are specific to the Chat Completion API
|
||||
parse_chat_input_to_harmony_message function.
|
||||
"""
|
||||
|
||||
assert len(messages) == 1
|
||||
# System messages are converted using Message.from_dict
|
||||
# which should preserve the role
|
||||
assert messages[0].author.role == Role.SYSTEM
|
||||
|
||||
def test_developer_message(self):
|
||||
"""Test parsing developer message."""
|
||||
chat_msg = {
|
||||
"role": "developer",
|
||||
"content": "Use concise language",
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.DEVELOPER
|
||||
|
||||
def test_user_message_with_string_content(self):
|
||||
"""Test parsing user message with string content."""
|
||||
chat_msg = {
|
||||
"role": "user",
|
||||
"content": "What's the weather in San Francisco?",
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.USER
|
||||
assert messages[0].content[0].text == "What's the weather in San Francisco?"
|
||||
|
||||
def test_user_message_with_array_content(self):
|
||||
"""Test parsing user message with array content."""
|
||||
chat_msg = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"text": "What's in this image? "},
|
||||
{"text": "Please describe it."},
|
||||
],
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.USER
|
||||
assert len(messages[0].content) == 2
|
||||
assert messages[0].content[0].text == "What's in this image? "
|
||||
assert messages[0].content[1].text == "Please describe it."
|
||||
|
||||
def test_assistant_message_with_string_content(self):
|
||||
"""Test parsing assistant message with string content (no tool calls)."""
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I help you today?",
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.ASSISTANT
|
||||
assert messages[0].content[0].text == "Hello! How can I help you today?"
|
||||
|
||||
def test_pydantic_model_input(self):
|
||||
"""Test parsing Pydantic model input (has model_dump method)."""
|
||||
|
||||
class MockPydanticModel:
|
||||
def model_dump(self, exclude_none=True):
|
||||
return {
|
||||
"role": "user",
|
||||
"content": "Test message",
|
||||
}
|
||||
|
||||
chat_msg = MockPydanticModel()
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.USER
|
||||
assert messages[0].content[0].text == "Test message"
|
||||
|
||||
def test_message_with_empty_content(self):
|
||||
"""Test parsing message with empty string content."""
|
||||
def test_user_message_with_empty_content(self):
|
||||
chat_msg = {
|
||||
"role": "user",
|
||||
"content": "",
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].content[0].text == ""
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_tool_call_with_missing_function_fields(self):
|
||||
"""Test parsing tool call with missing name or arguments."""
|
||||
def test_user_message_with_none_content(self):
|
||||
chat_msg = {
|
||||
"role": "user",
|
||||
"content": None,
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_assistant_message_with_empty_content(self):
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 0
|
||||
|
||||
def test_assistant_message_with_none_content(self):
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 0
|
||||
|
||||
def test_assistant_message_with_content_but_empty_reasoning(self):
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"content": "The answer is 4.",
|
||||
"reasoning": "",
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "final",
|
||||
"content": "The answer is 4.",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_assistant_message_with_reasoning_but_empty_content(self):
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"reasoning": "I'm thinking about the user's question.",
|
||||
"content": "",
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "analysis",
|
||||
"content": "I'm thinking about the user's question.",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_assistant_message_with_reasoning_but_none_content(self):
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"reasoning": "I'm thinking about the user's question.",
|
||||
"content": None,
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "analysis",
|
||||
"content": "I'm thinking about the user's question.",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_assistant_message_with_tool_calls_but_no_content(self):
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {} # Missing both name and arguments
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location": "San Francisco"}',
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].recipient == "functions."
|
||||
assert messages[0].content[0].text == ""
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "commentary",
|
||||
"recipient": "functions.get_weather",
|
||||
"content": '{"location": "San Francisco"}',
|
||||
"content_type": "json",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_array_content_with_missing_text(self):
|
||||
"""Test parsing array content where text field is missing."""
|
||||
def test_assistant_message_with_tool_calls_and_content(self):
|
||||
chat_msg = {
|
||||
"role": "user",
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location": "San Francisco"}',
|
||||
}
|
||||
}
|
||||
],
|
||||
"content": "I'll call the tool.",
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "commentary",
|
||||
"content": "I'll call the tool.",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "commentary",
|
||||
"recipient": "functions.get_weather",
|
||||
"content": '{"location": "San Francisco"}',
|
||||
"content_type": "json",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_assistant_message_with_tool_calls_and_reasoning(self):
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location": "San Francisco"}',
|
||||
}
|
||||
}
|
||||
],
|
||||
"reasoning": "I should use the get_weather tool.",
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "analysis",
|
||||
"content": "I should use the get_weather tool.",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "commentary",
|
||||
"recipient": "functions.get_weather",
|
||||
"content": '{"location": "San Francisco"}',
|
||||
"content_type": "json",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_assistant_message_with_tool_calls_and_reasoning_and_content(self):
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location": "San Francisco"}',
|
||||
}
|
||||
}
|
||||
],
|
||||
"reasoning": "I should use the get_weather tool.",
|
||||
"content": "I'll call the tool.",
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "commentary",
|
||||
"content": "I'll call the tool.",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "analysis",
|
||||
"content": "I should use the get_weather tool.",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "commentary",
|
||||
"recipient": "functions.get_weather",
|
||||
"content": '{"location": "San Francisco"}',
|
||||
"content_type": "json",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_tool_message_with_string_content(self):
|
||||
tool_id_names = {
|
||||
"call_123": "get_weather",
|
||||
}
|
||||
chat_msg = {
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"content": "The weather in San Francisco is sunny, 72°F",
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(
|
||||
chat_msg, tool_id_names=tool_id_names
|
||||
)
|
||||
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "tool",
|
||||
"name": "functions.get_weather",
|
||||
"content": "The weather in San Francisco is sunny, 72°F",
|
||||
"channel": "commentary",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_tool_message_with_array_content(self):
|
||||
tool_id_names = {
|
||||
"call_123": "search_results",
|
||||
}
|
||||
chat_msg = {
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"content": [
|
||||
{}, # Missing text field
|
||||
{"text": "actual text"},
|
||||
{"type": "text", "text": "Result 1: "},
|
||||
{"type": "text", "text": "Result 2: "},
|
||||
{
|
||||
"type": "image",
|
||||
"url": "http://example.com/img.png",
|
||||
}, # Should be ignored
|
||||
{"type": "text", "text": "Result 3"},
|
||||
],
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
messages = parse_chat_input_to_harmony_message(
|
||||
chat_msg, tool_id_names=tool_id_names
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert len(messages[0].content) == 2
|
||||
assert messages[0].content[0].text == ""
|
||||
assert messages[0].content[1].text == "actual text"
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "tool",
|
||||
"name": "functions.search_results",
|
||||
"content": "Result 1: Result 2: Result 3",
|
||||
"channel": "commentary",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_tool_message_with_empty_content(self):
|
||||
tool_id_names = {
|
||||
"call_123": "empty_tool",
|
||||
}
|
||||
chat_msg = {
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"content": "",
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(
|
||||
chat_msg, tool_id_names=tool_id_names
|
||||
)
|
||||
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "tool",
|
||||
"name": "functions.empty_tool",
|
||||
"content": "",
|
||||
"channel": "commentary",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_tool_message_with_none_content(self):
|
||||
tool_id_names = {
|
||||
"call_123": "empty_tool",
|
||||
}
|
||||
chat_msg = {
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"content": None,
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(
|
||||
chat_msg, tool_id_names=tool_id_names
|
||||
)
|
||||
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "tool",
|
||||
"name": "functions.empty_tool",
|
||||
"content": "",
|
||||
"channel": "commentary",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class TestAutoDropAnalysisMessages:
|
||||
def test_no_analysis_messages(self) -> None:
|
||||
messages = [
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "The answer is 4."
|
||||
).with_channel("final"),
|
||||
]
|
||||
cleaned_messages = auto_drop_analysis_messages(messages)
|
||||
assert cleaned_messages == messages
|
||||
|
||||
def test_only_analysis_message(self) -> None:
|
||||
messages = [
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking about the user's question."
|
||||
).with_channel("analysis"),
|
||||
]
|
||||
cleaned_messages = auto_drop_analysis_messages(messages)
|
||||
assert cleaned_messages == messages
|
||||
|
||||
def test_multiple_analysis_messages_without_final_message(self) -> None:
|
||||
messages = [
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking about the user's question."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking more."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking even more."
|
||||
).with_channel("analysis"),
|
||||
]
|
||||
cleaned_messages = auto_drop_analysis_messages(messages)
|
||||
assert cleaned_messages == messages
|
||||
|
||||
def test_only_final_message(self) -> None:
|
||||
messages = [
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "The answer is 4."
|
||||
).with_channel("final"),
|
||||
]
|
||||
cleaned_messages = auto_drop_analysis_messages(messages)
|
||||
assert cleaned_messages == messages
|
||||
|
||||
def test_drops_one_analysis_messages_before_final_message(self) -> None:
|
||||
messages = [
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking about the user's question."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "The answer is 4."
|
||||
).with_channel("final"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I should think harder."
|
||||
).with_channel("analysis"),
|
||||
]
|
||||
cleaned_messages = auto_drop_analysis_messages(messages)
|
||||
# Should have dropped the first analysis message
|
||||
assert cleaned_messages == messages[1:]
|
||||
|
||||
def test_drops_all_analysis_messages_before_final_message(self) -> None:
|
||||
messages = [
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking about the user's question."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking more."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking even more."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "The answer is 4."
|
||||
).with_channel("final"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I should think harder."
|
||||
).with_channel("analysis"),
|
||||
]
|
||||
cleaned_messages = auto_drop_analysis_messages(messages)
|
||||
# Should have dropped the first 3 analysis messages
|
||||
assert cleaned_messages == messages[3:]
|
||||
|
||||
def test_multiple_analysis_messages_with_multiple_final_messages(self) -> None:
|
||||
messages = [
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking about the user's question."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking more."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking even more."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "The answer is 4."
|
||||
).with_channel("final"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I should think harder."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "The answer is 5."
|
||||
).with_channel("final"),
|
||||
]
|
||||
cleaned_messages = auto_drop_analysis_messages(messages)
|
||||
# Should have dropped all those analysis messages
|
||||
assert len(cleaned_messages) == 2
|
||||
assert cleaned_messages[0].content[0].text == "The answer is 4."
|
||||
assert cleaned_messages[1].content[0].text == "The answer is 5."
|
||||
|
||||
def test_drops_non_assistant_analysis_messages(self) -> None:
|
||||
messages = [
|
||||
Message.from_role_and_content(
|
||||
Role.TOOL, "The tool thinks we should think harder."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "The answer is 4."
|
||||
).with_channel("final"),
|
||||
]
|
||||
cleaned_messages = auto_drop_analysis_messages(messages)
|
||||
# Should have dropped the analysis message
|
||||
assert cleaned_messages == messages[1:]
|
||||
|
||||
|
||||
class TestParseChatOutput:
|
||||
def test_parse_chat_output_interrupted_first_message(self) -> None:
|
||||
harmony_str = "<|channel|>final<|message|>I'm in the middle of answering"
|
||||
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
|
||||
reasoning, final_content, _ = parse_chat_output(token_ids)
|
||||
assert reasoning is None
|
||||
assert final_content == "I'm in the middle of answering"
|
||||
|
||||
def test_parse_chat_output_interrupted_reasoning_first_message(self) -> None:
|
||||
harmony_str = "<|channel|>analysis<|message|>I'm in the middle of thinking"
|
||||
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
|
||||
reasoning, final_content, _ = parse_chat_output(token_ids)
|
||||
assert reasoning == "I'm in the middle of thinking"
|
||||
assert final_content is None
|
||||
|
||||
def test_parse_chat_output_complete_reasoning_interrupted_content(self) -> None:
|
||||
harmony_str = (
|
||||
"<|channel|>analysis<|message|>I'm thinking.<|end|>"
|
||||
"<|start|>assistant<|channel|>final"
|
||||
"<|message|>I'm in the middle of answering"
|
||||
)
|
||||
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
|
||||
reasoning, final_content, _ = parse_chat_output(token_ids)
|
||||
assert reasoning == "I'm thinking."
|
||||
assert final_content == "I'm in the middle of answering"
|
||||
|
||||
def test_parse_chat_output_complete_content(self) -> None:
|
||||
harmony_str = "<|channel|>final<|message|>The answer is 4.<|end|>"
|
||||
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
|
||||
reasoning, final_content, _ = parse_chat_output(token_ids)
|
||||
assert reasoning is None
|
||||
assert final_content == "The answer is 4."
|
||||
|
||||
def test_parse_chat_output_complete_commentary(self) -> None:
|
||||
harmony_str = (
|
||||
"<|channel|>commentary<|message|>I need to call some tools.<|end|>"
|
||||
)
|
||||
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
|
||||
reasoning, final_content, _ = parse_chat_output(token_ids)
|
||||
assert reasoning is None
|
||||
assert final_content == "I need to call some tools."
|
||||
|
||||
def test_parse_chat_output_complete_reasoning(self) -> None:
|
||||
harmony_str = (
|
||||
"<|channel|>analysis<|message|>I've thought hard about this.<|end|>"
|
||||
)
|
||||
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
|
||||
reasoning, final_content, _ = parse_chat_output(token_ids)
|
||||
assert reasoning == "I've thought hard about this."
|
||||
assert final_content is None
|
||||
|
||||
def test_parse_chat_output_complete_reasoning_and_content(self) -> None:
|
||||
harmony_str = (
|
||||
"<|channel|>analysis<|message|>I've thought hard about this.<|end|>"
|
||||
"<|start|>assistant<|channel|>final<|message|>The answer is 4.<|end|>"
|
||||
)
|
||||
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
|
||||
reasoning, final_content, _ = parse_chat_output(token_ids)
|
||||
assert reasoning == "I've thought hard about this."
|
||||
assert final_content == "The answer is 4."
|
||||
|
||||
|
||||
class TestParseOutputMessage:
|
||||
|
||||
227
tests/entrypoints/openai/test_chat_error.py
Normal file
227
tests/entrypoints/openai/test_chat_error.py
Normal file
@ -0,0 +1,227 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config.multimodal import MultiModalConfig
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ErrorResponse
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
MODEL_NAME = "openai-community/gpt2"
|
||||
MODEL_NAME_SHORT = "gpt2"
|
||||
BASE_MODEL_PATHS = [
|
||||
BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME),
|
||||
BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT),
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockHFConfig:
|
||||
model_type: str = "any"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockModelConfig:
|
||||
task = "generate"
|
||||
runner_type = "generate"
|
||||
tokenizer = MODEL_NAME
|
||||
trust_remote_code = False
|
||||
tokenizer_mode = "auto"
|
||||
max_model_len = 100
|
||||
tokenizer_revision = None
|
||||
multimodal_config = MultiModalConfig()
|
||||
hf_config = MockHFConfig()
|
||||
logits_processor_pattern = None
|
||||
logits_processors: list[str] | None = None
|
||||
diff_sampling_param: dict | None = None
|
||||
allowed_local_media_path: str = ""
|
||||
allowed_media_domains: list[str] | None = None
|
||||
encoder_config = None
|
||||
generation_config: str = "auto"
|
||||
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
skip_tokenizer_init = False
|
||||
|
||||
def get_diff_sampling_param(self):
|
||||
return self.diff_sampling_param or {}
|
||||
|
||||
|
||||
def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
|
||||
models = OpenAIServingModels(
|
||||
engine_client=engine,
|
||||
base_model_paths=BASE_MODEL_PATHS,
|
||||
)
|
||||
serving_chat = OpenAIServingChat(
|
||||
engine,
|
||||
models,
|
||||
response_role="assistant",
|
||||
request_logger=None,
|
||||
chat_template=None,
|
||||
chat_template_content_format="auto",
|
||||
)
|
||||
|
||||
async def _fake_process_inputs(
|
||||
request_id,
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
*,
|
||||
lora_request,
|
||||
trace_headers,
|
||||
priority,
|
||||
):
|
||||
return dict(engine_prompt), {}
|
||||
|
||||
async def _fake_preprocess_chat(*args, **kwargs):
|
||||
# return conversation, engine_prompts
|
||||
return (
|
||||
[{"role": "user", "content": "Test"}],
|
||||
[{"prompt_token_ids": [1, 2, 3]}],
|
||||
)
|
||||
|
||||
serving_chat._process_inputs = AsyncMock(side_effect=_fake_process_inputs)
|
||||
serving_chat._preprocess_chat = AsyncMock(side_effect=_fake_preprocess_chat)
|
||||
return serving_chat
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_error_non_stream():
|
||||
"""test finish_reason='error' returns 500 InternalServerError (non-streaming)"""
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
|
||||
completion_output = CompletionOutput(
|
||||
index=0,
|
||||
text="",
|
||||
token_ids=[],
|
||||
cumulative_logprob=None,
|
||||
logprobs=None,
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
request_output = RequestOutput(
|
||||
request_id="test-id",
|
||||
prompt="Test prompt",
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
prompt_logprobs=None,
|
||||
outputs=[completion_output],
|
||||
finished=True,
|
||||
metrics=None,
|
||||
lora_request=None,
|
||||
encoder_prompt=None,
|
||||
encoder_prompt_token_ids=None,
|
||||
)
|
||||
|
||||
async def mock_generate(*args, **kwargs):
|
||||
yield request_output
|
||||
|
||||
mock_engine.generate = MagicMock(side_effect=mock_generate)
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[{"role": "user", "content": "Test prompt"}],
|
||||
max_tokens=10,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
response = await serving_chat.create_chat_completion(request)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert response.error.type == "InternalServerError"
|
||||
assert response.error.message == "Internal server error"
|
||||
assert response.error.code == HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_error_stream():
|
||||
"""test finish_reason='error' returns 500 InternalServerError (streaming)"""
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
|
||||
completion_output_1 = CompletionOutput(
|
||||
index=0,
|
||||
text="Hello",
|
||||
token_ids=[100],
|
||||
cumulative_logprob=None,
|
||||
logprobs=None,
|
||||
finish_reason=None,
|
||||
)
|
||||
|
||||
request_output_1 = RequestOutput(
|
||||
request_id="test-id",
|
||||
prompt="Test prompt",
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
prompt_logprobs=None,
|
||||
outputs=[completion_output_1],
|
||||
finished=False,
|
||||
metrics=None,
|
||||
lora_request=None,
|
||||
encoder_prompt=None,
|
||||
encoder_prompt_token_ids=None,
|
||||
)
|
||||
|
||||
completion_output_2 = CompletionOutput(
|
||||
index=0,
|
||||
text="Hello",
|
||||
token_ids=[100],
|
||||
cumulative_logprob=None,
|
||||
logprobs=None,
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
request_output_2 = RequestOutput(
|
||||
request_id="test-id",
|
||||
prompt="Test prompt",
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
prompt_logprobs=None,
|
||||
outputs=[completion_output_2],
|
||||
finished=True,
|
||||
metrics=None,
|
||||
lora_request=None,
|
||||
encoder_prompt=None,
|
||||
encoder_prompt_token_ids=None,
|
||||
)
|
||||
|
||||
async def mock_generate(*args, **kwargs):
|
||||
yield request_output_1
|
||||
yield request_output_2
|
||||
|
||||
mock_engine.generate = MagicMock(side_effect=mock_generate)
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[{"role": "user", "content": "Test prompt"}],
|
||||
max_tokens=10,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
response = await serving_chat.create_chat_completion(request)
|
||||
|
||||
chunks = []
|
||||
async for chunk in response:
|
||||
chunks.append(chunk)
|
||||
|
||||
assert len(chunks) >= 2
|
||||
assert any("Internal server error" in chunk for chunk in chunks), (
|
||||
f"Expected error message in chunks: {chunks}"
|
||||
)
|
||||
assert chunks[-1] == "data: [DONE]\n\n"
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user