Merge branch 'main' into woosuk/fa3-swa-cudagraph

This commit is contained in:
Woosuk Kwon 2025-08-04 12:47:49 -07:00
commit 06fba5410c
43 changed files with 1458 additions and 182 deletions

View File

@ -49,6 +49,7 @@ best_throughput=0
best_max_num_seqs=0
best_num_batched_tokens=0
best_goodput=0
best_request_rate=0
start_server() {
local gpu_memory_utilization=$1
@ -57,18 +58,35 @@ start_server() {
local vllm_log=$4
local profile_dir=$5
pkill -f vllm
pkill -if vllm
VLLM_USE_V1=1 VLLM_SERVER_DEV_MODE=1 VLLM_TORCH_PROFILER_DIR=$profile_dir vllm serve $MODEL \
--port 8004 \
--gpu-memory-utilization $gpu_memory_utilization \
--max-num-seqs $max_num_seqs \
--max-num-batched-tokens $max_num_batched_tokens \
--tensor-parallel-size $TP \
--enable-prefix-caching \
--load-format dummy \
--download-dir "$DOWNLOAD_DIR" \
--max-model-len $MAX_MODEL_LEN > "$vllm_log" 2>&1 &
# Define the common arguments as a bash array.
# Each argument and its value are separate elements.
local common_args_array=(
"$MODEL"
"--disable-log-requests"
"--port" "8004"
"--gpu-memory-utilization" "$gpu_memory_utilization"
"--max-num-seqs" "$max_num_seqs"
"--max-num-batched-tokens" "$max_num_batched_tokens"
"--tensor-parallel-size" "$TP"
"--enable-prefix-caching"
"--load-format" "dummy"
"--download-dir" "$DOWNLOAD_DIR"
"--max-model-len" "$MAX_MODEL_LEN"
)
# Use the array expansion "${common_args_array[@]}"
# This correctly passes each element as a separate argument.
if [[ -n "$profile_dir" ]]; then
# Start server with profiling enabled
VLLM_USE_V1=1 VLLM_SERVER_DEV_MODE=1 VLLM_TORCH_PROFILER_DIR=$profile_dir \
vllm serve "${common_args_array[@]}" > "$vllm_log" 2>&1 &
else
# Start server without profiling
VLLM_USE_V1=1 VLLM_SERVER_DEV_MODE=1 \
vllm serve "${common_args_array[@]}" > "$vllm_log" 2>&1 &
fi
# wait for 10 minutes...
server_started=0
@ -82,6 +100,7 @@ start_server() {
sleep 10
fi
done
if (( ! server_started )); then
echo "server did not start within 10 minutes. Please check server log at $vllm_log".
return 1
@ -90,37 +109,20 @@ start_server() {
fi
}
update_best_profile() {
local profile_dir=$1
local profile_index=$2
sorted_paths=($(find "$profile_dir" -maxdepth 1 -not -path "$profile_dir" | sort))
selected_profile_file=
if [[ "$SYSTEM" == "TPU" ]]; then
selected_profile_file="${sorted_paths[$profile_index]}/*.xplane.pb"
fi
if [[ "$SYSTEM" == "GPU" ]]; then
selected_profile_file="${sorted_paths[$profile_index]}"
fi
rm -f $PROFILE_PATH/*
cp $selected_profile_file $PROFILE_PATH
}
run_benchmark() {
local max_num_seqs=$1
local max_num_batched_tokens=$2
local gpu_memory_utilization=$3
echo "max_num_seq: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens"
local vllm_log="$LOG_FOLDER/vllm_log_${max_num_seqs}_${max_num_batched_tokens}.txt"
local profile_dir="$LOG_FOLDER/profile_${max_num_seqs}_${max_num_batched_tokens}"
echo "vllm_log: $vllm_log"
echo
rm -f $vllm_log
mkdir -p $profile_dir
pkill -f vllm
local profile_index=0
pkill -if vllm
echo "starting server..."
start_server $gpu_memory_utilization $max_num_seqs $max_num_batched_tokens $vllm_log $profile_dir
# Call start_server without a profile_dir to avoid profiling overhead
start_server $gpu_memory_utilization $max_num_seqs $max_num_batched_tokens $vllm_log ""
result=$?
if [[ "$result" -eq 1 ]]; then
echo "server failed to start. gpu_memory_utilization:$gpu_memory_utilization, max_num_seqs:$max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens"
@ -134,7 +136,8 @@ run_benchmark() {
# get a basic qps by using request-rate inf
bm_log="$LOG_FOLDER/bm_log_${max_num_seqs}_${max_num_batched_tokens}_requestrate_inf.txt"
prefix_len=$(( INPUT_LEN * MIN_CACHE_HIT_PCT / 100 ))
adjusted_input_len=$(( INPUT_LEN - prefix_len ))
adjusted_input_len=$(( INPUT_LEN - prefix_len ))
# --profile flag is removed from this call
vllm bench serve \
--backend vllm \
--model $MODEL \
@ -148,8 +151,7 @@ adjusted_input_len=$(( INPUT_LEN - prefix_len ))
--goodput e2el:$MAX_LATENCY_ALLOWED_MS \
--num-prompts 1000 \
--random-prefix-len $prefix_len \
--port 8004 \
--profile &> "$bm_log"
--port 8004 &> "$bm_log"
throughput=$(grep "Request throughput (req/s):" "$bm_log" | sed 's/[^0-9.]//g')
e2el=$(grep "P99 E2EL (ms):" "$bm_log" | awk '{print $NF}')
goodput=$(grep "Request goodput (req/s):" "$bm_log" | sed 's/[^0-9.]//g')
@ -163,7 +165,6 @@ adjusted_input_len=$(( INPUT_LEN - prefix_len ))
# start from request-rate as int(throughput) + 1
request_rate=$((${throughput%.*} + 1))
while ((request_rate > 0)); do
profile_index=$((profile_index+1))
# clear prefix cache
curl -X POST http://0.0.0.0:8004/reset_prefix_cache
sleep 5
@ -201,12 +202,7 @@ adjusted_input_len=$(( INPUT_LEN - prefix_len ))
best_max_num_seqs=$max_num_seqs
best_num_batched_tokens=$max_num_batched_tokens
best_goodput=$goodput
if [[ "$SYSTEM" == "TPU" ]]; then
update_best_profile "$profile_dir/plugins/profile" $profile_index
fi
if [[ "$SYSTEM" == "GPU" ]]; then
update_best_profile "$profile_dir" $profile_index
fi
best_request_rate=$request_rate
fi
else
echo "max_num_seqs: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens does not meet latency requirement ${MAX_LATENCY_ALLOWED_MS}"
@ -215,7 +211,7 @@ adjusted_input_len=$(( INPUT_LEN - prefix_len ))
echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput"
pkill vllm
pkill -if vllm
sleep 10
printf '=%.0s' $(seq 1 20)
return 0
@ -228,7 +224,8 @@ read -r -a num_batched_tokens_list <<< "$NUM_BATCHED_TOKENS_LIST"
gpu_memory_utilization=0.98
find_gpu_memory_utilization=0
while (( $(echo "$gpu_memory_utilization >= 0.9" | bc -l) )); do
start_server $gpu_memory_utilization "${num_seqs_list[-1]}" "${num_batched_tokens_list[-1]}" "$LOG_FOLDER/vllm_log_gpu_memory_utilization_$gpu_memory_utilization.log"
# Pass empty string for profile_dir argument
start_server $gpu_memory_utilization "${num_seqs_list[-1]}" "${num_batched_tokens_list[-1]}" "$LOG_FOLDER/vllm_log_gpu_memory_utilization_$gpu_memory_utilization.log" ""
result=$?
if [[ "$result" -eq 0 ]]; then
find_gpu_memory_utilization=1
@ -251,5 +248,45 @@ for num_seqs in "${num_seqs_list[@]}"; do
done
done
echo "finish permutations"
# =================================================================================
# FINAL PROFILING RUN FOR THE BEST CONFIGURATION
# =================================================================================
if (( $(echo "$best_throughput > 0" | bc -l) )); then
echo
echo "Benchmark tuning finished. Now running profiling on the best configuration found..."
echo "Best config: max_num_seqs: $best_max_num_seqs, max_num_batched_tokens: $best_num_batched_tokens, throughput: $best_throughput"
echo
vllm_log="$LOG_FOLDER/vllm_log_BEST_PROFILE.txt"
bm_log="$LOG_FOLDER/bm_log_BEST_PROFILE.txt"
# Start server with the best params and profiling ENABLED
echo "Starting server for profiling..."
start_server $gpu_memory_utilization $best_max_num_seqs $best_num_batched_tokens "$vllm_log" "$PROFILE_PATH"
# Run benchmark with the best params and the --profile flag
echo "Running benchmark with profiling..."
prefix_len=$(( INPUT_LEN * MIN_CACHE_HIT_PCT / 100 ))
adjusted_input_len=$(( INPUT_LEN - prefix_len ))
vllm bench serve \
--backend vllm \
--model $MODEL \
--dataset-name random \
--random-input-len $adjusted_input_len \
--random-output-len $OUTPUT_LEN \
--ignore-eos \
--disable-tqdm \
--request-rate $best_request_rate \
--percentile-metrics ttft,tpot,itl,e2el \
--goodput e2el:$MAX_LATENCY_ALLOWED_MS \
--num-prompts 100 \
--random-prefix-len $prefix_len \
--port 8004 \
--profile &> "$bm_log"
else
echo "No configuration met the latency requirements. Skipping final profiling run."
fi
pkill -if vllm
echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput, profile saved in: $PROFILE_PATH"
echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput, profile saved in: $PROFILE_PATH" >> "$RESULT"

View File

@ -120,7 +120,7 @@ A code example can be found here: <gh-file:examples/offline_inference/basic/clas
### `LLM.score`
The [score][vllm.LLM.score] method outputs similarity scores between sentence pairs.
It is designed for embedding models and cross encoder models. Embedding models use cosine similarity, and [cross-encoder models](https://www.sbert.net/examples/applications/cross-encoder/README.html) serve as rerankers between candidate query-document pairs in RAG systems.
It is designed for embedding models and cross-encoder models. Embedding models use cosine similarity, and [cross-encoder models](https://www.sbert.net/examples/applications/cross-encoder/README.html) serve as rerankers between candidate query-document pairs in RAG systems.
!!! note
vLLM can only perform the model inference component (e.g. embedding, reranking) of RAG.

View File

@ -311,6 +311,8 @@ See [this page](generative_models.md) for more information on how to use generat
#### Text Generation
These models primarily accept the [`LLM.generate`](./generative_models.md#llmgenerate) API. Chat/Instruct models additionally support the [`LLM.chat`](./generative_models.md#llmchat) API.
<style>
th {
white-space: nowrap;
@ -328,7 +330,7 @@ th {
| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | ✅︎ |
| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | |
| `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | |
| `ChatGLMModel`, `ChatGLMForConditionalGeneration` | ChatGLM | `THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, `ShieldLM-6B-chatglm3`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `ChatGLMModel`, `ChatGLMForConditionalGeneration` | ChatGLM | `zai-org/chatglm2-6b`, `zai-org/chatglm3-6b`, `ShieldLM-6B-chatglm3`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `CohereForCausalLM`, `Cohere2ForCausalLM` | Command-R | `CohereForAI/c4ai-command-r-v01`, `CohereForAI/c4ai-command-r7b-12-2024`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `DbrxForCausalLM` | DBRX | `databricks/dbrx-base`, `databricks/dbrx-instruct`, etc. | | ✅︎ | ✅︎ |
| `DeciLMForCausalLM` | DeciLM | `nvidia/Llama-3_3-Nemotron-Super-49B-v1`, etc. | ✅︎ | ✅︎ | ✅︎ |
@ -348,8 +350,8 @@ th {
| `Gemma2ForCausalLM` | Gemma 2 | `google/gemma-2-9b`, `google/gemma-2-27b`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Gemma3nForConditionalGeneration` | Gemma 3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ |
| `GlmForCausalLM` | GLM-4 | `THUDM/glm-4-9b-chat-hf`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Glm4ForCausalLM` | GLM-4-0414 | `THUDM/GLM-4-32B-0414`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `GlmForCausalLM` | GLM-4 | `zai-org/glm-4-9b-chat-hf`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Glm4ForCausalLM` | GLM-4-0414 | `zai-org/GLM-4-32B-0414`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `GPT2LMHeadModel` | GPT-2 | `gpt2`, `gpt2-xl`, etc. | | ✅︎ | ✅︎ |
| `GPTBigCodeForCausalLM` | StarCoder, SantaCoder, WizardCoder | `bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, `WizardLM/WizardCoder-15B-V1.0`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `GPTJForCausalLM` | GPT-J | `EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc. | | ✅︎ | ✅︎ |
@ -419,7 +421,9 @@ See [this page](./pooling_models.md) for more information on how to use pooling
Since some model architectures support both generative and pooling tasks,
you should explicitly specify `--runner pooling` to ensure that the model is used in pooling mode instead of generative mode.
#### Text Embedding
#### Embedding
These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) API.
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) |
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
@ -457,28 +461,10 @@ If your model is not in the above list, we will try to automatically convert the
[as_embedding_model][vllm.model_executor.models.adapters.as_embedding_model]. By default, the embeddings
of the whole prompt are extracted from the normalized hidden state corresponding to the last token.
#### Reward Modeling
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) |
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
| `InternLM2ForRewardModel` | InternLM2-based | `internlm/internlm2-1_8b-reward`, `internlm/internlm2-7b-reward`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `LlamaForCausalLM`<sup>C</sup> | Llama-based | `peiyi9979/math-shepherd-mistral-7b-prm`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen2ForRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-RM-72B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen2ForProcessRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-PRM-7B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* | \* |
<sup>C</sup> Automatically converted into a reward model via `--convert reward`. ([details](./pooling_models.md#model-conversion))
\* Feature support is the same as that of the original model.
If your model is not in the above list, we will try to automatically convert the model using
[as_reward_model][vllm.model_executor.models.adapters.as_reward_model]. By default, we return the hidden states of each token directly.
!!! important
For process-supervised reward models such as `peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly,
e.g.: `--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`.
#### Classification
These models primarily support the [`LLM.classify`](./pooling_models.md#llmclassify) API.
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) |
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ | |
@ -491,7 +477,10 @@ If your model is not in the above list, we will try to automatically convert the
If your model is not in the above list, we will try to automatically convert the model using
[as_seq_cls_model][vllm.model_executor.models.adapters.as_seq_cls_model]. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token.
#### Sentence Pair Scoring
#### Cross-encoder / Reranker
Cross-encoder and reranker models are a subset of classification models that accept two prompts as input.
These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) API.
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) |
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
@ -501,6 +490,7 @@ If your model is not in the above list, we will try to automatically convert the
| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | | |
| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | | | |
| `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* | \* |
<sup>C</sup> Automatically converted into a classification model via `--convert classify`. ([details](./pooling_models.md#model-conversion))
\* Feature support is the same as that of the original model.
@ -526,6 +516,28 @@ If your model is not in the above list, we will try to automatically convert the
vllm serve Qwen/Qwen3-Reranker-0.6B --hf_overrides '{"architectures": ["Qwen3ForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}'
```
#### Reward Modeling
These models primarily support the [`LLM.reward`](./pooling_models.md#llmreward) API.
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) |
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
| `InternLM2ForRewardModel` | InternLM2-based | `internlm/internlm2-1_8b-reward`, `internlm/internlm2-7b-reward`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `LlamaForCausalLM`<sup>C</sup> | Llama-based | `peiyi9979/math-shepherd-mistral-7b-prm`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen2ForRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-RM-72B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen2ForProcessRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-PRM-7B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* | \* |
<sup>C</sup> Automatically converted into a reward model via `--convert reward`. ([details](./pooling_models.md#model-conversion))
\* Feature support is the same as that of the original model.
If your model is not in the above list, we will try to automatically convert the model using
[as_reward_model][vllm.model_executor.models.adapters.as_reward_model]. By default, we return the hidden states of each token directly.
!!! important
For process-supervised reward models such as `peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly,
e.g.: `--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`.
[](){ #supported-mm-models }
## List of Multimodal Language Models
@ -579,6 +591,8 @@ See [this page](generative_models.md) for more information on how to use generat
#### Text Generation
These models primarily accept the [`LLM.generate`](./generative_models.md#llmgenerate) API. Chat/Instruct models additionally support the [`LLM.chat`](./generative_models.md#llmchat) API.
| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) |
|--------------|--------|--------|-------------------|----------------------|---------------------------|---------------------|
| `AriaForConditionalGeneration` | Aria | T + I<sup>+</sup> | `rhymes-ai/Aria` | | | ✅︎ |
@ -589,8 +603,8 @@ See [this page](generative_models.md) for more information on how to use generat
| `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large`, etc. | | | |
| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | ✅︎ |
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ |
| `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `THUDM/glm-4v-9b`, `THUDM/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + I<sup>E+</sup> + V<sup>E+</sup> | `THUDM/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Glm4MoeForCausalLM` | GLM-4.5 | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.5`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Glm4v_moeForConditionalGeneration` | GLM-4.5V | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.5V-Air`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ |
@ -720,11 +734,9 @@ Speech2Text models trained specifically for Automatic Speech Recognition.
See [this page](./pooling_models.md) for more information on how to use pooling models.
!!! important
Since some model architectures support both generative and pooling tasks,
you should explicitly specify `--runner pooling` to ensure that the model is used in pooling mode instead of generative mode.
#### Embedding
#### Text Embedding
These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) API.
!!! note
To get the best results, you should use pooling models that are specifically trained as such.
@ -742,7 +754,10 @@ The following table lists those that are tested in vLLM.
---
#### Scoring
#### Cross-encoder / Reranker
Cross-encoder and reranker models are a subset of classification models that accept two prompts as input.
These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) API.
| Architecture | Models | Inputs | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | [V1](gh-issue:8779) |
|-------------------------------------|--------------------|----------|--------------------------|------------------------|-----------------------------|-----------------------|

View File

@ -221,7 +221,7 @@ def run_gemma3(questions: list[str], modality: str) -> ModelRequestData:
# GLM-4v
def run_glm4v(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
model_name = "THUDM/glm-4v-9b"
model_name = "zai-org/glm-4v-9b"
engine_args = EngineArgs(
model=model_name,
@ -250,7 +250,7 @@ def run_glm4v(questions: list[str], modality: str) -> ModelRequestData:
# GLM-4.1V
def run_glm4_1v(questions: list[str], modality: str) -> ModelRequestData:
model_name = "THUDM/GLM-4.1V-9B-Thinking"
model_name = "zai-org/GLM-4.1V-9B-Thinking"
engine_args = EngineArgs(
model=model_name,

View File

@ -154,7 +154,7 @@ TEXT_GENERATION_MODELS = {
"baichuan-inc/Baichuan-7B": PPTestSettings.fast(),
"baichuan-inc/Baichuan2-13B-Chat": PPTestSettings.fast(),
"bigscience/bloomz-1b1": PPTestSettings.fast(),
"THUDM/chatglm3-6b": PPTestSettings.fast(),
"zai-org/chatglm3-6b": PPTestSettings.fast(),
"CohereForAI/c4ai-command-r-v01": PPTestSettings.fast(load_format="dummy"),
"databricks/dbrx-instruct": PPTestSettings.fast(load_format="dummy"),
"Deci/DeciLM-7B-instruct": PPTestSettings.fast(),
@ -224,7 +224,7 @@ MULTIMODAL_MODELS = {
"Salesforce/blip2-opt-6.7b": PPTestSettings.fast(),
"facebook/chameleon-7b": PPTestSettings.fast(),
"adept/fuyu-8b": PPTestSettings.fast(),
"THUDM/glm-4v-9b": PPTestSettings.fast(),
"zai-org/glm-4v-9b": PPTestSettings.fast(),
"OpenGVLab/InternVL2-1B": PPTestSettings.fast(),
"llava-hf/llava-1.5-7b-hf": PPTestSettings.fast(),
"llava-hf/llava-v1.6-mistral-7b-hf": PPTestSettings.fast(),

View File

@ -14,7 +14,7 @@ from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.utils import merge_async_iterators
MODEL_PATH = "THUDM/chatglm3-6b"
MODEL_PATH = "zai-org/chatglm3-6b"
LORA_RANK = 64
DEFAULT_MAX_LORAS = 4 * 3

View File

@ -6,7 +6,7 @@ from vllm.lora.request import LoRARequest
from ..utils import create_new_process_for_each_test, multi_gpu_test
MODEL_PATH = "THUDM/chatglm3-6b"
MODEL_PATH = "zai-org/chatglm3-6b"
PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501

View File

@ -53,7 +53,7 @@ AITER_MODEL_LIST = [
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
),
pytest.param(
"THUDM/chatglm3-6b", # chatglm (text-only)
"zai-org/chatglm3-6b", # chatglm (text-only)
),
pytest.param(
"meta-llama/Llama-3.2-1B-Instruct", # llama

View File

@ -355,7 +355,7 @@ VLM_TEST_SETTINGS = {
num_logprobs=10,
),
"glm4v": VLMTestInfo(
models=["THUDM/glm-4v-9b"],
models=["zai-org/glm-4v-9b"],
test_type=VLMTestType.IMAGE,
prompt_formatter=lambda img_prompt: f"<|user|>\n{img_prompt}<|assistant|>", # noqa: E501
single_image_prompts=IMAGE_ASSETS.prompts({
@ -374,7 +374,7 @@ VLM_TEST_SETTINGS = {
marks=[large_gpu_mark(min_gb=32)],
),
"glm4_1v": VLMTestInfo(
models=["THUDM/GLM-4.1V-9B-Thinking"],
models=["zai-org/GLM-4.1V-9B-Thinking"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|user|>\n{img_prompt}<|assistant|>", # noqa: E501
img_idx_to_prompt=lambda idx: "<|begin_of_image|><|image|><|end_of_image|>", # noqa: E501
@ -388,7 +388,7 @@ VLM_TEST_SETTINGS = {
marks=[large_gpu_mark(min_gb=32)],
),
"glm4_1v-video": VLMTestInfo(
models=["THUDM/GLM-4.1V-9B-Thinking"],
models=["zai-org/GLM-4.1V-9B-Thinking"],
# GLM4.1V require include video metadata for input
test_type=VLMTestType.CUSTOM_INPUTS,
max_model_len=4096,

View File

@ -271,8 +271,8 @@ def _test_processing_correctness_one(
"microsoft/Florence-2-base",
"adept/fuyu-8b",
"google/gemma-3-4b-it",
"THUDM/glm-4v-9b",
"THUDM/GLM-4.1V-9B-Thinking",
"zai-org/glm-4v-9b",
"zai-org/GLM-4.1V-9B-Thinking",
"ibm-granite/granite-speech-3.3-2b",
"h2oai/h2ovl-mississippi-800m",
"internlm/Intern-S1",

View File

@ -9,7 +9,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from ...utils import build_model_context
@pytest.mark.parametrize("model_id", ["THUDM/GLM-4.1V-9B-Thinking"])
@pytest.mark.parametrize("model_id", ["zai-org/GLM-4.1V-9B-Thinking"])
@pytest.mark.parametrize("expected_toks_per_frame", [299])
@pytest.mark.parametrize("num_frames", [32, 128])
@pytest.mark.parametrize("fps, expected_grid_t", [(1, 5), (2, 10)])

View File

@ -139,8 +139,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code=True),
"AquilaForCausalLM": _HfExamplesInfo("BAAI/AquilaChat2-7B",
trust_remote_code=True),
"ArceeForCausalLM": _HfExamplesInfo("arcee-ai/AFM-4.5B-Base",
is_available_online=False),
"ArceeForCausalLM": _HfExamplesInfo("arcee-ai/AFM-4.5B-Base"),
"ArcticForCausalLM": _HfExamplesInfo("Snowflake/snowflake-arctic-instruct",
trust_remote_code=True),
"BaiChuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan-7B",
@ -153,7 +152,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
extras={"tiny": "hmellor/tiny-random-BambaForCausalLM"}), # noqa: E501
"BloomForCausalLM": _HfExamplesInfo("bigscience/bloom-560m",
{"1b": "bigscience/bloomz-1b1"}),
"ChatGLMModel": _HfExamplesInfo("THUDM/chatglm3-6b",
"ChatGLMModel": _HfExamplesInfo("zai-org/chatglm3-6b",
trust_remote_code=True,
max_transformers_version="4.48"),
"ChatGLMForConditionalGeneration": _HfExamplesInfo("thu-coai/ShieldLM-6B-chatglm3", # noqa: E501
@ -187,8 +186,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"),
"Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it", # noqa: E501
min_transformers_version="4.53"),
"GlmForCausalLM": _HfExamplesInfo("THUDM/glm-4-9b-chat-hf"),
"Glm4ForCausalLM": _HfExamplesInfo("THUDM/GLM-4-9B-0414"),
"GlmForCausalLM": _HfExamplesInfo("zai-org/glm-4-9b-chat-hf"),
"Glm4ForCausalLM": _HfExamplesInfo("zai-org/GLM-4-9B-0414"),
"Glm4MoeForCausalLM": _HfExamplesInfo("zai-org/GLM-4.5",
min_transformers_version="4.54"), # noqa: E501
"GPT2LMHeadModel": _HfExamplesInfo("openai-community/gpt2",
@ -380,10 +379,10 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
"Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"),
"GraniteSpeechForConditionalGeneration": _HfExamplesInfo("ibm-granite/granite-speech-3.3-2b"), # noqa: E501
"GLM4VForCausalLM": _HfExamplesInfo("THUDM/glm-4v-9b",
"GLM4VForCausalLM": _HfExamplesInfo("zai-org/glm-4v-9b",
trust_remote_code=True,
hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501
"Glm4vForConditionalGeneration": _HfExamplesInfo("THUDM/GLM-4.1V-9B-Thinking"), # noqa: E501
"Glm4vForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.1V-9B-Thinking"), # noqa: E501
"Glm4v_moeForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.5V-Air",
is_available_online=False), # noqa: E501
"H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m",

View File

@ -10,7 +10,7 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer,
get_cached_tokenizer)
@pytest.mark.parametrize("model_id", ["gpt2", "THUDM/chatglm3-6b"])
@pytest.mark.parametrize("model_id", ["gpt2", "zai-org/chatglm3-6b"])
def test_cached_tokenizer(model_id: str):
reference_tokenizer = AutoTokenizer.from_pretrained(model_id,
trust_remote_code=True)

View File

@ -17,7 +17,7 @@ from vllm.v1.kv_cache_interface import FullAttentionSpec
BACKENDS_TO_TEST = [
_Backend.FLASH_ATTN_VLLM_V1, _Backend.FLASHINFER_VLLM_V1,
_Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN_VLLM_V1
_Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN_VLLM_V1, _Backend.TREE_ATTN
]
# Remove flashinfer from the list if it's not available

View File

@ -109,11 +109,11 @@ def create_common_attn_metadata(
def get_attention_backend(backend_name: _Backend):
"""Set up attention backend classes for testing.
Args:
backend_name: Name of the backend ("flash_attn", "flashinfer", etc.)
vllm_config: VllmConfig instance
Returns:
Tuple of (backend_builder_class, backend_impl_class)
"""
@ -126,6 +126,8 @@ def get_attention_backend(backend_name: _Backend):
"vllm.v1.attention.backends.flex_attention.FlexAttentionBackend",
_Backend.TRITON_ATTN_VLLM_V1:
"vllm.v1.attention.backends.triton_attn.TritonAttentionBackend",
_Backend.TREE_ATTN:
"vllm.v1.attention.backends.tree_attn.TreeAttentionBackend",
}
if backend_name not in backend_map:

View File

@ -10,7 +10,7 @@ from vllm.assets.image import ImageAsset
from vllm.config import KVTransferConfig
from vllm.multimodal.utils import encode_image_base64
MODEL_NAME = "RedHatAI/Qwen2.5-VL-3B-Instruct-quantized.w4a16"
MODEL_NAME = "RedHatAI/Qwen2.5-VL-3B-Instruct-quantized.w8a8"
SAMPLING_PARAMS = SamplingParams(temperature=0.0, top_k=1, max_tokens=128)

View File

@ -429,6 +429,33 @@ def test_zero_logprobs(vllm_model, example_prompts,
assert len(prompt_token_ids) == len(prompt_logprobs)
def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch):
"""Engine should return all vocabulary logprobs
Args:
example_prompts: list of example prompts (test fixture)
"""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
runner = VllmRunner(
"facebook/opt-125m",
max_logprobs=-1,
enable_prefix_caching=False,
# 2 other llms alive during whole session
gpu_memory_utilization=0.15,
max_model_len=256)
sampling_params_logprobs_all = SamplingParams(max_tokens=5,
logprobs=-1)
results_logprobs_all = runner.llm.generate(
example_prompts, sampling_params=sampling_params_logprobs_all)
vocab_size = runner.llm.llm_engine.get_model_config().get_vocab_size()
for i in range(len(results_logprobs_all)):
logprobs = results_logprobs_all[i].outputs[0].logprobs
assert logprobs is not None
for logprob in logprobs:
assert len(logprob) == vocab_size
@pytest.mark.parametrize(
"logprobs_mode",
["raw_logprobs", "raw_logits", "processed_logprobs", "processed_logits"])

View File

@ -202,7 +202,9 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
def test_propose(num_speculative_tokens):
@pytest.mark.parametrize("backend",
[_Backend.FLASH_ATTN_VLLM_V1, _Backend.TREE_ATTN])
def test_propose(num_speculative_tokens, backend):
# Use GPU device
device = torch.device(current_platform.device_type)
@ -301,8 +303,7 @@ def test_propose(num_speculative_tokens):
device=device)
sampling_metadata = mock.MagicMock()
attn_metadata_builder_cls, _ = get_attention_backend(
_Backend.FLASH_ATTN_VLLM_V1)
attn_metadata_builder_cls, _ = get_attention_backend(backend)
attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
layer_names=proposer.attn_layer_names,

View File

@ -0,0 +1,299 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from typing import Optional
import torch
from tests.v1.attention.utils import (_Backend, create_standard_kv_cache_spec,
create_vllm_config,
get_attention_backend)
from vllm.config import ParallelConfig, SpeculativeConfig
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
class MockAttentionLayer(torch.nn.Module):
_q_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
_k_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
_v_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
def __init__(self):
super().__init__()
def forward(self, x):
return x
def forward_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kv_cache: torch.Tensor,
block_table: torch.Tensor,
slot_mapping: torch.Tensor,
seqlen_k: int,
backend: _Backend,
spec_token_tree: Optional[str] = None,
num_spec_tokens: int = 0,
) -> torch.Tensor:
batch_size, q_len, num_heads, dim_per_head = q.shape
num_kv_heads = k.shape[-2]
# Initialize the query and KV sequence lengths.
query_start_loc = q_len * torch.arange(
batch_size + 1, device=q.device, dtype=torch.int32)
query_lens = torch.diff(query_start_loc)
seq_lens = torch.full(
(batch_size, ),
seqlen_k,
device=q.device,
dtype=torch.int32,
)
context_lens = seq_lens - query_lens
max_query_len = q_len
num_actual_tokens = query_start_loc[-1]
softmax_scale = q.shape[-1]**(-0.5)
layer = MockAttentionLayer()
# Build common metadata.
model_name = "meta-llama/Meta-Llama-3-8B"
builder_cls, impl_cls = get_attention_backend(backend)
vllm_config = create_vllm_config(model_name=model_name,
max_model_len=max(seq_lens))
if spec_token_tree is not None:
# Create speculative config if token tree is specified.
vllm_config.speculative_config = SpeculativeConfig(
target_model_config=vllm_config.model_config,
target_parallel_config=ParallelConfig(),
model=model_name,
method="eagle",
num_speculative_tokens=num_spec_tokens,
speculative_token_tree=spec_token_tree)
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
builder = builder_cls(kv_cache_spec, [], vllm_config, q.device)
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc.cpu(),
seq_lens=seq_lens,
seq_lens_cpu=seq_lens.cpu(),
num_computed_tokens_cpu=context_lens.cpu(),
num_reqs=batch_size,
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
block_table_tensor=block_table,
slot_mapping=slot_mapping,
)
# Build attention metadata.
attn_metadata = builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
# Initialize the backend implementation.
instance = impl_cls(
num_heads=num_heads,
head_size=dim_per_head,
scale=softmax_scale,
num_kv_heads=num_kv_heads,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype="auto",
)
# Run forward pass and return output.
query = q.view(-1, num_heads, dim_per_head)
key = k.view(-1, num_kv_heads, dim_per_head)
value = v.view(-1, num_kv_heads, dim_per_head)
output = torch.empty_like(query)
return instance.forward(
layer=layer,
query=query,
key=key,
value=value,
kv_cache=kv_cache.clone(),
attn_metadata=attn_metadata,
output=output,
)
def test_tree_attn_correctness() -> None:
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
device = "cuda"
tree_attn_masks = {
# Chain.
"[(0,), (0, 0), (0, 0, 0)]":
torch.tensor(
[
[1, 0, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 0],
[1, 1, 1, 1],
],
device=device,
dtype=torch.int32,
),
# Tree.
"[(0,), (1,), (0, 0), (0, 1), (1, 0), (1, 1)]":
torch.tensor(
[
[1, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0],
[1, 0, 1, 0, 0, 0, 0],
[1, 1, 0, 1, 0, 0, 0],
[1, 1, 0, 0, 1, 0, 0],
[1, 0, 1, 0, 0, 1, 0],
[1, 0, 1, 0, 0, 0, 1],
],
device=device,
dtype=torch.int32,
),
}
dim_per_head = 128
num_kv_heads = 2
block_size = 128
max_sequence_length = 8192
randomize_blocks = True
for batch_size in [1, 16, 32]:
for num_heads in [2, 4]:
for sequence_position in [16, 1024, 2048]:
for spec_token_tree, tree_attn_mask in tree_attn_masks.items():
# Assert that the number of heads is divisible
# by the number of KV heads.
assert num_heads % num_kv_heads == 0
# Initialize q, k, and v.
tree_size_q = tree_attn_mask.shape[0]
seqlen_k = sequence_position + tree_size_q
q = torch.randn(
(batch_size, tree_size_q, num_heads, dim_per_head),
device=device,
dtype=torch.bfloat16,
)
k = torch.randn(
(batch_size, tree_size_q, num_kv_heads, dim_per_head),
device=device,
dtype=torch.bfloat16,
)
v = torch.randn(
(batch_size, tree_size_q, num_kv_heads, dim_per_head),
device=device,
dtype=torch.bfloat16,
)
# Setup the block table and KV cache for paged KV.
assert max_sequence_length % block_size == 0
max_blocks_per_batch = max_sequence_length // block_size
kv_cache = torch.randn(
(
2,
batch_size * max_blocks_per_batch,
block_size,
num_kv_heads,
dim_per_head,
),
device=q.device,
dtype=torch.bfloat16,
)
num_alloc_blocks_per_batch = math.ceil(seqlen_k /
block_size)
block_table = torch.zeros(
(batch_size, max_blocks_per_batch),
device=q.device,
dtype=torch.int32,
)
block_ids = torch.arange(
0,
batch_size * num_alloc_blocks_per_batch,
device=q.device,
dtype=torch.int32,
)
if randomize_blocks:
# Randomize the block ids.
block_ids = block_ids[torch.randperm(
block_ids.numel())]
block_table[:, :
num_alloc_blocks_per_batch] = block_ids.view(
-1, num_alloc_blocks_per_batch)
# Setup the slot mapping for the input KVs.
tree_positions = sequence_position + torch.arange(
0,
tree_size_q,
device=q.device,
dtype=torch.int64,
).repeat(batch_size, 1)
tree_slot_mapping = _gen_slot_mapping(
tree_positions, block_table, block_size)
# Compute attention for the tree.
tree_attn_output = forward_attention(
q=q,
k=k,
v=v,
kv_cache=kv_cache,
block_table=block_table,
slot_mapping=tree_slot_mapping,
seqlen_k=seqlen_k,
backend=_Backend.TREE_ATTN,
spec_token_tree=spec_token_tree,
num_spec_tokens=tree_size_q - 1,
).view(batch_size, -1, num_heads, dim_per_head)
# Verify that the chain attention output for each
# branch of the tree (computed using FA3) matches
# the tree attention output.
for q_index in range(tree_size_q):
# Get the q, k, and v for the branch.
branch_mask = tree_attn_mask[q_index, :]
branch_indices = torch.nonzero(branch_mask,
as_tuple=True)[0]
q_len = branch_indices.shape[0]
q_branch = q[:, branch_indices]
k_branch = k[:, branch_indices]
v_branch = v[:, branch_indices]
# Setup slot mapping for the branch.
branch_positions = sequence_position + torch.arange(
0,
q_len,
device=q.device,
dtype=torch.int64,
).repeat(batch_size, 1)
branch_slot_mapping = _gen_slot_mapping(
branch_positions, block_table, block_size)
# Compute flash attention for the branch.
flash_attn_output = forward_attention(
q=q_branch,
k=k_branch,
v=v_branch,
kv_cache=kv_cache,
block_table=block_table,
slot_mapping=branch_slot_mapping,
seqlen_k=sequence_position + q_len,
backend=_Backend.FLASH_ATTN_VLLM_V1,
).view(batch_size, -1, num_heads, dim_per_head)
# Compare the outputs.
assert torch.allclose(
tree_attn_output[:, branch_indices],
flash_attn_output,
atol=7.81e-3,
), (f"outputs are not close for "
f"batch_size: {batch_size}, "
f"num_heads: {num_heads}, "
f"sequence_position: {sequence_position}, "
f"tree_attn_mask: {tree_attn_mask}, "
f"q_index: {q_index}.")
def _gen_slot_mapping(positions: torch.Tensor, block_table: torch.Tensor,
block_size: int):
block_indices = positions // block_size
blocks = block_table.gather(dim=1, index=block_indices)
return (blocks * block_size + positions % block_size).view(-1)

View File

@ -55,6 +55,7 @@ def kernel_unified_attention_2d(
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
seq_lens_ptr, # [num_seqs]
alibi_slopes_ptr, # [num_query_heads]
qq_bias_ptr, # [num_query_tokens, num_query_tokens]
scale, # float32
k_scale, # float32
v_scale, # float32
@ -66,10 +67,12 @@ def kernel_unified_attention_2d(
query_stride_1: tl.int64, # int, should be equal to head_size
output_stride_0: tl.int64, # int
output_stride_1: tl.int64, # int, should be equal to head_size
qq_bias_stride_0: tl.int64, # int
BLOCK_SIZE: tl.constexpr, # int
HEAD_SIZE: tl.constexpr, # int
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
USE_ALIBI_SLOPES: tl.constexpr, # bool
USE_QQ_BIAS: tl.constexpr, # bool
USE_SOFTCAP: tl.constexpr, # bool
SLIDING_WINDOW: tl.constexpr, # int
stride_k_cache_0: tl.int64, # int
@ -144,6 +147,11 @@ def kernel_unified_attention_2d(
mask=query_mask_1,
other=0.0)
# query-query attention bias
if USE_QQ_BIAS:
qq_bias_row_ptrs = (qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0
) # shape: [BLOCK_M]
# compute the length of the longest sequence prefix spanned by any
# query token in the current q_block (q_block_local_idx)
max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + (
@ -223,6 +231,18 @@ def kernel_unified_attention_2d(
if USE_ALIBI_SLOPES:
S += alibi_slope[:, None] * (seq_offset - context_len)
if USE_QQ_BIAS:
# compute key positions relative to query section
key_rel_pos = seq_offset - context_len # shape: [BLOCK_SIZE]
# load bias only for keys that correspond to queries
is_query_key = key_rel_pos >= 0 and key_rel_pos < qq_bias_stride_0
qq_bias = tl.load(
qq_bias_row_ptrs + key_rel_pos[None, :],
mask=is_query_key[None, :], # avoid OOB for context keys
other=0.0,
)
S += qq_bias
# compute running maximum
# m_j : (BLOCK_M,)
m_j = tl.maximum(M, tl.max(S, axis=1))
@ -275,6 +295,7 @@ def kernel_unified_attention_3d(
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
seq_lens_ptr, # [num_seqs]
alibi_slopes_ptr, # [num_query_heads]
qq_bias_ptr, # [num_query_tokens, num_query_tokens]
scale, # float32
k_scale, # float32
v_scale, # float32
@ -284,10 +305,12 @@ def kernel_unified_attention_3d(
block_table_stride: tl.int64, # int
query_stride_0: tl.int64, # int
query_stride_1: tl.int64, # int, should be equal to head_size
qq_bias_stride_0: tl.int64, # int
BLOCK_SIZE: tl.constexpr, # int
HEAD_SIZE: tl.constexpr, # int
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
USE_ALIBI_SLOPES: tl.constexpr, # bool
USE_QQ_BIAS: tl.constexpr, # bool
USE_SOFTCAP: tl.constexpr, # bool
SLIDING_WINDOW: tl.constexpr, # int
stride_k_cache_0: tl.int64, # int
@ -373,6 +396,11 @@ def kernel_unified_attention_3d(
mask=query_mask_1,
other=0.0)
# query-query attention bias
if USE_QQ_BIAS:
qq_bias_row_ptrs = (qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0
) # shape: [BLOCK_M]
num_blocks = cdiv_fn(seq_len, BLOCK_SIZE)
# iterate through tiles within current segment
@ -442,6 +470,18 @@ def kernel_unified_attention_3d(
if USE_ALIBI_SLOPES:
S += alibi_slope[:, None] * (seq_offset - context_len)
if USE_QQ_BIAS:
# compute key positions relative to query section
key_rel_pos = seq_offset - context_len # shape: [BLOCK_SIZE]
# load bias only for keys that correspond to queries
is_query_key = key_rel_pos >= 0 and key_rel_pos < qq_bias_stride_0
qq_bias = tl.load(
qq_bias_row_ptrs + key_rel_pos[None, :],
mask=is_query_key[None, :], # avoid OOB for context keys
other=0.0,
)
S += qq_bias
# compute running maximum
# m_j : (BLOCK_M,)
m_j = tl.maximum(M, tl.max(S, axis=1))
@ -586,6 +626,7 @@ def unified_attention(
k_descale,
v_descale,
alibi_slopes=None,
qq_bias=None,
):
assert causal, "Only causal attention is supported"
assert q_descale is None, "Q scales not supported"
@ -595,6 +636,7 @@ def unified_attention(
"Block size must be at least 32 for fp8"
use_alibi_slopes = alibi_slopes is not None
use_qq_bias = qq_bias is not None
block_size = v.shape[1]
num_seqs = len(seqused_k)
@ -630,6 +672,7 @@ def unified_attention(
block_tables_ptr=block_table,
seq_lens_ptr=seqused_k,
alibi_slopes_ptr=alibi_slopes,
qq_bias_ptr=qq_bias,
scale=softmax_scale,
k_scale=k_descale,
v_scale=v_descale,
@ -641,10 +684,12 @@ def unified_attention(
query_stride_1=q.stride(1),
output_stride_0=out.stride(0),
output_stride_1=out.stride(1),
qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0,
BLOCK_SIZE=block_size,
HEAD_SIZE=head_size,
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
USE_ALIBI_SLOPES=use_alibi_slopes,
USE_QQ_BIAS=use_qq_bias,
USE_SOFTCAP=(softcap > 0),
SLIDING_WINDOW=(1 + window_size[0]),
stride_k_cache_0=k.stride(0),
@ -699,6 +744,7 @@ def unified_attention(
block_tables_ptr=block_table,
seq_lens_ptr=seqused_k,
alibi_slopes_ptr=alibi_slopes,
qq_bias_ptr=qq_bias,
scale=softmax_scale,
k_scale=k_descale,
v_scale=v_descale,
@ -708,10 +754,12 @@ def unified_attention(
block_table_stride=block_table.stride(0),
query_stride_0=q.stride(0),
query_stride_1=q.stride(1),
qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0,
BLOCK_SIZE=block_size,
HEAD_SIZE=head_size,
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
USE_ALIBI_SLOPES=use_alibi_slopes,
USE_QQ_BIAS=use_qq_bias,
USE_SOFTCAP=(softcap > 0),
SLIDING_WINDOW=(1 + window_size[0]),
stride_k_cache_0=k.stride(0),

View File

@ -377,7 +377,8 @@ class ModelConfig:
max_logprobs: int = 20
"""Maximum number of log probabilities to return when `logprobs` is
specified in `SamplingParams`. The default value comes the default for the
OpenAI Chat Completions API."""
OpenAI Chat Completions API. -1 means no cap, i.e. all (output_length *
vocab_size) logprobs are allowed to be returned and it may cause OOM."""
logprobs_mode: LogprobsMode = "raw_logprobs"
"""Indicates the content returned in the logprobs and prompt_logprobs.
Supported mode:
@ -1585,7 +1586,7 @@ class ModelConfig:
"""
This method attempts to retrieve the non-default values of the
generation config for this model.
The generation config can contain information about special tokens, as
well as sampling parameters. Which is why this method exists separately
to `get_diff_sampling_param`.
@ -2066,7 +2067,7 @@ class ParallelConfig:
and when data_parallel_size > 0. Enables running an AsyncLLM
and API server on a "per-node" basis where vLLM load balances
between local data parallel ranks, but an external LB balances
between vLLM nodes/replicas. Set explicitly in conjunction with
between vLLM nodes/replicas. Set explicitly in conjunction with
--data-parallel-start-rank."""
enable_expert_parallel: bool = False
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
@ -3049,6 +3050,19 @@ class SpeculativeConfig:
f"num_speculative_tokens:{self.num_speculative_tokens}"
f" must be divisible by {n_predict=}")
if self.speculative_token_tree is None:
# Generate chain of tokens.
self.speculative_token_tree = str([
(i + 1) * (0, )
for i in range(self.num_speculative_tokens)
])
else:
# Sort the token tree breadth-first.
tree_choices = ast.literal_eval(
self.speculative_token_tree)
self.speculative_token_tree = str(
sorted(tree_choices, key=lambda t: (len(t), t)))
self.draft_tensor_parallel_size = \
SpeculativeConfig._verify_and_get_draft_tp(
self.target_parallel_config,

View File

@ -1454,7 +1454,6 @@ class EngineArgs:
"Please consider using other speculative decoding methods "
"such as ngram, medusa, eagle, or deepseek_mtp.")
# No XFormers so far.
V1_BACKENDS = [
"FLASH_ATTN_VLLM_V1",
"FLASH_ATTN",
@ -1469,6 +1468,7 @@ class EngineArgs:
"ROCM_AITER_MLA",
"TORCH_SDPA_VLLM_V1",
"FLEX_ATTENTION",
"TREE_ATTN",
]
if (envs.is_set("VLLM_ATTENTION_BACKEND")
and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):

View File

@ -90,8 +90,17 @@ class OpenAIServingResponses(OpenAIServing):
logger.info("Using default chat sampling params from %s: %s",
source, self.default_sampling_params)
# False by default.
# If False (default), the "store" option is (silently) ignored and the
# response is not stored. If True, the response is stored in memory.
# NOTE(woosuk): This may not be intuitive for users, as the default
# behavior in OpenAI's Responses API is to store the response, but
# vLLM's default behavior is not.
self.enable_store = envs.VLLM_ENABLE_RESPONSES_API_STORE
if self.enable_store:
logger.warning_once(
"`VLLM_ENABLE_RESPONSES_API_STORE` is enabled. This may "
"cause a memory leak since we never remove responses from "
"the store.")
# HACK(woosuk): This is a hack. We should use a better store.
# FIXME: If enable_store=True, this may cause a memory leak since we
# never remove responses from the store.
@ -121,9 +130,25 @@ class OpenAIServingResponses(OpenAIServing):
if self.engine_client.errored:
raise self.engine_client.dead_error
# If store is not enabled, return an error.
if request.store and not self.enable_store:
return self._make_store_not_supported_error()
if request.background:
return self.create_error_response(
err_type="invalid_request_error",
message=(
"This vLLM engine does not support `store=True` and "
"therefore does not support the background mode. To "
"enable these features, set the environment variable "
"`VLLM_ENABLE_RESPONSES_API_STORE=1` when launching "
"the vLLM server."),
status_code=HTTPStatus.BAD_REQUEST,
)
# Disable the store option.
# NOTE(woosuk): Although returning an error is possible, we opted
# to implicitly disable store and process the request anyway, as
# we assume most users do not intend to actually store the response
# (i.e., their request's `store=True` just because it's the default
# value).
request.store = False
# Handle the previous response ID.
prev_response_id = request.previous_response_id

View File

@ -1060,7 +1060,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Enables support for the "store" option in the OpenAI Responses API.
# When set to 1, vLLM's OpenAI server will retain the input and output
# messages for those requests in memory. By default, this is disabled (0).
# messages for those requests in memory. By default, this is disabled (0),
# and the "store" option is ignored.
# NOTE/WARNING:
# 1. Messages are kept in memory only (not persisted to disk) and will be
# lost when the vLLM server shuts down.

View File

@ -1079,9 +1079,6 @@ class FusedMoE(torch.nn.Module):
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
f"got {shard_id}.")
WEIGHT_SCALE_SUPPORTED = [
e.value for e in FusedMoeWeightScaleSupported
]
# Fetch the dim to shard the parameter/loaded weight
# based on the shard id. This will be whatever
# dimension intermediate_size_per_partition is used.
@ -1230,6 +1227,9 @@ class FusedMoE(torch.nn.Module):
loaded_weight=loaded_weight,
expert_id=expert_id)
else:
WEIGHT_SCALE_SUPPORTED = [
e.value for e in FusedMoeWeightScaleSupported
]
raise ValueError(
f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}")
return True if return_success else None

View File

@ -24,10 +24,12 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer,
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
@ -260,6 +262,81 @@ class ArceeModel(nn.Module):
return hidden_states, aux_hidden_states
return hidden_states
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
"""Load weights, mapping q/k/v projections to fused qkv_proj."""
stacked_params_mapping = [
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
continue
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if "scale" in name:
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
if remapped_name is None:
continue
name = remapped_name
mapped = False
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if name.endswith(".bias") and name not in params_dict:
mapped = True
break
if is_pp_missing_parameter(name, self):
mapped = True
break
param = params_dict[name]
weight_loader = param.weight_loader # type: ignore[attr-defined]
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(name)
mapped = True
break
if mapped:
continue
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"""Arcee Model for causal language modeling, integrated with vLLM
@ -304,8 +381,7 @@ class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
else:
# Placeholder for lm_head on non-last ranks
self.lm_head = PPMissingLayer()
# Provide a reference to the model's method for generating empty
# tensors (used in pipeline parallel schedule)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
@ -316,7 +392,6 @@ class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None
) -> Union[torch.Tensor, IntermediateTensors]:
# Forward pass through the Arcee model backbone
model_output = self.model(input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from
# https://github.com/THUDM/ChatGLM2-6B
# https://github.com/zai-org/ChatGLM2-6B
"""Inference-only ChatGLM model compatible with THUDM weights."""
import json
from collections.abc import Iterable
@ -86,10 +86,10 @@ class GLMAttention(nn.Module):
prefix=f"{prefix}.dense",
)
# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
# https://huggingface.co/zai-org/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
rope_ratio = getattr(config, "rope_ratio", 1.0)
max_positions = getattr(config, "seq_length", 8192)
# NOTE: THUDM/cogagent-9b-20241220 uses original_rope=False,
# NOTE: zai-org/cogagent-9b-20241220 uses original_rope=False,
# which is equivalent to is_neox_style=True
is_neox_style = not config.original_rope
self.rotary_emb = get_rope(

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from
# https://github.com/THUDM/CogAgent
# https://github.com/zai-org/CogAgent
"""Inference-only CogAgent model compatible with THUDM weights."""
from argparse import Namespace
from collections.abc import Mapping, Sequence

View File

@ -393,7 +393,7 @@ def merge_multimodal_embeddings_from_map(
inputs_embeds: torch.Tensor, multimodal_embeddings: NestedTensors,
placeholder_map: MultiModalPlaceholderMap.IndexMap) -> torch.Tensor:
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` using the provided
Merge ``multimodal_embeddings`` into ``inputs_embeds`` using the provided
placeholder map .
Note:
@ -418,17 +418,23 @@ def _merge_multimodal_embeddings(
Note:
This updates ``inputs_embeds`` in place.
"""
num_expected_tokens = is_multimodal.sum().item()
assert isinstance(num_expected_tokens, int)
flattened = _flatten_embeddings(multimodal_embeddings)
if flattened.shape[0] != num_expected_tokens:
expr = _embedding_count_expression(multimodal_embeddings)
raise ValueError(
f"Attempted to assign {expr} = {flattened.shape[0]} "
f"multimodal tokens to {num_expected_tokens} placeholders")
try:
# This is equivalent to: inputs_embeds[is_multimodal] = flattened.
inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1), flattened)
except RuntimeError as e:
num_expected_tokens = is_multimodal.sum().item()
assert isinstance(num_expected_tokens, int)
if flattened.shape[0] != num_expected_tokens:
expr = _embedding_count_expression(multimodal_embeddings)
raise ValueError(
f"Attempted to assign {expr} = {flattened.shape[0]} "
f"multimodal tokens to {num_expected_tokens} placeholders"
) from e
else:
raise ValueError("Error during masked scatter operation") from e
inputs_embeds[is_multimodal] = flattened
return inputs_embeds
@ -478,11 +484,11 @@ def merge_multimodal_embeddings(
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
positions in ``inputs_embeds`` corresponding to placeholder tokens in
``input_ids``.
``placeholder_token_id`` can be a list of token ids (e.g, token ids
of img_start, img_break, and img_end tokens) when needed: This means
the order of these tokens in the ``input_ids`` MUST MATCH the order of
their embeddings in ``multimodal_embeddings`` since we need to
``placeholder_token_id`` can be a list of token ids (e.g, token ids
of img_start, img_break, and img_end tokens) when needed: This means
the order of these tokens in the ``input_ids`` MUST MATCH the order of
their embeddings in ``multimodal_embeddings`` since we need to
slice-merge instead of individually scattering.
For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where
@ -491,9 +497,9 @@ def merge_multimodal_embeddings(
- I is image embedding token
- B is image break token
- E is image end token.
Then the image embeddings (that correspond to I's) from vision encoder
must be padded with embeddings of S, B, and E in the same order of
Then the image embeddings (that correspond to I's) from vision encoder
must be padded with embeddings of S, B, and E in the same order of
input_ids for a correct embedding merge.
Note:

View File

@ -270,6 +270,7 @@ class CudaPlatformBase(Platform):
FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501
if selected_backend == _Backend.FLASHINFER:
logger.info_once("Using FlashInfer backend on V1 engine.")
@ -287,6 +288,9 @@ class CudaPlatformBase(Platform):
elif selected_backend == _Backend.FLASH_ATTN:
logger.info_once("Using Flash Attention backend on V1 engine.")
return FLASH_ATTN_V1
elif selected_backend == _Backend.TREE_ATTN:
logger.info_once("Using Tree Attention backend on V1 engine.")
return TREE_ATTN_V1
from vllm.attention.selector import is_attn_backend_supported

View File

@ -62,6 +62,7 @@ class _Backend(enum.Enum):
DIFFERENTIAL_FLASH_ATTN = enum.auto()
NO_ATTENTION = enum.auto()
FLEX_ATTENTION = enum.auto()
TREE_ATTN = enum.auto()
class PlatformEnum(enum.Enum):

View File

@ -156,6 +156,7 @@ class SamplingParams(
Note that the implementation follows the OpenAI API: The API will
always return the log probability of the sampled token, so there
may be up to `logprobs+1` elements in the response.
When set to -1, return all `vocab_size` log probabilities.
prompt_logprobs: Number of log probabilities to return per prompt token.
detokenize: Whether to detokenize the output. Defaults to True.
skip_special_tokens: Whether to skip special tokens in the output.
@ -414,9 +415,10 @@ class SamplingParams(
raise ValueError(
f"min_tokens must be less than or equal to "
f"max_tokens={self.max_tokens}, got {self.min_tokens}.")
if self.logprobs is not None and self.logprobs < 0:
if (self.logprobs is not None and self.logprobs != -1
and self.logprobs < 0):
raise ValueError(
f"logprobs must be non-negative, got {self.logprobs}.")
f"logprobs must be non-negative or -1, got {self.logprobs}.")
if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
raise ValueError(f"prompt_logprobs must be non-negative, got "
f"{self.prompt_logprobs}.")

View File

@ -118,7 +118,7 @@ MODELS_ON_S3 = [
"stabilityai/stablelm-zephyr-3b",
"state-spaces/mamba-130m-hf",
"TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ",
"THUDM/glm-4v-9b",
"zai-org/glm-4v-9b",
"TIGER-Lab/Mantis-8B-siglip-llama3",
"TIGER-Lab/VLM2Vec-Full",
"tiiuae/falcon-40b",

View File

@ -290,20 +290,29 @@ def _maybe_remap_hf_config_attrs(config: PretrainedConfig) -> PretrainedConfig:
def maybe_override_with_speculators_target_model(
model: str,
tokenizer: str,
trust_remote_code: bool,
revision: Optional[str] = None) -> tuple[str, str]:
model: str,
tokenizer: str,
trust_remote_code: bool,
revision: Optional[str] = None,
**kwargs,
) -> tuple[str, str]:
"""
If running a speculators config, override running model with target model
"""
is_gguf = check_gguf_file(model)
if is_gguf:
kwargs["gguf_file"] = Path(model).name
gguf_model_repo = Path(model).parent
else:
gguf_model_repo = None
config_dict, _ = PretrainedConfig.get_config_dict(
model,
model if gguf_model_repo is None else gguf_model_repo,
revision=revision,
trust_remote_code=trust_remote_code,
token=_get_hf_token(),
**kwargs,
)
spec_config = config_dict.get("speculators_config")
spec_config = config_dict.get("speculators_config", None)
# Return the target model
if spec_config is not None:
model = tokenizer = spec_config["verifier"]["name_or_path"]

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from
# https://github.com/THUDM/ChatGLM2-6B
# https://github.com/zai-org/ChatGLM2-6B
from transformers import PretrainedConfig

View File

@ -271,7 +271,7 @@ def get_tokenizer(
}
tokenizer.add_special_tokens(special_tokens_map)
# NOTE: We can remove this after https://github.com/THUDM/ChatGLM3/issues/1324
# NOTE: We can remove this after https://github.com/zai-org/ChatGLM3/issues/1324
if type(tokenizer).__name__ in ("ChatGLMTokenizer",
"ChatGLM4Tokenizer"):
assert isinstance(tokenizer, PreTrainedTokenizer)

View File

@ -0,0 +1,452 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with TreeAttention."""
import ast
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata,
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm import _custom_ops as ops
logger = init_logger(__name__)
class TreeAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes.")
@staticmethod
def get_name() -> str:
return "TREE_ATTN_VLLM_V1"
@staticmethod
def get_impl_cls() -> type["TreeAttentionImpl"]:
return TreeAttentionImpl
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return TreeAttentionMetadata
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def get_builder_cls() -> type["TreeAttentionMetadataBuilder"]:
return TreeAttentionMetadataBuilder
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
@dataclass
class TreeAttentionMetadata:
num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
num_prefill_tokens: int = 0
num_decode_tokens: int = 0
num_prefills: int = 0
num_decodes: int = 0
tree_attn_bias: Optional[torch.Tensor] = None
# Cached Prefill/decode metadata.
_cached_prefill_metadata: Optional["TreeAttentionMetadata"] = None
_cached_decode_metadata: Optional["TreeAttentionMetadata"] = None
@property
def prefill_metadata(self) -> Optional["TreeAttentionMetadata"]:
if self.num_prefills == 0:
return None
if self._cached_prefill_metadata is not None:
# Recover cached prefill-phase attention
# metadata structure
return self._cached_prefill_metadata
q_start_loc = self.query_start_loc[self.num_decodes:]
q_seqlens = torch.diff(q_start_loc)
kv_seqlens = self.seq_lens[self.num_decodes:]
# Construct & cache prefill-phase attention metadata structure
self._cached_prefill_metadata = TreeAttentionMetadata(
num_actual_tokens=self.num_prefill_tokens,
max_query_len=int(q_seqlens.max().item()),
query_start_loc=q_start_loc - q_start_loc[0],
max_seq_len=int(kv_seqlens.max().item()),
seq_lens=kv_seqlens,
block_table=self.block_table[self.num_decodes:],
slot_mapping=self.slot_mapping[self.num_decode_tokens:],
)
return self._cached_prefill_metadata
@property
def decode_metadata(self) -> Optional["TreeAttentionMetadata"]:
if self.num_decode_tokens == 0:
return None
if self._cached_decode_metadata is not None:
# Recover cached decode-phase attention
# metadata structure
return self._cached_decode_metadata
q_start_loc = self.query_start_loc[:self.num_decodes + 1]
q_seqlens = torch.diff(q_start_loc)
kv_seqlens = self.seq_lens[:self.num_decodes]
# Construct & cache decode-phase attention metadata structure
self._cached_decode_metadata = TreeAttentionMetadata(
num_actual_tokens=self.num_decode_tokens,
max_query_len=int(q_seqlens.max().item()),
query_start_loc=q_start_loc,
max_seq_len=int(kv_seqlens.max().item()),
seq_lens=kv_seqlens,
block_table=self.block_table[:self.num_decodes],
slot_mapping=self.slot_mapping[:self.num_decode_tokens],
tree_attn_bias=self.tree_attn_bias,
)
return self._cached_decode_metadata
class TreeAttentionMetadataBuilder(
AttentionMetadataBuilder[TreeAttentionMetadata]):
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
self.kv_cache_spec = kv_cache_spec
self.block_size = kv_cache_spec.block_size
spec_config = vllm_config.speculative_config
spec_token_tree = (spec := spec_config) and spec.speculative_token_tree
tree_choices: list[tuple[int,
...]] = (ast.literal_eval(spec_token_tree)
if spec_token_tree is not None else
[(0, )])
# Construct the tree attention bias.
depth_counts = _get_depth_counts(tree_choices)
self.tree_attn_bias = _prepare_tree_attn_bias(
tree_choices,
depth_counts,
dtype=torch.float32,
device=device,
)
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
return reorder_batch_to_split_decodes_and_prefills(
input_batch,
scheduler_output,
decode_threshold=self.tree_attn_bias.shape[0])
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> TreeAttentionMetadata:
decode_threshold = self.tree_attn_bias.shape[0]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(common_attn_metadata,
decode_threshold=decode_threshold))
num_actual_tokens = common_attn_metadata.num_actual_tokens
q_start_loc = common_attn_metadata.query_start_loc
max_query_len = common_attn_metadata.max_query_len
kv_seqlens = common_attn_metadata.seq_lens
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
block_table = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
return TreeAttentionMetadata(
num_actual_tokens=num_actual_tokens,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
num_decodes=num_decodes,
max_query_len=max_query_len,
query_start_loc=q_start_loc,
max_seq_len=max_seq_len,
seq_lens=kv_seqlens,
block_table=block_table,
slot_mapping=slot_mapping,
tree_attn_bias=self.tree_attn_bias,
)
def build_for_drafting(
self,
common_attn_metadata: CommonAttentionMetadata,
draft_index: int,
) -> TreeAttentionMetadata:
# Cache the original tree attention bias.
orig_tree_attn_bias = self.tree_attn_bias
if draft_index == 0:
# Use prefill for drafting at the root level.
self.tree_attn_bias = torch.empty(0)
else:
# Slice the tree attention bias for drafting.
query_len = common_attn_metadata.max_query_len
start, end = draft_index, draft_index + query_len
self.tree_attn_bias = self.tree_attn_bias[start:end,
start:end].contiguous()
# Build attention bias.
attn_metadata = self.build(0, common_attn_metadata, fast_build=True)
# Reset the tree attention bias to the original value.
self.tree_attn_bias = orig_tree_attn_bias
return attn_metadata
def _get_depth_counts(sorted_tree_choices: list[tuple[int, ...]]) -> list[int]:
# Count the number of choices at each depth of the tree.
depth_counts = []
prev_depth = 0
for path in sorted_tree_choices:
depth = len(path)
if depth != prev_depth:
depth_counts.append(0)
depth_counts[depth - 1] += 1
prev_depth = depth
return depth_counts
def _prepare_tree_attn_bias(
sorted_tree_choices: list[tuple[int, ...]],
depth_counts: list[int],
dtype: Optional[torch.dtype],
device: Optional[torch.device],
) -> torch.Tensor:
# +1 comes from the additional root node.
tree_len = len(sorted_tree_choices) + 1
tree_attn_mask = torch.full((tree_len, tree_len),
-torch.inf,
device=device,
dtype=dtype)
# Set diagonal to all zeros. Each token should
# attend to itself.
mask_val = 0
for i in range(tree_len):
tree_attn_mask[i, i] = mask_val
# Set root to all zeros. All tokens attend to it.
tree_attn_mask[:, 0] = mask_val
# Set all ancestors to zeros.
start = 0
for i in range(len(depth_counts)):
for j in range(depth_counts[i]):
cur_tree_choice = sorted_tree_choices[start + j]
# Retrieve ancestor position.
if len(cur_tree_choice) == 1:
continue
ancestor_idx = []
for c in range(len(cur_tree_choice) - 1):
ancestor_idx.append(
sorted_tree_choices.index(cur_tree_choice[:c + 1]) + 1)
tree_attn_mask[j + start + 1, ancestor_idx] = mask_val
start += depth_counts[i]
return tree_attn_mask
class TreeAttentionImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
use_irope: bool = False,
) -> None:
if blocksparse_params is not None:
raise ValueError(
"TreeAttention does not support block-sparse attention.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.kv_cache_dtype = kv_cache_dtype
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
if logits_soft_cap is None:
# Setting logits_soft_cap to 0 means no soft cap.
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap
if sliding_window is None:
self.sliding_window = (-1, -1)
else:
self.sliding_window = (sliding_window - 1, 0)
TreeAttentionBackend.validate_head_size(head_size)
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"TreeAttentionImpl.")
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: TreeAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with TreeAttention.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for TreeAttentionImpl")
if attn_metadata is None:
# Profiling run.
return output
# Cache the input KVs.
key_cache, value_cache = kv_cache.unbind(0)
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
num_actual_tokens = attn_metadata.num_actual_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
descale_shape = (attn_metadata.query_start_loc.shape[0] - 1,
key.shape[1])
if prefill_meta := attn_metadata.prefill_metadata:
unified_attention(
q=query[num_decode_tokens:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[num_decode_tokens:num_actual_tokens],
cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len,
seqused_k=prefill_meta.seq_lens,
max_seqlen_k=prefill_meta.max_seq_len,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=prefill_meta.block_table,
softcap=self.logits_soft_cap,
q_descale=None, # Not supported
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
if decode_meta := attn_metadata.decode_metadata:
unified_attention(
q=query[:num_decode_tokens],
k=key_cache,
v=value_cache,
out=output[:num_decode_tokens],
cu_seqlens_q=decode_meta.query_start_loc,
max_seqlen_q=decode_meta.max_query_len,
seqused_k=decode_meta.seq_lens,
max_seqlen_k=decode_meta.max_seq_len,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
qq_bias=decode_meta.tree_attn_bias,
window_size=self.sliding_window,
block_table=decode_meta.block_table,
softcap=self.logits_soft_cap,
q_descale=None, # Not supported
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
return output

View File

@ -214,6 +214,26 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
return self.build(common_prefix_len=0,
common_attn_metadata=common_attn_metadata)
def build_for_drafting(
self,
common_attn_metadata: CommonAttentionMetadata,
draft_index: int,
) -> M:
"""
Build attention metadata for draft model. Uses build by default.
Args:
common_attn_metadata: The common attention metadata.
draft_index: The index of the current draft operation.
When speculating a chain of tokens, this index refers to the
draft attempt for the i-th token.
For tree-based attention, this index instead refers to the
draft attempt for the i-th level in the tree of tokens.
"""
return self.build(common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
fast_build=True)
def use_cascade_attention(
self,
common_prefix_len: int,

View File

@ -138,7 +138,7 @@ class LogprobsProcessor:
def pop_prompt_logprobs(self) -> Optional[PromptLogprobs]:
"""Pop and return all request prompt logprobs
The logprobs processor aggregates prompt chunk logprobs
over one or more prefill chunks. This method returns
all prompt logprobs at once and then forgets them.
@ -176,7 +176,8 @@ class LogprobsProcessor:
Returns:
dict[token id, Logprob]
"""
if num_logprobs == -1:
num_logprobs = len(logprobs)
# We do not need a special case for the sampled token
# being in the topk, since inserting duplicated data
# into a dictionary twice is the same as doing it once.

View File

@ -65,8 +65,11 @@ class Processor:
params: SamplingParams,
) -> None:
max_logprobs = self.model_config.max_logprobs
if max_logprobs == -1:
return
# Validate sample logprobs.
if params.logprobs and params.logprobs > max_logprobs:
if params.logprobs and (params.logprobs == -1
or params.logprobs > max_logprobs):
raise ValueError(
f"Requested sample logprobs of {params.logprobs}, "
f"which is greater than max allowed: {max_logprobs}")

View File

@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
from dataclasses import replace
from typing import Optional
import numpy as np
@ -17,6 +19,8 @@ from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.utils import is_pin_memory_available
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata,
TreeAttentionMetadataBuilder)
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.sample.metadata import SamplingMetadata
@ -74,18 +78,52 @@ class EagleProposer:
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
device=device)
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs +
1,
device=device,
dtype=torch.int32)
max_batch_size = vllm_config.scheduler_config.max_num_seqs
self.arange = torch.arange(
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
max_batch_size + 1,
device=device,
dtype=torch.int32,
)
self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
device=device)
# Parse the speculative token tree.
spec_token_tree = self.speculative_config.speculative_token_tree
self.tree_choices: list[tuple[int,
...]] = ast.literal_eval(spec_token_tree)
tree_depth = len(self.tree_choices[-1])
# Precompute per-level properties of the tree.
num_drafts_per_level = [0] * tree_depth
for node in self.tree_choices:
num_drafts_per_level[len(node) - 1] += 1
self.cu_drafts_per_level = [num_drafts_per_level[0]]
self.child_drafts_per_level = [num_drafts_per_level[0]]
for level in range(1, tree_depth):
self.cu_drafts_per_level.append(self.cu_drafts_per_level[-1] +
num_drafts_per_level[level])
self.child_drafts_per_level.append(num_drafts_per_level[level] //
num_drafts_per_level[level - 1])
# Find the first level where the tree branches off into one or more
# children.
self.first_branching_level = None
for level in range(tree_depth):
if self.cu_drafts_per_level[level] > level + 1:
self.first_branching_level = level
break
# Precompute draft position offsets in flattened tree.
self.tree_draft_pos_offsets = torch.arange(
1,
len(self.tree_choices) + 1,
device=device,
dtype=torch.int32,
).repeat(max_batch_size, 1)
def propose(
self,
# [num_tokens]
@ -120,11 +158,9 @@ class EagleProposer:
assert self.runner is not None
# FIXME: need to consider multiple kv_cache_groups
attn_metadata = self.runner.attn_metadata_builders[0].build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
fast_build=True,
)
attn_metadata = self.runner.attn_metadata_builders[
0].build_for_drafting(common_attn_metadata=common_attn_metadata,
draft_index=0)
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
@ -167,6 +203,22 @@ class EagleProposer:
last_hidden_states, hidden_states = ret_hidden_states
sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
positions = target_positions[last_token_indices]
hidden_states = hidden_states[last_token_indices]
if self.first_branching_level == 0:
# Branching has occurred at the root level. Draft using tree
# attention.
draft_token_ids_list = self.propose_tree(
tree_root_level=0,
batch_size=batch_size,
logits=logits,
positions=positions,
hidden_states=hidden_states,
common_attn_metadata=common_attn_metadata,
)
# [batch_size, num_tree_tokens]
return torch.cat(draft_token_ids_list, dim=1)
draft_token_ids = logits.argmax(dim=-1)
# Early exit if there is only one draft token to be generated.
@ -178,16 +230,15 @@ class EagleProposer:
# one layer. Adapt this code to support multiple layers once
# there's a multi-layer MTP module.
# Currently FlashAttention is the only backend that supports
# multi-token eagle spec decode. This is because the code below
# Currently, only FlashAttention and TreeAttention support multi-token
# eagle spec decode. This is because the code below
# makes assumptions about attn_metadata attributes available.
assert isinstance(attn_metadata, FlashAttentionMetadata)
assert isinstance(attn_metadata,
(FlashAttentionMetadata, TreeAttentionMetadata))
# Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids]
positions = target_positions[last_token_indices]
hidden_states = hidden_states[last_token_indices]
if self.use_cuda_graph and \
batch_size <= self.cudagraph_batch_sizes[-1]:
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
@ -196,7 +247,7 @@ class EagleProposer:
attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
for _ in range(self.num_speculative_tokens - 1):
for token_index in range(self.num_speculative_tokens - 1):
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default.
@ -265,7 +316,20 @@ class EagleProposer:
logits = self.model.compute_logits(last_hidden_states[:batch_size],
None)
# TODO(wenlong): get more than one token for tree attention
if self.first_branching_level == token_index + 1:
# Branching has occurred. The remaining tokens are drafted
# using tree attention.
draft_token_ids_list += self.propose_tree(
tree_root_level=token_index + 1,
batch_size=batch_size,
logits=logits,
positions=positions,
hidden_states=hidden_states,
common_attn_metadata=common_attn_metadata,
)
# [batch_size, num_tree_tokens]
return torch.cat(draft_token_ids_list, dim=1)
draft_token_ids = logits.argmax(dim=-1)
draft_token_ids_list.append(draft_token_ids)
@ -273,6 +337,175 @@ class EagleProposer:
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
return draft_token_ids
def propose_tree(
self,
tree_root_level: int,
batch_size: int,
# [num_tokens, vocab_size]
logits: torch.Tensor,
# [num_tokens]
positions: torch.Tensor,
# [num_tokens, hidden_size]
hidden_states: torch.Tensor,
common_attn_metadata: CommonAttentionMetadata,
) -> list[torch.Tensor]:
tree_attn_metadata_builder = self.runner.attn_metadata_builders[0]
assert isinstance(tree_attn_metadata_builder,
TreeAttentionMetadataBuilder)
total_num_drafts = self.cu_drafts_per_level[tree_root_level]
level_num_drafts = total_num_drafts
# Sample a draft token for each child at the tree root level.
num_children = self.child_drafts_per_level[tree_root_level]
if num_children == 1:
draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
else:
draft_token_ids = torch.topk(logits, num_children,
dim=-1).indices.view(batch_size, -1)
draft_token_ids_list = [draft_token_ids]
draft_hidden_states = hidden_states.view(batch_size, 1, -1)
# Initialize empty tensors for concatenation with the level outputs.
tree_input_ids = torch.empty(0,
device=self.input_ids.device,
dtype=self.input_ids.dtype)
tree_positions = torch.empty(0,
device=self.positions.device,
dtype=self.positions.dtype)
tree_hidden_states = torch.empty(0,
device=self.hidden_states.device,
dtype=self.hidden_states.dtype)
# Precompute the draft token positions.
flattened_draft_positions = (
positions.view(batch_size, -1) +
self.tree_draft_pos_offsets[:batch_size, :])
tree_depth = len(self.cu_drafts_per_level)
for level in range(tree_root_level, tree_depth - 1):
# Get draft positions for RoPE.
draft_positions = positions + (level + 1)
exceeds_max_model_len = (positions +
total_num_drafts) >= self.max_model_len
# Mask out the position ids that exceed the max model length.
# Otherwise, we may get out-of-range error in RoPE.
clamped_draft_positions = torch.where(
exceeds_max_model_len,
0,
draft_positions,
)
if level_num_drafts > 1:
# Repeat the positions for each draft at this level.
draft_positions = clamped_draft_positions.repeat_interleave(
level_num_drafts).reshape(batch_size, -1)
if num_children > 1:
# Repeat draft hidden states for each child.
draft_hidden_states = draft_hidden_states.repeat_interleave(
num_children, dim=1)
# Concatenate the draft tokens, positions, and hidden states.
tree_input_ids = torch.cat([tree_input_ids, draft_token_ids],
dim=1)
tree_positions = torch.cat([tree_positions, draft_positions],
dim=1)
tree_hidden_states = torch.cat(
[tree_hidden_states, draft_hidden_states], dim=1)
# Build new attention metadata for the next level of drafts.
# This is necessary to support tree attention.
query_len = total_num_drafts - tree_root_level
common_attn_metadata = replace(
common_attn_metadata,
query_start_loc=query_len * self.arange[:batch_size + 1],
seq_lens=common_attn_metadata.seq_lens + level_num_drafts,
num_actual_tokens=batch_size * query_len,
max_query_len=query_len,
)
attn_metadata = tree_attn_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata,
draft_index=tree_root_level + 1,
)
# Apply new attention metadata to all layers.
per_layer_attn_metadata = {}
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
# Consider max model length.
attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
self.max_model_len)
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
# Compute the slot mapping.
query_positions = flattened_draft_positions[:, level:level +
query_len]
block_numbers = query_positions // self.block_size
block_ids = attn_metadata.block_table.gather(dim=1,
index=block_numbers)
slot_mapping = (block_ids * self.block_size +
query_positions % self.block_size)
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
slot_mapping[exceeds_max_model_len] = PADDING_SLOT_ID
attn_metadata.slot_mapping = slot_mapping.view(-1)
# Copy inputs to buffer for cudagraph.
num_tokens = attn_metadata.num_actual_tokens
input_ids = tree_input_ids.view(-1)
self.input_ids[:num_tokens] = input_ids
self.positions[:num_tokens] = tree_positions.view(-1)
self.hidden_states[:num_tokens] = tree_hidden_states.view(
num_tokens, -1)
if self.use_cuda_graph and \
num_tokens <= self.cudagraph_batch_sizes[-1]:
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_tokens)
else:
num_input_tokens = num_tokens
# Run the model.
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
last_hidden_states, hidden_states = self.model(
input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens],
hidden_states=self.hidden_states[:num_input_tokens],
inputs_embeds=None,
)
# Get the output hidden states for the draft tokens.
draft_hidden_states = hidden_states[:num_tokens].view(
batch_size, query_len, -1)[:, -level_num_drafts:]
draft_last_hidden_states = last_hidden_states[:num_tokens].view(
batch_size, query_len, -1)[:, -level_num_drafts:]
# Get the output logits for the draft tokens.
logits = self.model.compute_logits(
draft_last_hidden_states.reshape(batch_size * level_num_drafts,
-1),
None,
)
# Sample a draft token for each child at the next tree level.
num_children = self.child_drafts_per_level[level + 1]
if num_children == 1:
draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
else:
draft_token_ids = torch.topk(logits, num_children,
dim=-1).indices.view(
batch_size, -1)
draft_token_ids_list.append(draft_token_ids)
# Update the # drafts counters for the next tree level.
level_num_drafts = self.cu_drafts_per_level[level +
1] - total_num_drafts
total_num_drafts = self.cu_drafts_per_level[level + 1]
return draft_token_ids_list
def prepare_inputs(
self,
common_attn_metadata: CommonAttentionMetadata,

View File

@ -34,22 +34,22 @@ class ConstantList(Generic[T], Sequence):
self._x = x
def append(self, item):
raise Exception("Cannot append to a constant list")
raise TypeError("Cannot append to a constant list")
def extend(self, item):
raise Exception("Cannot extend a constant list")
raise TypeError("Cannot extend a constant list")
def insert(self, item):
raise Exception("Cannot insert into a constant list")
raise TypeError("Cannot insert into a constant list")
def pop(self, item):
raise Exception("Cannot pop from a constant list")
raise TypeError("Cannot pop from a constant list")
def remove(self, item):
raise Exception("Cannot remove from a constant list")
raise TypeError("Cannot remove from a constant list")
def clear(self):
raise Exception("Cannot clear a constant list")
raise TypeError("Cannot clear a constant list")
def index(self,
item: T,
@ -78,10 +78,10 @@ class ConstantList(Generic[T], Sequence):
...
def __setitem__(self, item: Union[int, slice], value: Union[T, list[T]]):
raise Exception("Cannot set item in a constant list")
raise TypeError("Cannot set item in a constant list")
def __delitem__(self, item):
raise Exception("Cannot delete item from a constant list")
raise TypeError("Cannot delete item from a constant list")
def __iter__(self):
return iter(self._x)

View File

@ -337,7 +337,9 @@ class InputBatch:
self.generators[req_index] = request.generator
if sampling_params.logprobs is not None:
self.num_logprobs[req_id] = sampling_params.logprobs
self.num_logprobs[req_id] = (self.vocab_size
if sampling_params.logprobs == -1
else sampling_params.logprobs)
if sampling_params.prompt_logprobs is not None:
self.num_prompt_logprobs[
req_id] = sampling_params.prompt_logprobs