Merge remote-tracking branch 'origin/main' into refactor-fp8-linear

This commit is contained in:
vllmellm 2025-11-01 09:59:43 +00:00
commit 8e8218ebac
85 changed files with 2179 additions and 961 deletions

View File

@ -441,7 +441,7 @@ steps:
--ignore=lora/test_llm_with_multi_loras.py \
--ignore=lora/test_olmoe_tp.py \
--ignore=lora/test_deepseekv2_tp.py \
--ignore=lora/test_gptoss.py \
--ignore=lora/test_gptoss_tp.py \
--ignore=lora/test_qwen3moe_tp.py
parallelism: 4
@ -1217,6 +1217,8 @@ steps:
- pytest -v -s -x lora/test_llama_tp.py
- pytest -v -s -x lora/test_llm_with_multi_loras.py
- pytest -v -s -x lora/test_olmoe_tp.py
- pytest -v -s -x lora/test_gptoss_tp.py
- label: Weight Loading Multiple GPU Test # 33min
timeout_in_minutes: 45

View File

@ -340,6 +340,16 @@ steps:
commands:
- pytest -v -s v1/attention
- label: V1 Test attention (B200) # 10min
timeout_in_minutes: 30
gpu: b200
source_file_dependencies:
- vllm/v1/attention
- tests/v1/attention
commands:
- export VLLM_DISABLE_FLASHINFER_PREFILL=1 # TODO: FI prefill is bugged and causes incorrectness, fix this
- pytest -v -s v1/attention
- label: V1 Test others (CPU) # 5 mins
source_file_dependencies:
- vllm/
@ -417,7 +427,7 @@ steps:
--ignore=lora/test_llm_with_multi_loras.py \
--ignore=lora/test_olmoe_tp.py \
--ignore=lora/test_deepseekv2_tp.py \
--ignore=lora/test_gptoss.py \
--ignore=lora/test_gptoss_tp.py \
--ignore=lora/test_qwen3moe_tp.py
parallelism: 4
@ -1119,6 +1129,7 @@ steps:
- pytest -v -s -x lora/test_llama_tp.py
- pytest -v -s -x lora/test_llm_with_multi_loras.py
- pytest -v -s -x lora/test_olmoe_tp.py
- pytest -v -s -x lora/test_gptoss_tp.py
- label: Weight Loading Multiple GPU Test # 33min

View File

@ -1429,8 +1429,6 @@ async def main() -> None:
random.seed(args.seed)
np.random.seed(args.seed)
if not os.path.exists(args.model):
raise OSError(f"Path does not exist: {args.model}")
logger.info("Loading tokenizer")
tokenizer = AutoTokenizer.from_pretrained(args.model)

View File

@ -4,7 +4,7 @@ This doc serves as a collection of handy tips for optimizing your vLLM on TPU wo
## Get started
Looking for setup and installation instructions? Find them [here](../getting_started/installation/google_tpu.md).
Looking for setup and installation instructions? Find them [here](https://docs.vllm.ai/projects/tpu/en/latest/getting_started/installation/).
### TPU workload sizing

View File

@ -0,0 +1,133 @@
# Batch Invariance
!!! note
Batch invariance is currently in beta. Some features are still under active development.
Track progress and planned improvements at <https://github.com/vllm-project/vllm/issues/27433>
This document shows how to enable batch invariance in vLLM. Batch invariance ensures that the output of a model is deterministic and independent of the batch size or the order of requests in a batch.
## Motivation
Batch invariance is crucial for several use cases:
- **Framework debugging**: Deterministic outputs make it easier to debug issues in the inference framework, as the same input will always produce the same output regardless of batching.
- **Model debugging**: Helps identify issues in model implementations by ensuring consistent behavior across different batch configurations.
- **Reinforcement Learning (RL)**: RL training often requires deterministic rollouts for reproducibility and stable training.
- **Large-scale inference systems**: Systems that use vLLM as a component benefit from deterministic behavior for testing, validation, and consistency guarantees.
## Hardware Requirements
Batch invariance currently requires NVIDIA GPUs with compute capability 9.0 or higher:
- **H-series**: H100, H200
- **B-series**: B100, B200
## Enabling Batch Invariance
Batch invariance can be enabled by setting the `VLLM_BATCH_INVARIANT` environment variable to `1`:
```bash
export VLLM_BATCH_INVARIANT=1
```
### Online Inference (Server Mode)
To start a vLLM server with batch invariance enabled:
```bash
VLLM_BATCH_INVARIANT=1 vllm serve meta-llama/Llama-3.1-8B-Instruct
```
Then use the OpenAI-compatible client:
```python
from openai import OpenAI
client = OpenAI(
api_key="EMPTY",
base_url="http://localhost:8000/v1",
)
# These requests will produce deterministic outputs
# regardless of batch size or order
response = client.completions.create(
model="meta-llama/Llama-3.1-8B-Instruct",
prompt="The future of AI is",
max_tokens=100,
temperature=0.7,
seed=42,
)
print(response.choices[0].text)
```
### Offline Inference
For offline batch inference with batch invariance:
```python
import os
os.environ["VLLM_BATCH_INVARIANT"] = "1"
from vllm import LLM, SamplingParams
prompts = [
"The future of AI is",
"Machine learning enables",
"Deep learning models can",
]
sampling_params = SamplingParams(
temperature=0.7,
top_p=0.95,
max_tokens=100,
seed=42,
)
llm = LLM(
model="meta-llama/Llama-3.1-8B-Instruct",
tensor_parallel_size=1,
)
# Outputs will be deterministic regardless of batch size
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}")
print(f"Generated: {generated_text!r}\n")
```
## Tested Models
Batch invariance has been tested and verified on the following models:
- **DeepSeek series**: `deepseek-ai/DeepSeek-V3`, `deepseek-ai/DeepSeek-V3-0324`, `deepseek-ai/DeepSeek-R1`, `deepseek-ai/DeepSeek-V3.1`
- **Qwen3 (Dense)**: `Qwen/Qwen3-1.7B`, `Qwen/Qwen3-8B`
- **Qwen3 (MoE)**: `Qwen/Qwen3-30B-A3B`, `Qwen/Qwen3-Next-80B-A3B-Instruct`
- **Llama 3**: `meta-llama/Llama-3.1-8B-Instruct`, `meta-llama/Llama-3.2-1B-Instruct`
Other models may also work, but these have been explicitly validated. If you encounter issues with a specific model, please report them on the [GitHub issue tracker](https://github.com/vllm-project/vllm/issues/new/choose).
## Implementation Details
When batch invariance is enabled, vLLM:
1. Uses deterministic kernel implementations for attention and other operations
2. Ensures consistent numerical behavior across different batch sizes
3. Disables certain optimizations that may introduce non-determinism (such as custom all-reduce operations in tensor parallel mode)
!!! note
Enabling batch invariance may impact performance compared to the default non-deterministic mode. This trade-off is intentional to guarantee reproducibility.
## Future Improvements
The batch invariance feature is under active development. Planned improvements include:
- Support for additional GPU architectures
- Expanded model coverage
- Performance optimizations
- Additional testing and validation
For the latest status and to contribute ideas, see the [tracking issue](https://github.com/vllm-project/vllm/issues/27433).

View File

@ -81,7 +81,7 @@ python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py \
- Default: 5600
- **Required for both prefiller and decoder instances**
- Each vLLM worker needs a unique port on its host; using the same port number across different hosts is fine
- For TP/DP deployments, each worker's port on a node is computed as: base_port + dp_rank * tp_size + tp_rank (e.g., with `--tensor-parallel-size=4` and base_port=5600, tp_rank 0..3 use ports 5600, 5601, 5602, 5603 on that node).
- For TP/DP deployments, each worker's port on a node is computed as: base_port + dp_rank (e.g., with `--data-parallel-size=2` and base_port=5600, dp_rank 0..1 use port 5600, 5601 on that node).
- Used for the initial NIXL handshake between the prefiller and the decoder
- `VLLM_NIXL_SIDE_CHANNEL_HOST`: Host for side channel communication

View File

@ -2,4 +2,4 @@ nav:
- README.md
- gpu.md
- cpu.md
- google_tpu.md
- TPU: https://docs.vllm.ai/projects/tpu/en/latest/getting_started/installation/

View File

@ -11,7 +11,6 @@ vLLM supports the following hardware platforms:
- [ARM AArch64](cpu.md#arm-aarch64)
- [Apple silicon](cpu.md#apple-silicon)
- [IBM Z (S390X)](cpu.md#ibm-z-s390x)
- [Google TPU](google_tpu.md)
## Hardware Plugins
@ -20,6 +19,7 @@ The backends below live **outside** the main `vllm` repository and follow the
| Accelerator | PyPI / package | Repository |
|-------------|----------------|------------|
| Google TPU | `tpu-inference` | <https://github.com/vllm-project/tpu-inference> |
| Ascend NPU | `vllm-ascend` | <https://github.com/vllm-project/vllm-ascend> |
| Intel Gaudi (HPU) | N/A, install from source | <https://github.com/vllm-project/vllm-gaudi> |
| MetaX MACA GPU | N/A, install from source | <https://github.com/MetaX-MACA/vLLM-metax> |

View File

@ -1,193 +0,0 @@
# Google TPU
Tensor Processing Units (TPUs) are Google's custom-developed application-specific
integrated circuits (ASICs) used to accelerate machine learning workloads. TPUs
are available in different versions each with different hardware specifications.
For more information about TPUs, see [TPU System Architecture](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm).
For more information on the TPU versions supported with vLLM, see:
- [TPU v6e](https://cloud.google.com/tpu/docs/v6e)
- [TPU v5e](https://cloud.google.com/tpu/docs/v5e)
- [TPU v5p](https://cloud.google.com/tpu/docs/v5p)
- [TPU v4](https://cloud.google.com/tpu/docs/v4)
These TPU versions allow you to configure the physical arrangements of the TPU
chips. This can improve throughput and networking performance. For more
information see:
- [TPU v6e topologies](https://cloud.google.com/tpu/docs/v6e#configurations)
- [TPU v5e topologies](https://cloud.google.com/tpu/docs/v5e#tpu-v5e-config)
- [TPU v5p topologies](https://cloud.google.com/tpu/docs/v5p#tpu-v5p-config)
- [TPU v4 topologies](https://cloud.google.com/tpu/docs/v4#tpu-v4-config)
In order for you to use Cloud TPUs you need to have TPU quota granted to your
Google Cloud Platform project. TPU quotas specify how many TPUs you can use in a
GPC project and are specified in terms of TPU version, the number of TPU you
want to use, and quota type. For more information, see [TPU quota](https://cloud.google.com/tpu/docs/quota#tpu_quota).
For TPU pricing information, see [Cloud TPU pricing](https://cloud.google.com/tpu/pricing).
You may need additional persistent storage for your TPU VMs. For more
information, see [Storage options for Cloud TPU data](https://cloud.devsite.corp.google.com/tpu/docs/storage-options).
!!! warning
There are no pre-built wheels for this device, so you must either use the pre-built Docker image or build vLLM from source.
## Requirements
- Google Cloud TPU VM
- TPU versions: v6e, v5e, v5p, v4
- Python: 3.11 or newer
### Provision Cloud TPUs
You can provision Cloud TPUs using the [Cloud TPU API](https://cloud.google.com/tpu/docs/reference/rest)
or the [queued resources](https://cloud.google.com/tpu/docs/queued-resources)
API (preferred). This section shows how to create TPUs using the queued resource API. For
more information about using the Cloud TPU API, see [Create a Cloud TPU using the Create Node API](https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm#create-node-api).
Queued resources enable you to request Cloud TPU resources in a queued manner.
When you request queued resources, the request is added to a queue maintained by
the Cloud TPU service. When the requested resource becomes available, it's
assigned to your Google Cloud project for your immediate exclusive use.
!!! note
In all of the following commands, replace the ALL CAPS parameter names with
appropriate values. See the parameter descriptions table for more information.
### Provision Cloud TPUs with GKE
For more information about using TPUs with GKE, see:
- [About TPUs in GKE](https://cloud.google.com/kubernetes-engine/docs/concepts/tpus)
- [Deploy TPU workloads in GKE Standard](https://cloud.google.com/kubernetes-engine/docs/how-to/tpus)
- [Plan for TPUs in GKE](https://cloud.google.com/kubernetes-engine/docs/concepts/plan-tpus)
## Configure a new environment
### Provision a Cloud TPU with the queued resource API
Create a TPU v5e with 4 TPU chips:
```bash
gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \
--node-id TPU_NAME \
--project PROJECT_ID \
--zone ZONE \
--accelerator-type ACCELERATOR_TYPE \
--runtime-version RUNTIME_VERSION \
--service-account SERVICE_ACCOUNT
```
| Parameter name | Description |
|--------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| QUEUED_RESOURCE_ID | The user-assigned ID of the queued resource request. |
| TPU_NAME | The user-assigned name of the TPU which is created when the queued resource request is allocated. |
| PROJECT_ID | Your Google Cloud project |
| ZONE | The GCP zone where you want to create your Cloud TPU. The value you use depends on the version of TPUs you are using. For more information, see [TPU regions and zones] |
| ACCELERATOR_TYPE | The TPU version you want to use. Specify the TPU version, for example `v5litepod-4` specifies a v5e TPU with 4 cores, `v6e-1` specifies a v6e TPU with 1 core. For more information, see [TPU versions]. |
| RUNTIME_VERSION | The TPU VM runtime version to use. For example, use `v2-alpha-tpuv6e` for a VM loaded with one or more v6e TPU(s). |
| SERVICE_ACCOUNT | The email address for your service account. You can find it in the IAM Cloud Console under *Service Accounts*. For example: `tpu-service-account@<your_project_ID>.iam.gserviceaccount.com` |
Connect to your TPU VM using SSH:
```bash
gcloud compute tpus tpu-vm ssh TPU_NAME --project PROJECT_ID --zone ZONE
```
!!! note
When configuring `RUNTIME_VERSION` ("TPU software version") on GCP, ensure it matches the TPU generation you've selected by referencing the [TPU VM images] compatibility matrix. Using an incompatible version may prevent vLLM from running correctly.
[TPU versions]: https://cloud.google.com/tpu/docs/runtimes
[TPU VM images]: https://cloud.google.com/tpu/docs/runtimes
[TPU regions and zones]: https://cloud.google.com/tpu/docs/regions-zones
## Set up using Python
### Pre-built wheels
Currently, there are no pre-built TPU wheels.
### Build wheel from source
Install Miniconda:
```bash
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
bash Miniconda3-latest-Linux-x86_64.sh
source ~/.bashrc
```
Create and activate a Conda environment for vLLM:
```bash
conda create -n vllm python=3.12 -y
conda activate vllm
```
Clone the vLLM repository and go to the vLLM directory:
```bash
git clone https://github.com/vllm-project/vllm.git && cd vllm
```
Uninstall the existing `torch` and `torch_xla` packages:
```bash
pip uninstall torch torch-xla -y
```
Install build dependencies:
```bash
pip install -r requirements/tpu.txt
sudo apt-get install --no-install-recommends --yes libopenblas-base libopenmpi-dev libomp-dev
```
Run the setup script:
```bash
VLLM_TARGET_DEVICE="tpu" python -m pip install -e .
```
## Set up using Docker
### Pre-built images
See [Using Docker](../../deployment/docker.md) for instructions on using the official Docker image, making sure to substitute the image name `vllm/vllm-openai` with `vllm/vllm-tpu`.
### Build image from source
You can use [docker/Dockerfile.tpu](../../../docker/Dockerfile.tpu) to build a Docker image with TPU support.
```bash
docker build -f docker/Dockerfile.tpu -t vllm-tpu .
```
Run the Docker image with the following command:
```bash
# Make sure to add `--privileged --net host --shm-size=16G`.
docker run --privileged --net host --shm-size=16G -it vllm-tpu
```
!!! note
Since TPU relies on XLA which requires static shapes, vLLM bucketizes the
possible input shapes and compiles an XLA graph for each shape. The
compilation time may take 20~30 minutes in the first run. However, the
compilation time reduces to ~5 minutes afterwards because the XLA graphs are
cached in the disk (in `VLLM_XLA_CACHE_PATH` or `~/.cache/vllm/xla_cache` by default).
!!! tip
If you encounter the following error:
```console
from torch._C import * # noqa: F403
ImportError: libopenblas.so.0: cannot open shared object file: No such
file or directory
```
Install OpenBLAS with the following command:
```bash
sudo apt-get install --no-install-recommends --yes libopenblas-base libopenmpi-dev libomp-dev
```

View File

@ -63,6 +63,17 @@ This guide will help you quickly get started with vLLM to perform:
rocm/vllm-dev:nightly
```
=== "Google TPU"
To run vLLM on Google TPUs, you need to install the `vllm-tpu` package.
```bash
uv pip install vllm-tpu
```
!!! note
For more detailed instructions, including Docker, installing from source, and troubleshooting, please refer to the [vLLM on TPU documentation](https://docs.vllm.ai/projects/tpu/en/latest/).
!!! note
For more detail and non-CUDA platforms, please refer [here](installation/README.md) for specific instructions on how to install vLLM.

View File

@ -7,7 +7,7 @@ requests >= 2.26.0
tqdm
blake3
py-cpuinfo
transformers >= 4.56.0
transformers >= 4.56.0, < 5
tokenizers >= 0.21.1 # Required for fast incremental detokenization.
protobuf # Required by LlamaTokenizer.
fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint.

View File

@ -29,7 +29,7 @@ opencv-python-headless >= 4.11.0 # required for video test
datamodel_code_generator # required for minicpm3 test
lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d # required for model evaluation test
mteb>=1.38.11, <2 # required for mteb test
transformers==4.56.2
transformers==4.57.1
tokenizers==0.22.0
schemathesis>=3.39.15 # Required for openai schema test.
# quantization

View File

@ -37,7 +37,7 @@ datamodel_code_generator # required for minicpm3 test
# TODO: Use lm-eval[api]==0.4.10 once released
lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d # required for model evaluation test
mteb[bm25s]>=1.38.11, <2 # required for mteb test
transformers==4.56.2
transformers==4.57.1
tokenizers==0.22.0
schemathesis>=3.39.15 # Required for openai schema test.
# quantization

View File

@ -1196,7 +1196,7 @@ tqdm==4.66.6
# transformers
tqdm-multiprocess==0.0.11
# via lm-eval
transformers==4.56.2
transformers==4.57.1
# via
# -r requirements/test.in
# genai-perf

View File

@ -6,6 +6,9 @@ from copy import deepcopy
from tblib import pickling_support
# Import fixture
from tests.v1.entrypoints.conftest import sample_json_schema # noqa
# ruff: noqa
# Install support for pickling exceptions so that we can nicely propagate

View File

@ -237,7 +237,7 @@ def deepseekv2_lora_files():
@pytest.fixture(scope="session")
def gptoss20b_lora_files():
return snapshot_download(repo_id="LevinZheng/gpt-oss-20b-lora-adapter")
return snapshot_download(repo_id="jeeejeee/gpt-oss-20b-lora-adapter-text2sql")
@pytest.fixture(scope="session")

View File

@ -1,6 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# NOTE To avoid overloading the CI pipeline, this test script will
# not be triggered on CI and is primarily intended for local testing
# and verification.
import vllm
from vllm.lora.request import LoRARequest

View File

@ -1,52 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import vllm
from vllm.lora.request import LoRARequest
MODEL_PATH = "openai/gpt-oss-20b"
PROMPT_TEMPLATE = "<begin▁of▁sentence>You are a helpful assistant.\n\nUser: {context}\n\nAssistant:" # noqa: E501
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
prompts = [
PROMPT_TEMPLATE.format(context="Who are you?"),
]
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None,
)
# Print the outputs.
generated_texts: list[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
# FIXME: Load gpt-oss adapter
def test_gptoss20b_lora(gptoss20b_lora_files):
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
# Otherwise, the lora-test will fail due to CUDA OOM.
llm = vllm.LLM(
MODEL_PATH,
enable_lora=True,
max_loras=4,
trust_remote_code=True,
)
expected_lora_output = [
"I am an AI language model developed by OpenAI. "
"I am here to help you with any questions or "
"tasks you may have."
]
output1 = do_sample(llm, gptoss20b_lora_files, lora_id=1)
print(output1)
for i in range(len(expected_lora_output)):
assert output1[i].startswith(expected_lora_output[i])

View File

@ -0,0 +1,106 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import vllm
from vllm.lora.request import LoRARequest
from ..utils import multi_gpu_test
MODEL_PATH = "openai/gpt-oss-20b"
PROMPT_TEMPLATE = """<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.
Knowledge cutoff: 2024-06
Current date: 2025-10-29
Reasoning: medium
# Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>user<|message|>I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.
"
##Instruction:
farm contains tables such as city, farm, farm_competition, competition_record. Table city has columns such as City_ID, Official_Name, Status, Area_km_2, Population, Census_Ranking. City_ID is the primary key.
Table farm has columns such as Farm_ID, Year, Total_Horses, Working_Horses, Total_Cattle, Oxen, Bulls, Cows, Pigs, Sheep_and_Goats. Farm_ID is the primary key.
Table farm_competition has columns such as Competition_ID, Year, Theme, Host_city_ID, Hosts. Competition_ID is the primary key.
Table competition_record has columns such as Competition_ID, Farm_ID, Rank. Competition_ID is the primary key.
The Host_city_ID of farm_competition is the foreign key of City_ID of city.
The Farm_ID of competition_record is the foreign key of Farm_ID of farm.
The Competition_ID of competition_record is the foreign key of Competition_ID of farm_competition.
###Input:
{context}
###Response:<|end|><|start|>assistant<|channel|>final<|message|>""" # noqa: E501
EXPECTED_LORA_OUTPUT = [
"SELECT AVG(Working_Horses) FROM farm WHERE Total_Horses > 5000;",
"SELECT AVG(Working_Horses) FROM farm WHERE Total_Horses > 5000;",
"SELECT MAX(Cows) AS Max_Cows, MIN(Cows) AS Min_Cows FROM farm;",
"SELECT MAX(Cows) AS Max_Cows, MIN(Cows) AS Min_Cows FROM farm;",
]
def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:
prompts = [
PROMPT_TEMPLATE.format(
context="What is the average number of working horses of farms with more than 5000 total number of horses?" # noqa: E501
), # noqa: E501
PROMPT_TEMPLATE.format(
context="Give the average number of working horses on farms with more than 5000 total horses." # noqa: E501
), # noqa: E501
PROMPT_TEMPLATE.format(
context="What are the maximum and minimum number of cows across all farms."
),
PROMPT_TEMPLATE.format(
context="Return the maximum and minimum number of cows across all farms."
),
]
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None,
)
# Print the outputs.
generated_texts: list[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert generated_texts[i].startswith(EXPECTED_LORA_OUTPUT[i])
def test_gpt_oss_lora(gptoss20b_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
max_lora_rank=8,
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
cudagraph_specialize_lora=False,
),
)
generate_and_test(llm, gptoss20b_lora_files, lora_id=1)
generate_and_test(llm, gptoss20b_lora_files, lora_id=2)
@multi_gpu_test(num_gpus=2)
def test_gpt_oss_lora_tp2(gptoss20b_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=2,
max_lora_rank=8,
tensor_parallel_size=2,
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
cudagraph_specialize_lora=False,
),
)
generate_and_test(llm, gptoss20b_lora_files, lora_id=1)
generate_and_test(llm, gptoss20b_lora_files, lora_id=2)

View File

@ -1,6 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# NOTE To avoid overloading the CI pipeline, this test script will not
# be triggered on CI and is primarily intended for local testing and verification.
import vllm
from vllm.lora.request import LoRARequest

View File

@ -0,0 +1,86 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import asdict
from typing import NamedTuple
import pytest
from PIL.Image import Image
from transformers import AutoProcessor
from vllm import LLM, EngineArgs, SamplingParams
from vllm.multimodal.utils import encode_image_base64
MODEL_NAME = "Kwai-Keye/Keye-VL-8B-Preview"
QUESTION = "What is the content of each image?"
class ModelRequestData(NamedTuple):
engine_args: EngineArgs
prompt: str
image_data: list[Image]
stop_token_ids: list[int] | None = None
chat_template: str | None = None
sampling_params: SamplingParams | None = None
@pytest.mark.core_model
@pytest.mark.parametrize("question", [QUESTION])
def test_keye_vl(
image_assets,
question: str,
):
images = [asset.pil_image for asset in image_assets]
image_urls = [
f"data:image/jpeg;base64,{encode_image_base64(image)}" for image in images
]
engine_args = EngineArgs(
model=MODEL_NAME,
trust_remote_code=True,
max_model_len=8192,
max_num_seqs=5,
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [
{
"role": "user",
"content": [
*placeholders,
{"type": "text", "text": question},
],
},
]
processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
engine_args = asdict(engine_args) | {"seed": 42}
llm = LLM(**engine_args)
sampling_params = SamplingParams(
temperature=0.0, max_tokens=256, stop_token_ids=None
)
outputs = llm.generate(
{
"prompt": prompt,
"multi_modal_data": {"image": images},
},
sampling_params=sampling_params,
)
print("-" * 50)
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
assert len(generated_text) > 10, (
f"Generated text is too short: {generated_text}"
)
print("-" * 50)

View File

@ -186,6 +186,8 @@ def create_reduced_config(
if "text_config" in config_dict:
original_text_layers = config_dict["text_config"]["num_hidden_layers"]
config_dict["text_config"]["num_hidden_layers"] = text_layers
original_layer_types = config_dict["text_config"]["layer_types"]
config_dict["text_config"]["layer_types"] = original_layer_types[:text_layers]
print(f"Reduced text layers from {original_text_layers} to {text_layers}")
original_num_experts = config_dict["text_config"]["num_local_experts"]

View File

@ -882,27 +882,27 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
_TRANSFORMERS_BACKEND_MODELS = {
"TransformersEmbeddingModel": _HfExamplesInfo(
"BAAI/bge-base-en-v1.5", min_transformers_version="4.57.0.dev0"
"BAAI/bge-base-en-v1.5", min_transformers_version="5.0.0"
),
"TransformersForSequenceClassification": _HfExamplesInfo(
"papluca/xlm-roberta-base-language-detection",
min_transformers_version="4.57.0.dev0",
min_transformers_version="5.0.0",
),
"TransformersForCausalLM": _HfExamplesInfo(
"hmellor/Ilama-3.2-1B", trust_remote_code=True
),
"TransformersMultiModalForCausalLM": _HfExamplesInfo("BAAI/Emu3-Chat-hf"),
"TransformersMoEForCausalLM": _HfExamplesInfo(
"allenai/OLMoE-1B-7B-0924", min_transformers_version="4.57.0.dev0"
"allenai/OLMoE-1B-7B-0924", min_transformers_version="5.0.0"
),
"TransformersMultiModalMoEForCausalLM": _HfExamplesInfo(
"Qwen/Qwen3-VL-30B-A3B-Instruct", min_transformers_version="4.57.0.dev0"
"Qwen/Qwen3-VL-30B-A3B-Instruct", min_transformers_version="5.0.0"
),
"TransformersMoEEmbeddingModel": _HfExamplesInfo(
"Qwen/Qwen3-30B-A3B", min_transformers_version="4.57.0.dev0"
"Qwen/Qwen3-30B-A3B", min_transformers_version="5.0.0"
),
"TransformersMoEForSequenceClassification": _HfExamplesInfo(
"Qwen/Qwen3-30B-A3B", min_transformers_version="4.57.0.dev0"
"Qwen/Qwen3-30B-A3B", min_transformers_version="5.0.0"
),
"TransformersMultiModalEmbeddingModel": _HfExamplesInfo("google/gemma-3-4b-it"),
"TransformersMultiModalForSequenceClassification": _HfExamplesInfo(

View File

@ -82,7 +82,7 @@ def test_models(
from packaging.version import Version
installed = Version(transformers.__version__)
required = Version("4.57.0.dev0")
required = Version("5.0.0")
if model == "allenai/OLMoE-1B-7B-0924" and installed < required:
pytest.skip(
"MoE models with the Transformers backend require "

View File

@ -14,16 +14,19 @@ import torch
from tests.v1.attention.utils import (
BatchSpec,
create_common_attn_metadata,
create_standard_kv_cache_spec,
create_vllm_config,
try_get_attention_backend,
)
from vllm import _custom_ops as ops
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import _Backend, backend_to_class_str
from vllm.attention.ops.flashmla import is_flashmla_dense_supported
from vllm.attention.utils.fa_utils import flash_attn_supports_mla
from vllm.config.vllm import set_current_vllm_config
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.attention.backends.mla.common import QueryLenSupport
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec
@ -31,17 +34,46 @@ BACKENDS_TO_TEST = [
_Backend.CUTLASS_MLA,
_Backend.FLASHMLA,
_Backend.FLASH_ATTN_MLA,
_Backend.FLASHINFER_MLA,
_Backend.TRITON_MLA,
]
# Remove CUTLASS_MLA from the list if not using sm100
# Remove sm100 backends from the list if not using sm100
if not torch.cuda.is_available() or torch.cuda.get_device_properties(0).major < 10:
BACKENDS_TO_TEST.remove(_Backend.CUTLASS_MLA)
BACKENDS_TO_TEST.remove(_Backend.FLASHINFER_MLA)
# Remove FLASH_ATTN_MLA from the list if not supported
if not flash_attn_supports_mla():
BACKENDS_TO_TEST.remove(_Backend.FLASH_ATTN_MLA)
# Remove FLASHMLA from the list if not supported
if not is_flashmla_dense_supported()[0]:
BACKENDS_TO_TEST.remove(_Backend.FLASHMLA)
SPEC_DECODE_BACKENDS = []
for backend in BACKENDS_TO_TEST:
builder_cls, _ = try_get_attention_backend(backend)
query_len_support = getattr(
builder_cls, "query_len_support", QueryLenSupport.SINGLE_ONLY
)
if query_len_support != QueryLenSupport.SINGLE_ONLY:
SPEC_DECODE_BACKENDS.append(backend)
BACKEND_BLOCK_SIZES = {}
for backend in BACKENDS_TO_TEST:
backend_class_str = backend_to_class_str(backend)
backend_class = resolve_obj_by_qualname(backend_class_str)
supported_sizes = backend_class.get_supported_kernel_block_size()
if supported_sizes:
default_size = supported_sizes[0]
block_size = (
default_size if isinstance(default_size, int) else default_size.base
)
else:
block_size = 16
BACKEND_BLOCK_SIZES[backend] = block_size
torch.manual_seed(42)
@ -236,6 +268,26 @@ class MockAttentionLayer:
self._q_scale = torch.tensor(1.0, device=device)
self._k_scale = torch.tensor(1.0, device=device)
self._v_scale = torch.tensor(1.0, device=device)
self._prob_scale = torch.tensor(1.0, device=device)
self._q_scale_float = 1.0
self._k_scale_float = 1.0
self._v_scale_float = 1.0
def forward(self, *_args, **_kwargs):
raise NotImplementedError
class MockMLAAttentionLayer(AttentionLayerBase):
"""A mock MLA attention layer for populating static_forward_context."""
def __init__(self, impl):
self.impl = impl
def get_attn_backend(self):
raise NotImplementedError
def get_kv_cache_spec(self, vllm_config):
raise NotImplementedError
def run_attention_backend(
@ -262,13 +314,6 @@ def run_attention_backend(
# Set the current vllm config so that get_current_vllm_config() works
# in the backend implementations
with set_current_vllm_config(vllm_config):
# Build metadata
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
attn_metadata = builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
# Instantiate MLA implementation
num_heads = vllm_config.model_config.get_num_attention_heads(
vllm_config.parallel_config
@ -302,6 +347,19 @@ def run_attention_backend(
act_dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
impl.process_weights_after_loading(act_dtype)
# Populate static_forward_context with mock attention layers
for layer_name in layer_names:
vllm_config.compilation_config.static_forward_context[layer_name] = (
MockMLAAttentionLayer(impl)
)
# Build metadata
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
attn_metadata = builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
# Create mock layer and output buffer
mock_layer = MockAttentionLayer(device)
num_tokens = query.shape[0]
@ -353,15 +411,14 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
simulated paged KV cache.
5. Comparing the vLLM backend's output to the ground-truth SDPA output.
"""
from vllm.v1.attention.backends.mla.common import QueryLenSupport
batch_spec = BATCH_SPECS[batch_spec_name]
is_spec_decode_test = batch_spec_name.startswith("spec_decode")
spec_decode_backends = {_Backend.FLASH_ATTN_MLA, _Backend.FLASHMLA}
block_size = 16
unique_block_sizes = sorted(set(BACKEND_BLOCK_SIZES.values()))
default_block_size = unique_block_sizes[0]
required_blocks = sum(
(seq_len + block_size - 1) // block_size for seq_len in batch_spec.seq_lens
(seq_len + default_block_size - 1) // default_block_size
for seq_len in batch_spec.seq_lens
)
# Add 1 for null block at index 0, and some buffer
num_gpu_blocks = required_blocks + 1 + 100
@ -370,7 +427,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
model_name=model,
max_model_len=max(batch_spec.seq_lens),
num_gpu_blocks=num_gpu_blocks,
block_size=block_size,
block_size=default_block_size,
)
# For spec decode tests, add a speculative_config to set the reorder_batch_threshold
@ -388,8 +445,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
device = torch.device("cuda:0")
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
# 1. Setup
batch_size = batch_spec.batch_size
seq_lens = batch_spec.seq_lens
@ -399,7 +454,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
)
head_size = vllm_config.model_config.get_head_size()
dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
block_size = vllm_config.cache_config.block_size
kv_lora_rank = 512
qk_rope_head_dim = 64
qk_nope_head_dim = 128
@ -598,33 +652,83 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
)
mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T, requires_grad=False)
# Create metadata using original batch spec
common_attn_metadata = create_common_attn_metadata(
batch_spec, vllm_config.cache_config.block_size, device
)
# 3. Create metadata and KV caches for each block size
# Group backends by block size and test each group
metadata_per_block_size = {}
kv_cache_per_block_size = {}
# 3. Simulate Paged KV Cache and a realistic slot_mapping
kv_cache = create_and_prepopulate_kv_cache(
kv_c_contexts=kv_c_contexts,
k_pe_contexts=k_pe_contexts,
block_size=block_size,
head_size=head_size,
dtype=dtype,
device=device,
num_blocks=vllm_config.cache_config.num_gpu_blocks,
common_attn_metadata=common_attn_metadata,
randomize_blocks=True,
)
for block_size in unique_block_sizes:
# Create metadata for this block size
common_attn_metadata = create_common_attn_metadata(
batch_spec, block_size, device
)
# Pad block table to meet requirement:
# block_num % (128 / block_size) == 0
required_divisor = int(128 / block_size)
current_block_num = common_attn_metadata.block_table_tensor.shape[1]
if current_block_num % required_divisor != 0:
# Pad to next multiple of required_divisor
padded_block_num = (
(current_block_num + required_divisor - 1) // required_divisor
) * required_divisor
padding_cols = padded_block_num - current_block_num
padding = torch.zeros(
(common_attn_metadata.block_table_tensor.shape[0], padding_cols),
dtype=torch.int32,
device=device,
)
common_attn_metadata.block_table_tensor = torch.cat(
[common_attn_metadata.block_table_tensor, padding], dim=1
)
metadata_per_block_size[block_size] = common_attn_metadata
# Create KV cache for this block size
required_blocks_for_size = sum(
(seq_len + block_size - 1) // block_size for seq_len in batch_spec.seq_lens
)
num_blocks_for_size = required_blocks_for_size + 1 + 100
kv_cache = create_and_prepopulate_kv_cache(
kv_c_contexts=kv_c_contexts,
k_pe_contexts=k_pe_contexts,
block_size=block_size,
head_size=head_size,
dtype=dtype,
device=device,
num_blocks=num_blocks_for_size,
common_attn_metadata=common_attn_metadata,
randomize_blocks=True,
)
kv_cache_per_block_size[block_size] = kv_cache
# 4. Run vLLM backends and compare
failures = []
for backend_idx, backend_name in enumerate(BACKENDS_TO_TEST):
# Skip backends that don't support spec decode for spec decode tests
if is_spec_decode_test and backend_name not in spec_decode_backends:
if is_spec_decode_test and backend_name not in SPEC_DECODE_BACKENDS:
continue
# Get the appropriate block_size, metadata, and cache for this backend
block_size = BACKEND_BLOCK_SIZES[backend_name]
common_attn_metadata = metadata_per_block_size[block_size]
kv_cache = kv_cache_per_block_size[block_size]
# Create kv_cache_spec with the correct block_size for this backend
backend_kv_cache_spec = FullAttentionSpec(
block_size=block_size,
num_kv_heads=vllm_config.model_config.get_num_kv_heads(
vllm_config.parallel_config
),
head_size=vllm_config.model_config.get_head_size(),
dtype=vllm_config.model_config.dtype,
sliding_window=vllm_config.model_config.get_sliding_window(),
)
backend_output = run_attention_backend(
backend_name,
kv_cache_spec,
backend_kv_cache_spec,
["placeholder"],
vllm_config,
device,
@ -644,32 +748,48 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
expected_output = sdpa_outputs[backend_name]
# Check shape and dtype consistency
assert backend_output.shape == expected_output.shape, (
f"[{backend_name}] shape {backend_output.shape} != "
f"SDPA shape {expected_output.shape}"
)
assert backend_output.dtype == expected_output.dtype, (
f"[{backend_name}] dtype {backend_output.dtype} != "
f"SDPA dtype {expected_output.dtype}"
)
try:
assert backend_output.shape == expected_output.shape, (
f"[{backend_name}] shape {backend_output.shape} != "
f"SDPA shape {expected_output.shape}"
)
assert backend_output.dtype == expected_output.dtype, (
f"[{backend_name}] dtype {backend_output.dtype} != "
f"SDPA dtype {expected_output.dtype}"
)
assert torch.isfinite(backend_output).all(), (
f"[{backend_name}] produced non-finite values"
)
assert torch.isfinite(backend_output).all(), (
f"[{backend_name}] produced non-finite values"
)
# Check numerical similarity
rtol = 1e-2
atol = 5e-1
# Check numerical similarity
rtol = 1e-2
atol = 5e-1
max_diff = torch.max(torch.abs(backend_output - expected_output)).item()
max_rel_diff = torch.max(
torch.abs(backend_output - expected_output) / torch.abs(expected_output)
).item()
all_close = torch.allclose(
backend_output, expected_output, rtol=rtol, atol=atol
)
max_diff = torch.max(torch.abs(backend_output - expected_output)).item()
max_rel_diff = torch.max(
torch.abs(backend_output - expected_output) / torch.abs(expected_output)
).item()
all_close = torch.allclose(
backend_output, expected_output, rtol=rtol, atol=atol
)
assert all_close, (
f"[{backend_name}] output differs from SDPA baseline. "
f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})"
)
assert all_close, (
f"[{backend_name}] output differs from SDPA baseline. "
f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})"
)
except AssertionError as e:
failures.append(str(e))
# Report all failures at once
if failures:
# Create a summary for the single-line failure message
backend_names = []
for f in failures:
if "[_Backend." in f:
backend_name = f.split("[")[1].split("]")[0]
backend_names.append(backend_name)
summary = f"{len(failures)} backend(s) failed: {', '.join(backend_names)}"
detailed_msg = "\n".join(failures)
pytest.fail(f"{summary}\n{detailed_msg}")

View File

@ -285,7 +285,17 @@ full_cg_backend_configs = {
name="CutlassMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
"FORCE_NUM_KV_SPLITS": "1", # TODO: remove this when hang issue is fixed
},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
},
specific_gpu_arch=(10, 0),
),
# FlashInfer MLA on Blackwell
"FlashInferMLA": BackendConfig(
name="FlashInferMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASHINFER_MLA",
},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",

View File

@ -337,8 +337,6 @@ def test_stop_via_update_from_output():
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
model_output = ModelRunnerOutput(
@ -385,8 +383,6 @@ def test_stop_via_update_from_output():
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
model_output = ModelRunnerOutput(
@ -431,8 +427,6 @@ def test_stop_via_update_from_output():
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
model_output = ModelRunnerOutput(
@ -472,8 +466,6 @@ def test_stop_via_update_from_output():
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
model_output = ModelRunnerOutput(
@ -1988,7 +1980,6 @@ def test_schedule_skip_tokenizer_init():
scheduler.add_request(request)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests)
assert output.grammar_bitmask is None
def test_schedule_skip_tokenizer_init_structured_output_request():

View File

@ -7,6 +7,7 @@ import torch._dynamo.config as dynamo_config
from vllm import SamplingParams
from vllm.logprobs import Logprob
from vllm.sampling_params import StructuredOutputsParams
from ...conftest import VllmRunner
from ...models.utils import check_outputs_equal
@ -15,9 +16,12 @@ MODEL = "Qwen/Qwen3-0.6B"
@dynamo_config.patch(cache_size_limit=16)
def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
def test_preempt_and_async_scheduling_e2e(
sample_json_schema, monkeypatch: pytest.MonkeyPatch
):
"""Test consistency of combos of async scheduling, preemption,
uni/multiproc executor, and various sampling parameters."""
uni/multiproc executor, and various sampling parameters
including structured outputs."""
first_prompt = (
"The following numbers of the sequence "
@ -35,6 +39,12 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
dict(bad_words=["the", " the"]),
dict(logprobs=2),
dict(logprobs=2, presence_penalty=-1.0),
dict(structured_outputs=StructuredOutputsParams(json=sample_json_schema)),
dict(
structured_outputs=StructuredOutputsParams(json=sample_json_schema),
logprobs=2,
presence_penalty=-1.0,
),
]
default_params = dict(

View File

@ -248,7 +248,7 @@ def test_engine_core_concurrent_batches():
self,
scheduler_output,
non_block=False,
) -> Future[ModelRunnerOutput]:
) -> Future[ModelRunnerOutput | None]:
"""Make execute_model non-blocking."""
# DummyExecutor used only for testing async case.
@ -263,6 +263,23 @@ def test_engine_core_concurrent_batches():
# Use the thread pool instead of creating a new thread
return self.thread_pool.submit(_execute)
def sample_tokens(
self, grammar_output, non_block=False
) -> Future[ModelRunnerOutput]:
"""Make sample_tokens non-blocking."""
# DummyExecutor used only for testing async case.
assert non_block
def _execute():
output = self.collective_rpc("sample_tokens", args=(grammar_output,))
# Make a copy because output[0] may be reused
# by the next batch.
return copy.deepcopy(output[0])
# Use the thread pool instead of creating a new thread
return self.thread_pool.submit(_execute)
@property
def max_concurrent_batches(self) -> int:
return 2

View File

@ -31,7 +31,9 @@ class CustomMultiprocExecutor(MultiprocExecutor):
# Drop marker to show that this was run
with open(".marker", "w"):
...
return super().collective_rpc(method, timeout, args, kwargs)
return super().collective_rpc(
method, timeout, args, kwargs, non_block, unique_reply_rank
)
CustomMultiprocExecutorAsync = CustomMultiprocExecutor

View File

@ -0,0 +1,65 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for KV cache offloading configuration."""
import pytest
from vllm.config import CacheConfig, KVTransferConfig, ParallelConfig, VllmConfig
pytestmark = pytest.mark.cpu_test
@pytest.mark.parametrize(
"kv_offloading_backend,kv_offloading_size,tp,pp,expected_backend,expected_bytes",
[
("native", 4.0, 1, 1, "OffloadingConnector", 4.0 * (1 << 30)),
# bytes per rank: 8.0 GiB / (2 * 2) = 2.0 GiB
("native", 8.0, 2, 2, "OffloadingConnector", 8.0 * (1 << 30) / 4),
("lmcache", 4.0, 1, 1, "LMCacheConnectorV1", 4.0),
# size per rank: 8.0 GiB / (2 * 2) = 2.0 GiB
("lmcache", 8.0, 2, 2, "LMCacheConnectorV1", 2.0),
(None, None, 1, 1, None, None),
],
)
def test_kv_connector(
kv_offloading_backend, kv_offloading_size, tp, pp, expected_backend, expected_bytes
):
kv_transfer_config = (
KVTransferConfig(kv_connector_extra_config={"existing_key": "existing_value"})
if expected_backend is not None
else None
)
vllm_config = VllmConfig(
cache_config=CacheConfig(
kv_offloading_backend=kv_offloading_backend,
kv_offloading_size=kv_offloading_size,
),
kv_transfer_config=kv_transfer_config,
parallel_config=ParallelConfig(
tensor_parallel_size=tp, pipeline_parallel_size=pp
),
)
# No KV transfer config expected
if expected_backend is None:
assert vllm_config.kv_transfer_config is expected_backend
return
kv_transfer_config = vllm_config.kv_transfer_config
kv_connector_extra_config = kv_transfer_config.kv_connector_extra_config
assert kv_transfer_config.kv_connector == expected_backend
assert kv_transfer_config.kv_role == "kv_both"
if kv_offloading_backend == "native":
assert kv_connector_extra_config["kv_bytes_per_rank"] == expected_bytes
assert kv_connector_extra_config["num_cpu_blocks"] == 0
# Existing config should be preserved
assert kv_connector_extra_config["existing_key"] == "existing_value"
elif kv_offloading_backend == "lmcache":
assert kv_connector_extra_config["lmcache.local_cpu"] is True
assert kv_connector_extra_config["lmcache.max_local_cpu_size"] == expected_bytes
# Existing config should be replaced
assert "existing_key" not in kv_connector_extra_config

View File

@ -26,8 +26,6 @@ def _make_empty_scheduler_output():
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
kv_connector_metadata=SharedStorageConnectorMetadata(),
)

View File

@ -27,6 +27,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlAgentMetadata,
NixlConnector,
NixlConnectorMetadata,
NixlConnectorScheduler,
NixlConnectorWorker,
NixlKVConnectorStats,
)
@ -283,6 +284,92 @@ def test_prompt_less_than_block_size():
assert len(scheduler_output.scheduled_new_reqs) == 0
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
def test_kv_transfer_handshake(dist_init):
"""Unit test for basic NixlConnector interface functionality."""
# Test setup, we creates a scheduler that contains a NixlConnector
# of role SCHEDULER, and expect it to be serving NixlAgentMetadata from
# all workers of the instance.
vllm_config = create_vllm_config()
# in case the test runs on non-GPU machine
vllm_config.kv_transfer_config.kv_buffer_device = "cpu"
scheduler = create_scheduler(vllm_config)
# Create two NixlConnector of role WORKER, one is the worker of
# the scheduler (prefill), the other is a worker of decode instance.
# Prefill connector will register KV cache to populate proper handshake
# metadata.
prefill_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
"layer2": shared_tensor,
}
prefill_connector.register_kv_caches(kv_caches)
# Simulate EngineCore initialization that would
# gather connector metadata from all workers, the scheduler connector
# expects metadata to be in dict[int, KVConnectorHandshakeMetadata],
# where the first key is the dp_rank, the second key is the tp_rank.
metadata = {0: prefill_connector.get_handshake_metadata()}
scheduler_connector = scheduler.get_kv_connector()
scheduler_connector.set_xfer_handshake_metadata(metadata)
# Simulate a request that finishes prefill, which returns
# corresponding NixlConnectorMetadata for decode instance.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request = create_request(
request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_decode=True,
)
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
delay, kv_connector_metadata = scheduler.get_kv_connector().request_finished(
request, [0, 1, 2]
)
assert delay
# Decode connector will be able to create handshake with the prefill connector.
decode_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
# Here we are testing the retrieval of NIXLAgentMetadata.
# Knowing the implementation detail, we override the add_remote_agent
# to validate the metadata received is the same as the one in prefill_connector.
with patch.object(
decode_connector.connector_worker, "add_remote_agent"
) as mock_add_remote_agent:
mock_add_remote_agent.return_type = "remote_agent"
decode_connector.connector_worker._nixl_handshake(
kv_connector_metadata["remote_host"],
kv_connector_metadata["remote_port"],
kv_connector_metadata["tp_size"],
kv_connector_metadata["remote_engine_id"],
)
received_metadata = mock_add_remote_agent.call_args.args
assert received_metadata[1] == 0 # remote_tp_rank
assert received_metadata[2] == 1 # remote_tp_size
assert metadata[0] == received_metadata[0]
# Need to shutdown the background thread to release NIXL side channel port
scheduler_connector.shutdown()
class FakeNixlConnectorWorker(NixlConnectorWorker):
REMOTE_ENGINE_ID = "remote_engine"
@ -313,6 +400,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
engine_id=self.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
device_id=0,
num_blocks=1,
block_lens=self.block_len_per_layer,
attn_backend_name=self.backend_name,
@ -559,6 +647,7 @@ class TestNixlHandshake:
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
device_id=0,
num_blocks=1,
block_lens=worker.block_len_per_layer,
attn_backend_name=worker.backend_name,
@ -611,6 +700,7 @@ class TestNixlHandshake:
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
device_id=0,
num_blocks=1,
# prefill TP=1, decode TP=2, remote block_lens is double to local
block_lens=[i * 2 for i in worker.block_len_per_layer],
@ -891,9 +981,7 @@ def test_scheduler_kv_connector_stats_aggregation():
scheduled_encoder_inputs={},
num_common_prefix_blocks=[0],
finished_req_ids=set(),
free_encoder_mm_hashes=set(),
structured_output_request_ids={},
grammar_bitmask=None,
free_encoder_mm_hashes=[],
)
engine_core_outputs = scheduler.update_from_output(scheduler_output, model_output)
@ -1005,6 +1093,8 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
_ = llm.generate([f"What is the capital of France? {padding}"], sampling_params)
# Request-0 times out and is cleared!
assert "0" not in req_to_blocks
# Need to shutdown the background thread to release NIXL side channel port
llm.llm_engine.engine_core.shutdown()
def test_register_kv_caches(dist_init):
@ -1177,13 +1267,15 @@ def test_shutdown_cleans_up_resources(dist_init):
"""Test that shutdown() properly cleans up all resources."""
vllm_config = create_vllm_config()
scheduler = NixlConnectorScheduler(
vllm_config, vllm_config.kv_transfer_config.engine_id
)
worker = NixlConnectorWorker(vllm_config, vllm_config.kv_transfer_config.engine_id)
nixl_wrapper = worker.nixl_wrapper
with (
patch.object(worker, "_handshake_initiation_executor") as mock_exec,
patch.object(worker, "_nixl_handshake_listener_t") as mock_listener,
patch.object(worker, "_nixl_handshake_listener_stop_event") as mock_event,
patch.object(scheduler, "_nixl_handshake_listener_t") as mock_listener,
patch.object(nixl_wrapper, "release_xfer_handle") as mock_rel_xfer,
patch.object(nixl_wrapper, "release_dlist_handle") as mock_rel_dlist,
patch.object(nixl_wrapper, "remove_remote_agent") as mock_rem_agent,
@ -1204,8 +1296,12 @@ def test_shutdown_cleans_up_resources(dist_init):
worker.shutdown()
mock_exec.shutdown.assert_called_with(wait=False)
mock_event.set.assert_called_once()
mock_listener.join.assert_called_once_with(timeout=1.0)
# Same sequence on scheduler.shutdown()
scheduler.shutdown()
scheduler.shutdown()
scheduler.shutdown()
mock_listener.join.assert_called_once()
mock_rel_xfer.assert_called_once_with(123)
assert mock_rel_dlist.call_count == 2

View File

@ -92,8 +92,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
@ -171,8 +169,6 @@ def test_update_states_request_finished(model_runner):
num_common_prefix_blocks=[],
finished_req_ids={req_id},
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
model_runner._update_states(scheduler_output)
@ -201,8 +197,6 @@ def test_update_states_request_resumed(model_runner):
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
model_runner._update_states(scheduler_output)
@ -230,8 +224,6 @@ def test_update_states_request_resumed(model_runner):
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
model_runner._update_states(scheduler_output)
@ -261,8 +253,6 @@ def test_update_states_no_changes(model_runner):
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
model_runner._update_states(scheduler_output)
@ -296,8 +286,6 @@ def test_update_states_request_unscheduled(model_runner):
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
model_runner._update_states(scheduler_output)

View File

@ -6,6 +6,7 @@ import pytest
import torch
from vllm.attention import Attention
from vllm.attention.backends.abstract import MultipleOf
from vllm.config import (
CacheConfig,
ModelConfig,
@ -34,6 +35,7 @@ from vllm.v1.kv_cache_interface import (
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.utils import AttentionGroup
BLOCK_SIZE = 16
NUM_BLOCKS = 10
@ -150,8 +152,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
@ -181,6 +181,57 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
).all()
def _make_mock_backend_for_kernel_block_size(
supported_sizes: list[int | MultipleOf],
):
class _MockBackend:
@staticmethod
def get_supported_kernel_block_size():
return supported_sizes
return _MockBackend()
def _make_kv_cache_spec() -> FullAttentionSpec:
return FullAttentionSpec(block_size=1, num_kv_heads=1, head_size=1, dtype="float16")
def test_select_common_block_size_prefers_manager_block_size():
backend_a = _make_mock_backend_for_kernel_block_size([MultipleOf(32)])
backend_b = _make_mock_backend_for_kernel_block_size([64, MultipleOf(16)])
attn_groups = [
AttentionGroup(backend_a, [], [], _make_kv_cache_spec(), 0),
AttentionGroup(backend_b, [], [], _make_kv_cache_spec(), 0),
]
selected_size = GPUModelRunner.select_common_block_size(128, attn_groups)
assert selected_size == 128
def test_select_common_block_size_uses_largest_shared_int():
backend_a = _make_mock_backend_for_kernel_block_size([128, 64])
backend_b = _make_mock_backend_for_kernel_block_size([64, 32])
attn_groups = [
AttentionGroup(backend_a, [], [], _make_kv_cache_spec(), 0),
AttentionGroup(backend_b, [], [], _make_kv_cache_spec(), 0),
]
selected_size = GPUModelRunner.select_common_block_size(256, attn_groups)
assert selected_size == 64
def test_select_common_block_size_no_valid_option():
backend_a = _make_mock_backend_for_kernel_block_size([64])
backend_b = _make_mock_backend_for_kernel_block_size([MultipleOf(16)])
attn_groups = [
AttentionGroup(backend_a, [], [], _make_kv_cache_spec(), 0),
AttentionGroup(backend_b, [], [], _make_kv_cache_spec(), 0),
]
with pytest.raises(ValueError):
GPUModelRunner.select_common_block_size(48, attn_groups)
def test_update_states_new_request(model_runner, dist_init):
req_id = "req_0"
@ -216,8 +267,6 @@ def test_update_states_request_finished(model_runner, dist_init):
num_common_prefix_blocks=[],
finished_req_ids={req_id},
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
metadata_before = model_runner.input_batch.sampling_metadata
@ -248,8 +297,6 @@ def test_update_states_request_resumed(model_runner, dist_init):
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
model_runner._update_states(scheduler_output)
@ -277,8 +324,6 @@ def test_update_states_request_resumed(model_runner, dist_init):
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
metadata_before = model_runner.input_batch.sampling_metadata
@ -370,8 +415,6 @@ def test_update_states_no_changes(model_runner, dist_init):
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
metadata_before = model_runner.input_batch.sampling_metadata
@ -407,8 +450,6 @@ def test_update_states_request_unscheduled(model_runner, dist_init):
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
metadata_before = model_runner._update_states(scheduler_output)

View File

@ -270,21 +270,23 @@ class ipex_ops:
@staticmethod
def flash_attn_varlen_func(
out: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
seqused_k: torch.Tensor, # we don't support this in ipex kernel
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: float,
causal: bool,
block_table: torch.Tensor,
alibi_slopes: torch.Tensor | None,
softmax_scale: float | None = None,
causal: bool = False,
out: torch.Tensor | None = None,
block_table: torch.Tensor | None = None,
alibi_slopes: torch.Tensor | None = None,
window_size: list[int] | None = None,
softcap: float | None = 0.0,
seqused_k: torch.Tensor | None = None,
cu_seqlens_k: torch.Tensor | None = None,
# passed in qwen vl
dropout_p: float = 0.0,
# The following parameters are not used in ipex kernel currently,
# we keep API compatible to CUDA's.
scheduler_metadata=None,
@ -295,31 +297,63 @@ class ipex_ops:
num_splits=0,
s_aux: torch.Tensor | None = None,
):
if out is None:
out = torch.empty(q.shape, dtype=q.dtype, device=q.device)
real_window_size: tuple[int, int]
if window_size is None:
real_window_size = (-1, -1)
else:
assert len(window_size) == 2
real_window_size = (window_size[0], window_size[1])
return ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
out,
q.contiguous(),
k,
v,
cu_seqlens_q,
seqused_k,
max_seqlen_q,
max_seqlen_k,
softmax_scale,
causal,
block_table,
alibi_slopes,
softcap=softcap,
window_size_left=real_window_size[0],
window_size_right=real_window_size[1],
k_scale=1.0,
v_scale=1.0,
)
if block_table is None:
assert cu_seqlens_k is not None, (
"cu_seqlens_k can't be None when calling varlen_attention."
)
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
ipex_ops.varlen_attention(
q.contiguous(),
k.contiguous(),
v.contiguous(),
out,
cu_seqlens_q,
cu_seqlens_k,
None,
max_seqlen_q,
max_seqlen_k,
0.0,
softmax_scale,
False,
causal,
False,
None,
real_window_size[0],
real_window_size[1],
-1,
)
return out
else:
return ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
out,
q.contiguous(),
k,
v,
cu_seqlens_q,
seqused_k,
max_seqlen_q,
max_seqlen_k,
softmax_scale,
causal,
block_table,
alibi_slopes,
sink=s_aux,
softcap=softcap,
window_size_left=real_window_size[0],
window_size_right=real_window_size[1],
k_scale=1.0,
v_scale=1.0,
)
@staticmethod
def get_scheduler_metadata(

View File

@ -123,6 +123,11 @@ def maybe_get_vit_flash_attn_backend(
):
attn_backend = _Backend.FLASH_ATTN
use_upstream_fa = True
elif current_platform.is_xpu():
assert attn_backend == _Backend.FLASH_ATTN, (
"XPU platform only supports FLASH_ATTN as vision attention backend."
)
use_upstream_fa = False
else:
return _Backend.TORCH_SDPA, None
@ -133,7 +138,7 @@ def maybe_get_vit_flash_attn_backend(
if use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
else:
flash_attn_varlen_func = None
@ -521,22 +526,18 @@ class MultiHeadAttention(nn.Module):
# If vllm native fa is selected, we use it directly.
use_upstream_fa = False
if current_platform.is_xpu():
# currently, only torch_sdpa is supported on xpu
self.attn_backend = _Backend.TORCH_SDPA
else:
self.attn_backend = (
backend
if backend
in {
_Backend.TORCH_SDPA,
_Backend.XFORMERS,
_Backend.PALLAS,
_Backend.ROCM_AITER_FA,
_Backend.FLASH_ATTN,
}
else _Backend.TORCH_SDPA
)
self.attn_backend = (
backend
if backend
in {
_Backend.TORCH_SDPA,
_Backend.XFORMERS,
_Backend.PALLAS,
_Backend.ROCM_AITER_FA,
_Backend.FLASH_ATTN,
}
else _Backend.TORCH_SDPA
)
self.attn_backend, self._flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(

View File

@ -70,7 +70,7 @@ def flash_attn_maxseqlen_wrapper(
if use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = flash_attn_varlen_func(
q,

View File

@ -24,6 +24,7 @@ BlockSize = Literal[1, 8, 16, 32, 64, 128, 256]
CacheDType = Literal["auto", "bfloat16", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"]
MambaDType = Literal["auto", "float32"]
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"]
KVOffloadingBackend = Literal["native", "lmcache"]
@config
@ -128,6 +129,17 @@ class CacheConfig:
gpu_memory_utilization. Note that kv_cache_memory_bytes
(when not-None) ignores gpu_memory_utilization"""
kv_offloading_size: float | None = None
"""Size of the KV cache offloading buffer in GiB. When TP > 1, this is
the total buffer size summed across all TP ranks. By default, this is set
to None, which means no KV offloading is enabled. When set with
kv_offloading_backend, vLLM will enable KV cache offloading to CPU"""
kv_offloading_backend: KVOffloadingBackend | None = None
"""The backend to use for KV cache offloading. Supported backends include
'native' (vLLM native CPU offloading), 'lmcache' This option must be used
together with kv_offloading_size."""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,

View File

@ -2,10 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
from dataclasses import InitVar, field
from collections.abc import Callable
from dataclasses import InitVar
from typing import Any, Literal
from pydantic import SkipValidation, model_validator
from pydantic import Field, field_validator, model_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Self
@ -31,28 +32,28 @@ class SchedulerConfig:
runner_type: RunnerType = "generate"
"""The runner type to launch for the model."""
max_num_batched_tokens: SkipValidation[int] = None # type: ignore
max_num_batched_tokens: int = Field(default=None, ge=1)
"""Maximum number of tokens to be processed in a single iteration.
This config has no static default. If left unspecified by the user, it will
be set in `EngineArgs.create_engine_config` based on the usage context."""
max_num_seqs: SkipValidation[int] = None # type: ignore
max_num_seqs: int = Field(default=None, ge=1)
"""Maximum number of sequences to be processed in a single iteration.
This config has no static default. If left unspecified by the user, it will
be set in `EngineArgs.create_engine_config` based on the usage context."""
max_model_len: SkipValidation[int] = None # type: ignore
max_model_len: int = Field(default=None, ge=1)
"""Maximum length of a sequence (including prompt and generated text). This
is primarily set in `ModelConfig` and that value should be manually
duplicated here."""
max_num_partial_prefills: int = 1
max_num_partial_prefills: int = Field(default=1, ge=1)
"""For chunked prefill, the maximum number of sequences that can be
partially prefilled concurrently."""
max_long_partial_prefills: int = 1
max_long_partial_prefills: int = Field(default=1, ge=1)
"""For chunked prefill, the maximum number of prompts longer than
long_prefill_token_threshold that will be prefilled concurrently. Setting
this less than max_num_partial_prefills will allow shorter prompts to jump
@ -62,7 +63,7 @@ class SchedulerConfig:
"""For chunked prefill, a request is considered long if the prompt is
longer than this number of tokens."""
num_lookahead_slots: int = 0
num_lookahead_slots: int = Field(default=0, ge=0)
"""The number of slots to allocate per sequence per
step, beyond the known token ids. This is used in speculative
decoding to store KV activations of tokens which may or may not be
@ -71,7 +72,7 @@ class SchedulerConfig:
NOTE: This will be replaced by speculative config in the future; it is
present to enable correctness tests until then."""
enable_chunked_prefill: SkipValidation[bool] = None # type: ignore
enable_chunked_prefill: bool = Field(default=None)
"""If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens."""
@ -86,14 +87,14 @@ class SchedulerConfig:
"""
# TODO (ywang96): Make this configurable.
max_num_encoder_input_tokens: int = field(init=False)
max_num_encoder_input_tokens: int = Field(init=False)
"""Multimodal encoder compute budget, only used in V1.
NOTE: This is not currently configurable. It will be overridden by
max_num_batched_tokens in case max multimodal embedding size is larger."""
# TODO (ywang96): Make this configurable.
encoder_cache_size: int = field(init=False)
encoder_cache_size: int = Field(init=False)
"""Multimodal encoder cache size, only used in V1.
NOTE: This is not currently configurable. It will be overridden by
@ -106,7 +107,7 @@ class SchedulerConfig:
- "priority" means requests are handled based on given priority (lower
value means earlier handling) and time of arrival deciding any ties)."""
chunked_prefill_enabled: bool = field(init=False)
chunked_prefill_enabled: bool = Field(init=False)
"""True if chunked prefill is enabled."""
disable_chunked_mm_input: bool = False
@ -155,6 +156,20 @@ class SchedulerConfig:
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
@field_validator(
"max_num_batched_tokens",
"max_num_seqs",
"max_model_len",
"enable_chunked_prefill",
mode="wrap",
)
@classmethod
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
"""Skip validation if the value is `None` when initialisation is delayed."""
if value is None:
return value
return handler(value)
def __post_init__(self, is_encoder_decoder: bool) -> None:
if self.max_model_len is None:
self.max_model_len = 8192
@ -260,19 +275,7 @@ class SchedulerConfig:
self.max_num_seqs * self.max_model_len,
)
if self.num_lookahead_slots < 0:
raise ValueError(
"num_lookahead_slots "
f"({self.num_lookahead_slots}) must be greater than or "
"equal to 0."
)
if self.max_num_partial_prefills < 1:
raise ValueError(
f"max_num_partial_prefills ({self.max_num_partial_prefills}) "
"must be greater than or equal to 1."
)
elif self.max_num_partial_prefills > 1:
if self.max_num_partial_prefills > 1:
if not self.chunked_prefill_enabled:
raise ValueError(
"Chunked prefill must be enabled to set "
@ -286,13 +289,10 @@ class SchedulerConfig:
f"than the max_model_len ({self.max_model_len})."
)
if (self.max_long_partial_prefills < 1) or (
self.max_long_partial_prefills > self.max_num_partial_prefills
):
if self.max_long_partial_prefills > self.max_num_partial_prefills:
raise ValueError(
f"max_long_partial_prefills ({self.max_long_partial_prefills}) "
"must be greater than or equal to 1 and less than or equal to "
f"max_num_partial_prefills ({self.max_num_partial_prefills})."
f"{self.max_long_partial_prefills=} must be less than or equal to "
f"{self.max_num_partial_prefills=}."
)
return self

View File

@ -78,10 +78,6 @@ class SpeculativeConfig:
draft_tensor_parallel_size: int | None = Field(default=None, ge=1)
"""The degree of the tensor parallelism for the draft model. Can only be 1
or the same as the target model's tensor parallel size."""
disable_logprobs: bool = True
"""If set to True, token log probabilities are not returned during
speculative decoding. If set to False, token log probabilities are returned
according to the log probability settings in SamplingParams."""
# Draft model configuration
quantization: me_quant.QuantizationMethods | None = None
@ -126,12 +122,6 @@ class SpeculativeConfig:
"""The configuration of the target model."""
target_parallel_config: SkipValidation[ParallelConfig] = None # type: ignore
"""The parallel configuration for the target model."""
enable_chunked_prefill: SkipValidation[bool] = None # type: ignore
"""Whether vLLM is configured to use chunked prefill or not. Used for
raising an error since it's not yet compatible with speculative decode."""
disable_log_stats: SkipValidation[bool] = None # type: ignore
"""Whether to disable the periodic printing of stage times in speculative
decoding."""
# params generated in the post-init stage
draft_model_config: SkipValidation[ModelConfig] = None # type: ignore

View File

@ -2,8 +2,9 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
from typing import Any, Literal
from typing import Any, Literal, Self
from pydantic import model_validator
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
@ -56,7 +57,8 @@ class StructuredOutputsConfig:
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
def __post_init__(self):
@model_validator(mode="after")
def _validate_structured_output_config(self) -> Self:
if self.disable_any_whitespace and self.backend not in ("xgrammar", "guidance"):
raise ValueError(
"disable_any_whitespace is only supported for "
@ -67,3 +69,4 @@ class StructuredOutputsConfig:
"disable_additional_properties is only supported "
"for the guidance backend."
)
return self

View File

@ -289,6 +289,48 @@ class VllmConfig:
return replace(self, model_config=model_config)
def _post_init_kv_transfer_config(self) -> None:
"""Update KVTransferConfig based on top-level configs in VllmConfig.
Right now, this function reads the offloading settings from
CacheConfig and configures the KVTransferConfig accordingly.
"""
if (kv_offloading_backend := self.cache_config.kv_offloading_backend) is None:
return
# If no KVTransferConfig is provided, create a default one.
if self.kv_transfer_config is None:
self.kv_transfer_config = KVTransferConfig()
if (kv_offloading_size := self.cache_config.kv_offloading_size) is None:
raise ValueError(
"You must set kv_offloading_size when kv_offloading_backend is set."
)
num_kv_ranks = (
self.parallel_config.tensor_parallel_size
* self.parallel_config.pipeline_parallel_size
)
if kv_offloading_backend == "native":
self.kv_transfer_config.kv_connector = "OffloadingConnector"
kv_bytes_per_rank = kv_offloading_size * (1 << 30) / num_kv_ranks
# NOTE(ApostaC): the actual calculation for num_cpu_blocks should be
# done after the model's KV cache is initialized
self.kv_transfer_config.kv_connector_extra_config.update(
{"kv_bytes_per_rank": kv_bytes_per_rank, "num_cpu_blocks": 0}
)
elif kv_offloading_backend == "lmcache":
self.kv_transfer_config.kv_connector = "LMCacheConnectorV1"
kv_gb_per_rank = kv_offloading_size / num_kv_ranks
self.kv_transfer_config.kv_connector_extra_config = {
"lmcache.local_cpu": True,
"lmcache.max_local_cpu_size": kv_gb_per_rank,
}
# This is the same for all backends
self.kv_transfer_config.kv_role = "kv_both"
def __post_init__(self):
"""Verify configs are valid & consistent with each other."""
@ -646,6 +688,9 @@ class VllmConfig:
if "-quant_fp8" not in custom_ops:
custom_ops.append("+quant_fp8")
# Handle the KV connector configs
self._post_init_kv_transfer_config()
def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list:
# remove the sizes that not multiple of tp_size when
# enable sequence parallelism

View File

@ -6,7 +6,7 @@ KV cache helper for store.
from collections.abc import Sequence
from concurrent.futures import CancelledError, Future
from typing import TYPE_CHECKING, Literal, cast
from typing import TYPE_CHECKING, Literal
import torch
@ -138,8 +138,11 @@ class KVOutputAggregator:
return cls(connector.get_finished_count() or world_size)
def aggregate(
self, outputs: list[ModelRunnerOutput], output_rank: int = 0
) -> ModelRunnerOutput:
self, outputs: list[ModelRunnerOutput | None], output_rank: int = 0
) -> ModelRunnerOutput | None:
if not outputs[output_rank]:
return None
# Aggregate kv_connector_output from all workers
def update_finished_set(
@ -161,6 +164,7 @@ class KVOutputAggregator:
aggregated_kv_connector_stats = None
invalid_block_ids = set[int]()
for model_runner_output in outputs:
assert model_runner_output is not None
kv_output = model_runner_output.kv_connector_output
if not kv_output:
continue
@ -204,6 +208,7 @@ class KVOutputAggregator:
# select output of the worker specified by output_rank
output = outputs[output_rank]
assert output is not None
output.kv_connector_output = KVConnectorOutput(
finished_sending=finished_sending or None,
finished_recving=finished_recving or None,
@ -215,13 +220,16 @@ class KVOutputAggregator:
return output
def async_aggregate(
self, output_futures: Sequence[Future[ModelRunnerOutput]], output_rank: int = 0
) -> Future[ModelRunnerOutput]:
self,
output_futures: Sequence[Future[ModelRunnerOutput | None]],
output_rank: int = 0,
) -> Future[ModelRunnerOutput | None]:
"""Takes a list of futures and returns a single future which resolves
to the respective list of outputs."""
result_future: Future[ModelRunnerOutput] = Future()
result_future: Future[ModelRunnerOutput | None] = Future()
outputs: list[ModelRunnerOutput | None] = [None] * len(output_futures)
remaining = len(output_futures)
def make_callback(idx):
def callback(fut):
@ -236,12 +244,10 @@ class KVOutputAggregator:
result_future.set_exception(e)
# this check assumes io_thread_pool uses a single thread
if all(outputs):
result_future.set_result(
self.aggregate(
cast(list[ModelRunnerOutput], outputs), output_rank
)
)
nonlocal remaining
remaining -= 1
if not remaining:
result_future.set_result(self.aggregate(outputs, output_rank))
return callback

View File

@ -122,6 +122,15 @@ class KVConnectorRole(enum.Enum):
WORKER = 1
class KVConnectorHandshakeMetadata(ABC): # noqa: B024
"""
Metadata used for out of band connector handshake between
P/D workers. This needs to serializeable.
"""
pass
class KVConnectorMetadata(ABC): # noqa: B024
"""
Abstract Metadata used to communicate between the
@ -320,6 +329,18 @@ class KVConnectorBase_V1(ABC):
"""
return None
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
"""
Get the KVConnector handshake metadata for this connector.
This metadata is used for out-of-band connector handshake
between P/D workers.
Returns:
KVConnectorHandshakeMetadata: the handshake metadata.
None if no handshake metadata is available.
"""
return None
# ==============================
# Scheduler-side methods
# ==============================
@ -477,6 +498,17 @@ class KVConnectorBase_V1(ABC):
"""
return None
def set_xfer_handshake_metadata(
self, metadata: dict[int, KVConnectorHandshakeMetadata]
) -> None:
"""
Set the KV connector handshake metadata for this connector.
Args:
metadata (KVConnectorHandshakeMetadata): the handshake metadata to set.
"""
return None
@classmethod
def build_prom_metrics(
cls,

View File

@ -27,6 +27,7 @@ from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
CopyBlocksOp,
KVConnectorBase_V1,
KVConnectorHandshakeMetadata,
KVConnectorMetadata,
KVConnectorRole,
)
@ -93,15 +94,12 @@ _NIXL_SUPPORTED_DEVICE = {
_NIXL_SUPPORTED_DEVICE.update(current_platform.get_nixl_supported_devices())
class NixlAgentMetadata(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
# required for @cached_property.
dict=True,
):
@dataclass
class NixlAgentMetadata(KVConnectorHandshakeMetadata):
engine_id: str
agent_metadata: bytes
kv_caches_base_addr: list[int]
device_id: int
num_blocks: int
block_lens: list[int]
attn_backend_name: str
@ -223,6 +221,18 @@ class NixlConnector(KVConnectorBase_V1):
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, block_ids)
def set_xfer_handshake_metadata(
self, metadata: dict[int, KVConnectorHandshakeMetadata]
) -> None:
"""
Set the KV connector handshake metadata for this connector.
Args:
metadata (dict): the handshake metadata to set.
"""
assert self.connector_scheduler is not None
self.connector_scheduler.set_xfer_handshake_metadata(metadata)
############################################################
# Worker Side Methods
############################################################
@ -299,6 +309,21 @@ class NixlConnector(KVConnectorBase_V1):
def shutdown(self):
if self.connector_worker is not None:
self.connector_worker.shutdown()
if self.connector_scheduler is not None:
self.connector_scheduler.shutdown()
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
"""
Get the KVConnector handshake metadata for this connector.
This metadata is used for out-of-band connector handshake
between P/D workers.
Returns:
KVConnectorHandshakeMetadata: the handshake metadata.
None if no handshake metadata is available.
"""
assert self.connector_worker is not None
return self.connector_worker.xfer_handshake_metadata
class NixlConnectorScheduler:
@ -312,12 +337,16 @@ class NixlConnectorScheduler:
self.side_channel_port = (
envs.VLLM_NIXL_SIDE_CHANNEL_PORT
+ vllm_config.parallel_config.data_parallel_rank
* vllm_config.parallel_config.tensor_parallel_size
)
assert vllm_config.kv_transfer_config is not None
self.use_host_buffer = vllm_config.kv_transfer_config.kv_buffer_device == "cpu"
logger.info("Initializing NIXL Scheduler %s", engine_id)
# Background thread for handling new handshake requests.
self._nixl_handshake_listener_t: threading.Thread | None = None
self._encoded_xfer_handshake_metadata: dict[int, Any] = {}
self._stop_event = threading.Event()
# Requests that need to start recv/send.
# New requests are added by update_state_after_alloc in
# the scheduler. Used to make metadata passed to Worker.
@ -330,6 +359,89 @@ class NixlConnectorScheduler:
# remote prefill or aborted.
self._reqs_not_processed: set[ReqId] = set()
def shutdown(self):
self._stop_event.set()
if self._nixl_handshake_listener_t is not None:
self._nixl_handshake_listener_t.join()
self._nixl_handshake_listener_t = None
def set_xfer_handshake_metadata(
self, metadata: dict[int, KVConnectorHandshakeMetadata]
) -> None:
"""
Set the KV connector handshake metadata for this connector.
Args:
metadata (dict): the handshake metadata to set.
"""
encoded_data: dict[int, bytes] = {}
encoder = msgspec.msgpack.Encoder()
for tp_rank, rank_metadata in metadata.items():
if not isinstance(rank_metadata, NixlAgentMetadata):
raise ValueError(
"NixlConnectorScheduler expects NixlAgentMetadata for "
"handshake metadata."
)
encoded_data[tp_rank] = encoder.encode(rank_metadata)
logger.debug(
"Tp rank %d: encoded NixlAgentMetadata size: %s bytes",
tp_rank,
str(len(encoded_data[tp_rank])),
)
self._encoded_xfer_handshake_metadata = encoded_data
# Only start the listener when we have metadata to serve.
if self._nixl_handshake_listener_t is None:
ready_event = threading.Event()
self._nixl_handshake_listener_t = threading.Thread(
target=self._nixl_handshake_listener,
args=(
encoded_data,
ready_event,
self._stop_event,
self.side_channel_port,
),
daemon=True,
name="nixl_handshake_listener",
)
self._nixl_handshake_listener_t.start()
ready_event.wait() # Wait for listener ZMQ socket to be ready.
@staticmethod
def _nixl_handshake_listener(
encoded_data: dict[int, Any],
ready_event: threading.Event,
stop_event: threading.Event,
port: int,
):
"""Background thread for getting new NIXL handshakes."""
# NOTE(rob): this is a simple implementation. We will move
# to a better approach via HTTP endpoint soon.
# Listen for new requests for metadata.
host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
path = make_zmq_path("tcp", host, port)
logger.debug("Starting listening on path: %s", path)
with zmq_ctx(zmq.ROUTER, path) as sock:
sock.setsockopt(zmq.RCVTIMEO, 1000)
ready_event.set()
while True:
try:
identity, _, msg = sock.recv_multipart()
except zmq.Again:
if stop_event.is_set():
break
continue
# Decode the message which contains (GET_META_MSG, rank)
msg, target_tp_rank = msgspec.msgpack.decode(msg)
logger.debug(
"Received message for tp rank %s",
target_tp_rank,
)
if msg != GET_META_MSG:
logger.warning("Connection listener got unexpected message %s", msg)
sock.send_multipart((identity, b"", encoded_data[target_tp_rank]))
def get_num_new_matched_tokens(
self, request: "Request", num_computed_tokens: int
) -> tuple[int, bool]:
@ -537,8 +649,6 @@ class NixlConnectorScheduler:
class NixlConnectorWorker:
"""Implementation of Worker side methods"""
_POLL_TIMEOUT = 0.1 # Handshake thread polls for stop event every 100ms
@dataclass
class TpKVTopology:
"""
@ -651,16 +761,6 @@ class NixlConnectorWorker:
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict)
# NIXL handshake port.
# NOTE(rob): Within a DP group, each DP rank gets its own
# base port (which is sent in the KVTransferParams).
# Each TP rank listens/queries on the base_port + tp_rank.
self.side_channel_port: int = (
envs.VLLM_NIXL_SIDE_CHANNEL_PORT
+ vllm_config.parallel_config.data_parallel_rank
* vllm_config.parallel_config.tensor_parallel_size
)
# Metadata.
self.engine_id: EngineId = engine_id
self.tp_rank = get_tensor_model_parallel_rank()
@ -706,6 +806,7 @@ class NixlConnectorWorker:
# Map of engine_id -> kv_caches_base_addr. For TP case, each local
# rank will still only pull from a single remote TP worker.
self.kv_caches_base_addr: dict[EngineId, list[int]] = {}
self.device_id: int = 0
# Number of NIXL regions. Currently one region per cache
# (so 1 per layer for MLA, otherwise 2 per layer)
@ -736,9 +837,8 @@ class NixlConnectorWorker:
# requests that skipped transfer (handshake or transfer failures)
self._failed_recv_reqs: set[ReqId] = set()
# Background thread for handling new handshake requests.
self._nixl_handshake_listener_t: threading.Thread | None = None
self._nixl_handshake_listener_stop_event: threading.Event | None = None
# Handshake metadata of this worker for NIXL transfers.
self.xfer_handshake_metadata: NixlAgentMetadata | None = None
# Background thread for initializing new NIXL handshakes.
self._handshake_initiation_executor = ThreadPoolExecutor(
# NIXL is not guaranteed to be thread-safe, limit 1 worker.
@ -790,42 +890,6 @@ class NixlConnectorWorker:
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
)
@staticmethod
def _nixl_handshake_listener(
metadata: NixlAgentMetadata,
ready_event: threading.Event,
stop_event: threading.Event,
base_port: int,
tp_rank: int,
):
"""Background thread for getting new NIXL handshakes."""
# NOTE(rob): this is a simple implementation. We will move
# to a better approach via HTTP endpoint soon.
encoder = msgspec.msgpack.Encoder()
encoded_data = encoder.encode(metadata)
size_in_bytes = len(encoded_data)
logger.debug("Size of encoded NixlAgentMetadata: %s bytes", str(size_in_bytes))
# Listen for new requests for metadata.
host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
path = make_zmq_path("tcp", host, base_port + tp_rank)
logger.debug("Starting listening on path: %s", path)
with zmq_ctx(zmq.ROUTER, path) as sock:
ready_event.set()
poller = zmq.Poller()
poller.register(sock, zmq.POLLIN)
while not stop_event.is_set():
events = dict(
poller.poll(timeout=NixlConnectorWorker._POLL_TIMEOUT * 1000)
)
if sock not in events:
continue
identity, _, msg = sock.recv_multipart()
if msg != GET_META_MSG:
logger.warning("Connection listener got unexpected message %s", msg)
sock.send_multipart((identity, b"", encoded_data))
def _nixl_handshake(
self,
host: str,
@ -844,16 +908,17 @@ class NixlConnectorWorker:
# Handshake only with the remote TP rank that current local rank will
# pull from. With homogeneous TP it happens to be the same rank_i.
p_remote_rank = self.kv_topo.get_target_remote_rank(remote_tp_size)
path = make_zmq_path("tcp", host, port + p_remote_rank)
path = make_zmq_path("tcp", host, port)
logger.debug(
"Querying metadata on path: %s at remote rank %s", path, p_remote_rank
"Querying metadata on path: %s at remote tp rank %s", path, p_remote_rank
)
# Send query for the request.
with zmq_ctx(zmq.REQ, path) as sock:
msg = msgspec.msgpack.encode((GET_META_MSG, p_remote_rank))
# Set receive timeout to 5 seconds to avoid hanging on dead server
sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds
sock.send(GET_META_MSG)
sock.send(msg)
metadata_bytes = sock.recv()
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
metadata = decoder.decode(metadata_bytes)
@ -1042,6 +1107,10 @@ class NixlConnectorWorker:
assert tensor_size_bytes == curr_tensor_size_bytes, (
"All kv cache tensors must have the same size"
)
# Need to make sure the device ID is non-negative for NIXL,
# Torch uses -1 to indicate CPU tensors while NIXL uses explicit
# memory type.
self.device_id = max(cache.get_device(), 0)
caches_data.append(
(base_addr, curr_tensor_size_bytes, self.device_id, "")
)
@ -1139,10 +1208,11 @@ class NixlConnectorWorker:
assert len(self.block_window_per_layer) == self.num_layers
# After KV Caches registered, listen for new connections.
metadata = NixlAgentMetadata(
self.xfer_handshake_metadata = NixlAgentMetadata(
engine_id=self.engine_id,
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
device_id=self.device_id,
num_blocks=self.num_blocks,
block_lens=self.block_len_per_layer,
attn_backend_name=self.backend_name,
@ -1150,22 +1220,6 @@ class NixlConnectorWorker:
if not self.use_host_buffer
else self.host_buffer_kv_cache_layout,
)
ready_event, stop_event = threading.Event(), threading.Event()
self._nixl_handshake_listener_t = threading.Thread(
target=self._nixl_handshake_listener,
args=(
metadata,
ready_event,
stop_event,
self.side_channel_port,
self.tp_rank,
),
daemon=True,
name="nixl_handshake_listener",
)
self._nixl_handshake_listener_t.start()
self._nixl_handshake_listener_stop_event = stop_event
ready_event.wait() # Wait for listener ZMQ socket to be ready.
def add_remote_agent(
self,
@ -1267,7 +1321,7 @@ class NixlConnectorWorker:
# self.block_len == remote_block_len//tp_ratio bytes.
addr = base_addr + block_offset + rank_offset
# (addr, len, device id)
blocks_data.append((addr, kv_block_len, remote_tp_rank))
blocks_data.append((addr, kv_block_len, nixl_agent_meta.device_id))
if self._use_flashinfer:
# With FlashInfer index V separately to allow head splitting.
@ -1275,7 +1329,9 @@ class NixlConnectorWorker:
block_offset = block_id * nixl_agent_meta.block_lens[i]
addr = base_addr + block_offset + rank_offset
v_addr = addr + nixl_agent_meta.block_lens[i] // 2
blocks_data.append((v_addr, kv_block_len, remote_tp_rank))
blocks_data.append(
(v_addr, kv_block_len, nixl_agent_meta.device_id)
)
logger.debug(
"Created %s blocks for dst engine %s with remote rank %s and local rank %s",
@ -1843,14 +1899,6 @@ class NixlConnectorWorker:
def shutdown(self):
"""Shutdown the connector worker."""
self._handshake_initiation_executor.shutdown(wait=False)
if self._nixl_handshake_listener_stop_event is not None:
self._nixl_handshake_listener_stop_event.set()
self._nixl_handshake_listener_stop_event = None
if self._nixl_handshake_listener_t is not None:
# Generous timeout to allow the thread to exit
self._nixl_handshake_listener_t.join(timeout=self._POLL_TIMEOUT * 10)
assert not self._nixl_handshake_listener_t.is_alive()
self._nixl_handshake_listener_t = None
for handles in self._recving_transfers.values():
for handle, _ in handles:
self.nixl_wrapper.release_xfer_handle(handle)

View File

@ -54,7 +54,13 @@ from vllm.config import (
VllmConfig,
get_attr_docs,
)
from vllm.config.cache import BlockSize, CacheDType, MambaDType, PrefixCachingHashAlgo
from vllm.config.cache import (
BlockSize,
CacheDType,
KVOffloadingBackend,
MambaDType,
PrefixCachingHashAlgo,
)
from vllm.config.device import Device
from vllm.config.model import (
ConvertOption,
@ -553,6 +559,11 @@ class EngineArgs:
kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill
kv_offloading_size: float | None = CacheConfig.kv_offloading_size
kv_offloading_backend: KVOffloadingBackend | None = (
CacheConfig.kv_offloading_backend
)
def __post_init__(self):
# support `EngineArgs(compilation_config={...})`
# without having to manually construct a
@ -896,6 +907,12 @@ class EngineArgs:
cache_group.add_argument(
"--mamba-block-size", **cache_kwargs["mamba_block_size"]
)
cache_group.add_argument(
"--kv-offloading-size", **cache_kwargs["kv_offloading_size"]
)
cache_group.add_argument(
"--kv-offloading-backend", **cache_kwargs["kv_offloading_backend"]
)
# Multimodal related configs
multimodal_kwargs = get_kwargs(MultiModalConfig)
@ -1246,8 +1263,6 @@ class EngineArgs:
self,
target_model_config: ModelConfig,
target_parallel_config: ParallelConfig,
enable_chunked_prefill: bool,
disable_log_stats: bool,
) -> SpeculativeConfig | None:
"""Initializes and returns a SpeculativeConfig object based on
`speculative_config`.
@ -1267,8 +1282,6 @@ class EngineArgs:
{
"target_model_config": target_model_config,
"target_parallel_config": target_parallel_config,
"enable_chunked_prefill": enable_chunked_prefill,
"disable_log_stats": disable_log_stats,
}
)
return SpeculativeConfig(**self.speculative_config)
@ -1391,6 +1404,8 @@ class EngineArgs:
mamba_cache_dtype=self.mamba_cache_dtype,
mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
mamba_block_size=self.mamba_block_size,
kv_offloading_size=self.kv_offloading_size,
kv_offloading_backend=self.kv_offloading_backend,
)
ray_runtime_env = None
@ -1561,8 +1576,6 @@ class EngineArgs:
speculative_config = self.create_speculative_config(
target_model_config=model_config,
target_parallel_config=parallel_config,
enable_chunked_prefill=self.enable_chunked_prefill,
disable_log_stats=self.disable_log_stats,
)
# make sure num_lookahead_slots is set appropriately depending on
@ -1813,7 +1826,7 @@ class EngineArgs:
incremental_prefill_supported = (
pooling_type is not None
and pooling_type.lower() == "last"
and is_causal
and bool(is_causal)
)
action = "Enabling" if incremental_prefill_supported else "Disabling"

View File

@ -241,6 +241,7 @@ async def build_async_engine_client_from_engine_args(
)
# Don't keep the dummy data in memory
assert async_llm is not None
await async_llm.reset_mm_cache()
yield async_llm

View File

@ -345,22 +345,7 @@ class OpenAIServing:
if is_explicit_encoder_decoder_prompt(prompt):
raise NotImplementedError
else:
processed_inputs = processor.input_preprocessor._prompt_to_llm_inputs(
prompt
)
if processed_inputs["type"] == "embeds":
raise NotImplementedError
# This is a workaround to fix multimodal beam search; this is a
# bandaid fix for 2 small problems:
# 1. Multi_modal_data on the processed_inputs currently resolves to
# `None`.
# 2. preprocessing above expands the multimodal placeholders. However,
# this happens again in generation, so the double expansion causes
# a mismatch.
# TODO - would be ideal to handle this more gracefully.
prompt_text: str | None
prompt_token_ids: list[int]
multi_modal_data: MultiModalDataDict | None
@ -373,9 +358,16 @@ class OpenAIServing:
prompt_token_ids = prompt.get("prompt_token_ids", []) # type: ignore
multi_modal_data = prompt.get("multi_modal_data") # type: ignore
mm_processor_kwargs: dict[str, Any] | None = processed_inputs.get(
"mm_processor_kwargs"
) # type: ignore
mm_processor_kwargs: dict[str, Any] | None = None
# This is a workaround to fix multimodal beam search; this is a
# bandaid fix for 2 small problems:
# 1. Multi_modal_data on the processed_inputs currently resolves to
# `None`.
# 2. preprocessing above expands the multimodal placeholders. However,
# this happens again in generation, so the double expansion causes
# a mismatch.
# TODO - would be ideal to handle this more gracefully.
tokenized_length = len(prompt_token_ids)

View File

@ -2,11 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import re
import uuid
from collections.abc import Sequence
from typing import Any
import regex as re
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,

View File

@ -15,9 +15,7 @@ from vllm.distributed.parallel_state import (
from vllm.lora.layers.base import BaseLayerWithLoRA
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
_get_config_dtype_str,
mxfp4_w4a16_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
modular_marlin_fused_moe,
@ -26,13 +24,16 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
modular_triton_fused_moe,
try_get_optimal_moe_config,
)
from vllm.model_executor.layers.quantization.mxfp4 import Mxfp4Config
class FusedMoEWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: FusedMoE) -> None:
super().__init__()
self.base_layer = base_layer
assert not self.base_layer.use_ep, (
"EP support for Fused MoE LoRA is not implemented yet."
)
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.device = base_layer.w2_weight.device
@ -42,17 +43,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
moe_state_dict = {}
top_k = self.base_layer.top_k
if self.base_layer.quant_config is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
elif not isinstance(self.base_layer.quant_config, Mxfp4Config):
quant_config = self.base_layer.quant_config
else:
quant_config = mxfp4_w4a16_moe_quant_config(
w1_bias=self.base_layer.w13_bias,
w2_bias=self.base_layer.w2_bias,
w1_scale=self.base_layer.w13_weight_scale,
w2_scale=self.base_layer.w2_weight_scale,
)
self.base_layer.ensure_moe_quant_config_init()
quant_config = self.base_layer.quant_method.moe_quant_config
m_fused_moe_fn = (
modular_triton_fused_moe(
@ -69,7 +61,6 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
moe_state_dict["hidden_states"] = kwargs["hidden_states"]
moe_state_dict["topk_ids"] = kwargs["topk_ids"]
moe_state_dict["topk_weights"] = kwargs["topk_weights"]
moe_state_dict["global_num_experts"] = kwargs["global_num_experts"]
moe_state_dict["expert_map"] = kwargs["expert_map"]
moe_state_dict["apply_router_weight_on_input"] = kwargs[
"apply_router_weight_on_input"
@ -86,7 +77,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
hidden_states = moe_state_dict["hidden_states"]
topk_weights = moe_state_dict["topk_weights"]
curr_topk_ids = moe_state_dict["topk_ids"]
global_num_experts = moe_state_dict["global_num_experts"]
expert_map = moe_state_dict["expert_map"]
config_dtype = _get_config_dtype_str(
@ -118,7 +109,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
curr_topk_ids,
num_tokens,
config["BLOCK_SIZE_M"],
global_num_experts,
self.base_layer.local_num_experts,
max_loras,
expert_map,
)
@ -236,14 +227,10 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
) -> None:
"""Initializes lora matrices."""
assert not self.base_layer.use_ep, (
"EP support for Fused MoE LoRA is not implemented yet."
)
self.w1_lora_a_stacked = torch.zeros(
(
max_loras,
self.base_layer.global_num_experts,
self.base_layer.local_num_experts,
lora_config.max_lora_rank,
self.base_layer.hidden_size,
),
@ -253,7 +240,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.w1_lora_b_stacked = torch.zeros(
(
max_loras,
self.base_layer.global_num_experts,
self.base_layer.local_num_experts,
self.base_layer.intermediate_size_per_partition,
lora_config.max_lora_rank,
),
@ -264,7 +251,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.w2_lora_a_stacked = torch.zeros(
(
max_loras,
self.base_layer.global_num_experts,
self.base_layer.local_num_experts,
lora_config.max_lora_rank,
self.base_layer.intermediate_size_per_partition,
),
@ -274,7 +261,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.w2_lora_b_stacked = torch.zeros(
(
max_loras,
self.base_layer.global_num_experts,
self.base_layer.local_num_experts,
self.base_layer.hidden_size,
lora_config.max_lora_rank,
),
@ -285,7 +272,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.w3_lora_a_stacked = torch.zeros(
(
max_loras,
self.base_layer.global_num_experts,
self.base_layer.local_num_experts,
lora_config.max_lora_rank,
self.base_layer.hidden_size,
),
@ -295,7 +282,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.w3_lora_b_stacked = torch.zeros(
(
max_loras,
self.base_layer.global_num_experts,
self.base_layer.local_num_experts,
self.base_layer.intermediate_size_per_partition,
lora_config.max_lora_rank,
),
@ -308,7 +295,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.lora_a_stacked = []
self.lora_b_stacked = []
for lora_id in range(max_loras):
for experts_id in range(self.base_layer.global_num_experts):
for experts_id in range(self.base_layer.local_num_experts):
# gate_proj,down_proj,up_proj
self.lora_a_stacked.append(self.w1_lora_a_stacked[lora_id][experts_id])
self.lora_a_stacked.append(self.w2_lora_a_stacked[lora_id][experts_id])

View File

@ -88,14 +88,17 @@ def _fused_moe_lora_kernel(
grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)
# calculate pid_m,pid_n
pid_sk = pid % SPLIT_K
pid_m_n = pid // SPLIT_K
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
group_id = pid_m_n // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
pid_m = first_pid_m + ((pid_m_n % num_pid_in_group) % group_size_m)
pid_n = (pid_m_n % num_pid_in_group) // group_size_m
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr + lora_idx)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
@ -113,7 +116,7 @@ def _fused_moe_lora_kernel(
cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
offs_k = pid_sk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
token_ind = stride_tl * lora_idx + offs_token_id
@ -131,7 +134,8 @@ def _fused_moe_lora_kernel(
cur_b_ptr
+ lora_idx * stride_bl
+ expert_id * stride_be
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
+ offs_k[:, None] * stride_bk
+ offs_bn[None, :] * stride_bn
)
# accumulator

View File

@ -56,6 +56,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
ep_size: int = 1,
tp_rank: int = 0,
tp_size: int = 1,
use_dp: bool = False,
):
super().__init__(quant_config)
assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn, None), (
@ -67,6 +68,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
self.tp_rank = tp_rank
self.tp_size = tp_size
self.out_dtype = out_dtype
self.use_dp = use_dp
@property
def activation_formats(
@ -117,7 +119,8 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
"""
workspace1 = (M, K)
workspace2 = (0,)
output_shape = (M, K * 2 if self.quant_dtype == "nvfp4" else K)
# For TP, the quantization is fused with fused_moe call.
output_shape = (M, K * 2 if self.quant_dtype == "nvfp4" and self.use_dp else K)
# The workspace is determined by `aq`, since it comes after any
# potential communication op and is involved in the expert computation.
return (workspace1, workspace2, output_shape)
@ -214,6 +217,7 @@ def flashinfer_cutlass_moe_fp4(
FlashInferExperts(
out_dtype=hidden_states.dtype,
quant_config=quant_config,
use_dp=False,
),
)

View File

@ -170,6 +170,8 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
self._apply_router_weight_on_input(
a1, topk_weights, topk_ids, apply_router_weight_on_input
)
if not self.use_dp:
return a1, None, None, topk_ids, topk_weights
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
@ -179,14 +181,13 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
quant_config.block_shape,
is_fp4_scale_swizzled=not self.use_dp,
)
if self.use_dp:
topk_weights, topk_ids, a1q, a1q_scale = get_dp_group().all_gatherv(
[topk_weights, topk_ids, a1q, a1q_scale],
dim=0,
sizes=get_local_sizes(),
)
if quant_config.quant_dtype == "nvfp4":
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
topk_weights, topk_ids, a1q, a1q_scale = get_dp_group().all_gatherv(
[topk_weights, topk_ids, a1q, a1q_scale],
dim=0,
sizes=get_local_sizes(),
)
if quant_config.quant_dtype == "nvfp4":
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
return a1q, a1q_scale, None, topk_ids, topk_weights

View File

@ -672,8 +672,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input=apply_router_weight_on_input,
)
elif self.fused_experts is not None:
if self.moe.has_bias:
raise ValueError("FusedMoEModularKernel does not support bias.")
result = self.fused_experts(
hidden_states=x,
w1=layer.w13_weight,

View File

@ -40,18 +40,36 @@ logger = init_logger(__name__)
def kda_attention(
hidden_states: torch.Tensor,
output: torch.Tensor,
q_proj_states: torch.Tensor,
k_proj_states: torch.Tensor,
v_proj_states: torch.Tensor,
g1: torch.Tensor,
g2: torch.Tensor,
beta: torch.Tensor,
core_attn_out: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self._forward(hidden_states=hidden_states, output=output)
self._forward(
q_proj_states=q_proj_states,
k_proj_states=k_proj_states,
v_proj_states=v_proj_states,
g1=g1,
g2=g2,
beta=beta,
core_attn_out=core_attn_out,
)
def kda_attention_fake(
hidden_states: torch.Tensor,
output: torch.Tensor,
q_proj_states: torch.Tensor,
k_proj_states: torch.Tensor,
v_proj_states: torch.Tensor,
g1: torch.Tensor,
g2: torch.Tensor,
beta: torch.Tensor,
core_attn_out: torch.Tensor,
layer_name: str,
) -> None:
return
@ -60,7 +78,7 @@ def kda_attention_fake(
direct_register_custom_op(
op_name="kda_attention",
op_func=kda_attention,
mutates_args=["output"],
mutates_args=["core_attn_out"],
fake_impl=kda_attention_fake,
)
@ -242,36 +260,54 @@ class KimiDeltaAttention(nn.Module, MambaBase):
positions: torch.Tensor,
output: torch.Tensor,
) -> None:
return torch.ops.vllm.kda_attention(
hidden_states,
output,
num_tokens = hidden_states.size(0)
q = self.q_proj(hidden_states)[0]
k = self.k_proj(hidden_states)[0]
v = self.v_proj(hidden_states)[0]
beta = self.b_proj(hidden_states)[0].float().sigmoid()
g1 = self.f_b_proj(self.f_a_proj(hidden_states)[0])[0]
g1 = fused_kda_gate(g1, self.A_log, self.head_dim, g_bias=self.dt_bias)
beta = beta.unsqueeze(0)
g1 = g1.unsqueeze(0)
g_proj_states = self.g_b_proj(self.g_a_proj(hidden_states)[0])[0]
g2 = rearrange(g_proj_states, "... (h d) -> ... h d", d=self.head_dim)
core_attn_out = torch.zeros(
(1, num_tokens, self.local_num_heads, self.head_dim),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
torch.ops.vllm.kda_attention(
q,
k,
v,
g1,
g2,
beta,
core_attn_out,
self.prefix,
)
core_attn_out = self.o_norm(core_attn_out, g2)
core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)")
output[:] = self.o_proj(core_attn_out)[0]
def _forward(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
q_proj_states: torch.Tensor,
k_proj_states: torch.Tensor,
v_proj_states: torch.Tensor,
g1: torch.Tensor,
g2: torch.Tensor,
beta: torch.Tensor,
core_attn_out: torch.Tensor,
) -> None:
forward_context = get_forward_context()
attn_metadata: AttentionMetadata = forward_context.attn_metadata
if attn_metadata is None:
# V1 profile run
# Mimic the memory allocation in the real run
q = torch.empty_like(hidden_states)
k = torch.empty_like(hidden_states)
v = torch.empty_like(hidden_states)
g = hidden_states.new_empty(
hidden_states.size(0),
self.local_num_heads,
self.head_dim,
dtype=torch.float32,
)
beta = torch.empty(
hidden_states.size(0), self.local_num_heads, dtype=torch.float32
)
core_attn_out = torch.empty_like(hidden_states)
# # V1 profile run
return
assert isinstance(attn_metadata, dict)
@ -288,10 +324,6 @@ class KimiDeltaAttention(nn.Module, MambaBase):
conv_state_k = conv_state_k.transpose(-1, -2)
conv_state_v = conv_state_v.transpose(-1, -2)
q_proj_states = self.q_proj(hidden_states)[0]
k_proj_states = self.k_proj(hidden_states)[0]
v_proj_states = self.v_proj(hidden_states)[0]
q_conv_weights = self.q_conv1d.weight.view(
self.q_conv1d.weight.size(0), self.q_conv1d.weight.size(2)
)
@ -374,14 +406,6 @@ class KimiDeltaAttention(nn.Module, MambaBase):
lambda x: rearrange(x, "n (h d) -> 1 n h d", d=self.head_dim), (q, k, v)
)
beta = self.b_proj(hidden_states)[0].float().sigmoid()
g = self.f_b_proj(self.f_a_proj(hidden_states)[0])[0]
g = fused_kda_gate(g, self.A_log, self.head_dim, g_bias=self.dt_bias)
beta = beta.unsqueeze(0)
g = g.unsqueeze(0)
if attn_metadata.num_prefills > 0:
zero_idx = non_spec_state_indices_tensor[~has_initial_state]
recurrent_state[zero_idx] = 0
@ -393,7 +417,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
q=q,
k=k,
v=v,
g=g,
g=g1,
beta=beta,
initial_state=initial_state,
output_final_state=True,
@ -410,17 +434,12 @@ class KimiDeltaAttention(nn.Module, MambaBase):
q=q,
k=k,
v=v,
g=g,
g=g1,
beta=beta,
initial_state=recurrent_state,
use_qk_l2norm_in_kernel=True,
cu_seqlens=non_spec_query_start_loc,
ssm_state_indices=non_spec_state_indices_tensor,
)
g_proj_states = self.g_b_proj(self.g_a_proj(hidden_states)[0])[0]
g = rearrange(g_proj_states, "... (h d) -> ... h d", d=self.head_dim)
core_attn_out = self.o_norm(core_attn_out_non_spec, g)
core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)")
output[:] = self.o_proj(core_attn_out)[0]
assert core_attn_out_non_spec.shape == core_attn_out.shape
core_attn_out[:] = core_attn_out_non_spec

View File

@ -1769,29 +1769,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
elif (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
):
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
flashinfer_cutlass_moe_fp4,
)
assert self.moe_quant_config is not None
return flashinfer_cutlass_moe_fp4(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
quant_config=self.moe_quant_config,
inplace=False,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
else:
# If no modular kernel is provided, use cutlass_moe_fp4 for TP case
# only (no EP).

View File

@ -79,6 +79,7 @@ def select_nvfp4_gemm_impl(
ep_size=moe.moe_parallel_config.ep_size,
tp_rank=moe.moe_parallel_config.tp_rank,
tp_size=moe.moe_parallel_config.tp_size,
use_dp=moe.moe_parallel_config.dp_size > 1,
)
# native cutlass experts currently don't support DP; TP case won't call this

View File

@ -26,6 +26,7 @@
# limitations under the License.
"""Inference-only GLM-4V model compatible with HuggingFace weights."""
import itertools
import math
from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import partial
@ -36,7 +37,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import BatchFeature
from transformers import BatchFeature, PretrainedConfig
from transformers.models.glm4v.configuration_glm4v import Glm4vVisionConfig
from transformers.models.glm4v.image_processing_glm4v import (
Glm4vImageProcessor,
@ -89,6 +90,7 @@ from ..layers.activation import SiluAndMul
from .interfaces import (
MultiModalEmbeddings,
SupportsLoRA,
SupportsMRoPE,
SupportsMultiModal,
SupportsPP,
)
@ -1386,7 +1388,7 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]):
dummy_inputs=Glm4vDummyInputsBuilder,
)
class Glm4vForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
):
merge_by_field_config = True
@ -1613,6 +1615,149 @@ class Glm4vForConditionalGeneration(
multimodal_embeddings += tuple(video_embeddings)
return multimodal_embeddings
def get_mrope_input_positions(
self,
input_tokens: list[int],
hf_config: "PretrainedConfig",
image_grid_thw: list[list[int]] | torch.Tensor | None,
video_grid_thw: list[list[int]] | torch.Tensor | None,
second_per_grid_ts: list[float] | None = None,
context_len: int = 0,
seq_len: int | None = None,
audio_feature_lengths: torch.Tensor | None = None,
use_audio_in_video: bool = False,
) -> tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value for GLM4V."""
image_token_id = hf_config.image_token_id
video_start_token_id = hf_config.video_start_token_id
video_end_token_id = hf_config.video_end_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
llm_pos_ids_list: list = []
if not (image_grid_thw is None and video_grid_thw is None):
if isinstance(image_grid_thw, torch.Tensor):
image_grid_thw = image_grid_thw.tolist()
input_token_type: list[str] = []
video_check_flg = False
for token in input_tokens:
if token == video_start_token_id:
video_check_flg = True
elif token == video_end_token_id:
video_check_flg = False
if (token == image_token_id) and (video_check_flg is False):
input_token_type.append("image")
elif (token == image_token_id) and (video_check_flg is True):
input_token_type.append("video")
else:
input_token_type.append("text")
input_type_group: list[tuple[str, int, int]] = []
for key, group_iter in itertools.groupby(
enumerate(input_token_type), lambda x: x[1]
):
group_list = list(group_iter)
start_index = group_list[0][0]
end_index = group_list[-1][0] + 1
input_type_group.append((key, start_index, end_index))
video_frame_num = 1
mm_data_idx = 0
for modality_type, start_idx, end_idx in input_type_group:
st_idx = (
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
)
if modality_type == "image":
t, h, w = (
image_grid_thw[mm_data_idx][0],
image_grid_thw[mm_data_idx][1],
image_grid_thw[mm_data_idx][2],
)
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_merge_size,
w // spatial_merge_size,
)
t_index = (
torch.arange(llm_grid_t)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
.flatten()
)
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(llm_grid_t, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(llm_grid_t, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx
)
mm_data_idx += 1
elif modality_type == "video":
t, h, w = (
video_frame_num,
image_grid_thw[mm_data_idx][1],
image_grid_thw[mm_data_idx][2],
)
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_merge_size,
w // spatial_merge_size,
)
for t_idx in range(llm_grid_t):
t_index = (
torch.tensor(t_idx)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
.flatten()
)
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(1, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(1, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx
)
mm_data_idx += 1
video_frame_num += 1
else:
text_len = end_idx - start_idx
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
video_frame_num = 1
else:
text_len = len(input_tokens)
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1))
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
llm_positions = llm_positions[:, context_len:seq_len]
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
return llm_positions, mrope_position_delta
def forward(
self,
input_ids: torch.Tensor,

View File

@ -17,7 +17,9 @@ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPo
from transformers.utils import torch_int
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability
from vllm.attention.layer import (
maybe_get_vit_flash_attn_backend,
)
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
@ -56,12 +58,14 @@ from vllm.multimodal.processing import (
PromptUpdate,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (
MultiModalEmbeddings,
SupportsLoRA,
SupportsMRoPE,
SupportsMultiModal,
SupportsPP,
)
@ -337,7 +341,10 @@ def apply_rotary_pos_emb_flashatt(
cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous()
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
elif current_platform.is_rocm():
from flash_attn.ops.triton.rotary import apply_rotary as apply_rotary_emb
q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
@ -398,18 +405,28 @@ class KeyeSiglipAttention(nn.Module):
attn_backend_override=attn_backend_override,
)
self.use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
torch.get_default_dtype()
):
self.attn_backend = _Backend.FLASH_ATTN
self.use_upstream_fa = True
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
use_upstream_fa=False,
attn_backend_override=attn_backend_override,
)
)
if self.attn_backend not in {_Backend.FLASH_ATTN, _Backend.XFORMERS}:
if self.attn_backend not in {
_Backend.FLASH_ATTN,
_Backend.XFORMERS,
_Backend.ROCM_AITER_FA,
}:
raise RuntimeError(
f"Keye-VL does not support {self.attn_backend} backend now."
)
self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN,
_Backend.ROCM_AITER_FA,
}
def forward(
self,
hidden_states: torch.Tensor,
@ -457,15 +474,10 @@ class KeyeSiglipAttention(nn.Module):
self.head_dim,
)
if self.attn_backend == _Backend.FLASH_ATTN:
if self.use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
if self.is_flash_attn_backend:
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = flash_attn_varlen_func(
output = self.flash_attn_varlen_func(
q,
k,
v,
@ -1542,7 +1554,7 @@ class BaseKeyeModule(nn.Module):
dummy_inputs=KeyeDummyInputsBuilder,
)
class KeyeForConditionalGeneration(
BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP
BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
):
def _build_projector(
self,
@ -1611,3 +1623,142 @@ class KeyeForConditionalGeneration(
return tuple(
self._process_video_embeds(video_type, video_grid_thw, pixel_values_videos)
)
def get_mrope_input_positions(
self,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: list[list[int]] | torch.Tensor,
video_grid_thw: list[list[int]] | torch.Tensor,
context_len: int = 0,
seq_len: int | None = None,
second_per_grid_ts: list[float] | None = None,
audio_feature_lengths: torch.Tensor | None = None,
use_audio_in_video: bool = False,
) -> tuple[torch.Tensor, int]:
if isinstance(video_grid_thw, list) and len(video_grid_thw) > 0:
video_grid_thw = video_grid_thw[0]
"""Get mrope input positions and delta value (Keye series)."""
def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]:
"""
Split grid_thw along the t dimension.
Args:
grid_thw: shape [N, 3] tensor or nested list of [t, h, w].
Returns:
List of [1, h, w] rows, repeated t times for each original row.
"""
if isinstance(grid_thw, list):
grid_thw = torch.tensor(grid_thw, dtype=torch.long)
if grid_thw.numel() == 0:
return []
t, hw = grid_thw[:, 0], grid_thw[:, 1:]
ones = torch.ones_like(hw[:, :1]) # [N,1]
out = torch.cat([ones, hw], dim=1).repeat_interleave(t, dim=0)
return out.tolist()
video_grid_thw = split_thw(video_grid_thw)
image_token_id = hf_config.image_token_id
video_token_id = hf_config.video_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
image_nums = len(image_grid_thw)
frame_nums = len(video_grid_thw)
llm_pos_ids_list: list = []
st = 0
remain_images, remain_frames = image_nums, frame_nums
image_index, video_index = 0, 0
for _ in range(image_nums + frame_nums):
if remain_images > 0:
try:
ed_image = input_tokens.index(image_token_id, st)
except ValueError:
ed_image = len(input_tokens) + 1
else:
ed_image = len(input_tokens) + 1
if remain_frames > 0:
try:
ed_video = input_tokens.index(video_token_id, st)
except ValueError:
ed_video = len(input_tokens) + 1
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
video_index += 1
remain_frames -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_merge_size,
w // spatial_merge_size,
)
text_len = ed - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
t_index = (
(
torch.arange(llm_grid_t)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
)
.long()
.flatten()
)
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(llm_grid_t, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(llm_grid_t, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
llm_positions = llm_positions[:, context_len:seq_len]
return llm_positions, mrope_position_delta

View File

@ -22,7 +22,6 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
@ -61,7 +60,7 @@ class KimiMLP(nn.Module):
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: QKVParallelLinear | None = None,
quant_config: QuantizationConfig | None = None,
reduce_results: bool = True,
prefix: str = "",
) -> None:
@ -155,6 +154,7 @@ class KimiMoE(nn.Module):
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False,
prefix=f"{prefix}.shared_experts",
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@ -340,7 +340,7 @@ class KimiDecoderLayer(nn.Module):
self.block_sparse_moe = KimiMoE(
config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
prefix=f"{prefix}.block_sparse_moe",
)
self.mlp = self.block_sparse_moe
else:

View File

@ -49,7 +49,7 @@ from functools import cached_property
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.activations import ACT2FN, PytorchGELUTanh
from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import is_flash_attn_2_available
@ -651,7 +651,7 @@ class MoonVitPretrainedModel(PreTrainedModel):
"num_heads": config.num_attention_heads,
"hidden_dim": config.hidden_size,
"mlp_dim": config.intermediate_size,
"activation": PytorchGELUTanh(),
"activation": ACT2FN["gelu_pytorch_tanh"],
"attn_bias": True,
"attn_implementation": config._attn_implementation,
},

View File

@ -364,6 +364,8 @@ class Qwen2_5_VisionAttention(nn.Module):
if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN:
self.use_upstream_fa = True
if current_platform.is_xpu():
self.use_upstream_fa = False
self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN,
_Backend.ROCM_AITER_FA,
@ -856,10 +858,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
) -> tuple[torch.Tensor, torch.Tensor]:
max_seqlen = torch.zeros([], device=cu_seqlens.device)
seqlens = torch.zeros(1, device=cu_seqlens.device)
if (
self.attn_backend == _Backend.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA
):
if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.attn_backend == _Backend.XFORMERS:
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]

View File

@ -34,7 +34,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from transformers import AutoConfig, BatchFeature, PretrainedConfig
from transformers import BatchFeature, PretrainedConfig
from transformers.models.qwen2_vl import Qwen2VLImageProcessor, Qwen2VLProcessor
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
Qwen2VLConfig,
@ -789,10 +789,7 @@ class Qwen2VisionTransformer(nn.Module):
self, cu_seqlens: torch.Tensor
) -> tuple[int | None, list[int] | None]:
max_seqlen, seqlens = None, None
if (
self.attn_backend == _Backend.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA
):
if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == _Backend.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
@ -1654,9 +1651,7 @@ class Tarsier2Processor(Qwen2VLProcessor):
class Tarsier2ProcessingInfo(Qwen2VLProcessingInfo):
def get_hf_config(self) -> Qwen2VLConfig:
model_path = self.ctx.model_config.model
original_config = AutoConfig.from_pretrained(model_path)
config_dict = original_config.to_dict()
correct_config = Qwen2VLConfig.from_dict(config_dict)
correct_config = Qwen2VLConfig.from_pretrained(model_path)
return correct_config

View File

@ -115,6 +115,12 @@ class XPUPlatform(Platform):
device_props = torch.xpu.get_device_properties(device_id)
return device_props.total_memory
@classmethod
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend:
from vllm.attention.backends.registry import _Backend
return _Backend.FLASH_ATTN
@classmethod
def inference_mode(cls):
return torch.no_grad()

View File

@ -896,6 +896,8 @@ def get_kernel_options(
return kernel_options
else:
preferred_block = 32 if query.dtype == torch.float32 else 64
block_lower_bound = 16
block_m_candidate = ensure_divisible(preferred_block, block_m)
block_n_candidate = ensure_divisible(preferred_block, block_n)
@ -910,6 +912,9 @@ def get_kernel_options(
max(1, block_n_candidate // 2), block_n
)
block_m_candidate = max(block_m_candidate, block_lower_bound)
block_n_candidate = max(block_n_candidate, block_lower_bound)
kernel_options["BLOCK_M"] = block_m_candidate
kernel_options["BLOCK_N"] = block_n_candidate

View File

@ -6,7 +6,7 @@ from typing import ClassVar
import torch
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
from vllm.attention.backends.abstract import AttentionLayer, AttentionType
from vllm.attention.backends.abstract import AttentionLayer, AttentionType, MultipleOf
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
@ -40,6 +40,10 @@ class FlashInferMLABackend(MLACommonBackend):
def get_builder_cls() -> type["FlashInferMLAMetadataBuilder"]:
return FlashInferMLAMetadataBuilder
@classmethod
def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]:
return [32, 64]
g_fi_workspace = torch.zeros(
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE,

View File

@ -15,8 +15,12 @@ class AsyncScheduler(Scheduler):
scheduler_output: SchedulerOutput,
) -> None:
super()._update_after_schedule(scheduler_output)
pending_structured_output_tokens = False
for req_id in scheduler_output.num_scheduled_tokens:
request = self.requests[req_id]
pending_structured_output_tokens |= (
request.use_structured_output and request.num_output_placeholders > 0
)
if (
request.num_computed_tokens
== request.num_tokens + request.num_output_placeholders
@ -25,6 +29,10 @@ class AsyncScheduler(Scheduler):
# TODO(woosuk): Support speculative decoding.
request.num_output_placeholders += 1
scheduler_output.pending_structured_output_tokens = (
pending_structured_output_tokens
)
def _update_request_with_output(
self,
request: Request,

View File

@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.engine import EngineCoreOutputs
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
@ -40,6 +40,12 @@ class SchedulerInterface(ABC):
"""
raise NotImplementedError
@abstractmethod
def get_grammar_bitmask(
self, scheduler_output: "SchedulerOutput"
) -> "GrammarOutput | None":
raise NotImplementedError
@abstractmethod
def update_from_output(
self,

View File

@ -181,12 +181,17 @@ class SchedulerOutput:
# freed from the encoder cache.
free_encoder_mm_hashes: list[str]
# ids of structured outputs requests included in the bitmask, in the
# same order as the corresponding stacked rows of the bitmask.
# There may be more than one row per request in the case of speculative decoding.
structured_output_request_ids: list[str]
# the bitmask for the whole batch
grammar_bitmask: "npt.NDArray[np.int32] | None"
# Whether the scheduled requests have all the output tokens they
# need to perform grammar bitmask computation.
pending_structured_output_tokens: bool = False
# KV Cache Connector metadata.
kv_connector_metadata: KVConnectorMetadata | None = None
@dataclass
class GrammarOutput:
# ids of structured output requests.
structured_output_request_ids: list[str]
# Bitmask ordered as structured_output_request_ids.
grammar_bitmask: "npt.NDArray[np.int32]"

View File

@ -5,7 +5,7 @@ import itertools
import time
from collections import defaultdict
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any
from typing import Any
from vllm.config import VllmConfig
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
@ -24,7 +24,12 @@ from vllm.v1.core.encoder_cache_manager import (
)
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
from vllm.v1.core.sched.interface import SchedulerInterface
from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput
from vllm.v1.core.sched.output import (
CachedRequestData,
GrammarOutput,
NewRequestData,
SchedulerOutput,
)
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
from vllm.v1.core.sched.utils import check_stop, remove_all
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
@ -35,10 +40,6 @@ from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.structured_output import StructuredOutputManager
if TYPE_CHECKING:
import numpy as np
import numpy.typing as npt
logger = init_logger(__name__)
@ -619,9 +620,6 @@ class Scheduler(SchedulerInterface):
scheduled_spec_decode_tokens,
req_to_new_blocks,
)
structured_output_request_ids, grammar_bitmask = self.get_grammar_bitmask(
num_scheduled_tokens.keys(), scheduled_spec_decode_tokens
)
# Record the request ids that were scheduled in this step.
self.prev_step_scheduled_req_ids.clear()
@ -641,8 +639,6 @@ class Scheduler(SchedulerInterface):
# the previous and the current steps.
finished_req_ids=self.finished_req_ids,
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
structured_output_request_ids=structured_output_request_ids,
grammar_bitmask=grammar_bitmask,
)
# NOTE(Kuntai): this function is designed for multiple purposes:
@ -872,9 +868,8 @@ class Scheduler(SchedulerInterface):
def get_grammar_bitmask(
self,
scheduled_request_ids: Iterable[str],
scheduled_spec_decode_tokens: dict[str, list[int]],
) -> tuple[list[str], "npt.NDArray[np.int32] | None"]:
scheduler_output: SchedulerOutput,
) -> GrammarOutput | None:
# Collect list of scheduled request ids that use structured output.
# The corresponding rows of the bitmask will be in this order.
# PERF: in case of chunked prefill,
@ -883,18 +878,18 @@ class Scheduler(SchedulerInterface):
# cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids = [
req_id
for req_id in scheduled_request_ids
for req_id in scheduler_output.num_scheduled_tokens
if (req := self.requests.get(req_id)) and req.use_structured_output
]
if not structured_output_request_ids:
return structured_output_request_ids, None
return None
bitmask = self.structured_output_manager.grammar_bitmask(
self.requests,
structured_output_request_ids,
scheduled_spec_decode_tokens,
scheduler_output.scheduled_spec_decode_tokens,
)
return structured_output_request_ids, bitmask
return GrammarOutput(structured_output_request_ids, bitmask)
def update_from_output(
self,

View File

@ -12,7 +12,7 @@ from concurrent.futures import Future
from contextlib import ExitStack, contextmanager
from inspect import isclass, signature
from logging import DEBUG
from typing import Any, TypeVar
from typing import Any, TypeVar, cast
import msgspec
import zmq
@ -163,6 +163,27 @@ class EngineCore:
vllm_config, mm_registry
)
# If a KV connector is initialized for scheduler, we want to collect
# handshake metadata from all workers so the connector in the scheduler
# will have the full context
kv_connector = self.scheduler.get_kv_connector()
if kv_connector is not None:
# Collect and store KV connector xfer metadata from workers
# (after KV cache registration)
xfer_handshake_metadata = (
self.model_executor.get_kv_connector_handshake_metadata()
)
if xfer_handshake_metadata:
# xfer_handshake_metadata is list of dicts from workers
# Each dict already has structure {tp_rank: metadata}
# Merge all worker dicts into a single dict
content: dict[int, Any] = {}
for worker_dict in xfer_handshake_metadata:
if worker_dict is not None:
content.update(worker_dict)
kv_connector.set_xfer_handshake_metadata(content)
# Setup batch queue for pipeline parallelism.
# Batch queue for scheduled batches. This enables us to asynchronously
# schedule and execute batches, and is required by pipeline parallelism
@ -178,7 +199,7 @@ class EngineCore:
self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None
if (
self.vllm_config.cache_config.enable_prefix_caching
or self.scheduler.get_kv_connector() is not None
or kv_connector is not None
):
caching_hash_fn = get_hash_fn_by_name(
vllm_config.cache_config.prefix_caching_hash_algo
@ -313,9 +334,12 @@ class EngineCore:
if not self.scheduler.has_requests():
return {}, False
scheduler_output = self.scheduler.schedule()
future = self.model_executor.execute_model(scheduler_output, non_block=True)
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
with self.log_error_detail(scheduler_output):
model_output = self.model_executor.execute_model(scheduler_output)
model_output = future.result()
if model_output is None:
model_output = self.model_executor.sample_tokens(grammar_output)
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output
@ -355,20 +379,47 @@ class EngineCore:
assert len(batch_queue) < self.batch_queue_size
model_executed = False
deferred_scheduler_output = None
if self.scheduler.has_requests():
scheduler_output = self.scheduler.schedule()
future = self.model_executor.execute_model(scheduler_output, non_block=True)
batch_queue.appendleft((future, scheduler_output))
exec_future = self.model_executor.execute_model(
scheduler_output, non_block=True
)
model_executed = scheduler_output.total_num_scheduled_tokens > 0
if (
model_executed
and len(batch_queue) < self.batch_queue_size
and not batch_queue[-1][0].done()
):
# Don't block on next worker response unless the queue is full
# or there are no more requests to schedule.
return None, True
if scheduler_output.pending_structured_output_tokens:
# We need to defer sampling until we have processed the model output
# from the prior step.
deferred_scheduler_output = scheduler_output
# Block-wait for execute to return (continues running async on the GPU).
with self.log_error_detail(scheduler_output):
exec_result = exec_future.result()
assert exec_result is None
else:
# We aren't waiting for any tokens, get any grammar output immediately.
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
# Block-wait for execute to return (continues running async on the GPU).
with self.log_error_detail(scheduler_output):
exec_result = exec_future.result()
if exec_result is None:
# Call sample tokens.
future = self.model_executor.sample_tokens(
grammar_output, non_block=True
)
else:
# No sampling required (e.g. all requests finished).
future = cast(Future[ModelRunnerOutput], exec_future)
# Add this step's future to the queue.
batch_queue.appendleft((future, scheduler_output))
if (
model_executed
and len(batch_queue) < self.batch_queue_size
and not batch_queue[-1][0].done()
):
# Don't block on next worker response unless the queue is full
# or there are no more requests to schedule.
return None, True
elif not batch_queue:
# Queue is empty. We should not reach here since this method should
@ -384,6 +435,19 @@ class EngineCore:
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output
)
# NOTE(nick): We can either handle the deferred tasks here or save
# in a field and do it immediately once step_with_batch_queue is
# re-called. The latter slightly favors TTFT over TPOT/throughput.
if deferred_scheduler_output:
# We now have the tokens needed to compute the bitmask for the
# deferred request. Get the bitmask and call sample tokens.
grammar_output = self.scheduler.get_grammar_bitmask(
deferred_scheduler_output
)
future = self.model_executor.sample_tokens(grammar_output, non_block=True)
batch_queue.appendleft((future, deferred_scheduler_output))
return engine_core_outputs, model_executed
def shutdown(self):

View File

@ -9,11 +9,14 @@ from typing import TYPE_CHECKING, Literal, TypeVar, overload
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorHandshakeMetadata,
)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.tasks import SupportedTask
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.engine import ReconfigureDistributedRequest
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
@ -177,30 +180,51 @@ class Executor(ABC):
):
raise NotImplementedError
def get_kv_connector_handshake_metadata(
self,
) -> list[dict[int, KVConnectorHandshakeMetadata]]:
return self.collective_rpc("get_kv_connector_handshake_metadata")
@overload
def execute_model(
self,
scheduler_output: SchedulerOutput,
non_block: Literal[False] = False,
) -> ModelRunnerOutput:
self, scheduler_output: SchedulerOutput, non_block: Literal[False] = False
) -> ModelRunnerOutput | None:
pass
@overload
def execute_model(
self,
scheduler_output: SchedulerOutput,
non_block: Literal[True] = True,
) -> Future[ModelRunnerOutput]:
self, scheduler_output: SchedulerOutput, non_block: Literal[True] = True
) -> Future[ModelRunnerOutput | None]:
pass
def execute_model(
self, scheduler_output: SchedulerOutput, non_block: bool = False
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
output = self.collective_rpc( # type: ignore[call-overload]
"execute_model", args=(scheduler_output,), non_block=non_block
)
return output[0]
@overload
def sample_tokens(
self, grammar_output: GrammarOutput | None, non_block: Literal[False] = False
) -> ModelRunnerOutput:
pass
@overload
def sample_tokens(
self, grammar_output: GrammarOutput | None, non_block: Literal[True] = True
) -> Future[ModelRunnerOutput]:
pass
def sample_tokens(
self, grammar_output: GrammarOutput | None, non_block: bool = False
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
output = self.collective_rpc( # type: ignore[call-overload]
"sample_tokens", args=(grammar_output,), non_block=non_block
)
return output[0]
def execute_dummy_batch(self) -> None:
self.collective_rpc("execute_dummy_batch")

View File

@ -46,7 +46,7 @@ from vllm.utils.system_utils import (
get_mp_context,
set_process_title,
)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.executor.abstract import Executor, FailureCallback
from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput
from vllm.v1.worker.worker_base import WorkerWrapperBase
@ -132,15 +132,12 @@ class MultiprocExecutor(Executor):
uw.death_writer.close()
self._ensure_worker_termination([uw.proc for uw in unready_workers])
# For pipeline parallel, we use a thread pool for asynchronous
# execute_model.
if self.max_concurrent_batches > 1:
# Note: must use only 1 IO thread to keep dequeue sequence
# from the response queue
# _async_aggregate_workers_output also assumes a single IO thread
self.io_thread_pool = ThreadPoolExecutor(
max_workers=1, thread_name_prefix="mp_exec_io"
)
# Note: must use only 1 IO thread to keep dequeue sequence
# from the response queue.
# _async_aggregate_workers_output also assumes a single IO thread.
self.io_thread_pool = ThreadPoolExecutor(
max_workers=1, thread_name_prefix="mp_exec_io"
)
self.output_rank = self._get_output_rank()
self.has_connector = self.vllm_config.kv_transfer_config is not None
@ -180,15 +177,27 @@ class MultiprocExecutor(Executor):
self.failure_callback = callback
def execute_model( # type: ignore[override]
self,
scheduler_output: SchedulerOutput,
non_block: bool = False,
self, scheduler_output: SchedulerOutput, non_block: bool = False
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
return self._execute_with_aggregation(
"execute_model", scheduler_output, non_block=non_block
)
def sample_tokens( # type: ignore[override]
self, grammar_output: GrammarOutput | None, non_block: bool = False
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
return self._execute_with_aggregation( # type: ignore[return-value]
"sample_tokens", grammar_output, non_block=non_block
)
def _execute_with_aggregation(
self, method: str, *args, non_block: bool = False
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
if not self.has_connector:
# get output only from a single worker (output_rank)
(output,) = self.collective_rpc(
"execute_model",
args=(scheduler_output,),
method,
args=args,
unique_reply_rank=self.output_rank,
non_block=non_block,
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
@ -197,8 +206,8 @@ class MultiprocExecutor(Executor):
# get output from all workers
outputs = self.collective_rpc(
"execute_model",
args=(scheduler_output,),
method,
args=args,
non_block=non_block,
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
)

View File

@ -19,7 +19,7 @@ from vllm.utils.network_utils import (
get_ip,
get_open_port,
)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.executor.abstract import Executor
from vllm.v1.executor.ray_utils import (
@ -41,6 +41,9 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
COMPLETED_NONE_FUTURE: Future[ModelRunnerOutput | None] = Future()
COMPLETED_NONE_FUTURE.set_result(None)
@dataclass
class RayWorkerMetaData:
@ -96,6 +99,8 @@ class RayDistributedExecutor(Executor):
# KV connector setup
self.has_connector = self.vllm_config.kv_transfer_config is not None
self.scheduler_output: SchedulerOutput | None = None
@property
def max_concurrent_batches(self) -> int:
"""Ray distributed executor supports pipeline parallelism,
@ -381,22 +386,46 @@ class RayDistributedExecutor(Executor):
self.shutdown()
def execute_model( # type: ignore[override]
self, scheduler_output: SchedulerOutput, non_block: bool = False
self,
scheduler_output: SchedulerOutput,
non_block: bool = False,
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
if self.scheduler_output is not None:
raise RuntimeError(
"State error: sample_tokens() must be called "
"after execute_model() returns None."
)
self.scheduler_output = scheduler_output
return COMPLETED_NONE_FUTURE if non_block else None
def sample_tokens( # type: ignore[override]
self,
grammar_output: "GrammarOutput | None",
non_block: bool = False,
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
"""Execute the model on the Ray workers.
The scheduler output to use should have been provided in
a prior call to execute_model().
Args:
scheduler_output: The scheduler output to execute.
grammar_output: The structured outputs grammar bitmask, if applicable.
non_block: If True, the method will return a Future.
Returns:
The model runner output.
"""
scheduler_output = self.scheduler_output
if scheduler_output is None:
return None # noqa
self.scheduler_output = None
# Build the compiled DAG for the first time.
if self.forward_dag is None: # type: ignore
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
refs = self.forward_dag.execute(scheduler_output) # type: ignore
refs = self.forward_dag.execute((scheduler_output, grammar_output)) # type: ignore
if not self.has_connector:
# Get output only from a single worker (output_rank)

View File

@ -19,7 +19,7 @@ from vllm.v1.outputs import AsyncModelRunnerOutput
from vllm.v1.worker.worker_base import WorkerWrapperBase
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput
logger = init_logger(__name__)
@ -82,36 +82,41 @@ try:
def execute_model_ray(
self,
scheduler_output: Union[
"SchedulerOutput", tuple["SchedulerOutput", "IntermediateTensors"]
],
execute_model_input: tuple["SchedulerOutput", "GrammarOutput"]
| tuple["SchedulerOutput", "GrammarOutput", "IntermediateTensors"],
) -> Union[
"ModelRunnerOutput", tuple["SchedulerOutput", "IntermediateTensors"]
"ModelRunnerOutput",
tuple["SchedulerOutput", "GrammarOutput", "IntermediateTensors"],
]:
# This method is used by Ray Compiled Graph to execute the model,
# and it needs a special logic of self.setup_device_if_necessary()
self.setup_device_if_necessary()
assert self.worker is not None, "Worker is not initialized"
if isinstance(scheduler_output, tuple):
scheduler_output, intermediate_tensors = scheduler_output
if len(execute_model_input) == 3:
scheduler_output, grammar_output, intermediate_tensors = (
execute_model_input
)
else:
scheduler_output, intermediate_tensors = scheduler_output, None
scheduler_output, grammar_output = execute_model_input
intermediate_tensors = None
assert self.worker.model_runner is not None
output = self.worker.model_runner.execute_model(
scheduler_output, intermediate_tensors
)
if isinstance(output, IntermediateTensors):
output = scheduler_output, output
output = scheduler_output, grammar_output, output
elif not get_pp_group().is_last_rank:
# Case where there are no scheduled requests
# but may still be finished requests.
assert not output or not output.req_ids
output = scheduler_output, None
# Ensure outputs crossing Ray compiled DAG are serializable.
# AsyncModelRunnerOutput holds CUDA events and cannot be
# pickled.
if isinstance(output, AsyncModelRunnerOutput):
output = output.get_output()
output = scheduler_output, grammar_output, None
elif output is None:
output = self.worker.model_runner.sample_tokens(grammar_output)
# Ensure outputs crossing Ray compiled DAG are serializable.
# AsyncModelRunnerOutput holds CUDA events and cannot be
# pickled.
if isinstance(output, AsyncModelRunnerOutput):
output = output.get_output()
return output
def override_env_vars(self, vars: dict[str, str]):

View File

@ -16,6 +16,7 @@ from diskcache import Cache
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils.import_utils import LazyLoader
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
if TYPE_CHECKING:
import outlines_core as oc
@ -24,7 +25,6 @@ if TYPE_CHECKING:
import xgrammar as xgr
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
else:
xgr = LazyLoader("xgr", globals(), "xgrammar")
@ -47,6 +47,7 @@ CACHE = None
def apply_grammar_bitmask(
scheduler_output: SchedulerOutput,
grammar_output: GrammarOutput,
input_batch: InputBatch,
logits: torch.Tensor,
) -> None:
@ -58,9 +59,9 @@ def apply_grammar_bitmask(
input_batch (InputBatch): The input of model runner.
logits (torch.Tensor): The output logits of model forward.
"""
grammar_bitmask = scheduler_output.grammar_bitmask
if grammar_bitmask is None:
return
# Serialization of np.ndarray is much more efficient than a tensor,
# so we receive it in that format.
grammar_bitmask = grammar_output.grammar_bitmask
# We receive the structured output bitmask from the scheduler,
# compacted to contain bitmasks only for structured output requests.
@ -79,7 +80,7 @@ def apply_grammar_bitmask(
cumulative_offset += len(
scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
)
if req_id in scheduler_output.structured_output_request_ids:
if req_id in grammar_output.structured_output_request_ids:
struct_out_req_batch_indices[req_id] = logit_index
out_indices = []
@ -91,7 +92,7 @@ def apply_grammar_bitmask(
dtype=grammar_bitmask.dtype,
)
cumulative_index = 0
for req_id in scheduler_output.structured_output_request_ids:
for req_id in grammar_output.structured_output_request_ids:
num_spec_tokens = len(
scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
)
@ -101,22 +102,28 @@ def apply_grammar_bitmask(
sorted_bitmask[logit_index + i] = grammar_bitmask[cumulative_index + i]
out_indices.append(logit_index + i)
cumulative_index += 1 + num_spec_tokens
grammar_bitmask = sorted_bitmask
# Copy async to device as tensor.
grammar_bitmask = torch.from_numpy(sorted_bitmask).to(
logits.device, non_blocking=True
)
# If the length of out indices and the logits have the same shape
# we don't need to pass indices to the kernel,
# since the bitmask is already aligned with the logits.
skip_out_indices = len(out_indices) == logits.shape[0]
# Serialization of np.ndarray is much more efficient than a tensor,
# so we receive it in that format.
grammar_bitmask = torch.from_numpy(grammar_bitmask).contiguous()
index_tensor = None
if not skip_out_indices:
# xgrammar expects a python list of indices but it will actually work with
# a tensor. If we copy the tensor ourselves here we can do it in a non_blocking
# manner and there should be no cpu sync within xgrammar.
index_tensor = torch.tensor(
out_indices, dtype=torch.int32, device="cpu", pin_memory=True
)
index_tensor = index_tensor.to(logits.device, non_blocking=True)
xgr.apply_token_bitmask_inplace(
logits,
grammar_bitmask.to(logits.device, non_blocking=True),
indices=out_indices if not skip_out_indices else None,
)
xgr.apply_token_bitmask_inplace(logits, grammar_bitmask, indices=index_tensor)
class OutlinesVocabulary:

View File

@ -204,7 +204,7 @@ class InputBatch:
self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy()
# lora related
self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int32)
self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int64)
self.lora_id_to_request_ids: dict[int, set[str]] = {}
self.lora_id_to_lora_request: dict[int, LoRARequest] = {}

View File

@ -109,6 +109,7 @@ from vllm.v1.outputs import (
EMPTY_MODEL_RUNNER_OUTPUT,
AsyncModelRunnerOutput,
DraftTokenIds,
KVConnectorOutput,
LogprobsLists,
LogprobsTensors,
ModelRunnerOutput,
@ -150,7 +151,7 @@ from .utils import (
if TYPE_CHECKING:
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
logger = init_logger(__name__)
@ -218,6 +219,20 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
return output
class ExecuteModelState(NamedTuple):
"""Ephemeral cached state transferred between execute_model() and
sample_tokens(), after execute_model() returns None."""
scheduler_output: "SchedulerOutput"
logits: torch.Tensor
spec_decode_metadata: SpecDecodeMetadata | None
spec_decode_common_attn_metadata: CommonAttentionMetadata | None
hidden_states: torch.Tensor
sample_hidden_states: torch.Tensor
aux_hidden_states: list[torch.Tensor] | None
kv_connector_output: KVConnectorOutput | None
class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def __init__(
self,
@ -509,6 +524,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
pin_memory=self.pin_memory,
)
# Ephemeral state transferred between execute_model() and sample_tokens().
self.execute_model_state: ExecuteModelState | None = None
def reset_mm_cache(self) -> None:
if self.mm_budget:
self.mm_budget.reset_cache()
@ -2113,7 +2131,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_input_tokens: int, # Padded
intermediate_tensors: IntermediateTensors | None = None,
) -> tuple[
int,
torch.Tensor | None,
torch.Tensor | None,
torch.Tensor,
@ -2207,7 +2224,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
model_kwargs.update(encoder_inputs)
return (
num_scheduled_tokens,
input_ids,
inputs_embeds,
positions,
@ -2425,13 +2441,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self,
scheduler_output: "SchedulerOutput",
intermediate_tensors: IntermediateTensors | None = None,
) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors:
) -> ModelRunnerOutput | IntermediateTensors | None:
if self.execute_model_state is not None:
raise RuntimeError(
"State error: sample_tokens() must be called "
"after execute_model() returns None."
)
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
with record_function_or_nullcontext("Preprocess"):
with self.synchronize_input_prep():
# Update persistent batch states.
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
if not num_scheduled_tokens:
if not has_kv_transfer_group():
# Return empty ModelRunnerOutput if no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
@ -2471,7 +2493,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
(
num_scheduled_tokens,
input_ids,
inputs_embeds,
positions,
@ -2559,6 +2580,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Rare case.
assert not self.is_pooling_model
sample_hidden_states = hidden_states[logits_indices]
if not get_pp_group().is_last_rank:
all_gather_tensors = {
"residual": not is_residual_scattered_for_sp(
@ -2572,7 +2594,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
logits = None
else:
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states)
model_output_broadcast_data = {}
@ -2585,9 +2606,45 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
assert model_output_broadcast_data is not None
logits = model_output_broadcast_data["logits"]
# Apply structured output bitmasks if present
if scheduler_output.structured_output_request_ids:
apply_grammar_bitmask(scheduler_output, self.input_batch, logits)
self.execute_model_state = ExecuteModelState(
scheduler_output,
logits,
spec_decode_metadata,
spec_decode_common_attn_metadata,
hidden_states,
sample_hidden_states,
aux_hidden_states,
kv_connector_output,
)
return None
@torch.inference_mode
def sample_tokens(
self, grammar_output: "GrammarOutput | None"
) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors:
if self.execute_model_state is None:
# Nothing to do (PP non-final rank case), output isn't used.
return None # noqa
# Unpack ephemeral state.
(
scheduler_output,
logits,
spec_decode_metadata,
spec_decode_common_attn_metadata,
hidden_states,
sample_hidden_states,
aux_hidden_states,
kv_connector_output,
) = self.execute_model_state
# Clear ephemeral state.
self.execute_model_state = None
# Apply structured output bitmasks if present.
if grammar_output is not None:
apply_grammar_bitmask(
scheduler_output, grammar_output, self.input_batch, logits
)
with record_function_or_nullcontext("Sample"):
sampler_output = self._sample(logits, spec_decode_metadata)
@ -2646,7 +2703,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
sampler_output,
logits,
hidden_states,
num_scheduled_tokens,
scheduler_output.total_num_scheduled_tokens,
spec_decode_metadata,
)
@ -3978,6 +4035,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def create_attn_groups(
attn_backends_map: dict[AttentionGroupKey, list[str]],
kv_cache_group_id: int,
) -> list[AttentionGroup]:
attn_groups: list[AttentionGroup] = []
for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items():
@ -3987,6 +4045,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
kv_cache_spec,
self.vllm_config,
self.device,
kv_cache_group_id,
num_metadata_builders=1
if not self.parallel_config.enable_dbo
else 2,
@ -4005,8 +4064,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Resolve cudagraph_mode before actually initialize metadata_builders
self._check_and_update_cudagraph_mode(attention_backend_set)
for attn_backends_map in attention_backend_maps:
self.attn_groups.append(create_attn_groups(attn_backends_map))
for i, attn_backend_map in enumerate(attention_backend_maps):
self.attn_groups.append(create_attn_groups(attn_backend_map, i))
# Calculate reorder batch threshold (if needed)
self.calculate_reorder_batch_threshold()
@ -4149,89 +4208,88 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
group.get_metadata_builder().reorder_batch_threshold
for group in self._attn_group_iterator()
]
# If there are no attention groups (attention-free model) or no backend
# reports a threshold, leave reordering disabled.
if len(reorder_batch_thresholds) == 0:
self.reorder_batch_threshold = None
return
self.reorder_batch_threshold = reduce(min_none_high, reorder_batch_thresholds)
def _find_compatible_block_sizes(
self,
kv_manager_block_size: int,
backend_cls: type[AttentionBackend],
return_all: bool = False,
) -> list[int]:
"""
Find compatible block sizes for a backend.
Args:
kv_manager_block_size: Physical block size of KV cache
backend_cls: Attention backend class
return_all: Return all compatible sizes if True, max size if False
Returns:
Compatible block size(s) based on return_all parameter
Raises:
ValueError: If no compatible block size found
"""
supported_block_size = backend_cls.get_supported_kernel_block_size()
compatible_sizes = []
for block_size in supported_block_size:
if isinstance(block_size, int):
if kv_manager_block_size % block_size == 0:
compatible_sizes.append(block_size)
elif (
isinstance(block_size, MultipleOf)
and kv_manager_block_size % block_size.base == 0
):
compatible_sizes.append(kv_manager_block_size)
if not compatible_sizes:
raise ValueError(f"No compatible block size for {kv_manager_block_size}")
return compatible_sizes if return_all else [max(compatible_sizes)]
def _select_common_block_size(
self, kv_manager_block_size: int, attn_groups: list[AttentionGroup]
@staticmethod
def select_common_block_size(
kv_manager_block_size: int, attn_groups: list[AttentionGroup]
) -> int:
"""
Select common block size for all backends.
Select a block size that is supported by all backends and is a factor of
kv_manager_block_size.
If kv_manager_block_size is supported by all backends, return it directly.
Otherwise, return the max supported size.
Args:
kv_manager_block_size: Block size of KV cache
attn_groups: List of attention groups
Returns:
Block size supported by all backends,
prioritizing cache_config.block_size
The selected block size
Raises:
ValueError: If no common block size found
ValueError: If no valid block size found
"""
all_backend_supports = []
for attn_group in attn_groups:
compatible_sizes = self._find_compatible_block_sizes(
kv_manager_block_size, attn_group.backend, return_all=True
)
supported_sizes = sorted(list(set(compatible_sizes)), reverse=True)
all_backend_supports.append(set(supported_sizes))
def block_size_is_supported(
backends: list[type[AttentionBackend]], block_size: int
) -> bool:
"""
Check if the block size is supported by all backends.
"""
for backend in backends:
is_supported = False
for supported_size in backend.get_supported_kernel_block_size():
if isinstance(supported_size, int):
if block_size == supported_size:
is_supported = True
elif isinstance(supported_size, MultipleOf):
if block_size % supported_size.base == 0:
is_supported = True
else:
raise ValueError(f"Unknown supported size: {supported_size}")
if not is_supported:
return False
return True
common_supported_sizes = set.intersection(*all_backend_supports)
backends = [group.backend for group in attn_groups]
if not common_supported_sizes:
error_msg = f"No common block size for {kv_manager_block_size}. "
for i, attn_group in enumerate(attn_groups):
supported = all_backend_supports[i]
error_msg += (
f"Backend {attn_group.backend} supports: {sorted(supported)}. "
)
raise ValueError(error_msg)
# Case 1: if the block_size of kv cache manager is supported by all backends,
# return it directly
if block_size_is_supported(backends, kv_manager_block_size):
return kv_manager_block_size
if self.cache_config.block_size in common_supported_sizes:
return self.cache_config.block_size
# Case 2: otherwise, the block_size must be an `int`-format supported size of
# at least one backend. Iterate over all `int`-format supported sizes in
# descending order and return the first one that is supported by all backends.
# Simple proof:
# If the supported size b is in MultipleOf(x_i) format for all attention
# backends i, and b a factor of kv_manager_block_size, then
# kv_manager_block_size also satisfies MultipleOf(x_i) for all i. We will
# return kv_manager_block_size in case 1.
all_int_supported_sizes = set(
supported_size
for backend in backends
for supported_size in backend.get_supported_kernel_block_size()
if isinstance(supported_size, int)
)
return max(common_supported_sizes)
for supported_size in sorted(all_int_supported_sizes, reverse=True):
if kv_manager_block_size % supported_size != 0:
continue
if block_size_is_supported(backends, supported_size):
return supported_size
raise ValueError(f"No common block size for {kv_manager_block_size}. ")
def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None:
def may_reinitialize_input_batch(
self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int]
) -> None:
"""
Re-initialize the input batch if the block sizes are different from
`[self.cache_config.block_size]`. This usually happens when there
@ -4239,6 +4297,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
Args:
kv_cache_config: The KV cache configuration.
kernel_block_sizes: The kernel block sizes for each KV cache group.
"""
block_sizes = [
kv_cache_group.kv_cache_spec.block_size
@ -4246,9 +4305,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if not isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec)
]
# Generate kernel_block_sizes that matches each block_size
kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config)
if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [
self.cache_config.block_size
]:
@ -4349,7 +4405,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# all backends in the group.
attn_groups = self.attn_groups[kv_cache_group_id]
kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size
selected_kernel_size = self._select_common_block_size(
selected_kernel_size = self.select_common_block_size(
kv_manager_block_size, attn_groups
)
kernel_block_sizes.append(selected_kernel_size)
@ -4367,6 +4423,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self,
kv_cache_config: KVCacheConfig,
kv_cache_raw_tensors: dict[str, torch.Tensor],
kernel_block_sizes: list[int],
) -> dict[str, torch.Tensor]:
"""
Reshape the KV cache tensors to the desired shape and dtype.
@ -4375,6 +4432,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
kv_cache_config: The KV cache config
kv_cache_raw_tensors: The KV cache buffer of each layer, with
correct size but uninitialized shape.
kernel_block_sizes: The kernel block sizes for each KV cache group.
Returns:
Dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache.
@ -4384,6 +4442,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for group in self._kv_cache_spec_attn_group_iterator():
kv_cache_spec = group.kv_cache_spec
attn_backend = group.backend
if group.kv_cache_group_id == len(kernel_block_sizes):
# There may be a last group for layers without kv cache.
continue
kernel_block_size = kernel_block_sizes[group.kv_cache_group_id]
for layer_name in group.layer_names:
if layer_name in self.runner_only_attn_layers:
continue
@ -4392,24 +4454,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
if isinstance(kv_cache_spec, AttentionSpec):
has_attn = True
kv_manager_block_size = kv_cache_spec.block_size
kernel_size_list = self._find_compatible_block_sizes(
kv_manager_block_size, attn_backend, return_all=False
num_blocks_per_kv_block = (
kv_cache_spec.block_size // kernel_block_size
)
kernel_size = kernel_size_list[0]
num_blocks_per_kv_block = kv_manager_block_size // kernel_size
kernel_num_blocks = num_blocks * num_blocks_per_kv_block
kv_cache_shape = attn_backend.get_kv_cache_shape(
kernel_num_blocks,
kernel_size,
kernel_block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
cache_dtype_str=self.cache_config.cache_dtype,
)
dtype = kv_cache_spec.dtype
try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() # noqa: E501
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
assert len(kv_cache_stride_order) == len(kv_cache_shape)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
@ -4492,13 +4551,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
def initialize_kv_cache_tensors(
self, kv_cache_config: KVCacheConfig
self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int]
) -> dict[str, torch.Tensor]:
"""
Initialize the memory buffer for KV cache.
Args:
kv_cache_config: The KV cache config
kernel_block_sizes: The kernel block sizes for each KV cache group.
Returns:
Dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache.
@ -4507,7 +4568,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config)
# Change the memory buffer to the desired shape
kv_caches = self._reshape_kv_cache_tensors(
kv_cache_config, kv_cache_raw_tensors
kv_cache_config, kv_cache_raw_tensors, kernel_block_sizes
)
# Set up cross-layer KV cache sharing
@ -4566,9 +4627,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.may_add_encoder_only_layers_to_kv_cache_config()
self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)
self.initialize_attn_backend(kv_cache_config)
# The kernel block size for all KV cache groups. For example, if
# kv_cache_manager uses block_size 256 for a given group, but the attention
# backends for that group only supports block_size 64, we will return
# kernel_block_size 64 and split the 256-token-block to 4 blocks with 64
# tokens each.
kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config)
# Reinitialize need to after initialize_attn_backend
self.may_reinitialize_input_batch(kv_cache_config)
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
self.may_reinitialize_input_batch(kv_cache_config, kernel_block_sizes)
kv_caches = self.initialize_kv_cache_tensors(
kv_cache_config, kernel_block_sizes
)
if self.speculative_config and self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer)

View File

@ -6,6 +6,7 @@ import copy
import gc
import os
from contextlib import AbstractContextManager, nullcontext
from types import NoneType
from typing import TYPE_CHECKING, Any
import torch
@ -19,7 +20,11 @@ from vllm.distributed import (
init_distributed_environment,
set_custom_all_reduce,
)
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
from vllm.distributed.kv_transfer import (
ensure_kv_transfer_initialized,
get_kv_transfer_group,
has_kv_transfer_group,
)
from vllm.distributed.parallel_state import (
get_pp_group,
get_tp_group,
@ -33,6 +38,7 @@ from vllm.sequence import IntermediateTensors
from vllm.tasks import SupportedTask
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.mem_utils import MemorySnapshot, memory_profiling
from vllm.v1.core.sched.output import GrammarOutput
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import (
@ -348,6 +354,21 @@ class Worker(WorkerBase):
return int(self.available_kv_cache_memory_bytes)
def get_kv_connector_handshake_metadata(self) -> dict | None:
"""Get KV connector metadata from this worker if available."""
if not has_kv_transfer_group():
return None
connector = get_kv_transfer_group()
# Return None for connectors that don't need to exchange handshake
# metadata across workers.
if (metadata := connector.get_handshake_metadata()) is None:
return None
tp_rank = get_tp_group().rank_in_group
return {tp_rank: metadata}
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
return self.model_runner.get_kv_cache_spec()
@ -489,11 +510,16 @@ class Worker(WorkerBase):
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.model_runner.get_supported_tasks()
@torch.inference_mode()
def sample_tokens(
self, grammar_output: "GrammarOutput"
) -> ModelRunnerOutput | AsyncModelRunnerOutput:
return self.model_runner.sample_tokens(grammar_output)
@torch.inference_mode()
def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
self, scheduler_output: "SchedulerOutput"
) -> ModelRunnerOutput | None:
intermediate_tensors = None
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
@ -512,13 +538,13 @@ class Worker(WorkerBase):
)
output = self.model_runner.execute_model(scheduler_output, intermediate_tensors)
if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)):
if isinstance(output, (ModelRunnerOutput, NoneType)):
return output
assert isinstance(output, IntermediateTensors)
parallel_config = self.vllm_config.parallel_config
assert (
parallel_config.distributed_executor_backend != ("external_launcher")
parallel_config.distributed_executor_backend != "external_launcher"
and not get_pp_group().is_last_rank
)

View File

@ -139,7 +139,7 @@ class InputBatch:
self.min_tokens: dict[int, tuple[int, set[int]]] = {}
# lora related
self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int32)
self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int64)
self.lora_id_to_request_ids: dict[int, set[str]] = {}
self.lora_id_to_lora_request: dict[int, LoRARequest] = {}

View File

@ -92,7 +92,7 @@ from .utils import (
)
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
logger = init_logger(__name__)
@ -372,6 +372,11 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
else:
self.sample_from_logits_func = self.sample_from_logits
# For passing scheduler_output between successive
# execute_model() and sample_tokens() calls.
self.scheduler_output: SchedulerOutput | None = None
self.mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None
def reset_mm_cache(self) -> None:
if self.mm_budget:
self.mm_budget.reset_cache()
@ -1078,7 +1083,12 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self,
scheduler_output: "SchedulerOutput",
intermediate_tensors: IntermediateTensors | None = None,
) -> ModelRunnerOutput:
) -> ModelRunnerOutput | None:
if self.scheduler_output is not None:
raise RuntimeError(
"State error: sample_tokens() must be called "
"after execute_model() returns None."
)
# Update cached state
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
@ -1088,14 +1098,30 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return self.kv_connector_no_forward(scheduler_output, self.vllm_config)
mm_embed_inputs = None
if self.supports_mm_inputs:
# Run the multimodal encoder if any.
self._execute_mm_encoder(scheduler_output)
mm_embed_inputs = self._gather_mm_embeddings(scheduler_output)
else:
mm_embed_inputs = None
torch_xla.sync(wait=False)
self.scheduler_output = scheduler_output
self.mm_embed_inputs = mm_embed_inputs
return None
@torch.no_grad()
def sample_tokens(
self, grammar_output: "GrammarOutput | None"
) -> ModelRunnerOutput:
if self.scheduler_output is None:
# Nothing to do (PP non-final rank case), output isn't used.
return None # noqa
scheduler_output = self.scheduler_output
mm_embed_inputs = self.mm_embed_inputs
self.scheduler_output = None
self.mm_embed_inputs = None
# Prepare inputs, the requests might be split into multiple
# executions, combine the result of each execution.
start_index = 0
@ -1131,9 +1157,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
self.input_batch, padded_num_reqs, self.device
)
if scheduler_output.grammar_bitmask is not None:
if grammar_output is not None:
require_struct_decoding, grammar_bitmask_padded, arange = (
self.prepare_structured_decoding_input(logits, scheduler_output)
self.prepare_structured_decoding_input(logits, grammar_output)
)
logits = self.structured_decode(
require_struct_decoding, grammar_bitmask_padded, logits, arange
@ -1954,10 +1980,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return self.model.get_input_embeddings(*args, **kwargs)
def prepare_structured_decoding_input(
self, logits: torch.Tensor, scheduler_output: "SchedulerOutput"
self, logits: torch.Tensor, grammar_output: "GrammarOutput"
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
grammar_bitmask = scheduler_output.grammar_bitmask
assert grammar_bitmask is not None
grammar_bitmask = grammar_output.grammar_bitmask
num_reqs, _ = logits.shape
# Reset pre-allocated tensors
@ -1965,7 +1990,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.require_structured_out_cpu.zero_()
cumulative_mask_idx = 0
for req_id in scheduler_output.structured_output_request_ids:
for req_id in grammar_output.structured_output_request_ids:
if req_id not in self.input_batch.req_id_to_index:
continue
batch_index = self.input_batch.req_id_to_index[req_id]

View File

@ -17,7 +17,6 @@ from vllm.distributed import (
)
from vllm.distributed.kv_transfer import (
ensure_kv_transfer_initialized,
has_kv_transfer_group,
)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@ -27,7 +26,7 @@ from vllm.platforms.tpu import USE_TPU_INFERENCE
from vllm.tasks import SupportedTask
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.utils import report_usage_stats
@ -255,13 +254,13 @@ class TPUWorker:
tpu_kv_cache_bytes = tpu_kv_cache_bytes * head_size // padded_head_size
return int(tpu_kv_cache_bytes)
def sample_tokens(self, grammar_output: "GrammarOutput") -> ModelRunnerOutput:
return self.model_runner.sample_tokens(grammar_output)
def execute_model(
self,
scheduler_output: "SchedulerOutput",
self, scheduler_output: "SchedulerOutput"
) -> ModelRunnerOutput | None:
output = self.model_runner.execute_model(scheduler_output)
# every worker's output is needed when kv_transfer_group is set up
return output if self.is_driver_worker or has_kv_transfer_group() else None
return self.model_runner.execute_model(scheduler_output)
def profile(self, is_start: bool = True):
if self.rank < 1:

View File

@ -140,6 +140,7 @@ class AttentionGroup:
metadata_builders: list[AttentionMetadataBuilder]
layer_names: list[str]
kv_cache_spec: KVCacheSpec
kv_cache_group_id: int
@staticmethod
def create_with_metadata_builders(
@ -148,13 +149,16 @@ class AttentionGroup:
kv_cache_spec: KVCacheSpec,
vllm_config: VllmConfig,
device: torch.device,
kv_cache_group_id: int,
num_metadata_builders: int = 1,
) -> "AttentionGroup":
metadata_builders = [
backend.get_builder_cls()(kv_cache_spec, layer_names, vllm_config, device)
for _ in range(num_metadata_builders)
]
return AttentionGroup(backend, metadata_builders, layer_names, kv_cache_spec)
return AttentionGroup(
backend, metadata_builders, layer_names, kv_cache_spec, kv_cache_group_id
)
def get_metadata_builder(self, ubatch_id: int = 0) -> AttentionMetadataBuilder:
assert len(self.metadata_builders) > ubatch_id

View File

@ -20,10 +20,12 @@ from vllm.v1.kv_cache_interface import KVCacheSpec
from vllm.v1.serial_utils import run_method
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.outputs import AsyncModelRunnerOutput, ModelRunnerOutput
else:
SchedulerOutput = object
GrammarOutput = object
AsyncModelRunnerOutput = object
ModelRunnerOutput = object
logger = init_logger(__name__)
@ -122,7 +124,21 @@ class WorkerBase:
"""Load model onto target device."""
raise NotImplementedError
def execute_model(self, scheduler_output: SchedulerOutput) -> ModelRunnerOutput:
def execute_model(
self, scheduler_output: SchedulerOutput
) -> ModelRunnerOutput | None:
"""If this method returns None, sample_tokens should be called immediately after
to obtain the ModelRunnerOutput.
Note that this design may be changed in future if/when structured outputs
parallelism is re-architected.
"""
raise NotImplementedError
def sample_tokens(
self, grammar_output: GrammarOutput
) -> ModelRunnerOutput | AsyncModelRunnerOutput:
"""Should be called immediately after execute_model iff it returned None."""
raise NotImplementedError
def get_cache_block_size_bytes(self) -> int:
@ -344,7 +360,7 @@ class WorkerWrapperBase:
scheduler_output: SchedulerOutput,
*args,
**kwargs,
) -> ModelRunnerOutput:
) -> ModelRunnerOutput | None:
self._apply_mm_cache(scheduler_output)
assert self.worker is not None