mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-05 20:17:07 +08:00
Merge branch 'main' into woosuk/fa3-swa-cudagraph
This commit is contained in:
commit
06fba5410c
@ -49,6 +49,7 @@ best_throughput=0
|
||||
best_max_num_seqs=0
|
||||
best_num_batched_tokens=0
|
||||
best_goodput=0
|
||||
best_request_rate=0
|
||||
|
||||
start_server() {
|
||||
local gpu_memory_utilization=$1
|
||||
@ -57,18 +58,35 @@ start_server() {
|
||||
local vllm_log=$4
|
||||
local profile_dir=$5
|
||||
|
||||
pkill -f vllm
|
||||
pkill -if vllm
|
||||
|
||||
VLLM_USE_V1=1 VLLM_SERVER_DEV_MODE=1 VLLM_TORCH_PROFILER_DIR=$profile_dir vllm serve $MODEL \
|
||||
--port 8004 \
|
||||
--gpu-memory-utilization $gpu_memory_utilization \
|
||||
--max-num-seqs $max_num_seqs \
|
||||
--max-num-batched-tokens $max_num_batched_tokens \
|
||||
--tensor-parallel-size $TP \
|
||||
--enable-prefix-caching \
|
||||
--load-format dummy \
|
||||
--download-dir "$DOWNLOAD_DIR" \
|
||||
--max-model-len $MAX_MODEL_LEN > "$vllm_log" 2>&1 &
|
||||
# Define the common arguments as a bash array.
|
||||
# Each argument and its value are separate elements.
|
||||
local common_args_array=(
|
||||
"$MODEL"
|
||||
"--disable-log-requests"
|
||||
"--port" "8004"
|
||||
"--gpu-memory-utilization" "$gpu_memory_utilization"
|
||||
"--max-num-seqs" "$max_num_seqs"
|
||||
"--max-num-batched-tokens" "$max_num_batched_tokens"
|
||||
"--tensor-parallel-size" "$TP"
|
||||
"--enable-prefix-caching"
|
||||
"--load-format" "dummy"
|
||||
"--download-dir" "$DOWNLOAD_DIR"
|
||||
"--max-model-len" "$MAX_MODEL_LEN"
|
||||
)
|
||||
|
||||
# Use the array expansion "${common_args_array[@]}"
|
||||
# This correctly passes each element as a separate argument.
|
||||
if [[ -n "$profile_dir" ]]; then
|
||||
# Start server with profiling enabled
|
||||
VLLM_USE_V1=1 VLLM_SERVER_DEV_MODE=1 VLLM_TORCH_PROFILER_DIR=$profile_dir \
|
||||
vllm serve "${common_args_array[@]}" > "$vllm_log" 2>&1 &
|
||||
else
|
||||
# Start server without profiling
|
||||
VLLM_USE_V1=1 VLLM_SERVER_DEV_MODE=1 \
|
||||
vllm serve "${common_args_array[@]}" > "$vllm_log" 2>&1 &
|
||||
fi
|
||||
|
||||
# wait for 10 minutes...
|
||||
server_started=0
|
||||
@ -82,6 +100,7 @@ start_server() {
|
||||
sleep 10
|
||||
fi
|
||||
done
|
||||
|
||||
if (( ! server_started )); then
|
||||
echo "server did not start within 10 minutes. Please check server log at $vllm_log".
|
||||
return 1
|
||||
@ -90,37 +109,20 @@ start_server() {
|
||||
fi
|
||||
}
|
||||
|
||||
update_best_profile() {
|
||||
local profile_dir=$1
|
||||
local profile_index=$2
|
||||
sorted_paths=($(find "$profile_dir" -maxdepth 1 -not -path "$profile_dir" | sort))
|
||||
selected_profile_file=
|
||||
if [[ "$SYSTEM" == "TPU" ]]; then
|
||||
selected_profile_file="${sorted_paths[$profile_index]}/*.xplane.pb"
|
||||
fi
|
||||
if [[ "$SYSTEM" == "GPU" ]]; then
|
||||
selected_profile_file="${sorted_paths[$profile_index]}"
|
||||
fi
|
||||
rm -f $PROFILE_PATH/*
|
||||
cp $selected_profile_file $PROFILE_PATH
|
||||
}
|
||||
|
||||
run_benchmark() {
|
||||
local max_num_seqs=$1
|
||||
local max_num_batched_tokens=$2
|
||||
local gpu_memory_utilization=$3
|
||||
echo "max_num_seq: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens"
|
||||
local vllm_log="$LOG_FOLDER/vllm_log_${max_num_seqs}_${max_num_batched_tokens}.txt"
|
||||
local profile_dir="$LOG_FOLDER/profile_${max_num_seqs}_${max_num_batched_tokens}"
|
||||
echo "vllm_log: $vllm_log"
|
||||
echo
|
||||
rm -f $vllm_log
|
||||
mkdir -p $profile_dir
|
||||
pkill -f vllm
|
||||
local profile_index=0
|
||||
pkill -if vllm
|
||||
|
||||
echo "starting server..."
|
||||
start_server $gpu_memory_utilization $max_num_seqs $max_num_batched_tokens $vllm_log $profile_dir
|
||||
# Call start_server without a profile_dir to avoid profiling overhead
|
||||
start_server $gpu_memory_utilization $max_num_seqs $max_num_batched_tokens $vllm_log ""
|
||||
result=$?
|
||||
if [[ "$result" -eq 1 ]]; then
|
||||
echo "server failed to start. gpu_memory_utilization:$gpu_memory_utilization, max_num_seqs:$max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens"
|
||||
@ -134,7 +136,8 @@ run_benchmark() {
|
||||
# get a basic qps by using request-rate inf
|
||||
bm_log="$LOG_FOLDER/bm_log_${max_num_seqs}_${max_num_batched_tokens}_requestrate_inf.txt"
|
||||
prefix_len=$(( INPUT_LEN * MIN_CACHE_HIT_PCT / 100 ))
|
||||
adjusted_input_len=$(( INPUT_LEN - prefix_len ))
|
||||
adjusted_input_len=$(( INPUT_LEN - prefix_len ))
|
||||
# --profile flag is removed from this call
|
||||
vllm bench serve \
|
||||
--backend vllm \
|
||||
--model $MODEL \
|
||||
@ -148,8 +151,7 @@ adjusted_input_len=$(( INPUT_LEN - prefix_len ))
|
||||
--goodput e2el:$MAX_LATENCY_ALLOWED_MS \
|
||||
--num-prompts 1000 \
|
||||
--random-prefix-len $prefix_len \
|
||||
--port 8004 \
|
||||
--profile &> "$bm_log"
|
||||
--port 8004 &> "$bm_log"
|
||||
throughput=$(grep "Request throughput (req/s):" "$bm_log" | sed 's/[^0-9.]//g')
|
||||
e2el=$(grep "P99 E2EL (ms):" "$bm_log" | awk '{print $NF}')
|
||||
goodput=$(grep "Request goodput (req/s):" "$bm_log" | sed 's/[^0-9.]//g')
|
||||
@ -163,7 +165,6 @@ adjusted_input_len=$(( INPUT_LEN - prefix_len ))
|
||||
# start from request-rate as int(throughput) + 1
|
||||
request_rate=$((${throughput%.*} + 1))
|
||||
while ((request_rate > 0)); do
|
||||
profile_index=$((profile_index+1))
|
||||
# clear prefix cache
|
||||
curl -X POST http://0.0.0.0:8004/reset_prefix_cache
|
||||
sleep 5
|
||||
@ -201,12 +202,7 @@ adjusted_input_len=$(( INPUT_LEN - prefix_len ))
|
||||
best_max_num_seqs=$max_num_seqs
|
||||
best_num_batched_tokens=$max_num_batched_tokens
|
||||
best_goodput=$goodput
|
||||
if [[ "$SYSTEM" == "TPU" ]]; then
|
||||
update_best_profile "$profile_dir/plugins/profile" $profile_index
|
||||
fi
|
||||
if [[ "$SYSTEM" == "GPU" ]]; then
|
||||
update_best_profile "$profile_dir" $profile_index
|
||||
fi
|
||||
best_request_rate=$request_rate
|
||||
fi
|
||||
else
|
||||
echo "max_num_seqs: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens does not meet latency requirement ${MAX_LATENCY_ALLOWED_MS}"
|
||||
@ -215,7 +211,7 @@ adjusted_input_len=$(( INPUT_LEN - prefix_len ))
|
||||
|
||||
echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput"
|
||||
|
||||
pkill vllm
|
||||
pkill -if vllm
|
||||
sleep 10
|
||||
printf '=%.0s' $(seq 1 20)
|
||||
return 0
|
||||
@ -228,7 +224,8 @@ read -r -a num_batched_tokens_list <<< "$NUM_BATCHED_TOKENS_LIST"
|
||||
gpu_memory_utilization=0.98
|
||||
find_gpu_memory_utilization=0
|
||||
while (( $(echo "$gpu_memory_utilization >= 0.9" | bc -l) )); do
|
||||
start_server $gpu_memory_utilization "${num_seqs_list[-1]}" "${num_batched_tokens_list[-1]}" "$LOG_FOLDER/vllm_log_gpu_memory_utilization_$gpu_memory_utilization.log"
|
||||
# Pass empty string for profile_dir argument
|
||||
start_server $gpu_memory_utilization "${num_seqs_list[-1]}" "${num_batched_tokens_list[-1]}" "$LOG_FOLDER/vllm_log_gpu_memory_utilization_$gpu_memory_utilization.log" ""
|
||||
result=$?
|
||||
if [[ "$result" -eq 0 ]]; then
|
||||
find_gpu_memory_utilization=1
|
||||
@ -251,5 +248,45 @@ for num_seqs in "${num_seqs_list[@]}"; do
|
||||
done
|
||||
done
|
||||
echo "finish permutations"
|
||||
|
||||
# =================================================================================
|
||||
# FINAL PROFILING RUN FOR THE BEST CONFIGURATION
|
||||
# =================================================================================
|
||||
if (( $(echo "$best_throughput > 0" | bc -l) )); then
|
||||
echo
|
||||
echo "Benchmark tuning finished. Now running profiling on the best configuration found..."
|
||||
echo "Best config: max_num_seqs: $best_max_num_seqs, max_num_batched_tokens: $best_num_batched_tokens, throughput: $best_throughput"
|
||||
echo
|
||||
|
||||
vllm_log="$LOG_FOLDER/vllm_log_BEST_PROFILE.txt"
|
||||
bm_log="$LOG_FOLDER/bm_log_BEST_PROFILE.txt"
|
||||
|
||||
# Start server with the best params and profiling ENABLED
|
||||
echo "Starting server for profiling..."
|
||||
start_server $gpu_memory_utilization $best_max_num_seqs $best_num_batched_tokens "$vllm_log" "$PROFILE_PATH"
|
||||
|
||||
# Run benchmark with the best params and the --profile flag
|
||||
echo "Running benchmark with profiling..."
|
||||
prefix_len=$(( INPUT_LEN * MIN_CACHE_HIT_PCT / 100 ))
|
||||
adjusted_input_len=$(( INPUT_LEN - prefix_len ))
|
||||
vllm bench serve \
|
||||
--backend vllm \
|
||||
--model $MODEL \
|
||||
--dataset-name random \
|
||||
--random-input-len $adjusted_input_len \
|
||||
--random-output-len $OUTPUT_LEN \
|
||||
--ignore-eos \
|
||||
--disable-tqdm \
|
||||
--request-rate $best_request_rate \
|
||||
--percentile-metrics ttft,tpot,itl,e2el \
|
||||
--goodput e2el:$MAX_LATENCY_ALLOWED_MS \
|
||||
--num-prompts 100 \
|
||||
--random-prefix-len $prefix_len \
|
||||
--port 8004 \
|
||||
--profile &> "$bm_log"
|
||||
else
|
||||
echo "No configuration met the latency requirements. Skipping final profiling run."
|
||||
fi
|
||||
pkill -if vllm
|
||||
echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput, profile saved in: $PROFILE_PATH"
|
||||
echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput, profile saved in: $PROFILE_PATH" >> "$RESULT"
|
||||
|
||||
@ -120,7 +120,7 @@ A code example can be found here: <gh-file:examples/offline_inference/basic/clas
|
||||
### `LLM.score`
|
||||
|
||||
The [score][vllm.LLM.score] method outputs similarity scores between sentence pairs.
|
||||
It is designed for embedding models and cross encoder models. Embedding models use cosine similarity, and [cross-encoder models](https://www.sbert.net/examples/applications/cross-encoder/README.html) serve as rerankers between candidate query-document pairs in RAG systems.
|
||||
It is designed for embedding models and cross-encoder models. Embedding models use cosine similarity, and [cross-encoder models](https://www.sbert.net/examples/applications/cross-encoder/README.html) serve as rerankers between candidate query-document pairs in RAG systems.
|
||||
|
||||
!!! note
|
||||
vLLM can only perform the model inference component (e.g. embedding, reranking) of RAG.
|
||||
|
||||
@ -311,6 +311,8 @@ See [this page](generative_models.md) for more information on how to use generat
|
||||
|
||||
#### Text Generation
|
||||
|
||||
These models primarily accept the [`LLM.generate`](./generative_models.md#llmgenerate) API. Chat/Instruct models additionally support the [`LLM.chat`](./generative_models.md#llmchat) API.
|
||||
|
||||
<style>
|
||||
th {
|
||||
white-space: nowrap;
|
||||
@ -328,7 +330,7 @@ th {
|
||||
| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | |
|
||||
| `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | |
|
||||
| `ChatGLMModel`, `ChatGLMForConditionalGeneration` | ChatGLM | `THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, `ShieldLM-6B-chatglm3`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `ChatGLMModel`, `ChatGLMForConditionalGeneration` | ChatGLM | `zai-org/chatglm2-6b`, `zai-org/chatglm3-6b`, `ShieldLM-6B-chatglm3`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `CohereForCausalLM`, `Cohere2ForCausalLM` | Command-R | `CohereForAI/c4ai-command-r-v01`, `CohereForAI/c4ai-command-r7b-12-2024`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `DbrxForCausalLM` | DBRX | `databricks/dbrx-base`, `databricks/dbrx-instruct`, etc. | | ✅︎ | ✅︎ |
|
||||
| `DeciLMForCausalLM` | DeciLM | `nvidia/Llama-3_3-Nemotron-Super-49B-v1`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
@ -348,8 +350,8 @@ th {
|
||||
| `Gemma2ForCausalLM` | Gemma 2 | `google/gemma-2-9b`, `google/gemma-2-27b`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Gemma3nForConditionalGeneration` | Gemma 3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ |
|
||||
| `GlmForCausalLM` | GLM-4 | `THUDM/glm-4-9b-chat-hf`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Glm4ForCausalLM` | GLM-4-0414 | `THUDM/GLM-4-32B-0414`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `GlmForCausalLM` | GLM-4 | `zai-org/glm-4-9b-chat-hf`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Glm4ForCausalLM` | GLM-4-0414 | `zai-org/GLM-4-32B-0414`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `GPT2LMHeadModel` | GPT-2 | `gpt2`, `gpt2-xl`, etc. | | ✅︎ | ✅︎ |
|
||||
| `GPTBigCodeForCausalLM` | StarCoder, SantaCoder, WizardCoder | `bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, `WizardLM/WizardCoder-15B-V1.0`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `GPTJForCausalLM` | GPT-J | `EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc. | | ✅︎ | ✅︎ |
|
||||
@ -419,7 +421,9 @@ See [this page](./pooling_models.md) for more information on how to use pooling
|
||||
Since some model architectures support both generative and pooling tasks,
|
||||
you should explicitly specify `--runner pooling` to ensure that the model is used in pooling mode instead of generative mode.
|
||||
|
||||
#### Text Embedding
|
||||
#### Embedding
|
||||
|
||||
These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) API.
|
||||
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) |
|
||||
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
|
||||
@ -457,28 +461,10 @@ If your model is not in the above list, we will try to automatically convert the
|
||||
[as_embedding_model][vllm.model_executor.models.adapters.as_embedding_model]. By default, the embeddings
|
||||
of the whole prompt are extracted from the normalized hidden state corresponding to the last token.
|
||||
|
||||
#### Reward Modeling
|
||||
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) |
|
||||
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
|
||||
| `InternLM2ForRewardModel` | InternLM2-based | `internlm/internlm2-1_8b-reward`, `internlm/internlm2-7b-reward`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `LlamaForCausalLM`<sup>C</sup> | Llama-based | `peiyi9979/math-shepherd-mistral-7b-prm`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen2ForRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-RM-72B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen2ForProcessRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-PRM-7B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* | \* |
|
||||
|
||||
<sup>C</sup> Automatically converted into a reward model via `--convert reward`. ([details](./pooling_models.md#model-conversion))
|
||||
\* Feature support is the same as that of the original model.
|
||||
|
||||
If your model is not in the above list, we will try to automatically convert the model using
|
||||
[as_reward_model][vllm.model_executor.models.adapters.as_reward_model]. By default, we return the hidden states of each token directly.
|
||||
|
||||
!!! important
|
||||
For process-supervised reward models such as `peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly,
|
||||
e.g.: `--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`.
|
||||
|
||||
#### Classification
|
||||
|
||||
These models primarily support the [`LLM.classify`](./pooling_models.md#llmclassify) API.
|
||||
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) |
|
||||
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
|
||||
| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ | |
|
||||
@ -491,7 +477,10 @@ If your model is not in the above list, we will try to automatically convert the
|
||||
If your model is not in the above list, we will try to automatically convert the model using
|
||||
[as_seq_cls_model][vllm.model_executor.models.adapters.as_seq_cls_model]. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token.
|
||||
|
||||
#### Sentence Pair Scoring
|
||||
#### Cross-encoder / Reranker
|
||||
|
||||
Cross-encoder and reranker models are a subset of classification models that accept two prompts as input.
|
||||
These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) API.
|
||||
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) |
|
||||
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
|
||||
@ -501,6 +490,7 @@ If your model is not in the above list, we will try to automatically convert the
|
||||
| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | | |
|
||||
| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | | | |
|
||||
| `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* | \* |
|
||||
|
||||
<sup>C</sup> Automatically converted into a classification model via `--convert classify`. ([details](./pooling_models.md#model-conversion))
|
||||
\* Feature support is the same as that of the original model.
|
||||
@ -526,6 +516,28 @@ If your model is not in the above list, we will try to automatically convert the
|
||||
vllm serve Qwen/Qwen3-Reranker-0.6B --hf_overrides '{"architectures": ["Qwen3ForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}'
|
||||
```
|
||||
|
||||
#### Reward Modeling
|
||||
|
||||
These models primarily support the [`LLM.reward`](./pooling_models.md#llmreward) API.
|
||||
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) |
|
||||
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
|
||||
| `InternLM2ForRewardModel` | InternLM2-based | `internlm/internlm2-1_8b-reward`, `internlm/internlm2-7b-reward`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `LlamaForCausalLM`<sup>C</sup> | Llama-based | `peiyi9979/math-shepherd-mistral-7b-prm`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen2ForRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-RM-72B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen2ForProcessRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-PRM-7B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* | \* |
|
||||
|
||||
<sup>C</sup> Automatically converted into a reward model via `--convert reward`. ([details](./pooling_models.md#model-conversion))
|
||||
\* Feature support is the same as that of the original model.
|
||||
|
||||
If your model is not in the above list, we will try to automatically convert the model using
|
||||
[as_reward_model][vllm.model_executor.models.adapters.as_reward_model]. By default, we return the hidden states of each token directly.
|
||||
|
||||
!!! important
|
||||
For process-supervised reward models such as `peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly,
|
||||
e.g.: `--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`.
|
||||
|
||||
[](){ #supported-mm-models }
|
||||
|
||||
## List of Multimodal Language Models
|
||||
@ -579,6 +591,8 @@ See [this page](generative_models.md) for more information on how to use generat
|
||||
|
||||
#### Text Generation
|
||||
|
||||
These models primarily accept the [`LLM.generate`](./generative_models.md#llmgenerate) API. Chat/Instruct models additionally support the [`LLM.chat`](./generative_models.md#llmchat) API.
|
||||
|
||||
| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) |
|
||||
|--------------|--------|--------|-------------------|----------------------|---------------------------|---------------------|
|
||||
| `AriaForConditionalGeneration` | Aria | T + I<sup>+</sup> | `rhymes-ai/Aria` | | | ✅︎ |
|
||||
@ -589,8 +603,8 @@ See [this page](generative_models.md) for more information on how to use generat
|
||||
| `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large`, etc. | | | |
|
||||
| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ |
|
||||
| `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `THUDM/glm-4v-9b`, `THUDM/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + I<sup>E+</sup> + V<sup>E+</sup> | `THUDM/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Glm4MoeForCausalLM` | GLM-4.5 | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.5`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Glm4v_moeForConditionalGeneration` | GLM-4.5V | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.5V-Air`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ |
|
||||
@ -720,11 +734,9 @@ Speech2Text models trained specifically for Automatic Speech Recognition.
|
||||
|
||||
See [this page](./pooling_models.md) for more information on how to use pooling models.
|
||||
|
||||
!!! important
|
||||
Since some model architectures support both generative and pooling tasks,
|
||||
you should explicitly specify `--runner pooling` to ensure that the model is used in pooling mode instead of generative mode.
|
||||
#### Embedding
|
||||
|
||||
#### Text Embedding
|
||||
These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) API.
|
||||
|
||||
!!! note
|
||||
To get the best results, you should use pooling models that are specifically trained as such.
|
||||
@ -742,7 +754,10 @@ The following table lists those that are tested in vLLM.
|
||||
|
||||
---
|
||||
|
||||
#### Scoring
|
||||
#### Cross-encoder / Reranker
|
||||
|
||||
Cross-encoder and reranker models are a subset of classification models that accept two prompts as input.
|
||||
These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) API.
|
||||
|
||||
| Architecture | Models | Inputs | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | [V1](gh-issue:8779) |
|
||||
|-------------------------------------|--------------------|----------|--------------------------|------------------------|-----------------------------|-----------------------|
|
||||
|
||||
@ -221,7 +221,7 @@ def run_gemma3(questions: list[str], modality: str) -> ModelRequestData:
|
||||
# GLM-4v
|
||||
def run_glm4v(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
model_name = "THUDM/glm-4v-9b"
|
||||
model_name = "zai-org/glm-4v-9b"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
@ -250,7 +250,7 @@ def run_glm4v(questions: list[str], modality: str) -> ModelRequestData:
|
||||
|
||||
# GLM-4.1V
|
||||
def run_glm4_1v(questions: list[str], modality: str) -> ModelRequestData:
|
||||
model_name = "THUDM/GLM-4.1V-9B-Thinking"
|
||||
model_name = "zai-org/GLM-4.1V-9B-Thinking"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
|
||||
@ -154,7 +154,7 @@ TEXT_GENERATION_MODELS = {
|
||||
"baichuan-inc/Baichuan-7B": PPTestSettings.fast(),
|
||||
"baichuan-inc/Baichuan2-13B-Chat": PPTestSettings.fast(),
|
||||
"bigscience/bloomz-1b1": PPTestSettings.fast(),
|
||||
"THUDM/chatglm3-6b": PPTestSettings.fast(),
|
||||
"zai-org/chatglm3-6b": PPTestSettings.fast(),
|
||||
"CohereForAI/c4ai-command-r-v01": PPTestSettings.fast(load_format="dummy"),
|
||||
"databricks/dbrx-instruct": PPTestSettings.fast(load_format="dummy"),
|
||||
"Deci/DeciLM-7B-instruct": PPTestSettings.fast(),
|
||||
@ -224,7 +224,7 @@ MULTIMODAL_MODELS = {
|
||||
"Salesforce/blip2-opt-6.7b": PPTestSettings.fast(),
|
||||
"facebook/chameleon-7b": PPTestSettings.fast(),
|
||||
"adept/fuyu-8b": PPTestSettings.fast(),
|
||||
"THUDM/glm-4v-9b": PPTestSettings.fast(),
|
||||
"zai-org/glm-4v-9b": PPTestSettings.fast(),
|
||||
"OpenGVLab/InternVL2-1B": PPTestSettings.fast(),
|
||||
"llava-hf/llava-1.5-7b-hf": PPTestSettings.fast(),
|
||||
"llava-hf/llava-v1.6-mistral-7b-hf": PPTestSettings.fast(),
|
||||
|
||||
@ -14,7 +14,7 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import merge_async_iterators
|
||||
|
||||
MODEL_PATH = "THUDM/chatglm3-6b"
|
||||
MODEL_PATH = "zai-org/chatglm3-6b"
|
||||
LORA_RANK = 64
|
||||
DEFAULT_MAX_LORAS = 4 * 3
|
||||
|
||||
|
||||
@ -6,7 +6,7 @@ from vllm.lora.request import LoRARequest
|
||||
|
||||
from ..utils import create_new_process_for_each_test, multi_gpu_test
|
||||
|
||||
MODEL_PATH = "THUDM/chatglm3-6b"
|
||||
MODEL_PATH = "zai-org/chatglm3-6b"
|
||||
|
||||
PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501
|
||||
|
||||
|
||||
@ -53,7 +53,7 @@ AITER_MODEL_LIST = [
|
||||
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
|
||||
),
|
||||
pytest.param(
|
||||
"THUDM/chatglm3-6b", # chatglm (text-only)
|
||||
"zai-org/chatglm3-6b", # chatglm (text-only)
|
||||
),
|
||||
pytest.param(
|
||||
"meta-llama/Llama-3.2-1B-Instruct", # llama
|
||||
|
||||
@ -355,7 +355,7 @@ VLM_TEST_SETTINGS = {
|
||||
num_logprobs=10,
|
||||
),
|
||||
"glm4v": VLMTestInfo(
|
||||
models=["THUDM/glm-4v-9b"],
|
||||
models=["zai-org/glm-4v-9b"],
|
||||
test_type=VLMTestType.IMAGE,
|
||||
prompt_formatter=lambda img_prompt: f"<|user|>\n{img_prompt}<|assistant|>", # noqa: E501
|
||||
single_image_prompts=IMAGE_ASSETS.prompts({
|
||||
@ -374,7 +374,7 @@ VLM_TEST_SETTINGS = {
|
||||
marks=[large_gpu_mark(min_gb=32)],
|
||||
),
|
||||
"glm4_1v": VLMTestInfo(
|
||||
models=["THUDM/GLM-4.1V-9B-Thinking"],
|
||||
models=["zai-org/GLM-4.1V-9B-Thinking"],
|
||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||
prompt_formatter=lambda img_prompt: f"<|user|>\n{img_prompt}<|assistant|>", # noqa: E501
|
||||
img_idx_to_prompt=lambda idx: "<|begin_of_image|><|image|><|end_of_image|>", # noqa: E501
|
||||
@ -388,7 +388,7 @@ VLM_TEST_SETTINGS = {
|
||||
marks=[large_gpu_mark(min_gb=32)],
|
||||
),
|
||||
"glm4_1v-video": VLMTestInfo(
|
||||
models=["THUDM/GLM-4.1V-9B-Thinking"],
|
||||
models=["zai-org/GLM-4.1V-9B-Thinking"],
|
||||
# GLM4.1V require include video metadata for input
|
||||
test_type=VLMTestType.CUSTOM_INPUTS,
|
||||
max_model_len=4096,
|
||||
|
||||
@ -271,8 +271,8 @@ def _test_processing_correctness_one(
|
||||
"microsoft/Florence-2-base",
|
||||
"adept/fuyu-8b",
|
||||
"google/gemma-3-4b-it",
|
||||
"THUDM/glm-4v-9b",
|
||||
"THUDM/GLM-4.1V-9B-Thinking",
|
||||
"zai-org/glm-4v-9b",
|
||||
"zai-org/GLM-4.1V-9B-Thinking",
|
||||
"ibm-granite/granite-speech-3.3-2b",
|
||||
"h2oai/h2ovl-mississippi-800m",
|
||||
"internlm/Intern-S1",
|
||||
|
||||
@ -9,7 +9,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from ...utils import build_model_context
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", ["THUDM/GLM-4.1V-9B-Thinking"])
|
||||
@pytest.mark.parametrize("model_id", ["zai-org/GLM-4.1V-9B-Thinking"])
|
||||
@pytest.mark.parametrize("expected_toks_per_frame", [299])
|
||||
@pytest.mark.parametrize("num_frames", [32, 128])
|
||||
@pytest.mark.parametrize("fps, expected_grid_t", [(1, 5), (2, 10)])
|
||||
|
||||
@ -139,8 +139,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
trust_remote_code=True),
|
||||
"AquilaForCausalLM": _HfExamplesInfo("BAAI/AquilaChat2-7B",
|
||||
trust_remote_code=True),
|
||||
"ArceeForCausalLM": _HfExamplesInfo("arcee-ai/AFM-4.5B-Base",
|
||||
is_available_online=False),
|
||||
"ArceeForCausalLM": _HfExamplesInfo("arcee-ai/AFM-4.5B-Base"),
|
||||
"ArcticForCausalLM": _HfExamplesInfo("Snowflake/snowflake-arctic-instruct",
|
||||
trust_remote_code=True),
|
||||
"BaiChuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan-7B",
|
||||
@ -153,7 +152,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
extras={"tiny": "hmellor/tiny-random-BambaForCausalLM"}), # noqa: E501
|
||||
"BloomForCausalLM": _HfExamplesInfo("bigscience/bloom-560m",
|
||||
{"1b": "bigscience/bloomz-1b1"}),
|
||||
"ChatGLMModel": _HfExamplesInfo("THUDM/chatglm3-6b",
|
||||
"ChatGLMModel": _HfExamplesInfo("zai-org/chatglm3-6b",
|
||||
trust_remote_code=True,
|
||||
max_transformers_version="4.48"),
|
||||
"ChatGLMForConditionalGeneration": _HfExamplesInfo("thu-coai/ShieldLM-6B-chatglm3", # noqa: E501
|
||||
@ -187,8 +186,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"),
|
||||
"Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it", # noqa: E501
|
||||
min_transformers_version="4.53"),
|
||||
"GlmForCausalLM": _HfExamplesInfo("THUDM/glm-4-9b-chat-hf"),
|
||||
"Glm4ForCausalLM": _HfExamplesInfo("THUDM/GLM-4-9B-0414"),
|
||||
"GlmForCausalLM": _HfExamplesInfo("zai-org/glm-4-9b-chat-hf"),
|
||||
"Glm4ForCausalLM": _HfExamplesInfo("zai-org/GLM-4-9B-0414"),
|
||||
"Glm4MoeForCausalLM": _HfExamplesInfo("zai-org/GLM-4.5",
|
||||
min_transformers_version="4.54"), # noqa: E501
|
||||
"GPT2LMHeadModel": _HfExamplesInfo("openai-community/gpt2",
|
||||
@ -380,10 +379,10 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
|
||||
"Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"),
|
||||
"GraniteSpeechForConditionalGeneration": _HfExamplesInfo("ibm-granite/granite-speech-3.3-2b"), # noqa: E501
|
||||
"GLM4VForCausalLM": _HfExamplesInfo("THUDM/glm-4v-9b",
|
||||
"GLM4VForCausalLM": _HfExamplesInfo("zai-org/glm-4v-9b",
|
||||
trust_remote_code=True,
|
||||
hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501
|
||||
"Glm4vForConditionalGeneration": _HfExamplesInfo("THUDM/GLM-4.1V-9B-Thinking"), # noqa: E501
|
||||
"Glm4vForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.1V-9B-Thinking"), # noqa: E501
|
||||
"Glm4v_moeForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.5V-Air",
|
||||
is_available_online=False), # noqa: E501
|
||||
"H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m",
|
||||
|
||||
@ -10,7 +10,7 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer,
|
||||
get_cached_tokenizer)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", ["gpt2", "THUDM/chatglm3-6b"])
|
||||
@pytest.mark.parametrize("model_id", ["gpt2", "zai-org/chatglm3-6b"])
|
||||
def test_cached_tokenizer(model_id: str):
|
||||
reference_tokenizer = AutoTokenizer.from_pretrained(model_id,
|
||||
trust_remote_code=True)
|
||||
|
||||
@ -17,7 +17,7 @@ from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
|
||||
BACKENDS_TO_TEST = [
|
||||
_Backend.FLASH_ATTN_VLLM_V1, _Backend.FLASHINFER_VLLM_V1,
|
||||
_Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN_VLLM_V1
|
||||
_Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN_VLLM_V1, _Backend.TREE_ATTN
|
||||
]
|
||||
|
||||
# Remove flashinfer from the list if it's not available
|
||||
|
||||
@ -109,11 +109,11 @@ def create_common_attn_metadata(
|
||||
|
||||
def get_attention_backend(backend_name: _Backend):
|
||||
"""Set up attention backend classes for testing.
|
||||
|
||||
|
||||
Args:
|
||||
backend_name: Name of the backend ("flash_attn", "flashinfer", etc.)
|
||||
vllm_config: VllmConfig instance
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (backend_builder_class, backend_impl_class)
|
||||
"""
|
||||
@ -126,6 +126,8 @@ def get_attention_backend(backend_name: _Backend):
|
||||
"vllm.v1.attention.backends.flex_attention.FlexAttentionBackend",
|
||||
_Backend.TRITON_ATTN_VLLM_V1:
|
||||
"vllm.v1.attention.backends.triton_attn.TritonAttentionBackend",
|
||||
_Backend.TREE_ATTN:
|
||||
"vllm.v1.attention.backends.tree_attn.TreeAttentionBackend",
|
||||
}
|
||||
|
||||
if backend_name not in backend_map:
|
||||
|
||||
@ -10,7 +10,7 @@ from vllm.assets.image import ImageAsset
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.multimodal.utils import encode_image_base64
|
||||
|
||||
MODEL_NAME = "RedHatAI/Qwen2.5-VL-3B-Instruct-quantized.w4a16"
|
||||
MODEL_NAME = "RedHatAI/Qwen2.5-VL-3B-Instruct-quantized.w8a8"
|
||||
|
||||
SAMPLING_PARAMS = SamplingParams(temperature=0.0, top_k=1, max_tokens=128)
|
||||
|
||||
|
||||
@ -429,6 +429,33 @@ def test_zero_logprobs(vllm_model, example_prompts,
|
||||
assert len(prompt_token_ids) == len(prompt_logprobs)
|
||||
|
||||
|
||||
def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Engine should return all vocabulary logprobs
|
||||
|
||||
Args:
|
||||
example_prompts: list of example prompts (test fixture)
|
||||
"""
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
runner = VllmRunner(
|
||||
"facebook/opt-125m",
|
||||
max_logprobs=-1,
|
||||
enable_prefix_caching=False,
|
||||
# 2 other llms alive during whole session
|
||||
gpu_memory_utilization=0.15,
|
||||
max_model_len=256)
|
||||
sampling_params_logprobs_all = SamplingParams(max_tokens=5,
|
||||
logprobs=-1)
|
||||
results_logprobs_all = runner.llm.generate(
|
||||
example_prompts, sampling_params=sampling_params_logprobs_all)
|
||||
vocab_size = runner.llm.llm_engine.get_model_config().get_vocab_size()
|
||||
for i in range(len(results_logprobs_all)):
|
||||
logprobs = results_logprobs_all[i].outputs[0].logprobs
|
||||
assert logprobs is not None
|
||||
for logprob in logprobs:
|
||||
assert len(logprob) == vocab_size
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"logprobs_mode",
|
||||
["raw_logprobs", "raw_logits", "processed_logprobs", "processed_logits"])
|
||||
|
||||
@ -202,7 +202,9 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
|
||||
def test_propose(num_speculative_tokens):
|
||||
@pytest.mark.parametrize("backend",
|
||||
[_Backend.FLASH_ATTN_VLLM_V1, _Backend.TREE_ATTN])
|
||||
def test_propose(num_speculative_tokens, backend):
|
||||
# Use GPU device
|
||||
device = torch.device(current_platform.device_type)
|
||||
|
||||
@ -301,8 +303,7 @@ def test_propose(num_speculative_tokens):
|
||||
device=device)
|
||||
sampling_metadata = mock.MagicMock()
|
||||
|
||||
attn_metadata_builder_cls, _ = get_attention_backend(
|
||||
_Backend.FLASH_ATTN_VLLM_V1)
|
||||
attn_metadata_builder_cls, _ = get_attention_backend(backend)
|
||||
attn_metadata_builder = attn_metadata_builder_cls(
|
||||
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
|
||||
layer_names=proposer.attn_layer_names,
|
||||
|
||||
299
tests/v1/spec_decode/test_tree_attention.py
Normal file
299
tests/v1/spec_decode/test_tree_attention.py
Normal file
@ -0,0 +1,299 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from tests.v1.attention.utils import (_Backend, create_standard_kv_cache_spec,
|
||||
create_vllm_config,
|
||||
get_attention_backend)
|
||||
from vllm.config import ParallelConfig, SpeculativeConfig
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
|
||||
|
||||
class MockAttentionLayer(torch.nn.Module):
|
||||
_q_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
|
||||
_k_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
|
||||
_v_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
def forward_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
seqlen_k: int,
|
||||
backend: _Backend,
|
||||
spec_token_tree: Optional[str] = None,
|
||||
num_spec_tokens: int = 0,
|
||||
) -> torch.Tensor:
|
||||
batch_size, q_len, num_heads, dim_per_head = q.shape
|
||||
num_kv_heads = k.shape[-2]
|
||||
# Initialize the query and KV sequence lengths.
|
||||
query_start_loc = q_len * torch.arange(
|
||||
batch_size + 1, device=q.device, dtype=torch.int32)
|
||||
query_lens = torch.diff(query_start_loc)
|
||||
seq_lens = torch.full(
|
||||
(batch_size, ),
|
||||
seqlen_k,
|
||||
device=q.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
context_lens = seq_lens - query_lens
|
||||
max_query_len = q_len
|
||||
num_actual_tokens = query_start_loc[-1]
|
||||
|
||||
softmax_scale = q.shape[-1]**(-0.5)
|
||||
layer = MockAttentionLayer()
|
||||
|
||||
# Build common metadata.
|
||||
model_name = "meta-llama/Meta-Llama-3-8B"
|
||||
builder_cls, impl_cls = get_attention_backend(backend)
|
||||
vllm_config = create_vllm_config(model_name=model_name,
|
||||
max_model_len=max(seq_lens))
|
||||
if spec_token_tree is not None:
|
||||
# Create speculative config if token tree is specified.
|
||||
vllm_config.speculative_config = SpeculativeConfig(
|
||||
target_model_config=vllm_config.model_config,
|
||||
target_parallel_config=ParallelConfig(),
|
||||
model=model_name,
|
||||
method="eagle",
|
||||
num_speculative_tokens=num_spec_tokens,
|
||||
speculative_token_tree=spec_token_tree)
|
||||
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
|
||||
builder = builder_cls(kv_cache_spec, [], vllm_config, q.device)
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc.cpu(),
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_cpu=seq_lens.cpu(),
|
||||
num_computed_tokens_cpu=context_lens.cpu(),
|
||||
num_reqs=batch_size,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
block_table_tensor=block_table,
|
||||
slot_mapping=slot_mapping,
|
||||
)
|
||||
|
||||
# Build attention metadata.
|
||||
attn_metadata = builder.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
|
||||
# Initialize the backend implementation.
|
||||
instance = impl_cls(
|
||||
num_heads=num_heads,
|
||||
head_size=dim_per_head,
|
||||
scale=softmax_scale,
|
||||
num_kv_heads=num_kv_heads,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
kv_cache_dtype="auto",
|
||||
)
|
||||
|
||||
# Run forward pass and return output.
|
||||
query = q.view(-1, num_heads, dim_per_head)
|
||||
key = k.view(-1, num_kv_heads, dim_per_head)
|
||||
value = v.view(-1, num_kv_heads, dim_per_head)
|
||||
output = torch.empty_like(query)
|
||||
return instance.forward(
|
||||
layer=layer,
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
kv_cache=kv_cache.clone(),
|
||||
attn_metadata=attn_metadata,
|
||||
output=output,
|
||||
)
|
||||
|
||||
|
||||
def test_tree_attn_correctness() -> None:
|
||||
torch.manual_seed(42)
|
||||
torch.cuda.manual_seed_all(42)
|
||||
|
||||
device = "cuda"
|
||||
tree_attn_masks = {
|
||||
# Chain.
|
||||
"[(0,), (0, 0), (0, 0, 0)]":
|
||||
torch.tensor(
|
||||
[
|
||||
[1, 0, 0, 0],
|
||||
[1, 1, 0, 0],
|
||||
[1, 1, 1, 0],
|
||||
[1, 1, 1, 1],
|
||||
],
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
# Tree.
|
||||
"[(0,), (1,), (0, 0), (0, 1), (1, 0), (1, 1)]":
|
||||
torch.tensor(
|
||||
[
|
||||
[1, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0, 0],
|
||||
[1, 0, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 1, 0, 0, 0],
|
||||
[1, 1, 0, 0, 1, 0, 0],
|
||||
[1, 0, 1, 0, 0, 1, 0],
|
||||
[1, 0, 1, 0, 0, 0, 1],
|
||||
],
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
}
|
||||
|
||||
dim_per_head = 128
|
||||
num_kv_heads = 2
|
||||
block_size = 128
|
||||
max_sequence_length = 8192
|
||||
randomize_blocks = True
|
||||
for batch_size in [1, 16, 32]:
|
||||
for num_heads in [2, 4]:
|
||||
for sequence_position in [16, 1024, 2048]:
|
||||
for spec_token_tree, tree_attn_mask in tree_attn_masks.items():
|
||||
# Assert that the number of heads is divisible
|
||||
# by the number of KV heads.
|
||||
assert num_heads % num_kv_heads == 0
|
||||
|
||||
# Initialize q, k, and v.
|
||||
tree_size_q = tree_attn_mask.shape[0]
|
||||
seqlen_k = sequence_position + tree_size_q
|
||||
q = torch.randn(
|
||||
(batch_size, tree_size_q, num_heads, dim_per_head),
|
||||
device=device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
k = torch.randn(
|
||||
(batch_size, tree_size_q, num_kv_heads, dim_per_head),
|
||||
device=device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
v = torch.randn(
|
||||
(batch_size, tree_size_q, num_kv_heads, dim_per_head),
|
||||
device=device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
# Setup the block table and KV cache for paged KV.
|
||||
assert max_sequence_length % block_size == 0
|
||||
max_blocks_per_batch = max_sequence_length // block_size
|
||||
kv_cache = torch.randn(
|
||||
(
|
||||
2,
|
||||
batch_size * max_blocks_per_batch,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
dim_per_head,
|
||||
),
|
||||
device=q.device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
num_alloc_blocks_per_batch = math.ceil(seqlen_k /
|
||||
block_size)
|
||||
block_table = torch.zeros(
|
||||
(batch_size, max_blocks_per_batch),
|
||||
device=q.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
block_ids = torch.arange(
|
||||
0,
|
||||
batch_size * num_alloc_blocks_per_batch,
|
||||
device=q.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
if randomize_blocks:
|
||||
# Randomize the block ids.
|
||||
block_ids = block_ids[torch.randperm(
|
||||
block_ids.numel())]
|
||||
block_table[:, :
|
||||
num_alloc_blocks_per_batch] = block_ids.view(
|
||||
-1, num_alloc_blocks_per_batch)
|
||||
|
||||
# Setup the slot mapping for the input KVs.
|
||||
tree_positions = sequence_position + torch.arange(
|
||||
0,
|
||||
tree_size_q,
|
||||
device=q.device,
|
||||
dtype=torch.int64,
|
||||
).repeat(batch_size, 1)
|
||||
tree_slot_mapping = _gen_slot_mapping(
|
||||
tree_positions, block_table, block_size)
|
||||
|
||||
# Compute attention for the tree.
|
||||
tree_attn_output = forward_attention(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
kv_cache=kv_cache,
|
||||
block_table=block_table,
|
||||
slot_mapping=tree_slot_mapping,
|
||||
seqlen_k=seqlen_k,
|
||||
backend=_Backend.TREE_ATTN,
|
||||
spec_token_tree=spec_token_tree,
|
||||
num_spec_tokens=tree_size_q - 1,
|
||||
).view(batch_size, -1, num_heads, dim_per_head)
|
||||
|
||||
# Verify that the chain attention output for each
|
||||
# branch of the tree (computed using FA3) matches
|
||||
# the tree attention output.
|
||||
for q_index in range(tree_size_q):
|
||||
# Get the q, k, and v for the branch.
|
||||
branch_mask = tree_attn_mask[q_index, :]
|
||||
branch_indices = torch.nonzero(branch_mask,
|
||||
as_tuple=True)[0]
|
||||
q_len = branch_indices.shape[0]
|
||||
q_branch = q[:, branch_indices]
|
||||
k_branch = k[:, branch_indices]
|
||||
v_branch = v[:, branch_indices]
|
||||
|
||||
# Setup slot mapping for the branch.
|
||||
branch_positions = sequence_position + torch.arange(
|
||||
0,
|
||||
q_len,
|
||||
device=q.device,
|
||||
dtype=torch.int64,
|
||||
).repeat(batch_size, 1)
|
||||
branch_slot_mapping = _gen_slot_mapping(
|
||||
branch_positions, block_table, block_size)
|
||||
|
||||
# Compute flash attention for the branch.
|
||||
flash_attn_output = forward_attention(
|
||||
q=q_branch,
|
||||
k=k_branch,
|
||||
v=v_branch,
|
||||
kv_cache=kv_cache,
|
||||
block_table=block_table,
|
||||
slot_mapping=branch_slot_mapping,
|
||||
seqlen_k=sequence_position + q_len,
|
||||
backend=_Backend.FLASH_ATTN_VLLM_V1,
|
||||
).view(batch_size, -1, num_heads, dim_per_head)
|
||||
|
||||
# Compare the outputs.
|
||||
assert torch.allclose(
|
||||
tree_attn_output[:, branch_indices],
|
||||
flash_attn_output,
|
||||
atol=7.81e-3,
|
||||
), (f"outputs are not close for "
|
||||
f"batch_size: {batch_size}, "
|
||||
f"num_heads: {num_heads}, "
|
||||
f"sequence_position: {sequence_position}, "
|
||||
f"tree_attn_mask: {tree_attn_mask}, "
|
||||
f"q_index: {q_index}.")
|
||||
|
||||
|
||||
def _gen_slot_mapping(positions: torch.Tensor, block_table: torch.Tensor,
|
||||
block_size: int):
|
||||
block_indices = positions // block_size
|
||||
blocks = block_table.gather(dim=1, index=block_indices)
|
||||
return (blocks * block_size + positions % block_size).view(-1)
|
||||
@ -55,6 +55,7 @@ def kernel_unified_attention_2d(
|
||||
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
||||
seq_lens_ptr, # [num_seqs]
|
||||
alibi_slopes_ptr, # [num_query_heads]
|
||||
qq_bias_ptr, # [num_query_tokens, num_query_tokens]
|
||||
scale, # float32
|
||||
k_scale, # float32
|
||||
v_scale, # float32
|
||||
@ -66,10 +67,12 @@ def kernel_unified_attention_2d(
|
||||
query_stride_1: tl.int64, # int, should be equal to head_size
|
||||
output_stride_0: tl.int64, # int
|
||||
output_stride_1: tl.int64, # int, should be equal to head_size
|
||||
qq_bias_stride_0: tl.int64, # int
|
||||
BLOCK_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
||||
USE_QQ_BIAS: tl.constexpr, # bool
|
||||
USE_SOFTCAP: tl.constexpr, # bool
|
||||
SLIDING_WINDOW: tl.constexpr, # int
|
||||
stride_k_cache_0: tl.int64, # int
|
||||
@ -144,6 +147,11 @@ def kernel_unified_attention_2d(
|
||||
mask=query_mask_1,
|
||||
other=0.0)
|
||||
|
||||
# query-query attention bias
|
||||
if USE_QQ_BIAS:
|
||||
qq_bias_row_ptrs = (qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0
|
||||
) # shape: [BLOCK_M]
|
||||
|
||||
# compute the length of the longest sequence prefix spanned by any
|
||||
# query token in the current q_block (q_block_local_idx)
|
||||
max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + (
|
||||
@ -223,6 +231,18 @@ def kernel_unified_attention_2d(
|
||||
if USE_ALIBI_SLOPES:
|
||||
S += alibi_slope[:, None] * (seq_offset - context_len)
|
||||
|
||||
if USE_QQ_BIAS:
|
||||
# compute key positions relative to query section
|
||||
key_rel_pos = seq_offset - context_len # shape: [BLOCK_SIZE]
|
||||
# load bias only for keys that correspond to queries
|
||||
is_query_key = key_rel_pos >= 0 and key_rel_pos < qq_bias_stride_0
|
||||
qq_bias = tl.load(
|
||||
qq_bias_row_ptrs + key_rel_pos[None, :],
|
||||
mask=is_query_key[None, :], # avoid OOB for context keys
|
||||
other=0.0,
|
||||
)
|
||||
S += qq_bias
|
||||
|
||||
# compute running maximum
|
||||
# m_j : (BLOCK_M,)
|
||||
m_j = tl.maximum(M, tl.max(S, axis=1))
|
||||
@ -275,6 +295,7 @@ def kernel_unified_attention_3d(
|
||||
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
||||
seq_lens_ptr, # [num_seqs]
|
||||
alibi_slopes_ptr, # [num_query_heads]
|
||||
qq_bias_ptr, # [num_query_tokens, num_query_tokens]
|
||||
scale, # float32
|
||||
k_scale, # float32
|
||||
v_scale, # float32
|
||||
@ -284,10 +305,12 @@ def kernel_unified_attention_3d(
|
||||
block_table_stride: tl.int64, # int
|
||||
query_stride_0: tl.int64, # int
|
||||
query_stride_1: tl.int64, # int, should be equal to head_size
|
||||
qq_bias_stride_0: tl.int64, # int
|
||||
BLOCK_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
||||
USE_QQ_BIAS: tl.constexpr, # bool
|
||||
USE_SOFTCAP: tl.constexpr, # bool
|
||||
SLIDING_WINDOW: tl.constexpr, # int
|
||||
stride_k_cache_0: tl.int64, # int
|
||||
@ -373,6 +396,11 @@ def kernel_unified_attention_3d(
|
||||
mask=query_mask_1,
|
||||
other=0.0)
|
||||
|
||||
# query-query attention bias
|
||||
if USE_QQ_BIAS:
|
||||
qq_bias_row_ptrs = (qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0
|
||||
) # shape: [BLOCK_M]
|
||||
|
||||
num_blocks = cdiv_fn(seq_len, BLOCK_SIZE)
|
||||
|
||||
# iterate through tiles within current segment
|
||||
@ -442,6 +470,18 @@ def kernel_unified_attention_3d(
|
||||
if USE_ALIBI_SLOPES:
|
||||
S += alibi_slope[:, None] * (seq_offset - context_len)
|
||||
|
||||
if USE_QQ_BIAS:
|
||||
# compute key positions relative to query section
|
||||
key_rel_pos = seq_offset - context_len # shape: [BLOCK_SIZE]
|
||||
# load bias only for keys that correspond to queries
|
||||
is_query_key = key_rel_pos >= 0 and key_rel_pos < qq_bias_stride_0
|
||||
qq_bias = tl.load(
|
||||
qq_bias_row_ptrs + key_rel_pos[None, :],
|
||||
mask=is_query_key[None, :], # avoid OOB for context keys
|
||||
other=0.0,
|
||||
)
|
||||
S += qq_bias
|
||||
|
||||
# compute running maximum
|
||||
# m_j : (BLOCK_M,)
|
||||
m_j = tl.maximum(M, tl.max(S, axis=1))
|
||||
@ -586,6 +626,7 @@ def unified_attention(
|
||||
k_descale,
|
||||
v_descale,
|
||||
alibi_slopes=None,
|
||||
qq_bias=None,
|
||||
):
|
||||
assert causal, "Only causal attention is supported"
|
||||
assert q_descale is None, "Q scales not supported"
|
||||
@ -595,6 +636,7 @@ def unified_attention(
|
||||
"Block size must be at least 32 for fp8"
|
||||
|
||||
use_alibi_slopes = alibi_slopes is not None
|
||||
use_qq_bias = qq_bias is not None
|
||||
|
||||
block_size = v.shape[1]
|
||||
num_seqs = len(seqused_k)
|
||||
@ -630,6 +672,7 @@ def unified_attention(
|
||||
block_tables_ptr=block_table,
|
||||
seq_lens_ptr=seqused_k,
|
||||
alibi_slopes_ptr=alibi_slopes,
|
||||
qq_bias_ptr=qq_bias,
|
||||
scale=softmax_scale,
|
||||
k_scale=k_descale,
|
||||
v_scale=v_descale,
|
||||
@ -641,10 +684,12 @@ def unified_attention(
|
||||
query_stride_1=q.stride(1),
|
||||
output_stride_0=out.stride(0),
|
||||
output_stride_1=out.stride(1),
|
||||
qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0,
|
||||
BLOCK_SIZE=block_size,
|
||||
HEAD_SIZE=head_size,
|
||||
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||
USE_ALIBI_SLOPES=use_alibi_slopes,
|
||||
USE_QQ_BIAS=use_qq_bias,
|
||||
USE_SOFTCAP=(softcap > 0),
|
||||
SLIDING_WINDOW=(1 + window_size[0]),
|
||||
stride_k_cache_0=k.stride(0),
|
||||
@ -699,6 +744,7 @@ def unified_attention(
|
||||
block_tables_ptr=block_table,
|
||||
seq_lens_ptr=seqused_k,
|
||||
alibi_slopes_ptr=alibi_slopes,
|
||||
qq_bias_ptr=qq_bias,
|
||||
scale=softmax_scale,
|
||||
k_scale=k_descale,
|
||||
v_scale=v_descale,
|
||||
@ -708,10 +754,12 @@ def unified_attention(
|
||||
block_table_stride=block_table.stride(0),
|
||||
query_stride_0=q.stride(0),
|
||||
query_stride_1=q.stride(1),
|
||||
qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0,
|
||||
BLOCK_SIZE=block_size,
|
||||
HEAD_SIZE=head_size,
|
||||
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||
USE_ALIBI_SLOPES=use_alibi_slopes,
|
||||
USE_QQ_BIAS=use_qq_bias,
|
||||
USE_SOFTCAP=(softcap > 0),
|
||||
SLIDING_WINDOW=(1 + window_size[0]),
|
||||
stride_k_cache_0=k.stride(0),
|
||||
|
||||
@ -377,7 +377,8 @@ class ModelConfig:
|
||||
max_logprobs: int = 20
|
||||
"""Maximum number of log probabilities to return when `logprobs` is
|
||||
specified in `SamplingParams`. The default value comes the default for the
|
||||
OpenAI Chat Completions API."""
|
||||
OpenAI Chat Completions API. -1 means no cap, i.e. all (output_length *
|
||||
vocab_size) logprobs are allowed to be returned and it may cause OOM."""
|
||||
logprobs_mode: LogprobsMode = "raw_logprobs"
|
||||
"""Indicates the content returned in the logprobs and prompt_logprobs.
|
||||
Supported mode:
|
||||
@ -1585,7 +1586,7 @@ class ModelConfig:
|
||||
"""
|
||||
This method attempts to retrieve the non-default values of the
|
||||
generation config for this model.
|
||||
|
||||
|
||||
The generation config can contain information about special tokens, as
|
||||
well as sampling parameters. Which is why this method exists separately
|
||||
to `get_diff_sampling_param`.
|
||||
@ -2066,7 +2067,7 @@ class ParallelConfig:
|
||||
and when data_parallel_size > 0. Enables running an AsyncLLM
|
||||
and API server on a "per-node" basis where vLLM load balances
|
||||
between local data parallel ranks, but an external LB balances
|
||||
between vLLM nodes/replicas. Set explicitly in conjunction with
|
||||
between vLLM nodes/replicas. Set explicitly in conjunction with
|
||||
--data-parallel-start-rank."""
|
||||
enable_expert_parallel: bool = False
|
||||
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
|
||||
@ -3049,6 +3050,19 @@ class SpeculativeConfig:
|
||||
f"num_speculative_tokens:{self.num_speculative_tokens}"
|
||||
f" must be divisible by {n_predict=}")
|
||||
|
||||
if self.speculative_token_tree is None:
|
||||
# Generate chain of tokens.
|
||||
self.speculative_token_tree = str([
|
||||
(i + 1) * (0, )
|
||||
for i in range(self.num_speculative_tokens)
|
||||
])
|
||||
else:
|
||||
# Sort the token tree breadth-first.
|
||||
tree_choices = ast.literal_eval(
|
||||
self.speculative_token_tree)
|
||||
self.speculative_token_tree = str(
|
||||
sorted(tree_choices, key=lambda t: (len(t), t)))
|
||||
|
||||
self.draft_tensor_parallel_size = \
|
||||
SpeculativeConfig._verify_and_get_draft_tp(
|
||||
self.target_parallel_config,
|
||||
|
||||
@ -1454,7 +1454,6 @@ class EngineArgs:
|
||||
"Please consider using other speculative decoding methods "
|
||||
"such as ngram, medusa, eagle, or deepseek_mtp.")
|
||||
|
||||
# No XFormers so far.
|
||||
V1_BACKENDS = [
|
||||
"FLASH_ATTN_VLLM_V1",
|
||||
"FLASH_ATTN",
|
||||
@ -1469,6 +1468,7 @@ class EngineArgs:
|
||||
"ROCM_AITER_MLA",
|
||||
"TORCH_SDPA_VLLM_V1",
|
||||
"FLEX_ATTENTION",
|
||||
"TREE_ATTN",
|
||||
]
|
||||
if (envs.is_set("VLLM_ATTENTION_BACKEND")
|
||||
and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):
|
||||
|
||||
@ -90,8 +90,17 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
logger.info("Using default chat sampling params from %s: %s",
|
||||
source, self.default_sampling_params)
|
||||
|
||||
# False by default.
|
||||
# If False (default), the "store" option is (silently) ignored and the
|
||||
# response is not stored. If True, the response is stored in memory.
|
||||
# NOTE(woosuk): This may not be intuitive for users, as the default
|
||||
# behavior in OpenAI's Responses API is to store the response, but
|
||||
# vLLM's default behavior is not.
|
||||
self.enable_store = envs.VLLM_ENABLE_RESPONSES_API_STORE
|
||||
if self.enable_store:
|
||||
logger.warning_once(
|
||||
"`VLLM_ENABLE_RESPONSES_API_STORE` is enabled. This may "
|
||||
"cause a memory leak since we never remove responses from "
|
||||
"the store.")
|
||||
# HACK(woosuk): This is a hack. We should use a better store.
|
||||
# FIXME: If enable_store=True, this may cause a memory leak since we
|
||||
# never remove responses from the store.
|
||||
@ -121,9 +130,25 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
if self.engine_client.errored:
|
||||
raise self.engine_client.dead_error
|
||||
|
||||
# If store is not enabled, return an error.
|
||||
if request.store and not self.enable_store:
|
||||
return self._make_store_not_supported_error()
|
||||
if request.background:
|
||||
return self.create_error_response(
|
||||
err_type="invalid_request_error",
|
||||
message=(
|
||||
"This vLLM engine does not support `store=True` and "
|
||||
"therefore does not support the background mode. To "
|
||||
"enable these features, set the environment variable "
|
||||
"`VLLM_ENABLE_RESPONSES_API_STORE=1` when launching "
|
||||
"the vLLM server."),
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
# Disable the store option.
|
||||
# NOTE(woosuk): Although returning an error is possible, we opted
|
||||
# to implicitly disable store and process the request anyway, as
|
||||
# we assume most users do not intend to actually store the response
|
||||
# (i.e., their request's `store=True` just because it's the default
|
||||
# value).
|
||||
request.store = False
|
||||
|
||||
# Handle the previous response ID.
|
||||
prev_response_id = request.previous_response_id
|
||||
|
||||
@ -1060,7 +1060,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
|
||||
# Enables support for the "store" option in the OpenAI Responses API.
|
||||
# When set to 1, vLLM's OpenAI server will retain the input and output
|
||||
# messages for those requests in memory. By default, this is disabled (0).
|
||||
# messages for those requests in memory. By default, this is disabled (0),
|
||||
# and the "store" option is ignored.
|
||||
# NOTE/WARNING:
|
||||
# 1. Messages are kept in memory only (not persisted to disk) and will be
|
||||
# lost when the vLLM server shuts down.
|
||||
|
||||
@ -1079,9 +1079,6 @@ class FusedMoE(torch.nn.Module):
|
||||
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
|
||||
f"got {shard_id}.")
|
||||
|
||||
WEIGHT_SCALE_SUPPORTED = [
|
||||
e.value for e in FusedMoeWeightScaleSupported
|
||||
]
|
||||
# Fetch the dim to shard the parameter/loaded weight
|
||||
# based on the shard id. This will be whatever
|
||||
# dimension intermediate_size_per_partition is used.
|
||||
@ -1230,6 +1227,9 @@ class FusedMoE(torch.nn.Module):
|
||||
loaded_weight=loaded_weight,
|
||||
expert_id=expert_id)
|
||||
else:
|
||||
WEIGHT_SCALE_SUPPORTED = [
|
||||
e.value for e in FusedMoeWeightScaleSupported
|
||||
]
|
||||
raise ValueError(
|
||||
f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}")
|
||||
return True if return_success else None
|
||||
|
||||
@ -24,10 +24,12 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, PPMissingLayer,
|
||||
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
@ -260,6 +262,81 @@ class ArceeModel(nn.Module):
|
||||
return hidden_states, aux_hidden_states
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
"""Load weights, mapping q/k/v projections to fused qkv_proj."""
|
||||
stacked_params_mapping = [
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
(".qkv_proj", ".k_proj", "k"),
|
||||
(".qkv_proj", ".v_proj", "v"),
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if ("rotary_emb.cos_cached" in name
|
||||
or "rotary_emb.sin_cached" in name):
|
||||
continue
|
||||
|
||||
if (self.quant_config is not None and
|
||||
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
||||
loaded_weight[0])
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
|
||||
if "scale" in name:
|
||||
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if remapped_name is None:
|
||||
continue
|
||||
name = remapped_name
|
||||
|
||||
mapped = False
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
mapped = True
|
||||
break
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
mapped = True
|
||||
break
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader # type: ignore[attr-defined]
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
loaded_params.add(name)
|
||||
mapped = True
|
||||
break
|
||||
|
||||
if mapped:
|
||||
continue
|
||||
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
|
||||
return loaded_params
|
||||
|
||||
|
||||
class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
"""Arcee Model for causal language modeling, integrated with vLLM
|
||||
@ -304,8 +381,7 @@ class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
else:
|
||||
# Placeholder for lm_head on non-last ranks
|
||||
self.lm_head = PPMissingLayer()
|
||||
# Provide a reference to the model's method for generating empty
|
||||
# tensors (used in pipeline parallel schedule)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
@ -316,7 +392,6 @@ class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
# Forward pass through the Arcee model backbone
|
||||
model_output = self.model(input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Adapted from
|
||||
# https://github.com/THUDM/ChatGLM2-6B
|
||||
# https://github.com/zai-org/ChatGLM2-6B
|
||||
"""Inference-only ChatGLM model compatible with THUDM weights."""
|
||||
import json
|
||||
from collections.abc import Iterable
|
||||
@ -86,10 +86,10 @@ class GLMAttention(nn.Module):
|
||||
prefix=f"{prefix}.dense",
|
||||
)
|
||||
|
||||
# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
|
||||
# https://huggingface.co/zai-org/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
|
||||
rope_ratio = getattr(config, "rope_ratio", 1.0)
|
||||
max_positions = getattr(config, "seq_length", 8192)
|
||||
# NOTE: THUDM/cogagent-9b-20241220 uses original_rope=False,
|
||||
# NOTE: zai-org/cogagent-9b-20241220 uses original_rope=False,
|
||||
# which is equivalent to is_neox_style=True
|
||||
is_neox_style = not config.original_rope
|
||||
self.rotary_emb = get_rope(
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/THUDM/CogAgent
|
||||
# https://github.com/zai-org/CogAgent
|
||||
"""Inference-only CogAgent model compatible with THUDM weights."""
|
||||
from argparse import Namespace
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
@ -393,7 +393,7 @@ def merge_multimodal_embeddings_from_map(
|
||||
inputs_embeds: torch.Tensor, multimodal_embeddings: NestedTensors,
|
||||
placeholder_map: MultiModalPlaceholderMap.IndexMap) -> torch.Tensor:
|
||||
"""
|
||||
Merge ``multimodal_embeddings`` into ``inputs_embeds`` using the provided
|
||||
Merge ``multimodal_embeddings`` into ``inputs_embeds`` using the provided
|
||||
placeholder map .
|
||||
|
||||
Note:
|
||||
@ -418,17 +418,23 @@ def _merge_multimodal_embeddings(
|
||||
Note:
|
||||
This updates ``inputs_embeds`` in place.
|
||||
"""
|
||||
num_expected_tokens = is_multimodal.sum().item()
|
||||
assert isinstance(num_expected_tokens, int)
|
||||
|
||||
flattened = _flatten_embeddings(multimodal_embeddings)
|
||||
if flattened.shape[0] != num_expected_tokens:
|
||||
expr = _embedding_count_expression(multimodal_embeddings)
|
||||
raise ValueError(
|
||||
f"Attempted to assign {expr} = {flattened.shape[0]} "
|
||||
f"multimodal tokens to {num_expected_tokens} placeholders")
|
||||
try:
|
||||
# This is equivalent to: inputs_embeds[is_multimodal] = flattened.
|
||||
inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1), flattened)
|
||||
except RuntimeError as e:
|
||||
num_expected_tokens = is_multimodal.sum().item()
|
||||
assert isinstance(num_expected_tokens, int)
|
||||
|
||||
if flattened.shape[0] != num_expected_tokens:
|
||||
expr = _embedding_count_expression(multimodal_embeddings)
|
||||
raise ValueError(
|
||||
f"Attempted to assign {expr} = {flattened.shape[0]} "
|
||||
f"multimodal tokens to {num_expected_tokens} placeholders"
|
||||
) from e
|
||||
else:
|
||||
raise ValueError("Error during masked scatter operation") from e
|
||||
|
||||
inputs_embeds[is_multimodal] = flattened
|
||||
return inputs_embeds
|
||||
|
||||
|
||||
@ -478,11 +484,11 @@ def merge_multimodal_embeddings(
|
||||
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
|
||||
positions in ``inputs_embeds`` corresponding to placeholder tokens in
|
||||
``input_ids``.
|
||||
|
||||
``placeholder_token_id`` can be a list of token ids (e.g, token ids
|
||||
of img_start, img_break, and img_end tokens) when needed: This means
|
||||
the order of these tokens in the ``input_ids`` MUST MATCH the order of
|
||||
their embeddings in ``multimodal_embeddings`` since we need to
|
||||
|
||||
``placeholder_token_id`` can be a list of token ids (e.g, token ids
|
||||
of img_start, img_break, and img_end tokens) when needed: This means
|
||||
the order of these tokens in the ``input_ids`` MUST MATCH the order of
|
||||
their embeddings in ``multimodal_embeddings`` since we need to
|
||||
slice-merge instead of individually scattering.
|
||||
|
||||
For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where
|
||||
@ -491,9 +497,9 @@ def merge_multimodal_embeddings(
|
||||
- I is image embedding token
|
||||
- B is image break token
|
||||
- E is image end token.
|
||||
|
||||
Then the image embeddings (that correspond to I's) from vision encoder
|
||||
must be padded with embeddings of S, B, and E in the same order of
|
||||
|
||||
Then the image embeddings (that correspond to I's) from vision encoder
|
||||
must be padded with embeddings of S, B, and E in the same order of
|
||||
input_ids for a correct embedding merge.
|
||||
|
||||
Note:
|
||||
|
||||
@ -270,6 +270,7 @@ class CudaPlatformBase(Platform):
|
||||
FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
|
||||
TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
|
||||
FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
|
||||
TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501
|
||||
|
||||
if selected_backend == _Backend.FLASHINFER:
|
||||
logger.info_once("Using FlashInfer backend on V1 engine.")
|
||||
@ -287,6 +288,9 @@ class CudaPlatformBase(Platform):
|
||||
elif selected_backend == _Backend.FLASH_ATTN:
|
||||
logger.info_once("Using Flash Attention backend on V1 engine.")
|
||||
return FLASH_ATTN_V1
|
||||
elif selected_backend == _Backend.TREE_ATTN:
|
||||
logger.info_once("Using Tree Attention backend on V1 engine.")
|
||||
return TREE_ATTN_V1
|
||||
|
||||
from vllm.attention.selector import is_attn_backend_supported
|
||||
|
||||
|
||||
@ -62,6 +62,7 @@ class _Backend(enum.Enum):
|
||||
DIFFERENTIAL_FLASH_ATTN = enum.auto()
|
||||
NO_ATTENTION = enum.auto()
|
||||
FLEX_ATTENTION = enum.auto()
|
||||
TREE_ATTN = enum.auto()
|
||||
|
||||
|
||||
class PlatformEnum(enum.Enum):
|
||||
|
||||
@ -156,6 +156,7 @@ class SamplingParams(
|
||||
Note that the implementation follows the OpenAI API: The API will
|
||||
always return the log probability of the sampled token, so there
|
||||
may be up to `logprobs+1` elements in the response.
|
||||
When set to -1, return all `vocab_size` log probabilities.
|
||||
prompt_logprobs: Number of log probabilities to return per prompt token.
|
||||
detokenize: Whether to detokenize the output. Defaults to True.
|
||||
skip_special_tokens: Whether to skip special tokens in the output.
|
||||
@ -414,9 +415,10 @@ class SamplingParams(
|
||||
raise ValueError(
|
||||
f"min_tokens must be less than or equal to "
|
||||
f"max_tokens={self.max_tokens}, got {self.min_tokens}.")
|
||||
if self.logprobs is not None and self.logprobs < 0:
|
||||
if (self.logprobs is not None and self.logprobs != -1
|
||||
and self.logprobs < 0):
|
||||
raise ValueError(
|
||||
f"logprobs must be non-negative, got {self.logprobs}.")
|
||||
f"logprobs must be non-negative or -1, got {self.logprobs}.")
|
||||
if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
|
||||
raise ValueError(f"prompt_logprobs must be non-negative, got "
|
||||
f"{self.prompt_logprobs}.")
|
||||
|
||||
@ -118,7 +118,7 @@ MODELS_ON_S3 = [
|
||||
"stabilityai/stablelm-zephyr-3b",
|
||||
"state-spaces/mamba-130m-hf",
|
||||
"TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ",
|
||||
"THUDM/glm-4v-9b",
|
||||
"zai-org/glm-4v-9b",
|
||||
"TIGER-Lab/Mantis-8B-siglip-llama3",
|
||||
"TIGER-Lab/VLM2Vec-Full",
|
||||
"tiiuae/falcon-40b",
|
||||
|
||||
@ -290,20 +290,29 @@ def _maybe_remap_hf_config_attrs(config: PretrainedConfig) -> PretrainedConfig:
|
||||
|
||||
|
||||
def maybe_override_with_speculators_target_model(
|
||||
model: str,
|
||||
tokenizer: str,
|
||||
trust_remote_code: bool,
|
||||
revision: Optional[str] = None) -> tuple[str, str]:
|
||||
model: str,
|
||||
tokenizer: str,
|
||||
trust_remote_code: bool,
|
||||
revision: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
If running a speculators config, override running model with target model
|
||||
"""
|
||||
is_gguf = check_gguf_file(model)
|
||||
if is_gguf:
|
||||
kwargs["gguf_file"] = Path(model).name
|
||||
gguf_model_repo = Path(model).parent
|
||||
else:
|
||||
gguf_model_repo = None
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||
model,
|
||||
model if gguf_model_repo is None else gguf_model_repo,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
token=_get_hf_token(),
|
||||
**kwargs,
|
||||
)
|
||||
spec_config = config_dict.get("speculators_config")
|
||||
spec_config = config_dict.get("speculators_config", None)
|
||||
# Return the target model
|
||||
if spec_config is not None:
|
||||
model = tokenizer = spec_config["verifier"]["name_or_path"]
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/THUDM/ChatGLM2-6B
|
||||
# https://github.com/zai-org/ChatGLM2-6B
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
|
||||
@ -271,7 +271,7 @@ def get_tokenizer(
|
||||
}
|
||||
tokenizer.add_special_tokens(special_tokens_map)
|
||||
|
||||
# NOTE: We can remove this after https://github.com/THUDM/ChatGLM3/issues/1324
|
||||
# NOTE: We can remove this after https://github.com/zai-org/ChatGLM3/issues/1324
|
||||
if type(tokenizer).__name__ in ("ChatGLMTokenizer",
|
||||
"ChatGLM4Tokenizer"):
|
||||
assert isinstance(tokenizer, PreTrainedTokenizer)
|
||||
|
||||
452
vllm/v1/attention/backends/tree_attn.py
Normal file
452
vllm/v1/attention/backends/tree_attn.py
Normal file
@ -0,0 +1,452 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with TreeAttention."""
|
||||
|
||||
import ast
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TreeAttentionBackend(AttentionBackend):
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@classmethod
|
||||
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
supported_head_sizes = cls.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
attn_type = cls.__name__.removesuffix("Backend")
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by {attn_type}. "
|
||||
f"Supported head sizes are: {supported_head_sizes}. "
|
||||
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
||||
"FlexAttention backend which supports all head sizes.")
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TREE_ATTN_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["TreeAttentionImpl"]:
|
||||
return TreeAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
||||
return TreeAttentionMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["TreeAttentionMetadataBuilder"]:
|
||||
return TreeAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class TreeAttentionMetadata:
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
num_prefill_tokens: int = 0
|
||||
num_decode_tokens: int = 0
|
||||
num_prefills: int = 0
|
||||
num_decodes: int = 0
|
||||
|
||||
tree_attn_bias: Optional[torch.Tensor] = None
|
||||
|
||||
# Cached Prefill/decode metadata.
|
||||
_cached_prefill_metadata: Optional["TreeAttentionMetadata"] = None
|
||||
_cached_decode_metadata: Optional["TreeAttentionMetadata"] = None
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["TreeAttentionMetadata"]:
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
|
||||
if self._cached_prefill_metadata is not None:
|
||||
# Recover cached prefill-phase attention
|
||||
# metadata structure
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
q_start_loc = self.query_start_loc[self.num_decodes:]
|
||||
q_seqlens = torch.diff(q_start_loc)
|
||||
kv_seqlens = self.seq_lens[self.num_decodes:]
|
||||
# Construct & cache prefill-phase attention metadata structure
|
||||
self._cached_prefill_metadata = TreeAttentionMetadata(
|
||||
num_actual_tokens=self.num_prefill_tokens,
|
||||
max_query_len=int(q_seqlens.max().item()),
|
||||
query_start_loc=q_start_loc - q_start_loc[0],
|
||||
max_seq_len=int(kv_seqlens.max().item()),
|
||||
seq_lens=kv_seqlens,
|
||||
block_table=self.block_table[self.num_decodes:],
|
||||
slot_mapping=self.slot_mapping[self.num_decode_tokens:],
|
||||
)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["TreeAttentionMetadata"]:
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
if self._cached_decode_metadata is not None:
|
||||
# Recover cached decode-phase attention
|
||||
# metadata structure
|
||||
return self._cached_decode_metadata
|
||||
|
||||
q_start_loc = self.query_start_loc[:self.num_decodes + 1]
|
||||
q_seqlens = torch.diff(q_start_loc)
|
||||
kv_seqlens = self.seq_lens[:self.num_decodes]
|
||||
# Construct & cache decode-phase attention metadata structure
|
||||
self._cached_decode_metadata = TreeAttentionMetadata(
|
||||
num_actual_tokens=self.num_decode_tokens,
|
||||
max_query_len=int(q_seqlens.max().item()),
|
||||
query_start_loc=q_start_loc,
|
||||
max_seq_len=int(kv_seqlens.max().item()),
|
||||
seq_lens=kv_seqlens,
|
||||
block_table=self.block_table[:self.num_decodes],
|
||||
slot_mapping=self.slot_mapping[:self.num_decode_tokens],
|
||||
tree_attn_bias=self.tree_attn_bias,
|
||||
)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
|
||||
class TreeAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[TreeAttentionMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
|
||||
spec_config = vllm_config.speculative_config
|
||||
spec_token_tree = (spec := spec_config) and spec.speculative_token_tree
|
||||
tree_choices: list[tuple[int,
|
||||
...]] = (ast.literal_eval(spec_token_tree)
|
||||
if spec_token_tree is not None else
|
||||
[(0, )])
|
||||
# Construct the tree attention bias.
|
||||
depth_counts = _get_depth_counts(tree_choices)
|
||||
self.tree_attn_bias = _prepare_tree_attn_bias(
|
||||
tree_choices,
|
||||
depth_counts,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
return reorder_batch_to_split_decodes_and_prefills(
|
||||
input_batch,
|
||||
scheduler_output,
|
||||
decode_threshold=self.tree_attn_bias.shape[0])
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> TreeAttentionMetadata:
|
||||
decode_threshold = self.tree_attn_bias.shape[0]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(common_attn_metadata,
|
||||
decode_threshold=decode_threshold))
|
||||
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
q_start_loc = common_attn_metadata.query_start_loc
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
kv_seqlens = common_attn_metadata.seq_lens
|
||||
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
|
||||
block_table = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
|
||||
return TreeAttentionMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefills=num_prefills,
|
||||
num_decodes=num_decodes,
|
||||
max_query_len=max_query_len,
|
||||
query_start_loc=q_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=kv_seqlens,
|
||||
block_table=block_table,
|
||||
slot_mapping=slot_mapping,
|
||||
tree_attn_bias=self.tree_attn_bias,
|
||||
)
|
||||
|
||||
def build_for_drafting(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
draft_index: int,
|
||||
) -> TreeAttentionMetadata:
|
||||
# Cache the original tree attention bias.
|
||||
orig_tree_attn_bias = self.tree_attn_bias
|
||||
|
||||
if draft_index == 0:
|
||||
# Use prefill for drafting at the root level.
|
||||
self.tree_attn_bias = torch.empty(0)
|
||||
else:
|
||||
# Slice the tree attention bias for drafting.
|
||||
query_len = common_attn_metadata.max_query_len
|
||||
start, end = draft_index, draft_index + query_len
|
||||
self.tree_attn_bias = self.tree_attn_bias[start:end,
|
||||
start:end].contiguous()
|
||||
|
||||
# Build attention bias.
|
||||
attn_metadata = self.build(0, common_attn_metadata, fast_build=True)
|
||||
|
||||
# Reset the tree attention bias to the original value.
|
||||
self.tree_attn_bias = orig_tree_attn_bias
|
||||
return attn_metadata
|
||||
|
||||
|
||||
def _get_depth_counts(sorted_tree_choices: list[tuple[int, ...]]) -> list[int]:
|
||||
# Count the number of choices at each depth of the tree.
|
||||
depth_counts = []
|
||||
prev_depth = 0
|
||||
for path in sorted_tree_choices:
|
||||
depth = len(path)
|
||||
if depth != prev_depth:
|
||||
depth_counts.append(0)
|
||||
depth_counts[depth - 1] += 1
|
||||
prev_depth = depth
|
||||
return depth_counts
|
||||
|
||||
|
||||
def _prepare_tree_attn_bias(
|
||||
sorted_tree_choices: list[tuple[int, ...]],
|
||||
depth_counts: list[int],
|
||||
dtype: Optional[torch.dtype],
|
||||
device: Optional[torch.device],
|
||||
) -> torch.Tensor:
|
||||
# +1 comes from the additional root node.
|
||||
tree_len = len(sorted_tree_choices) + 1
|
||||
tree_attn_mask = torch.full((tree_len, tree_len),
|
||||
-torch.inf,
|
||||
device=device,
|
||||
dtype=dtype)
|
||||
|
||||
# Set diagonal to all zeros. Each token should
|
||||
# attend to itself.
|
||||
mask_val = 0
|
||||
for i in range(tree_len):
|
||||
tree_attn_mask[i, i] = mask_val
|
||||
|
||||
# Set root to all zeros. All tokens attend to it.
|
||||
tree_attn_mask[:, 0] = mask_val
|
||||
|
||||
# Set all ancestors to zeros.
|
||||
start = 0
|
||||
for i in range(len(depth_counts)):
|
||||
for j in range(depth_counts[i]):
|
||||
cur_tree_choice = sorted_tree_choices[start + j]
|
||||
# Retrieve ancestor position.
|
||||
if len(cur_tree_choice) == 1:
|
||||
continue
|
||||
ancestor_idx = []
|
||||
for c in range(len(cur_tree_choice) - 1):
|
||||
ancestor_idx.append(
|
||||
sorted_tree_choices.index(cur_tree_choice[:c + 1]) + 1)
|
||||
tree_attn_mask[j + start + 1, ancestor_idx] = mask_val
|
||||
start += depth_counts[i]
|
||||
return tree_attn_mask
|
||||
|
||||
|
||||
class TreeAttentionImpl(AttentionImpl):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"TreeAttention does not support block-sparse attention.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
if logits_soft_cap is None:
|
||||
# Setting logits_soft_cap to 0 means no soft cap.
|
||||
logits_soft_cap = 0
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
if sliding_window is None:
|
||||
self.sliding_window = (-1, -1)
|
||||
else:
|
||||
self.sliding_window = (sliding_window - 1, 0)
|
||||
|
||||
TreeAttentionBackend.validate_head_size(head_size)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"TreeAttentionImpl.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: TreeAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with TreeAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for TreeAttentionImpl")
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
|
||||
# Cache the input KVs.
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
if self.kv_sharing_target_layer_name is None:
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
||||
# not padded. However, we don't need to do key[:num_actual_tokens]
|
||||
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
||||
# op uses the slot_mapping's shape to determine the number of
|
||||
# actual tokens.
|
||||
ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
descale_shape = (attn_metadata.query_start_loc.shape[0] - 1,
|
||||
key.shape[1])
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
unified_attention(
|
||||
q=query[num_decode_tokens:num_actual_tokens],
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
out=output[num_decode_tokens:num_actual_tokens],
|
||||
cu_seqlens_q=prefill_meta.query_start_loc,
|
||||
max_seqlen_q=prefill_meta.max_query_len,
|
||||
seqused_k=prefill_meta.seq_lens,
|
||||
max_seqlen_k=prefill_meta.max_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
block_table=prefill_meta.block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
q_descale=None, # Not supported
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
)
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
unified_attention(
|
||||
q=query[:num_decode_tokens],
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
out=output[:num_decode_tokens],
|
||||
cu_seqlens_q=decode_meta.query_start_loc,
|
||||
max_seqlen_q=decode_meta.max_query_len,
|
||||
seqused_k=decode_meta.seq_lens,
|
||||
max_seqlen_k=decode_meta.max_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
qq_bias=decode_meta.tree_attn_bias,
|
||||
window_size=self.sliding_window,
|
||||
block_table=decode_meta.block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
q_descale=None, # Not supported
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
)
|
||||
return output
|
||||
@ -214,6 +214,26 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
||||
return self.build(common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata)
|
||||
|
||||
def build_for_drafting(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
draft_index: int,
|
||||
) -> M:
|
||||
"""
|
||||
Build attention metadata for draft model. Uses build by default.
|
||||
|
||||
Args:
|
||||
common_attn_metadata: The common attention metadata.
|
||||
draft_index: The index of the current draft operation.
|
||||
When speculating a chain of tokens, this index refers to the
|
||||
draft attempt for the i-th token.
|
||||
For tree-based attention, this index instead refers to the
|
||||
draft attempt for the i-th level in the tree of tokens.
|
||||
"""
|
||||
return self.build(common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
fast_build=True)
|
||||
|
||||
def use_cascade_attention(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
|
||||
@ -138,7 +138,7 @@ class LogprobsProcessor:
|
||||
|
||||
def pop_prompt_logprobs(self) -> Optional[PromptLogprobs]:
|
||||
"""Pop and return all request prompt logprobs
|
||||
|
||||
|
||||
The logprobs processor aggregates prompt chunk logprobs
|
||||
over one or more prefill chunks. This method returns
|
||||
all prompt logprobs at once and then forgets them.
|
||||
@ -176,7 +176,8 @@ class LogprobsProcessor:
|
||||
Returns:
|
||||
dict[token id, Logprob]
|
||||
"""
|
||||
|
||||
if num_logprobs == -1:
|
||||
num_logprobs = len(logprobs)
|
||||
# We do not need a special case for the sampled token
|
||||
# being in the topk, since inserting duplicated data
|
||||
# into a dictionary twice is the same as doing it once.
|
||||
|
||||
@ -65,8 +65,11 @@ class Processor:
|
||||
params: SamplingParams,
|
||||
) -> None:
|
||||
max_logprobs = self.model_config.max_logprobs
|
||||
if max_logprobs == -1:
|
||||
return
|
||||
# Validate sample logprobs.
|
||||
if params.logprobs and params.logprobs > max_logprobs:
|
||||
if params.logprobs and (params.logprobs == -1
|
||||
or params.logprobs > max_logprobs):
|
||||
raise ValueError(
|
||||
f"Requested sample logprobs of {params.logprobs}, "
|
||||
f"which is greater than max allowed: {max_logprobs}")
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import ast
|
||||
from dataclasses import replace
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
@ -17,6 +19,8 @@ from vllm.model_executor.models import supports_multimodal
|
||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||
from vllm.utils import is_pin_memory_available
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata,
|
||||
TreeAttentionMetadataBuilder)
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
@ -74,18 +78,52 @@ class EagleProposer:
|
||||
(self.max_num_tokens, self.hidden_size),
|
||||
dtype=self.dtype,
|
||||
device=device)
|
||||
# We need +1 here because the arange is used to set query_start_loc,
|
||||
# which has one more element than batch_size.
|
||||
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs +
|
||||
1,
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
|
||||
max_batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||
self.arange = torch.arange(
|
||||
# We need +1 here because the arange is used to set query_start_loc,
|
||||
# which has one more element than batch_size.
|
||||
max_batch_size + 1,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
self.inputs_embeds = torch.zeros(
|
||||
(self.max_num_tokens, self.hidden_size),
|
||||
dtype=self.dtype,
|
||||
device=device)
|
||||
|
||||
# Parse the speculative token tree.
|
||||
spec_token_tree = self.speculative_config.speculative_token_tree
|
||||
self.tree_choices: list[tuple[int,
|
||||
...]] = ast.literal_eval(spec_token_tree)
|
||||
tree_depth = len(self.tree_choices[-1])
|
||||
# Precompute per-level properties of the tree.
|
||||
num_drafts_per_level = [0] * tree_depth
|
||||
for node in self.tree_choices:
|
||||
num_drafts_per_level[len(node) - 1] += 1
|
||||
self.cu_drafts_per_level = [num_drafts_per_level[0]]
|
||||
self.child_drafts_per_level = [num_drafts_per_level[0]]
|
||||
for level in range(1, tree_depth):
|
||||
self.cu_drafts_per_level.append(self.cu_drafts_per_level[-1] +
|
||||
num_drafts_per_level[level])
|
||||
self.child_drafts_per_level.append(num_drafts_per_level[level] //
|
||||
num_drafts_per_level[level - 1])
|
||||
# Find the first level where the tree branches off into one or more
|
||||
# children.
|
||||
self.first_branching_level = None
|
||||
for level in range(tree_depth):
|
||||
if self.cu_drafts_per_level[level] > level + 1:
|
||||
self.first_branching_level = level
|
||||
break
|
||||
# Precompute draft position offsets in flattened tree.
|
||||
self.tree_draft_pos_offsets = torch.arange(
|
||||
1,
|
||||
len(self.tree_choices) + 1,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
).repeat(max_batch_size, 1)
|
||||
|
||||
def propose(
|
||||
self,
|
||||
# [num_tokens]
|
||||
@ -120,11 +158,9 @@ class EagleProposer:
|
||||
assert self.runner is not None
|
||||
|
||||
# FIXME: need to consider multiple kv_cache_groups
|
||||
attn_metadata = self.runner.attn_metadata_builders[0].build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
fast_build=True,
|
||||
)
|
||||
attn_metadata = self.runner.attn_metadata_builders[
|
||||
0].build_for_drafting(common_attn_metadata=common_attn_metadata,
|
||||
draft_index=0)
|
||||
|
||||
# At this moment, we assume all eagle layers belong to the same KV
|
||||
# cache group, thus using the same attention metadata.
|
||||
@ -167,6 +203,22 @@ class EagleProposer:
|
||||
last_hidden_states, hidden_states = ret_hidden_states
|
||||
sample_hidden_states = last_hidden_states[last_token_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
positions = target_positions[last_token_indices]
|
||||
hidden_states = hidden_states[last_token_indices]
|
||||
if self.first_branching_level == 0:
|
||||
# Branching has occurred at the root level. Draft using tree
|
||||
# attention.
|
||||
draft_token_ids_list = self.propose_tree(
|
||||
tree_root_level=0,
|
||||
batch_size=batch_size,
|
||||
logits=logits,
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
# [batch_size, num_tree_tokens]
|
||||
return torch.cat(draft_token_ids_list, dim=1)
|
||||
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
|
||||
# Early exit if there is only one draft token to be generated.
|
||||
@ -178,16 +230,15 @@ class EagleProposer:
|
||||
# one layer. Adapt this code to support multiple layers once
|
||||
# there's a multi-layer MTP module.
|
||||
|
||||
# Currently FlashAttention is the only backend that supports
|
||||
# multi-token eagle spec decode. This is because the code below
|
||||
# Currently, only FlashAttention and TreeAttention support multi-token
|
||||
# eagle spec decode. This is because the code below
|
||||
# makes assumptions about attn_metadata attributes available.
|
||||
assert isinstance(attn_metadata, FlashAttentionMetadata)
|
||||
assert isinstance(attn_metadata,
|
||||
(FlashAttentionMetadata, TreeAttentionMetadata))
|
||||
|
||||
# Generate the remaining draft tokens.
|
||||
draft_token_ids_list = [draft_token_ids]
|
||||
|
||||
positions = target_positions[last_token_indices]
|
||||
hidden_states = hidden_states[last_token_indices]
|
||||
if self.use_cuda_graph and \
|
||||
batch_size <= self.cudagraph_batch_sizes[-1]:
|
||||
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
|
||||
@ -196,7 +247,7 @@ class EagleProposer:
|
||||
attn_metadata.num_actual_tokens = batch_size
|
||||
attn_metadata.max_query_len = 1
|
||||
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
|
||||
for _ in range(self.num_speculative_tokens - 1):
|
||||
for token_index in range(self.num_speculative_tokens - 1):
|
||||
# Update the inputs.
|
||||
# cast to int32 is crucial when eagle model is compiled.
|
||||
# tensor.argmax() returns int64 by default.
|
||||
@ -265,7 +316,20 @@ class EagleProposer:
|
||||
logits = self.model.compute_logits(last_hidden_states[:batch_size],
|
||||
None)
|
||||
|
||||
# TODO(wenlong): get more than one token for tree attention
|
||||
if self.first_branching_level == token_index + 1:
|
||||
# Branching has occurred. The remaining tokens are drafted
|
||||
# using tree attention.
|
||||
draft_token_ids_list += self.propose_tree(
|
||||
tree_root_level=token_index + 1,
|
||||
batch_size=batch_size,
|
||||
logits=logits,
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
# [batch_size, num_tree_tokens]
|
||||
return torch.cat(draft_token_ids_list, dim=1)
|
||||
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
draft_token_ids_list.append(draft_token_ids)
|
||||
|
||||
@ -273,6 +337,175 @@ class EagleProposer:
|
||||
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
|
||||
return draft_token_ids
|
||||
|
||||
def propose_tree(
|
||||
self,
|
||||
tree_root_level: int,
|
||||
batch_size: int,
|
||||
# [num_tokens, vocab_size]
|
||||
logits: torch.Tensor,
|
||||
# [num_tokens]
|
||||
positions: torch.Tensor,
|
||||
# [num_tokens, hidden_size]
|
||||
hidden_states: torch.Tensor,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
) -> list[torch.Tensor]:
|
||||
tree_attn_metadata_builder = self.runner.attn_metadata_builders[0]
|
||||
assert isinstance(tree_attn_metadata_builder,
|
||||
TreeAttentionMetadataBuilder)
|
||||
|
||||
total_num_drafts = self.cu_drafts_per_level[tree_root_level]
|
||||
level_num_drafts = total_num_drafts
|
||||
# Sample a draft token for each child at the tree root level.
|
||||
num_children = self.child_drafts_per_level[tree_root_level]
|
||||
if num_children == 1:
|
||||
draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
|
||||
else:
|
||||
draft_token_ids = torch.topk(logits, num_children,
|
||||
dim=-1).indices.view(batch_size, -1)
|
||||
draft_token_ids_list = [draft_token_ids]
|
||||
draft_hidden_states = hidden_states.view(batch_size, 1, -1)
|
||||
|
||||
# Initialize empty tensors for concatenation with the level outputs.
|
||||
tree_input_ids = torch.empty(0,
|
||||
device=self.input_ids.device,
|
||||
dtype=self.input_ids.dtype)
|
||||
tree_positions = torch.empty(0,
|
||||
device=self.positions.device,
|
||||
dtype=self.positions.dtype)
|
||||
tree_hidden_states = torch.empty(0,
|
||||
device=self.hidden_states.device,
|
||||
dtype=self.hidden_states.dtype)
|
||||
# Precompute the draft token positions.
|
||||
flattened_draft_positions = (
|
||||
positions.view(batch_size, -1) +
|
||||
self.tree_draft_pos_offsets[:batch_size, :])
|
||||
tree_depth = len(self.cu_drafts_per_level)
|
||||
for level in range(tree_root_level, tree_depth - 1):
|
||||
# Get draft positions for RoPE.
|
||||
draft_positions = positions + (level + 1)
|
||||
exceeds_max_model_len = (positions +
|
||||
total_num_drafts) >= self.max_model_len
|
||||
# Mask out the position ids that exceed the max model length.
|
||||
# Otherwise, we may get out-of-range error in RoPE.
|
||||
clamped_draft_positions = torch.where(
|
||||
exceeds_max_model_len,
|
||||
0,
|
||||
draft_positions,
|
||||
)
|
||||
if level_num_drafts > 1:
|
||||
# Repeat the positions for each draft at this level.
|
||||
draft_positions = clamped_draft_positions.repeat_interleave(
|
||||
level_num_drafts).reshape(batch_size, -1)
|
||||
|
||||
if num_children > 1:
|
||||
# Repeat draft hidden states for each child.
|
||||
draft_hidden_states = draft_hidden_states.repeat_interleave(
|
||||
num_children, dim=1)
|
||||
|
||||
# Concatenate the draft tokens, positions, and hidden states.
|
||||
tree_input_ids = torch.cat([tree_input_ids, draft_token_ids],
|
||||
dim=1)
|
||||
tree_positions = torch.cat([tree_positions, draft_positions],
|
||||
dim=1)
|
||||
tree_hidden_states = torch.cat(
|
||||
[tree_hidden_states, draft_hidden_states], dim=1)
|
||||
|
||||
# Build new attention metadata for the next level of drafts.
|
||||
# This is necessary to support tree attention.
|
||||
query_len = total_num_drafts - tree_root_level
|
||||
common_attn_metadata = replace(
|
||||
common_attn_metadata,
|
||||
query_start_loc=query_len * self.arange[:batch_size + 1],
|
||||
seq_lens=common_attn_metadata.seq_lens + level_num_drafts,
|
||||
num_actual_tokens=batch_size * query_len,
|
||||
max_query_len=query_len,
|
||||
)
|
||||
attn_metadata = tree_attn_metadata_builder.build_for_drafting(
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
draft_index=tree_root_level + 1,
|
||||
)
|
||||
|
||||
# Apply new attention metadata to all layers.
|
||||
per_layer_attn_metadata = {}
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
|
||||
# Consider max model length.
|
||||
attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
|
||||
self.max_model_len)
|
||||
# For the requests that exceed the max model length, we set the
|
||||
# sequence length to 1 to minimize their overheads in attention.
|
||||
attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
|
||||
|
||||
# Compute the slot mapping.
|
||||
query_positions = flattened_draft_positions[:, level:level +
|
||||
query_len]
|
||||
block_numbers = query_positions // self.block_size
|
||||
block_ids = attn_metadata.block_table.gather(dim=1,
|
||||
index=block_numbers)
|
||||
slot_mapping = (block_ids * self.block_size +
|
||||
query_positions % self.block_size)
|
||||
# Mask out the slot mappings that exceed the max model length.
|
||||
# Otherwise, the KV cache will be inadvertently updated with the
|
||||
# padding tokens.
|
||||
slot_mapping[exceeds_max_model_len] = PADDING_SLOT_ID
|
||||
attn_metadata.slot_mapping = slot_mapping.view(-1)
|
||||
|
||||
# Copy inputs to buffer for cudagraph.
|
||||
num_tokens = attn_metadata.num_actual_tokens
|
||||
input_ids = tree_input_ids.view(-1)
|
||||
self.input_ids[:num_tokens] = input_ids
|
||||
self.positions[:num_tokens] = tree_positions.view(-1)
|
||||
self.hidden_states[:num_tokens] = tree_hidden_states.view(
|
||||
num_tokens, -1)
|
||||
|
||||
if self.use_cuda_graph and \
|
||||
num_tokens <= self.cudagraph_batch_sizes[-1]:
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(
|
||||
num_tokens)
|
||||
else:
|
||||
num_input_tokens = num_tokens
|
||||
# Run the model.
|
||||
with set_forward_context(per_layer_attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens):
|
||||
last_hidden_states, hidden_states = self.model(
|
||||
input_ids=self.input_ids[:num_input_tokens],
|
||||
positions=self.positions[:num_input_tokens],
|
||||
hidden_states=self.hidden_states[:num_input_tokens],
|
||||
inputs_embeds=None,
|
||||
)
|
||||
|
||||
# Get the output hidden states for the draft tokens.
|
||||
draft_hidden_states = hidden_states[:num_tokens].view(
|
||||
batch_size, query_len, -1)[:, -level_num_drafts:]
|
||||
draft_last_hidden_states = last_hidden_states[:num_tokens].view(
|
||||
batch_size, query_len, -1)[:, -level_num_drafts:]
|
||||
|
||||
# Get the output logits for the draft tokens.
|
||||
logits = self.model.compute_logits(
|
||||
draft_last_hidden_states.reshape(batch_size * level_num_drafts,
|
||||
-1),
|
||||
None,
|
||||
)
|
||||
|
||||
# Sample a draft token for each child at the next tree level.
|
||||
num_children = self.child_drafts_per_level[level + 1]
|
||||
if num_children == 1:
|
||||
draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
|
||||
else:
|
||||
draft_token_ids = torch.topk(logits, num_children,
|
||||
dim=-1).indices.view(
|
||||
batch_size, -1)
|
||||
draft_token_ids_list.append(draft_token_ids)
|
||||
|
||||
# Update the # drafts counters for the next tree level.
|
||||
level_num_drafts = self.cu_drafts_per_level[level +
|
||||
1] - total_num_drafts
|
||||
total_num_drafts = self.cu_drafts_per_level[level + 1]
|
||||
|
||||
return draft_token_ids_list
|
||||
|
||||
def prepare_inputs(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
|
||||
@ -34,22 +34,22 @@ class ConstantList(Generic[T], Sequence):
|
||||
self._x = x
|
||||
|
||||
def append(self, item):
|
||||
raise Exception("Cannot append to a constant list")
|
||||
raise TypeError("Cannot append to a constant list")
|
||||
|
||||
def extend(self, item):
|
||||
raise Exception("Cannot extend a constant list")
|
||||
raise TypeError("Cannot extend a constant list")
|
||||
|
||||
def insert(self, item):
|
||||
raise Exception("Cannot insert into a constant list")
|
||||
raise TypeError("Cannot insert into a constant list")
|
||||
|
||||
def pop(self, item):
|
||||
raise Exception("Cannot pop from a constant list")
|
||||
raise TypeError("Cannot pop from a constant list")
|
||||
|
||||
def remove(self, item):
|
||||
raise Exception("Cannot remove from a constant list")
|
||||
raise TypeError("Cannot remove from a constant list")
|
||||
|
||||
def clear(self):
|
||||
raise Exception("Cannot clear a constant list")
|
||||
raise TypeError("Cannot clear a constant list")
|
||||
|
||||
def index(self,
|
||||
item: T,
|
||||
@ -78,10 +78,10 @@ class ConstantList(Generic[T], Sequence):
|
||||
...
|
||||
|
||||
def __setitem__(self, item: Union[int, slice], value: Union[T, list[T]]):
|
||||
raise Exception("Cannot set item in a constant list")
|
||||
raise TypeError("Cannot set item in a constant list")
|
||||
|
||||
def __delitem__(self, item):
|
||||
raise Exception("Cannot delete item from a constant list")
|
||||
raise TypeError("Cannot delete item from a constant list")
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._x)
|
||||
|
||||
@ -337,7 +337,9 @@ class InputBatch:
|
||||
self.generators[req_index] = request.generator
|
||||
|
||||
if sampling_params.logprobs is not None:
|
||||
self.num_logprobs[req_id] = sampling_params.logprobs
|
||||
self.num_logprobs[req_id] = (self.vocab_size
|
||||
if sampling_params.logprobs == -1
|
||||
else sampling_params.logprobs)
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
self.num_prompt_logprobs[
|
||||
req_id] = sampling_params.prompt_logprobs
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user