mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-18 19:27:09 +08:00
Merge branch 'main' into v1-block-table-opt
This commit is contained in:
commit
caacd1ddfb
@ -43,7 +43,7 @@ main() {
|
||||
|
||||
|
||||
|
||||
# The figures should be genereated by a separate process outside the CI/CD pipeline
|
||||
# The figures should be generated by a separate process outside the CI/CD pipeline
|
||||
|
||||
# # generate figures
|
||||
# python3 -m pip install tabulate pandas matplotlib
|
||||
|
||||
@ -301,6 +301,104 @@ run_serving_tests() {
|
||||
kill_gpu_processes
|
||||
}
|
||||
|
||||
run_genai_perf_tests() {
|
||||
# run genai-perf tests
|
||||
|
||||
# $1: a json file specifying genai-perf test cases
|
||||
local genai_perf_test_file
|
||||
genai_perf_test_file=$1
|
||||
|
||||
# Iterate over genai-perf tests
|
||||
jq -c '.[]' "$genai_perf_test_file" | while read -r params; do
|
||||
# get the test name, and append the GPU type back to it.
|
||||
test_name=$(echo "$params" | jq -r '.test_name')
|
||||
|
||||
# if TEST_SELECTOR is set, only run the test cases that match the selector
|
||||
if [[ -n "$TEST_SELECTOR" ]] && [[ ! "$test_name" =~ $TEST_SELECTOR ]]; then
|
||||
echo "Skip test case $test_name."
|
||||
continue
|
||||
fi
|
||||
|
||||
# prepend the current serving engine to the test name
|
||||
test_name=${CURRENT_LLM_SERVING_ENGINE}_${test_name}
|
||||
|
||||
# get common parameters
|
||||
common_params=$(echo "$params" | jq -r '.common_parameters')
|
||||
model=$(echo "$common_params" | jq -r '.model')
|
||||
tp=$(echo "$common_params" | jq -r '.tp')
|
||||
dataset_name=$(echo "$common_params" | jq -r '.dataset_name')
|
||||
dataset_path=$(echo "$common_params" | jq -r '.dataset_path')
|
||||
port=$(echo "$common_params" | jq -r '.port')
|
||||
num_prompts=$(echo "$common_params" | jq -r '.num_prompts')
|
||||
reuse_server=$(echo "$common_params" | jq -r '.reuse_server')
|
||||
|
||||
# get client and server arguments
|
||||
server_params=$(echo "$params" | jq -r ".${CURRENT_LLM_SERVING_ENGINE}_server_parameters")
|
||||
qps_list=$(echo "$params" | jq -r '.qps_list')
|
||||
qps_list=$(echo "$qps_list" | jq -r '.[] | @sh')
|
||||
echo "Running over qps list $qps_list"
|
||||
|
||||
# check if there is enough GPU to run the test
|
||||
if [[ $gpu_count -lt $tp ]]; then
|
||||
echo "Required num-shard $tp but only $gpu_count GPU found. Skip testcase $test_name."
|
||||
continue
|
||||
fi
|
||||
|
||||
if [[ $reuse_server == "true" ]]; then
|
||||
echo "Reuse previous server for test case $test_name"
|
||||
else
|
||||
kill_gpu_processes
|
||||
bash "$VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/scripts/launch-server.sh" \
|
||||
"$server_params" "$common_params"
|
||||
fi
|
||||
|
||||
if wait_for_server; then
|
||||
echo ""
|
||||
echo "$CURRENT_LLM_SERVING_ENGINE server is up and running."
|
||||
else
|
||||
echo ""
|
||||
echo "$CURRENT_LLM_SERVING_ENGINE failed to start within the timeout period."
|
||||
break
|
||||
fi
|
||||
|
||||
# iterate over different QPS
|
||||
for qps in $qps_list; do
|
||||
# remove the surrounding single quote from qps
|
||||
if [[ "$qps" == *"inf"* ]]; then
|
||||
echo "qps was $qps"
|
||||
qps=$num_prompts
|
||||
echo "now qps is $qps"
|
||||
fi
|
||||
|
||||
new_test_name=$test_name"_qps_"$qps
|
||||
backend=$CURRENT_LLM_SERVING_ENGINE
|
||||
|
||||
if [[ "$backend" == *"vllm"* ]]; then
|
||||
backend="vllm"
|
||||
fi
|
||||
#TODO: add output dir.
|
||||
client_command="genai-perf profile \
|
||||
-m $model \
|
||||
--service-kind openai \
|
||||
--backend vllm \
|
||||
--endpoint-type chat \
|
||||
--streaming \
|
||||
--url localhost:$port \
|
||||
--request-rate $qps \
|
||||
--num-prompts $num_prompts \
|
||||
"
|
||||
|
||||
echo "Client command: $client_command"
|
||||
|
||||
eval "$client_command"
|
||||
|
||||
#TODO: process/record outputs
|
||||
done
|
||||
done
|
||||
|
||||
kill_gpu_processes
|
||||
|
||||
}
|
||||
|
||||
prepare_dataset() {
|
||||
|
||||
@ -328,12 +426,17 @@ main() {
|
||||
|
||||
pip install -U transformers
|
||||
|
||||
pip install -r requirements-dev.txt
|
||||
which genai-perf
|
||||
|
||||
# check storage
|
||||
df -h
|
||||
|
||||
ensure_installed wget
|
||||
ensure_installed curl
|
||||
ensure_installed jq
|
||||
# genai-perf dependency
|
||||
ensure_installed libb64-0d
|
||||
|
||||
prepare_dataset
|
||||
|
||||
@ -345,6 +448,10 @@ main() {
|
||||
# run the test
|
||||
run_serving_tests "$BENCHMARK_ROOT/tests/nightly-tests.json"
|
||||
|
||||
# run genai-perf tests
|
||||
run_genai_perf_tests "$BENCHMARK_ROOT/tests/genai-perf-tests.json"
|
||||
mv artifacts/ $RESULTS_FOLDER/
|
||||
|
||||
# upload benchmark results to buildkite
|
||||
python3 -m pip install tabulate pandas
|
||||
python3 "$BENCHMARK_ROOT/scripts/summary-nightly-results.py"
|
||||
|
||||
23
.buildkite/nightly-benchmarks/tests/genai-perf-tests.json
Normal file
23
.buildkite/nightly-benchmarks/tests/genai-perf-tests.json
Normal file
@ -0,0 +1,23 @@
|
||||
[
|
||||
{
|
||||
"test_name": "llama8B_tp1_genai_perf",
|
||||
"qps_list": [4,8,16,32],
|
||||
"common_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
"tp": 1,
|
||||
"port": 8000,
|
||||
"num_prompts": 500,
|
||||
"reuse_server": false
|
||||
},
|
||||
"vllm_server_parameters": {
|
||||
"disable_log_stats": "",
|
||||
"disable_log_requests": "",
|
||||
"gpu_memory_utilization": 0.9,
|
||||
"num_scheduler_steps": 10,
|
||||
"max_num_seqs": 512,
|
||||
"dtype": "bfloat16"
|
||||
},
|
||||
"genai_perf_input_parameters": {
|
||||
}
|
||||
}
|
||||
]
|
||||
@ -83,6 +83,6 @@ function cpu_tests() {
|
||||
tests/lora/test_qwen2vl.py"
|
||||
}
|
||||
|
||||
# All of CPU tests are expected to be finished less than 25 mins.
|
||||
# All of CPU tests are expected to be finished less than 40 mins.
|
||||
export -f cpu_tests
|
||||
timeout 30m bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"
|
||||
timeout 40m bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"
|
||||
|
||||
@ -25,8 +25,11 @@ if [ -f /tmp/neuron-docker-build-timestamp ]; then
|
||||
last_build=$(cat /tmp/neuron-docker-build-timestamp)
|
||||
current_time=$(date +%s)
|
||||
if [ $((current_time - last_build)) -gt 86400 ]; then
|
||||
# Remove dangling images (those that are not tagged and not used by any container)
|
||||
docker image prune -f
|
||||
docker system prune -f
|
||||
# Remove unused volumes / force the system prune for old images as well.
|
||||
docker volume prune -f && docker system prune -f
|
||||
# Remove huggingface model artifacts and compiler cache
|
||||
rm -rf "${HF_MOUNT:?}/*"
|
||||
rm -rf "${NEURON_COMPILE_CACHE_MOUNT:?}/*"
|
||||
echo "$current_time" > /tmp/neuron-docker-build-timestamp
|
||||
|
||||
@ -52,7 +52,6 @@ steps:
|
||||
- tests/worker
|
||||
- tests/standalone_tests/lazy_torch_compile.py
|
||||
commands:
|
||||
- pip install git+https://github.com/Isotr0py/DeepSeek-VL2.git # Used by multimoda processing test
|
||||
- python3 standalone_tests/lazy_torch_compile.py
|
||||
- pytest -v -s mq_llm_engine # MQLLMEngine
|
||||
- pytest -v -s async_engine # AsyncLLMEngine
|
||||
@ -77,7 +76,9 @@ steps:
|
||||
- tests/basic_correctness/test_basic_correctness
|
||||
- tests/basic_correctness/test_cpu_offload
|
||||
- tests/basic_correctness/test_preemption
|
||||
- tests/basic_correctness/test_cumem.py
|
||||
commands:
|
||||
- pytest -v -s basic_correctness/test_cumem.py
|
||||
- pytest -v -s basic_correctness/test_basic_correctness.py
|
||||
- pytest -v -s basic_correctness/test_cpu_offload.py
|
||||
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
|
||||
@ -107,7 +108,7 @@ steps:
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
commands:
|
||||
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py
|
||||
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py --ignore=entrypoints/llm/test_collective_rpc.py
|
||||
- pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
|
||||
@ -126,11 +127,15 @@ steps:
|
||||
- tests/distributed
|
||||
- tests/spec_decode/e2e/test_integration_dist_tp4
|
||||
- tests/compile
|
||||
- examples/offline_inference/rlhf.py
|
||||
commands:
|
||||
- pytest -v -s distributed/test_utils.py
|
||||
- pytest -v -s compile/test_basic_correctness.py
|
||||
- pytest -v -s distributed/test_pynccl.py
|
||||
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
|
||||
# TODO: create a dedicated test section for multi-GPU example tests
|
||||
# when we have multiple distributed example tests
|
||||
- python3 ../examples/offline_inference/rlhf.py
|
||||
|
||||
- label: Metrics, Tracing Test # 10min
|
||||
num_gpus: 2
|
||||
@ -462,7 +467,10 @@ steps:
|
||||
- vllm/worker/worker_base.py
|
||||
- vllm/worker/worker.py
|
||||
- vllm/worker/model_runner.py
|
||||
- entrypoints/llm/test_collective_rpc.py
|
||||
commands:
|
||||
- pytest -v -s entrypoints/llm/test_collective_rpc.py
|
||||
- torchrun --nproc-per-node=2 distributed/test_torchrun_example.py
|
||||
- pytest -v -s ./compile/test_basic_correctness.py
|
||||
- pytest -v -s ./compile/test_wrapper.py
|
||||
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
|
||||
@ -471,7 +479,9 @@ steps:
|
||||
- pytest models/encoder_decoder/language/test_bart.py -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
|
||||
# this test fails consistently.
|
||||
# TODO: investigate and fix
|
||||
# - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/disagg_test.py
|
||||
|
||||
@ -509,7 +519,9 @@ steps:
|
||||
- vllm/engine
|
||||
- tests/multi_step
|
||||
commands:
|
||||
- pytest -v -s multi_step/test_correctness_async_llm.py
|
||||
# this test is quite flaky
|
||||
# TODO: investigate and fix.
|
||||
# - pytest -v -s multi_step/test_correctness_async_llm.py
|
||||
- pytest -v -s multi_step/test_correctness_llm.py
|
||||
|
||||
- label: Pipeline Parallelism Test # 45min
|
||||
|
||||
27
.github/CODEOWNERS
vendored
27
.github/CODEOWNERS
vendored
@ -2,32 +2,35 @@
|
||||
# for more info about CODEOWNERS file
|
||||
|
||||
# This lists cover the "core" components of vLLM that require careful review
|
||||
/vllm/attention/backends/abstract.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
|
||||
/vllm/core @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
|
||||
/vllm/engine/llm_engine.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
|
||||
/vllm/executor/executor_base.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
|
||||
/vllm/worker/worker_base.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
|
||||
/vllm/worker/worker.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
|
||||
/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
|
||||
/vllm/attention/backends/abstract.py @WoosukKwon @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/core @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/engine/llm_engine.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/executor/executor_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/worker/worker_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/worker/worker.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth
|
||||
/vllm/model_executor/guided_decoding @mgoin
|
||||
/vllm/multimodal @DarkLight1337 @ywang96
|
||||
CMakeLists.txt @tlrmchlsmth
|
||||
|
||||
# vLLM V1
|
||||
/vllm/v1 @WoosukKwon @robertgshaw2-neuralmagic @njhill @ywang96 @comaniac @alexm-neuralmagic
|
||||
/vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat
|
||||
|
||||
# Test ownership
|
||||
/tests/async_engine @njhill @robertgshaw2-neuralmagic @simon-mo
|
||||
/tests/async_engine @njhill @robertgshaw2-redhat @simon-mo
|
||||
/tests/test_inputs.py @DarkLight1337 @ywang96
|
||||
/tests/entrypoints @DarkLight1337 @robertgshaw2-neuralmagic @simon-mo
|
||||
/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @simon-mo
|
||||
/tests/models @DarkLight1337 @ywang96
|
||||
/tests/multimodal @DarkLight1337 @ywang96
|
||||
/tests/prefix_caching @comaniac @KuntaiDu
|
||||
/tests/spec_decode @njhill @LiuXiaoxuanPKU
|
||||
/tests/kernels @tlrmchlsmth @WoosukKwon
|
||||
/tests/quantization @mgoin @robertgshaw2-neuralmagic
|
||||
/tests/quantization @mgoin @robertgshaw2-redhat
|
||||
/.buildkite/lm-eval-harness @mgoin @simon-mo
|
||||
/tests/distributed/test_multi_node_assignment.py @youkaichao
|
||||
/tests/distributed/test_pipeline_parallel.py @youkaichao
|
||||
/tests/distributed/test_same_node.py @youkaichao
|
||||
/tests/multi_step @alexm-neuralmagic @comaniac
|
||||
/tests/multi_step @alexm-redhat @comaniac
|
||||
/tests/weight_loading @mgoin @youkaichao
|
||||
/tests/basic_correctness/test_chunked_prefill @rkooo567 @comaniac
|
||||
|
||||
40
.github/workflows/actionlint.yml
vendored
40
.github/workflows/actionlint.yml
vendored
@ -1,40 +0,0 @@
|
||||
name: Lint GitHub Actions workflows
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- "main"
|
||||
paths:
|
||||
- '.github/workflows/*.ya?ml'
|
||||
- '.github/workflows/actionlint.*'
|
||||
- '.github/workflows/matchers/actionlint.json'
|
||||
pull_request:
|
||||
branches:
|
||||
- "main"
|
||||
paths:
|
||||
- '.github/workflows/*.ya?ml'
|
||||
- '.github/workflows/actionlint.*'
|
||||
- '.github/workflows/matchers/actionlint.json'
|
||||
|
||||
env:
|
||||
LC_ALL: en_US.UTF-8
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
actionlint:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: "Checkout"
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: "Run actionlint"
|
||||
run: |
|
||||
echo "::add-matcher::.github/workflows/matchers/actionlint.json"
|
||||
tools/actionlint.sh -color
|
||||
53
.github/workflows/clang-format.yml
vendored
53
.github/workflows/clang-format.yml
vendored
@ -1,53 +0,0 @@
|
||||
name: clang-format
|
||||
|
||||
on:
|
||||
# Trigger the workflow on push or pull request,
|
||||
# but only for the main branch
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- '**/*.h'
|
||||
- '**/*.cpp'
|
||||
- '**/*.cu'
|
||||
- '**/*.cuh'
|
||||
- '.github/workflows/clang-format.yml'
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- '**/*.h'
|
||||
- '**/*.cpp'
|
||||
- '**/*.cu'
|
||||
- '**/*.cuh'
|
||||
- '.github/workflows/clang-format.yml'
|
||||
|
||||
jobs:
|
||||
clang-format:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.11"]
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install clang-format==18.1.5
|
||||
- name: Running clang-format
|
||||
run: |
|
||||
EXCLUDES=(
|
||||
'csrc/moe/topk_softmax_kernels.cu'
|
||||
'csrc/quantization/gguf/ggml-common.h'
|
||||
'csrc/quantization/gguf/dequantize.cuh'
|
||||
'csrc/quantization/gguf/vecdotq.cuh'
|
||||
'csrc/quantization/gguf/mmq.cuh'
|
||||
'csrc/quantization/gguf/mmvq.cuh'
|
||||
)
|
||||
find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \
|
||||
| grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \
|
||||
| xargs clang-format --dry-run --Werror
|
||||
45
.github/workflows/codespell.yml
vendored
45
.github/workflows/codespell.yml
vendored
@ -1,45 +0,0 @@
|
||||
name: codespell
|
||||
|
||||
on:
|
||||
# Trigger the workflow on push or pull request,
|
||||
# but only for the main branch
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "**/*.py"
|
||||
- "**/*.md"
|
||||
- "**/*.rst"
|
||||
- pyproject.toml
|
||||
- requirements-lint.txt
|
||||
- .github/workflows/codespell.yml
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "**/*.py"
|
||||
- "**/*.md"
|
||||
- "**/*.rst"
|
||||
- pyproject.toml
|
||||
- requirements-lint.txt
|
||||
- .github/workflows/codespell.yml
|
||||
|
||||
jobs:
|
||||
codespell:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.12"]
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements-lint.txt
|
||||
- name: Spelling check with codespell
|
||||
run: |
|
||||
codespell --toml pyproject.toml
|
||||
32
.github/workflows/doc-lint.yml
vendored
32
.github/workflows/doc-lint.yml
vendored
@ -1,32 +0,0 @@
|
||||
name: Lint documentation
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "docs/**"
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "docs/**"
|
||||
|
||||
jobs:
|
||||
doc-lint:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.12"]
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements-lint.txt
|
||||
- name: Linting docs
|
||||
run: tools/doc-lint.sh
|
||||
17
.github/workflows/matchers/ruff.json
vendored
17
.github/workflows/matchers/ruff.json
vendored
@ -1,17 +0,0 @@
|
||||
{
|
||||
"problemMatcher": [
|
||||
{
|
||||
"owner": "ruff",
|
||||
"pattern": [
|
||||
{
|
||||
"regexp": "^(.+?):(\\d+):(\\d+): (\\w+): (.+)$",
|
||||
"file": 1,
|
||||
"line": 2,
|
||||
"column": 3,
|
||||
"code": 4,
|
||||
"message": 5
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
51
.github/workflows/mypy.yaml
vendored
51
.github/workflows/mypy.yaml
vendored
@ -1,51 +0,0 @@
|
||||
name: mypy
|
||||
|
||||
on:
|
||||
# Trigger the workflow on push or pull request,
|
||||
# but only for the main branch
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- '**/*.py'
|
||||
- '.github/workflows/mypy.yaml'
|
||||
- 'tools/mypy.sh'
|
||||
- 'pyproject.toml'
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
# This workflow is only relevant when one of the following files changes.
|
||||
# However, we have github configured to expect and require this workflow
|
||||
# to run and pass before github with auto-merge a pull request. Until github
|
||||
# allows more flexible auto-merge policy, we can just run this on every PR.
|
||||
# It doesn't take that long to run, anyway.
|
||||
#paths:
|
||||
# - '**/*.py'
|
||||
# - '.github/workflows/mypy.yaml'
|
||||
# - 'tools/mypy.sh'
|
||||
# - 'pyproject.toml'
|
||||
|
||||
jobs:
|
||||
mypy:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install mypy==1.11.1
|
||||
pip install types-setuptools
|
||||
pip install types-PyYAML
|
||||
pip install types-requests
|
||||
pip install types-setuptools
|
||||
- name: Mypy
|
||||
run: |
|
||||
echo "::add-matcher::.github/workflows/matchers/mypy.json"
|
||||
tools/mypy.sh 1 ${{ matrix.python-version }}
|
||||
37
.github/workflows/png-lint.yml
vendored
37
.github/workflows/png-lint.yml
vendored
@ -1,37 +0,0 @@
|
||||
name: Lint PNG exports from excalidraw
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- "main"
|
||||
paths:
|
||||
- '*.excalidraw.png'
|
||||
- '.github/workflows/png-lint.yml'
|
||||
pull_request:
|
||||
branches:
|
||||
- "main"
|
||||
paths:
|
||||
- '*.excalidraw.png'
|
||||
- '.github/workflows/png-lint.yml'
|
||||
|
||||
env:
|
||||
LC_ALL: en_US.UTF-8
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
actionlint:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: "Checkout"
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: "Run png-lint.sh to check excalidraw exported images"
|
||||
run: |
|
||||
tools/png-lint.sh
|
||||
19
.github/workflows/pre-commit.yml
vendored
Normal file
19
.github/workflows/pre-commit.yml
vendored
Normal file
@ -0,0 +1,19 @@
|
||||
name: pre-commit
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
push:
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||
with:
|
||||
python-version: "3.12"
|
||||
- run: echo "::add-matcher::.github/workflows/matchers/actionlint.json"
|
||||
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
|
||||
with:
|
||||
extra_args: --all-files --hook-stage manual
|
||||
52
.github/workflows/ruff.yml
vendored
52
.github/workflows/ruff.yml
vendored
@ -1,52 +0,0 @@
|
||||
name: ruff
|
||||
|
||||
on:
|
||||
# Trigger the workflow on push or pull request,
|
||||
# but only for the main branch
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "**/*.py"
|
||||
- pyproject.toml
|
||||
- requirements-lint.txt
|
||||
- .github/workflows/matchers/ruff.json
|
||||
- .github/workflows/ruff.yml
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
# This workflow is only relevant when one of the following files changes.
|
||||
# However, we have github configured to expect and require this workflow
|
||||
# to run and pass before github with auto-merge a pull request. Until github
|
||||
# allows more flexible auto-merge policy, we can just run this on every PR.
|
||||
# It doesn't take that long to run, anyway.
|
||||
#paths:
|
||||
# - "**/*.py"
|
||||
# - pyproject.toml
|
||||
# - requirements-lint.txt
|
||||
# - .github/workflows/matchers/ruff.json
|
||||
# - .github/workflows/ruff.yml
|
||||
|
||||
jobs:
|
||||
ruff:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.12"]
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements-lint.txt
|
||||
- name: Analysing the code with ruff
|
||||
run: |
|
||||
echo "::add-matcher::.github/workflows/matchers/ruff.json"
|
||||
ruff check --output-format github .
|
||||
- name: Run isort
|
||||
run: |
|
||||
isort . --check-only
|
||||
37
.github/workflows/shellcheck.yml
vendored
37
.github/workflows/shellcheck.yml
vendored
@ -1,37 +0,0 @@
|
||||
name: Lint shell scripts
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- "main"
|
||||
paths:
|
||||
- '**/*.sh'
|
||||
- '.github/workflows/shellcheck.yml'
|
||||
pull_request:
|
||||
branches:
|
||||
- "main"
|
||||
paths:
|
||||
- '**/*.sh'
|
||||
- '.github/workflows/shellcheck.yml'
|
||||
|
||||
env:
|
||||
LC_ALL: en_US.UTF-8
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
shellcheck:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: "Checkout"
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: "Check shell scripts"
|
||||
run: |
|
||||
tools/shellcheck.sh
|
||||
38
.github/workflows/yapf.yml
vendored
38
.github/workflows/yapf.yml
vendored
@ -1,38 +0,0 @@
|
||||
name: yapf
|
||||
|
||||
on:
|
||||
# Trigger the workflow on push or pull request,
|
||||
# but only for the main branch
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "**/*.py"
|
||||
- .github/workflows/yapf.yml
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "**/*.py"
|
||||
- .github/workflows/yapf.yml
|
||||
|
||||
jobs:
|
||||
yapf:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.12"]
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install yapf==0.32.0
|
||||
pip install toml==0.10.2
|
||||
- name: Running yapf
|
||||
run: |
|
||||
yapf --diff --recursive .
|
||||
92
.pre-commit-config.yaml
Normal file
92
.pre-commit-config.yaml
Normal file
@ -0,0 +1,92 @@
|
||||
default_stages:
|
||||
- pre-commit # Run locally
|
||||
- manual # Run in CI
|
||||
repos:
|
||||
- repo: https://github.com/google/yapf
|
||||
rev: v0.32.0
|
||||
hooks:
|
||||
- id: yapf
|
||||
args: [--in-place, --verbose]
|
||||
additional_dependencies: [toml] # TODO: Remove when yapf is upgraded
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.6.5
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--output-format, github]
|
||||
- repo: https://github.com/codespell-project/codespell
|
||||
rev: v2.3.0
|
||||
hooks:
|
||||
- id: codespell
|
||||
exclude: 'benchmarks/sonnet.txt|(build|tests/(lora/data|models/fixtures|prompts))/.*'
|
||||
- repo: https://github.com/PyCQA/isort
|
||||
rev: 5.13.2
|
||||
hooks:
|
||||
- id: isort
|
||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||
rev: v18.1.5
|
||||
hooks:
|
||||
- id: clang-format
|
||||
exclude: 'csrc/(moe/topk_softmax_kernels.cu|quantization/gguf/(ggml-common.h|dequantize.cuh|vecdotq.cuh|mmq.cuh|mmvq.cuh))'
|
||||
types_or: [c++, cuda]
|
||||
args: [--style=file, --verbose]
|
||||
- repo: https://github.com/jackdewinter/pymarkdown
|
||||
rev: v0.9.27
|
||||
hooks:
|
||||
- id: pymarkdown
|
||||
files: docs/.*
|
||||
- repo: https://github.com/rhysd/actionlint
|
||||
rev: v1.7.6
|
||||
hooks:
|
||||
- id: actionlint
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: mypy-local
|
||||
name: Run mypy for local Python installation
|
||||
entry: tools/mypy.sh 0 "local"
|
||||
language: python
|
||||
types: [python]
|
||||
additional_dependencies: &mypy_deps [mypy==1.11.1, types-setuptools, types-PyYAML, types-requests]
|
||||
stages: [pre-commit] # Don't run in CI
|
||||
- id: mypy-3.9 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
|
||||
name: Run mypy for Python 3.9
|
||||
entry: tools/mypy.sh 1 "3.9"
|
||||
language: python
|
||||
types: [python]
|
||||
additional_dependencies: *mypy_deps
|
||||
stages: [manual] # Only run in CI
|
||||
- id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
|
||||
name: Run mypy for Python 3.10
|
||||
entry: tools/mypy.sh 1 "3.10"
|
||||
language: python
|
||||
types: [python]
|
||||
additional_dependencies: *mypy_deps
|
||||
stages: [manual] # Only run in CI
|
||||
- id: mypy-3.11 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
|
||||
name: Run mypy for Python 3.11
|
||||
entry: tools/mypy.sh 1 "3.11"
|
||||
language: python
|
||||
types: [python]
|
||||
additional_dependencies: *mypy_deps
|
||||
stages: [manual] # Only run in CI
|
||||
- id: mypy-3.12 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
|
||||
name: Run mypy for Python 3.12
|
||||
entry: tools/mypy.sh 1 "3.12"
|
||||
language: python
|
||||
types: [python]
|
||||
additional_dependencies: *mypy_deps
|
||||
stages: [manual] # Only run in CI
|
||||
- id: shellcheck
|
||||
name: Lint shell scripts
|
||||
entry: tools/shellcheck.sh
|
||||
language: script
|
||||
types: [shell]
|
||||
- id: png-lint
|
||||
name: Lint PNG exports from excalidraw
|
||||
entry: tools/png-lint.sh
|
||||
language: script
|
||||
types: [png]
|
||||
- id: suggestion
|
||||
name: Suggestion
|
||||
entry: bash -c 'echo "To bypass pre-commit hooks, add --no-verify to git commit."'
|
||||
language: system
|
||||
verbose: true
|
||||
@ -24,9 +24,6 @@ include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
|
||||
# Suppress potential warnings about unused manually-specified variables
|
||||
set(ignoreMe "${VLLM_PYTHON_PATH}")
|
||||
|
||||
# Prevent installation of dependencies (cutlass) by default.
|
||||
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
|
||||
|
||||
#
|
||||
# Supported python versions. These versions will be searched in order, the
|
||||
# first match will be selected. These should be kept in sync with setup.py.
|
||||
@ -181,6 +178,31 @@ message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}")
|
||||
# Define other extension targets
|
||||
#
|
||||
|
||||
#
|
||||
# cumem_allocator extension
|
||||
#
|
||||
|
||||
set(VLLM_CUMEM_EXT_SRC
|
||||
"csrc/cumem_allocator.cpp")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${VLLM_CUMEM_EXT_SRC}"
|
||||
CUDA_ARCHS "${CUDA_ARCHS}")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
message(STATUS "Enabling cumem allocator extension.")
|
||||
# link against cuda driver library
|
||||
list(APPEND CUMEM_LIBS cuda)
|
||||
define_gpu_extension_target(
|
||||
cumem_allocator
|
||||
DESTINATION vllm
|
||||
LANGUAGE CXX
|
||||
SOURCES ${VLLM_CUMEM_EXT_SRC}
|
||||
LIBRARIES ${CUMEM_LIBS}
|
||||
USE_SABI 3.8
|
||||
WITH_SOABI)
|
||||
endif()
|
||||
|
||||
#
|
||||
# _C extension
|
||||
#
|
||||
@ -511,7 +533,7 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
|
||||
endif()
|
||||
|
||||
# vllm-flash-attn currently only supported on CUDA
|
||||
if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda")
|
||||
if (NOT VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
return()
|
||||
endif ()
|
||||
|
||||
@ -534,7 +556,7 @@ endif()
|
||||
# They should be identical but if they aren't, this is a massive footgun.
|
||||
#
|
||||
# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place.
|
||||
# To only install vllm-flash-attn, use --component vllm_flash_attn_c.
|
||||
# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2) or --component _vllm_fa3_C (for FA3).
|
||||
# If no component is specified, vllm-flash-attn is still installed.
|
||||
|
||||
# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading.
|
||||
@ -546,43 +568,41 @@ if (DEFINED ENV{VLLM_FLASH_ATTN_SRC_DIR})
|
||||
endif()
|
||||
|
||||
if(VLLM_FLASH_ATTN_SRC_DIR)
|
||||
FetchContent_Declare(vllm-flash-attn SOURCE_DIR ${VLLM_FLASH_ATTN_SRC_DIR})
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn SOURCE_DIR
|
||||
${VLLM_FLASH_ATTN_SRC_DIR}
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
)
|
||||
else()
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn
|
||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||
GIT_TAG 96266b1111111f3d11aabefaf3bacbab6a89d03c
|
||||
GIT_TAG 90eacc1af2a7c3de62ea249e929ed5faccf38954
|
||||
GIT_PROGRESS TRUE
|
||||
# Don't share the vllm-flash-attn build between build types
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
)
|
||||
endif()
|
||||
|
||||
# Set the parent build flag so that the vllm-flash-attn library does not redo compile flag and arch initialization.
|
||||
set(VLLM_PARENT_BUILD ON)
|
||||
|
||||
# Ensure the vllm/vllm_flash_attn directory exists before installation
|
||||
install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")" COMPONENT vllm_flash_attn_c)
|
||||
|
||||
# Make sure vllm-flash-attn install rules are nested under vllm/
|
||||
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" COMPONENT vllm_flash_attn_c)
|
||||
install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c)
|
||||
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" COMPONENT vllm_flash_attn_c)
|
||||
|
||||
# Fetch the vllm-flash-attn library
|
||||
FetchContent_MakeAvailable(vllm-flash-attn)
|
||||
message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}")
|
||||
|
||||
# Restore the install prefix
|
||||
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c)
|
||||
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" COMPONENT vllm_flash_attn_c)
|
||||
|
||||
# Copy over the vllm-flash-attn python files
|
||||
# Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in
|
||||
# case only one is built, in the case both are built redundant work is done)
|
||||
install(
|
||||
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
|
||||
DESTINATION vllm/vllm_flash_attn
|
||||
COMPONENT vllm_flash_attn_c
|
||||
FILES_MATCHING PATTERN "*.py"
|
||||
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
|
||||
DESTINATION vllm_flash_attn
|
||||
COMPONENT _vllm_fa2_C
|
||||
FILES_MATCHING PATTERN "*.py"
|
||||
)
|
||||
|
||||
install(
|
||||
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
|
||||
DESTINATION vllm_flash_attn
|
||||
COMPONENT _vllm_fa3_C
|
||||
FILES_MATCHING PATTERN "*.py"
|
||||
)
|
||||
|
||||
# Nothing after vllm-flash-attn, see comment about macros above
|
||||
|
||||
@ -52,7 +52,7 @@ WORKDIR /workspace
|
||||
# after this step
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
|
||||
python3 -m pip install --index-url https://download.pytorch.org/whl/nightly/cu124 "torch==2.6.0.dev20241210+cu124" "torchvision==0.22.0.dev20241215"; \
|
||||
python3 -m pip install --index-url https://download.pytorch.org/whl/nightly/cu126 "torch==2.7.0.dev20250121+cu126" "torchvision==0.22.0.dev20250121"; \
|
||||
fi
|
||||
|
||||
COPY requirements-common.txt requirements-common.txt
|
||||
|
||||
@ -26,10 +26,10 @@ RUN pip install intel_extension_for_pytorch==2.5.0
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
COPY requirements-build.txt requirements-build.txt
|
||||
ARG PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu"
|
||||
ENV PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL}
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
--mount=type=bind,src=requirements-build.txt,target=requirements-build.txt \
|
||||
pip install --upgrade pip && \
|
||||
pip install -r requirements-build.txt
|
||||
|
||||
@ -37,9 +37,9 @@ FROM cpu-test-1 AS build
|
||||
|
||||
WORKDIR /workspace/vllm
|
||||
|
||||
COPY requirements-common.txt requirements-common.txt
|
||||
COPY requirements-cpu.txt requirements-cpu.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
--mount=type=bind,src=requirements-common.txt,target=requirements-common.txt \
|
||||
--mount=type=bind,src=requirements-cpu.txt,target=requirements-cpu.txt \
|
||||
pip install -v -r requirements-cpu.txt
|
||||
|
||||
COPY . .
|
||||
|
||||
261
Dockerfile.rocm
261
Dockerfile.rocm
@ -1,174 +1,119 @@
|
||||
# Default ROCm 6.2 base image
|
||||
ARG BASE_IMAGE="rocm/pytorch:rocm6.2_ubuntu20.04_py3.9_pytorch_release_2.3.0"
|
||||
# default base image
|
||||
ARG REMOTE_VLLM="0"
|
||||
ARG USE_CYTHON="0"
|
||||
ARG BUILD_RPD="1"
|
||||
ARG COMMON_WORKDIR=/app
|
||||
ARG BASE_IMAGE=rocm/vllm-dev:base
|
||||
|
||||
# Default ROCm ARCHes to build vLLM for.
|
||||
ARG PYTORCH_ROCM_ARCH="gfx908;gfx90a;gfx942;gfx1100"
|
||||
FROM ${BASE_IMAGE} AS base
|
||||
|
||||
# Whether to install CK-based flash-attention
|
||||
# If 0, will not install flash-attention
|
||||
ARG BUILD_FA="1"
|
||||
ARG FA_GFX_ARCHS="gfx90a;gfx942"
|
||||
ARG FA_BRANCH="3cea2fb"
|
||||
|
||||
# Whether to build triton on rocm
|
||||
ARG BUILD_TRITON="1"
|
||||
ARG TRITON_BRANCH="e192dba"
|
||||
|
||||
### Base image build stage
|
||||
FROM $BASE_IMAGE AS base
|
||||
|
||||
# Import arg(s) defined before this build stage
|
||||
ARG PYTORCH_ROCM_ARCH
|
||||
ARG ARG_PYTORCH_ROCM_ARCH
|
||||
ENV PYTORCH_ROCM_ARCH=${ARG_PYTORCH_ROCM_ARCH:-${PYTORCH_ROCM_ARCH}}
|
||||
|
||||
# Install some basic utilities
|
||||
RUN apt-get update && apt-get install python3 python3-pip -y
|
||||
RUN apt-get update && apt-get install -y \
|
||||
curl \
|
||||
ca-certificates \
|
||||
sudo \
|
||||
git \
|
||||
bzip2 \
|
||||
libx11-6 \
|
||||
build-essential \
|
||||
wget \
|
||||
unzip \
|
||||
tmux \
|
||||
ccache \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# When launching the container, mount the code directory to /vllm-workspace
|
||||
ARG APP_MOUNT=/vllm-workspace
|
||||
WORKDIR ${APP_MOUNT}
|
||||
|
||||
RUN python3 -m pip install --upgrade pip
|
||||
# Remove sccache so it doesn't interfere with ccache
|
||||
# TODO: implement sccache support across components
|
||||
RUN apt-get update -q -y && apt-get install -q -y \
|
||||
sqlite3 libsqlite3-dev libfmt-dev libmsgpack-dev libsuitesparse-dev
|
||||
# Remove sccache
|
||||
RUN python3 -m pip install --upgrade pip && pip install setuptools_scm
|
||||
RUN apt-get purge -y sccache; python3 -m pip uninstall -y sccache; rm -f "$(which sccache)"
|
||||
|
||||
# Install torch == 2.6.0 on ROCm
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
|
||||
*"rocm-6.2"*) \
|
||||
python3 -m pip uninstall -y torch torchvision \
|
||||
&& python3 -m pip install --pre \
|
||||
torch==2.6.0.dev20241113+rocm6.2 \
|
||||
'setuptools-scm>=8' \
|
||||
torchvision==0.20.0.dev20241113+rocm6.2 \
|
||||
--extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2;; \
|
||||
*) ;; esac
|
||||
|
||||
ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
|
||||
ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin:
|
||||
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
|
||||
ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:
|
||||
|
||||
ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
|
||||
ENV CCACHE_DIR=/root/.cache/ccache
|
||||
ARG COMMON_WORKDIR
|
||||
WORKDIR ${COMMON_WORKDIR}
|
||||
|
||||
|
||||
### AMD-SMI build stage
|
||||
FROM base AS build_amdsmi
|
||||
# Build amdsmi wheel always
|
||||
RUN cd /opt/rocm/share/amd_smi \
|
||||
&& python3 -m pip wheel . --wheel-dir=/install
|
||||
# -----------------------
|
||||
# vLLM fetch stages
|
||||
FROM base AS fetch_vllm_0
|
||||
ONBUILD COPY ./ vllm/
|
||||
FROM base AS fetch_vllm_1
|
||||
ARG VLLM_REPO="https://github.com/vllm-project/vllm.git"
|
||||
ARG VLLM_BRANCH="main"
|
||||
ONBUILD RUN git clone ${VLLM_REPO} \
|
||||
&& cd vllm \
|
||||
&& git checkout ${VLLM_BRANCH}
|
||||
FROM fetch_vllm_${REMOTE_VLLM} AS fetch_vllm
|
||||
|
||||
# -----------------------
|
||||
# vLLM build stages
|
||||
FROM fetch_vllm AS build_vllm
|
||||
ARG USE_CYTHON
|
||||
# Build vLLM
|
||||
RUN cd vllm \
|
||||
&& python3 -m pip install -r requirements-rocm.txt \
|
||||
&& python3 setup.py clean --all \
|
||||
&& if [ ${USE_CYTHON} -eq "1" ]; then python3 setup_cython.py build_ext --inplace; fi \
|
||||
&& python3 setup.py bdist_wheel --dist-dir=dist
|
||||
FROM scratch AS export_vllm
|
||||
ARG COMMON_WORKDIR
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/dist/*.whl /
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/requirements*.txt /
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/benchmarks /benchmarks
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/tests /tests
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/examples /examples
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/.buildkite /.buildkite
|
||||
|
||||
### Flash-Attention wheel build stage
|
||||
FROM base AS build_fa
|
||||
ARG BUILD_FA
|
||||
ARG FA_GFX_ARCHS
|
||||
ARG FA_BRANCH
|
||||
# Build ROCm flash-attention wheel if `BUILD_FA = 1`
|
||||
RUN --mount=type=cache,target=${CCACHE_DIR} \
|
||||
if [ "$BUILD_FA" = "1" ]; then \
|
||||
mkdir -p libs \
|
||||
&& cd libs \
|
||||
&& git clone https://github.com/ROCm/flash-attention.git \
|
||||
&& cd flash-attention \
|
||||
&& git checkout "${FA_BRANCH}" \
|
||||
&& git submodule update --init \
|
||||
&& GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \
|
||||
# Create an empty directory otherwise as later build stages expect one
|
||||
else mkdir -p /install; \
|
||||
fi
|
||||
# -----------------------
|
||||
# Test vLLM image
|
||||
FROM base AS test
|
||||
|
||||
RUN python3 -m pip install --upgrade pip && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
### Triton wheel build stage
|
||||
FROM base AS build_triton
|
||||
ARG BUILD_TRITON
|
||||
ARG TRITON_BRANCH
|
||||
# Build triton wheel if `BUILD_TRITON = 1`
|
||||
RUN --mount=type=cache,target=${CCACHE_DIR} \
|
||||
if [ "$BUILD_TRITON" = "1" ]; then \
|
||||
mkdir -p libs \
|
||||
&& cd libs \
|
||||
&& python3 -m pip install ninja cmake wheel pybind11 \
|
||||
&& git clone https://github.com/OpenAI/triton.git \
|
||||
&& cd triton \
|
||||
&& git checkout "${TRITON_BRANCH}" \
|
||||
&& cd python \
|
||||
&& python3 setup.py bdist_wheel --dist-dir=/install; \
|
||||
# Create an empty directory otherwise as later build stages expect one
|
||||
else mkdir -p /install; \
|
||||
fi
|
||||
# Install vLLM
|
||||
RUN --mount=type=bind,from=export_vllm,src=/,target=/install \
|
||||
cd /install \
|
||||
&& pip install -U -r requirements-rocm.txt \
|
||||
&& pip uninstall -y vllm \
|
||||
&& pip install *.whl
|
||||
|
||||
|
||||
### Final vLLM build stage
|
||||
FROM base AS final
|
||||
# Import the vLLM development directory from the build context
|
||||
COPY . .
|
||||
ARG GIT_REPO_CHECK=0
|
||||
RUN --mount=type=bind,source=.git,target=.git \
|
||||
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi
|
||||
|
||||
RUN python3 -m pip install --upgrade pip
|
||||
|
||||
# Package upgrades for useful functionality or to avoid dependency issues
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install --upgrade numba scipy huggingface-hub[cli] pytest-shard
|
||||
|
||||
|
||||
# Workaround for ray >= 2.10.0
|
||||
ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1
|
||||
# Silences the HF Tokenizers warning
|
||||
ENV TOKENIZERS_PARALLELISM=false
|
||||
|
||||
RUN --mount=type=cache,target=${CCACHE_DIR} \
|
||||
--mount=type=bind,source=.git,target=.git \
|
||||
--mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install -Ur requirements-rocm.txt \
|
||||
&& python3 setup.py clean --all \
|
||||
&& python3 setup.py develop
|
||||
|
||||
# Copy amdsmi wheel into final image
|
||||
RUN --mount=type=bind,from=build_amdsmi,src=/install,target=/install \
|
||||
mkdir -p libs \
|
||||
&& cp /install/*.whl libs \
|
||||
# Preemptively uninstall to avoid same-version no-installs
|
||||
&& python3 -m pip uninstall -y amdsmi;
|
||||
|
||||
# Copy triton wheel(s) into final image if they were built
|
||||
RUN --mount=type=bind,from=build_triton,src=/install,target=/install \
|
||||
mkdir -p libs \
|
||||
&& if ls /install/*.whl; then \
|
||||
cp /install/*.whl libs \
|
||||
# Preemptively uninstall to avoid same-version no-installs
|
||||
&& python3 -m pip uninstall -y triton; fi
|
||||
|
||||
# Copy flash-attn wheel(s) into final image if they were built
|
||||
RUN --mount=type=bind,from=build_fa,src=/install,target=/install \
|
||||
mkdir -p libs \
|
||||
&& if ls /install/*.whl; then \
|
||||
cp /install/*.whl libs \
|
||||
# Preemptively uninstall to avoid same-version no-installs
|
||||
&& python3 -m pip uninstall -y flash-attn; fi
|
||||
|
||||
# Install wheels that were built to the final image
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
if ls libs/*.whl; then \
|
||||
python3 -m pip install libs/*.whl; fi
|
||||
WORKDIR /vllm-workspace
|
||||
ARG COMMON_WORKDIR
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm /vllm-workspace
|
||||
|
||||
# install development dependencies (for testing)
|
||||
RUN python3 -m pip install -e tests/vllm_test_utils
|
||||
RUN cd /vllm-workspace \
|
||||
&& rm -rf vllm \
|
||||
&& python3 -m pip install -e tests/vllm_test_utils \
|
||||
&& python3 -m pip install lm-eval[api]==0.4.4 \
|
||||
&& python3 -m pip install pytest-shard
|
||||
|
||||
# -----------------------
|
||||
# Final vLLM image
|
||||
FROM base AS final
|
||||
|
||||
RUN python3 -m pip install --upgrade pip && rm -rf /var/lib/apt/lists/*
|
||||
# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
|
||||
# Manually remove it so that later steps of numpy upgrade can continue
|
||||
RUN case "$(which python3)" in \
|
||||
*"/opt/conda/envs/py_3.9"*) \
|
||||
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/;; \
|
||||
*) ;; esac
|
||||
|
||||
RUN python3 -m pip install --upgrade huggingface-hub[cli]
|
||||
ARG BUILD_RPD
|
||||
RUN if [ ${BUILD_RPD} -eq "1" ]; then \
|
||||
git clone -b nvtx_enabled https://github.com/ROCm/rocmProfileData.git \
|
||||
&& cd rocmProfileData/rpd_tracer \
|
||||
&& pip install -r requirements.txt && cd ../ \
|
||||
&& make && make install \
|
||||
&& cd hipMarker && python3 setup.py install ; fi
|
||||
|
||||
# Install vLLM
|
||||
RUN --mount=type=bind,from=export_vllm,src=/,target=/install \
|
||||
cd /install \
|
||||
&& pip install -U -r requirements-rocm.txt \
|
||||
&& pip uninstall -y vllm \
|
||||
&& pip install *.whl
|
||||
|
||||
ARG COMMON_WORKDIR
|
||||
|
||||
# Copy over the benchmark scripts as well
|
||||
COPY --from=export_vllm /benchmarks ${COMMON_WORKDIR}/vllm/benchmarks
|
||||
COPY --from=export_vllm /examples ${COMMON_WORKDIR}/vllm/examples
|
||||
|
||||
ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1
|
||||
ENV TOKENIZERS_PARALLELISM=false
|
||||
|
||||
# Performance environment variable.
|
||||
ENV HIP_FORCE_DEV_KERNARG=1
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
|
||||
158
Dockerfile.rocm_base
Normal file
158
Dockerfile.rocm_base
Normal file
@ -0,0 +1,158 @@
|
||||
ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:6.3.1-complete
|
||||
ARG HIPBLASLT_BRANCH="4d40e36"
|
||||
ARG HIPBLAS_COMMON_BRANCH="7c1566b"
|
||||
ARG LEGACY_HIPBLASLT_OPTION=
|
||||
ARG RCCL_BRANCH="648a58d"
|
||||
ARG RCCL_REPO="https://github.com/ROCm/rccl"
|
||||
ARG TRITON_BRANCH="e5be006"
|
||||
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
|
||||
ARG PYTORCH_BRANCH="8d4926e"
|
||||
ARG PYTORCH_VISION_BRANCH="v0.19.1"
|
||||
ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
|
||||
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
|
||||
ARG FA_BRANCH="b7d29fb"
|
||||
ARG FA_REPO="https://github.com/ROCm/flash-attention.git"
|
||||
|
||||
FROM ${BASE_IMAGE} AS base
|
||||
|
||||
ENV PATH=/opt/rocm/llvm/bin:$PATH
|
||||
ENV ROCM_PATH=/opt/rocm
|
||||
ENV LD_LIBRARY_PATH=/opt/rocm/lib:/usr/local/lib:
|
||||
ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942
|
||||
ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
|
||||
|
||||
ARG PYTHON_VERSION=3.12
|
||||
|
||||
RUN mkdir -p /app
|
||||
WORKDIR /app
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Install Python and other dependencies
|
||||
RUN apt-get update -y \
|
||||
&& apt-get install -y software-properties-common git curl sudo vim less \
|
||||
&& add-apt-repository ppa:deadsnakes/ppa \
|
||||
&& apt-get update -y \
|
||||
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
|
||||
python${PYTHON_VERSION}-lib2to3 python-is-python3 \
|
||||
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
|
||||
&& update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \
|
||||
&& ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \
|
||||
&& curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \
|
||||
&& python3 --version && python3 -m pip --version
|
||||
|
||||
RUN pip install -U packaging cmake ninja wheel setuptools pybind11 Cython
|
||||
|
||||
FROM base AS build_hipblaslt
|
||||
ARG HIPBLASLT_BRANCH
|
||||
ARG HIPBLAS_COMMON_BRANCH
|
||||
# Set to "--legacy_hipblas_direct" for ROCm<=6.2
|
||||
ARG LEGACY_HIPBLASLT_OPTION
|
||||
RUN git clone https://github.com/ROCm/hipBLAS-common.git
|
||||
RUN cd hipBLAS-common \
|
||||
&& git checkout ${HIPBLAS_COMMON_BRANCH} \
|
||||
&& mkdir build \
|
||||
&& cd build \
|
||||
&& cmake .. \
|
||||
&& make package \
|
||||
&& dpkg -i ./*.deb
|
||||
RUN git clone https://github.com/ROCm/hipBLASLt
|
||||
RUN cd hipBLASLt \
|
||||
&& git checkout ${HIPBLASLT_BRANCH} \
|
||||
&& ./install.sh -d --architecture ${PYTORCH_ROCM_ARCH} ${LEGACY_HIPBLASLT_OPTION} \
|
||||
&& cd build/release \
|
||||
&& make package
|
||||
RUN mkdir -p /app/install && cp /app/hipBLASLt/build/release/*.deb /app/hipBLAS-common/build/*.deb /app/install
|
||||
|
||||
FROM base AS build_rccl
|
||||
ARG RCCL_BRANCH
|
||||
ARG RCCL_REPO
|
||||
RUN git clone ${RCCL_REPO}
|
||||
RUN cd rccl \
|
||||
&& git checkout ${RCCL_BRANCH} \
|
||||
&& ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH}
|
||||
RUN mkdir -p /app/install && cp /app/rccl/build/release/*.deb /app/install
|
||||
|
||||
FROM base AS build_triton
|
||||
ARG TRITON_BRANCH
|
||||
ARG TRITON_REPO
|
||||
RUN git clone ${TRITON_REPO}
|
||||
RUN cd triton \
|
||||
&& git checkout ${TRITON_BRANCH} \
|
||||
&& cd python \
|
||||
&& python3 setup.py bdist_wheel --dist-dir=dist
|
||||
RUN mkdir -p /app/install && cp /app/triton/python/dist/*.whl /app/install
|
||||
|
||||
FROM base AS build_amdsmi
|
||||
RUN cd /opt/rocm/share/amd_smi \
|
||||
&& pip wheel . --wheel-dir=dist
|
||||
RUN mkdir -p /app/install && cp /opt/rocm/share/amd_smi/dist/*.whl /app/install
|
||||
|
||||
FROM base AS build_pytorch
|
||||
ARG PYTORCH_BRANCH
|
||||
ARG PYTORCH_VISION_BRANCH
|
||||
ARG PYTORCH_REPO
|
||||
ARG PYTORCH_VISION_REPO
|
||||
ARG FA_BRANCH
|
||||
ARG FA_REPO
|
||||
RUN git clone ${PYTORCH_REPO} pytorch
|
||||
RUN cd pytorch && git checkout ${PYTORCH_BRANCH} && \
|
||||
pip install -r requirements.txt && git submodule update --init --recursive \
|
||||
&& python3 tools/amd_build/build_amd.py \
|
||||
&& CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist \
|
||||
&& pip install dist/*.whl
|
||||
RUN git clone ${PYTORCH_VISION_REPO} vision
|
||||
RUN cd vision && git checkout ${PYTORCH_VISION_BRANCH} \
|
||||
&& python3 setup.py bdist_wheel --dist-dir=dist \
|
||||
&& pip install dist/*.whl
|
||||
RUN git clone ${FA_REPO}
|
||||
RUN cd flash-attention \
|
||||
&& git checkout ${FA_BRANCH} \
|
||||
&& git submodule update --init \
|
||||
&& MAX_JOBS=64 GPU_ARCHS=${PYTORCH_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist
|
||||
RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \
|
||||
&& cp /app/vision/dist/*.whl /app/install \
|
||||
&& cp /app/flash-attention/dist/*.whl /app/install
|
||||
|
||||
FROM base AS final
|
||||
RUN --mount=type=bind,from=build_hipblaslt,src=/app/install/,target=/install \
|
||||
dpkg -i /install/*deb \
|
||||
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
|
||||
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status
|
||||
RUN --mount=type=bind,from=build_rccl,src=/app/install/,target=/install \
|
||||
dpkg -i /install/*deb \
|
||||
&& sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \
|
||||
&& sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status
|
||||
RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \
|
||||
pip install /install/*.whl
|
||||
RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
|
||||
pip install /install/*.whl
|
||||
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
|
||||
pip install /install/*.whl
|
||||
|
||||
ARG BASE_IMAGE
|
||||
ARG HIPBLASLT_BRANCH
|
||||
ARG LEGACY_HIPBLASLT_OPTION
|
||||
ARG RCCL_BRANCH
|
||||
ARG RCCL_REPO
|
||||
ARG TRITON_BRANCH
|
||||
ARG TRITON_REPO
|
||||
ARG PYTORCH_BRANCH
|
||||
ARG PYTORCH_VISION_BRANCH
|
||||
ARG PYTORCH_REPO
|
||||
ARG PYTORCH_VISION_REPO
|
||||
ARG FA_BRANCH
|
||||
ARG FA_REPO
|
||||
RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \
|
||||
&& echo "HIPBLAS_COMMON_BRANCH: ${HIPBLAS_COMMON_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "HIPBLASLT_BRANCH: ${HIPBLASLT_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "LEGACY_HIPBLASLT_OPTION: ${LEGACY_HIPBLASLT_OPTION}" >> /app/versions.txt \
|
||||
&& echo "RCCL_BRANCH: ${RCCL_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "RCCL_REPO: ${RCCL_REPO}" >> /app/versions.txt \
|
||||
&& echo "TRITON_BRANCH: ${TRITON_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "TRITON_REPO: ${TRITON_REPO}" >> /app/versions.txt \
|
||||
&& echo "PYTORCH_BRANCH: ${PYTORCH_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "PYTORCH_VISION_BRANCH: ${PYTORCH_VISION_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "PYTORCH_REPO: ${PYTORCH_REPO}" >> /app/versions.txt \
|
||||
&& echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \
|
||||
&& echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt
|
||||
@ -1,4 +1,4 @@
|
||||
ARG NIGHTLY_DATE="20241017"
|
||||
ARG NIGHTLY_DATE="20250122"
|
||||
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"
|
||||
|
||||
FROM $BASE_IMAGE
|
||||
|
||||
@ -15,11 +15,8 @@ Easy, fast, and cheap LLM serving for everyone
|
||||
|
||||
---
|
||||
|
||||
The first vLLM meetup in 2025 is happening on January 22nd, Wednesday, with Google Cloud in San Francisco! We will talk about vLLM's performant V1 architecture, Q1 roadmap, Google Cloud's innovation around vLLM: networking, Cloud Run, Vertex, and TPU! [Register Now](https://lu.ma/zep56hui)
|
||||
|
||||
---
|
||||
|
||||
*Latest News* 🔥
|
||||
- [2025/01] We hosted [the eighth vLLM meetup](https://lu.ma/zep56hui) with Google Cloud! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1epVkt4Zu8Jz_S5OhEHPc798emsYh2BwYfRuDDVEF7u4/edit?usp=sharing).
|
||||
- [2024/12] vLLM joins [pytorch ecosystem](https://pytorch.org/blog/vllm-joins-pytorch)! Easy, Fast, and Cheap LLM Serving for Everyone!
|
||||
- [2024/11] We hosted [the seventh vLLM meetup](https://lu.ma/h0qvrajz) with Snowflake! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1e3CxQBV3JsfGp30SwyvS3eM_tW-ghOhJ9PAJGK6KR54/edit?usp=sharing), and Snowflake team [here](https://docs.google.com/presentation/d/1qF3RkDAbOULwz9WK5TOltt2fE9t6uIc_hVNLFAaQX6A/edit?usp=sharing).
|
||||
- [2024/10] We have just created a developer slack ([slack.vllm.ai](https://slack.vllm.ai)) focusing on coordinating contributions and discussing features. Please feel free to join us there!
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
|
||||
If you believe you have found a security vulnerability in vLLM, we encourage you to let us know right away. We will investigate all legitimate reports and do our best to quickly fix the problem.
|
||||
|
||||
Please report security issues privately using [the vulnerability submission form](https://github.com/vllm-project/vllm/security/advisories/new). Reports will then be triaged by the [vulnerability management team](https://docs.vllm.ai/contributing/vulnerability_management/).
|
||||
Please report security issues privately using [the vulnerability submission form](https://github.com/vllm-project/vllm/security/advisories/new). Reports will then be triaged by the [vulnerability management team](https://docs.vllm.ai/en/latest/contributing/vulnerability_management.html).
|
||||
|
||||
---
|
||||
|
||||
|
||||
@ -22,6 +22,7 @@ class RequestFuncInput:
|
||||
prompt_len: int
|
||||
output_len: int
|
||||
model: str
|
||||
model_name: Optional[str] = None
|
||||
best_of: int = 1
|
||||
logprobs: Optional[int] = None
|
||||
extra_body: Optional[dict] = None
|
||||
@ -34,6 +35,7 @@ class RequestFuncOutput:
|
||||
generated_text: str = ""
|
||||
success: bool = False
|
||||
latency: float = 0.0
|
||||
output_tokens: int = 0
|
||||
ttft: float = 0.0 # Time to first token
|
||||
itl: List[float] = field(
|
||||
default_factory=list) # List of inter-token latencies
|
||||
@ -78,7 +80,7 @@ async def async_request_tgi(
|
||||
continue
|
||||
chunk_bytes = chunk_bytes.decode("utf-8")
|
||||
|
||||
#NOTE: Sometimes TGI returns a ping response without
|
||||
# NOTE: Sometimes TGI returns a ping response without
|
||||
# any data, we should skip it.
|
||||
if chunk_bytes.startswith(":"):
|
||||
continue
|
||||
@ -155,7 +157,7 @@ async def async_request_trt_llm(
|
||||
timestamp = time.perf_counter()
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
ttft = time.perf_counter() - st
|
||||
ttft = timestamp - st
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
@ -235,15 +237,20 @@ async def async_request_openai_completions(
|
||||
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
payload = {
|
||||
"model": request_func_input.model,
|
||||
"model": request_func_input.model_name \
|
||||
if request_func_input.model_name else request_func_input.model,
|
||||
"prompt": request_func_input.prompt,
|
||||
"temperature": 0.0,
|
||||
"best_of": request_func_input.best_of,
|
||||
"max_tokens": request_func_input.output_len,
|
||||
"logprobs": request_func_input.logprobs,
|
||||
"stream": True,
|
||||
"ignore_eos": request_func_input.ignore_eos,
|
||||
"stream_options": {
|
||||
"include_usage": True,
|
||||
},
|
||||
}
|
||||
if request_func_input.ignore_eos:
|
||||
payload["ignore_eos"] = request_func_input.ignore_eos
|
||||
if request_func_input.extra_body:
|
||||
payload.update(request_func_input.extra_body)
|
||||
headers = {
|
||||
@ -254,7 +261,6 @@ async def async_request_openai_completions(
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
generated_text = ""
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
@ -269,15 +275,16 @@ async def async_request_openai_completions(
|
||||
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
||||
"data: ")
|
||||
if chunk == "[DONE]":
|
||||
latency = time.perf_counter() - st
|
||||
else:
|
||||
if chunk != "[DONE]":
|
||||
data = json.loads(chunk)
|
||||
|
||||
# NOTE: Some completion API might have a last
|
||||
# usage summary response without a token so we
|
||||
# want to check a token was generated
|
||||
if data["choices"][0]["text"]:
|
||||
if choices := data.get("choices"):
|
||||
# Note that text could be empty here
|
||||
# e.g. for special tokens
|
||||
text = choices[0].get("text")
|
||||
timestamp = time.perf_counter()
|
||||
# First token
|
||||
if not first_chunk_received:
|
||||
@ -291,7 +298,10 @@ async def async_request_openai_completions(
|
||||
most_recent_timestamp)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
generated_text += data["choices"][0]["text"]
|
||||
generated_text += text or ""
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get(
|
||||
"completion_tokens")
|
||||
if first_chunk_received:
|
||||
output.success = True
|
||||
else:
|
||||
@ -300,7 +310,7 @@ async def async_request_openai_completions(
|
||||
"Never received a valid chunk to calculate TTFT."
|
||||
"This response will be marked as failed!")
|
||||
output.generated_text = generated_text
|
||||
output.latency = latency
|
||||
output.latency = most_recent_timestamp - st
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
@ -328,7 +338,8 @@ async def async_request_openai_chat_completions(
|
||||
if request_func_input.multi_modal_content:
|
||||
content.append(request_func_input.multi_modal_content)
|
||||
payload = {
|
||||
"model": request_func_input.model,
|
||||
"model": request_func_input.model_name \
|
||||
if request_func_input.model_name else request_func_input.model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
@ -338,8 +349,12 @@ async def async_request_openai_chat_completions(
|
||||
"temperature": 0.0,
|
||||
"max_completion_tokens": request_func_input.output_len,
|
||||
"stream": True,
|
||||
"ignore_eos": request_func_input.ignore_eos,
|
||||
"stream_options": {
|
||||
"include_usage": True,
|
||||
},
|
||||
}
|
||||
if request_func_input.ignore_eos:
|
||||
payload["ignore_eos"] = request_func_input.ignore_eos
|
||||
if request_func_input.extra_body:
|
||||
payload.update(request_func_input.extra_body)
|
||||
headers = {
|
||||
@ -365,17 +380,15 @@ async def async_request_openai_chat_completions(
|
||||
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
||||
"data: ")
|
||||
if chunk == "[DONE]":
|
||||
latency = time.perf_counter() - st
|
||||
else:
|
||||
if chunk != "[DONE]":
|
||||
timestamp = time.perf_counter()
|
||||
data = json.loads(chunk)
|
||||
|
||||
delta = data["choices"][0]["delta"]
|
||||
if delta.get("content", None):
|
||||
if choices := data.get("choices"):
|
||||
content = choices[0]["delta"].get("content")
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
ttft = time.perf_counter() - st
|
||||
ttft = timestamp - st
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
@ -383,13 +396,16 @@ async def async_request_openai_chat_completions(
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
|
||||
generated_text += delta["content"]
|
||||
generated_text += content or ""
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get(
|
||||
"completion_tokens")
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
|
||||
output.generated_text = generated_text
|
||||
output.success = True
|
||||
output.latency = latency
|
||||
output.latency = most_recent_timestamp - st
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
|
||||
@ -25,6 +25,7 @@ On the client side, run:
|
||||
import argparse
|
||||
import asyncio
|
||||
import base64
|
||||
import gc
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
@ -423,7 +424,7 @@ def calculate_metrics(
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
selected_percentile_metrics: List[str],
|
||||
selected_percentiles: List[float],
|
||||
gootput_config_dict: Dict[str, float],
|
||||
goodput_config_dict: Dict[str, float],
|
||||
) -> Tuple[BenchmarkMetrics, List[int]]:
|
||||
actual_output_lens: List[int] = []
|
||||
total_input = 0
|
||||
@ -436,19 +437,23 @@ def calculate_metrics(
|
||||
e2els: List[float] = []
|
||||
for i in range(len(outputs)):
|
||||
if outputs[i].success:
|
||||
# We use the tokenizer to count the number of output tokens for all
|
||||
# serving backends instead of looking at len(outputs[i].itl) since
|
||||
# multiple output tokens may be bundled together
|
||||
# Note : this may inflate the output token count slightly
|
||||
output_len = len(
|
||||
tokenizer(outputs[i].generated_text,
|
||||
add_special_tokens=False).input_ids)
|
||||
output_len = outputs[i].output_tokens
|
||||
|
||||
if output_len is None:
|
||||
# We use the tokenizer to count the number of output tokens
|
||||
# for some serving backends instead of looking at
|
||||
# len(outputs[i].itl) since multiple output tokens may be
|
||||
# bundled together
|
||||
# Note : this may inflate the output token count slightly
|
||||
output_len = len(
|
||||
tokenizer(outputs[i].generated_text,
|
||||
add_special_tokens=False).input_ids)
|
||||
actual_output_lens.append(output_len)
|
||||
total_input += input_requests[i][1]
|
||||
tpot = 0
|
||||
if output_len > 1:
|
||||
tpot = (outputs[i].latency - outputs[i].ttft) / (output_len -
|
||||
1)
|
||||
latency_minus_ttft = outputs[i].latency - outputs[i].ttft
|
||||
tpot = latency_minus_ttft / (output_len - 1)
|
||||
tpots.append(tpot)
|
||||
# Note: if output_len <= 1, we regard tpot as 0 for goodput
|
||||
all_tpots.append(tpot)
|
||||
@ -459,21 +464,21 @@ def calculate_metrics(
|
||||
else:
|
||||
actual_output_lens.append(0)
|
||||
|
||||
if gootput_config_dict:
|
||||
if goodput_config_dict:
|
||||
valid_metrics = []
|
||||
slo_values = []
|
||||
|
||||
if "ttft" in gootput_config_dict:
|
||||
if "ttft" in goodput_config_dict:
|
||||
valid_metrics.append(ttfts)
|
||||
slo_values.append(gootput_config_dict["ttft"] /
|
||||
slo_values.append(goodput_config_dict["ttft"] /
|
||||
MILLISECONDS_TO_SECONDS_CONVERSION)
|
||||
if "tpot" in gootput_config_dict:
|
||||
if "tpot" in goodput_config_dict:
|
||||
valid_metrics.append(all_tpots)
|
||||
slo_values.append(gootput_config_dict["tpot"] /
|
||||
slo_values.append(goodput_config_dict["tpot"] /
|
||||
MILLISECONDS_TO_SECONDS_CONVERSION)
|
||||
if "e2el" in gootput_config_dict:
|
||||
if "e2el" in goodput_config_dict:
|
||||
valid_metrics.append(e2els)
|
||||
slo_values.append(gootput_config_dict["e2el"] /
|
||||
slo_values.append(goodput_config_dict["e2el"] /
|
||||
MILLISECONDS_TO_SECONDS_CONVERSION)
|
||||
|
||||
for req_metric in zip(*valid_metrics):
|
||||
@ -525,6 +530,7 @@ async def benchmark(
|
||||
api_url: str,
|
||||
base_url: str,
|
||||
model_id: str,
|
||||
model_name: str,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
input_requests: List[Tuple[str, int, int]],
|
||||
logprobs: Optional[int],
|
||||
@ -536,7 +542,7 @@ async def benchmark(
|
||||
selected_percentile_metrics: List[str],
|
||||
selected_percentiles: List[str],
|
||||
ignore_eos: bool,
|
||||
gootput_config_dict: Dict[str, float],
|
||||
goodput_config_dict: Dict[str, float],
|
||||
max_concurrency: Optional[int],
|
||||
):
|
||||
if backend in ASYNC_REQUEST_FUNCS:
|
||||
@ -553,6 +559,7 @@ async def benchmark(
|
||||
"Multi-modal content is only supported on 'openai-chat' backend.")
|
||||
test_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
model_name=model_name,
|
||||
prompt=test_prompt,
|
||||
api_url=api_url,
|
||||
prompt_len=test_prompt_len,
|
||||
@ -573,6 +580,7 @@ async def benchmark(
|
||||
if profile:
|
||||
print("Starting profiler...")
|
||||
profile_input = RequestFuncInput(model=model_id,
|
||||
model_name=model_name,
|
||||
prompt=test_prompt,
|
||||
api_url=base_url + "/start_profile",
|
||||
prompt_len=test_prompt_len,
|
||||
@ -616,6 +624,7 @@ async def benchmark(
|
||||
async for request in get_request(input_requests, request_rate, burstiness):
|
||||
prompt, prompt_len, output_len, mm_content = request
|
||||
request_func_input = RequestFuncInput(model=model_id,
|
||||
model_name=model_name,
|
||||
prompt=prompt,
|
||||
api_url=api_url,
|
||||
prompt_len=prompt_len,
|
||||
@ -657,7 +666,7 @@ async def benchmark(
|
||||
tokenizer=tokenizer,
|
||||
selected_percentile_metrics=selected_percentile_metrics,
|
||||
selected_percentiles=selected_percentiles,
|
||||
gootput_config_dict=gootput_config_dict,
|
||||
goodput_config_dict=goodput_config_dict,
|
||||
)
|
||||
|
||||
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
|
||||
@ -669,7 +678,7 @@ async def benchmark(
|
||||
metrics.total_output))
|
||||
print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
|
||||
metrics.request_throughput))
|
||||
if gootput_config_dict:
|
||||
if goodput_config_dict:
|
||||
print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
|
||||
metrics.request_goodput))
|
||||
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
|
||||
@ -684,7 +693,7 @@ async def benchmark(
|
||||
"total_output_tokens": metrics.total_output,
|
||||
"request_throughput": metrics.request_throughput,
|
||||
"request_goodput:":
|
||||
metrics.request_goodput if gootput_config_dict else None,
|
||||
metrics.request_goodput if goodput_config_dict else None,
|
||||
"output_throughput": metrics.output_throughput,
|
||||
"total_token_throughput": metrics.total_token_throughput,
|
||||
"input_lens": [output.prompt_len for output in outputs],
|
||||
@ -740,11 +749,11 @@ async def benchmark(
|
||||
|
||||
def check_goodput_args(args):
|
||||
# Check and parse goodput arguments
|
||||
gootput_config_dict = {}
|
||||
goodput_config_dict = {}
|
||||
VALID_NAMES = ["ttft", "tpot", "e2el"]
|
||||
if args.goodput:
|
||||
gootput_config_dict = parse_goodput(args.goodput)
|
||||
for slo_name, slo_val in gootput_config_dict.items():
|
||||
goodput_config_dict = parse_goodput(args.goodput)
|
||||
for slo_name, slo_val in goodput_config_dict.items():
|
||||
if slo_name not in VALID_NAMES:
|
||||
raise ValueError(
|
||||
f"Invalid metric name found, {slo_name}: {slo_val}. "
|
||||
@ -755,22 +764,22 @@ def check_goodput_args(args):
|
||||
f"Invalid value found, {slo_name}: {slo_val}. "
|
||||
"The service level objective value should be "
|
||||
"non-negative.")
|
||||
return gootput_config_dict
|
||||
return goodput_config_dict
|
||||
|
||||
|
||||
def parse_goodput(slo_pairs):
|
||||
gootput_config_dict = {}
|
||||
goodput_config_dict = {}
|
||||
try:
|
||||
for slo_pair in slo_pairs:
|
||||
slo_name, slo_val = slo_pair.split(":")
|
||||
gootput_config_dict[slo_name] = float(slo_val)
|
||||
goodput_config_dict[slo_name] = float(slo_val)
|
||||
except ValueError as err:
|
||||
raise argparse.ArgumentTypeError(
|
||||
"Invalid format found for service level objectives. "
|
||||
"Specify service level objectives for goodput as \"KEY:VALUE\" "
|
||||
"pairs, where the key is a metric name, and the value is a "
|
||||
"number in milliseconds.") from err
|
||||
return gootput_config_dict
|
||||
return goodput_config_dict
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
@ -780,6 +789,7 @@ def main(args: argparse.Namespace):
|
||||
|
||||
backend = args.backend
|
||||
model_id = args.model
|
||||
model_name = args.served_model_name
|
||||
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
|
||||
tokenizer_mode = args.tokenizer_mode
|
||||
|
||||
@ -869,7 +879,11 @@ def main(args: argparse.Namespace):
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset: {args.dataset_name}")
|
||||
|
||||
gootput_config_dict = check_goodput_args(args)
|
||||
goodput_config_dict = check_goodput_args(args)
|
||||
|
||||
# Avoid GC processing "static" data - reduce pause times.
|
||||
gc.collect()
|
||||
gc.freeze()
|
||||
|
||||
benchmark_result = asyncio.run(
|
||||
benchmark(
|
||||
@ -877,6 +891,7 @@ def main(args: argparse.Namespace):
|
||||
api_url=api_url,
|
||||
base_url=base_url,
|
||||
model_id=model_id,
|
||||
model_name=model_name,
|
||||
tokenizer=tokenizer,
|
||||
input_requests=input_requests,
|
||||
logprobs=args.logprobs,
|
||||
@ -890,7 +905,7 @@ def main(args: argparse.Namespace):
|
||||
float(p) for p in args.metric_percentiles.split(",")
|
||||
],
|
||||
ignore_eos=args.ignore_eos,
|
||||
gootput_config_dict=gootput_config_dict,
|
||||
goodput_config_dict=goodput_config_dict,
|
||||
max_concurrency=args.max_concurrency,
|
||||
))
|
||||
|
||||
@ -1222,5 +1237,12 @@ if __name__ == "__main__":
|
||||
'always use the slow tokenizer. \n* '
|
||||
'"mistral" will always use the `mistral_common` tokenizer.')
|
||||
|
||||
parser.add_argument("--served-model-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The model name used in the API. "
|
||||
"If not specified, the model name will be the "
|
||||
"same as the ``--model`` argument. ")
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
1147
benchmarks/kernels/benchmark_lora.py
Normal file
1147
benchmarks/kernels/benchmark_lora.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,7 @@
|
||||
import argparse
|
||||
import time
|
||||
from datetime import datetime
|
||||
from itertools import product
|
||||
from typing import Any, Dict, List, Tuple, TypedDict
|
||||
|
||||
import ray
|
||||
@ -13,6 +14,9 @@ from vllm.model_executor.layers.fused_moe.fused_moe import *
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
FP8_DTYPE = torch.float8_e4m3fnuz if current_platform.is_rocm(
|
||||
) else torch.float8_e4m3fn
|
||||
|
||||
|
||||
class BenchmarkConfig(TypedDict):
|
||||
BLOCK_SIZE_M: int
|
||||
@ -80,8 +84,8 @@ def benchmark_config(
|
||||
a1_scale = torch.randn(1, dtype=torch.float32)
|
||||
a2_scale = torch.randn(1, dtype=torch.float32)
|
||||
|
||||
w1 = w1.to(torch.float8_e4m3fn)
|
||||
w2 = w2.to(torch.float8_e4m3fn)
|
||||
w1 = w1.to(FP8_DTYPE)
|
||||
w2 = w2.to(FP8_DTYPE)
|
||||
|
||||
input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)
|
||||
|
||||
@ -141,28 +145,172 @@ def benchmark_config(
|
||||
return avg
|
||||
|
||||
|
||||
def get_configs_compute_bound() -> List[Dict[str, int]]:
|
||||
# Reduced search space for faster tuning.
|
||||
# TODO(woosuk): Increase the search space and use a performance model to
|
||||
# prune the search space.
|
||||
def get_rocm_tuning_space(use_fp16):
|
||||
block_mn_range = [16, 32, 64, 128, 256]
|
||||
block_k_range = [16, 32, 64, 128, 256]
|
||||
if not use_fp16:
|
||||
block_k_range.remove(16) # BLOCK_K=16 not supported for fp8
|
||||
num_warps_range = [1, 2, 4, 8]
|
||||
group_m_range = [1, 4, 8, 16, 32]
|
||||
num_stage_range = [2]
|
||||
waves_per_eu_range = [0]
|
||||
matrix_instr_nonkdim_range = [16, 32] if use_fp16 else []
|
||||
kpack_range = [1, 2] if use_fp16 else []
|
||||
|
||||
param_ranges = {
|
||||
"BLOCK_SIZE_M": block_mn_range,
|
||||
"BLOCK_SIZE_N": block_mn_range,
|
||||
"BLOCK_SIZE_K": block_k_range,
|
||||
"GROUP_SIZE_M": group_m_range,
|
||||
"num_warps": num_warps_range,
|
||||
"num_stages": num_stage_range,
|
||||
"waves_per_eu": waves_per_eu_range,
|
||||
}
|
||||
if use_fp16:
|
||||
param_ranges["matrix_instr_nonkdim"] = matrix_instr_nonkdim_range
|
||||
param_ranges["kpack"] = kpack_range
|
||||
|
||||
return param_ranges
|
||||
|
||||
|
||||
def get_configs_compute_bound(use_fp16) -> List[Dict[str, int]]:
|
||||
configs: List[BenchmarkConfig] = []
|
||||
for num_stages in [2, 3, 4, 5]:
|
||||
for block_m in [16, 32, 64, 128, 256]:
|
||||
for block_k in [64, 128, 256]:
|
||||
for block_n in [32, 64, 128, 256]:
|
||||
for num_warps in [4, 8]:
|
||||
for group_size in [1, 16, 32, 64]:
|
||||
configs.append({
|
||||
"BLOCK_SIZE_M": block_m,
|
||||
"BLOCK_SIZE_N": block_n,
|
||||
"BLOCK_SIZE_K": block_k,
|
||||
"GROUP_SIZE_M": group_size,
|
||||
"num_warps": num_warps,
|
||||
"num_stages": num_stages,
|
||||
})
|
||||
|
||||
if current_platform.is_rocm():
|
||||
param_ranges = get_rocm_tuning_space(use_fp16)
|
||||
else:
|
||||
# Reduced search space for faster tuning.
|
||||
# TODO(woosuk): Increase the search space and use a performance model to
|
||||
# prune the search space.
|
||||
block_m_range = [16, 32, 64, 128, 256]
|
||||
block_n_range = [32, 64, 128, 256]
|
||||
block_k_range = [64, 128, 256]
|
||||
num_warps_range = [4, 8]
|
||||
group_m_range = [1, 16, 32, 64]
|
||||
num_stage_range = [2, 3, 4, 5]
|
||||
|
||||
param_ranges = {
|
||||
"BLOCK_SIZE_M": block_m_range,
|
||||
"BLOCK_SIZE_N": block_n_range,
|
||||
"BLOCK_SIZE_K": block_k_range,
|
||||
"GROUP_SIZE_M": group_m_range,
|
||||
"num_warps": num_warps_range,
|
||||
"num_stages": num_stage_range,
|
||||
}
|
||||
|
||||
keys, values = zip(*param_ranges.items())
|
||||
for config_values in product(*values):
|
||||
config = dict(zip(keys, config_values))
|
||||
configs.append(config)
|
||||
return configs
|
||||
|
||||
|
||||
def prune_rocm_search_space(num_tokens, shard_intermediate_size, hidden_size,
|
||||
search_space, is_fp16):
|
||||
N1, K1 = shard_intermediate_size, hidden_size
|
||||
N2, K2 = hidden_size, shard_intermediate_size // 2
|
||||
pruned_space_1 = prune_rocm_configs(num_tokens * 2, N1, K1, search_space,
|
||||
is_fp16)
|
||||
pruned_space_2 = prune_rocm_configs(num_tokens * 2, N2, K2, search_space,
|
||||
is_fp16)
|
||||
search_space = merge_unique_dicts(pruned_space_1, pruned_space_2)
|
||||
return search_space
|
||||
|
||||
|
||||
# The following code is inspired by ROCm/Triton GEMM tuning script:
|
||||
# https://github.com/ROCm/triton/blob/triton-mlir/scripts/amd/gemm/tune_gemm.py#L89
|
||||
def prune_rocm_configs(M, N, K, configs, is_fp16=True):
|
||||
pruned_configs = []
|
||||
elemBytes_a = 2 if is_fp16 else 1
|
||||
elemBytes_b = 2 if is_fp16 else 1
|
||||
|
||||
mfma = 16 if M < 32 or N < 32 else 32
|
||||
|
||||
# TODO (zhanglx): figure out the boundary between large and small gemms
|
||||
large_gemm = False
|
||||
if M >= 2048 and N >= 2048:
|
||||
large_gemm = True
|
||||
|
||||
for config in configs:
|
||||
BLOCK_SIZE_M = config.get("BLOCK_SIZE_M")
|
||||
BLOCK_SIZE_N = config.get("BLOCK_SIZE_N")
|
||||
BLOCK_SIZE_K = config.get("BLOCK_SIZE_K")
|
||||
num_warps = config.get("num_warps")
|
||||
|
||||
if is_fp16:
|
||||
matrix_instr_nonkdim = config.get("matrix_instr_nonkdim")
|
||||
if matrix_instr_nonkdim > mfma:
|
||||
continue
|
||||
if mfma == 4 and BLOCK_SIZE_K < 64:
|
||||
continue
|
||||
# some layouts could not work properly in case
|
||||
# number elements per thread is less 1
|
||||
if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:
|
||||
continue
|
||||
SPLIT_K = config.get("SPLIT_K", 1)
|
||||
GROUP_M = config.get("GROUP_SIZE_M")
|
||||
if is_fp16:
|
||||
if (matrix_instr_nonkdim > BLOCK_SIZE_M
|
||||
or matrix_instr_nonkdim > BLOCK_SIZE_N):
|
||||
continue
|
||||
if (matrix_instr_nonkdim >= M
|
||||
and matrix_instr_nonkdim != BLOCK_SIZE_M):
|
||||
continue
|
||||
if (matrix_instr_nonkdim >= N
|
||||
and matrix_instr_nonkdim != BLOCK_SIZE_N):
|
||||
continue
|
||||
# Skip BLOCK_SIZE that is too large compare to M/N
|
||||
# unless BLOCK_SIZE is already small enough
|
||||
if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16:
|
||||
continue
|
||||
if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16:
|
||||
continue
|
||||
# skip large split_k when not necessary
|
||||
if SPLIT_K != 1 and not need_split_k(M, N, K):
|
||||
continue
|
||||
# skip split_k that leads to EVEN_K = false
|
||||
leap = SPLIT_K * BLOCK_SIZE_K
|
||||
modv = K % leap
|
||||
if modv != 0:
|
||||
continue
|
||||
# skip large GROUP_M
|
||||
if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1:
|
||||
continue
|
||||
# out of shared memory resource
|
||||
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
|
||||
LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a +
|
||||
BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b)
|
||||
if LDS > 65536:
|
||||
continue
|
||||
# Skip small block sizes and num_warps for large gemm
|
||||
# For fp16 and f8, we want to only use BLOCK_SIZE >= 64
|
||||
if large_gemm:
|
||||
if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64:
|
||||
continue
|
||||
if BLOCK_SIZE_K < 64:
|
||||
continue
|
||||
if num_warps < 4:
|
||||
continue
|
||||
|
||||
pruned_configs.append(config)
|
||||
|
||||
return pruned_configs
|
||||
|
||||
|
||||
def need_split_k(SIZE_M, SIZE_N, SIZE_K):
|
||||
return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024
|
||||
|
||||
|
||||
def merge_unique_dicts(list1, list2):
|
||||
result = []
|
||||
combined_list = list1.copy()
|
||||
combined_list.extend(list2)
|
||||
for dictionary in combined_list:
|
||||
if dictionary not in result:
|
||||
result.append(dictionary)
|
||||
return result
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
class BenchmarkWorker:
|
||||
|
||||
@ -170,6 +318,10 @@ class BenchmarkWorker:
|
||||
torch.set_default_device("cuda")
|
||||
current_platform.seed_everything(seed)
|
||||
self.seed = seed
|
||||
# Get the device ID to allocate tensors and kernels
|
||||
# on the respective GPU. This is required for Ray to work
|
||||
# correctly with multi-GPU tuning on the ROCm platform.
|
||||
self.device_id = int(ray.get_gpu_ids()[0])
|
||||
|
||||
def benchmark(
|
||||
self,
|
||||
@ -217,25 +369,33 @@ class BenchmarkWorker:
|
||||
) -> Dict[str, int]:
|
||||
best_config = None
|
||||
best_time = float("inf")
|
||||
for config in tqdm(search_space):
|
||||
try:
|
||||
kernel_time = benchmark_config(config,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
num_iters=10)
|
||||
except triton.runtime.autotuner.OutOfResources:
|
||||
# Some configurations may be invalid and fail to compile.
|
||||
continue
|
||||
if current_platform.is_rocm():
|
||||
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
|
||||
search_space = prune_rocm_search_space(num_tokens,
|
||||
shard_intermediate_size,
|
||||
hidden_size, search_space,
|
||||
is_fp16)
|
||||
|
||||
if kernel_time < best_time:
|
||||
best_time = kernel_time
|
||||
best_config = config
|
||||
with torch.cuda.device(self.device_id):
|
||||
for config in tqdm(search_space):
|
||||
try:
|
||||
kernel_time = benchmark_config(config,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
num_iters=20)
|
||||
except triton.runtime.autotuner.OutOfResources:
|
||||
# Some configurations may be invalid and fail to compile.
|
||||
continue
|
||||
|
||||
if kernel_time < best_time:
|
||||
best_time = kernel_time
|
||||
best_config = config
|
||||
now = datetime.now()
|
||||
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
|
||||
assert best_config is not None
|
||||
@ -244,12 +404,27 @@ class BenchmarkWorker:
|
||||
|
||||
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
|
||||
return {
|
||||
"BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
|
||||
"BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
|
||||
"BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
|
||||
"GROUP_SIZE_M": config["GROUP_SIZE_M"],
|
||||
"num_warps": config["num_warps"],
|
||||
"num_stages": config["num_stages"],
|
||||
"BLOCK_SIZE_M":
|
||||
config["BLOCK_SIZE_M"],
|
||||
"BLOCK_SIZE_N":
|
||||
config["BLOCK_SIZE_N"],
|
||||
"BLOCK_SIZE_K":
|
||||
config["BLOCK_SIZE_K"],
|
||||
"GROUP_SIZE_M":
|
||||
config["GROUP_SIZE_M"],
|
||||
"num_warps":
|
||||
config["num_warps"],
|
||||
"num_stages":
|
||||
config["num_stages"],
|
||||
**({
|
||||
"waves_per_eu": config["waves_per_eu"]
|
||||
} if "waves_per_eu" in config else {}),
|
||||
**({
|
||||
"matrix_instr_nonkdim": config["matrix_instr_nonkdim"]
|
||||
} if "matrix_instr_nonkdim" in config else {}),
|
||||
**({
|
||||
"kpack": config["kpack"]
|
||||
} if "kpack" in config else {}),
|
||||
}
|
||||
|
||||
|
||||
@ -294,7 +469,7 @@ def main(args: argparse.Namespace):
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
|
||||
hidden_size = config.hidden_size
|
||||
dtype = config.torch_dtype
|
||||
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
|
||||
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
||||
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
||||
|
||||
@ -322,7 +497,8 @@ def main(args: argparse.Namespace):
|
||||
return ray.get(outputs)
|
||||
|
||||
if args.tune:
|
||||
search_space = get_configs_compute_bound()
|
||||
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
|
||||
search_space = get_configs_compute_bound(is_fp16)
|
||||
print(f"Start tuning over {len(search_space)} configurations...")
|
||||
|
||||
start = time.time()
|
||||
|
||||
@ -98,7 +98,9 @@ def main(
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Using default kv_scale
|
||||
k_scale = v_scale = 1.0
|
||||
k_scale = v_scale = torch.tensor(1.0,
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
|
||||
for _ in range(num_iters):
|
||||
if version == "v1":
|
||||
|
||||
210
benchmarks/kernels/utils.py
Normal file
210
benchmarks/kernels/utils.py
Normal file
@ -0,0 +1,210 @@
|
||||
import dataclasses
|
||||
from typing import Any, Callable, Iterable, Optional
|
||||
|
||||
import torch
|
||||
import torch.utils.benchmark as TBenchmark
|
||||
from torch.utils.benchmark import Measurement as TMeasurement
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CudaGraphBenchParams:
|
||||
num_ops_in_cuda_graph: int
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ArgPool:
|
||||
"""
|
||||
When some argument of the benchmarking function is annotated with this type,
|
||||
the benchmarking class (BenchMM) will collapse the argument to a pick a
|
||||
single value from the given list of values, during function invocation.
|
||||
For every invocation during a benchmarking run, it will choose a
|
||||
different value from the list.
|
||||
"""
|
||||
values: Iterable[Any]
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.values[index]
|
||||
|
||||
|
||||
class Bench:
|
||||
|
||||
class ArgsIterator:
|
||||
|
||||
def __init__(self, args_list, kwargs_list):
|
||||
assert len(args_list) == len(kwargs_list)
|
||||
self.args_list = args_list
|
||||
self.kwargs_list = kwargs_list
|
||||
self.n = len(self.args_list)
|
||||
self.idx = 0
|
||||
|
||||
def __next__(self):
|
||||
while True:
|
||||
yield (self.args_list[self.idx], self.kwargs_list[self.idx])
|
||||
self.idx += 1
|
||||
self.idx = self.idx % self.n
|
||||
|
||||
def reset(self):
|
||||
self.idx = 0
|
||||
|
||||
@property
|
||||
def n_args(self):
|
||||
return self.n
|
||||
|
||||
def __init__(self, cuda_graph_params: Optional[CudaGraphBenchParams],
|
||||
label: str, sub_label: str, description: str, fn: Callable,
|
||||
*args, **kwargs):
|
||||
|
||||
self.cuda_graph_params = cuda_graph_params
|
||||
self.use_cuda_graph = self.cuda_graph_params is not None
|
||||
self.label = label
|
||||
self.sub_label = sub_label
|
||||
self.description = description
|
||||
self.fn = fn
|
||||
|
||||
# Process args
|
||||
self._args = args
|
||||
self._kwargs = kwargs
|
||||
self.args_list, self.kwargs_list = self.collapse_argpool(
|
||||
*args, **kwargs)
|
||||
self.args_iterator = self.ArgsIterator(self.args_list,
|
||||
self.kwargs_list)
|
||||
|
||||
# Cudagraph runner
|
||||
self.g = None
|
||||
if self.use_cuda_graph:
|
||||
self.g = self.get_cuda_graph_runner()
|
||||
|
||||
# benchmark run params
|
||||
self.min_run_time = 1
|
||||
|
||||
def collapse_argpool(self, *args, **kwargs):
|
||||
argpool_args = [arg for arg in args if isinstance(arg, ArgPool)] + [
|
||||
arg for arg in kwargs.values() if isinstance(arg, ArgPool)
|
||||
]
|
||||
if len(argpool_args) == 0:
|
||||
return [args], [kwargs]
|
||||
|
||||
# Make sure all argpools are of the same size
|
||||
argpool_size = len(argpool_args[0].values)
|
||||
assert all([argpool_size == len(arg.values) for arg in argpool_args])
|
||||
|
||||
# create copies of the args
|
||||
args_list = []
|
||||
kwargs_list = []
|
||||
for _ in range(argpool_size):
|
||||
args_list.append(args)
|
||||
kwargs_list.append(kwargs.copy())
|
||||
|
||||
for i in range(argpool_size):
|
||||
# collapse args; Just pick the ith value
|
||||
args_list[i] = tuple([
|
||||
arg[i] if isinstance(arg, ArgPool) else arg
|
||||
for arg in args_list[i]
|
||||
])
|
||||
|
||||
# collapse kwargs
|
||||
kwargs_i = kwargs_list[i]
|
||||
arg_pool_keys = [
|
||||
k for k, v in kwargs_i.items() if isinstance(v, ArgPool)
|
||||
]
|
||||
for k in arg_pool_keys:
|
||||
# again just pick the ith value
|
||||
kwargs_i[k] = kwargs_i[k][i]
|
||||
kwargs_list[i] = kwargs_i
|
||||
|
||||
return args_list, kwargs_list
|
||||
|
||||
def get_cuda_graph_runner(self):
|
||||
assert self.use_cuda_graph
|
||||
assert self.args_iterator is not None
|
||||
|
||||
num_graph_ops = self.cuda_graph_params.num_ops_in_cuda_graph
|
||||
|
||||
# warmup
|
||||
args_it = self.args_iterator.__next__()
|
||||
for _ in range(2):
|
||||
args, kwargs = next(args_it)
|
||||
self.fn(*args, **kwargs)
|
||||
|
||||
self.args_iterator.reset()
|
||||
args_it = self.args_iterator.__next__()
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
for _ in range(num_graph_ops):
|
||||
args, kwargs = next(args_it)
|
||||
self.fn(*args, **kwargs)
|
||||
return g
|
||||
|
||||
def run_cudagrah(self) -> TMeasurement:
|
||||
assert self.use_cuda_graph
|
||||
globals = {'g': self.g}
|
||||
|
||||
return TBenchmark.Timer(
|
||||
stmt="g.replay()",
|
||||
globals=globals,
|
||||
label=(
|
||||
f"{self.label}"
|
||||
f" | cugraph {self.cuda_graph_params.num_ops_in_cuda_graph} ops"
|
||||
),
|
||||
sub_label=self.sub_label,
|
||||
description=self.description,
|
||||
).blocked_autorange(min_run_time=self.min_run_time)
|
||||
|
||||
def run_eager(self) -> TMeasurement:
|
||||
setup = None
|
||||
stmt = None
|
||||
globals = None
|
||||
|
||||
has_arg_pool = self.args_iterator.n_args > 1
|
||||
if has_arg_pool:
|
||||
setup = '''
|
||||
args_iterator.reset()
|
||||
args_it = args_iterator.__next__()
|
||||
'''
|
||||
stmt = '''
|
||||
args, kwargs = next(args_it)
|
||||
fn(*args, **kwargs)
|
||||
'''
|
||||
globals = {'fn': self.fn, 'args_iterator': self.args_iterator}
|
||||
else:
|
||||
# no arg pool. Just use the args and kwargs directly
|
||||
self.args_iterator.reset()
|
||||
args_it = self.args_iterator.__next__()
|
||||
args, kwargs = next(args_it)
|
||||
|
||||
setup = ""
|
||||
stmt = '''
|
||||
fn(*args, **kwargs)
|
||||
'''
|
||||
globals = {'fn': self.fn, 'args': args, 'kwargs': kwargs}
|
||||
|
||||
return TBenchmark.Timer(
|
||||
stmt=stmt,
|
||||
setup=setup,
|
||||
globals=globals,
|
||||
label=self.label,
|
||||
sub_label=self.sub_label,
|
||||
description=self.description,
|
||||
).blocked_autorange(min_run_time=self.min_run_time)
|
||||
|
||||
def run(self) -> TMeasurement:
|
||||
timer = None
|
||||
if self.use_cuda_graph: # noqa SIM108
|
||||
timer = self.run_cudagrah()
|
||||
else:
|
||||
timer = self.run_eager()
|
||||
if not timer.meets_confidence() or timer.has_warnings:
|
||||
print("Doesn't meet confidence - re-running bench ...")
|
||||
return self.run()
|
||||
return timer
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
if exc_type:
|
||||
print(f"exc type {exc_type}")
|
||||
print(f"exc value {exc_value}")
|
||||
print(f"exc traceback {traceback}")
|
||||
@ -58,8 +58,8 @@ function (hipify_sources_target OUT_SRCS NAME ORIG_SRCS)
|
||||
#
|
||||
set(SRCS ${ORIG_SRCS})
|
||||
set(CXX_SRCS ${ORIG_SRCS})
|
||||
list(FILTER SRCS EXCLUDE REGEX "\.(cc)|(cpp)$")
|
||||
list(FILTER CXX_SRCS INCLUDE REGEX "\.(cc)|(cpp)$")
|
||||
list(FILTER SRCS EXCLUDE REGEX "\.(cc)|(cpp)|(hip)$")
|
||||
list(FILTER CXX_SRCS INCLUDE REGEX "\.(cc)|(cpp)|(hip)$")
|
||||
|
||||
#
|
||||
# Generate ROCm/HIP source file names from CUDA file names.
|
||||
|
||||
@ -105,7 +105,7 @@ __device__ void paged_attention_kernel(
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||
const float k_scale, const float v_scale, const int tp_rank,
|
||||
const float* k_scale, const float* v_scale, const int tp_rank,
|
||||
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
||||
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
||||
const int seq_idx = blockIdx.y;
|
||||
@ -285,7 +285,7 @@ __device__ void paged_attention_kernel(
|
||||
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
|
||||
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
||||
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
|
||||
k_vec_quant, k_scale);
|
||||
k_vec_quant, *k_scale);
|
||||
}
|
||||
}
|
||||
|
||||
@ -415,7 +415,7 @@ __device__ void paged_attention_kernel(
|
||||
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
|
||||
// Vector conversion from V_quant_vec to V_vec.
|
||||
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
|
||||
v_scale);
|
||||
*v_scale);
|
||||
}
|
||||
if (block_idx == num_seq_blocks - 1) {
|
||||
// NOTE(woosuk): When v_vec contains the tokens that are out of the
|
||||
@ -513,7 +513,7 @@ __global__ void paged_attention_v1_kernel(
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||
const float k_scale, const float v_scale, const int tp_rank,
|
||||
const float* k_scale, const float* v_scale, const int tp_rank,
|
||||
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
||||
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
||||
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
|
||||
@ -549,7 +549,7 @@ __global__ void paged_attention_v2_kernel(
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||
const float k_scale, const float v_scale, const int tp_rank,
|
||||
const float* k_scale, const float* v_scale, const int tp_rank,
|
||||
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
||||
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
||||
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
|
||||
|
||||
@ -41,7 +41,7 @@
|
||||
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
|
||||
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
|
||||
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
|
||||
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
|
||||
k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \
|
||||
blocksparse_vert_stride, blocksparse_block_size, \
|
||||
blocksparse_head_sliding_step);
|
||||
|
||||
@ -53,10 +53,10 @@ void paged_attention_v1_launcher(
|
||||
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
||||
const std::optional<torch::Tensor>& alibi_slopes, float k_scale,
|
||||
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
|
||||
const int blocksparse_vert_stride, const int blocksparse_block_size,
|
||||
const int blocksparse_head_sliding_step) {
|
||||
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale, const int tp_rank,
|
||||
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
||||
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
||||
int num_seqs = query.size(0);
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
@ -80,6 +80,8 @@ void paged_attention_v1_launcher(
|
||||
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
|
||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||
int* seq_lens_ptr = seq_lens.data_ptr<int>();
|
||||
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
|
||||
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
|
||||
|
||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
int padded_max_seq_len =
|
||||
@ -177,8 +179,9 @@ void paged_attention_v1(
|
||||
torch::Tensor& seq_lens, // [num_seqs]
|
||||
int64_t block_size, int64_t max_seq_len,
|
||||
const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale, const int64_t tp_rank,
|
||||
const int64_t blocksparse_local_blocks,
|
||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||
const int64_t blocksparse_head_sliding_step) {
|
||||
const bool is_block_sparse = (blocksparse_vert_stride > 1);
|
||||
|
||||
@ -37,7 +37,7 @@
|
||||
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
|
||||
value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
|
||||
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
|
||||
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \
|
||||
kv_block_stride, kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, \
|
||||
blocksparse_local_blocks, blocksparse_vert_stride, \
|
||||
blocksparse_block_size, blocksparse_head_sliding_step); \
|
||||
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
|
||||
@ -54,10 +54,10 @@ void paged_attention_v2_launcher(
|
||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
||||
const std::optional<torch::Tensor>& alibi_slopes, float k_scale,
|
||||
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
|
||||
const int blocksparse_vert_stride, const int blocksparse_block_size,
|
||||
const int blocksparse_head_sliding_step) {
|
||||
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale, const int tp_rank,
|
||||
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
||||
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
||||
int num_seqs = query.size(0);
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
@ -84,6 +84,8 @@ void paged_attention_v2_launcher(
|
||||
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
|
||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||
int* seq_lens_ptr = seq_lens.data_ptr<int>();
|
||||
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
|
||||
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
|
||||
|
||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
|
||||
@ -188,8 +190,9 @@ void paged_attention_v2(
|
||||
torch::Tensor& seq_lens, // [num_seqs]
|
||||
int64_t block_size, int64_t max_seq_len,
|
||||
const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale, const int64_t tp_rank,
|
||||
const int64_t blocksparse_local_blocks,
|
||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||
const int64_t blocksparse_head_sliding_step) {
|
||||
const bool is_block_sparse = (blocksparse_vert_stride > 1);
|
||||
|
||||
@ -18,15 +18,15 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
|
||||
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
||||
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping,
|
||||
const std::string& kv_cache_dtype, const double k_scale,
|
||||
const double v_scale);
|
||||
const std::string& kv_cache_dtype,
|
||||
torch::Tensor& k_scale, torch::Tensor& v_scale);
|
||||
|
||||
void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping,
|
||||
const std::string& kv_cache_dtype,
|
||||
const double k_scale, const double v_scale);
|
||||
torch::Tensor& k_scale, torch::Tensor& v_scale);
|
||||
|
||||
// Just for unittest
|
||||
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||
|
||||
@ -159,8 +159,8 @@ __global__ void reshape_and_cache_kernel(
|
||||
// block_size]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int key_stride, const int value_stride, const int num_heads,
|
||||
const int head_size, const int block_size, const int x, const float k_scale,
|
||||
const float v_scale) {
|
||||
const int head_size, const int block_size, const int x,
|
||||
const float* k_scale, const float* v_scale) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
if (slot_idx < 0) {
|
||||
@ -196,9 +196,9 @@ __global__ void reshape_and_cache_kernel(
|
||||
value_cache[tgt_value_idx] = tgt_value;
|
||||
} else {
|
||||
key_cache[tgt_key_idx] =
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, k_scale);
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
|
||||
value_cache[tgt_value_idx] =
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, v_scale);
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -214,7 +214,7 @@ __global__ void reshape_and_cache_flash_kernel(
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int block_stride, const int key_stride, const int value_stride,
|
||||
const int num_heads, const int head_size, const int block_size,
|
||||
const float k_scale, const float v_scale) {
|
||||
const float* k_scale, const float* v_scale) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
// NOTE: slot_idx can be -1 if the token is padded
|
||||
@ -239,9 +239,9 @@ __global__ void reshape_and_cache_flash_kernel(
|
||||
value_cache[tgt_key_value_idx] = tgt_value;
|
||||
} else {
|
||||
key_cache[tgt_key_value_idx] =
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, k_scale);
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
|
||||
value_cache[tgt_key_value_idx] =
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, v_scale);
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -258,7 +258,9 @@ __global__ void reshape_and_cache_flash_kernel(
|
||||
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
||||
slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
|
||||
num_heads, head_size, block_size, x, k_scale, v_scale);
|
||||
num_heads, head_size, block_size, x, \
|
||||
reinterpret_cast<const float*>(k_scale.data_ptr()), \
|
||||
reinterpret_cast<const float*>(v_scale.data_ptr()));
|
||||
|
||||
void reshape_and_cache(
|
||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||
@ -268,8 +270,8 @@ void reshape_and_cache(
|
||||
torch::Tensor&
|
||||
value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
torch::Tensor& slot_mapping, // [num_tokens]
|
||||
const std::string& kv_cache_dtype, const double k_scale,
|
||||
const double v_scale) {
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale) {
|
||||
int num_tokens = key.size(0);
|
||||
int num_heads = key.size(1);
|
||||
int head_size = key.size(2);
|
||||
@ -299,7 +301,9 @@ void reshape_and_cache(
|
||||
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
||||
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, \
|
||||
value_stride, num_heads, head_size, block_size, k_scale, v_scale);
|
||||
value_stride, num_heads, head_size, block_size, \
|
||||
reinterpret_cast<const float*>(k_scale.data_ptr()), \
|
||||
reinterpret_cast<const float*>(v_scale.data_ptr()));
|
||||
|
||||
void reshape_and_cache_flash(
|
||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||
@ -308,8 +312,8 @@ void reshape_and_cache_flash(
|
||||
torch::Tensor&
|
||||
value_cache, // [num_blocks, block_size, num_heads, head_size]
|
||||
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
|
||||
const std::string& kv_cache_dtype, const double k_scale,
|
||||
const double v_scale) {
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale) {
|
||||
// NOTE(woosuk): In vLLM V1, key.size(0) can be different from
|
||||
// slot_mapping.size(0) because of padding for CUDA graphs.
|
||||
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
|
||||
|
||||
@ -32,7 +32,7 @@ class ScalarType {
|
||||
signed_(signed_),
|
||||
bias(bias),
|
||||
finite_values_only(finite_values_only),
|
||||
nan_repr(nan_repr){};
|
||||
nan_repr(nan_repr) {};
|
||||
|
||||
static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) {
|
||||
return ScalarType(0, size_bits - 1, true, bias);
|
||||
|
||||
@ -460,11 +460,11 @@ void paged_attention_v1(
|
||||
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale, const int64_t tp_rank,
|
||||
const int64_t blocksparse_local_blocks,
|
||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||
const int64_t blocksparse_head_sliding_step) {
|
||||
TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f);
|
||||
TORCH_CHECK(blocksparse_vert_stride <= 1,
|
||||
"CPU backend does not support blocksparse attention yet.");
|
||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
|
||||
@ -782,11 +782,11 @@ void paged_attention_v2(
|
||||
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale, const int64_t tp_rank,
|
||||
const int64_t blocksparse_local_blocks,
|
||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||
const int64_t blocksparse_head_sliding_step) {
|
||||
TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f);
|
||||
TORCH_CHECK(blocksparse_vert_stride <= 1,
|
||||
"CPU backend does not support blocksparse attention yet.");
|
||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
|
||||
|
||||
@ -107,10 +107,8 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
|
||||
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
||||
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping,
|
||||
const std::string& kv_cache_dtype, double k_scale,
|
||||
double v_scale) {
|
||||
TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f);
|
||||
|
||||
const std::string& kv_cache_dtype,
|
||||
torch::Tensor& k_scale, torch::Tensor& v_scale) {
|
||||
int num_tokens = key.size(0);
|
||||
int num_heads = key.size(1);
|
||||
int head_size = key.size(2);
|
||||
|
||||
@ -2,13 +2,13 @@
|
||||
#define CPU_TYPES_HPP
|
||||
|
||||
#if defined(__x86_64__)
|
||||
//x86 implementation
|
||||
// x86 implementation
|
||||
#include "cpu_types_x86.hpp"
|
||||
#elif defined(__POWER9_VECTOR__)
|
||||
//ppc implementation
|
||||
// ppc implementation
|
||||
#include "cpu_types_vsx.hpp"
|
||||
#elif defined(__aarch64__)
|
||||
//arm implementation
|
||||
// arm implementation
|
||||
#include "cpu_types_arm.hpp"
|
||||
#else
|
||||
#warning "unsupported vLLM cpu implementation"
|
||||
|
||||
@ -1,48 +1,50 @@
|
||||
#include <arm_neon.h>
|
||||
#include <torch/all.h>
|
||||
#include <torch/all.h>
|
||||
#include <cmath>
|
||||
|
||||
namespace vec_op {
|
||||
|
||||
#ifdef ARM_BF16_SUPPORT
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
||||
#else
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
|
||||
#endif
|
||||
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||
|
||||
#ifndef CPU_OP_GUARD
|
||||
#define CPU_KERNEL_GUARD_IN(NAME)
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME)
|
||||
#define CPU_KERNEL_GUARD_IN(NAME)
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME)
|
||||
#else
|
||||
#define CPU_KERNEL_GUARD_IN(NAME) \
|
||||
std::cout << #NAME << " invoked." << std::endl;
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl;
|
||||
#define CPU_KERNEL_GUARD_IN(NAME) \
|
||||
std::cout << #NAME << " invoked." << std::endl;
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME) \
|
||||
std::cout << #NAME << " exit." << std::endl;
|
||||
#endif
|
||||
|
||||
#define FORCE_INLINE __attribute__((always_inline)) inline
|
||||
|
||||
namespace {
|
||||
template <typename T, T... indexes, typename F>
|
||||
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F &&f) {
|
||||
(f(std::integral_constant<T, indexes>{}), ...);
|
||||
};
|
||||
};
|
||||
template <typename T, T... indexes, typename F>
|
||||
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F&& f) {
|
||||
(f(std::integral_constant<T, indexes>{}), ...);
|
||||
};
|
||||
}; // namespace
|
||||
|
||||
template <typename T, T count, typename F,
|
||||
typename = std::enable_if_t<std::is_invocable_v<F, T>>>
|
||||
constexpr void unroll_loop(F &&f) {
|
||||
constexpr void unroll_loop(F&& f) {
|
||||
unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
|
||||
}
|
||||
|
||||
template <typename T> struct Vec {
|
||||
template <typename T>
|
||||
struct Vec {
|
||||
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; };
|
||||
};
|
||||
|
||||
@ -54,127 +56,124 @@ struct FP16Vec8 : public Vec<FP16Vec8> {
|
||||
|
||||
float16x8_t reg;
|
||||
|
||||
explicit FP16Vec8(const void *ptr)
|
||||
: reg(vld1q_f16(static_cast<const __fp16 *>(ptr))) {};
|
||||
explicit FP16Vec8(const void* ptr)
|
||||
: reg(vld1q_f16(static_cast<const __fp16*>(ptr))) {};
|
||||
|
||||
explicit FP16Vec8(const FP32Vec8 &);
|
||||
explicit FP16Vec8(const FP32Vec8&);
|
||||
|
||||
void save(void *ptr) const {
|
||||
vst1q_f16(static_cast<__fp16 *>(ptr), reg);
|
||||
}
|
||||
void save(void* ptr) const { vst1q_f16(static_cast<__fp16*>(ptr), reg); }
|
||||
};
|
||||
|
||||
struct FP16Vec16 : public Vec<FP16Vec16> {
|
||||
constexpr static int VEC_ELEM_NUM = 16;
|
||||
|
||||
float16x8x2_t reg;
|
||||
|
||||
explicit FP16Vec16(const void *ptr) {
|
||||
reg.val[0] = vld1q_f16(reinterpret_cast<const __fp16*>(ptr));
|
||||
reg.val[1] = vld1q_f16(reinterpret_cast<const __fp16*>(ptr) + 8);
|
||||
}
|
||||
|
||||
explicit FP16Vec16(const FP32Vec16& vec);
|
||||
|
||||
void save(void *ptr) const {
|
||||
vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]);
|
||||
vst1q_f16(reinterpret_cast<__fp16*>(ptr) + 8, reg.val[1]);
|
||||
}
|
||||
|
||||
void save(void *ptr, const int elem_num) const {
|
||||
int full_blocks = elem_num / 8;
|
||||
int remainder = elem_num % 8;
|
||||
|
||||
if (full_blocks > 0) {
|
||||
vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]);
|
||||
if (full_blocks > 1) {
|
||||
vst1q_f16(reinterpret_cast<__fp16*>(ptr) + 8, reg.val[1]);
|
||||
}
|
||||
}
|
||||
constexpr static int VEC_ELEM_NUM = 16;
|
||||
|
||||
// Note: below is the unrolled version of the following code:
|
||||
//
|
||||
// for (int i = 0; i < remainder; ++i) {
|
||||
// reinterpret_cast<__fp16*>(ptr)[full_blocks * 8 + i] =
|
||||
// vgetq_lane_f16(temp, i);
|
||||
// }
|
||||
//
|
||||
// For macOS build (Clang), the arm/neon intrinsics function
|
||||
// `vgetq_lane_f16` needs the parameter `i` to be constant at compile
|
||||
// time.
|
||||
|
||||
if (remainder > 0) {
|
||||
float16x8_t temp = reg.val[full_blocks];
|
||||
__fp16* fp16_ptr = reinterpret_cast<__fp16*>(ptr);
|
||||
switch (remainder)
|
||||
{
|
||||
case 1:
|
||||
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
|
||||
break;
|
||||
case 2:
|
||||
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
|
||||
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
|
||||
break;
|
||||
case 3:
|
||||
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
|
||||
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
|
||||
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
|
||||
break;
|
||||
case 4:
|
||||
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
|
||||
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
|
||||
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
|
||||
fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3);
|
||||
break;
|
||||
case 5:
|
||||
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
|
||||
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
|
||||
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
|
||||
fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3);
|
||||
fp16_ptr[full_blocks * 8 + 4] = vgetq_lane_f16(temp, 4);
|
||||
break;
|
||||
case 6:
|
||||
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
|
||||
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
|
||||
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
|
||||
fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3);
|
||||
fp16_ptr[full_blocks * 8 + 4] = vgetq_lane_f16(temp, 4);
|
||||
fp16_ptr[full_blocks * 8 + 5] = vgetq_lane_f16(temp, 5);
|
||||
break;
|
||||
case 7:
|
||||
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
|
||||
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
|
||||
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
|
||||
fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3);
|
||||
fp16_ptr[full_blocks * 8 + 4] = vgetq_lane_f16(temp, 4);
|
||||
fp16_ptr[full_blocks * 8 + 5] = vgetq_lane_f16(temp, 5);
|
||||
fp16_ptr[full_blocks * 8 + 6] = vgetq_lane_f16(temp, 6);
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
float16x8x2_t reg;
|
||||
|
||||
explicit FP16Vec16(const void* ptr) {
|
||||
reg.val[0] = vld1q_f16(reinterpret_cast<const __fp16*>(ptr));
|
||||
reg.val[1] = vld1q_f16(reinterpret_cast<const __fp16*>(ptr) + 8);
|
||||
}
|
||||
|
||||
explicit FP16Vec16(const FP32Vec16& vec);
|
||||
|
||||
void save(void* ptr) const {
|
||||
vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]);
|
||||
vst1q_f16(reinterpret_cast<__fp16*>(ptr) + 8, reg.val[1]);
|
||||
}
|
||||
|
||||
void save(void* ptr, const int elem_num) const {
|
||||
int full_blocks = elem_num / 8;
|
||||
int remainder = elem_num % 8;
|
||||
|
||||
if (full_blocks > 0) {
|
||||
vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]);
|
||||
if (full_blocks > 1) {
|
||||
vst1q_f16(reinterpret_cast<__fp16*>(ptr) + 8, reg.val[1]);
|
||||
}
|
||||
}
|
||||
|
||||
// Note: below is the unrolled version of the following code:
|
||||
//
|
||||
// for (int i = 0; i < remainder; ++i) {
|
||||
// reinterpret_cast<__fp16*>(ptr)[full_blocks * 8 + i] =
|
||||
// vgetq_lane_f16(temp, i);
|
||||
// }
|
||||
//
|
||||
// For macOS build (Clang), the arm/neon intrinsics function
|
||||
// `vgetq_lane_f16` needs the parameter `i` to be constant at compile
|
||||
// time.
|
||||
|
||||
if (remainder > 0) {
|
||||
float16x8_t temp = reg.val[full_blocks];
|
||||
__fp16* fp16_ptr = reinterpret_cast<__fp16*>(ptr);
|
||||
switch (remainder) {
|
||||
case 1:
|
||||
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
|
||||
break;
|
||||
case 2:
|
||||
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
|
||||
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
|
||||
break;
|
||||
case 3:
|
||||
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
|
||||
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
|
||||
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
|
||||
break;
|
||||
case 4:
|
||||
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
|
||||
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
|
||||
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
|
||||
fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3);
|
||||
break;
|
||||
case 5:
|
||||
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
|
||||
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
|
||||
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
|
||||
fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3);
|
||||
fp16_ptr[full_blocks * 8 + 4] = vgetq_lane_f16(temp, 4);
|
||||
break;
|
||||
case 6:
|
||||
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
|
||||
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
|
||||
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
|
||||
fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3);
|
||||
fp16_ptr[full_blocks * 8 + 4] = vgetq_lane_f16(temp, 4);
|
||||
fp16_ptr[full_blocks * 8 + 5] = vgetq_lane_f16(temp, 5);
|
||||
break;
|
||||
case 7:
|
||||
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
|
||||
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
|
||||
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
|
||||
fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3);
|
||||
fp16_ptr[full_blocks * 8 + 4] = vgetq_lane_f16(temp, 4);
|
||||
fp16_ptr[full_blocks * 8 + 5] = vgetq_lane_f16(temp, 5);
|
||||
fp16_ptr[full_blocks * 8 + 6] = vgetq_lane_f16(temp, 6);
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
#ifdef ARM_BF16_SUPPORT
|
||||
struct BF16Vec8 : public Vec<BF16Vec8> {
|
||||
constexpr static int VEC_ELEM_NUM = 8;
|
||||
|
||||
bfloat16x8_t reg;
|
||||
|
||||
explicit BF16Vec8(const void *ptr)
|
||||
: reg(*reinterpret_cast<const bfloat16x8_t *>(ptr)) {};
|
||||
explicit BF16Vec8(const void* ptr)
|
||||
: reg(*reinterpret_cast<const bfloat16x8_t*>(ptr)) {};
|
||||
|
||||
explicit BF16Vec8(bfloat16x8_t data) : reg(data) {};
|
||||
|
||||
explicit BF16Vec8(const FP32Vec8 &);
|
||||
explicit BF16Vec8(const FP32Vec8&);
|
||||
|
||||
explicit BF16Vec8(float32x4x2_t v) : reg(vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[0]), v.val[1])) {};
|
||||
explicit BF16Vec8(float32x4x2_t v)
|
||||
: reg(vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[0]), v.val[1])) {};
|
||||
|
||||
void save(void *ptr) const { *reinterpret_cast<bfloat16x8_t *>(ptr) = reg; }
|
||||
void save(void* ptr) const { *reinterpret_cast<bfloat16x8_t*>(ptr) = reg; }
|
||||
};
|
||||
|
||||
struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||
@ -182,19 +181,18 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||
|
||||
bfloat16x8x2_t reg;
|
||||
|
||||
explicit BF16Vec16(const void *ptr)
|
||||
: reg(*reinterpret_cast<const bfloat16x8x2_t *>(ptr)) {};
|
||||
explicit BF16Vec16(const void* ptr)
|
||||
: reg(*reinterpret_cast<const bfloat16x8x2_t*>(ptr)) {};
|
||||
|
||||
explicit BF16Vec16(bfloat16x8x2_t data) : reg(data) {};
|
||||
|
||||
explicit BF16Vec16(const FP32Vec16 &);
|
||||
explicit BF16Vec16(const FP32Vec16&);
|
||||
|
||||
explicit BF16Vec16(float32x4x4_t v) : reg({
|
||||
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[0]), v.val[1]),
|
||||
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[2]), v.val[3])
|
||||
}){};
|
||||
explicit BF16Vec16(float32x4x4_t v)
|
||||
: reg({vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[0]), v.val[1]),
|
||||
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[2]), v.val[3])}) {};
|
||||
|
||||
void save(void *ptr) const { *reinterpret_cast<bfloat16x8x2_t *>(ptr) = reg; };
|
||||
void save(void* ptr) const { *reinterpret_cast<bfloat16x8x2_t*>(ptr) = reg; };
|
||||
};
|
||||
|
||||
struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
@ -202,19 +200,15 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
|
||||
bfloat16x8x4_t reg;
|
||||
|
||||
explicit BF16Vec32(const void *ptr)
|
||||
: reg(*reinterpret_cast<const bfloat16x8x4_t *>(ptr)) {};
|
||||
explicit BF16Vec32(const void* ptr)
|
||||
: reg(*reinterpret_cast<const bfloat16x8x4_t*>(ptr)) {};
|
||||
|
||||
explicit BF16Vec32(bfloat16x8x4_t data) : reg(data) {};
|
||||
|
||||
explicit BF16Vec32(const BF16Vec8 &vec8_data) : reg({
|
||||
vec8_data.reg,
|
||||
vec8_data.reg,
|
||||
vec8_data.reg,
|
||||
vec8_data.reg
|
||||
}) {};
|
||||
explicit BF16Vec32(const BF16Vec8& vec8_data)
|
||||
: reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {};
|
||||
|
||||
void save(void *ptr) const { *reinterpret_cast<bfloat16x8x4_t *>(ptr) = reg; };
|
||||
void save(void* ptr) const { *reinterpret_cast<bfloat16x8x4_t*>(ptr) = reg; };
|
||||
};
|
||||
#endif
|
||||
|
||||
@ -232,11 +226,11 @@ struct FP32Vec4 : public Vec<FP32Vec4> {
|
||||
|
||||
explicit FP32Vec4() : reg(vdupq_n_f32(0.0f)) {};
|
||||
|
||||
explicit FP32Vec4(const float *ptr) : reg(vld1q_f32(ptr)) {};
|
||||
explicit FP32Vec4(const float* ptr) : reg(vld1q_f32(ptr)) {};
|
||||
|
||||
explicit FP32Vec4(float32x4_t data) : reg(data) {};
|
||||
|
||||
explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {};
|
||||
explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {};
|
||||
};
|
||||
|
||||
struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
@ -252,32 +246,37 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
|
||||
explicit FP32Vec8() : reg({vmovq_n_f32(0.0), vmovq_n_f32(0.0)}) {};
|
||||
|
||||
explicit FP32Vec8(const float *ptr) : reg({vld1q_f32(ptr), vld1q_f32(ptr + 4)}) {};
|
||||
explicit FP32Vec8(const float* ptr)
|
||||
: reg({vld1q_f32(ptr), vld1q_f32(ptr + 4)}) {};
|
||||
|
||||
explicit FP32Vec8(float32x4x2_t data) : reg(data) {};
|
||||
|
||||
explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {};
|
||||
explicit FP32Vec8(const FP32Vec8& data) : reg(data.reg) {};
|
||||
|
||||
explicit FP32Vec8(const FP16Vec8 &v) {
|
||||
reg.val[0] = vcvt_f32_f16(vget_low_f16(v.reg));
|
||||
reg.val[1] = vcvt_f32_f16(vget_high_f16(v.reg));
|
||||
};
|
||||
explicit FP32Vec8(const FP16Vec8& v) {
|
||||
reg.val[0] = vcvt_f32_f16(vget_low_f16(v.reg));
|
||||
reg.val[1] = vcvt_f32_f16(vget_high_f16(v.reg));
|
||||
};
|
||||
|
||||
explicit FP32Vec8(float16x8_t v) : reg({vcvt_f32_f16(vget_low_f16(v)), vcvt_f32_f16(vget_high_f16(v))}) {};
|
||||
explicit FP32Vec8(float16x8_t v)
|
||||
: reg({vcvt_f32_f16(vget_low_f16(v)), vcvt_f32_f16(vget_high_f16(v))}) {};
|
||||
|
||||
#ifdef ARM_BF16_SUPPORT
|
||||
#ifdef ARM_BF16_SUPPORT
|
||||
|
||||
explicit FP32Vec8(bfloat16x8_t v) : reg({vcvtq_low_f32_bf16(v), vcvtq_high_f32_bf16(v)}) {};
|
||||
explicit FP32Vec8(bfloat16x8_t v)
|
||||
: reg({vcvtq_low_f32_bf16(v), vcvtq_high_f32_bf16(v)}) {};
|
||||
|
||||
explicit FP32Vec8(const BF16Vec8 &v) : reg({vcvtq_low_f32_bf16(v.reg), vcvtq_high_f32_bf16(v.reg)}) {};
|
||||
explicit FP32Vec8(const BF16Vec8& v)
|
||||
: reg({vcvtq_low_f32_bf16(v.reg), vcvtq_high_f32_bf16(v.reg)}) {};
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
float reduce_sum() const {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
float answer = 0;
|
||||
unroll_loop<int, VEC_ELEM_NUM>([&answer, &ar](int i) { answer += ar.values[i]; });
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&answer, &ar](int i) { answer += ar.values[i]; });
|
||||
|
||||
return answer;
|
||||
}
|
||||
@ -324,10 +323,14 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
|
||||
float32x2_t er_vec0 = {static_cast<float32_t>(erf(ar.values[0])), static_cast<float32_t>(erf(ar.values[1]))};
|
||||
float32x2_t er_vec1 = {static_cast<float32_t>(erf(ar.values[2])), static_cast<float32_t>(erf(ar.values[3]))};
|
||||
float32x2_t er_vec2 = {static_cast<float32_t>(erf(ar.values[4])), static_cast<float32_t>(erf(ar.values[5]))};
|
||||
float32x2_t er_vec3 = {static_cast<float32_t>(erf(ar.values[6])), static_cast<float32_t>(erf(ar.values[7]))};
|
||||
float32x2_t er_vec0 = {static_cast<float32_t>(erf(ar.values[0])),
|
||||
static_cast<float32_t>(erf(ar.values[1]))};
|
||||
float32x2_t er_vec1 = {static_cast<float32_t>(erf(ar.values[2])),
|
||||
static_cast<float32_t>(erf(ar.values[3]))};
|
||||
float32x2_t er_vec2 = {static_cast<float32_t>(erf(ar.values[4])),
|
||||
static_cast<float32_t>(erf(ar.values[5]))};
|
||||
float32x2_t er_vec3 = {static_cast<float32_t>(erf(ar.values[6])),
|
||||
static_cast<float32_t>(erf(ar.values[7]))};
|
||||
|
||||
float32x4_t result0 = vcombine_f32(er_vec0, er_vec1);
|
||||
float32x4_t result1 = vcombine_f32(er_vec2, er_vec3);
|
||||
@ -337,25 +340,29 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
result.val[1] = result1;
|
||||
|
||||
return FP32Vec8(result);
|
||||
}
|
||||
|
||||
FP32Vec8 operator*(const FP32Vec8 &b) const {
|
||||
return FP32Vec8(float32x4x2_t({vmulq_f32(reg.val[0], b.reg.val[0]), vmulq_f32(reg.val[1], b.reg.val[1])}));
|
||||
}
|
||||
|
||||
FP32Vec8 operator+(const FP32Vec8 &b) const {
|
||||
return FP32Vec8(float32x4x2_t({vaddq_f32(reg.val[0], b.reg.val[0]), vaddq_f32(reg.val[1], b.reg.val[1])}));
|
||||
FP32Vec8 operator*(const FP32Vec8& b) const {
|
||||
return FP32Vec8(float32x4x2_t({vmulq_f32(reg.val[0], b.reg.val[0]),
|
||||
vmulq_f32(reg.val[1], b.reg.val[1])}));
|
||||
}
|
||||
|
||||
FP32Vec8 operator-(const FP32Vec8 &b) const {
|
||||
return FP32Vec8(float32x4x2_t({vsubq_f32(reg.val[0], b.reg.val[0]), vsubq_f32(reg.val[1], b.reg.val[1])}));
|
||||
FP32Vec8 operator+(const FP32Vec8& b) const {
|
||||
return FP32Vec8(float32x4x2_t({vaddq_f32(reg.val[0], b.reg.val[0]),
|
||||
vaddq_f32(reg.val[1], b.reg.val[1])}));
|
||||
}
|
||||
|
||||
FP32Vec8 operator/(const FP32Vec8 &b) const {
|
||||
return FP32Vec8(float32x4x2_t({vdivq_f32(reg.val[0], b.reg.val[0]), vdivq_f32(reg.val[1], b.reg.val[1])}));
|
||||
FP32Vec8 operator-(const FP32Vec8& b) const {
|
||||
return FP32Vec8(float32x4x2_t({vsubq_f32(reg.val[0], b.reg.val[0]),
|
||||
vsubq_f32(reg.val[1], b.reg.val[1])}));
|
||||
}
|
||||
|
||||
void save(float *ptr) const {
|
||||
FP32Vec8 operator/(const FP32Vec8& b) const {
|
||||
return FP32Vec8(float32x4x2_t({vdivq_f32(reg.val[0], b.reg.val[0]),
|
||||
vdivq_f32(reg.val[1], b.reg.val[1])}));
|
||||
}
|
||||
|
||||
void save(float* ptr) const {
|
||||
vst1q_f32(ptr, reg.val[0]);
|
||||
vst1q_f32(ptr + 4, reg.val[1]);
|
||||
}
|
||||
@ -370,103 +377,100 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
|
||||
float32x4x4_t reg;
|
||||
|
||||
explicit FP32Vec16(float v) : reg({vmovq_n_f32(v), vmovq_n_f32(v), vmovq_n_f32(v), vmovq_n_f32(v)}) {}
|
||||
explicit FP32Vec16(float v)
|
||||
: reg({vmovq_n_f32(v), vmovq_n_f32(v), vmovq_n_f32(v), vmovq_n_f32(v)}) {}
|
||||
|
||||
explicit FP32Vec16() : reg({vmovq_n_f32(0.0), vmovq_n_f32(0.0), vmovq_n_f32(0.0), vmovq_n_f32(0.0)}) {}
|
||||
explicit FP32Vec16()
|
||||
: reg({vmovq_n_f32(0.0), vmovq_n_f32(0.0), vmovq_n_f32(0.0),
|
||||
vmovq_n_f32(0.0)}) {}
|
||||
|
||||
explicit FP32Vec16(const float *ptr) : reg({vld1q_f32(ptr), vld1q_f32(ptr + 4), vld1q_f32(ptr + 8), vld1q_f32(ptr + 12)}) {}
|
||||
explicit FP32Vec16(const float* ptr)
|
||||
: reg({vld1q_f32(ptr), vld1q_f32(ptr + 4), vld1q_f32(ptr + 8),
|
||||
vld1q_f32(ptr + 12)}) {}
|
||||
|
||||
explicit FP32Vec16(float32x4x4_t data) : reg(data) {}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec8 &data) {
|
||||
reg.val[0] = data.reg.val[0];
|
||||
reg.val[1] = data.reg.val[1];
|
||||
reg.val[2] = data.reg.val[0];
|
||||
reg.val[3] = data.reg.val[1];
|
||||
explicit FP32Vec16(const FP32Vec8& data) {
|
||||
reg.val[0] = data.reg.val[0];
|
||||
reg.val[1] = data.reg.val[1];
|
||||
reg.val[2] = data.reg.val[0];
|
||||
reg.val[3] = data.reg.val[1];
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec16 &data) : reg(data.reg) {}
|
||||
explicit FP32Vec16(const FP32Vec16& data) : reg(data.reg) {}
|
||||
|
||||
explicit FP32Vec16(const FP16Vec8 &v) : FP32Vec16(FP32Vec8(v.reg)) {}
|
||||
explicit FP32Vec16(const FP16Vec8& v) : FP32Vec16(FP32Vec8(v.reg)) {}
|
||||
|
||||
#ifdef ARM_BF16_SUPPORT
|
||||
explicit FP32Vec16(bfloat16x8x2_t v) : reg({
|
||||
vcvtq_low_f32_bf16(v.val[0]),
|
||||
vcvtq_high_f32_bf16(v.val[0]),
|
||||
vcvtq_low_f32_bf16(v.val[1]),
|
||||
vcvtq_high_f32_bf16(v.val[1])
|
||||
}) {};
|
||||
#endif
|
||||
#ifdef ARM_BF16_SUPPORT
|
||||
explicit FP32Vec16(bfloat16x8x2_t v)
|
||||
: reg({vcvtq_low_f32_bf16(v.val[0]), vcvtq_high_f32_bf16(v.val[0]),
|
||||
vcvtq_low_f32_bf16(v.val[1]), vcvtq_high_f32_bf16(v.val[1])}) {};
|
||||
#endif
|
||||
|
||||
explicit FP32Vec16(const FP32Vec4 &data) {
|
||||
explicit FP32Vec16(const FP32Vec4& data) {
|
||||
reg.val[0] = data.reg;
|
||||
reg.val[1] = data.reg;
|
||||
reg.val[2] = data.reg;
|
||||
reg.val[3] = data.reg;
|
||||
};
|
||||
|
||||
#ifdef ARM_BF16_SUPPORT
|
||||
explicit FP32Vec16(const BF16Vec16 &v) : reg({
|
||||
vcvtq_low_f32_bf16(v.reg.val[0]),
|
||||
vcvtq_high_f32_bf16(v.reg.val[0]),
|
||||
vcvtq_low_f32_bf16(v.reg.val[1]),
|
||||
vcvtq_high_f32_bf16(v.reg.val[1])
|
||||
}) {};
|
||||
#ifdef ARM_BF16_SUPPORT
|
||||
explicit FP32Vec16(const BF16Vec16& v)
|
||||
: reg({vcvtq_low_f32_bf16(v.reg.val[0]),
|
||||
vcvtq_high_f32_bf16(v.reg.val[0]),
|
||||
vcvtq_low_f32_bf16(v.reg.val[1]),
|
||||
vcvtq_high_f32_bf16(v.reg.val[1])}) {};
|
||||
|
||||
explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {};
|
||||
#endif
|
||||
explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {};
|
||||
#endif
|
||||
|
||||
explicit FP32Vec16(const FP16Vec16 &v) {
|
||||
reg.val[0] = vcvt_f32_f16(vget_low_f16(v.reg.val[0]));
|
||||
reg.val[1] = vcvt_f32_f16(vget_high_f16(v.reg.val[0]));
|
||||
reg.val[2] = vcvt_f32_f16(vget_low_f16(v.reg.val[1]));
|
||||
reg.val[3] = vcvt_f32_f16(vget_high_f16(v.reg.val[1]));
|
||||
explicit FP32Vec16(const FP16Vec16& v) {
|
||||
reg.val[0] = vcvt_f32_f16(vget_low_f16(v.reg.val[0]));
|
||||
reg.val[1] = vcvt_f32_f16(vget_high_f16(v.reg.val[0]));
|
||||
reg.val[2] = vcvt_f32_f16(vget_low_f16(v.reg.val[1]));
|
||||
reg.val[3] = vcvt_f32_f16(vget_high_f16(v.reg.val[1]));
|
||||
};
|
||||
|
||||
FP32Vec16 operator+(const FP32Vec16 &b) const {
|
||||
return FP32Vec16(float32x4x4_t({
|
||||
vaddq_f32(reg.val[0], b.reg.val[0]),
|
||||
vaddq_f32(reg.val[1], b.reg.val[1]),
|
||||
vaddq_f32(reg.val[2], b.reg.val[2]),
|
||||
vaddq_f32(reg.val[3], b.reg.val[3])}));
|
||||
FP32Vec16 operator+(const FP32Vec16& b) const {
|
||||
return FP32Vec16(float32x4x4_t({vaddq_f32(reg.val[0], b.reg.val[0]),
|
||||
vaddq_f32(reg.val[1], b.reg.val[1]),
|
||||
vaddq_f32(reg.val[2], b.reg.val[2]),
|
||||
vaddq_f32(reg.val[3], b.reg.val[3])}));
|
||||
};
|
||||
|
||||
FP32Vec16 operator*(const FP32Vec16 &b) const {
|
||||
return FP32Vec16(float32x4x4_t({
|
||||
vmulq_f32(reg.val[0], b.reg.val[0]),
|
||||
vmulq_f32(reg.val[1], b.reg.val[1]),
|
||||
vmulq_f32(reg.val[2], b.reg.val[2]),
|
||||
vmulq_f32(reg.val[3], b.reg.val[3])}));
|
||||
FP32Vec16 operator*(const FP32Vec16& b) const {
|
||||
return FP32Vec16(float32x4x4_t({vmulq_f32(reg.val[0], b.reg.val[0]),
|
||||
vmulq_f32(reg.val[1], b.reg.val[1]),
|
||||
vmulq_f32(reg.val[2], b.reg.val[2]),
|
||||
vmulq_f32(reg.val[3], b.reg.val[3])}));
|
||||
};
|
||||
|
||||
FP32Vec16 operator-(const FP32Vec16 &b) const {
|
||||
return FP32Vec16(float32x4x4_t({
|
||||
vsubq_f32(reg.val[0], b.reg.val[0]),
|
||||
vsubq_f32(reg.val[1], b.reg.val[1]),
|
||||
vsubq_f32(reg.val[2], b.reg.val[2]),
|
||||
vsubq_f32(reg.val[3], b.reg.val[3])
|
||||
}));
|
||||
FP32Vec16 operator-(const FP32Vec16& b) const {
|
||||
return FP32Vec16(float32x4x4_t({vsubq_f32(reg.val[0], b.reg.val[0]),
|
||||
vsubq_f32(reg.val[1], b.reg.val[1]),
|
||||
vsubq_f32(reg.val[2], b.reg.val[2]),
|
||||
vsubq_f32(reg.val[3], b.reg.val[3])}));
|
||||
};
|
||||
|
||||
FP32Vec16 operator/(const FP32Vec16 &b) const {
|
||||
return FP32Vec16(float32x4x4_t({
|
||||
vdivq_f32(reg.val[0], b.reg.val[0]),
|
||||
vdivq_f32(reg.val[1], b.reg.val[1]),
|
||||
vdivq_f32(reg.val[2], b.reg.val[2]),
|
||||
vdivq_f32(reg.val[3], b.reg.val[3])
|
||||
}));
|
||||
FP32Vec16 operator/(const FP32Vec16& b) const {
|
||||
return FP32Vec16(float32x4x4_t({vdivq_f32(reg.val[0], b.reg.val[0]),
|
||||
vdivq_f32(reg.val[1], b.reg.val[1]),
|
||||
vdivq_f32(reg.val[2], b.reg.val[2]),
|
||||
vdivq_f32(reg.val[3], b.reg.val[3])}));
|
||||
};
|
||||
|
||||
float reduce_sum() const {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
float answer = 0;
|
||||
unroll_loop<int, VEC_ELEM_NUM>([&answer, &ar](int i) { answer += ar.values[i]; });
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&answer, &ar](int i) { answer += ar.values[i]; });
|
||||
|
||||
return answer;
|
||||
};
|
||||
|
||||
template <int group_size> float reduce_sub_sum(int idx) {
|
||||
template <int group_size>
|
||||
float reduce_sub_sum(int idx) {
|
||||
static_assert(VEC_ELEM_NUM % group_size == 0);
|
||||
|
||||
AliasReg ar;
|
||||
@ -479,7 +483,7 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
return answer;
|
||||
};
|
||||
|
||||
void save(float *ptr) const {
|
||||
void save(float* ptr) const {
|
||||
vst1q_f32(ptr, reg.val[0]);
|
||||
vst1q_f32(ptr + 4, reg.val[1]);
|
||||
vst1q_f32(ptr + 8, reg.val[2]);
|
||||
@ -487,43 +491,59 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
};
|
||||
};
|
||||
|
||||
template <typename T> struct VecType { using vec_type = void; };
|
||||
template <typename T>
|
||||
struct VecType {
|
||||
using vec_type = void;
|
||||
};
|
||||
|
||||
template <typename T> using vec_t = typename VecType<T>::vec_type;
|
||||
template <typename T>
|
||||
using vec_t = typename VecType<T>::vec_type;
|
||||
|
||||
template <> struct VecType<float> { using vec_type = FP32Vec8; };
|
||||
template <>
|
||||
struct VecType<float> {
|
||||
using vec_type = FP32Vec8;
|
||||
};
|
||||
|
||||
template <> struct VecType<c10::Half> { using vec_type = FP16Vec8; };
|
||||
template <>
|
||||
struct VecType<c10::Half> {
|
||||
using vec_type = FP16Vec8;
|
||||
};
|
||||
|
||||
#ifdef ARM_BF16_SUPPORT
|
||||
template <> struct VecType<c10::BFloat16> { using vec_type = BF16Vec8; };
|
||||
template <>
|
||||
struct VecType<c10::BFloat16> {
|
||||
using vec_type = BF16Vec8;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename T> void storeFP32(float v, T *ptr) { *ptr = v; }
|
||||
|
||||
template <> inline void storeFP32<c10::Half>(float v, c10::Half *ptr) {
|
||||
*reinterpret_cast<__fp16 *>(ptr) = v;
|
||||
template <typename T>
|
||||
void storeFP32(float v, T* ptr) {
|
||||
*ptr = v;
|
||||
}
|
||||
|
||||
inline FP16Vec16::FP16Vec16(const FP32Vec16 &v) {
|
||||
float16x4_t low_0 = vcvt_f16_f32(v.reg.val[0]);
|
||||
float16x4_t high_0 = vcvt_f16_f32(v.reg.val[1]);
|
||||
float16x4_t low_1 = vcvt_f16_f32(v.reg.val[2]);
|
||||
float16x4_t high_1 = vcvt_f16_f32(v.reg.val[3]);
|
||||
template <>
|
||||
inline void storeFP32<c10::Half>(float v, c10::Half* ptr) {
|
||||
*reinterpret_cast<__fp16*>(ptr) = v;
|
||||
}
|
||||
|
||||
reg.val[0] = vcombine_f16(low_0, high_0);
|
||||
reg.val[1] = vcombine_f16(low_1, high_1);
|
||||
inline FP16Vec16::FP16Vec16(const FP32Vec16& v) {
|
||||
float16x4_t low_0 = vcvt_f16_f32(v.reg.val[0]);
|
||||
float16x4_t high_0 = vcvt_f16_f32(v.reg.val[1]);
|
||||
float16x4_t low_1 = vcvt_f16_f32(v.reg.val[2]);
|
||||
float16x4_t high_1 = vcvt_f16_f32(v.reg.val[3]);
|
||||
|
||||
reg.val[0] = vcombine_f16(low_0, high_0);
|
||||
reg.val[1] = vcombine_f16(low_1, high_1);
|
||||
};
|
||||
|
||||
inline FP16Vec8 :: FP16Vec8(const FP32Vec8 &v) {
|
||||
float16x4_t lower_half = vcvt_f16_f32(v.reg.val[0]);
|
||||
float16x4_t upper_half = vcvt_f16_f32(v.reg.val[1]);
|
||||
inline FP16Vec8 ::FP16Vec8(const FP32Vec8& v) {
|
||||
float16x4_t lower_half = vcvt_f16_f32(v.reg.val[0]);
|
||||
float16x4_t upper_half = vcvt_f16_f32(v.reg.val[1]);
|
||||
|
||||
reg = vcombine_f16(lower_half, upper_half);
|
||||
reg = vcombine_f16(lower_half, upper_half);
|
||||
};
|
||||
|
||||
inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) {
|
||||
|
||||
inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) {
|
||||
acc.reg.val[0] = vfmaq_f32(acc.reg.val[0], a.reg.val[0], b.reg.val[0]);
|
||||
acc.reg.val[1] = vfmaq_f32(acc.reg.val[1], a.reg.val[1], b.reg.val[1]);
|
||||
acc.reg.val[2] = vfmaq_f32(acc.reg.val[2], a.reg.val[2], b.reg.val[2]);
|
||||
@ -531,8 +551,7 @@ inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) {
|
||||
};
|
||||
|
||||
#ifdef ARM_BF16_SUPPORT
|
||||
inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) {
|
||||
|
||||
inline void fma(FP32Vec16& acc, BF16Vec32& a, BF16Vec32& b) {
|
||||
float32x4_t a0_low = vcvt_f32_bf16(vget_low_bf16(a.reg.val[0]));
|
||||
float32x4_t a0_high = vcvt_f32_bf16(vget_high_bf16(a.reg.val[0]));
|
||||
float32x4_t a1_low = vcvt_f32_bf16(vget_low_bf16(a.reg.val[1]));
|
||||
@ -551,22 +570,22 @@ inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) {
|
||||
#endif
|
||||
|
||||
#ifdef ARM_BF16_SUPPORT
|
||||
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) : reg(vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[0]), v.reg.val[1])) {};
|
||||
inline BF16Vec8::BF16Vec8(const FP32Vec8& v)
|
||||
: reg(vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[0]), v.reg.val[1])) {
|
||||
};
|
||||
|
||||
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) : reg({
|
||||
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[0]), v.reg.val[1]),
|
||||
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[2]), v.reg.val[3])
|
||||
}){};
|
||||
inline BF16Vec16::BF16Vec16(const FP32Vec16& v)
|
||||
: reg({vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[0]), v.reg.val[1]),
|
||||
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[2]),
|
||||
v.reg.val[3])}) {};
|
||||
#endif
|
||||
|
||||
inline void prefetch(const void *addr) {
|
||||
__builtin_prefetch(addr, 0, 1);
|
||||
};
|
||||
inline void prefetch(const void* addr) { __builtin_prefetch(addr, 0, 1); };
|
||||
|
||||
#ifdef ARM_BF16_SUPPORT
|
||||
template <>
|
||||
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
|
||||
*reinterpret_cast<__bf16 *>(ptr) = vcvth_bf16_f32(v);
|
||||
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
|
||||
*reinterpret_cast<__bf16*>(ptr) = vcvth_bf16_f32(v);
|
||||
};
|
||||
#endif
|
||||
};
|
||||
}; // namespace vec_op
|
||||
@ -9,38 +9,40 @@
|
||||
namespace vec_op {
|
||||
|
||||
// FIXME: FP16 is not fully supported in Torch-CPU
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||
|
||||
#ifndef CPU_OP_GUARD
|
||||
#define CPU_KERNEL_GUARD_IN(NAME)
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME)
|
||||
#define CPU_KERNEL_GUARD_IN(NAME)
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME)
|
||||
#else
|
||||
#define CPU_KERNEL_GUARD_IN(NAME) \
|
||||
std::cout << #NAME << " invoked." << std::endl;
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl;
|
||||
#define CPU_KERNEL_GUARD_IN(NAME) \
|
||||
std::cout << #NAME << " invoked." << std::endl;
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME) \
|
||||
std::cout << #NAME << " exit." << std::endl;
|
||||
#endif
|
||||
|
||||
#define FORCE_INLINE __attribute__((always_inline)) inline
|
||||
|
||||
namespace {
|
||||
template <typename T, T... indexes, typename F>
|
||||
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F &&f) {
|
||||
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F&& f) {
|
||||
(f(std::integral_constant<T, indexes>{}), ...);
|
||||
}
|
||||
}; // namespace
|
||||
}; // namespace
|
||||
|
||||
template <typename T, T count, typename F,
|
||||
typename = std::enable_if_t<std::is_invocable_v<F, T>>>
|
||||
constexpr void unroll_loop(F &&f) {
|
||||
constexpr void unroll_loop(F&& f) {
|
||||
unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
|
||||
}
|
||||
|
||||
template <typename T> struct Vec {
|
||||
template <typename T>
|
||||
struct Vec {
|
||||
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; }
|
||||
};
|
||||
|
||||
@ -68,12 +70,14 @@ struct BF16Vec8 : public Vec<BF16Vec8> {
|
||||
|
||||
__vector signed short reg;
|
||||
|
||||
explicit BF16Vec8(const void *ptr)
|
||||
: reg((__vector signed short)vec_xl(0, (__vector signed short *)ptr)) {}
|
||||
explicit BF16Vec8(const void* ptr)
|
||||
: reg((__vector signed short)vec_xl(0, (__vector signed short*)ptr)) {}
|
||||
|
||||
explicit BF16Vec8(const FP32Vec8 &);
|
||||
explicit BF16Vec8(const FP32Vec8&);
|
||||
|
||||
void save(void *ptr) const { *reinterpret_cast<__vector signed short *>(ptr) = reg; }
|
||||
void save(void* ptr) const {
|
||||
*reinterpret_cast<__vector signed short*>(ptr) = reg;
|
||||
}
|
||||
};
|
||||
|
||||
struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||
@ -81,18 +85,18 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||
|
||||
ss16x8x2_t reg;
|
||||
|
||||
explicit BF16Vec16(const void *ptr) {
|
||||
explicit BF16Vec16(const void* ptr) {
|
||||
// Load 256 bits in two parts
|
||||
reg.val[0] = (__vector signed short)vec_xl(0, (signed short *)ptr);
|
||||
reg.val[1] = (__vector signed short)vec_xl(16, (signed short *)ptr);
|
||||
reg.val[0] = (__vector signed short)vec_xl(0, (signed short*)ptr);
|
||||
reg.val[1] = (__vector signed short)vec_xl(16, (signed short*)ptr);
|
||||
}
|
||||
|
||||
explicit BF16Vec16(const FP32Vec16 &);
|
||||
explicit BF16Vec16(const FP32Vec16&);
|
||||
|
||||
void save(void *ptr) const {
|
||||
void save(void* ptr) const {
|
||||
// Save 256 bits in two parts
|
||||
vec_xst(reg.val[0], 0, (signed short *)ptr);
|
||||
vec_xst(reg.val[1], 16, (signed short *)ptr);
|
||||
vec_xst(reg.val[0], 0, (signed short*)ptr);
|
||||
vec_xst(reg.val[1], 16, (signed short*)ptr);
|
||||
}
|
||||
};
|
||||
|
||||
@ -102,19 +106,15 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
constexpr static int VEC_ELEM_NUM = 32;
|
||||
|
||||
ss16x8x4_t reg;
|
||||
explicit BF16Vec32(const void *ptr)
|
||||
: reg(*reinterpret_cast<const ss16x8x4_t *>(ptr)) {}
|
||||
explicit BF16Vec32(const void* ptr)
|
||||
: reg(*reinterpret_cast<const ss16x8x4_t*>(ptr)) {}
|
||||
|
||||
explicit BF16Vec32(ss16x8x4_t data) : reg(data) {}
|
||||
|
||||
explicit BF16Vec32(const BF16Vec8 &vec8_data) : reg({
|
||||
vec8_data.reg,
|
||||
vec8_data.reg,
|
||||
vec8_data.reg,
|
||||
vec8_data.reg
|
||||
}) {}
|
||||
explicit BF16Vec32(const BF16Vec8& vec8_data)
|
||||
: reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {}
|
||||
|
||||
void save(void *ptr) const { *reinterpret_cast<ss16x8x4_t *>(ptr) = reg; }
|
||||
void save(void* ptr) const { *reinterpret_cast<ss16x8x4_t*>(ptr) = reg; }
|
||||
};
|
||||
|
||||
struct FP32Vec4 : public Vec<FP32Vec4> {
|
||||
@ -130,11 +130,11 @@ struct FP32Vec4 : public Vec<FP32Vec4> {
|
||||
|
||||
explicit FP32Vec4() : reg(vec_splats(0.0f)) {}
|
||||
|
||||
explicit FP32Vec4(const float *ptr) : reg(vec_xl(0, ptr)) {}
|
||||
explicit FP32Vec4(const float* ptr) : reg(vec_xl(0, ptr)) {}
|
||||
|
||||
explicit FP32Vec4(__vector float data) : reg(data) {}
|
||||
|
||||
explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {}
|
||||
explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {}
|
||||
};
|
||||
|
||||
struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
@ -156,19 +156,19 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
reg.val[1] = vec_splats(0.0f);
|
||||
}
|
||||
|
||||
explicit FP32Vec8(const float *ptr) {
|
||||
explicit FP32Vec8(const float* ptr) {
|
||||
reg.val[0] = vec_xl(0, ptr);
|
||||
reg.val[1] = vec_xl(16, ptr);
|
||||
}
|
||||
|
||||
explicit FP32Vec8(f32x4x2_t data) : reg(data) {}
|
||||
|
||||
explicit FP32Vec8(const FP32Vec8 &data) {
|
||||
explicit FP32Vec8(const FP32Vec8& data) {
|
||||
reg.val[0] = data.reg.val[0];
|
||||
reg.val[1] = data.reg.val[1];
|
||||
}
|
||||
|
||||
explicit FP32Vec8(const BF16Vec8 &v) {
|
||||
explicit FP32Vec8(const BF16Vec8& v) {
|
||||
reg.val[0] = (__vector float)vec_mergeh(zero, v.reg);
|
||||
reg.val[1] = (__vector float)vec_mergel(zero, v.reg);
|
||||
}
|
||||
@ -177,7 +177,8 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
float result = 0;
|
||||
unroll_loop<int, VEC_ELEM_NUM>([&result, &ar](int i) { result += ar.values[i]; });
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&result, &ar](int i) { result += ar.values[i]; });
|
||||
|
||||
return result;
|
||||
}
|
||||
@ -230,23 +231,27 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
|
||||
}
|
||||
|
||||
FP32Vec8 operator*(const FP32Vec8 &b) const {
|
||||
return FP32Vec8({vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1])});
|
||||
FP32Vec8 operator*(const FP32Vec8& b) const {
|
||||
return FP32Vec8(
|
||||
{vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1])});
|
||||
}
|
||||
|
||||
FP32Vec8 operator+(const FP32Vec8 &b) const {
|
||||
return FP32Vec8({vec_add(reg.val[0], b.reg.val[0]), vec_add(reg.val[1], b.reg.val[1])});
|
||||
FP32Vec8 operator+(const FP32Vec8& b) const {
|
||||
return FP32Vec8(
|
||||
{vec_add(reg.val[0], b.reg.val[0]), vec_add(reg.val[1], b.reg.val[1])});
|
||||
}
|
||||
|
||||
FP32Vec8 operator-(const FP32Vec8 &b) const {
|
||||
return FP32Vec8({vec_sub(reg.val[0], b.reg.val[0]), vec_sub(reg.val[1], b.reg.val[1])});
|
||||
FP32Vec8 operator-(const FP32Vec8& b) const {
|
||||
return FP32Vec8(
|
||||
{vec_sub(reg.val[0], b.reg.val[0]), vec_sub(reg.val[1], b.reg.val[1])});
|
||||
}
|
||||
|
||||
FP32Vec8 operator/(const FP32Vec8 &b) const {
|
||||
return FP32Vec8({vec_div(reg.val[0], b.reg.val[0]), vec_div(reg.val[1], b.reg.val[1])});
|
||||
FP32Vec8 operator/(const FP32Vec8& b) const {
|
||||
return FP32Vec8(
|
||||
{vec_div(reg.val[0], b.reg.val[0]), vec_div(reg.val[1], b.reg.val[1])});
|
||||
}
|
||||
|
||||
void save(float *ptr) const {
|
||||
void save(float* ptr) const {
|
||||
vec_xst(reg.val[0], 0, ptr);
|
||||
vec_xst(reg.val[1], 16, ptr);
|
||||
}
|
||||
@ -275,7 +280,7 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
reg.val[3] = vec_splats(0.0f);
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const float *ptr) {
|
||||
explicit FP32Vec16(const float* ptr) {
|
||||
reg.val[0] = vec_xl(0, ptr);
|
||||
reg.val[1] = vec_xl(16, ptr);
|
||||
reg.val[2] = vec_xl(32, ptr);
|
||||
@ -284,78 +289,76 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
|
||||
explicit FP32Vec16(f32x4x4_t data) : reg(data) {}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec16 &data) {
|
||||
explicit FP32Vec16(const FP32Vec16& data) {
|
||||
reg.val[0] = data.reg.val[0];
|
||||
reg.val[1] = data.reg.val[1];
|
||||
reg.val[2] = data.reg.val[2];
|
||||
reg.val[3] = data.reg.val[3];
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec4 &data) {
|
||||
explicit FP32Vec16(const FP32Vec4& data) {
|
||||
reg.val[0] = data.reg;
|
||||
reg.val[1] = data.reg;
|
||||
reg.val[2] = data.reg;
|
||||
reg.val[3] = data.reg;
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec8 &data) {
|
||||
explicit FP32Vec16(const FP32Vec8& data) {
|
||||
reg.val[0] = data.reg.val[0];
|
||||
reg.val[1] = data.reg.val[1];
|
||||
reg.val[2] = data.reg.val[0];
|
||||
reg.val[3] = data.reg.val[1];
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const BF16Vec16 &v) {
|
||||
explicit FP32Vec16(const BF16Vec16& v) {
|
||||
reg.val[0] = (__vector float)vec_mergeh(zero, v.reg.val[0]);
|
||||
reg.val[1] = (__vector float)vec_mergel(zero, v.reg.val[0]);
|
||||
reg.val[2] = (__vector float)vec_mergeh(zero, v.reg.val[1]);
|
||||
reg.val[3] = (__vector float)vec_mergel(zero, v.reg.val[1]);
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
|
||||
explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
|
||||
|
||||
FP32Vec16 operator*(const FP32Vec16 &b) const {
|
||||
return FP32Vec16(f32x4x4_t({
|
||||
vec_mul(reg.val[0], b.reg.val[0]),
|
||||
vec_mul(reg.val[1], b.reg.val[1]),
|
||||
vec_mul(reg.val[2], b.reg.val[2]),
|
||||
vec_mul(reg.val[3], b.reg.val[3])}));
|
||||
FP32Vec16 operator*(const FP32Vec16& b) const {
|
||||
return FP32Vec16(f32x4x4_t({vec_mul(reg.val[0], b.reg.val[0]),
|
||||
vec_mul(reg.val[1], b.reg.val[1]),
|
||||
vec_mul(reg.val[2], b.reg.val[2]),
|
||||
vec_mul(reg.val[3], b.reg.val[3])}));
|
||||
}
|
||||
|
||||
FP32Vec16 operator+(const FP32Vec16 &b) const {
|
||||
return FP32Vec16(f32x4x4_t({
|
||||
vec_add(reg.val[0], b.reg.val[0]),
|
||||
vec_add(reg.val[1], b.reg.val[1]),
|
||||
vec_add(reg.val[2], b.reg.val[2]),
|
||||
vec_add(reg.val[3], b.reg.val[3])}));
|
||||
FP32Vec16 operator+(const FP32Vec16& b) const {
|
||||
return FP32Vec16(f32x4x4_t({vec_add(reg.val[0], b.reg.val[0]),
|
||||
vec_add(reg.val[1], b.reg.val[1]),
|
||||
vec_add(reg.val[2], b.reg.val[2]),
|
||||
vec_add(reg.val[3], b.reg.val[3])}));
|
||||
}
|
||||
|
||||
FP32Vec16 operator-(const FP32Vec16 &b) const {
|
||||
return FP32Vec16(f32x4x4_t({
|
||||
vec_sub(reg.val[0], b.reg.val[0]),
|
||||
vec_sub(reg.val[1], b.reg.val[1]),
|
||||
vec_sub(reg.val[2], b.reg.val[2]),
|
||||
vec_sub(reg.val[3], b.reg.val[3])}));
|
||||
FP32Vec16 operator-(const FP32Vec16& b) const {
|
||||
return FP32Vec16(f32x4x4_t({vec_sub(reg.val[0], b.reg.val[0]),
|
||||
vec_sub(reg.val[1], b.reg.val[1]),
|
||||
vec_sub(reg.val[2], b.reg.val[2]),
|
||||
vec_sub(reg.val[3], b.reg.val[3])}));
|
||||
}
|
||||
|
||||
FP32Vec16 operator/(const FP32Vec16 &b) const {
|
||||
return FP32Vec16(f32x4x4_t({
|
||||
vec_div(reg.val[0], b.reg.val[0]),
|
||||
vec_div(reg.val[1], b.reg.val[1]),
|
||||
vec_div(reg.val[2], b.reg.val[2]),
|
||||
vec_div(reg.val[3], b.reg.val[3])}));
|
||||
FP32Vec16 operator/(const FP32Vec16& b) const {
|
||||
return FP32Vec16(f32x4x4_t({vec_div(reg.val[0], b.reg.val[0]),
|
||||
vec_div(reg.val[1], b.reg.val[1]),
|
||||
vec_div(reg.val[2], b.reg.val[2]),
|
||||
vec_div(reg.val[3], b.reg.val[3])}));
|
||||
}
|
||||
|
||||
float reduce_sum() const {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
float result = 0;
|
||||
unroll_loop<int, VEC_ELEM_NUM>([&result, &ar](int i) { result += ar.values[i]; });
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&result, &ar](int i) { result += ar.values[i]; });
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
template <int group_size> float reduce_sub_sum(int idx) {
|
||||
template <int group_size>
|
||||
float reduce_sub_sum(int idx) {
|
||||
static_assert(VEC_ELEM_NUM % group_size == 0);
|
||||
|
||||
AliasReg ar;
|
||||
@ -368,7 +371,7 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
return result;
|
||||
}
|
||||
|
||||
void save(float *ptr) const {
|
||||
void save(float* ptr) const {
|
||||
vec_xst(reg.val[0], 0, ptr);
|
||||
vec_xst(reg.val[1], 16, ptr);
|
||||
vec_xst(reg.val[2], 32, ptr);
|
||||
@ -376,43 +379,62 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T> struct VecType { using vec_type = void; };
|
||||
template <typename T>
|
||||
struct VecType {
|
||||
using vec_type = void;
|
||||
};
|
||||
|
||||
template <typename T> using vec_t = typename VecType<T>::vec_type;
|
||||
template <typename T>
|
||||
using vec_t = typename VecType<T>::vec_type;
|
||||
|
||||
template <> struct VecType<float> { using vec_type = FP32Vec8; };
|
||||
template <>
|
||||
struct VecType<float> {
|
||||
using vec_type = FP32Vec8;
|
||||
};
|
||||
|
||||
template <> struct VecType<c10::BFloat16> { using vec_type = BF16Vec8; };
|
||||
template <>
|
||||
struct VecType<c10::BFloat16> {
|
||||
using vec_type = BF16Vec8;
|
||||
};
|
||||
|
||||
template <typename T> void storeFP32(float v, T *ptr) { *ptr = v; }
|
||||
template <typename T>
|
||||
void storeFP32(float v, T* ptr) {
|
||||
*ptr = v;
|
||||
}
|
||||
|
||||
inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) {
|
||||
inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) {
|
||||
acc = acc + a * b;
|
||||
}
|
||||
|
||||
template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
|
||||
c10::BFloat16 __attribute__((__may_alias__)) *v_ptr =
|
||||
reinterpret_cast<c10::BFloat16 *>(&v);
|
||||
template <>
|
||||
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
|
||||
c10::BFloat16 __attribute__((__may_alias__))* v_ptr =
|
||||
reinterpret_cast<c10::BFloat16*>(&v);
|
||||
*ptr = *(v_ptr + 1);
|
||||
}
|
||||
|
||||
#ifndef __VEC_CLASS_FP_NAN
|
||||
#define __VEC_CLASS_FP_NAN (1 << 6)
|
||||
#define __VEC_CLASS_FP_NAN (1 << 6)
|
||||
#endif
|
||||
|
||||
const static __vector unsigned char omask = { 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29 };
|
||||
const static __vector unsigned char omask = {0, 1, 4, 5, 8, 9, 12, 13,
|
||||
16, 17, 20, 21, 24, 25, 28, 29};
|
||||
#ifndef _ARCH_PWR10
|
||||
const static __vector unsigned int bias = { 0x00007fff, 0x00007fff, 0x00007fff, 0x00007fff };
|
||||
const static __vector unsigned int nan = { 0x7fc00000, 0x7fc00000, 0x7fc00000, 0x7fc00000 };
|
||||
const static __vector unsigned int sh16 = { 16, 16, 16, 16 };
|
||||
const static __vector unsigned int one = { 1, 1, 1, 1 };
|
||||
const static __vector unsigned int bias = {0x00007fff, 0x00007fff, 0x00007fff,
|
||||
0x00007fff};
|
||||
const static __vector unsigned int nan = {0x7fc00000, 0x7fc00000, 0x7fc00000,
|
||||
0x7fc00000};
|
||||
const static __vector unsigned int sh16 = {16, 16, 16, 16};
|
||||
const static __vector unsigned int one = {1, 1, 1, 1};
|
||||
#endif
|
||||
|
||||
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) {
|
||||
inline BF16Vec8::BF16Vec8(const FP32Vec8& v) {
|
||||
#ifdef _ARCH_PWR10
|
||||
__vector signed short ret[2];
|
||||
ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[0]);
|
||||
ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[1]);
|
||||
ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16(
|
||||
(__vector unsigned char)v.reg.val[0]);
|
||||
ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16(
|
||||
(__vector unsigned char)v.reg.val[1]);
|
||||
reg = vec_perm(ret[0], ret[1], omask);
|
||||
#elif defined(_ARCH_PWR9)
|
||||
__vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]);
|
||||
@ -425,8 +447,10 @@ inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) {
|
||||
__vector unsigned int rnd1 = vec_add(lsb1, bias);
|
||||
inp0 = vec_add(inp0, rnd0);
|
||||
inp1 = vec_add(inp1, rnd1);
|
||||
__vector __bool int sel0 = vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN);
|
||||
__vector __bool int sel1 = vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN);
|
||||
__vector __bool int sel0 =
|
||||
vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN);
|
||||
__vector __bool int sel1 =
|
||||
vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN);
|
||||
inp0 = vec_sel(inp0, nan, sel0);
|
||||
inp1 = vec_sel(inp1, nan, sel1);
|
||||
inp0 = vec_sr(inp0, sh16);
|
||||
@ -435,13 +459,17 @@ inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) {
|
||||
#endif
|
||||
}
|
||||
|
||||
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) {
|
||||
inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
|
||||
#ifdef _ARCH_PWR10
|
||||
__vector signed short ret[4];
|
||||
ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[0]);
|
||||
ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[1]);
|
||||
ret[2] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[2]);
|
||||
ret[3] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[3]);
|
||||
ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16(
|
||||
(__vector unsigned char)v.reg.val[0]);
|
||||
ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16(
|
||||
(__vector unsigned char)v.reg.val[1]);
|
||||
ret[2] = (__vector signed short)__builtin_vsx_xvcvspbf16(
|
||||
(__vector unsigned char)v.reg.val[2]);
|
||||
ret[3] = (__vector signed short)__builtin_vsx_xvcvspbf16(
|
||||
(__vector unsigned char)v.reg.val[3]);
|
||||
reg.val[0] = vec_perm(ret[0], ret[1], omask);
|
||||
reg.val[1] = vec_perm(ret[2], ret[3], omask);
|
||||
#elif defined(_ARCH_PWR9)
|
||||
@ -465,10 +493,14 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) {
|
||||
inp1 = vec_add(inp1, rnd1);
|
||||
inp2 = vec_add(inp2, rnd2);
|
||||
inp3 = vec_add(inp3, rnd3);
|
||||
__vector __bool int sel0 = vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN);
|
||||
__vector __bool int sel1 = vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN);
|
||||
__vector __bool int sel2 = vec_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN);
|
||||
__vector __bool int sel3 = vec_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN);
|
||||
__vector __bool int sel0 =
|
||||
vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN);
|
||||
__vector __bool int sel1 =
|
||||
vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN);
|
||||
__vector __bool int sel2 =
|
||||
vec_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN);
|
||||
__vector __bool int sel3 =
|
||||
vec_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN);
|
||||
inp0 = vec_sel(inp0, nan, sel0);
|
||||
inp1 = vec_sel(inp1, nan, sel1);
|
||||
inp2 = vec_sel(inp2, nan, sel2);
|
||||
@ -482,10 +514,10 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) {
|
||||
#endif
|
||||
}
|
||||
|
||||
inline void prefetch(const void *addr) {
|
||||
inline void prefetch(const void* addr) {
|
||||
__asm__ __volatile__("dcbt 0, %0" : : "r"(addr) : "memory");
|
||||
}
|
||||
|
||||
}; // namespace vec_op
|
||||
}; // namespace vec_op
|
||||
|
||||
#endif
|
||||
|
||||
@ -11,39 +11,40 @@ static_assert(false, "AVX2 must be supported for the current implementation.");
|
||||
|
||||
namespace vec_op {
|
||||
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||
|
||||
#ifndef CPU_OP_GUARD
|
||||
#define CPU_KERNEL_GUARD_IN(NAME)
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME)
|
||||
#define CPU_KERNEL_GUARD_IN(NAME)
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME)
|
||||
#else
|
||||
#define CPU_KERNEL_GUARD_IN(NAME) \
|
||||
RECORD_FUNCTION(#NAME, c10::ArrayRef<c10::IValue>({}));
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME)
|
||||
#define CPU_KERNEL_GUARD_IN(NAME) \
|
||||
RECORD_FUNCTION(#NAME, c10::ArrayRef<c10::IValue>({}));
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME)
|
||||
#endif
|
||||
|
||||
#define FORCE_INLINE __attribute__((always_inline)) inline
|
||||
|
||||
namespace {
|
||||
template <typename T, T... indexes, typename F>
|
||||
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F &&f) {
|
||||
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F&& f) {
|
||||
(f(std::integral_constant<T, indexes>{}), ...);
|
||||
}
|
||||
}; // namespace
|
||||
}; // namespace
|
||||
|
||||
template <typename T, T count, typename F,
|
||||
typename = std::enable_if_t<std::is_invocable_v<F, T>>>
|
||||
constexpr void unroll_loop(F &&f) {
|
||||
constexpr void unroll_loop(F&& f) {
|
||||
unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
|
||||
}
|
||||
|
||||
template <typename T> struct Vec {
|
||||
template <typename T>
|
||||
struct Vec {
|
||||
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; }
|
||||
};
|
||||
|
||||
@ -55,12 +56,12 @@ struct FP16Vec8 : public Vec<FP16Vec8> {
|
||||
|
||||
__m128i reg;
|
||||
|
||||
explicit FP16Vec8(const void *ptr)
|
||||
: reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {}
|
||||
explicit FP16Vec8(const void* ptr)
|
||||
: reg((__m128i)_mm_loadu_si128((__m128i*)ptr)) {}
|
||||
|
||||
explicit FP16Vec8(const FP32Vec8 &);
|
||||
explicit FP16Vec8(const FP32Vec8&);
|
||||
|
||||
void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; }
|
||||
void save(void* ptr) const { *reinterpret_cast<__m128i*>(ptr) = reg; }
|
||||
};
|
||||
|
||||
struct FP16Vec16 : public Vec<FP16Vec16> {
|
||||
@ -68,12 +69,12 @@ struct FP16Vec16 : public Vec<FP16Vec16> {
|
||||
|
||||
__m256i reg;
|
||||
|
||||
explicit FP16Vec16(const void *ptr)
|
||||
: reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {}
|
||||
explicit FP16Vec16(const void* ptr)
|
||||
: reg((__m256i)_mm256_loadu_si256((__m256i*)ptr)) {}
|
||||
|
||||
explicit FP16Vec16(const FP32Vec16 &);
|
||||
explicit FP16Vec16(const FP32Vec16&);
|
||||
|
||||
void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; }
|
||||
void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; }
|
||||
|
||||
void save(void* ptr, const int elem_num) const {
|
||||
constexpr uint32_t M = 0xFFFFFFFF;
|
||||
@ -87,12 +88,12 @@ struct BF16Vec8 : public Vec<BF16Vec8> {
|
||||
|
||||
__m128i reg;
|
||||
|
||||
explicit BF16Vec8(const void *ptr)
|
||||
: reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {}
|
||||
explicit BF16Vec8(const void* ptr)
|
||||
: reg((__m128i)_mm_loadu_si128((__m128i*)ptr)) {}
|
||||
|
||||
explicit BF16Vec8(const FP32Vec8 &);
|
||||
explicit BF16Vec8(const FP32Vec8&);
|
||||
|
||||
void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; }
|
||||
void save(void* ptr) const { *reinterpret_cast<__m128i*>(ptr) = reg; }
|
||||
};
|
||||
|
||||
struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||
@ -100,12 +101,12 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||
|
||||
__m256i reg;
|
||||
|
||||
explicit BF16Vec16(const void *ptr)
|
||||
: reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {}
|
||||
explicit BF16Vec16(const void* ptr)
|
||||
: reg((__m256i)_mm256_loadu_si256((__m256i*)ptr)) {}
|
||||
|
||||
explicit BF16Vec16(const FP32Vec16 &);
|
||||
explicit BF16Vec16(const FP32Vec16&);
|
||||
|
||||
void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; }
|
||||
void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; }
|
||||
|
||||
void save(void* ptr, const int elem_num) const {
|
||||
constexpr uint32_t M = 0xFFFFFFFF;
|
||||
@ -120,11 +121,11 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
|
||||
__m512i reg;
|
||||
|
||||
explicit BF16Vec32(const void *ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {}
|
||||
explicit BF16Vec32(const void* ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {}
|
||||
|
||||
explicit BF16Vec32(__m512i data) : reg(data) {}
|
||||
|
||||
explicit BF16Vec32(BF16Vec8 &vec8_data)
|
||||
explicit BF16Vec32(BF16Vec8& vec8_data)
|
||||
: reg((__m512i)_mm512_inserti32x4(
|
||||
_mm512_inserti32x4(_mm512_inserti32x4(_mm512_castsi128_si512(
|
||||
(__m128i)vec8_data.reg),
|
||||
@ -132,7 +133,7 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
(__m128i)vec8_data.reg, 2),
|
||||
(__m128i)vec8_data.reg, 3)) {}
|
||||
|
||||
void save(void *ptr) const { *reinterpret_cast<__m512i *>(ptr) = reg; }
|
||||
void save(void* ptr) const { *reinterpret_cast<__m512i*>(ptr) = reg; }
|
||||
};
|
||||
#else
|
||||
struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
@ -141,24 +142,24 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
__m256i reg_low;
|
||||
__m256i reg_high;
|
||||
|
||||
explicit BF16Vec32(const void *ptr)
|
||||
: reg_low(_mm256_loadu_si256((__m256i const *)ptr)),
|
||||
reg_high(_mm256_loadu_si256((__m256i const *)ptr + 1)) {}
|
||||
explicit BF16Vec32(const void* ptr)
|
||||
: reg_low(_mm256_loadu_si256((__m256i const*)ptr)),
|
||||
reg_high(_mm256_loadu_si256((__m256i const*)ptr + 1)) {}
|
||||
|
||||
explicit BF16Vec32(__m256i low, __m256i high) : reg_low(low),
|
||||
reg_high(high) {}
|
||||
explicit BF16Vec32(__m256i low, __m256i high)
|
||||
: reg_low(low), reg_high(high) {}
|
||||
|
||||
explicit BF16Vec32(BF16Vec8 &vec8_data)
|
||||
explicit BF16Vec32(BF16Vec8& vec8_data)
|
||||
: reg_low((__m256i)_mm256_inserti32x4(
|
||||
_mm256_castsi128_si256((__m128i)vec8_data.reg),
|
||||
(__m128i)vec8_data.reg, 1)),
|
||||
_mm256_castsi128_si256((__m128i)vec8_data.reg),
|
||||
(__m128i)vec8_data.reg, 1)),
|
||||
reg_high((__m256i)_mm256_inserti32x4(
|
||||
_mm256_castsi128_si256((__m128i)vec8_data.reg),
|
||||
(__m128i)vec8_data.reg, 1)) {}
|
||||
_mm256_castsi128_si256((__m128i)vec8_data.reg),
|
||||
(__m128i)vec8_data.reg, 1)) {}
|
||||
|
||||
void save(void *ptr) const {
|
||||
*reinterpret_cast<__m256i *>(ptr) = reg_low;
|
||||
*reinterpret_cast<__m256i *>((__m256i *)ptr + 1) = reg_high;
|
||||
void save(void* ptr) const {
|
||||
*reinterpret_cast<__m256i*>(ptr) = reg_low;
|
||||
*reinterpret_cast<__m256i*>((__m256i*)ptr + 1) = reg_high;
|
||||
}
|
||||
};
|
||||
#endif
|
||||
@ -176,11 +177,11 @@ struct FP32Vec4 : public Vec<FP32Vec4> {
|
||||
|
||||
explicit FP32Vec4() : reg(_mm_set1_ps(0.0)) {}
|
||||
|
||||
explicit FP32Vec4(const float *ptr) : reg(_mm_loadu_ps(ptr)) {}
|
||||
explicit FP32Vec4(const float* ptr) : reg(_mm_loadu_ps(ptr)) {}
|
||||
|
||||
explicit FP32Vec4(__m128 data) : reg(data) {}
|
||||
|
||||
explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {}
|
||||
explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {}
|
||||
};
|
||||
|
||||
struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
@ -196,15 +197,15 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
|
||||
explicit FP32Vec8() : reg(_mm256_set1_ps(0.0)) {}
|
||||
|
||||
explicit FP32Vec8(const float *ptr) : reg(_mm256_loadu_ps(ptr)) {}
|
||||
explicit FP32Vec8(const float* ptr) : reg(_mm256_loadu_ps(ptr)) {}
|
||||
|
||||
explicit FP32Vec8(__m256 data) : reg(data) {}
|
||||
|
||||
explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {}
|
||||
explicit FP32Vec8(const FP32Vec8& data) : reg(data.reg) {}
|
||||
|
||||
explicit FP32Vec8(const FP16Vec8 &v) : reg(_mm256_cvtph_ps(v.reg)) {}
|
||||
explicit FP32Vec8(const FP16Vec8& v) : reg(_mm256_cvtph_ps(v.reg)) {}
|
||||
|
||||
explicit FP32Vec8(const BF16Vec8 &v)
|
||||
explicit FP32Vec8(const BF16Vec8& v)
|
||||
: reg(_mm256_castsi256_ps(
|
||||
_mm256_bslli_epi128(_mm256_cvtepu16_epi32(v.reg), 2))) {}
|
||||
|
||||
@ -212,7 +213,8 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
float result = 0;
|
||||
unroll_loop<int, VEC_ELEM_NUM>([&result, &ar](int i) { result += ar.values[i]; });
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&result, &ar](int i) { result += ar.values[i]; });
|
||||
|
||||
return result;
|
||||
}
|
||||
@ -244,27 +246,27 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
erf(ar.values[1]), erf(ar.values[0])));
|
||||
}
|
||||
|
||||
FP32Vec8 operator*(const FP32Vec8 &b) const {
|
||||
FP32Vec8 operator*(const FP32Vec8& b) const {
|
||||
return FP32Vec8(_mm256_mul_ps(reg, b.reg));
|
||||
}
|
||||
|
||||
FP32Vec8 operator+(const FP32Vec8 &b) const {
|
||||
FP32Vec8 operator+(const FP32Vec8& b) const {
|
||||
return FP32Vec8(_mm256_add_ps(reg, b.reg));
|
||||
}
|
||||
|
||||
FP32Vec8 operator-(const FP32Vec8 &b) const {
|
||||
FP32Vec8 operator-(const FP32Vec8& b) const {
|
||||
return FP32Vec8(_mm256_sub_ps(reg, b.reg));
|
||||
}
|
||||
|
||||
FP32Vec8 operator/(const FP32Vec8 &b) const {
|
||||
FP32Vec8 operator/(const FP32Vec8& b) const {
|
||||
return FP32Vec8(_mm256_div_ps(reg, b.reg));
|
||||
}
|
||||
|
||||
void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); }
|
||||
void save(float* ptr) const { _mm256_storeu_ps(ptr, reg); }
|
||||
};
|
||||
|
||||
#ifdef __AVX512F__
|
||||
struct INT32Vec16: public Vec<INT32Vec16> {
|
||||
struct INT32Vec16 : public Vec<INT32Vec16> {
|
||||
constexpr static int VEC_ELEM_NUM = 16;
|
||||
union AliasReg {
|
||||
__m512i reg;
|
||||
@ -272,12 +274,11 @@ struct INT32Vec16: public Vec<INT32Vec16> {
|
||||
};
|
||||
|
||||
__m512i reg;
|
||||
|
||||
explicit INT32Vec16(const void* data_ptr) : reg(_mm512_loadu_epi32(data_ptr)) {}
|
||||
|
||||
void save(int32_t* ptr) const {
|
||||
_mm512_storeu_epi32(ptr, reg);
|
||||
}
|
||||
explicit INT32Vec16(const void* data_ptr)
|
||||
: reg(_mm512_loadu_epi32(data_ptr)) {}
|
||||
|
||||
void save(int32_t* ptr) const { _mm512_storeu_epi32(ptr, reg); }
|
||||
|
||||
void save(int32_t* ptr, const int elem_num) const {
|
||||
constexpr uint32_t M = 0xFFFFFFFF;
|
||||
@ -301,11 +302,11 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
|
||||
explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {}
|
||||
|
||||
explicit FP32Vec16(const float *ptr) : reg(_mm512_loadu_ps(ptr)) {}
|
||||
explicit FP32Vec16(const float* ptr) : reg(_mm512_loadu_ps(ptr)) {}
|
||||
|
||||
explicit FP32Vec16(__m512 data) : reg(data) {}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec4 &data)
|
||||
explicit FP32Vec16(const FP32Vec4& data)
|
||||
: reg((__m512)_mm512_inserti32x4(
|
||||
_mm512_inserti32x4(
|
||||
_mm512_inserti32x4(_mm512_castsi128_si512((__m128i)data.reg),
|
||||
@ -313,36 +314,37 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
(__m128i)data.reg, 2),
|
||||
(__m128i)data.reg, 3)) {}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec8 &data)
|
||||
explicit FP32Vec16(const FP32Vec8& data)
|
||||
: reg((__m512)_mm512_inserti32x8(
|
||||
_mm512_castsi256_si512((__m256i)data.reg), (__m256i)data.reg, 1)) {}
|
||||
|
||||
explicit FP32Vec16(const BF16Vec16 &v)
|
||||
explicit FP32Vec16(const BF16Vec16& v)
|
||||
: reg(_mm512_castsi512_ps(
|
||||
_mm512_bslli_epi128(_mm512_cvtepu16_epi32(v.reg), 2))) {}
|
||||
|
||||
explicit FP32Vec16(const FP16Vec16 &v) : reg(_mm512_cvtph_ps(v.reg)) {}
|
||||
explicit FP32Vec16(const FP16Vec16& v) : reg(_mm512_cvtph_ps(v.reg)) {}
|
||||
|
||||
explicit FP32Vec16(const FP16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
|
||||
explicit FP32Vec16(const FP16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
|
||||
|
||||
explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
|
||||
explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
|
||||
|
||||
explicit FP32Vec16(const INT32Vec16 &v)
|
||||
: reg(_mm512_cvt_roundepi32_ps(v.reg, _MM_FROUND_TO_NEAREST_INT |_MM_FROUND_NO_EXC)) {}
|
||||
explicit FP32Vec16(const INT32Vec16& v)
|
||||
: reg(_mm512_cvt_roundepi32_ps(
|
||||
v.reg, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {}
|
||||
|
||||
FP32Vec16 operator*(const FP32Vec16 &b) const {
|
||||
FP32Vec16 operator*(const FP32Vec16& b) const {
|
||||
return FP32Vec16(_mm512_mul_ps(reg, b.reg));
|
||||
}
|
||||
|
||||
FP32Vec16 operator+(const FP32Vec16 &b) const {
|
||||
FP32Vec16 operator+(const FP32Vec16& b) const {
|
||||
return FP32Vec16(_mm512_add_ps(reg, b.reg));
|
||||
}
|
||||
|
||||
FP32Vec16 operator-(const FP32Vec16 &b) const {
|
||||
FP32Vec16 operator-(const FP32Vec16& b) const {
|
||||
return FP32Vec16(_mm512_sub_ps(reg, b.reg));
|
||||
}
|
||||
|
||||
FP32Vec16 operator/(const FP32Vec16 &b) const {
|
||||
FP32Vec16 operator/(const FP32Vec16& b) const {
|
||||
return FP32Vec16(_mm512_div_ps(reg, b.reg));
|
||||
}
|
||||
|
||||
@ -370,9 +372,7 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
return FP32Vec16(_mm512_mask_min_ps(reg, mask, reg, b.reg));
|
||||
}
|
||||
|
||||
FP32Vec16 abs() const {
|
||||
return FP32Vec16(_mm512_abs_ps(reg));
|
||||
}
|
||||
FP32Vec16 abs() const { return FP32Vec16(_mm512_abs_ps(reg)); }
|
||||
|
||||
float reduce_sum() const { return _mm512_reduce_add_ps(reg); }
|
||||
|
||||
@ -380,14 +380,15 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
|
||||
float reduce_min() const { return _mm512_reduce_min_ps(reg); }
|
||||
|
||||
template <int group_size> float reduce_sub_sum(int idx) {
|
||||
template <int group_size>
|
||||
float reduce_sub_sum(int idx) {
|
||||
static_assert(VEC_ELEM_NUM % group_size == 0);
|
||||
constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
|
||||
__mmask16 mask = _cvtu32_mask16(base_mask << (idx * group_size));
|
||||
return _mm512_mask_reduce_add_ps(mask, reg);
|
||||
}
|
||||
|
||||
void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); }
|
||||
void save(float* ptr) const { _mm512_storeu_ps(ptr, reg); }
|
||||
|
||||
void save(float* ptr, const int elem_num) const {
|
||||
constexpr uint32_t M = 0xFFFFFFFF;
|
||||
@ -407,32 +408,30 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
__m256 reg_low;
|
||||
__m256 reg_high;
|
||||
|
||||
explicit FP32Vec16(float v) : reg_low(_mm256_set1_ps(v)),
|
||||
reg_high(_mm256_set1_ps(v)) {}
|
||||
explicit FP32Vec16(float v)
|
||||
: reg_low(_mm256_set1_ps(v)), reg_high(_mm256_set1_ps(v)) {}
|
||||
|
||||
explicit FP32Vec16() : reg_low(_mm256_set1_ps(0.0)),
|
||||
reg_high(_mm256_set1_ps(0.0)) {}
|
||||
explicit FP32Vec16()
|
||||
: reg_low(_mm256_set1_ps(0.0)), reg_high(_mm256_set1_ps(0.0)) {}
|
||||
|
||||
explicit FP32Vec16(const float *ptr) : reg_low(_mm256_loadu_ps(ptr)),
|
||||
reg_high(_mm256_loadu_ps(ptr + 8)) {}
|
||||
explicit FP32Vec16(const float* ptr)
|
||||
: reg_low(_mm256_loadu_ps(ptr)), reg_high(_mm256_loadu_ps(ptr + 8)) {}
|
||||
|
||||
explicit FP32Vec16(__m256 low, __m256 high) : reg_low(low), reg_high(high) {}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec16 &data) : reg_low(data.reg_low),
|
||||
reg_high(data.reg_high) {}
|
||||
explicit FP32Vec16(const FP32Vec16& data)
|
||||
: reg_low(data.reg_low), reg_high(data.reg_high) {}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec4 &data)
|
||||
explicit FP32Vec16(const FP32Vec4& data)
|
||||
: reg_low((__m256)_mm256_inserti128_si256(
|
||||
_mm256_castsi128_si256((__m128i)data.reg),
|
||||
(__m128i)data.reg, 1)),
|
||||
_mm256_castsi128_si256((__m128i)data.reg), (__m128i)data.reg, 1)),
|
||||
reg_high((__m256)_mm256_inserti128_si256(
|
||||
_mm256_castsi128_si256((__m128i)data.reg),
|
||||
(__m128i)data.reg, 1)) {}
|
||||
_mm256_castsi128_si256((__m128i)data.reg), (__m128i)data.reg, 1)) {}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec8 &data)
|
||||
explicit FP32Vec16(const FP32Vec8& data)
|
||||
: reg_low(data.reg), reg_high(data.reg) {}
|
||||
|
||||
explicit FP32Vec16(const FP16Vec16 &v) {
|
||||
explicit FP32Vec16(const FP16Vec16& v) {
|
||||
__m128i low = _mm256_extractf128_si256(v.reg, 0);
|
||||
__m128i high = _mm256_extractf128_si256(v.reg, 1);
|
||||
|
||||
@ -440,9 +439,9 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
reg_high = _mm256_cvtph_ps(high);
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const FP16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
|
||||
explicit FP32Vec16(const FP16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
|
||||
|
||||
explicit FP32Vec16(const BF16Vec16 &v) {
|
||||
explicit FP32Vec16(const BF16Vec16& v) {
|
||||
__m128i low = _mm256_extractf128_si256(v.reg, 0);
|
||||
__m128i high = _mm256_extractf128_si256(v.reg, 1);
|
||||
|
||||
@ -456,24 +455,24 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
reg_high = _mm256_castsi256_ps(v_high_shifted);
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
|
||||
explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
|
||||
|
||||
FP32Vec16 operator*(const FP32Vec16 &b) const {
|
||||
FP32Vec16 operator*(const FP32Vec16& b) const {
|
||||
return FP32Vec16(_mm256_mul_ps(reg_low, b.reg_low),
|
||||
_mm256_mul_ps(reg_high, b.reg_high));
|
||||
}
|
||||
|
||||
FP32Vec16 operator+(const FP32Vec16 &b) const {
|
||||
FP32Vec16 operator+(const FP32Vec16& b) const {
|
||||
return FP32Vec16(_mm256_add_ps(reg_low, b.reg_low),
|
||||
_mm256_add_ps(reg_high, b.reg_high));
|
||||
}
|
||||
|
||||
FP32Vec16 operator-(const FP32Vec16 &b) const {
|
||||
FP32Vec16 operator-(const FP32Vec16& b) const {
|
||||
return FP32Vec16(_mm256_sub_ps(reg_low, b.reg_low),
|
||||
_mm256_sub_ps(reg_high, b.reg_high));
|
||||
}
|
||||
|
||||
FP32Vec16 operator/(const FP32Vec16 &b) const {
|
||||
FP32Vec16 operator/(const FP32Vec16& b) const {
|
||||
return FP32Vec16(_mm256_div_ps(reg_low, b.reg_low),
|
||||
_mm256_div_ps(reg_high, b.reg_high));
|
||||
}
|
||||
@ -484,7 +483,8 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
return low.reduce_sum() + high.reduce_sum();
|
||||
}
|
||||
|
||||
template <int group_size> float reduce_sub_sum(int idx) {
|
||||
template <int group_size>
|
||||
float reduce_sub_sum(int idx) {
|
||||
float sum = 0.0;
|
||||
static_assert(VEC_ELEM_NUM % group_size == 0);
|
||||
constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
|
||||
@ -507,7 +507,7 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
return sum;
|
||||
}
|
||||
|
||||
void save(float *ptr) const {
|
||||
void save(float* ptr) const {
|
||||
_mm256_storeu_ps(ptr, reg_low);
|
||||
_mm256_storeu_ps(ptr + 8, reg_high);
|
||||
}
|
||||
@ -515,7 +515,7 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
#endif
|
||||
|
||||
#ifdef __AVX512F__
|
||||
struct INT8Vec16: public Vec<INT8Vec16> {
|
||||
struct INT8Vec16 : public Vec<INT8Vec16> {
|
||||
constexpr static int VEC_ELEM_NUM = 16;
|
||||
union AliasReg {
|
||||
__m128i reg;
|
||||
@ -523,14 +523,12 @@ struct INT8Vec16: public Vec<INT8Vec16> {
|
||||
};
|
||||
|
||||
__m128i reg;
|
||||
|
||||
explicit INT8Vec16(const FP32Vec16& vec) : reg(
|
||||
_mm512_cvtepi32_epi8(_mm512_cvt_roundps_epi32(vec.reg, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC))
|
||||
) {}
|
||||
|
||||
void save(int8_t* ptr) const {
|
||||
_mm_storeu_epi8(ptr, reg);
|
||||
}
|
||||
explicit INT8Vec16(const FP32Vec16& vec)
|
||||
: reg(_mm512_cvtepi32_epi8(_mm512_cvt_roundps_epi32(
|
||||
vec.reg, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC))) {}
|
||||
|
||||
void save(int8_t* ptr) const { _mm_storeu_epi8(ptr, reg); }
|
||||
|
||||
void save(int8_t* ptr, const int elem_num) const {
|
||||
constexpr uint32_t M = 0xFFFFFFFF;
|
||||
@ -540,71 +538,92 @@ struct INT8Vec16: public Vec<INT8Vec16> {
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename T> struct VecType { using vec_type = void; };
|
||||
template <typename T>
|
||||
struct VecType {
|
||||
using vec_type = void;
|
||||
};
|
||||
|
||||
template <typename T> using vec_t = typename VecType<T>::vec_type;
|
||||
template <typename T>
|
||||
using vec_t = typename VecType<T>::vec_type;
|
||||
|
||||
template <> struct VecType<float> { using vec_type = FP32Vec8; };
|
||||
template <>
|
||||
struct VecType<float> {
|
||||
using vec_type = FP32Vec8;
|
||||
};
|
||||
|
||||
template <> struct VecType<c10::Half> { using vec_type = FP16Vec8; };
|
||||
template <>
|
||||
struct VecType<c10::Half> {
|
||||
using vec_type = FP16Vec8;
|
||||
};
|
||||
|
||||
template <> struct VecType<c10::BFloat16> { using vec_type = BF16Vec8; };
|
||||
template <>
|
||||
struct VecType<c10::BFloat16> {
|
||||
using vec_type = BF16Vec8;
|
||||
};
|
||||
|
||||
template <typename T> void storeFP32(float v, T *ptr) { *ptr = v; }
|
||||
template <typename T>
|
||||
void storeFP32(float v, T* ptr) {
|
||||
*ptr = v;
|
||||
}
|
||||
|
||||
inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) {
|
||||
inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) {
|
||||
acc = acc + a * b;
|
||||
}
|
||||
|
||||
template <> inline void storeFP32<c10::Half>(float v, c10::Half *ptr) {
|
||||
*reinterpret_cast<unsigned short *>(ptr) =
|
||||
template <>
|
||||
inline void storeFP32<c10::Half>(float v, c10::Half* ptr) {
|
||||
*reinterpret_cast<unsigned short*>(ptr) =
|
||||
_cvtss_sh(v, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
|
||||
}
|
||||
|
||||
inline FP16Vec8::FP16Vec8(const FP32Vec8 &v)
|
||||
inline FP16Vec8::FP16Vec8(const FP32Vec8& v)
|
||||
: reg(_mm256_cvtps_ph(v.reg,
|
||||
_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {}
|
||||
|
||||
#ifdef __AVX512F__
|
||||
inline FP16Vec16::FP16Vec16(const FP32Vec16 &v)
|
||||
inline FP16Vec16::FP16Vec16(const FP32Vec16& v)
|
||||
: reg(_mm512_cvtps_ph(v.reg,
|
||||
_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {}
|
||||
#else
|
||||
inline FP16Vec16::FP16Vec16(const FP32Vec16 &v)
|
||||
: reg(_mm256_insertf128_si256(_mm256_castsi128_si256(FP16Vec8(FP32Vec8(v.reg_low)).reg), FP16Vec8(FP32Vec8(v.reg_low)).reg, 1)) {}
|
||||
inline FP16Vec16::FP16Vec16(const FP32Vec16& v)
|
||||
: reg(_mm256_insertf128_si256(
|
||||
_mm256_castsi128_si256(FP16Vec8(FP32Vec8(v.reg_low)).reg),
|
||||
FP16Vec8(FP32Vec8(v.reg_low)).reg, 1)) {}
|
||||
#endif
|
||||
|
||||
#ifdef __AVX512BF16__
|
||||
template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
|
||||
*reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v);
|
||||
template <>
|
||||
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
|
||||
*reinterpret_cast<__bfloat16*>(ptr) = _mm_cvtness_sbh(v);
|
||||
}
|
||||
|
||||
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
|
||||
inline BF16Vec8::BF16Vec8(const FP32Vec8& v)
|
||||
: reg((__m128i)_mm256_cvtneps_pbh(v.reg)) {}
|
||||
|
||||
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v)
|
||||
inline BF16Vec16::BF16Vec16(const FP32Vec16& v)
|
||||
: reg((__m256i)_mm512_cvtneps_pbh(v.reg)) {}
|
||||
|
||||
inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) {
|
||||
inline void fma(FP32Vec16& acc, BF16Vec32& a, BF16Vec32& b) {
|
||||
acc.reg = _mm512_dpbf16_ps(acc.reg, (__m512bh)a.reg, (__m512bh)b.reg);
|
||||
}
|
||||
#else
|
||||
template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
|
||||
c10::BFloat16 __attribute__((__may_alias__)) *v_ptr =
|
||||
reinterpret_cast<c10::BFloat16 *>(&v);
|
||||
template <>
|
||||
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
|
||||
c10::BFloat16 __attribute__((__may_alias__))* v_ptr =
|
||||
reinterpret_cast<c10::BFloat16*>(&v);
|
||||
*ptr = *(v_ptr + 1);
|
||||
}
|
||||
|
||||
#ifdef __AVX512F__
|
||||
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
|
||||
#ifdef __AVX512F__
|
||||
inline BF16Vec8::BF16Vec8(const FP32Vec8& v)
|
||||
: reg(_mm256_cvtepi32_epi16(
|
||||
_mm256_bsrli_epi128(_mm256_castps_si256(v.reg), 2))) {}
|
||||
|
||||
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v)
|
||||
inline BF16Vec16::BF16Vec16(const FP32Vec16& v)
|
||||
: reg(_mm512_cvtepi32_epi16(
|
||||
_mm512_bsrli_epi128(_mm512_castps_si512(v.reg), 2))) {}
|
||||
#else
|
||||
namespace{
|
||||
#else
|
||||
namespace {
|
||||
__m128i FP32Vec8_to_BF16Vec8_avx2(__m256 a) {
|
||||
__m256i ai = _mm256_castps_si256(a);
|
||||
ai = _mm256_srli_epi32(ai, 16);
|
||||
@ -612,21 +631,21 @@ __m128i FP32Vec8_to_BF16Vec8_avx2(__m256 a) {
|
||||
ai = _mm256_permute4x64_epi64(ai, 0b00111001);
|
||||
return _mm256_extracti128_si256(ai, 0);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
|
||||
inline BF16Vec8::BF16Vec8(const FP32Vec8& v)
|
||||
: reg(FP32Vec8_to_BF16Vec8_avx2(v.reg)) {}
|
||||
|
||||
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) {
|
||||
inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
|
||||
BF16Vec8 low = BF16Vec8(FP32Vec8(v.reg_low));
|
||||
BF16Vec8 high = BF16Vec8(FP32Vec8(v.reg_high));
|
||||
reg = _mm256_insertf128_si256(_mm256_castsi128_si256(low.reg), high.reg, 1);
|
||||
}
|
||||
#endif // __AVX512F__
|
||||
#endif // __AVX512BF16__
|
||||
#endif // __AVX512F__
|
||||
#endif // __AVX512BF16__
|
||||
|
||||
inline void prefetch(const void *addr) { _mm_prefetch(addr, _MM_HINT_T1); }
|
||||
inline void prefetch(const void* addr) { _mm_prefetch(addr, _MM_HINT_T1); }
|
||||
|
||||
}; // namespace vec_op
|
||||
}; // namespace vec_op
|
||||
|
||||
#endif
|
||||
|
||||
@ -30,7 +30,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor value_cache, int num_kv_heads, float scale,"
|
||||
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
||||
" int max_seq_len, Tensor? alibi_slopes,"
|
||||
" str kv_cache_dtype, float k_scale, float v_scale,"
|
||||
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
|
||||
" int tp_rank, int blocksparse_local_blocks,"
|
||||
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
||||
" int blocksparse_head_sliding_step) -> ()");
|
||||
@ -44,7 +44,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor value_cache, int num_kv_heads, float scale,"
|
||||
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
||||
" int max_seq_len, Tensor? alibi_slopes,"
|
||||
" str kv_cache_dtype, float k_scale, float v_scale,"
|
||||
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
|
||||
" int tp_rank, int blocksparse_local_blocks,"
|
||||
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
||||
" int blocksparse_head_sliding_step) -> ()");
|
||||
@ -148,7 +148,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
" Tensor! key_cache, Tensor! value_cache,"
|
||||
" Tensor slot_mapping,"
|
||||
" str kv_cache_dtype,"
|
||||
" float k_scale, float v_scale) -> ()");
|
||||
" Tensor k_scale, Tensor v_scale) -> ()");
|
||||
cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache);
|
||||
}
|
||||
|
||||
|
||||
310
csrc/cumem_allocator.cpp
Normal file
310
csrc/cumem_allocator.cpp
Normal file
@ -0,0 +1,310 @@
|
||||
// A CUDAPluggableAllocator based on cumem* APIs.
|
||||
// Important: allocation size, CUdeviceptr and CUmemGenericAllocationHandle*
|
||||
// need to be unsigned long long
|
||||
#include <iostream>
|
||||
|
||||
extern "C" {
|
||||
|
||||
#define PY_SSIZE_T_CLEAN
|
||||
#include <Python.h>
|
||||
|
||||
#include <sys/types.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cuda.h>
|
||||
|
||||
#define CUDA_CHECK(condition) \
|
||||
do { \
|
||||
CUresult error = condition; \
|
||||
if (error != 0) { \
|
||||
char* error_string; \
|
||||
cuGetErrorString(error, (const char**)&error_string); \
|
||||
std::cerr << "CUDA Error: " << error_string << " at " << __FILE__ << ":" \
|
||||
<< __LINE__ << std::endl; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// Global references to Python callables
|
||||
// NOTE: this is borrowed reference, so we don't need to DECREF them.
|
||||
// This brings the limitation that the allocator needs to be singleton.
|
||||
static PyObject* g_python_malloc_callback = nullptr;
|
||||
static PyObject* g_python_free_callback = nullptr;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helper functions:
|
||||
|
||||
void ensure_context(unsigned long long device) {
|
||||
CUcontext pctx;
|
||||
CUDA_CHECK(cuCtxGetCurrent(&pctx));
|
||||
if (!pctx) {
|
||||
// Ensure device context.
|
||||
CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device));
|
||||
CUDA_CHECK(cuCtxSetCurrent(pctx));
|
||||
}
|
||||
}
|
||||
|
||||
void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem,
|
||||
CUmemGenericAllocationHandle* p_memHandle) {
|
||||
ensure_context(device);
|
||||
// Define memory allocation properties
|
||||
CUmemAllocationProp prop = {};
|
||||
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
|
||||
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
|
||||
prop.location.id = device;
|
||||
prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_NONE;
|
||||
|
||||
// Allocate memory using cuMemCreate
|
||||
CUDA_CHECK(cuMemCreate(p_memHandle, size, &prop, 0));
|
||||
CUDA_CHECK(cuMemMap(d_mem, size, 0, *p_memHandle, 0));
|
||||
|
||||
CUmemAccessDesc accessDesc = {};
|
||||
accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
|
||||
accessDesc.location.id = device;
|
||||
accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
|
||||
|
||||
CUDA_CHECK(cuMemSetAccess(d_mem, size, &accessDesc, 1));
|
||||
// std::cout << "create_and_map: device=" << device << ", size=" << size << ",
|
||||
// d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl;
|
||||
}
|
||||
|
||||
void unmap_and_release(unsigned long long device, ssize_t size,
|
||||
CUdeviceptr d_mem,
|
||||
CUmemGenericAllocationHandle* p_memHandle) {
|
||||
// std::cout << "unmap_and_release: device=" << device << ", size=" << size <<
|
||||
// ", d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl;
|
||||
ensure_context(device);
|
||||
CUDA_CHECK(cuMemUnmap(d_mem, size));
|
||||
CUDA_CHECK(cuMemRelease(*p_memHandle));
|
||||
}
|
||||
|
||||
PyObject* create_tuple_from_c_integers(unsigned long long a,
|
||||
unsigned long long b,
|
||||
unsigned long long c,
|
||||
unsigned long long d) {
|
||||
// Create a new tuple of size 4
|
||||
PyObject* tuple = PyTuple_New(4);
|
||||
if (!tuple) {
|
||||
return NULL; // Return NULL on failure
|
||||
}
|
||||
|
||||
// Convert integers to Python objects and set them in the tuple
|
||||
PyTuple_SetItem(
|
||||
tuple, 0,
|
||||
PyLong_FromUnsignedLongLong(a)); // Steals reference to the PyLong
|
||||
PyTuple_SetItem(tuple, 1, PyLong_FromUnsignedLongLong(b));
|
||||
PyTuple_SetItem(tuple, 2, PyLong_FromUnsignedLongLong(c));
|
||||
PyTuple_SetItem(tuple, 3, PyLong_FromUnsignedLongLong(d));
|
||||
|
||||
// Note: PyTuple_SetItem "steals" a reference to each object,
|
||||
// so we do not need to Py_DECREF the PyLong objects explicitly.
|
||||
|
||||
return tuple; // Return the created tuple
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Our exported C functions that call Python:
|
||||
|
||||
// use CUstream instead of cudaStream_t, to avoid including cuda_runtime_api.h
|
||||
void* my_malloc(ssize_t size, int device, CUstream stream) {
|
||||
ensure_context(device);
|
||||
|
||||
// first allocation, align the size, and reserve an address, and also allocate
|
||||
// a CUmemGenericAllocationHandle
|
||||
|
||||
// Define memory allocation properties
|
||||
CUmemAllocationProp prop = {};
|
||||
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
|
||||
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
|
||||
prop.location.id = device;
|
||||
prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_NONE;
|
||||
|
||||
// Check if the allocation is supported
|
||||
size_t granularity;
|
||||
CUDA_CHECK(cuMemGetAllocationGranularity(&granularity, &prop,
|
||||
CU_MEM_ALLOC_GRANULARITY_MINIMUM));
|
||||
|
||||
size_t alignedSize = ((size + granularity - 1) / granularity) * granularity;
|
||||
|
||||
CUdeviceptr d_mem;
|
||||
CUDA_CHECK(cuMemAddressReserve(&d_mem, alignedSize, 0, 0, 0));
|
||||
|
||||
// allocate the CUmemGenericAllocationHandle
|
||||
CUmemGenericAllocationHandle* p_memHandle =
|
||||
(CUmemGenericAllocationHandle*)malloc(
|
||||
sizeof(CUmemGenericAllocationHandle));
|
||||
|
||||
if (!g_python_malloc_callback) {
|
||||
std::cerr << "ERROR: g_python_malloc_callback not set.\n";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Acquire GIL (not in stable ABI officially, but often works)
|
||||
PyGILState_STATE gstate = PyGILState_Ensure();
|
||||
|
||||
PyObject* arg_tuple = create_tuple_from_c_integers(
|
||||
(unsigned long long)device, (unsigned long long)alignedSize,
|
||||
(unsigned long long)d_mem, (unsigned long long)p_memHandle);
|
||||
|
||||
// Call g_python_malloc_callback
|
||||
PyObject* py_result =
|
||||
PyObject_CallFunctionObjArgs(g_python_malloc_callback, arg_tuple, NULL);
|
||||
Py_DECREF(arg_tuple);
|
||||
|
||||
if (!py_result) {
|
||||
PyErr_Print();
|
||||
PyGILState_Release(gstate);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
PyGILState_Release(gstate);
|
||||
|
||||
// do the final mapping
|
||||
create_and_map(device, alignedSize, d_mem, p_memHandle);
|
||||
|
||||
return (void*)d_mem;
|
||||
}
|
||||
|
||||
// use CUstream instead of cudaStream_t, to avoid including cuda_runtime_api.h
|
||||
void my_free(void* ptr, ssize_t size, int device, CUstream stream) {
|
||||
// get memory handle from the pointer
|
||||
if (!g_python_free_callback) {
|
||||
std::cerr << "ERROR: g_python_free_callback not set.\n";
|
||||
return;
|
||||
}
|
||||
|
||||
// Acquire GIL (not in stable ABI officially, but often works)
|
||||
PyGILState_STATE gstate = PyGILState_Ensure();
|
||||
|
||||
PyObject* py_ptr =
|
||||
PyLong_FromUnsignedLongLong(reinterpret_cast<unsigned long long>(ptr));
|
||||
|
||||
PyObject* py_result =
|
||||
PyObject_CallFunctionObjArgs(g_python_free_callback, py_ptr, NULL);
|
||||
|
||||
if (!py_result || !PyTuple_Check(py_result) || PyTuple_Size(py_result) != 4) {
|
||||
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4");
|
||||
return;
|
||||
}
|
||||
|
||||
unsigned long long recv_device, recv_size;
|
||||
unsigned long long recv_d_mem, recv_p_memHandle;
|
||||
// Unpack the tuple into four C integers
|
||||
if (!PyArg_ParseTuple(py_result, "KKKK", &recv_device, &recv_size,
|
||||
&recv_d_mem, &recv_p_memHandle)) {
|
||||
// PyArg_ParseTuple sets an error if it fails
|
||||
return;
|
||||
}
|
||||
|
||||
PyGILState_Release(gstate);
|
||||
|
||||
// recv_size == size
|
||||
// recv_device == device
|
||||
|
||||
// Free memory
|
||||
|
||||
CUdeviceptr d_mem = (CUdeviceptr)recv_d_mem;
|
||||
CUmemGenericAllocationHandle* p_memHandle =
|
||||
(CUmemGenericAllocationHandle*)recv_p_memHandle;
|
||||
unmap_and_release(device, size, d_mem, p_memHandle);
|
||||
|
||||
// free address and the handle
|
||||
CUDA_CHECK(cuMemAddressFree(d_mem, size));
|
||||
free(p_memHandle);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Python extension boilerplate:
|
||||
|
||||
// Python-exposed function: init_module(python_malloc, python_free)
|
||||
static PyObject* py_init_module(PyObject* self, PyObject* args) {
|
||||
PyObject* malloc_callback = nullptr;
|
||||
PyObject* free_callback = nullptr;
|
||||
|
||||
if (!PyArg_ParseTuple(args, "OO", &malloc_callback, &free_callback)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!PyCallable_Check(malloc_callback) || !PyCallable_Check(free_callback)) {
|
||||
PyErr_SetString(PyExc_TypeError, "Both arguments must be callables");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Save the Python callables
|
||||
// This module does not handle GC of these objects, so they must be kept alive
|
||||
// outside of this module.
|
||||
g_python_malloc_callback = malloc_callback;
|
||||
g_python_free_callback = free_callback;
|
||||
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject* python_unmap_and_release(PyObject* self, PyObject* args) {
|
||||
if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) {
|
||||
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
unsigned long long recv_device, recv_size;
|
||||
unsigned long long recv_d_mem, recv_p_memHandle;
|
||||
// Unpack the tuple into four C integers
|
||||
if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem,
|
||||
&recv_p_memHandle)) {
|
||||
// PyArg_ParseTuple sets an error if it fails
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
CUdeviceptr d_mem_ptr = (CUdeviceptr)recv_d_mem;
|
||||
CUmemGenericAllocationHandle* p_memHandle =
|
||||
(CUmemGenericAllocationHandle*)recv_p_memHandle;
|
||||
|
||||
unmap_and_release(recv_device, recv_size, d_mem_ptr, p_memHandle);
|
||||
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject* python_create_and_map(PyObject* self, PyObject* args) {
|
||||
if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) {
|
||||
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
unsigned long long recv_device, recv_size;
|
||||
unsigned long long recv_d_mem, recv_p_memHandle;
|
||||
// Unpack the tuple into four C integers
|
||||
if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem,
|
||||
&recv_p_memHandle)) {
|
||||
// PyArg_ParseTuple sets an error if it fails
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
CUdeviceptr d_mem_ptr = (CUdeviceptr)recv_d_mem;
|
||||
CUmemGenericAllocationHandle* p_memHandle =
|
||||
(CUmemGenericAllocationHandle*)recv_p_memHandle;
|
||||
|
||||
create_and_map(recv_device, recv_size, d_mem_ptr, p_memHandle);
|
||||
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyMethodDef module_methods[] = {
|
||||
{"init_module", (PyCFunction)py_init_module, METH_VARARGS,
|
||||
"Initialize module with python_malloc and python_free callables."},
|
||||
{"python_create_and_map", (PyCFunction)python_create_and_map, METH_VARARGS,
|
||||
"Create and map memory on the device."},
|
||||
{"python_unmap_and_release", (PyCFunction)python_unmap_and_release,
|
||||
METH_VARARGS, "Unmap and release memory on the device."},
|
||||
{NULL, NULL, 0, NULL} // sentinel
|
||||
};
|
||||
|
||||
static struct PyModuleDef cumem_allocator_module = {
|
||||
PyModuleDef_HEAD_INIT, "cumem_allocator",
|
||||
"cumem-based allocator for CUDAPluggableAllocator", -1, module_methods};
|
||||
|
||||
PyMODINIT_FUNC PyInit_cumem_allocator(void) {
|
||||
// Initialize the module
|
||||
PyObject* module = PyModule_Create(&cumem_allocator_module);
|
||||
if (!module) {
|
||||
return NULL;
|
||||
}
|
||||
return module;
|
||||
}
|
||||
} // extern "C"
|
||||
@ -27,8 +27,7 @@
|
||||
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
|
||||
int max_shared_mem_per_block_opt_in = 0;
|
||||
cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin,
|
||||
device);
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
return max_shared_mem_per_block_opt_in;
|
||||
}
|
||||
|
||||
|
||||
@ -21,7 +21,7 @@ __device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <typename scalar_t>
|
||||
template <typename scalar_t, typename token_cnts_t>
|
||||
__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
|
||||
int32_t* sorted_token_ids,
|
||||
int32_t* expert_ids,
|
||||
@ -32,12 +32,8 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
|
||||
const size_t start_idx = threadIdx.x * tokens_per_thread;
|
||||
|
||||
extern __shared__ int32_t shared_mem[];
|
||||
|
||||
int32_t* tokens_cnts =
|
||||
shared_mem; // 2d tensor with shape (blockDim.x + 1, num_experts)
|
||||
int32_t* cumsum =
|
||||
shared_mem +
|
||||
(blockDim.x + 1) * num_experts; // 1d tensor with shape (num_experts + 1)
|
||||
int32_t* cumsum = shared_mem; // 1d tensor with shape (num_experts + 1)
|
||||
token_cnts_t* tokens_cnts = (token_cnts_t*)(shared_mem + blockDim.x + 1);
|
||||
|
||||
for (int i = 0; i < num_experts; ++i) {
|
||||
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
|
||||
@ -74,7 +70,7 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
|
||||
block_size) *
|
||||
block_size;
|
||||
}
|
||||
*total_tokens_post_pad = cumsum[num_experts];
|
||||
*total_tokens_post_pad = static_cast<int32_t>(cumsum[num_experts]);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
@ -224,26 +220,46 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
torch::Tensor num_tokens_post_pad) {
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
// If we have very large number of experts, we can no longer use shared
|
||||
// memory.
|
||||
// TODO(simon): the right solution should be calculating the exact right
|
||||
// amount of shared memory and use that. The num_experts >= 256 is just a
|
||||
// temporary solution to unblock Deepseek V3.
|
||||
if (num_experts >= 256) {
|
||||
int device_max_shared_mem;
|
||||
auto dev = topk_ids.get_device();
|
||||
cudaDeviceGetAttribute(&device_max_shared_mem,
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
||||
|
||||
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
|
||||
const int32_t shared_mem_i32 =
|
||||
((num_thread + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t);
|
||||
const int32_t shared_mem_i16 =
|
||||
((num_thread + 1) * num_experts) * sizeof(uint16_t) +
|
||||
(num_experts + 1) * sizeof(int32_t);
|
||||
|
||||
bool use_global_memory = false;
|
||||
bool use_i16 = false; // Use uint16_t for shared memory token counts
|
||||
if (shared_mem_i32 < device_max_shared_mem) {
|
||||
// Do nothing in this case. We're all set to use int32_t token counts
|
||||
} else if (shared_mem_i16 < device_max_shared_mem &&
|
||||
topk_ids.numel() <= 65535) {
|
||||
// when nelements of topk_ids is smaller than 65535 (max value of uint16),
|
||||
// element value of token_cnts would also smaller than 65535,
|
||||
// so we can use uint16 as dtype of token_cnts
|
||||
use_i16 = true;
|
||||
} else {
|
||||
use_global_memory = true;
|
||||
}
|
||||
|
||||
if (use_global_memory) {
|
||||
VLLM_DISPATCH_INTEGRAL_TYPES(
|
||||
topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
|
||||
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
|
||||
// tensors
|
||||
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
|
||||
|
||||
const int32_t mem_tokens_cnts =
|
||||
((num_experts + 1) * num_experts) * sizeof(int32_t);
|
||||
const int32_t mem_cumsum = (num_experts + 1) * sizeof(int32_t);
|
||||
// allocate global memory
|
||||
int32_t* tokens_cnts;
|
||||
int32_t* cumsum;
|
||||
cudaMalloc(&tokens_cnts, mem_tokens_cnts);
|
||||
cudaMalloc(&cumsum, mem_cumsum);
|
||||
auto options_int = torch::TensorOptions()
|
||||
.dtype(torch::kInt)
|
||||
.device(topk_ids.device());
|
||||
torch::Tensor token_cnts_buffer =
|
||||
torch::empty({(num_experts + 1) * num_experts}, options_int);
|
||||
torch::Tensor cumsum_buffer =
|
||||
torch::empty({num_experts + 1}, options_int);
|
||||
|
||||
auto kernel =
|
||||
vllm::moe::moe_align_block_size_global_mem_kernel<scalar_t>;
|
||||
@ -252,25 +268,32 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
experts_ids.data_ptr<int32_t>(),
|
||||
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
|
||||
topk_ids.numel(), tokens_cnts, cumsum);
|
||||
cudaFree(tokens_cnts);
|
||||
cudaFree(cumsum);
|
||||
topk_ids.numel(), token_cnts_buffer.data_ptr<int32_t>(),
|
||||
cumsum_buffer.data_ptr<int32_t>());
|
||||
});
|
||||
} else if (use_i16) {
|
||||
VLLM_DISPATCH_INTEGRAL_TYPES(
|
||||
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||
// set dynamic shared mem
|
||||
auto kernel =
|
||||
vllm::moe::moe_align_block_size_kernel<scalar_t, uint16_t>;
|
||||
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
|
||||
(void*)kernel, shared_mem_i16));
|
||||
kernel<<<1, num_thread, shared_mem_i16, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
experts_ids.data_ptr<int32_t>(),
|
||||
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
|
||||
topk_ids.numel());
|
||||
});
|
||||
} else {
|
||||
VLLM_DISPATCH_INTEGRAL_TYPES(
|
||||
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
|
||||
// tensors
|
||||
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
|
||||
const int32_t shared_mem =
|
||||
((num_thread + 1) * num_experts + (num_experts + 1)) *
|
||||
sizeof(int32_t);
|
||||
|
||||
// set dynamic shared mem
|
||||
auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t>;
|
||||
auto kernel =
|
||||
vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>;
|
||||
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
|
||||
(void*)kernel, shared_mem));
|
||||
kernel<<<1, num_thread, shared_mem, stream>>>(
|
||||
(void*)kernel, shared_mem_i32));
|
||||
kernel<<<1, num_thread, shared_mem_i32, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
experts_ids.data_ptr<int32_t>(),
|
||||
|
||||
10
csrc/ops.h
10
csrc/ops.h
@ -34,8 +34,9 @@ void paged_attention_v1(
|
||||
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale, const int64_t tp_rank,
|
||||
const int64_t blocksparse_local_blocks,
|
||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||
const int64_t blocksparse_head_sliding_step);
|
||||
|
||||
@ -45,8 +46,9 @@ void paged_attention_v2(
|
||||
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale, const int64_t tp_rank,
|
||||
const int64_t blocksparse_local_blocks,
|
||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||
const int64_t blocksparse_head_sliding_step);
|
||||
|
||||
|
||||
@ -218,7 +218,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
|
||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
|
||||
// head_size]
|
||||
scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size]
|
||||
int max_ctx_blocks, float k_scale, float v_scale) {
|
||||
int max_ctx_blocks, const float* k_scale_ptr, const float* v_scale_ptr) {
|
||||
constexpr int NWARPS = NUM_THREADS / WARP_SIZE;
|
||||
const int warpid = threadIdx.x / WARP_SIZE;
|
||||
const int laneid = threadIdx.x % WARP_SIZE;
|
||||
@ -406,7 +406,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
|
||||
// Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d];
|
||||
const _B8x8 Vlocalb8 = v_ptrh8be[d];
|
||||
Vlocal[h][b * BLOCK_SIZE / 8 + d] =
|
||||
scaled_convert_b8x8<scalar_t, KV_DTYPE>(Vlocalb8, v_scale);
|
||||
scaled_convert_b8x8<scalar_t, KV_DTYPE>(Vlocalb8, *v_scale_ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -416,7 +416,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
|
||||
#pragma unroll
|
||||
for (int d = 0; d < KHELOOP; d++) {
|
||||
Klocal[d] =
|
||||
scaled_convert_b8x8<scalar_t, KV_DTYPE>(Klocalb8[d], k_scale);
|
||||
scaled_convert_b8x8<scalar_t, KV_DTYPE>(Klocalb8[d], *k_scale_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
@ -890,7 +890,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
|
||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
|
||||
// head_size]
|
||||
scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size]
|
||||
int max_ctx_blocks, float k_scale, float v_scale) {
|
||||
int max_ctx_blocks, const float* k_scale, const float* v_scale) {
|
||||
UNREACHABLE_CODE
|
||||
}
|
||||
|
||||
@ -919,7 +919,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \
|
||||
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
|
||||
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \
|
||||
k_scale, v_scale);
|
||||
k_scale_ptr, v_scale_ptr);
|
||||
|
||||
template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
|
||||
int BLOCK_SIZE, int HEAD_SIZE, int PARTITION_SIZE = 512>
|
||||
@ -929,7 +929,7 @@ void paged_attention_custom_launcher(
|
||||
torch::Tensor& value_cache, const int num_kv_heads, float scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& context_lens,
|
||||
int max_context_len, const std::optional<torch::Tensor>& alibi_slopes,
|
||||
float k_scale, float v_scale) {
|
||||
torch::Tensor& k_scale, torch::Tensor& v_scale) {
|
||||
int num_seqs = query.size(0);
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
@ -953,6 +953,8 @@ void paged_attention_custom_launcher(
|
||||
KVT* value_cache_ptr = reinterpret_cast<KVT*>(value_cache.data_ptr());
|
||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||
int* context_lens_ptr = context_lens.data_ptr<int>();
|
||||
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
|
||||
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
|
||||
|
||||
const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE);
|
||||
const int max_num_partitions =
|
||||
@ -1087,7 +1089,8 @@ void paged_attention(
|
||||
torch::Tensor& context_lens, // [num_seqs]
|
||||
int64_t block_size, int64_t max_context_len,
|
||||
const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, double k_scale, double v_scale) {
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale) {
|
||||
const int head_size = query.size(2);
|
||||
if (kv_cache_dtype == "auto") {
|
||||
if (query.dtype() == at::ScalarType::Half) {
|
||||
|
||||
@ -10,5 +10,5 @@ void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
|
||||
torch::Tensor& context_lens, int64_t block_size,
|
||||
int64_t max_context_len,
|
||||
const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, double k_scale,
|
||||
double v_scale);
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale);
|
||||
|
||||
@ -27,7 +27,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
|
||||
" int max_context_len,"
|
||||
" Tensor? alibi_slopes,"
|
||||
" str kv_cache_dtype,"
|
||||
" float k_scale, float v_scale) -> ()");
|
||||
" Tensor k_scale, Tensor v_scale) -> ()");
|
||||
rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention);
|
||||
}
|
||||
|
||||
|
||||
@ -30,7 +30,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor value_cache, int num_kv_heads, float scale,"
|
||||
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
||||
" int max_seq_len, Tensor? alibi_slopes,"
|
||||
" str kv_cache_dtype, float k_scale, float v_scale,"
|
||||
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
|
||||
" int tp_rank, int blocksparse_local_blocks,"
|
||||
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
||||
" int blocksparse_head_sliding_step) -> ()");
|
||||
@ -44,7 +44,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor value_cache, int num_kv_heads, float scale,"
|
||||
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
||||
" int max_seq_len, Tensor? alibi_slopes,"
|
||||
" str kv_cache_dtype, float k_scale, float v_scale,"
|
||||
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
|
||||
" int tp_rank, int blocksparse_local_blocks,"
|
||||
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
||||
" int blocksparse_head_sliding_step) -> ()");
|
||||
@ -462,7 +462,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
" Tensor! key_cache, Tensor! value_cache,"
|
||||
" Tensor slot_mapping,"
|
||||
" str kv_cache_dtype,"
|
||||
" float k_scale, float v_scale) -> ()");
|
||||
" Tensor k_scale, Tensor v_scale) -> ()");
|
||||
cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache);
|
||||
|
||||
// Reshape the key and value tensors and cache them.
|
||||
@ -472,7 +472,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
" Tensor! value_cache,"
|
||||
" Tensor slot_mapping,"
|
||||
" str kv_cache_dtype,"
|
||||
" float k_scale, float v_scale) -> ()");
|
||||
" Tensor k_scale, Tensor v_scale) -> ()");
|
||||
cache_ops.impl("reshape_and_cache_flash", torch::kCUDA,
|
||||
&reshape_and_cache_flash);
|
||||
|
||||
|
||||
@ -43,7 +43,7 @@
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: vllm.multimodal.inputs.MultiModalInputsV2
|
||||
.. autoclass:: vllm.multimodal.inputs.MultiModalInputs
|
||||
:members:
|
||||
:show-inheritance:
|
||||
```
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
|
||||
We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below:
|
||||
|
||||
- [The eighth vLLM meetup](https://lu.ma/zep56hui), with Google Cloud, January 22nd 2025. [[Slides]](https://docs.google.com/presentation/d/1epVkt4Zu8Jz_S5OhEHPc798emsYh2BwYfRuDDVEF7u4/edit?usp=sharing)
|
||||
- [The seventh vLLM meetup](https://lu.ma/h0qvrajz), with Snowflake, November 14th 2024. [[Slides]](https://docs.google.com/presentation/d/1e3CxQBV3JsfGp30SwyvS3eM_tW-ghOhJ9PAJGK6KR54/edit?usp=sharing)
|
||||
- [The sixth vLLM meetup](https://lu.ma/87q3nvnh), with NVIDIA, September 9th 2024. [[Slides]](https://docs.google.com/presentation/d/1wrLGwytQfaOTd5wCGSPNhoaW3nq0E-9wqyP7ny93xRs/edit?usp=sharing)
|
||||
- [The fifth vLLM meetup](https://lu.ma/lp0gyjqr), with AWS, July 24th 2024. [[Slides]](https://docs.google.com/presentation/d/1RgUD8aCfcHocghoP3zmXzck9vX3RCI9yfUAB2Bbcl4Y/edit?usp=sharing)
|
||||
|
||||
@ -25,10 +25,12 @@ Check out the [building from source](#build-from-source) documentation for detai
|
||||
```bash
|
||||
pip install -r requirements-dev.txt
|
||||
|
||||
# linting and formatting
|
||||
bash format.sh
|
||||
# Static type checking
|
||||
mypy
|
||||
# Linting, formatting and static type checking
|
||||
pre-commit install
|
||||
|
||||
# You can manually run pre-commit with
|
||||
pre-commit run --all-files
|
||||
|
||||
# Unit tests
|
||||
pytest tests/
|
||||
```
|
||||
@ -88,7 +90,8 @@ If the PR spans more than one category, please include all relevant prefixes.
|
||||
The PR needs to meet the following code quality standards:
|
||||
|
||||
- We adhere to [Google Python style guide](https://google.github.io/styleguide/pyguide.html) and [Google C++ style guide](https://google.github.io/styleguide/cppguide.html).
|
||||
- Pass all linter checks. Please use <gh-file:format.sh> to format your code.
|
||||
- Pass all linter checks. Please use `pre-commit` to format your code. See
|
||||
<https://pre-commit.com/#usage> if `pre-commit` is new to you.
|
||||
- The code needs to be well-documented to ensure future contributors can easily
|
||||
understand the code.
|
||||
- Include sufficient tests to ensure the project stays correct and robust. This
|
||||
|
||||
@ -41,3 +41,20 @@ You may use the `#security` channel in the [VLLM Slack](https://slack.vllm.ai)
|
||||
to discuss security-related topics. However, please do not disclose any
|
||||
vulnerabilities in this channel. If you need to report a vulnerability, please
|
||||
use the GitHub security advisory system or contact a VMT member privately.
|
||||
|
||||
## Vulnerability Disclosure
|
||||
|
||||
The process for disclosing vulnerabilities is the following:
|
||||
|
||||
- The VMT will work with the project maintainers to develop a fix for the
|
||||
vulnerability.
|
||||
- The VMT will coordinate with the reporter and project maintainers to prepare a
|
||||
security advisory that adequately describes the vulnerability and its impact.
|
||||
- The VMT will coordinate with the project maintainers to publish a fix and
|
||||
release an update that includes that fix.
|
||||
- The VMT will publish the security advisory on GitHub. Release notes will be
|
||||
updated to include a reference to the security advisory.
|
||||
|
||||
The VMT and project maintainers will work to minimize the amount of time in
|
||||
between disclosing any public information about the vulnerability and making a
|
||||
release and advisory available.
|
||||
|
||||
@ -42,6 +42,9 @@ DOCKER_BUILDKIT=1 docker build . --target vllm-openai --tag vllm/vllm-openai
|
||||
By default vLLM will build for all GPU types for widest distribution. If you are just building for the
|
||||
current GPU type the machine is running on, you can add the argument `--build-arg torch_cuda_arch_list=""`
|
||||
for vLLM to find the current GPU type and build for that.
|
||||
|
||||
If you are using Podman instead of Docker, you might need to disable SELinux labeling by
|
||||
adding `--security-opt label=disable` when running `podman build` command to avoid certain [existing issues](https://github.com/containers/buildah/discussions/4184).
|
||||
```
|
||||
|
||||
## Building for Arm64/aarch64
|
||||
|
||||
@ -307,7 +307,7 @@ Check the '✗' with links to see tracking issue for unsupported feature/hardwar
|
||||
- ✅
|
||||
- ?
|
||||
- ?
|
||||
- ✅
|
||||
- [✗](gh-issue:11484)
|
||||
- ✅
|
||||
- ✗
|
||||
- ?
|
||||
|
||||
@ -1,44 +0,0 @@
|
||||
(fp8-e4m3-kvcache)=
|
||||
|
||||
# FP8 E4M3 KV Cache
|
||||
|
||||
Quantizing the KV cache to FP8 reduces its memory footprint. This increases the number of tokens that can be stored in the cache,
|
||||
improving throughput. OCP (Open Compute Project www.opencompute.org) specifies two common 8-bit floating point data formats: E5M2
|
||||
(5 exponent bits and 2 mantissa bits) and E4M3FN (4 exponent bits and 3 mantissa bits), often shortened as E4M3. One benefit of
|
||||
the E4M3 format over E5M2 is that floating point numbers are represented in higher precision. However, the small dynamic range of
|
||||
FP8 E4M3 (±240.0 can be represented) typically necessitates the use of a higher-precision (typically FP32) scaling factor alongside
|
||||
each quantized tensor. For now, only per-tensor (scalar) scaling factors are supported. Development is ongoing to support scaling
|
||||
factors of a finer granularity (e.g. per-channel).
|
||||
|
||||
These scaling factors can be specified by passing an optional quantization param JSON to the LLM engine at load time. If
|
||||
this JSON is not specified, scaling factors default to 1.0. These scaling factors are typically obtained when running an
|
||||
unquantized model through a quantizer tool (e.g. AMD quantizer or NVIDIA AMMO).
|
||||
|
||||
To install AMMO (AlgorithMic Model Optimization):
|
||||
|
||||
```console
|
||||
pip install --no-cache-dir --extra-index-url https://pypi.nvidia.com nvidia-ammo
|
||||
```
|
||||
|
||||
Studies have shown that FP8 E4M3 quantization typically only minimally degrades inference accuracy. The most recent silicon
|
||||
offerings e.g. AMD MI300, NVIDIA Hopper or later support native hardware conversion to and from fp32, fp16, bf16, etc.
|
||||
Thus, LLM inference is greatly accelerated with minimal accuracy loss.
|
||||
|
||||
Here is an example of how to enable this feature:
|
||||
|
||||
```python
|
||||
# two float8_e4m3fn kv cache scaling factor files are provided under tests/fp8_kv, please refer to
|
||||
# https://github.com/vllm-project/vllm/blob/main/examples/other/fp8/README.md to generate kv_cache_scales.json of your own.
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
sampling_params = SamplingParams(temperature=1.3, top_p=0.8)
|
||||
llm = LLM(model="meta-llama/Llama-2-7b-chat-hf",
|
||||
kv_cache_dtype="fp8",
|
||||
quantization_param_path="./tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json")
|
||||
prompt = "London is the capital of"
|
||||
out = llm.generate(prompt, sampling_params)[0].outputs[0].text
|
||||
print(out)
|
||||
|
||||
# output w/ scaling factors: England, the United Kingdom, and one of the world's leading financial,
|
||||
# output w/o scaling factors: England, located in the southeastern part of the country. It is known
|
||||
```
|
||||
@ -1,31 +0,0 @@
|
||||
(fp8-kv-cache)=
|
||||
|
||||
# FP8 E5M2 KV Cache
|
||||
|
||||
The int8/int4 quantization scheme requires additional scale GPU memory storage, which reduces the expected GPU memory benefits.
|
||||
The FP8 data format retains 2~3 mantissa bits and can convert float/fp16/bfloat16 and fp8 to each other.
|
||||
|
||||
Here is an example of how to enable this feature:
|
||||
|
||||
```python
|
||||
from vllm import LLM, SamplingParams
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
# Create an LLM.
|
||||
llm = LLM(model="facebook/opt-125m", kv_cache_dtype="fp8")
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
```
|
||||
@ -14,6 +14,5 @@ bnb
|
||||
gguf
|
||||
int8
|
||||
fp8
|
||||
fp8_e5m2_kvcache
|
||||
fp8_e4m3_kvcache
|
||||
quantized_kvcache
|
||||
```
|
||||
|
||||
147
docs/source/features/quantization/quantized_kvcache.md
Normal file
147
docs/source/features/quantization/quantized_kvcache.md
Normal file
@ -0,0 +1,147 @@
|
||||
(quantized-kvcache)=
|
||||
|
||||
# Quantized KV Cache
|
||||
|
||||
## FP8 KV Cache
|
||||
|
||||
Quantizing the KV cache to FP8 reduces its memory footprint. This increases the number of tokens that can be stored in the cache, improving throughput.
|
||||
|
||||
### FP8 Formats
|
||||
|
||||
[OCP (Open Compute Project)](https://www.opencompute.org) specifies two common 8-bit floating point data formats:
|
||||
|
||||
- E5M2 (5 exponent bits and 2 mantissa bits)
|
||||
- E4M3FN (4 exponent bits and 3 mantissa bits, often shortened as E4M3)
|
||||
|
||||
The E4M3 format offers higher precision compared to E5M2. However, due to its small dynamic range (±240.0), E4M3 typically requires a higher-precision (FP32) scaling factor alongside each quantized tensor.
|
||||
|
||||
### Current Limitations
|
||||
|
||||
For now, only per-tensor (scalar) scaling factors are supported. Development is ongoing to support scaling factors of a finer granularity (e.g. per-channel).
|
||||
|
||||
### Performance Impact
|
||||
|
||||
The current FP8 KV cache implementation primarily benefits throughput by allowing approximately double the amount of space for KV cache allocation. This enables either:
|
||||
|
||||
- Processing longer context lengths for individual requests, or
|
||||
- Handling more concurrent request batches
|
||||
|
||||
However, there are currently no latency improvements as the implementation does not yet include fused dequantization and attention operations. Future releases will support quantized attention with hardware acceleration, which should provide additional performance benefits. While the most recent silicon offerings (e.g. AMD MI300, NVIDIA Hopper or later) support native hardware conversion between FP8 and other formats (fp32, fp16, bf16), this benefit is not yet fully realized.
|
||||
|
||||
Studies have shown that FP8 E4M3 quantization typically only minimally degrades inference accuracy, making it a practical choice for throughput optimization.
|
||||
|
||||
## Usage Example
|
||||
|
||||
Here is an example of how to enable FP8 quantization:
|
||||
|
||||
```python
|
||||
# To calculate kv cache scales on the fly enable the calculate_kv_scales
|
||||
# parameter
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.7, top_p=0.8)
|
||||
llm = LLM(model="meta-llama/Llama-2-7b-chat-hf",
|
||||
kv_cache_dtype="fp8",
|
||||
calculate_kv_scales=True)
|
||||
prompt = "London is the capital of"
|
||||
out = llm.generate(prompt, sampling_params)[0].outputs[0].text
|
||||
print(out)
|
||||
```
|
||||
|
||||
The `kv_cache_dtype` argument specifies the data type for KV cache storage:
|
||||
- `"auto"`: Uses the model's default "unquantized" data type
|
||||
- `"fp8"` or `"fp8_e4m3"`: Supported on CUDA 11.8+ and ROCm (AMD GPU)
|
||||
- `"fp8_e5m2"`: Supported on CUDA 11.8+
|
||||
|
||||
## Calibrated Scales for Better Accuracy
|
||||
|
||||
For optimal model quality when using FP8 KV Cache, we recommend using calibrated scales tuned to representative inference data. [LLM Compressor](https://github.com/vllm-project/llm-compressor/) is the recommended tool for this process.
|
||||
|
||||
### Installation
|
||||
|
||||
First, install the required dependencies:
|
||||
|
||||
```console
|
||||
pip install llmcompressor
|
||||
```
|
||||
|
||||
### Example Usage
|
||||
|
||||
Here's a complete example using `meta-llama/Llama-3.1-8B-Instruct` (most models can use this same pattern):
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from llmcompressor.transformers import oneshot
|
||||
|
||||
# Select model and load it
|
||||
MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="auto", torch_dtype="auto")
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
||||
|
||||
# Select calibration dataset
|
||||
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
|
||||
DATASET_SPLIT = "train_sft"
|
||||
|
||||
# Configure calibration parameters
|
||||
NUM_CALIBRATION_SAMPLES = 512 # 512 samples is a good starting point
|
||||
MAX_SEQUENCE_LENGTH = 2048
|
||||
|
||||
# Load and preprocess dataset
|
||||
ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
|
||||
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))
|
||||
|
||||
def process_and_tokenize(example):
|
||||
text = tokenizer.apply_chat_template(example["messages"], tokenize=False)
|
||||
return tokenizer(
|
||||
text,
|
||||
padding=False,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
truncation=True,
|
||||
add_special_tokens=False,
|
||||
)
|
||||
|
||||
ds = ds.map(process_and_tokenize, remove_columns=ds.column_names)
|
||||
|
||||
# Configure quantization settings
|
||||
recipe = """
|
||||
quant_stage:
|
||||
quant_modifiers:
|
||||
QuantizationModifier:
|
||||
kv_cache_scheme:
|
||||
num_bits: 8
|
||||
type: float
|
||||
strategy: tensor
|
||||
dynamic: false
|
||||
symmetric: true
|
||||
"""
|
||||
|
||||
# Apply quantization
|
||||
oneshot(
|
||||
model=model,
|
||||
dataset=ds,
|
||||
recipe=recipe,
|
||||
max_seq_length=MAX_SEQUENCE_LENGTH,
|
||||
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
|
||||
)
|
||||
|
||||
# Save quantized model
|
||||
SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-KV"
|
||||
model.save_pretrained(SAVE_DIR, save_compressed=True)
|
||||
tokenizer.save_pretrained(SAVE_DIR)
|
||||
```
|
||||
|
||||
The above script will create a folder in your current directory containing your quantized model (e.g., `Llama-3.1-8B-Instruct-FP8-KV`) with calibrated scales.
|
||||
|
||||
When running the model you must specify `kv_cache_dtype="fp8"` in order to enable the kv cache quantization and use the scales.
|
||||
|
||||
```python
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.7, top_p=0.8)
|
||||
llm = LLM(model="Llama-3.1-8B-Instruct-FP8-KV", kv_cache_dtype="fp8")
|
||||
prompt = "London is the capital of"
|
||||
out = llm.generate(prompt, sampling_params)[0].outputs[0].text
|
||||
print(out)
|
||||
```
|
||||
@ -13,6 +13,14 @@ vLLM supports AMD GPUs with ROCm 6.2.
|
||||
|
||||
Currently, there are no pre-built ROCm wheels.
|
||||
|
||||
However, the [AMD Infinity hub for vLLM](https://hub.docker.com/r/rocm/vllm/tags) offers a prebuilt, optimized
|
||||
docker image designed for validating inference performance on the AMD Instinct™ MI300X accelerator.
|
||||
|
||||
```{tip}
|
||||
Please check [LLM inference performance validation on AMD Instinct MI300X](https://rocm.docs.amd.com/en/latest/how-to/performance-validation/mi300x/vllm-benchmark.html)
|
||||
for instructions on how to use this prebuilt docker image.
|
||||
```
|
||||
|
||||
### Build wheel from source
|
||||
|
||||
0. Install prerequisites (skip if you are already in an environment/docker with the following installed):
|
||||
@ -70,7 +78,7 @@ Currently, there are no pre-built ROCm wheels.
|
||||
|
||||
# Install PyTorch
|
||||
$ pip uninstall torch -y
|
||||
$ pip install --no-cache-dir --pre torch==2.6.0.dev20241024 --index-url https://download.pytorch.org/whl/nightly/rocm6.2
|
||||
$ pip install --no-cache-dir --pre torch --index-url https://download.pytorch.org/whl/rocm6.2
|
||||
|
||||
# Build & install AMD SMI
|
||||
$ pip install /opt/rocm/share/amd_smi
|
||||
@ -123,11 +131,10 @@ It is important that the user kicks off the docker build using buildkit. Either
|
||||
<gh-file:Dockerfile.rocm> uses ROCm 6.2 by default, but also supports ROCm 5.7, 6.0 and 6.1 in older vLLM branches.
|
||||
It provides flexibility to customize the build of docker image using the following arguments:
|
||||
|
||||
- `BASE_IMAGE`: specifies the base image used when running `docker build`, specifically the PyTorch on ROCm base image.
|
||||
- `BUILD_FA`: specifies whether to build CK flash-attention. The default is 1. For [Radeon RX 7900 series (gfx1100)](https://rocm.docs.amd.com/projects/radeon/en/latest/index.html), this should be set to 0 before flash-attention supports this target.
|
||||
- `FX_GFX_ARCHS`: specifies the GFX architecture that is used to build CK flash-attention, for example, `gfx90a;gfx942` for MI200 and MI300. The default is `gfx90a;gfx942`
|
||||
- `FA_BRANCH`: specifies the branch used to build the CK flash-attention in [ROCm's flash-attention repo](https://github.com/ROCmSoftwarePlatform/flash-attention). The default is `ae7928c`
|
||||
- `BUILD_TRITON`: specifies whether to build triton flash-attention. The default value is 1.
|
||||
- `BASE_IMAGE`: specifies the base image used when running `docker build`. The default value `rocm/vllm-dev:base` is an image published and maintained by AMD. It is being built using <gh-file:Dockerfile.rocm_base>
|
||||
- `USE_CYTHON`: An option to run cython compilation on a subset of python files upon docker build
|
||||
- `BUILD_RPD`: Include RocmProfileData profiling tool in the image
|
||||
- `ARG_PYTORCH_ROCM_ARCH`: Allows to override the gfx architecture values from the base docker image
|
||||
|
||||
Their values can be passed in when running `docker build` with `--build-arg` options.
|
||||
|
||||
@ -137,10 +144,10 @@ To build vllm on ROCm 6.2 for MI200 and MI300 series, you can use the default:
|
||||
DOCKER_BUILDKIT=1 docker build -f Dockerfile.rocm -t vllm-rocm .
|
||||
```
|
||||
|
||||
To build vllm on ROCm 6.2 for Radeon RX7900 series (gfx1100), you should specify `BUILD_FA` as below:
|
||||
To build vllm on ROCm 6.2 for Radeon RX7900 series (gfx1100), you should pick the alternative base image:
|
||||
|
||||
```console
|
||||
DOCKER_BUILDKIT=1 docker build --build-arg BUILD_FA="0" -f Dockerfile.rocm -t vllm-rocm .
|
||||
DOCKER_BUILDKIT=1 docker build --build-arg BASE_IMAGE="rocm/vllm-dev:navi_base" -f Dockerfile.rocm -t vllm-rocm .
|
||||
```
|
||||
|
||||
To run the above docker image `vllm-rocm`, use the below command:
|
||||
|
||||
@ -22,9 +22,9 @@ It'd be better to store the model in a local disk. Additionally, have a look at
|
||||
To isolate the model downloading and loading issue, you can use the `--load-format dummy` argument to skip loading the model weights. This way, you can check if the model downloading and loading is the bottleneck.
|
||||
```
|
||||
|
||||
## Model is too large
|
||||
## Out of memory
|
||||
|
||||
If the model is too large to fit in a single GPU, you might want to [consider tensor parallelism](#distributed-serving) to split the model across multiple GPUs. In that case, every process will read the whole model and split it into chunks, which makes the disk reading time even longer (proportional to the size of tensor parallelism). You can convert the model checkpoint to a sharded checkpoint using <gh-file:examples/offline_inference/save_sharded_state.py>. The conversion process might take some time, but later you can load the sharded checkpoint much faster. The model loading time should remain constant regardless of the size of tensor parallelism.
|
||||
If the model is too large to fit in a single GPU, you will get an out-of-memory (OOM) error. Consider [using tensor parallelism](#distributed-serving) to split the model across multiple GPUs. In that case, every process will read the whole model and split it into chunks, which makes the disk reading time even longer (proportional to the size of tensor parallelism). You can convert the model checkpoint to a sharded checkpoint using <gh-file:examples/offline_inference/save_sharded_state.py>. The conversion process might take some time, but later you can load the sharded checkpoint much faster. The model loading time should remain constant regardless of the size of tensor parallelism.
|
||||
|
||||
## Enable more logging
|
||||
|
||||
@ -197,6 +197,63 @@ if __name__ == '__main__':
|
||||
llm = vllm.LLM(...)
|
||||
```
|
||||
|
||||
## `torch.compile` Error
|
||||
|
||||
vLLM heavily depends on `torch.compile` to optimize the model for better performance, which introduces the dependency on the `torch.compile` functionality and the `triton` library. By default, we use `torch.compile` to [optimize some functions](https://github.com/vllm-project/vllm/pull/10406) in the model. Before running vLLM, you can check if `torch.compile` is working as expected by running the following script:
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
@torch.compile
|
||||
def f(x):
|
||||
# a simple function to test torch.compile
|
||||
x = x + 1
|
||||
x = x * 2
|
||||
x = x.sin()
|
||||
return x
|
||||
|
||||
x = torch.randn(4, 4).cuda()
|
||||
print(f(x))
|
||||
```
|
||||
|
||||
If it raises errors from `torch/_inductor` directory, usually it means you have a custom `triton` library that is not compatible with the version of PyTorch you are using. See [this issue](https://github.com/vllm-project/vllm/issues/12219) for example.
|
||||
|
||||
## Model failed to be inspected
|
||||
|
||||
If you see an error like:
|
||||
|
||||
```text
|
||||
File "vllm/model_executor/models/registry.py", line xxx, in _raise_for_unsupported
|
||||
raise ValueError(
|
||||
ValueError: Model architectures ['<arch>'] failed to be inspected. Please check the logs for more details.
|
||||
```
|
||||
|
||||
It means that vLLM failed to import the model file.
|
||||
Usually, it is related to missing dependencies or outdated binaries in the vLLM build.
|
||||
Please read the logs carefully to determine the root cause of the error.
|
||||
|
||||
## Model not supported
|
||||
|
||||
If you see an error like:
|
||||
|
||||
```text
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
File "vllm/model_executor/models/registry.py", line xxx, in inspect_model_cls
|
||||
for arch in architectures:
|
||||
TypeError: 'NoneType' object is not iterable
|
||||
```
|
||||
|
||||
or:
|
||||
|
||||
```text
|
||||
File "vllm/model_executor/models/registry.py", line xxx, in _raise_for_unsupported
|
||||
raise ValueError(
|
||||
ValueError: Model architectures ['<arch>'] are not supported for now. Supported architectures: [...]
|
||||
```
|
||||
|
||||
But you are sure that the model is in the [list of supported models](#supported-models), there may be some issue with vLLM's model resolution. In that case, please follow [these steps](#model-resolution) to explicitly specify the vLLM implementation for the model.
|
||||
|
||||
## Known Issues
|
||||
|
||||
- In `v0.5.2`, `v0.5.3`, and `v0.5.3.post1`, there is a bug caused by [zmq](https://github.com/zeromq/pyzmq/issues/2000) , which can occasionally cause vLLM to hang depending on the machine configuration. The solution is to upgrade to the latest version of `vllm` to include the [fix](gh-pr:6759).
|
||||
|
||||
@ -216,6 +216,11 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
- `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc.
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
* - `InternLM3ForCausalLM`
|
||||
- InternLM3
|
||||
- `internlm/internlm3-8b-instruct`, etc.
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
* - `JAISLMHeadModel`
|
||||
- Jais
|
||||
- `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc.
|
||||
@ -297,8 +302,8 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
* - `Phi3ForCausalLM`
|
||||
- Phi-3
|
||||
- `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc.
|
||||
- Phi-4, Phi-3
|
||||
- `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc.
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
* - `Phi3SmallForCausalLM`
|
||||
@ -465,6 +470,11 @@ of the whole prompt are extracted from the normalized hidden state corresponding
|
||||
- `Qwen/Qwen2.5-Math-RM-72B`, etc.
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
* - `Qwen2ForProcessRewardModel`
|
||||
- Qwen2-based
|
||||
- `Qwen/Qwen2.5-Math-PRM-7B`, `Qwen/Qwen2.5-Math-PRM-72B`, etc.
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
```
|
||||
|
||||
If your model is not in the above list, we will try to automatically convert the model using
|
||||
@ -613,7 +623,7 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
* - `DeepseekVLV2ForCausalLM`
|
||||
- DeepSeek-VL2
|
||||
- T + I<sup>+</sup>
|
||||
- `deepseek-ai/deepseek-vl2-tiny`(WIP), `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2` etc. (see note)
|
||||
- `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2` etc. (see note)
|
||||
-
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
@ -749,7 +759,7 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
- `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc.
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
-
|
||||
- ✅︎
|
||||
* - `UltravoxModel`
|
||||
- Ultravox
|
||||
- T + A<sup>E+</sup>
|
||||
@ -762,17 +772,10 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
<sup>E</sup> Pre-computed embeddings can be inputted for this modality.
|
||||
<sup>+</sup> Multiple items can be inputted per text prompt for this modality.
|
||||
|
||||
````{note}
|
||||
The `deepseek-ai/deepseek-vl2-tiny` is not supported yet.
|
||||
|
||||
To use `DeepSeek-VL2` series models, you need to install a fork version `deepseek_vl2` package:
|
||||
```shell
|
||||
pip install git+https://github.com/Isotr0py/DeepSeek-VL2.git
|
||||
```{note}
|
||||
To use `DeepSeek-VL2` series models, you have to pass `--hf_overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'` when running vLLM.
|
||||
```
|
||||
|
||||
Besides, to run `DeepSeek-VL2` series models, you have to pass `--hf_overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'` when running vLLM.
|
||||
````
|
||||
|
||||
```{note}
|
||||
To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have to pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.
|
||||
```
|
||||
|
||||
@ -31,6 +31,30 @@ Please refer to the above pages for more details about each API.
|
||||
This section lists the most common options for running the vLLM engine.
|
||||
For a full list, refer to the [Engine Arguments](#engine-args) page.
|
||||
|
||||
(model-resolution)=
|
||||
|
||||
### Model resolution
|
||||
|
||||
vLLM loads HuggingFace-compatible models by inspecting the `architectures` field in `config.json` of the model repository
|
||||
and finding the corresponding implementation that is registered to vLLM.
|
||||
Nevertheless, our model resolution may fail for the following reasons:
|
||||
|
||||
- The `config.json` of the model repository lacks the `architectures` field.
|
||||
- Unofficial repositories refer to a model using alternative names which are not recorded in vLLM.
|
||||
- The same architecture name is used for multiple models, creating ambiguity as to which model should be loaded.
|
||||
|
||||
To fix this, explicitly specify the model architecture by passing `config.json` overrides to the `hf_overrides` option.
|
||||
For example:
|
||||
|
||||
```python
|
||||
model = LLM(
|
||||
model="cerebras/Cerebras-GPT-1.3B",
|
||||
hf_overrides={"architectures": ["GPT2LMHeadModel"]}, # GPT-2
|
||||
)
|
||||
```
|
||||
|
||||
Our [list of supported models](#supported-models) shows the model architectures that are recognized by vLLM.
|
||||
|
||||
### Reducing memory usage
|
||||
|
||||
Large models might cause your machine to run out of memory (OOM). Here are some options that help alleviate this problem.
|
||||
|
||||
186
examples/offline_inference/rlhf.py
Normal file
186
examples/offline_inference/rlhf.py
Normal file
@ -0,0 +1,186 @@
|
||||
"""
|
||||
a simple demonstration of RLHF with vLLM, inspired by
|
||||
the OpenRLHF framework https://github.com/OpenRLHF/OpenRLHF .
|
||||
It follows the design that, training processes and inference processes
|
||||
are different, and they live on different GPUs.
|
||||
Training processes send prompts to inference processes to generate data,
|
||||
and also synchronize the weights of the model by broadcasting the weights
|
||||
from the training process to the inference process.
|
||||
Note that this is a simple demonstration of one training instance and one
|
||||
inference instance. In practice, there could be multiple training instances
|
||||
and multiple inference instances. For the full implementation, please refer
|
||||
to the OpenRLHF framework.
|
||||
"""
|
||||
import os
|
||||
|
||||
import ray
|
||||
import torch
|
||||
from ray.util.placement_group import placement_group
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.utils import get_ip, get_open_port
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
|
||||
def stateless_init_process_group(master_address, master_port, rank, world_size,
|
||||
device):
|
||||
"""
|
||||
vLLM provides `StatelessProcessGroup` to create a process group
|
||||
without considering the global process group in torch.distributed.
|
||||
It is recommended to create `StatelessProcessGroup`, and then initialize
|
||||
the data-plane communication (NCCL) between external (train processes)
|
||||
and vLLM workers.
|
||||
"""
|
||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
pg = StatelessProcessGroup.create(host=master_address,
|
||||
port=master_port,
|
||||
rank=rank,
|
||||
world_size=world_size)
|
||||
pynccl = PyNcclCommunicator(pg, device=device)
|
||||
return pynccl
|
||||
|
||||
|
||||
class MyWorker(Worker):
|
||||
"""
|
||||
The `MyWorker` class inherits from `Worker` to provide custom functions.
|
||||
For simplicity, we define the `MyWorker` class in this self-contained
|
||||
script. Normally, we should define the `MyWorker` class in a separate
|
||||
file and pass the qualified name of the class to the `worker_cls`
|
||||
parameter.
|
||||
"""
|
||||
|
||||
def init_weight_update_group(self, master_address, master_port,
|
||||
rank_offset, world_size):
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
rank = get_world_group().rank + rank_offset
|
||||
self.model_update_group = stateless_init_process_group(
|
||||
master_address,
|
||||
master_port,
|
||||
rank,
|
||||
world_size,
|
||||
self.device,
|
||||
)
|
||||
|
||||
def update_weight(self, name, dtype, shape):
|
||||
weight = torch.empty(shape, dtype=dtype, device="cuda")
|
||||
self.model_update_group.broadcast(weight,
|
||||
src=0,
|
||||
stream=torch.cuda.current_stream())
|
||||
|
||||
self.model_runner.model.load_weights(weights=[(name, weight)])
|
||||
|
||||
del weight
|
||||
|
||||
def check_weights_changed(self):
|
||||
"""
|
||||
Check if the weights are updated to 0.
|
||||
"""
|
||||
weights_updated = True
|
||||
for name, p in self.model_runner.model.named_parameters():
|
||||
weights_updated = weights_updated and torch.allclose(
|
||||
p, torch.zeros_like(p))
|
||||
return weights_updated
|
||||
|
||||
|
||||
class MyLLM(LLM):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# a hack to make the script work.
|
||||
# stop ray from manipulating CUDA_VISIBLE_DEVICES
|
||||
# at the top-level
|
||||
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
"""
|
||||
Start the training process, here we use huggingface transformers
|
||||
as an example to hold a model on GPU 0.
|
||||
"""
|
||||
|
||||
train_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
|
||||
train_model.to("cuda:0")
|
||||
"""
|
||||
Start the inference process, here we use vLLM to hold a model on GPU 1 and
|
||||
GPU 2. For the details on how to use ray, please refer to the ray
|
||||
documentation https://docs.ray.io/en/latest/ .
|
||||
"""
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"
|
||||
ray.init()
|
||||
|
||||
pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2)
|
||||
ray.get(pg_inference.ready())
|
||||
scheduling_inference = PlacementGroupSchedulingStrategy(
|
||||
placement_group=pg_inference,
|
||||
placement_group_capture_child_tasks=True,
|
||||
placement_group_bundle_index=0,
|
||||
)
|
||||
"""
|
||||
launch the vLLM inference engine.
|
||||
here we use `enforce_eager` to reduce the start time.
|
||||
"""
|
||||
llm = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=0,
|
||||
scheduling_strategy=scheduling_inference,
|
||||
)(MyLLM).remote(
|
||||
model="facebook/opt-125m",
|
||||
enforce_eager=True,
|
||||
worker_cls=MyWorker,
|
||||
tensor_parallel_size=2,
|
||||
distributed_executor_backend="ray",
|
||||
)
|
||||
|
||||
# Generate texts from the prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
|
||||
outputs = ray.get(llm.generate.remote(prompts, sampling_params))
|
||||
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, "
|
||||
f"Generated text: {generated_text!r}")
|
||||
|
||||
# set up the communication between the training process
|
||||
# and the inference engine.
|
||||
master_address = get_ip()
|
||||
master_port = get_open_port()
|
||||
|
||||
handle = llm.collective_rpc.remote("init_weight_update_group",
|
||||
args=(master_address, master_port, 1, 3))
|
||||
model_update_group = stateless_init_process_group(master_address, master_port,
|
||||
0, 3, torch.device("cuda:0"))
|
||||
ray.get(handle)
|
||||
|
||||
# simulate training, modify the weights of the model.
|
||||
for name, p in train_model.named_parameters():
|
||||
p.data.zero_()
|
||||
|
||||
# sync weight from the training process to the inference engine.
|
||||
for name, p in train_model.named_parameters():
|
||||
handle = llm.collective_rpc.remote("update_weight",
|
||||
args=(name, p.dtype, p.shape))
|
||||
model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream())
|
||||
ray.get(handle)
|
||||
|
||||
# check if the weights are updated.
|
||||
assert all(ray.get(llm.collective_rpc.remote("check_weights_changed")))
|
||||
|
||||
# use the updated model to generate texts, they will be nonsense
|
||||
# because the weights are all zeros.
|
||||
outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))
|
||||
for output in outputs_updated:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, "
|
||||
f"Generated text: {generated_text!r}")
|
||||
64
examples/offline_inference/torchrun_example.py
Normal file
64
examples/offline_inference/torchrun_example.py
Normal file
@ -0,0 +1,64 @@
|
||||
"""
|
||||
experimental support for tensor-parallel inference with torchrun,
|
||||
see https://github.com/vllm-project/vllm/issues/11400 for
|
||||
the motivation and use case for this example.
|
||||
run the script with `torchrun --nproc-per-node=2 torchrun_example.py`,
|
||||
the argument 2 should match the `tensor_parallel_size` below.
|
||||
see `tests/distributed/test_torchrun_example.py` for the unit test.
|
||||
"""
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# Create prompts, the same across all ranks
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
# Create sampling parameters, the same across all ranks
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
# Use `distributed_executor_backend="external_launcher"` so that
|
||||
# this llm engine/instance only creates one worker.
|
||||
llm = LLM(
|
||||
model="facebook/opt-125m",
|
||||
tensor_parallel_size=2,
|
||||
distributed_executor_backend="external_launcher",
|
||||
)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
# all ranks will have the same outputs
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, "
|
||||
f"Generated text: {generated_text!r}")
|
||||
"""
|
||||
Further tips:
|
||||
|
||||
1. to communicate control messages across all ranks, use the cpu group,
|
||||
a PyTorch ProcessGroup with GLOO backend.
|
||||
|
||||
```python
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
cpu_group = get_world_group().cpu_group
|
||||
torch_rank = dist.get_rank(group=cpu_group)
|
||||
if torch_rank == 0:
|
||||
# do something for rank 0, e.g. saving the results to disk.
|
||||
```
|
||||
|
||||
2. to communicate data across all ranks, use the model's device group,
|
||||
a PyTorch ProcessGroup with NCCL backend.
|
||||
```python
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
device_group = get_world_group().device_group
|
||||
```
|
||||
|
||||
3. to access the model directly in every rank, use the following code:
|
||||
```python
|
||||
llm.llm_engine.model_executor.driver_worker.worker.model_runner.model
|
||||
```
|
||||
"""
|
||||
@ -26,14 +26,12 @@ def run_aria(question: str, modality: str):
|
||||
|
||||
# NOTE: Need L40 (or equivalent) to avoid OOM
|
||||
llm = LLM(model=model_name,
|
||||
tokenizer_mode="slow",
|
||||
dtype="bfloat16",
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
trust_remote_code=True,
|
||||
dtype="bfloat16",
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||
|
||||
prompt = (f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>\n{question}"
|
||||
prompt = (f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>{question}"
|
||||
"<|im_end|>\n<|im_start|>assistant\n")
|
||||
|
||||
stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]
|
||||
@ -70,7 +68,7 @@ def run_chameleon(question: str, modality: str):
|
||||
def run_deepseek_vl2(question: str, modality: str):
|
||||
assert modality == "image"
|
||||
|
||||
model_name = "deepseek-ai/deepseek-vl2-small"
|
||||
model_name = "deepseek-ai/deepseek-vl2-tiny"
|
||||
|
||||
llm = LLM(model=model_name,
|
||||
max_model_len=4096,
|
||||
@ -325,7 +323,6 @@ def run_mllama(question: str, modality: str):
|
||||
model=model_name,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=16,
|
||||
enforce_eager=True,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
)
|
||||
|
||||
|
||||
@ -55,7 +55,7 @@ def load_aria(question, image_urls: List[str]) -> ModelRequestData:
|
||||
|
||||
|
||||
def load_deepseek_vl2(question: str, image_urls: List[str]):
|
||||
model_name = "deepseek-ai/deepseek-vl2-small"
|
||||
model_name = "deepseek-ai/deepseek-vl2-tiny"
|
||||
|
||||
llm = LLM(model=model_name,
|
||||
max_model_len=4096,
|
||||
@ -186,7 +186,6 @@ def load_mllama(question, image_urls: List[str]) -> ModelRequestData:
|
||||
model=model_name,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=16,
|
||||
enforce_eager=True,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
)
|
||||
|
||||
@ -394,7 +393,7 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
|
||||
|
||||
model_example_map = {
|
||||
"aria": load_aria,
|
||||
"deepseek_vl2": load_deepseek_vl2,
|
||||
"deepseek_vl_v2": load_deepseek_vl2,
|
||||
"h2ovl_chat": load_h2onvl,
|
||||
"idefics3": load_idefics3,
|
||||
"internvl_chat": load_internvl,
|
||||
|
||||
@ -3,6 +3,8 @@
|
||||
# We will launch 2 vllm instances (1 for prefill and 1 for decode),
|
||||
# and then transfer the KV cache between them.
|
||||
|
||||
set -xe
|
||||
|
||||
echo "🚧🚧 Warning: The usage of disaggregated prefill is experimental and subject to change 🚧🚧"
|
||||
sleep 1
|
||||
|
||||
@ -69,7 +71,7 @@ wait_for_server 8200
|
||||
# instance
|
||||
# NOTE: the usage of this API is subject to change --- in the future we will
|
||||
# introduce "vllm connect" to connect between prefill and decode instances
|
||||
python3 ../benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py &
|
||||
python3 ../../benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py &
|
||||
sleep 1
|
||||
|
||||
# serve two example requests
|
||||
|
||||
@ -1,96 +0,0 @@
|
||||
# FP8 KV Cache
|
||||
|
||||
This utility extracts the KV cache scaling factors from a quantized HF (Hugging Face) model. The extracted scaling factors are saved to a JSON file, which can later be used by vLLM (variable-length language model) during runtime. This tool is particularly useful when the KV cache data type is FP8 and is intended for use on ROCm (AMD GPU) platforms.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Python 3.x
|
||||
- PyTorch
|
||||
- NumPy
|
||||
- Hugging Face Transformers
|
||||
- Hugging Face Hub
|
||||
- AMMO
|
||||
|
||||
Before incorporating the FP8 datatype for inference workloads, you must adhere to the following steps:
|
||||
1. Install all necessary prerequisites and dependencies.
|
||||
2. Convert HF model into a quantized HF model.
|
||||
3. Extract KV Cache Scaling Factors from quantized HF model.
|
||||
4. Load KV Cache Scaling Factors into VLLM.
|
||||
|
||||
### 2. Convert HF model into a quantized HF model.
|
||||
Note: The following steps are adapted from the [TensorRT-LLM repository](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/quantization/README.md).
|
||||
|
||||
`quantize.py` (examples/other/fp8/quantizer/quantize.py) uses the quantization toolkit (AMMO) to calibrate the PyTorch models and export TensorRT-LLM checkpoints. Each TensorRT-LLM checkpoint contains a config file (in .json format) and one or several rank weight files (in .safetensors format).
|
||||
|
||||
The detailed quantization toolkit (AMMO) conversion guide for FP8 can be found at `examples/other/fp8/quantizer/README.md`.
|
||||
|
||||
### 3. Extract KV Cache Scaling Factors from quantized HF model.
|
||||
`extract_scales.py` (examples/other/fp8/extract_scales.py) can be utilized to extract the KV cache scaling factors from your quantized HF model, however at the moment, this tool exclusively supports Llama 2 models. It is also important to note the following:
|
||||
1. **File Structure**: The utility operates under the assumption that all parameters, including KV cache scaling factors, corresponding to a particular Tensor Parallelism (TP) rank are stored in a single file. These files must adhere to a specific naming convention where the TP rank is immediately identified after a specific keyword (e.g., "rank") in the filename.
|
||||
|
||||
2. **TP Decomposition**: The utility assumes consistency between the TP decomposition employed by the quantizer tool and that used by vLLM.
|
||||
|
||||
3. **AMMO Compatibility**: Currently, the generated KV cache scaling factors for AMMO remain uniform across all TP ranks.
|
||||
|
||||
```python
|
||||
# prerequisites:
|
||||
# - Quantized HF LLaMa 2 model
|
||||
python3 examples/other/fp8/extract_scales.py --help
|
||||
Usage: extract_scales.py [-h] --quantized_model QUANTIZED_MODEL [--load_format {auto,safetensors,npz,pt}] [--output_dir OUTPUT_DIR] [--output_name OUTPUT_NAME] [--tp_size TP_SIZE]
|
||||
|
||||
KV Scale Extraction Example
|
||||
|
||||
optional arguments:
|
||||
--quantized_model: Specify either the local path to, or name of, a quantized HF model. It is expected that the quantization format is FP8_E4M3, for use on ROCm (AMD GPU).
|
||||
Optional arguments:
|
||||
--cache_dir: Specify a cache directory to use in the event of a HF model download. (Default: None)
|
||||
--load_format: Specify the format of the model's tensor files containing the KV cache scaling factors. (Choices: auto, safetensors, npz, pt; Default: auto)
|
||||
--revision: Specify the model's revision number. (Default: None)
|
||||
--output_dir: Specify the output directory. By default the KV cache scaling factors will be saved in the model directory. (Default: None)
|
||||
--output_name: Specify the output filename. (Default: kv_cache_scales.json)
|
||||
--tp_size: Specify the tensor-parallel (TP) size that the quantized model should correspond to. If specified, during KV cache scaling factor extraction the observed TP size will be checked against this and an error will be raised if there is a mismatch. (Default: None)
|
||||
```
|
||||
```python
|
||||
Example:
|
||||
python3 examples/other/fp8/extract_scales.py --quantized_model <QUANTIZED_MODEL_DIR> --tp_size <TENSOR_PARALLEL_SIZE> --output_dir <PATH_TO_OUTPUT_DIR>
|
||||
```
|
||||
### 4. Load KV Cache Scaling Factors into VLLM.
|
||||
This script evaluates the inference throughput of language models using various backends such as vLLM. It measures the time taken to process a given number of prompts and generate sequences for each prompt. The recently generated KV cache scaling factors are now integrated into the benchmarking process and allow for KV cache scaling factors to be utilized for FP8.
|
||||
```
|
||||
# prerequisites:
|
||||
# - LLaMa 2 kv_cache_scales.json file
|
||||
|
||||
python3 benchmarks/benchmark_throughput.py --help
|
||||
usage: benchmark_throughput.py [-h] [--backend {vllm,hf,mii}] [--dataset DATASET] [--input-len INPUT_LEN] [--output-len OUTPUT_LEN] [--model MODEL]
|
||||
[--tokenizer TOKENIZER] [--quantization {awq,gptq,None}] [--tensor-parallel-size TENSOR_PARALLEL_SIZE] [--n N]
|
||||
[--use-beam-search] [--num-prompts NUM_PROMPTS] [--seed SEED] [--hf-max-batch-size HF_MAX_BATCH_SIZE] [--trust-remote-code]
|
||||
[--max-model-len MAX_MODEL_LEN] [--dtype {auto,half,float16,bfloat16,float,float32}] [--enforce-eager] [--kv-cache-dtype {auto,fp8}]
|
||||
[--quantization-param-path KV_CACHE_quantization_param_path]
|
||||
|
||||
Benchmark Throughput Example
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
--backend {vllm,hf,mii}
|
||||
--dataset DATASET Path to the dataset.
|
||||
--input-len INPUT_LEN Input prompt length for each request
|
||||
--output-len OUTPUT_LEN Output length for each request. Overrides the output length from the dataset.
|
||||
--model MODEL
|
||||
--tokenizer TOKENIZER
|
||||
--quantization {awq,gptq,None}, -q {awq,gptq,None}
|
||||
--tensor-parallel-size TENSOR_PARALLEL_SIZE, -tp TENSOR_PARALLEL_SIZE
|
||||
--n N Number of generated sequences per prompt.
|
||||
--use-beam-search
|
||||
--num-prompts NUM_PROMPTS Number of prompts to process.
|
||||
--seed SEED
|
||||
--hf-max-batch-size HF_MAX_BATCH_SIZE Maximum batch size for HF backend.
|
||||
--trust-remote-code trust remote code from huggingface
|
||||
--max-model-len MAX_MODEL_LEN Maximum length of a sequence (including prompt and output). If None, will be derived from the model.
|
||||
--dtype {auto,half,float16,bfloat16,float,float32} data type for model weights and activations. The "auto" option will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models.
|
||||
--enforce-eager enforce eager execution
|
||||
--kv-cache-dtype {auto,fp8} Data type for kv cache storage. If "auto", will use model data type. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported ```for common inference criteria.
|
||||
--quantization-param-path QUANT_PARAM_JSON Path to the JSON file containing the KV cache scaling factors. This should generally be supplied, when KV cache dtype is FP8. Otherwise, KV cache scaling factors default to 1.0, which may cause accuracy issues. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for common inference criteria.
|
||||
```
|
||||
Example:
|
||||
```console
|
||||
python3 benchmarks/benchmark_throughput.py --input-len <INPUT_LEN> --output-len <OUTPUT_LEN> -tp <TENSOR_PARALLEL_SIZE> --kv-cache-dtype fp8 --quantization-param-path <path/to/kv_cache_scales.json> --model <path-to-llama2>
|
||||
```
|
||||
@ -1,367 +0,0 @@
|
||||
import argparse
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from safetensors.torch import safe_open
|
||||
|
||||
from vllm.model_executor.layers.quantization.schema import QuantParamSchema
|
||||
|
||||
|
||||
# Adapted from vllm/model_executor/model_loader/weight_utils.py
|
||||
# The main differences are that we add the NPZ format and simplify
|
||||
# its functionality drastically for our purposes (e.g. we assume that
|
||||
# the quantized model exists locally and there is no need to download it)
|
||||
def _prepare_hf_weights(
|
||||
quantized_model_dir: str,
|
||||
load_format: str = "auto",
|
||||
fall_back_to_pt: bool = True,
|
||||
) -> Tuple[List[str], bool]:
|
||||
if not os.path.isdir(quantized_model_dir):
|
||||
raise FileNotFoundError(
|
||||
f"The quantized model directory `{quantized_model_dir}` "
|
||||
"does not exist.")
|
||||
use_safetensors = False
|
||||
# Some quantized models use .pt files for storing the weights.
|
||||
if load_format == "auto":
|
||||
allow_patterns = ["*.safetensors", "*.bin"]
|
||||
elif load_format == "safetensors":
|
||||
use_safetensors = True
|
||||
allow_patterns = ["*.safetensors"]
|
||||
elif load_format == "pt":
|
||||
allow_patterns = ["*.pt"]
|
||||
elif load_format == "npz":
|
||||
allow_patterns = ["*.npz"]
|
||||
else:
|
||||
raise ValueError(f"Unknown load_format: {load_format}")
|
||||
if fall_back_to_pt:
|
||||
allow_patterns += ["*.pt"]
|
||||
|
||||
hf_weights_files: List[str] = []
|
||||
for pattern in allow_patterns:
|
||||
hf_weights_files += glob.glob(
|
||||
os.path.join(quantized_model_dir, pattern))
|
||||
if len(hf_weights_files) > 0:
|
||||
if pattern == "*.safetensors":
|
||||
use_safetensors = True
|
||||
break
|
||||
|
||||
if not use_safetensors:
|
||||
# Exclude files that are not needed for inference.
|
||||
# https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
|
||||
blacklist = [
|
||||
"training_args.bin",
|
||||
"optimizer.bin",
|
||||
"optimizer.pt",
|
||||
"scheduler.pt",
|
||||
"scaler.pt",
|
||||
]
|
||||
hf_weights_files = [
|
||||
f for f in hf_weights_files
|
||||
if not any(f.endswith(x) for x in blacklist)
|
||||
]
|
||||
|
||||
if len(hf_weights_files) == 0:
|
||||
raise RuntimeError(
|
||||
f"Cannot find any model weights with `{quantized_model_dir}`")
|
||||
|
||||
return hf_weights_files, use_safetensors
|
||||
|
||||
|
||||
# Adapted from vllm/model_executor/model_loader/weight_utils.py
|
||||
def _hf_tensorfile_iterator(filename: str, load_format: str,
|
||||
use_safetensors: bool):
|
||||
if load_format == "npz":
|
||||
assert not use_safetensors
|
||||
with np.load(filename) as data:
|
||||
for name in data.files:
|
||||
param = torch.from_numpy(data[name])
|
||||
yield name, param
|
||||
elif use_safetensors:
|
||||
with safe_open(filename, framework="pt") as f:
|
||||
for name in f.keys(): # NOQA: SIM118
|
||||
param = f.get_tensor(name)
|
||||
yield name, param
|
||||
else:
|
||||
state = torch.load(filename, map_location="cpu")
|
||||
for name, param in state.items():
|
||||
yield name, param
|
||||
del state
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def _kv_scales_extractor(
|
||||
hf_tensor_files: List[str],
|
||||
use_safetensors: bool,
|
||||
rank_keyword: str = "rank",
|
||||
expected_tp_size: Optional[int] = None) -> Dict[int, Dict[int, float]]:
|
||||
"""
|
||||
Given a list of files containing tensor data, attempt to extract KV cache
|
||||
scales from these files. Intended as a helper function taking in the output
|
||||
from _prepare_hf_weights.
|
||||
Args:
|
||||
rank_keyword Matches the number immediately after this keyword in the
|
||||
tensor filename to determine the TP rank corresponding
|
||||
to said tensor file
|
||||
expected_tp_size If specified, the TP size of the tensor files is checked
|
||||
against this and an error is raised if they don't match.
|
||||
Returns a dictionary mapping TP ranks to their relevant KV cache scales.
|
||||
The per-rank scales are themselves represented as a dictionary of layer
|
||||
indices to the respective per-layer scale.
|
||||
"""
|
||||
for char in rank_keyword:
|
||||
assert not char.isdecimal(
|
||||
), f"Rank keyword {rank_keyword} contains a numeric character!"
|
||||
rank_scales_map: Dict[int, Dict[int, float]] = {}
|
||||
for tensor_file in hf_tensor_files:
|
||||
try:
|
||||
rank_idx = tensor_file.find(rank_keyword)
|
||||
if rank_idx != -1:
|
||||
start_idx = rank_idx + len(rank_keyword)
|
||||
stop_idx = start_idx
|
||||
while stop_idx < len(
|
||||
tensor_file) and tensor_file[stop_idx].isdecimal():
|
||||
stop_idx += 1
|
||||
if stop_idx == start_idx:
|
||||
raise RuntimeError("Did not find rank # in filename.")
|
||||
rank = int(tensor_file[start_idx:stop_idx])
|
||||
elif len(hf_tensor_files) == 1:
|
||||
# Since there is only one tensor file, we can assume
|
||||
# that it's intended for TP rank 0
|
||||
rank = 0
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Filename does not contain '{rank_keyword}'.")
|
||||
except RuntimeError:
|
||||
print("Unable to determine TP rank "
|
||||
f"corresponding to file '{tensor_file}'")
|
||||
raise
|
||||
|
||||
if rank not in rank_scales_map:
|
||||
layer_scales_map: Dict[int, float] = {}
|
||||
rank_scales_map[rank] = layer_scales_map
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Tensor file '{tensor_file}' shares TP rank {rank} "
|
||||
"with another tensor file.")
|
||||
|
||||
module_delimiter = ":" if args.load_format == "npz" else "."
|
||||
for name, param in _hf_tensorfile_iterator(tensor_file,
|
||||
args.load_format,
|
||||
use_safetensors):
|
||||
if "kv_cache_scaling_factor" in name:
|
||||
nums = [
|
||||
int(s) for s in name.split(module_delimiter)
|
||||
if s.isdecimal()
|
||||
]
|
||||
assert len(
|
||||
nums) == 1, f"Could not determine layer idx for {name}"
|
||||
layer_idx = nums[0]
|
||||
assert layer_idx not in layer_scales_map, f"Duplicate scaling"\
|
||||
f" factor corresponding to layer {layer_idx}"
|
||||
try:
|
||||
layer_scales_map[layer_idx] = param.item()
|
||||
except RuntimeError:
|
||||
print(
|
||||
"This utility supports only per-tensor scalar scales "
|
||||
f"for now. The tensor\n {name} = {param} \nis an "
|
||||
"invalid scale factor.")
|
||||
raise
|
||||
|
||||
if all(
|
||||
len(layer_scales_map) == 0
|
||||
for layer_scales_map in rank_scales_map.values()):
|
||||
# Note: this is true even if the rank_scales_map is empty
|
||||
print("WARNING: No KV cache scale factors found. No output saved.")
|
||||
return None
|
||||
empirical_tp_world_size = max(rank_scales_map.keys()) + 1
|
||||
if expected_tp_size is not None:
|
||||
assert expected_tp_size == empirical_tp_world_size, \
|
||||
f"User expected TP world size = {expected_tp_size} " \
|
||||
"from model but tool is expecting TP world size = " \
|
||||
f"{empirical_tp_world_size} from model instead."
|
||||
for i in range(empirical_tp_world_size):
|
||||
assert i in rank_scales_map, "Expected TP world size = "\
|
||||
f"{empirical_tp_world_size} but did not find KV " \
|
||||
f"cache scaling factors for TP rank {i}"
|
||||
print(f"Found TP world size = {empirical_tp_world_size} "
|
||||
"when extracting KV cache scales!")
|
||||
return rank_scales_map
|
||||
|
||||
|
||||
def _metadata_extractor(quantized_model_dir: str,
|
||||
metadata_extract_fns: \
|
||||
Dict[str, Callable[[Dict[str, Any]], Any]]) \
|
||||
-> Dict[str, Any]:
|
||||
"""
|
||||
Given a directory containing quantized model files, this function
|
||||
aims to extract metadata from the JSON files within this directory.
|
||||
Each JSON file is expected to represent a dictionary in JSON
|
||||
format (referred to as a "JSON-dictionary"). Metadata extraction is
|
||||
defined by a dictionary called metadata_extract_fns, where each
|
||||
metadata field name is mapped to an extraction function.
|
||||
|
||||
These extraction functions are designed to take a JSON-dictionary
|
||||
as their only argument and return the corresponding metadata.
|
||||
While extraction functions are permitted to raise exceptions, they
|
||||
should only raise a KeyError or ValueError if the metadata field
|
||||
cannot be extracted from the current JSON-dictionary, yet there's
|
||||
a possibility of finding it in another JSON-dictionary.
|
||||
|
||||
The function returns a dictionary that maps metadata fields to
|
||||
their extracted data. The keys of this dictionary correspond exactly
|
||||
to those in metadata_extract_fns. If any fields fail to be extracted,
|
||||
their corresponding values are set to None, and a warning is printed.
|
||||
"""
|
||||
if not os.path.isdir(quantized_model_dir):
|
||||
raise FileNotFoundError(
|
||||
f"The quantized model directory `{quantized_model_dir}` "
|
||||
"does not exist.")
|
||||
metadata_files = glob.glob(os.path.join(quantized_model_dir, "*.json"))
|
||||
|
||||
result: Dict[str, Any] = {}
|
||||
for file in metadata_files:
|
||||
with open(file) as f:
|
||||
try:
|
||||
metadata = json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
print(f"Could not parse `{file}` as a valid metadata file,"
|
||||
" skipping it.")
|
||||
continue
|
||||
if not isinstance(metadata, dict):
|
||||
print(f"The file `{file}` does not correspond to a "
|
||||
"JSON-serialized dictionary, skipping it.")
|
||||
continue
|
||||
for metadata_name, extract_fn in metadata_extract_fns.items():
|
||||
try:
|
||||
metadata_info = extract_fn(metadata)
|
||||
if metadata_name not in result:
|
||||
result[metadata_name] = metadata_info
|
||||
elif metadata_info != result[metadata_name]:
|
||||
raise RuntimeError(
|
||||
"Metadata mismatch! Originally found "
|
||||
f"{metadata_name} = {result[metadata_name]} but "
|
||||
f"now found {metadata_name} = {metadata_info} in "
|
||||
f"`{file}`")
|
||||
except KeyError:
|
||||
# It is possible that a given file does not contain some
|
||||
# of our selected metadata as it could be located in some
|
||||
# other metadata file.
|
||||
# 'EFINAE': extract_fn failure is not an error.
|
||||
pass
|
||||
except ValueError:
|
||||
# See above.
|
||||
pass
|
||||
|
||||
# Warn if we cannot find any of the requested metadata
|
||||
for metadata_name in metadata_extract_fns:
|
||||
if metadata_name not in result:
|
||||
print("WARNING: Unable to find requested metadata field "
|
||||
f"`{metadata_name}`, setting it to None.")
|
||||
result[metadata_name] = None
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main(args):
|
||||
metadata_extract_fns = {
|
||||
"model_type": lambda json_dict: json_dict["layers"][0]["decoder_type"],
|
||||
"tp_size": lambda json_dict: int(json_dict["tensor_parallel"]),
|
||||
"model_dtype": lambda json_dict: json_dict["dtype"]
|
||||
}
|
||||
recovered_metadata = _metadata_extractor(args.quantized_model,
|
||||
metadata_extract_fns)
|
||||
if args.tp_size is not None:
|
||||
metadata_tp_size = recovered_metadata["tp_size"]
|
||||
if metadata_tp_size is not None:
|
||||
assert args.tp_size == metadata_tp_size, \
|
||||
f"User expected TP world size = {args.tp_size} " \
|
||||
f"but found TP world size = {metadata_tp_size} from metadata!"
|
||||
expected_tp_size = args.tp_size or recovered_metadata["tp_size"]
|
||||
rank_keyword = "rank"
|
||||
hf_tensor_files, use_safetensors = _prepare_hf_weights(
|
||||
args.quantized_model, args.load_format)
|
||||
rank_scales_map = _kv_scales_extractor(hf_tensor_files, use_safetensors,
|
||||
rank_keyword, expected_tp_size)
|
||||
# Postprocess: formatting to the current schema. Consider pulling it
|
||||
# out into a dedicated function should it ever become more complicated.
|
||||
rank_scales_map = {
|
||||
rank: {k: scale[k]
|
||||
for k in sorted(scale.keys())}
|
||||
for rank, scale in rank_scales_map.items()
|
||||
}
|
||||
# TODO: Expand this with activation and weights scaling factors when
|
||||
# they are used in the future
|
||||
schema = QuantParamSchema(
|
||||
model_type=recovered_metadata["model_type"],
|
||||
kv_cache={
|
||||
"dtype": ("float8_e4m3fn" if len(rank_scales_map) > 0 else
|
||||
recovered_metadata["model_dtype"]),
|
||||
"scaling_factor":
|
||||
rank_scales_map
|
||||
},
|
||||
)
|
||||
|
||||
if args.output_dir is None:
|
||||
output_file = os.path.join(args.quantized_model, args.output_name)
|
||||
else:
|
||||
if not os.path.isdir(args.output_dir):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
output_file = os.path.join(args.output_dir, args.output_name)
|
||||
|
||||
with open(output_file, 'w') as f:
|
||||
f.write(schema.model_dump_json(indent=4))
|
||||
print(f"Completed! KV cache scaling factors saved to {output_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="This simple utility extracts the "
|
||||
"KV cache scaling factors from a quantized HF model "
|
||||
"and saves them to a JSON file compatible with later "
|
||||
"use by vLLM (pass this file to the appropriate "
|
||||
"runtime typically using the argument "
|
||||
"--quantization-param-path <filename>). This is only used "
|
||||
"if the KV cache dtype is FP8 and on ROCm (AMD GPU).")
|
||||
parser.add_argument(
|
||||
"--quantized-model",
|
||||
help="Specify the directory containing a single quantized HF model. "
|
||||
"It is expected that the quantization format is FP8_E4M3, for use "
|
||||
"on ROCm (AMD GPU).",
|
||||
required=True)
|
||||
parser.add_argument(
|
||||
"--load_format",
|
||||
help="Optionally specify the format of the model's tensor files "
|
||||
"containing the KV cache scaling factors.",
|
||||
choices=["auto", "safetensors", "npz", "pt"],
|
||||
default="auto")
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
help="Optionally specify the output directory. By default the "
|
||||
"KV cache scaling factors will be saved in the model directory, "
|
||||
"however you can override this behavior here.",
|
||||
default=None)
|
||||
parser.add_argument(
|
||||
"--output-name",
|
||||
help="Optionally specify the output filename.",
|
||||
# TODO: Change this once additional scaling factors are enabled
|
||||
default="kv_cache_scales.json")
|
||||
parser.add_argument(
|
||||
"--tp-size",
|
||||
help="Optionally specify the tensor-parallel (TP) size that the "
|
||||
"quantized model should correspond to. If specified, during KV "
|
||||
"cache scaling factor extraction the observed TP size will be "
|
||||
"checked against this and an error will be raised if there is "
|
||||
"a mismatch. If not specified, the quantized model's expected "
|
||||
"TP size is instead inferred from the largest TP rank observed. "
|
||||
"The expected TP size is cross-checked against the TP ranks "
|
||||
"observed in the quantized model and an error is raised if any "
|
||||
"discrepancies are found.",
|
||||
default=None,
|
||||
type=int)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
@ -1,32 +0,0 @@
|
||||
### Quantizer Utilities
|
||||
`quantize.py`: NVIDIA Quantization utilities using TensorRT-Model-Optimizer, ported
|
||||
from TensorRT-LLM: [`examples/quantization/quantize.py`](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/quantization/quantize.py)
|
||||
|
||||
### Prerequisite
|
||||
|
||||
#### AMMO (AlgorithMic Model Optimization) Installation: nvidia-ammo 0.7.1 or later
|
||||
`pip install --no-cache-dir --extra-index-url https://pypi.nvidia.com nvidia-ammo`
|
||||
|
||||
#### AMMO Download (code and docs)
|
||||
`https://developer.nvidia.com/downloads/assets/cuda/files/nvidia-ammo/nvidia_ammo-0.5.0.tar.gz`
|
||||
`https://developer.nvidia.com/downloads/assets/cuda/files/nvidia-ammo/nvidia_ammo-0.7.1.tar.gz`
|
||||
|
||||
### Usage
|
||||
|
||||
#### Run on H100 system for speed if FP8; number of GPUs depends on the model size
|
||||
|
||||
#### Example: quantize Llama2-7b model from HF to FP8 with FP8 KV Cache:
|
||||
`python quantize.py --model-dir ./ll2-7b --dtype float16 --qformat fp8 --kv-cache-dtype fp8 --output-dir ./ll2_7b_fp8 --calib-size 512 --tp-size 1`
|
||||
|
||||
Outputs: model structure, quantized model & parameters (with scaling factors) are in JSON and Safetensors (npz is generated only for the reference)
|
||||
```
|
||||
# ll ./ll2_7b_fp8/
|
||||
total 19998244
|
||||
drwxr-xr-x 2 root root 4096 Feb 7 01:08 ./
|
||||
drwxrwxr-x 8 1060 1061 4096 Feb 7 01:08 ../
|
||||
-rw-r--r-- 1 root root 176411 Feb 7 01:08 llama_tp1.json
|
||||
-rw-r--r-- 1 root root 13477087480 Feb 7 01:09 llama_tp1_rank0.npz
|
||||
-rw-r--r-- 1 root root 7000893272 Feb 7 01:08 rank0.safetensors
|
||||
#
|
||||
```
|
||||
|
||||
@ -1,367 +0,0 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # noqa: E501
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Adapted from examples/quantization/hf_ptq.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
|
||||
import ammo.torch.quantization as atq
|
||||
import numpy as np
|
||||
import torch
|
||||
from ammo.torch.export import export_model_config
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
RAND_SEED = 1234
|
||||
MAX_SEQ_LEN = 2048
|
||||
|
||||
EMPTY_CFG = {
|
||||
"quant_cfg": {
|
||||
"*weight_quantizer": {
|
||||
"enable": False,
|
||||
},
|
||||
"*input_quantizer": {
|
||||
"enable": False
|
||||
},
|
||||
"*lm_head*": {
|
||||
"enable": False
|
||||
},
|
||||
"*output_layer*": {
|
||||
"enable": False
|
||||
},
|
||||
"default": {
|
||||
"enable": False
|
||||
},
|
||||
},
|
||||
"algorithm": "max",
|
||||
}
|
||||
|
||||
KV_CACHE_CFG = {
|
||||
"*.query_key_value.output_quantizer": {
|
||||
"num_bits": 8,
|
||||
"axis": None,
|
||||
"enable": True
|
||||
},
|
||||
"*.Wqkv.output_quantizer": {
|
||||
"num_bits": 8,
|
||||
"axis": None,
|
||||
"enable": True
|
||||
},
|
||||
"*.W_pack.output_quantizer": {
|
||||
"num_bits": 8,
|
||||
"axis": None,
|
||||
"enable": True
|
||||
},
|
||||
"*.c_attn.output_quantizer": {
|
||||
"num_bits": 8,
|
||||
"axis": None,
|
||||
"enable": True
|
||||
},
|
||||
"*.k_proj.output_quantizer": {
|
||||
"num_bits": 8,
|
||||
"axis": None,
|
||||
"enable": True
|
||||
},
|
||||
"*.v_proj.output_quantizer": {
|
||||
"num_bits": 8,
|
||||
"axis": None,
|
||||
"enable": True
|
||||
},
|
||||
}
|
||||
|
||||
QUANT_CFG_CHOICES = {
|
||||
"int8_sq": atq.INT8_SMOOTHQUANT_CFG,
|
||||
"fp8": atq.FP8_DEFAULT_CFG,
|
||||
"int4_awq": atq.INT4_AWQ_CFG,
|
||||
"w4a8_awq": atq.W4A8_AWQ_BETA_CFG,
|
||||
"int8_wo": EMPTY_CFG,
|
||||
"int4_wo": EMPTY_CFG,
|
||||
"full_prec": EMPTY_CFG,
|
||||
}
|
||||
|
||||
MODEL_NAME_PATTERN_MAP = {
|
||||
"GPT2": "gpt2",
|
||||
"Xverse": "llama",
|
||||
"Llama": "llama",
|
||||
"Mistral": "llama",
|
||||
"GPTJ": "gptj",
|
||||
"FalconForCausalLM": "falcon",
|
||||
"RWForCausalLM": "falcon",
|
||||
"baichuan": "baichuan",
|
||||
"MPT": "mpt",
|
||||
"Bloom": "bloom",
|
||||
"ChatGLM": "chatglm",
|
||||
"QWen": "qwen",
|
||||
}
|
||||
|
||||
|
||||
def get_tokenizer(ckpt_path, max_seq_len=MAX_SEQ_LEN, model_type=None):
|
||||
print(f"Initializing tokenizer from {ckpt_path}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
ckpt_path,
|
||||
model_max_length=max_seq_len,
|
||||
padding_side="left",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
if model_type and model_type == "qwen":
|
||||
# qwen use token id 151643 as pad and eos tokens
|
||||
tokenizer.pad_token = tokenizer.convert_ids_to_tokens(151643)
|
||||
tokenizer.eos_token = tokenizer.convert_ids_to_tokens(151643)
|
||||
|
||||
# can't set attribute 'pad_token' for "<unk>"
|
||||
if tokenizer.pad_token != "<unk>":
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
assert (tokenizer.pad_token
|
||||
is not None), f"Pad token for {model_type} cannot be set!"
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
def get_model(ckpt_path, dtype="fp16", device="cuda"):
|
||||
print(f"Initializing model from {ckpt_path}")
|
||||
if dtype == "bf16" or dtype == "bfloat16":
|
||||
dtype = torch.bfloat16
|
||||
elif dtype == "fp16" or dtype == "float16":
|
||||
dtype = torch.float16
|
||||
elif dtype == "fp32" or dtype == "float32":
|
||||
dtype = torch.float32
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown dtype {dtype}")
|
||||
|
||||
# model_kwargs = {"torch_dtype": dtype}
|
||||
model_kwargs = {"torch_dtype": "auto"}
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(ckpt_path,
|
||||
device_map="auto",
|
||||
**model_kwargs,
|
||||
trust_remote_code=True)
|
||||
model.eval()
|
||||
|
||||
model_dtype = next(model.parameters()).dtype
|
||||
if dtype != model_dtype:
|
||||
print("[TensorRT-LLM][WARNING] The manually set model data type is "
|
||||
f"{dtype}, but the data type of the HuggingFace model is "
|
||||
f"{model_dtype}.")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_model_type(model):
|
||||
for k, v in MODEL_NAME_PATTERN_MAP.items():
|
||||
if k.lower() in type(model).__name__.lower():
|
||||
return v
|
||||
return None
|
||||
|
||||
|
||||
def get_calib_dataloader(data="cnn_dailymail",
|
||||
tokenizer=None,
|
||||
batch_size=1,
|
||||
calib_size=512,
|
||||
block_size=512,
|
||||
device=None):
|
||||
print("Loading calibration dataset")
|
||||
if data == "pileval":
|
||||
dataset = load_dataset(
|
||||
"json",
|
||||
data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst",
|
||||
split="train")
|
||||
dataset = dataset["text"][:calib_size]
|
||||
elif data == "cnn_dailymail":
|
||||
dataset = load_dataset("cnn_dailymail", name="3.0.0", split="train")
|
||||
dataset = dataset["article"][:calib_size]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
batch_encoded = tokenizer.batch_encode_plus(dataset,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=block_size)
|
||||
if device:
|
||||
batch_encoded = batch_encoded.to(device)
|
||||
batch_encoded = batch_encoded["input_ids"]
|
||||
|
||||
calib_dataloader = DataLoader(batch_encoded,
|
||||
batch_size=batch_size,
|
||||
shuffle=False)
|
||||
|
||||
return calib_dataloader
|
||||
|
||||
|
||||
def quantize_model(model, quant_cfg, calib_dataloader=None):
|
||||
|
||||
def calibrate_loop():
|
||||
if calib_dataloader is None:
|
||||
return
|
||||
"""Adjusts weights and scaling factors based on selected algorithms."""
|
||||
for idx, data in enumerate(calib_dataloader):
|
||||
print(f"Calibrating batch {idx}")
|
||||
model(data)
|
||||
|
||||
print("Starting quantization...")
|
||||
start_time = time.time()
|
||||
atq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
|
||||
end_time = time.time()
|
||||
print("Quantization done. Total time used: {:.2f} s.".format(end_time -
|
||||
start_time))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def main(args):
|
||||
if not torch.cuda.is_available():
|
||||
raise OSError("GPU is required for inference.")
|
||||
|
||||
random.seed(RAND_SEED)
|
||||
np.random.seed(RAND_SEED)
|
||||
|
||||
model = get_model(args.model_dir, args.dtype, args.device)
|
||||
model_type = get_model_type(model)
|
||||
tokenizer = get_tokenizer(args.model_dir, model_type=model_type)
|
||||
|
||||
if args.qformat in ["full_prec", "int8_wo", "int4_wo"
|
||||
] and args.kv_cache_dtype is None:
|
||||
print(f"No quantization applied, export {args.dtype} model")
|
||||
else:
|
||||
if "awq" in args.qformat:
|
||||
if args.calib_size > 32:
|
||||
print("AWQ calibration could take longer with calib_size = "
|
||||
f"{args.calib_size}, Using calib_size=32 instead")
|
||||
args.calib_size = 32
|
||||
print("\nAWQ calibration could take longer than other calibration "
|
||||
"methods. Please increase the batch size to speed up the "
|
||||
"calibration process. Batch size can be set by adding the "
|
||||
"argument --batch_size <batch_size> to the command line.\n")
|
||||
|
||||
calib_dataloader = get_calib_dataloader(
|
||||
tokenizer=tokenizer,
|
||||
batch_size=args.batch_size,
|
||||
calib_size=args.calib_size,
|
||||
device=args.device,
|
||||
)
|
||||
|
||||
if args.qformat in QUANT_CFG_CHOICES:
|
||||
quant_cfg = QUANT_CFG_CHOICES[args.qformat]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported quantization format: {args.qformat}")
|
||||
|
||||
if "awq" in args.qformat:
|
||||
quant_cfg = copy.deepcopy(QUANT_CFG_CHOICES[args.qformat])
|
||||
weight_quantizer = quant_cfg["quant_cfg"][
|
||||
"*weight_quantizer"] # type: ignore
|
||||
if isinstance(weight_quantizer, list):
|
||||
weight_quantizer = weight_quantizer[0]
|
||||
weight_quantizer["block_sizes"][-1] = args.awq_block_size
|
||||
|
||||
if args.kv_cache_dtype is not None:
|
||||
if args.kv_cache_dtype == "fp8":
|
||||
for value in KV_CACHE_CFG.values():
|
||||
value.update({"num_bits": (4, 3)}) # type: ignore
|
||||
quant_cfg["quant_cfg"].update(KV_CACHE_CFG) # type: ignore
|
||||
|
||||
print(quant_cfg)
|
||||
|
||||
model = quantize_model(model, quant_cfg, calib_dataloader)
|
||||
|
||||
with torch.inference_mode():
|
||||
if model_type is None:
|
||||
print(f"Unknown model type {type(model).__name__}. Continue "
|
||||
"exporting...")
|
||||
model_type = f"unknown:{type(model).__name__}"
|
||||
|
||||
export_path = args.output_dir
|
||||
start_time = time.time()
|
||||
|
||||
if args.qformat == "int4_awq" and model_type == "qwen":
|
||||
torch.save(model.state_dict(), export_path)
|
||||
else:
|
||||
export_npz = (model_type not in [
|
||||
'gptj', 'falcon', 'chatglm', 'mpt', 'llama', 'baichuan'
|
||||
])
|
||||
|
||||
# export safetensors
|
||||
export_model_config(
|
||||
model,
|
||||
model_type,
|
||||
getattr(torch, args.dtype),
|
||||
export_dir=export_path,
|
||||
inference_tensor_parallel=args.tp_size,
|
||||
inference_pipeline_parallel=args.pp_size,
|
||||
# export_tensorrt_llm_config=(not export_npz),
|
||||
export_tensorrt_llm_config=False,
|
||||
export_npz=export_npz)
|
||||
|
||||
# Workaround for wo quantization
|
||||
if args.qformat in ["int8_wo", "int4_wo", "full_prec"]:
|
||||
with open(f"{export_path}/config.json") as f:
|
||||
tensorrt_llm_config = json.load(f)
|
||||
if args.qformat == "int8_wo":
|
||||
tensorrt_llm_config["quantization"]["quant_algo"] = 'W8A16'
|
||||
elif args.qformat == "int4_wo":
|
||||
tensorrt_llm_config["quantization"]["quant_algo"] = 'W4A16'
|
||||
else:
|
||||
tensorrt_llm_config["quantization"]["quant_algo"] = None
|
||||
with open(f"{export_path}/config.json", "w") as f:
|
||||
json.dump(tensorrt_llm_config, f, indent=4)
|
||||
|
||||
end_time = time.time()
|
||||
print("Quantized model exported to {} \nTotal time used {:.2f} s.".
|
||||
format(export_path, end_time - start_time))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--model-dir",
|
||||
help="Specify where the HuggingFace model is",
|
||||
required=True)
|
||||
parser.add_argument("--device", default="cuda")
|
||||
parser.add_argument("--dtype", help="Model data type.", default="float16")
|
||||
parser.add_argument(
|
||||
"--qformat",
|
||||
help="Quantization format.",
|
||||
default="full_prec",
|
||||
choices=[
|
||||
"fp8", "int8_sq", "int4_awq", "w4a8_awq", "int8_wo", "int4_wo",
|
||||
"full_prec"
|
||||
],
|
||||
)
|
||||
parser.add_argument("--batch-size",
|
||||
help="Batch size for calibration.",
|
||||
type=int,
|
||||
default=1)
|
||||
parser.add_argument("--calib-size",
|
||||
help="Number of samples for calibration.",
|
||||
type=int,
|
||||
default=512)
|
||||
parser.add_argument("--output-dir", default="exported_model")
|
||||
parser.add_argument("--tp-size", type=int, default=1)
|
||||
parser.add_argument("--pp-size", type=int, default=1)
|
||||
parser.add_argument("--awq-block-size", type=int, default=128)
|
||||
parser.add_argument("--kv-cache-dtype",
|
||||
help="KV Cache dtype.",
|
||||
default=None,
|
||||
choices=["int8", "fp8", None])
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
23
examples/template_deepseek_vl2.jinja
Normal file
23
examples/template_deepseek_vl2.jinja
Normal file
@ -0,0 +1,23 @@
|
||||
{%- if messages[0]['role'] == 'system' -%}
|
||||
{%- set system_message = messages[0]['content'] -%}
|
||||
{%- set messages = messages[1:] -%}
|
||||
{%- else -%}
|
||||
{% set system_message = '' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{{ bos_token + system_message }}
|
||||
{%- for message in messages -%}
|
||||
{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}
|
||||
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if message['role'] == 'user' -%}
|
||||
{{ '<|User|>: ' + message['content'] + '\n' }}
|
||||
{%- elif message['role'] == 'assistant' -%}
|
||||
{{ '<|Assistant|>: ' + message['content'] + eos_token + '\n' }}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
|
||||
{%- if add_generation_prompt -%}
|
||||
{{ '<|Assistant|>: ' }}
|
||||
{% endif %}
|
||||
324
format.sh
324
format.sh
@ -1,321 +1,5 @@
|
||||
#!/usr/bin/env bash
|
||||
# YAPF formatter, adapted from ray and skypilot.
|
||||
#
|
||||
# Usage:
|
||||
# # Do work and commit your work.
|
||||
#!/bin/bash
|
||||
|
||||
# # Format files that differ from origin/main.
|
||||
# bash format.sh
|
||||
|
||||
# # Commit changed files with message 'Run yapf and ruff'
|
||||
#
|
||||
#
|
||||
# YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase.
|
||||
# You are encouraged to run this locally before pushing changes for review.
|
||||
|
||||
# Cause the script to exit if a single command fails
|
||||
set -eo pipefail
|
||||
|
||||
# this stops git rev-parse from failing if we run this from the .git directory
|
||||
builtin cd "$(dirname "${BASH_SOURCE:-$0}")"
|
||||
ROOT="$(git rev-parse --show-toplevel)"
|
||||
builtin cd "$ROOT" || exit 1
|
||||
|
||||
check_command() {
|
||||
if ! command -v "$1" &> /dev/null; then
|
||||
echo "❓❓$1 is not installed, please run \`pip install -r requirements-lint.txt\`"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
check_command yapf
|
||||
check_command ruff
|
||||
check_command mypy
|
||||
check_command codespell
|
||||
check_command isort
|
||||
check_command clang-format
|
||||
|
||||
YAPF_VERSION=$(yapf --version | awk '{print $2}')
|
||||
RUFF_VERSION=$(ruff --version | awk '{print $2}')
|
||||
MYPY_VERSION=$(mypy --version | awk '{print $2}')
|
||||
CODESPELL_VERSION=$(codespell --version)
|
||||
ISORT_VERSION=$(isort --vn)
|
||||
CLANGFORMAT_VERSION=$(clang-format --version | awk '{print $3}')
|
||||
PYMARKDOWNLNT_VERSION=$(pymarkdownlnt version | awk '{print $1}')
|
||||
|
||||
# # params: tool name, tool version, required version
|
||||
tool_version_check() {
|
||||
expected=$(grep "$1" requirements-lint.txt | cut -d'=' -f3)
|
||||
if [[ "$2" != "$expected" ]]; then
|
||||
echo "❓❓Wrong $1 version installed: $expected is required, not $2."
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
tool_version_check "yapf" "$YAPF_VERSION"
|
||||
tool_version_check "ruff" "$RUFF_VERSION"
|
||||
tool_version_check "mypy" "$MYPY_VERSION"
|
||||
tool_version_check "isort" "$ISORT_VERSION"
|
||||
tool_version_check "codespell" "$CODESPELL_VERSION"
|
||||
tool_version_check "clang-format" "$CLANGFORMAT_VERSION"
|
||||
tool_version_check "pymarkdownlnt" "$PYMARKDOWNLNT_VERSION"
|
||||
|
||||
YAPF_FLAGS=(
|
||||
'--recursive'
|
||||
'--parallel'
|
||||
)
|
||||
|
||||
YAPF_EXCLUDES=(
|
||||
'--exclude' 'build/**'
|
||||
)
|
||||
|
||||
# Format specified files
|
||||
format() {
|
||||
yapf --in-place "${YAPF_FLAGS[@]}" "$@"
|
||||
}
|
||||
|
||||
# Format files that differ from main branch. Ignores dirs that are not slated
|
||||
# for autoformat yet.
|
||||
format_changed() {
|
||||
# The `if` guard ensures that the list of filenames is not empty, which
|
||||
# could cause yapf to receive 0 positional arguments, making it hang
|
||||
# waiting for STDIN.
|
||||
#
|
||||
# `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that
|
||||
# exist on both branches.
|
||||
MERGEBASE="$(git merge-base origin/main HEAD)"
|
||||
|
||||
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
|
||||
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs -P 5 \
|
||||
yapf --in-place "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}"
|
||||
fi
|
||||
|
||||
}
|
||||
|
||||
# Format all files
|
||||
format_all() {
|
||||
yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" .
|
||||
}
|
||||
|
||||
## This flag formats individual files. --files *must* be the first command line
|
||||
## arg to use this option.
|
||||
if [[ "$1" == '--files' ]]; then
|
||||
format "${@:2}"
|
||||
# If `--all` is passed, then any further arguments are ignored and the
|
||||
# entire python directory is formatted.
|
||||
elif [[ "$1" == '--all' ]]; then
|
||||
format_all
|
||||
else
|
||||
# Format only the files that changed in last commit.
|
||||
format_changed
|
||||
fi
|
||||
echo 'vLLM yapf: Done'
|
||||
|
||||
# Run mypy
|
||||
echo 'vLLM mypy:'
|
||||
tools/mypy.sh
|
||||
echo 'vLLM mypy: Done'
|
||||
|
||||
|
||||
# If git diff returns a file that is in the skip list, the file may be checked anyway:
|
||||
# https://github.com/codespell-project/codespell/issues/1915
|
||||
# Avoiding the "./" prefix and using "/**" globs for directories appears to solve the problem
|
||||
CODESPELL_EXCLUDES=(
|
||||
'--skip' 'tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**'
|
||||
)
|
||||
|
||||
# check spelling of specified files
|
||||
spell_check() {
|
||||
codespell "$@"
|
||||
}
|
||||
|
||||
spell_check_all(){
|
||||
codespell --toml pyproject.toml "${CODESPELL_EXCLUDES[@]}"
|
||||
}
|
||||
|
||||
# Spelling check of files that differ from main branch.
|
||||
spell_check_changed() {
|
||||
# The `if` guard ensures that the list of filenames is not empty, which
|
||||
# could cause ruff to receive 0 positional arguments, making it hang
|
||||
# waiting for STDIN.
|
||||
#
|
||||
# `diff-filter=ACM` and $MERGEBASE is to ensure we only lint files that
|
||||
# exist on both branches.
|
||||
MERGEBASE="$(git merge-base origin/main HEAD)"
|
||||
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
|
||||
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \
|
||||
codespell "${CODESPELL_EXCLUDES[@]}"
|
||||
fi
|
||||
}
|
||||
|
||||
# Run Codespell
|
||||
## This flag runs spell check of individual files. --files *must* be the first command line
|
||||
## arg to use this option.
|
||||
if [[ "$1" == '--files' ]]; then
|
||||
spell_check "${@:2}"
|
||||
# If `--all` is passed, then any further arguments are ignored and the
|
||||
# entire python directory is linted.
|
||||
elif [[ "$1" == '--all' ]]; then
|
||||
spell_check_all
|
||||
else
|
||||
# Check spelling only of the files that changed in last commit.
|
||||
spell_check_changed
|
||||
fi
|
||||
echo 'vLLM codespell: Done'
|
||||
|
||||
|
||||
# Lint specified files
|
||||
lint() {
|
||||
ruff check "$@"
|
||||
}
|
||||
|
||||
# Lint files that differ from main branch. Ignores dirs that are not slated
|
||||
# for autolint yet.
|
||||
lint_changed() {
|
||||
# The `if` guard ensures that the list of filenames is not empty, which
|
||||
# could cause ruff to receive 0 positional arguments, making it hang
|
||||
# waiting for STDIN.
|
||||
#
|
||||
# `diff-filter=ACM` and $MERGEBASE is to ensure we only lint files that
|
||||
# exist on both branches.
|
||||
MERGEBASE="$(git merge-base origin/main HEAD)"
|
||||
|
||||
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
|
||||
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \
|
||||
ruff check
|
||||
fi
|
||||
|
||||
}
|
||||
|
||||
# Run Ruff
|
||||
### This flag lints individual files. --files *must* be the first command line
|
||||
### arg to use this option.
|
||||
if [[ "$1" == '--files' ]]; then
|
||||
lint "${@:2}"
|
||||
# If `--all` is passed, then any further arguments are ignored and the
|
||||
# entire python directory is linted.
|
||||
elif [[ "$1" == '--all' ]]; then
|
||||
lint vllm tests
|
||||
else
|
||||
# Format only the files that changed in last commit.
|
||||
lint_changed
|
||||
fi
|
||||
echo 'vLLM ruff: Done'
|
||||
|
||||
# check spelling of specified files
|
||||
isort_check() {
|
||||
isort "$@"
|
||||
}
|
||||
|
||||
isort_check_all(){
|
||||
isort .
|
||||
}
|
||||
|
||||
# Spelling check of files that differ from main branch.
|
||||
isort_check_changed() {
|
||||
# The `if` guard ensures that the list of filenames is not empty, which
|
||||
# could cause ruff to receive 0 positional arguments, making it hang
|
||||
# waiting for STDIN.
|
||||
#
|
||||
# `diff-filter=ACM` and $MERGEBASE is to ensure we only lint files that
|
||||
# exist on both branches.
|
||||
MERGEBASE="$(git merge-base origin/main HEAD)"
|
||||
|
||||
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
|
||||
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \
|
||||
isort
|
||||
fi
|
||||
}
|
||||
|
||||
# Run Isort
|
||||
# This flag runs spell check of individual files. --files *must* be the first command line
|
||||
# arg to use this option.
|
||||
if [[ "$1" == '--files' ]]; then
|
||||
isort_check "${@:2}"
|
||||
# If `--all` is passed, then any further arguments are ignored and the
|
||||
# entire python directory is linted.
|
||||
elif [[ "$1" == '--all' ]]; then
|
||||
isort_check_all
|
||||
else
|
||||
# Check spelling only of the files that changed in last commit.
|
||||
isort_check_changed
|
||||
fi
|
||||
echo 'vLLM isort: Done'
|
||||
|
||||
# Clang-format section
|
||||
# Exclude some files for formatting because they are vendored
|
||||
# NOTE: Keep up to date with .github/workflows/clang-format.yml
|
||||
CLANG_FORMAT_EXCLUDES=(
|
||||
'csrc/moe/topk_softmax_kernels.cu'
|
||||
'csrc/quantization/gguf/ggml-common.h'
|
||||
'csrc/quantization/gguf/dequantize.cuh'
|
||||
'csrc/quantization/gguf/vecdotq.cuh'
|
||||
'csrc/quantization/gguf/mmq.cuh'
|
||||
'csrc/quantization/gguf/mmvq.cuh'
|
||||
)
|
||||
|
||||
# Format specified files with clang-format
|
||||
clang_format() {
|
||||
clang-format -i "$@"
|
||||
}
|
||||
|
||||
# Format files that differ from main branch with clang-format.
|
||||
clang_format_changed() {
|
||||
# The `if` guard ensures that the list of filenames is not empty, which
|
||||
# could cause clang-format to receive 0 positional arguments, making it hang
|
||||
# waiting for STDIN.
|
||||
#
|
||||
# `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that
|
||||
# exist on both branches.
|
||||
MERGEBASE="$(git merge-base origin/main HEAD)"
|
||||
|
||||
# Get the list of changed files, excluding the specified ones
|
||||
changed_files=$(git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.h' '*.cpp' '*.cu' '*.cuh' | (grep -vFf <(printf "%s\n" "${CLANG_FORMAT_EXCLUDES[@]}") || echo -e))
|
||||
if [ -n "$changed_files" ]; then
|
||||
echo "$changed_files" | xargs -P 5 clang-format -i
|
||||
fi
|
||||
}
|
||||
|
||||
# Format all files with clang-format
|
||||
clang_format_all() {
|
||||
find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \
|
||||
| grep -vFf <(printf "%s\n" "${CLANG_FORMAT_EXCLUDES[@]}") \
|
||||
| xargs clang-format -i
|
||||
}
|
||||
|
||||
# Run clang-format
|
||||
if [[ "$1" == '--files' ]]; then
|
||||
clang_format "${@:2}"
|
||||
elif [[ "$1" == '--all' ]]; then
|
||||
clang_format_all
|
||||
else
|
||||
clang_format_changed
|
||||
fi
|
||||
echo 'vLLM clang-format: Done'
|
||||
|
||||
echo 'vLLM actionlint:'
|
||||
tools/actionlint.sh -color
|
||||
echo 'vLLM actionlint: Done'
|
||||
|
||||
echo 'vLLM shellcheck:'
|
||||
tools/shellcheck.sh
|
||||
echo 'vLLM shellcheck: Done'
|
||||
|
||||
echo 'excalidraw png check:'
|
||||
tools/png-lint.sh
|
||||
echo 'excalidraw png check: Done'
|
||||
|
||||
if ! git diff --quiet &>/dev/null; then
|
||||
echo
|
||||
echo "🔍🔍There are files changed by the format checker or by you that are not added and committed:"
|
||||
git --no-pager diff --name-only
|
||||
echo "🔍🔍Format checker passed, but please add, commit and push all the files above to include changes made by the format checker."
|
||||
|
||||
exit 1
|
||||
else
|
||||
echo "✨🎉 Format check passed! Congratulations! 🎉✨"
|
||||
fi
|
||||
|
||||
echo 'vLLM doc-lint:'
|
||||
tools/doc-lint.sh
|
||||
echo 'vLLM doc-lint: Done'
|
||||
echo "vLLM linting system has been moved from format.sh to pre-commit hook."
|
||||
echo "Please run 'pip install -r requirements-lint.txt' and 'pre-commit install' to install the pre-commit hook."
|
||||
echo "Then linters will run automatically before each commit."
|
||||
|
||||
@ -15,6 +15,11 @@ build-backend = "setuptools.build_meta"
|
||||
[tool.setuptools_scm]
|
||||
# version_file = "vllm/_version.py" # currently handled by `setup.py:get_version()`
|
||||
|
||||
[tool.yapfignore]
|
||||
ignore_patterns = [
|
||||
"build/**",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
# Allow lines to be as long as 80.
|
||||
line-length = 80
|
||||
@ -52,6 +57,9 @@ ignore = [
|
||||
"B007",
|
||||
# f-string format
|
||||
"UP032",
|
||||
# Python 3.8 typing
|
||||
"UP006", "UP035",
|
||||
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
|
||||
@ -19,7 +19,7 @@ pillow # Required for image processing
|
||||
prometheus-fastapi-instrumentator >= 7.0.0
|
||||
tiktoken >= 0.6.0 # Required for DBRX tokenizer
|
||||
lm-format-enforcer >= 0.10.9, < 0.11
|
||||
outlines == 0.1.11 # Requires pytorch
|
||||
outlines == 0.1.11
|
||||
lark == 1.2.2
|
||||
xgrammar >= 0.1.6; platform_machine == "x86_64"
|
||||
typing_extensions >= 4.10
|
||||
@ -34,6 +34,6 @@ pyyaml
|
||||
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
|
||||
setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12
|
||||
einops # Required for Qwen2-VL.
|
||||
compressed-tensors == 0.8.1 # required for compressed-tensors, requires pytorch
|
||||
compressed-tensors == 0.9.0 # required for compressed-tensors
|
||||
depyf==0.18.0 # required for profiling and debugging with compilation config
|
||||
cloudpickle # allows pickling lambda functions in model_executor/models/registry.py
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
|
||||
# Dependencies for HPU code
|
||||
ray
|
||||
triton
|
||||
triton==3.1.0
|
||||
pandas
|
||||
tabulate
|
||||
setuptools>=61
|
||||
|
||||
@ -1,15 +1,2 @@
|
||||
# formatting
|
||||
yapf==0.32.0
|
||||
toml==0.10.2
|
||||
tomli==2.0.2
|
||||
ruff==0.6.5
|
||||
codespell==2.3.0
|
||||
isort==5.13.2
|
||||
clang-format==18.1.5
|
||||
pymarkdownlnt==0.9.26
|
||||
|
||||
# type checking
|
||||
mypy==1.11.1
|
||||
types-PyYAML
|
||||
types-requests
|
||||
types-setuptools
|
||||
pre-commit==4.0.1
|
||||
|
||||
@ -29,4 +29,7 @@ lm-eval[api]==0.4.4 # required for model evaluation test
|
||||
bitsandbytes>=0.45.0
|
||||
buildkite-test-collector==0.1.9
|
||||
|
||||
genai_perf==0.0.8
|
||||
tritonclient==2.51.0
|
||||
|
||||
numpy < 2.0.0
|
||||
|
||||
@ -37,7 +37,7 @@ audioread==3.0.1
|
||||
# via librosa
|
||||
awscli==1.35.23
|
||||
# via -r requirements-test.in
|
||||
bitsandbytes>=0.45.0
|
||||
bitsandbytes==0.45.0
|
||||
# via -r requirements-test.in
|
||||
black==24.10.0
|
||||
# via datamodel-code-generator
|
||||
@ -75,6 +75,8 @@ colorama==0.4.6
|
||||
# tqdm-multiprocess
|
||||
contourpy==1.3.0
|
||||
# via matplotlib
|
||||
cramjam==2.9.0
|
||||
# via fastparquet
|
||||
cupy-cuda12x==13.3.0
|
||||
# via ray
|
||||
cycler==0.12.1
|
||||
@ -109,6 +111,8 @@ email-validator==2.2.0
|
||||
# via pydantic
|
||||
evaluate==0.4.3
|
||||
# via lm-eval
|
||||
fastparquet==2024.11.0
|
||||
# via genai-perf
|
||||
fastrlock==0.8.2
|
||||
# via cupy-cuda12x
|
||||
filelock==3.16.1
|
||||
@ -130,8 +134,11 @@ fsspec[http]==2024.9.0
|
||||
# via
|
||||
# datasets
|
||||
# evaluate
|
||||
# fastparquet
|
||||
# huggingface-hub
|
||||
# torch
|
||||
genai-perf==0.0.8
|
||||
# via -r requirements-test.in
|
||||
genson==1.3.0
|
||||
# via datamodel-code-generator
|
||||
h11==0.14.0
|
||||
@ -186,6 +193,8 @@ jsonschema==4.23.0
|
||||
# ray
|
||||
jsonschema-specifications==2024.10.1
|
||||
# via jsonschema
|
||||
kaleido==0.2.1
|
||||
# via genai-perf
|
||||
kiwisolver==1.4.7
|
||||
# via matplotlib
|
||||
lazy-loader==0.4
|
||||
@ -200,6 +209,8 @@ lm-eval[api]==0.4.4
|
||||
# via -r requirements-test.in
|
||||
lxml==5.3.0
|
||||
# via sacrebleu
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
markupsafe==3.0.2
|
||||
# via jinja2
|
||||
matplotlib==3.9.2
|
||||
@ -209,6 +220,8 @@ mbstrdecoder==1.1.3
|
||||
# dataproperty
|
||||
# pytablewriter
|
||||
# typepy
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mistral-common[opencv]==1.5.1
|
||||
# via
|
||||
# -r requirements-test.in
|
||||
@ -249,6 +262,8 @@ numpy==1.26.4
|
||||
# datasets
|
||||
# decord
|
||||
# evaluate
|
||||
# fastparquet
|
||||
# genai-perf
|
||||
# librosa
|
||||
# matplotlib
|
||||
# mistral-common
|
||||
@ -256,15 +271,18 @@ numpy==1.26.4
|
||||
# numexpr
|
||||
# opencv-python-headless
|
||||
# pandas
|
||||
# patsy
|
||||
# peft
|
||||
# rouge-score
|
||||
# sacrebleu
|
||||
# scikit-learn
|
||||
# scipy
|
||||
# soxr
|
||||
# statsmodels
|
||||
# tensorizer
|
||||
# torchvision
|
||||
# transformers
|
||||
# tritonclient
|
||||
nvidia-cublas-cu12==12.4.5.8
|
||||
# via
|
||||
# nvidia-cudnn-cu12
|
||||
@ -306,30 +324,39 @@ packaging==24.1
|
||||
# datamodel-code-generator
|
||||
# datasets
|
||||
# evaluate
|
||||
# fastparquet
|
||||
# huggingface-hub
|
||||
# lazy-loader
|
||||
# matplotlib
|
||||
# peft
|
||||
# plotly
|
||||
# pooch
|
||||
# pytest
|
||||
# pytest-rerunfailures
|
||||
# ray
|
||||
# statsmodels
|
||||
# transformers
|
||||
# typepy
|
||||
pandas==2.2.3
|
||||
# via
|
||||
# datasets
|
||||
# evaluate
|
||||
# fastparquet
|
||||
# genai-perf
|
||||
# statsmodels
|
||||
pathspec==0.12.1
|
||||
# via black
|
||||
pathvalidate==3.2.1
|
||||
# via pytablewriter
|
||||
patsy==1.0.1
|
||||
# via statsmodels
|
||||
peft==0.13.2
|
||||
# via
|
||||
# -r requirements-test.in
|
||||
# lm-eval
|
||||
pillow==10.4.0
|
||||
# via
|
||||
# genai-perf
|
||||
# matplotlib
|
||||
# mistral-common
|
||||
# sentence-transformers
|
||||
@ -338,6 +365,8 @@ platformdirs==4.3.6
|
||||
# via
|
||||
# black
|
||||
# pooch
|
||||
plotly==5.24.1
|
||||
# via genai-perf
|
||||
pluggy==1.5.0
|
||||
# via pytest
|
||||
pooch==1.8.2
|
||||
@ -360,7 +389,9 @@ psutil==6.1.0
|
||||
py==1.11.0
|
||||
# via pytest-forked
|
||||
pyarrow==18.0.0
|
||||
# via datasets
|
||||
# via
|
||||
# datasets
|
||||
# genai-perf
|
||||
pyasn1==0.6.1
|
||||
# via rsa
|
||||
pybind11==2.13.6
|
||||
@ -373,6 +404,8 @@ pydantic[email]==2.9.2
|
||||
# mistral-common
|
||||
pydantic-core==2.23.4
|
||||
# via pydantic
|
||||
pygments==2.18.0
|
||||
# via rich
|
||||
pyparsing==3.2.0
|
||||
# via matplotlib
|
||||
pytablewriter==1.2.0
|
||||
@ -381,14 +414,18 @@ pytest==8.3.3
|
||||
# via
|
||||
# -r requirements-test.in
|
||||
# buildkite-test-collector
|
||||
# genai-perf
|
||||
# pytest-asyncio
|
||||
# pytest-forked
|
||||
# pytest-mock
|
||||
# pytest-rerunfailures
|
||||
# pytest-shard
|
||||
pytest-asyncio==0.24.0
|
||||
# via -r requirements-test.in
|
||||
pytest-forked==1.6.0
|
||||
# via -r requirements-test.in
|
||||
pytest-mock==3.14.0
|
||||
# via genai-perf
|
||||
pytest-rerunfailures==14.0
|
||||
# via -r requirements-test.in
|
||||
pytest-shard==0.1.2
|
||||
@ -399,6 +436,8 @@ python-dateutil==2.9.0.post0
|
||||
# matplotlib
|
||||
# pandas
|
||||
# typepy
|
||||
python-rapidjson==1.20
|
||||
# via tritonclient
|
||||
pytz==2024.2
|
||||
# via
|
||||
# pandas
|
||||
@ -409,9 +448,11 @@ pyyaml==6.0.2
|
||||
# awscli
|
||||
# datamodel-code-generator
|
||||
# datasets
|
||||
# genai-perf
|
||||
# huggingface-hub
|
||||
# peft
|
||||
# ray
|
||||
# responses
|
||||
# timm
|
||||
# transformers
|
||||
ray[adag]==2.40.0
|
||||
@ -438,8 +479,13 @@ requests==2.32.3
|
||||
# mistral-common
|
||||
# pooch
|
||||
# ray
|
||||
# responses
|
||||
# tiktoken
|
||||
# transformers
|
||||
responses==0.25.3
|
||||
# via genai-perf
|
||||
rich==13.9.4
|
||||
# via genai-perf
|
||||
rouge-score==0.1.2
|
||||
# via lm-eval
|
||||
rpds-py==0.20.1
|
||||
@ -470,6 +516,7 @@ scipy==1.13.1
|
||||
# librosa
|
||||
# scikit-learn
|
||||
# sentence-transformers
|
||||
# statsmodels
|
||||
sentence-transformers==3.2.1
|
||||
# via -r requirements-test.in
|
||||
sentencepiece==0.2.0
|
||||
@ -490,6 +537,8 @@ soxr==0.5.0.post1
|
||||
# via librosa
|
||||
sqlitedict==2.1.0
|
||||
# via lm-eval
|
||||
statsmodels==0.14.4
|
||||
# via genai-perf
|
||||
sympy==1.13.1
|
||||
# via torch
|
||||
tabledata==1.3.3
|
||||
@ -499,7 +548,9 @@ tabulate==0.9.0
|
||||
tcolorpy==0.1.6
|
||||
# via pytablewriter
|
||||
tenacity==9.0.0
|
||||
# via lm-eval
|
||||
# via
|
||||
# lm-eval
|
||||
# plotly
|
||||
tensorizer==2.9.0
|
||||
# via -r requirements-test.in
|
||||
threadpoolctl==3.5.0
|
||||
@ -540,6 +591,7 @@ tqdm-multiprocess==0.0.11
|
||||
# via lm-eval
|
||||
transformers==4.47.0
|
||||
# via
|
||||
# genai-perf
|
||||
# lm-eval
|
||||
# peft
|
||||
# sentence-transformers
|
||||
@ -548,6 +600,10 @@ transformers-stream-generator==0.0.5
|
||||
# via -r requirements-test.in
|
||||
triton==3.1.0
|
||||
# via torch
|
||||
tritonclient==2.51.0
|
||||
# via
|
||||
# -r requirements-test.in
|
||||
# genai-perf
|
||||
typepy[datetime]==1.3.2
|
||||
# via
|
||||
# dataproperty
|
||||
@ -555,6 +611,7 @@ typepy[datetime]==1.3.2
|
||||
# tabledata
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# bitsandbytes
|
||||
# huggingface-hub
|
||||
# librosa
|
||||
# mistral-common
|
||||
@ -563,10 +620,12 @@ typing-extensions==4.12.2
|
||||
# torch
|
||||
tzdata==2024.2
|
||||
# via pandas
|
||||
urllib3==1.26.20
|
||||
urllib3==2.2.3
|
||||
# via
|
||||
# botocore
|
||||
# requests
|
||||
# responses
|
||||
# tritonclient
|
||||
word2number==1.1
|
||||
# via lm-eval
|
||||
xxhash==3.5.0
|
||||
|
||||
26
setup.py
26
setup.py
@ -228,8 +228,11 @@ class cmake_build_ext(build_ext):
|
||||
|
||||
# CMake appends the extension prefix to the install path,
|
||||
# and outdir already contains that prefix, so we need to remove it.
|
||||
# We assume only the final component of extension prefix is added by
|
||||
# CMake, this is currently true for current extensions but may not
|
||||
# always be the case.
|
||||
prefix = outdir
|
||||
for i in range(ext.name.count('.')):
|
||||
if '.' in ext.name:
|
||||
prefix = prefix.parent
|
||||
|
||||
# prefix here should actually be the same for all components
|
||||
@ -298,9 +301,11 @@ class repackage_wheel(build_ext):
|
||||
files_to_copy = [
|
||||
"vllm/_C.abi3.so",
|
||||
"vllm/_moe_C.abi3.so",
|
||||
"vllm/vllm_flash_attn/vllm_flash_attn_c.abi3.so",
|
||||
"vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so",
|
||||
"vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so",
|
||||
"vllm/vllm_flash_attn/flash_attn_interface.py",
|
||||
"vllm/vllm_flash_attn/__init__.py",
|
||||
"vllm/cumem_allocator.abi3.so",
|
||||
# "vllm/_version.py", # not available in nightly wheels yet
|
||||
]
|
||||
file_members = filter(lambda x: x.filename in files_to_copy,
|
||||
@ -472,13 +477,9 @@ def get_gaudi_sw_version():
|
||||
|
||||
|
||||
def get_vllm_version() -> str:
|
||||
# TODO: Revisit this temporary approach: https://github.com/vllm-project/vllm/issues/9182#issuecomment-2404860236
|
||||
try:
|
||||
version = get_version(
|
||||
write_to="vllm/_version.py", # TODO: move this to pyproject.toml
|
||||
)
|
||||
except LookupError:
|
||||
version = "0.0.0"
|
||||
version = get_version(
|
||||
write_to="vllm/_version.py", # TODO: move this to pyproject.toml
|
||||
)
|
||||
|
||||
sep = "+" if "+" not in version else "." # dev versions might contain +
|
||||
|
||||
@ -553,7 +554,7 @@ def get_requirements() -> List[str]:
|
||||
return resolved_requirements
|
||||
|
||||
if _no_device():
|
||||
requirements = _read_requirements("requirements-cuda.txt")
|
||||
requirements = _read_requirements("requirements-cpu.txt")
|
||||
elif _is_cuda():
|
||||
requirements = _read_requirements("requirements-cuda.txt")
|
||||
cuda_major, cuda_minor = torch.version.cuda.split(".")
|
||||
@ -596,8 +597,9 @@ if _is_hip():
|
||||
ext_modules.append(CMakeExtension(name="vllm._rocm_C"))
|
||||
|
||||
if _is_cuda():
|
||||
ext_modules.append(
|
||||
CMakeExtension(name="vllm.vllm_flash_attn.vllm_flash_attn_c"))
|
||||
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C"))
|
||||
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
|
||||
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))
|
||||
|
||||
if _build_custom_ops():
|
||||
ext_modules.append(CMakeExtension(name="vllm._C"))
|
||||
|
||||
112
tests/basic_correctness/test_cumem.py
Normal file
112
tests/basic_correctness/test_cumem.py
Normal file
@ -0,0 +1,112 @@
|
||||
import torch
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.device_allocator.cumem import CuMemAllocator
|
||||
from vllm.utils import GiB_bytes
|
||||
|
||||
from ..utils import fork_new_process_for_each_test
|
||||
|
||||
|
||||
@fork_new_process_for_each_test
|
||||
def test_basic_cumem():
|
||||
# some tensors from default memory pool
|
||||
shape = (1024, 1024)
|
||||
x = torch.empty(shape, device='cuda')
|
||||
x.zero_()
|
||||
|
||||
# some tensors from custom memory pool
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
with allocator.use_memory_pool():
|
||||
# custom memory pool
|
||||
y = torch.empty(shape, device='cuda')
|
||||
y.zero_()
|
||||
y += 1
|
||||
z = torch.empty(shape, device='cuda')
|
||||
z.zero_()
|
||||
z += 2
|
||||
|
||||
# they can be used together
|
||||
output = x + y + z
|
||||
assert torch.allclose(output, torch.ones_like(output) * 3)
|
||||
|
||||
free_bytes = torch.cuda.mem_get_info()[0]
|
||||
allocator.sleep()
|
||||
free_bytes_after_sleep = torch.cuda.mem_get_info()[0]
|
||||
assert free_bytes_after_sleep > free_bytes
|
||||
allocator.wake_up()
|
||||
|
||||
# they can be used together
|
||||
output = x + y + z
|
||||
assert torch.allclose(output, torch.ones_like(output) * 3)
|
||||
|
||||
|
||||
@fork_new_process_for_each_test
|
||||
def test_cumem_with_cudagraph():
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
with allocator.use_memory_pool():
|
||||
weight = torch.eye(1024, device='cuda')
|
||||
with allocator.use_memory_pool(tag="discard"):
|
||||
cache = torch.empty(1024, 1024, device='cuda')
|
||||
|
||||
def model(x):
|
||||
out = x @ weight
|
||||
cache[:out.size(0)].copy_(out)
|
||||
return out + 1
|
||||
|
||||
x = torch.empty(128, 1024, device='cuda')
|
||||
|
||||
# warmup
|
||||
model(x)
|
||||
|
||||
# capture cudagraph
|
||||
model_graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(model_graph):
|
||||
y = model(x)
|
||||
|
||||
free_bytes = torch.cuda.mem_get_info()[0]
|
||||
allocator.sleep()
|
||||
free_bytes_after_sleep = torch.cuda.mem_get_info()[0]
|
||||
assert free_bytes_after_sleep > free_bytes
|
||||
allocator.wake_up()
|
||||
|
||||
# after waking up, the content in the weight tensor
|
||||
# should be restored, but the content in the cache tensor
|
||||
# should be discarded
|
||||
|
||||
# this operation is also compatible with cudagraph
|
||||
|
||||
x.random_()
|
||||
model_graph.replay()
|
||||
|
||||
# cache content is as expected
|
||||
assert torch.allclose(x, cache[:x.size(0)])
|
||||
|
||||
# output content is as expected
|
||||
assert torch.allclose(y, x + 1)
|
||||
|
||||
|
||||
@fork_new_process_for_each_test
|
||||
def test_end_to_end():
|
||||
free, total = torch.cuda.mem_get_info()
|
||||
used_bytes_baseline = total - free # in case other process is running
|
||||
llm = LLM("meta-llama/Llama-3.2-1B", enable_sleep_mode=True)
|
||||
prompt = "How are you?"
|
||||
sampling_params = SamplingParams(temperature=0, max_tokens=10)
|
||||
output = llm.generate(prompt, sampling_params)
|
||||
|
||||
# the benefit of `llm.sleep(level=2)` is mainly CPU memory usage,
|
||||
# which is difficult to measure in the test. therefore, we only
|
||||
# test sleep level 1 here.
|
||||
llm.sleep(level=1)
|
||||
|
||||
free_gpu_bytes_after_sleep, total = torch.cuda.mem_get_info()
|
||||
used_bytes = total - free_gpu_bytes_after_sleep - used_bytes_baseline
|
||||
# now the memory usage is mostly cudagraph memory pool,
|
||||
# and it should be less than the model weights (1B model, 2GiB weights)
|
||||
assert used_bytes < 2 * GiB_bytes
|
||||
|
||||
llm.wake_up()
|
||||
output2 = llm.generate(prompt, sampling_params)
|
||||
|
||||
# cmp output
|
||||
assert output[0].outputs[0].text == output2[0].outputs[0].text
|
||||
@ -244,6 +244,7 @@ def video_assets() -> _VideoAssets:
|
||||
|
||||
|
||||
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
|
||||
_R = TypeVar("_R")
|
||||
|
||||
|
||||
class HfRunner:
|
||||
@ -930,6 +931,10 @@ class VllmRunner:
|
||||
req_outputs = self.model.score(text_1, text_2)
|
||||
return [req_output.outputs.score for req_output in req_outputs]
|
||||
|
||||
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
|
||||
executor = self.model.llm_engine.model_executor
|
||||
return executor.apply_model(func)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
|
||||
@ -796,6 +796,44 @@ class TestPrefixCachingBlockAllocator:
|
||||
block_hashes=block_hashes_seq1)
|
||||
assert len(cached_blocks) == len(blocks_seq1) - num_evicted_blocks
|
||||
|
||||
# Test reset prefix cache
|
||||
@staticmethod
|
||||
@pytest.mark.parametrize("num_blocks", [10])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
def test_reset_prefix_cache(num_blocks: int, block_size: int):
|
||||
"""This test case simulates the case of resetting the prefix cache."""
|
||||
|
||||
allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
|
||||
block_size=block_size)
|
||||
token_ids = list(range(3 * block_size))
|
||||
|
||||
first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
|
||||
block_size=block_size,
|
||||
token_ids=token_ids,
|
||||
allocator=allocator,
|
||||
)
|
||||
second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
|
||||
block_size=block_size,
|
||||
token_ids=token_ids,
|
||||
allocator=allocator,
|
||||
)
|
||||
|
||||
# Free each block in the first chain.
|
||||
for block in first_chain:
|
||||
allocator.free(block)
|
||||
|
||||
# Failed to reset prefix cache because some blocks are not freed yet.
|
||||
assert not allocator.reset_prefix_cache()
|
||||
assert allocator.get_prefix_cache_hit_rate() > 0.0
|
||||
|
||||
# Free each block in the second chain.
|
||||
for block in second_chain:
|
||||
allocator.free(block)
|
||||
|
||||
# Reset prefix cache.
|
||||
assert allocator.reset_prefix_cache()
|
||||
assert allocator.get_prefix_cache_hit_rate() == 0.0
|
||||
|
||||
@staticmethod
|
||||
def create_immutable_chain(
|
||||
block_size: int,
|
||||
|
||||
56
tests/distributed/test_torchrun_example.py
Normal file
56
tests/distributed/test_torchrun_example.py
Normal file
@ -0,0 +1,56 @@
|
||||
# unit test for `examples/offline_inference/torchrun_example.py`
|
||||
|
||||
import random
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
|
||||
# Create prompts
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
# set different `gpu_memory_utilization` and `swap_space` for different ranks,
|
||||
# to test if all ranks agree on the same kv cache configuration.
|
||||
llm = LLM(model="facebook/opt-125m",
|
||||
tensor_parallel_size=2,
|
||||
distributed_executor_backend="external_launcher",
|
||||
gpu_memory_utilization=random.uniform(0.7, 0.9),
|
||||
swap_space=random.randint(1, 4))
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
cpu_group = get_world_group().cpu_group
|
||||
|
||||
torch_rank = dist.get_rank(group=cpu_group)
|
||||
|
||||
|
||||
def test_consistent_across_ranks(obj):
|
||||
if torch_rank == 0:
|
||||
dist.broadcast_object_list([obj], src=0, group=cpu_group)
|
||||
else:
|
||||
container = [None]
|
||||
dist.broadcast_object_list(container, src=0, group=cpu_group)
|
||||
assert container[0] == obj
|
||||
|
||||
|
||||
test_consistent_across_ranks(
|
||||
llm.llm_engine.vllm_config.cache_config.num_cpu_blocks)
|
||||
test_consistent_across_ranks(
|
||||
llm.llm_engine.vllm_config.cache_config.num_gpu_blocks)
|
||||
|
||||
# all ranks should have the same outputs
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
test_consistent_across_ranks(prompt)
|
||||
test_consistent_across_ranks(generated_text)
|
||||
print(f"Rank {torch_rank}, Prompt: {prompt!r}, "
|
||||
f"Generated text: {generated_text!r}")
|
||||
@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import pytest
|
||||
|
||||
@ -18,7 +18,7 @@ class Mock:
|
||||
class CustomUniExecutor(UniProcExecutor):
|
||||
|
||||
def collective_rpc(self,
|
||||
method: str,
|
||||
method: Union[str, Callable],
|
||||
timeout: Optional[float] = None,
|
||||
args: Tuple = (),
|
||||
kwargs: Optional[Dict] = None) -> List[Any]:
|
||||
@ -51,7 +51,9 @@ def test_custom_executor(model, tmp_path):
|
||||
assert not os.path.exists(".marker")
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model, distributed_executor_backend=CustomUniExecutor)
|
||||
model=model,
|
||||
distributed_executor_backend=CustomUniExecutor,
|
||||
)
|
||||
engine = LLMEngine.from_engine_args(engine_args)
|
||||
sampling_params = SamplingParams(max_tokens=1)
|
||||
|
||||
|
||||
@ -22,7 +22,7 @@ class DummyWorkerWrapper(WorkerWrapperBase):
|
||||
# simulate error case
|
||||
raise worker_input
|
||||
|
||||
return self.rank, input
|
||||
return self.rpc_rank, input
|
||||
|
||||
|
||||
def _start_workers() -> Tuple[List[ProcessWorkerWrapper], WorkerMonitor]:
|
||||
|
||||
36
tests/entrypoints/llm/test_collective_rpc.py
Normal file
36
tests/entrypoints/llm/test_collective_rpc.py
Normal file
@ -0,0 +1,36 @@
|
||||
import pytest
|
||||
|
||||
from vllm import LLM
|
||||
|
||||
from ...utils import fork_new_process_for_each_test
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tp_size", [1, 2])
|
||||
@pytest.mark.parametrize("backend", ["mp", "ray"])
|
||||
@fork_new_process_for_each_test
|
||||
def test_collective_rpc(tp_size, backend):
|
||||
if tp_size == 1 and backend == "ray":
|
||||
pytest.skip("Skip duplicate test case")
|
||||
if tp_size == 1:
|
||||
backend = None
|
||||
|
||||
# intentionally define the method and class in the test function,
|
||||
# to test if they can be serialized and sent to the workers
|
||||
def echo_rank(self):
|
||||
return self.rank
|
||||
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
class MyWorker(Worker):
|
||||
|
||||
def echo_rank(self):
|
||||
return self.rank
|
||||
|
||||
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
|
||||
enforce_eager=True,
|
||||
load_format="dummy",
|
||||
tensor_parallel_size=tp_size,
|
||||
distributed_executor_backend=backend,
|
||||
worker_cls=MyWorker)
|
||||
for method in ["echo_rank", echo_rank]:
|
||||
assert llm.collective_rpc(method) == list(range(tp_size))
|
||||
@ -105,3 +105,10 @@ def test_multiple_pooling_params(llm: LLM):
|
||||
# pooling_params is None, default params should be applied
|
||||
outputs = llm.encode(PROMPTS, pooling_params=None)
|
||||
assert len(PROMPTS) == len(outputs)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_right_side_truncation(llm: LLM):
|
||||
# Embeddings models should truncate the end of the prompt
|
||||
tokenizer = llm.get_tokenizer()
|
||||
assert tokenizer.truncation_side == "right"
|
||||
|
||||
@ -17,6 +17,33 @@ MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||
# generation quality here
|
||||
LORA_NAME = "typeof/zephyr-7b-beta-lora"
|
||||
|
||||
BADREQUEST_CASES = [
|
||||
(
|
||||
"test_rank",
|
||||
{
|
||||
"r": 1024
|
||||
},
|
||||
"is greater than max_lora_rank",
|
||||
),
|
||||
(
|
||||
"test_bias",
|
||||
{
|
||||
"bias": "all"
|
||||
},
|
||||
"Adapter bias cannot be used without bias_enabled",
|
||||
),
|
||||
("test_dora", {
|
||||
"use_dora": True
|
||||
}, "does not yet support DoRA"),
|
||||
(
|
||||
"test_modules_to_save",
|
||||
{
|
||||
"modules_to_save": ["lm_head"]
|
||||
},
|
||||
"only supports modules_to_save being None",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def zephyr_lora_files():
|
||||
@ -138,32 +165,36 @@ async def test_dynamic_lora_invalid_files(client: openai.AsyncOpenAI,
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dynamic_lora_invalid_lora_rank(client: openai.AsyncOpenAI,
|
||||
tmp_path, zephyr_lora_files):
|
||||
invalid_rank = tmp_path / "invalid_rank"
|
||||
@pytest.mark.parametrize("test_name,config_change,expected_error",
|
||||
BADREQUEST_CASES)
|
||||
async def test_dynamic_lora_badrequests(client: openai.AsyncOpenAI, tmp_path,
|
||||
zephyr_lora_files, test_name: str,
|
||||
config_change: dict,
|
||||
expected_error: str):
|
||||
# Create test directory
|
||||
test_dir = tmp_path / test_name
|
||||
|
||||
# Copy adapter from zephyr_lora_files to invalid_rank
|
||||
shutil.copytree(zephyr_lora_files, invalid_rank)
|
||||
# Copy adapter files
|
||||
shutil.copytree(zephyr_lora_files, test_dir)
|
||||
|
||||
with open(invalid_rank / "adapter_config.json") as f:
|
||||
# Load and modify configuration
|
||||
config_path = test_dir / "adapter_config.json"
|
||||
with open(config_path) as f:
|
||||
adapter_config = json.load(f)
|
||||
# Apply configuration changes
|
||||
adapter_config.update(config_change)
|
||||
|
||||
print(adapter_config)
|
||||
|
||||
# assert False
|
||||
|
||||
# Change rank to invalid value
|
||||
adapter_config["r"] = 1024
|
||||
with open(invalid_rank / "adapter_config.json", "w") as f:
|
||||
# Save modified configuration
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(adapter_config, f)
|
||||
|
||||
with pytest.raises(openai.BadRequestError,
|
||||
match="is greater than max_lora_rank"):
|
||||
# Test loading the adapter
|
||||
with pytest.raises(openai.BadRequestError, match=expected_error):
|
||||
await client.post("load_lora_adapter",
|
||||
cast_to=str,
|
||||
body={
|
||||
"lora_name": "invalid-json",
|
||||
"lora_path": str(invalid_rank)
|
||||
"lora_name": test_name,
|
||||
"lora_path": str(test_dir)
|
||||
})
|
||||
|
||||
|
||||
|
||||
@ -12,6 +12,9 @@ MODEL_NAME = "BAAI/bge-reranker-v2-m3"
|
||||
def server():
|
||||
args = [
|
||||
"--enforce-eager",
|
||||
# Will be used on tests to compare prompt input length
|
||||
"--max-model-len",
|
||||
"100"
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
@ -20,8 +23,7 @@ def server():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_text_1_str_text_2_list(server: RemoteOpenAIServer,
|
||||
model_name: str):
|
||||
def test_text_1_str_text_2_list(server: RemoteOpenAIServer, model_name: str):
|
||||
text_1 = "What is the capital of France?"
|
||||
text_2 = [
|
||||
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
|
||||
@ -45,8 +47,7 @@ async def test_text_1_str_text_2_list(server: RemoteOpenAIServer,
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_text_1_list_text_2_list(server: RemoteOpenAIServer,
|
||||
model_name: str):
|
||||
def test_text_1_list_text_2_list(server: RemoteOpenAIServer, model_name: str):
|
||||
text_1 = [
|
||||
"What is the capital of the United States?",
|
||||
"What is the capital of France?"
|
||||
@ -73,8 +74,7 @@ async def test_text_1_list_text_2_list(server: RemoteOpenAIServer,
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_text_1_str_text_2_str(server: RemoteOpenAIServer,
|
||||
model_name: str):
|
||||
def test_text_1_str_text_2_str(server: RemoteOpenAIServer, model_name: str):
|
||||
text_1 = "What is the capital of France?"
|
||||
text_2 = "The capital of France is Paris."
|
||||
|
||||
@ -91,3 +91,36 @@ async def test_text_1_str_text_2_str(server: RemoteOpenAIServer,
|
||||
assert score.data is not None
|
||||
assert len(score.data) == 1
|
||||
assert score.data[0].score >= 0.9
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_score_max_model_len(server: RemoteOpenAIServer, model_name: str):
|
||||
|
||||
text_1 = "What is the capital of France?" * 20
|
||||
text_2 = [
|
||||
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
|
||||
]
|
||||
|
||||
score_response = requests.post(server.url_for("score"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"text_1": text_1,
|
||||
"text_2": text_2,
|
||||
})
|
||||
assert score_response.status_code == 400
|
||||
# Assert just a small fragments of the response
|
||||
assert "Please reduce the length of the input." in \
|
||||
score_response.text
|
||||
|
||||
# Test truncation
|
||||
score_response = requests.post(server.url_for("score"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"text_1": text_1,
|
||||
"text_2": text_2,
|
||||
"truncate_prompt_tokens": 101
|
||||
})
|
||||
assert score_response.status_code == 400
|
||||
assert "Please, select a smaller truncation size." in \
|
||||
score_response.text
|
||||
|
||||
@ -754,6 +754,7 @@ def test_resolve_content_format_hf_defined(model, expected_format):
|
||||
("template_chatglm.jinja", "string"),
|
||||
("template_chatglm2.jinja", "string"),
|
||||
("template_chatml.jinja", "string"),
|
||||
("template_deepseek_vl2.jinja", "string"),
|
||||
("template_falcon_180b.jinja", "string"),
|
||||
("template_falcon.jinja", "string"),
|
||||
("template_inkbot.jinja", "string"),
|
||||
|
||||
@ -1,90 +0,0 @@
|
||||
{
|
||||
"model_type": "llama",
|
||||
"kv_cache": {
|
||||
"dtype": "float8_e4m3fn",
|
||||
"scaling_factor": {
|
||||
"0": {
|
||||
"0": 0.0230364128947258,
|
||||
"1": 0.01979283057153225,
|
||||
"2": 0.0241350457072258,
|
||||
"3": 0.0308314748108387,
|
||||
"4": 0.0430733822286129,
|
||||
"5": 0.0370396226644516,
|
||||
"6": 0.0306222103536129,
|
||||
"7": 0.0357491634786129,
|
||||
"8": 0.0358189195394516,
|
||||
"9": 0.0443289652466774,
|
||||
"10": 0.0433175228536129,
|
||||
"11": 0.0416782945394516,
|
||||
"12": 0.0366908498108387,
|
||||
"13": 0.0432477705180645,
|
||||
"14": 0.0410505048930645,
|
||||
"15": 0.0457589291036129,
|
||||
"16": 0.0418526791036129,
|
||||
"17": 0.0432477705180645,
|
||||
"18": 0.0469447560608387,
|
||||
"19": 0.0514787957072258,
|
||||
"20": 0.0541294664144516,
|
||||
"21": 0.0587681382894516,
|
||||
"22": 0.0625,
|
||||
"23": 0.0585588738322258,
|
||||
"24": 0.0600237175822258,
|
||||
"25": 0.0588030144572258,
|
||||
"26": 0.0531180277466774,
|
||||
"27": 0.06396484375,
|
||||
"28": 0.0603027381002903,
|
||||
"29": 0.0582101047039032,
|
||||
"30": 0.0625348836183548,
|
||||
"31": 0.0585588738322258,
|
||||
"32": 0.0582798570394516,
|
||||
"33": 0.0575125589966774,
|
||||
"34": 0.0590820349752903,
|
||||
"35": 0.0614188089966774,
|
||||
"36": 0.0631975457072258,
|
||||
"37": 0.0615931935608387,
|
||||
"38": 0.0601283498108387,
|
||||
"39": 0.0571986623108387,
|
||||
"40": 0.0670340433716774,
|
||||
"41": 0.0523507259786129,
|
||||
"42": 0.0547223798930645,
|
||||
"43": 0.0631975457072258,
|
||||
"44": 0.0663713738322258,
|
||||
"45": 0.0603376142680645,
|
||||
"46": 0.0652204304933548,
|
||||
"47": 0.0734514519572258,
|
||||
"48": 0.0693708211183548,
|
||||
"49": 0.0725446492433548,
|
||||
"50": 0.0627790242433548,
|
||||
"51": 0.0691266804933548,
|
||||
"52": 0.0688825398683548,
|
||||
"53": 0.068429134786129,
|
||||
"54": 0.0605119988322258,
|
||||
"55": 0.0799386203289032,
|
||||
"56": 0.0853097140789032,
|
||||
"57": 0.0661969929933548,
|
||||
"58": 0.0689871683716774,
|
||||
"59": 0.0724051371216774,
|
||||
"60": 0.0541643425822258,
|
||||
"61": 0.0626743882894516,
|
||||
"62": 0.0628487765789032,
|
||||
"63": 0.0607212632894516,
|
||||
"64": 0.0589076466858387,
|
||||
"65": 0.0451660193502903,
|
||||
"66": 0.0453055277466774,
|
||||
"67": 0.0414341539144516,
|
||||
"68": 0.0385044664144516,
|
||||
"69": 0.0414341539144516,
|
||||
"70": 0.0466308631002903,
|
||||
"71": 0.0399693101644516,
|
||||
"72": 0.0437011756002903,
|
||||
"73": 0.0434221550822258,
|
||||
"74": 0.0428989976644516,
|
||||
"75": 0.0401785746216774,
|
||||
"76": 0.0431082621216774,
|
||||
"77": 0.0484444759786129,
|
||||
"78": 0.0417829267680645,
|
||||
"79": 0.0418178029358387
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,42 +0,0 @@
|
||||
{
|
||||
"model_type": "llama",
|
||||
"kv_cache": {
|
||||
"dtype": "float8_e4m3fn",
|
||||
"scaling_factor": {
|
||||
"0": {
|
||||
"0": 0.0152239128947258,
|
||||
"1": 0.0188860222697258,
|
||||
"2": 0.0354178324341774,
|
||||
"3": 0.0376674123108387,
|
||||
"4": 0.0418526791036129,
|
||||
"5": 0.0433175228536129,
|
||||
"6": 0.0397600457072258,
|
||||
"7": 0.0424455925822258,
|
||||
"8": 0.0415387861430645,
|
||||
"9": 0.0408412404358387,
|
||||
"10": 0.0395856611430645,
|
||||
"11": 0.0377371683716774,
|
||||
"12": 0.0400739423930645,
|
||||
"13": 0.040771484375,
|
||||
"14": 0.0393415205180645,
|
||||
"15": 0.0369001142680645,
|
||||
"16": 0.03857421875,
|
||||
"17": 0.0387486070394516,
|
||||
"18": 0.0403180830180645,
|
||||
"19": 0.0396205373108387,
|
||||
"20": 0.0375627800822258,
|
||||
"21": 0.0407366082072258,
|
||||
"22": 0.0432477705180645,
|
||||
"23": 0.0377022884786129,
|
||||
"24": 0.0399693101644516,
|
||||
"25": 0.0374581478536129,
|
||||
"26": 0.0413295216858387,
|
||||
"27": 0.0442243330180645,
|
||||
"28": 0.0424804724752903,
|
||||
"29": 0.0456891767680645,
|
||||
"30": 0.0409109964966774,
|
||||
"31": 0.0482352152466774
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user