Merge remote-tracking branch 'origin/main' into refactor-fp8-linear

This commit is contained in:
vllmellm 2025-11-01 09:59:43 +00:00
commit 8e8218ebac
85 changed files with 2179 additions and 961 deletions

View File

@ -441,7 +441,7 @@ steps:
--ignore=lora/test_llm_with_multi_loras.py \ --ignore=lora/test_llm_with_multi_loras.py \
--ignore=lora/test_olmoe_tp.py \ --ignore=lora/test_olmoe_tp.py \
--ignore=lora/test_deepseekv2_tp.py \ --ignore=lora/test_deepseekv2_tp.py \
--ignore=lora/test_gptoss.py \ --ignore=lora/test_gptoss_tp.py \
--ignore=lora/test_qwen3moe_tp.py --ignore=lora/test_qwen3moe_tp.py
parallelism: 4 parallelism: 4
@ -1217,6 +1217,8 @@ steps:
- pytest -v -s -x lora/test_llama_tp.py - pytest -v -s -x lora/test_llama_tp.py
- pytest -v -s -x lora/test_llm_with_multi_loras.py - pytest -v -s -x lora/test_llm_with_multi_loras.py
- pytest -v -s -x lora/test_olmoe_tp.py - pytest -v -s -x lora/test_olmoe_tp.py
- pytest -v -s -x lora/test_gptoss_tp.py
- label: Weight Loading Multiple GPU Test # 33min - label: Weight Loading Multiple GPU Test # 33min
timeout_in_minutes: 45 timeout_in_minutes: 45

View File

@ -340,6 +340,16 @@ steps:
commands: commands:
- pytest -v -s v1/attention - pytest -v -s v1/attention
- label: V1 Test attention (B200) # 10min
timeout_in_minutes: 30
gpu: b200
source_file_dependencies:
- vllm/v1/attention
- tests/v1/attention
commands:
- export VLLM_DISABLE_FLASHINFER_PREFILL=1 # TODO: FI prefill is bugged and causes incorrectness, fix this
- pytest -v -s v1/attention
- label: V1 Test others (CPU) # 5 mins - label: V1 Test others (CPU) # 5 mins
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
@ -417,7 +427,7 @@ steps:
--ignore=lora/test_llm_with_multi_loras.py \ --ignore=lora/test_llm_with_multi_loras.py \
--ignore=lora/test_olmoe_tp.py \ --ignore=lora/test_olmoe_tp.py \
--ignore=lora/test_deepseekv2_tp.py \ --ignore=lora/test_deepseekv2_tp.py \
--ignore=lora/test_gptoss.py \ --ignore=lora/test_gptoss_tp.py \
--ignore=lora/test_qwen3moe_tp.py --ignore=lora/test_qwen3moe_tp.py
parallelism: 4 parallelism: 4
@ -1119,6 +1129,7 @@ steps:
- pytest -v -s -x lora/test_llama_tp.py - pytest -v -s -x lora/test_llama_tp.py
- pytest -v -s -x lora/test_llm_with_multi_loras.py - pytest -v -s -x lora/test_llm_with_multi_loras.py
- pytest -v -s -x lora/test_olmoe_tp.py - pytest -v -s -x lora/test_olmoe_tp.py
- pytest -v -s -x lora/test_gptoss_tp.py
- label: Weight Loading Multiple GPU Test # 33min - label: Weight Loading Multiple GPU Test # 33min

View File

@ -1429,8 +1429,6 @@ async def main() -> None:
random.seed(args.seed) random.seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
if not os.path.exists(args.model):
raise OSError(f"Path does not exist: {args.model}")
logger.info("Loading tokenizer") logger.info("Loading tokenizer")
tokenizer = AutoTokenizer.from_pretrained(args.model) tokenizer = AutoTokenizer.from_pretrained(args.model)

View File

@ -4,7 +4,7 @@ This doc serves as a collection of handy tips for optimizing your vLLM on TPU wo
## Get started ## Get started
Looking for setup and installation instructions? Find them [here](../getting_started/installation/google_tpu.md). Looking for setup and installation instructions? Find them [here](https://docs.vllm.ai/projects/tpu/en/latest/getting_started/installation/).
### TPU workload sizing ### TPU workload sizing

View File

@ -0,0 +1,133 @@
# Batch Invariance
!!! note
Batch invariance is currently in beta. Some features are still under active development.
Track progress and planned improvements at <https://github.com/vllm-project/vllm/issues/27433>
This document shows how to enable batch invariance in vLLM. Batch invariance ensures that the output of a model is deterministic and independent of the batch size or the order of requests in a batch.
## Motivation
Batch invariance is crucial for several use cases:
- **Framework debugging**: Deterministic outputs make it easier to debug issues in the inference framework, as the same input will always produce the same output regardless of batching.
- **Model debugging**: Helps identify issues in model implementations by ensuring consistent behavior across different batch configurations.
- **Reinforcement Learning (RL)**: RL training often requires deterministic rollouts for reproducibility and stable training.
- **Large-scale inference systems**: Systems that use vLLM as a component benefit from deterministic behavior for testing, validation, and consistency guarantees.
## Hardware Requirements
Batch invariance currently requires NVIDIA GPUs with compute capability 9.0 or higher:
- **H-series**: H100, H200
- **B-series**: B100, B200
## Enabling Batch Invariance
Batch invariance can be enabled by setting the `VLLM_BATCH_INVARIANT` environment variable to `1`:
```bash
export VLLM_BATCH_INVARIANT=1
```
### Online Inference (Server Mode)
To start a vLLM server with batch invariance enabled:
```bash
VLLM_BATCH_INVARIANT=1 vllm serve meta-llama/Llama-3.1-8B-Instruct
```
Then use the OpenAI-compatible client:
```python
from openai import OpenAI
client = OpenAI(
api_key="EMPTY",
base_url="http://localhost:8000/v1",
)
# These requests will produce deterministic outputs
# regardless of batch size or order
response = client.completions.create(
model="meta-llama/Llama-3.1-8B-Instruct",
prompt="The future of AI is",
max_tokens=100,
temperature=0.7,
seed=42,
)
print(response.choices[0].text)
```
### Offline Inference
For offline batch inference with batch invariance:
```python
import os
os.environ["VLLM_BATCH_INVARIANT"] = "1"
from vllm import LLM, SamplingParams
prompts = [
"The future of AI is",
"Machine learning enables",
"Deep learning models can",
]
sampling_params = SamplingParams(
temperature=0.7,
top_p=0.95,
max_tokens=100,
seed=42,
)
llm = LLM(
model="meta-llama/Llama-3.1-8B-Instruct",
tensor_parallel_size=1,
)
# Outputs will be deterministic regardless of batch size
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}")
print(f"Generated: {generated_text!r}\n")
```
## Tested Models
Batch invariance has been tested and verified on the following models:
- **DeepSeek series**: `deepseek-ai/DeepSeek-V3`, `deepseek-ai/DeepSeek-V3-0324`, `deepseek-ai/DeepSeek-R1`, `deepseek-ai/DeepSeek-V3.1`
- **Qwen3 (Dense)**: `Qwen/Qwen3-1.7B`, `Qwen/Qwen3-8B`
- **Qwen3 (MoE)**: `Qwen/Qwen3-30B-A3B`, `Qwen/Qwen3-Next-80B-A3B-Instruct`
- **Llama 3**: `meta-llama/Llama-3.1-8B-Instruct`, `meta-llama/Llama-3.2-1B-Instruct`
Other models may also work, but these have been explicitly validated. If you encounter issues with a specific model, please report them on the [GitHub issue tracker](https://github.com/vllm-project/vllm/issues/new/choose).
## Implementation Details
When batch invariance is enabled, vLLM:
1. Uses deterministic kernel implementations for attention and other operations
2. Ensures consistent numerical behavior across different batch sizes
3. Disables certain optimizations that may introduce non-determinism (such as custom all-reduce operations in tensor parallel mode)
!!! note
Enabling batch invariance may impact performance compared to the default non-deterministic mode. This trade-off is intentional to guarantee reproducibility.
## Future Improvements
The batch invariance feature is under active development. Planned improvements include:
- Support for additional GPU architectures
- Expanded model coverage
- Performance optimizations
- Additional testing and validation
For the latest status and to contribute ideas, see the [tracking issue](https://github.com/vllm-project/vllm/issues/27433).

View File

@ -81,7 +81,7 @@ python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py \
- Default: 5600 - Default: 5600
- **Required for both prefiller and decoder instances** - **Required for both prefiller and decoder instances**
- Each vLLM worker needs a unique port on its host; using the same port number across different hosts is fine - Each vLLM worker needs a unique port on its host; using the same port number across different hosts is fine
- For TP/DP deployments, each worker's port on a node is computed as: base_port + dp_rank * tp_size + tp_rank (e.g., with `--tensor-parallel-size=4` and base_port=5600, tp_rank 0..3 use ports 5600, 5601, 5602, 5603 on that node). - For TP/DP deployments, each worker's port on a node is computed as: base_port + dp_rank (e.g., with `--data-parallel-size=2` and base_port=5600, dp_rank 0..1 use port 5600, 5601 on that node).
- Used for the initial NIXL handshake between the prefiller and the decoder - Used for the initial NIXL handshake between the prefiller and the decoder
- `VLLM_NIXL_SIDE_CHANNEL_HOST`: Host for side channel communication - `VLLM_NIXL_SIDE_CHANNEL_HOST`: Host for side channel communication

View File

@ -2,4 +2,4 @@ nav:
- README.md - README.md
- gpu.md - gpu.md
- cpu.md - cpu.md
- google_tpu.md - TPU: https://docs.vllm.ai/projects/tpu/en/latest/getting_started/installation/

View File

@ -11,7 +11,6 @@ vLLM supports the following hardware platforms:
- [ARM AArch64](cpu.md#arm-aarch64) - [ARM AArch64](cpu.md#arm-aarch64)
- [Apple silicon](cpu.md#apple-silicon) - [Apple silicon](cpu.md#apple-silicon)
- [IBM Z (S390X)](cpu.md#ibm-z-s390x) - [IBM Z (S390X)](cpu.md#ibm-z-s390x)
- [Google TPU](google_tpu.md)
## Hardware Plugins ## Hardware Plugins
@ -20,6 +19,7 @@ The backends below live **outside** the main `vllm` repository and follow the
| Accelerator | PyPI / package | Repository | | Accelerator | PyPI / package | Repository |
|-------------|----------------|------------| |-------------|----------------|------------|
| Google TPU | `tpu-inference` | <https://github.com/vllm-project/tpu-inference> |
| Ascend NPU | `vllm-ascend` | <https://github.com/vllm-project/vllm-ascend> | | Ascend NPU | `vllm-ascend` | <https://github.com/vllm-project/vllm-ascend> |
| Intel Gaudi (HPU) | N/A, install from source | <https://github.com/vllm-project/vllm-gaudi> | | Intel Gaudi (HPU) | N/A, install from source | <https://github.com/vllm-project/vllm-gaudi> |
| MetaX MACA GPU | N/A, install from source | <https://github.com/MetaX-MACA/vLLM-metax> | | MetaX MACA GPU | N/A, install from source | <https://github.com/MetaX-MACA/vLLM-metax> |

View File

@ -1,193 +0,0 @@
# Google TPU
Tensor Processing Units (TPUs) are Google's custom-developed application-specific
integrated circuits (ASICs) used to accelerate machine learning workloads. TPUs
are available in different versions each with different hardware specifications.
For more information about TPUs, see [TPU System Architecture](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm).
For more information on the TPU versions supported with vLLM, see:
- [TPU v6e](https://cloud.google.com/tpu/docs/v6e)
- [TPU v5e](https://cloud.google.com/tpu/docs/v5e)
- [TPU v5p](https://cloud.google.com/tpu/docs/v5p)
- [TPU v4](https://cloud.google.com/tpu/docs/v4)
These TPU versions allow you to configure the physical arrangements of the TPU
chips. This can improve throughput and networking performance. For more
information see:
- [TPU v6e topologies](https://cloud.google.com/tpu/docs/v6e#configurations)
- [TPU v5e topologies](https://cloud.google.com/tpu/docs/v5e#tpu-v5e-config)
- [TPU v5p topologies](https://cloud.google.com/tpu/docs/v5p#tpu-v5p-config)
- [TPU v4 topologies](https://cloud.google.com/tpu/docs/v4#tpu-v4-config)
In order for you to use Cloud TPUs you need to have TPU quota granted to your
Google Cloud Platform project. TPU quotas specify how many TPUs you can use in a
GPC project and are specified in terms of TPU version, the number of TPU you
want to use, and quota type. For more information, see [TPU quota](https://cloud.google.com/tpu/docs/quota#tpu_quota).
For TPU pricing information, see [Cloud TPU pricing](https://cloud.google.com/tpu/pricing).
You may need additional persistent storage for your TPU VMs. For more
information, see [Storage options for Cloud TPU data](https://cloud.devsite.corp.google.com/tpu/docs/storage-options).
!!! warning
There are no pre-built wheels for this device, so you must either use the pre-built Docker image or build vLLM from source.
## Requirements
- Google Cloud TPU VM
- TPU versions: v6e, v5e, v5p, v4
- Python: 3.11 or newer
### Provision Cloud TPUs
You can provision Cloud TPUs using the [Cloud TPU API](https://cloud.google.com/tpu/docs/reference/rest)
or the [queued resources](https://cloud.google.com/tpu/docs/queued-resources)
API (preferred). This section shows how to create TPUs using the queued resource API. For
more information about using the Cloud TPU API, see [Create a Cloud TPU using the Create Node API](https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm#create-node-api).
Queued resources enable you to request Cloud TPU resources in a queued manner.
When you request queued resources, the request is added to a queue maintained by
the Cloud TPU service. When the requested resource becomes available, it's
assigned to your Google Cloud project for your immediate exclusive use.
!!! note
In all of the following commands, replace the ALL CAPS parameter names with
appropriate values. See the parameter descriptions table for more information.
### Provision Cloud TPUs with GKE
For more information about using TPUs with GKE, see:
- [About TPUs in GKE](https://cloud.google.com/kubernetes-engine/docs/concepts/tpus)
- [Deploy TPU workloads in GKE Standard](https://cloud.google.com/kubernetes-engine/docs/how-to/tpus)
- [Plan for TPUs in GKE](https://cloud.google.com/kubernetes-engine/docs/concepts/plan-tpus)
## Configure a new environment
### Provision a Cloud TPU with the queued resource API
Create a TPU v5e with 4 TPU chips:
```bash
gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \
--node-id TPU_NAME \
--project PROJECT_ID \
--zone ZONE \
--accelerator-type ACCELERATOR_TYPE \
--runtime-version RUNTIME_VERSION \
--service-account SERVICE_ACCOUNT
```
| Parameter name | Description |
|--------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| QUEUED_RESOURCE_ID | The user-assigned ID of the queued resource request. |
| TPU_NAME | The user-assigned name of the TPU which is created when the queued resource request is allocated. |
| PROJECT_ID | Your Google Cloud project |
| ZONE | The GCP zone where you want to create your Cloud TPU. The value you use depends on the version of TPUs you are using. For more information, see [TPU regions and zones] |
| ACCELERATOR_TYPE | The TPU version you want to use. Specify the TPU version, for example `v5litepod-4` specifies a v5e TPU with 4 cores, `v6e-1` specifies a v6e TPU with 1 core. For more information, see [TPU versions]. |
| RUNTIME_VERSION | The TPU VM runtime version to use. For example, use `v2-alpha-tpuv6e` for a VM loaded with one or more v6e TPU(s). |
| SERVICE_ACCOUNT | The email address for your service account. You can find it in the IAM Cloud Console under *Service Accounts*. For example: `tpu-service-account@<your_project_ID>.iam.gserviceaccount.com` |
Connect to your TPU VM using SSH:
```bash
gcloud compute tpus tpu-vm ssh TPU_NAME --project PROJECT_ID --zone ZONE
```
!!! note
When configuring `RUNTIME_VERSION` ("TPU software version") on GCP, ensure it matches the TPU generation you've selected by referencing the [TPU VM images] compatibility matrix. Using an incompatible version may prevent vLLM from running correctly.
[TPU versions]: https://cloud.google.com/tpu/docs/runtimes
[TPU VM images]: https://cloud.google.com/tpu/docs/runtimes
[TPU regions and zones]: https://cloud.google.com/tpu/docs/regions-zones
## Set up using Python
### Pre-built wheels
Currently, there are no pre-built TPU wheels.
### Build wheel from source
Install Miniconda:
```bash
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
bash Miniconda3-latest-Linux-x86_64.sh
source ~/.bashrc
```
Create and activate a Conda environment for vLLM:
```bash
conda create -n vllm python=3.12 -y
conda activate vllm
```
Clone the vLLM repository and go to the vLLM directory:
```bash
git clone https://github.com/vllm-project/vllm.git && cd vllm
```
Uninstall the existing `torch` and `torch_xla` packages:
```bash
pip uninstall torch torch-xla -y
```
Install build dependencies:
```bash
pip install -r requirements/tpu.txt
sudo apt-get install --no-install-recommends --yes libopenblas-base libopenmpi-dev libomp-dev
```
Run the setup script:
```bash
VLLM_TARGET_DEVICE="tpu" python -m pip install -e .
```
## Set up using Docker
### Pre-built images
See [Using Docker](../../deployment/docker.md) for instructions on using the official Docker image, making sure to substitute the image name `vllm/vllm-openai` with `vllm/vllm-tpu`.
### Build image from source
You can use [docker/Dockerfile.tpu](../../../docker/Dockerfile.tpu) to build a Docker image with TPU support.
```bash
docker build -f docker/Dockerfile.tpu -t vllm-tpu .
```
Run the Docker image with the following command:
```bash
# Make sure to add `--privileged --net host --shm-size=16G`.
docker run --privileged --net host --shm-size=16G -it vllm-tpu
```
!!! note
Since TPU relies on XLA which requires static shapes, vLLM bucketizes the
possible input shapes and compiles an XLA graph for each shape. The
compilation time may take 20~30 minutes in the first run. However, the
compilation time reduces to ~5 minutes afterwards because the XLA graphs are
cached in the disk (in `VLLM_XLA_CACHE_PATH` or `~/.cache/vllm/xla_cache` by default).
!!! tip
If you encounter the following error:
```console
from torch._C import * # noqa: F403
ImportError: libopenblas.so.0: cannot open shared object file: No such
file or directory
```
Install OpenBLAS with the following command:
```bash
sudo apt-get install --no-install-recommends --yes libopenblas-base libopenmpi-dev libomp-dev
```

View File

@ -63,6 +63,17 @@ This guide will help you quickly get started with vLLM to perform:
rocm/vllm-dev:nightly rocm/vllm-dev:nightly
``` ```
=== "Google TPU"
To run vLLM on Google TPUs, you need to install the `vllm-tpu` package.
```bash
uv pip install vllm-tpu
```
!!! note
For more detailed instructions, including Docker, installing from source, and troubleshooting, please refer to the [vLLM on TPU documentation](https://docs.vllm.ai/projects/tpu/en/latest/).
!!! note !!! note
For more detail and non-CUDA platforms, please refer [here](installation/README.md) for specific instructions on how to install vLLM. For more detail and non-CUDA platforms, please refer [here](installation/README.md) for specific instructions on how to install vLLM.

View File

@ -7,7 +7,7 @@ requests >= 2.26.0
tqdm tqdm
blake3 blake3
py-cpuinfo py-cpuinfo
transformers >= 4.56.0 transformers >= 4.56.0, < 5
tokenizers >= 0.21.1 # Required for fast incremental detokenization. tokenizers >= 0.21.1 # Required for fast incremental detokenization.
protobuf # Required by LlamaTokenizer. protobuf # Required by LlamaTokenizer.
fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint. fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint.

View File

@ -29,7 +29,7 @@ opencv-python-headless >= 4.11.0 # required for video test
datamodel_code_generator # required for minicpm3 test datamodel_code_generator # required for minicpm3 test
lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d # required for model evaluation test lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d # required for model evaluation test
mteb>=1.38.11, <2 # required for mteb test mteb>=1.38.11, <2 # required for mteb test
transformers==4.56.2 transformers==4.57.1
tokenizers==0.22.0 tokenizers==0.22.0
schemathesis>=3.39.15 # Required for openai schema test. schemathesis>=3.39.15 # Required for openai schema test.
# quantization # quantization

View File

@ -37,7 +37,7 @@ datamodel_code_generator # required for minicpm3 test
# TODO: Use lm-eval[api]==0.4.10 once released # TODO: Use lm-eval[api]==0.4.10 once released
lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d # required for model evaluation test lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d # required for model evaluation test
mteb[bm25s]>=1.38.11, <2 # required for mteb test mteb[bm25s]>=1.38.11, <2 # required for mteb test
transformers==4.56.2 transformers==4.57.1
tokenizers==0.22.0 tokenizers==0.22.0
schemathesis>=3.39.15 # Required for openai schema test. schemathesis>=3.39.15 # Required for openai schema test.
# quantization # quantization

View File

@ -1196,7 +1196,7 @@ tqdm==4.66.6
# transformers # transformers
tqdm-multiprocess==0.0.11 tqdm-multiprocess==0.0.11
# via lm-eval # via lm-eval
transformers==4.56.2 transformers==4.57.1
# via # via
# -r requirements/test.in # -r requirements/test.in
# genai-perf # genai-perf

View File

@ -6,6 +6,9 @@ from copy import deepcopy
from tblib import pickling_support from tblib import pickling_support
# Import fixture
from tests.v1.entrypoints.conftest import sample_json_schema # noqa
# ruff: noqa # ruff: noqa
# Install support for pickling exceptions so that we can nicely propagate # Install support for pickling exceptions so that we can nicely propagate

View File

@ -237,7 +237,7 @@ def deepseekv2_lora_files():
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def gptoss20b_lora_files(): def gptoss20b_lora_files():
return snapshot_download(repo_id="LevinZheng/gpt-oss-20b-lora-adapter") return snapshot_download(repo_id="jeeejeee/gpt-oss-20b-lora-adapter-text2sql")
@pytest.fixture(scope="session") @pytest.fixture(scope="session")

View File

@ -1,6 +1,10 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# NOTE To avoid overloading the CI pipeline, this test script will
# not be triggered on CI and is primarily intended for local testing
# and verification.
import vllm import vllm
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest

View File

@ -1,52 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import vllm
from vllm.lora.request import LoRARequest
MODEL_PATH = "openai/gpt-oss-20b"
PROMPT_TEMPLATE = "<begin▁of▁sentence>You are a helpful assistant.\n\nUser: {context}\n\nAssistant:" # noqa: E501
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
prompts = [
PROMPT_TEMPLATE.format(context="Who are you?"),
]
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None,
)
# Print the outputs.
generated_texts: list[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
# FIXME: Load gpt-oss adapter
def test_gptoss20b_lora(gptoss20b_lora_files):
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
# Otherwise, the lora-test will fail due to CUDA OOM.
llm = vllm.LLM(
MODEL_PATH,
enable_lora=True,
max_loras=4,
trust_remote_code=True,
)
expected_lora_output = [
"I am an AI language model developed by OpenAI. "
"I am here to help you with any questions or "
"tasks you may have."
]
output1 = do_sample(llm, gptoss20b_lora_files, lora_id=1)
print(output1)
for i in range(len(expected_lora_output)):
assert output1[i].startswith(expected_lora_output[i])

View File

@ -0,0 +1,106 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import vllm
from vllm.lora.request import LoRARequest
from ..utils import multi_gpu_test
MODEL_PATH = "openai/gpt-oss-20b"
PROMPT_TEMPLATE = """<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.
Knowledge cutoff: 2024-06
Current date: 2025-10-29
Reasoning: medium
# Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>user<|message|>I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.
"
##Instruction:
farm contains tables such as city, farm, farm_competition, competition_record. Table city has columns such as City_ID, Official_Name, Status, Area_km_2, Population, Census_Ranking. City_ID is the primary key.
Table farm has columns such as Farm_ID, Year, Total_Horses, Working_Horses, Total_Cattle, Oxen, Bulls, Cows, Pigs, Sheep_and_Goats. Farm_ID is the primary key.
Table farm_competition has columns such as Competition_ID, Year, Theme, Host_city_ID, Hosts. Competition_ID is the primary key.
Table competition_record has columns such as Competition_ID, Farm_ID, Rank. Competition_ID is the primary key.
The Host_city_ID of farm_competition is the foreign key of City_ID of city.
The Farm_ID of competition_record is the foreign key of Farm_ID of farm.
The Competition_ID of competition_record is the foreign key of Competition_ID of farm_competition.
###Input:
{context}
###Response:<|end|><|start|>assistant<|channel|>final<|message|>""" # noqa: E501
EXPECTED_LORA_OUTPUT = [
"SELECT AVG(Working_Horses) FROM farm WHERE Total_Horses > 5000;",
"SELECT AVG(Working_Horses) FROM farm WHERE Total_Horses > 5000;",
"SELECT MAX(Cows) AS Max_Cows, MIN(Cows) AS Min_Cows FROM farm;",
"SELECT MAX(Cows) AS Max_Cows, MIN(Cows) AS Min_Cows FROM farm;",
]
def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:
prompts = [
PROMPT_TEMPLATE.format(
context="What is the average number of working horses of farms with more than 5000 total number of horses?" # noqa: E501
), # noqa: E501
PROMPT_TEMPLATE.format(
context="Give the average number of working horses on farms with more than 5000 total horses." # noqa: E501
), # noqa: E501
PROMPT_TEMPLATE.format(
context="What are the maximum and minimum number of cows across all farms."
),
PROMPT_TEMPLATE.format(
context="Return the maximum and minimum number of cows across all farms."
),
]
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None,
)
# Print the outputs.
generated_texts: list[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert generated_texts[i].startswith(EXPECTED_LORA_OUTPUT[i])
def test_gpt_oss_lora(gptoss20b_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
max_lora_rank=8,
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
cudagraph_specialize_lora=False,
),
)
generate_and_test(llm, gptoss20b_lora_files, lora_id=1)
generate_and_test(llm, gptoss20b_lora_files, lora_id=2)
@multi_gpu_test(num_gpus=2)
def test_gpt_oss_lora_tp2(gptoss20b_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=2,
max_lora_rank=8,
tensor_parallel_size=2,
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
cudagraph_specialize_lora=False,
),
)
generate_and_test(llm, gptoss20b_lora_files, lora_id=1)
generate_and_test(llm, gptoss20b_lora_files, lora_id=2)

View File

@ -1,6 +1,10 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# NOTE To avoid overloading the CI pipeline, this test script will not
# be triggered on CI and is primarily intended for local testing and verification.
import vllm import vllm
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest

View File

@ -0,0 +1,86 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import asdict
from typing import NamedTuple
import pytest
from PIL.Image import Image
from transformers import AutoProcessor
from vllm import LLM, EngineArgs, SamplingParams
from vllm.multimodal.utils import encode_image_base64
MODEL_NAME = "Kwai-Keye/Keye-VL-8B-Preview"
QUESTION = "What is the content of each image?"
class ModelRequestData(NamedTuple):
engine_args: EngineArgs
prompt: str
image_data: list[Image]
stop_token_ids: list[int] | None = None
chat_template: str | None = None
sampling_params: SamplingParams | None = None
@pytest.mark.core_model
@pytest.mark.parametrize("question", [QUESTION])
def test_keye_vl(
image_assets,
question: str,
):
images = [asset.pil_image for asset in image_assets]
image_urls = [
f"data:image/jpeg;base64,{encode_image_base64(image)}" for image in images
]
engine_args = EngineArgs(
model=MODEL_NAME,
trust_remote_code=True,
max_model_len=8192,
max_num_seqs=5,
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [
{
"role": "user",
"content": [
*placeholders,
{"type": "text", "text": question},
],
},
]
processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
engine_args = asdict(engine_args) | {"seed": 42}
llm = LLM(**engine_args)
sampling_params = SamplingParams(
temperature=0.0, max_tokens=256, stop_token_ids=None
)
outputs = llm.generate(
{
"prompt": prompt,
"multi_modal_data": {"image": images},
},
sampling_params=sampling_params,
)
print("-" * 50)
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
assert len(generated_text) > 10, (
f"Generated text is too short: {generated_text}"
)
print("-" * 50)

View File

@ -186,6 +186,8 @@ def create_reduced_config(
if "text_config" in config_dict: if "text_config" in config_dict:
original_text_layers = config_dict["text_config"]["num_hidden_layers"] original_text_layers = config_dict["text_config"]["num_hidden_layers"]
config_dict["text_config"]["num_hidden_layers"] = text_layers config_dict["text_config"]["num_hidden_layers"] = text_layers
original_layer_types = config_dict["text_config"]["layer_types"]
config_dict["text_config"]["layer_types"] = original_layer_types[:text_layers]
print(f"Reduced text layers from {original_text_layers} to {text_layers}") print(f"Reduced text layers from {original_text_layers} to {text_layers}")
original_num_experts = config_dict["text_config"]["num_local_experts"] original_num_experts = config_dict["text_config"]["num_local_experts"]

View File

@ -882,27 +882,27 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
_TRANSFORMERS_BACKEND_MODELS = { _TRANSFORMERS_BACKEND_MODELS = {
"TransformersEmbeddingModel": _HfExamplesInfo( "TransformersEmbeddingModel": _HfExamplesInfo(
"BAAI/bge-base-en-v1.5", min_transformers_version="4.57.0.dev0" "BAAI/bge-base-en-v1.5", min_transformers_version="5.0.0"
), ),
"TransformersForSequenceClassification": _HfExamplesInfo( "TransformersForSequenceClassification": _HfExamplesInfo(
"papluca/xlm-roberta-base-language-detection", "papluca/xlm-roberta-base-language-detection",
min_transformers_version="4.57.0.dev0", min_transformers_version="5.0.0",
), ),
"TransformersForCausalLM": _HfExamplesInfo( "TransformersForCausalLM": _HfExamplesInfo(
"hmellor/Ilama-3.2-1B", trust_remote_code=True "hmellor/Ilama-3.2-1B", trust_remote_code=True
), ),
"TransformersMultiModalForCausalLM": _HfExamplesInfo("BAAI/Emu3-Chat-hf"), "TransformersMultiModalForCausalLM": _HfExamplesInfo("BAAI/Emu3-Chat-hf"),
"TransformersMoEForCausalLM": _HfExamplesInfo( "TransformersMoEForCausalLM": _HfExamplesInfo(
"allenai/OLMoE-1B-7B-0924", min_transformers_version="4.57.0.dev0" "allenai/OLMoE-1B-7B-0924", min_transformers_version="5.0.0"
), ),
"TransformersMultiModalMoEForCausalLM": _HfExamplesInfo( "TransformersMultiModalMoEForCausalLM": _HfExamplesInfo(
"Qwen/Qwen3-VL-30B-A3B-Instruct", min_transformers_version="4.57.0.dev0" "Qwen/Qwen3-VL-30B-A3B-Instruct", min_transformers_version="5.0.0"
), ),
"TransformersMoEEmbeddingModel": _HfExamplesInfo( "TransformersMoEEmbeddingModel": _HfExamplesInfo(
"Qwen/Qwen3-30B-A3B", min_transformers_version="4.57.0.dev0" "Qwen/Qwen3-30B-A3B", min_transformers_version="5.0.0"
), ),
"TransformersMoEForSequenceClassification": _HfExamplesInfo( "TransformersMoEForSequenceClassification": _HfExamplesInfo(
"Qwen/Qwen3-30B-A3B", min_transformers_version="4.57.0.dev0" "Qwen/Qwen3-30B-A3B", min_transformers_version="5.0.0"
), ),
"TransformersMultiModalEmbeddingModel": _HfExamplesInfo("google/gemma-3-4b-it"), "TransformersMultiModalEmbeddingModel": _HfExamplesInfo("google/gemma-3-4b-it"),
"TransformersMultiModalForSequenceClassification": _HfExamplesInfo( "TransformersMultiModalForSequenceClassification": _HfExamplesInfo(

View File

@ -82,7 +82,7 @@ def test_models(
from packaging.version import Version from packaging.version import Version
installed = Version(transformers.__version__) installed = Version(transformers.__version__)
required = Version("4.57.0.dev0") required = Version("5.0.0")
if model == "allenai/OLMoE-1B-7B-0924" and installed < required: if model == "allenai/OLMoE-1B-7B-0924" and installed < required:
pytest.skip( pytest.skip(
"MoE models with the Transformers backend require " "MoE models with the Transformers backend require "

View File

@ -14,16 +14,19 @@ import torch
from tests.v1.attention.utils import ( from tests.v1.attention.utils import (
BatchSpec, BatchSpec,
create_common_attn_metadata, create_common_attn_metadata,
create_standard_kv_cache_spec,
create_vllm_config, create_vllm_config,
try_get_attention_backend, try_get_attention_backend,
) )
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import _Backend, backend_to_class_str
from vllm.attention.ops.flashmla import is_flashmla_dense_supported from vllm.attention.ops.flashmla import is_flashmla_dense_supported
from vllm.attention.utils.fa_utils import flash_attn_supports_mla
from vllm.config.vllm import set_current_vllm_config from vllm.config.vllm import set_current_vllm_config
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.attention.backends.mla.common import QueryLenSupport
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec from vllm.v1.kv_cache_interface import FullAttentionSpec
@ -31,17 +34,46 @@ BACKENDS_TO_TEST = [
_Backend.CUTLASS_MLA, _Backend.CUTLASS_MLA,
_Backend.FLASHMLA, _Backend.FLASHMLA,
_Backend.FLASH_ATTN_MLA, _Backend.FLASH_ATTN_MLA,
_Backend.FLASHINFER_MLA,
_Backend.TRITON_MLA, _Backend.TRITON_MLA,
] ]
# Remove CUTLASS_MLA from the list if not using sm100 # Remove sm100 backends from the list if not using sm100
if not torch.cuda.is_available() or torch.cuda.get_device_properties(0).major < 10: if not torch.cuda.is_available() or torch.cuda.get_device_properties(0).major < 10:
BACKENDS_TO_TEST.remove(_Backend.CUTLASS_MLA) BACKENDS_TO_TEST.remove(_Backend.CUTLASS_MLA)
BACKENDS_TO_TEST.remove(_Backend.FLASHINFER_MLA)
# Remove FLASH_ATTN_MLA from the list if not supported
if not flash_attn_supports_mla():
BACKENDS_TO_TEST.remove(_Backend.FLASH_ATTN_MLA)
# Remove FLASHMLA from the list if not supported # Remove FLASHMLA from the list if not supported
if not is_flashmla_dense_supported()[0]: if not is_flashmla_dense_supported()[0]:
BACKENDS_TO_TEST.remove(_Backend.FLASHMLA) BACKENDS_TO_TEST.remove(_Backend.FLASHMLA)
SPEC_DECODE_BACKENDS = []
for backend in BACKENDS_TO_TEST:
builder_cls, _ = try_get_attention_backend(backend)
query_len_support = getattr(
builder_cls, "query_len_support", QueryLenSupport.SINGLE_ONLY
)
if query_len_support != QueryLenSupport.SINGLE_ONLY:
SPEC_DECODE_BACKENDS.append(backend)
BACKEND_BLOCK_SIZES = {}
for backend in BACKENDS_TO_TEST:
backend_class_str = backend_to_class_str(backend)
backend_class = resolve_obj_by_qualname(backend_class_str)
supported_sizes = backend_class.get_supported_kernel_block_size()
if supported_sizes:
default_size = supported_sizes[0]
block_size = (
default_size if isinstance(default_size, int) else default_size.base
)
else:
block_size = 16
BACKEND_BLOCK_SIZES[backend] = block_size
torch.manual_seed(42) torch.manual_seed(42)
@ -236,6 +268,26 @@ class MockAttentionLayer:
self._q_scale = torch.tensor(1.0, device=device) self._q_scale = torch.tensor(1.0, device=device)
self._k_scale = torch.tensor(1.0, device=device) self._k_scale = torch.tensor(1.0, device=device)
self._v_scale = torch.tensor(1.0, device=device) self._v_scale = torch.tensor(1.0, device=device)
self._prob_scale = torch.tensor(1.0, device=device)
self._q_scale_float = 1.0
self._k_scale_float = 1.0
self._v_scale_float = 1.0
def forward(self, *_args, **_kwargs):
raise NotImplementedError
class MockMLAAttentionLayer(AttentionLayerBase):
"""A mock MLA attention layer for populating static_forward_context."""
def __init__(self, impl):
self.impl = impl
def get_attn_backend(self):
raise NotImplementedError
def get_kv_cache_spec(self, vllm_config):
raise NotImplementedError
def run_attention_backend( def run_attention_backend(
@ -262,13 +314,6 @@ def run_attention_backend(
# Set the current vllm config so that get_current_vllm_config() works # Set the current vllm config so that get_current_vllm_config() works
# in the backend implementations # in the backend implementations
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
# Build metadata
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
attn_metadata = builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
# Instantiate MLA implementation # Instantiate MLA implementation
num_heads = vllm_config.model_config.get_num_attention_heads( num_heads = vllm_config.model_config.get_num_attention_heads(
vllm_config.parallel_config vllm_config.parallel_config
@ -302,6 +347,19 @@ def run_attention_backend(
act_dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype) act_dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
impl.process_weights_after_loading(act_dtype) impl.process_weights_after_loading(act_dtype)
# Populate static_forward_context with mock attention layers
for layer_name in layer_names:
vllm_config.compilation_config.static_forward_context[layer_name] = (
MockMLAAttentionLayer(impl)
)
# Build metadata
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
attn_metadata = builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
# Create mock layer and output buffer # Create mock layer and output buffer
mock_layer = MockAttentionLayer(device) mock_layer = MockAttentionLayer(device)
num_tokens = query.shape[0] num_tokens = query.shape[0]
@ -353,15 +411,14 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
simulated paged KV cache. simulated paged KV cache.
5. Comparing the vLLM backend's output to the ground-truth SDPA output. 5. Comparing the vLLM backend's output to the ground-truth SDPA output.
""" """
from vllm.v1.attention.backends.mla.common import QueryLenSupport
batch_spec = BATCH_SPECS[batch_spec_name] batch_spec = BATCH_SPECS[batch_spec_name]
is_spec_decode_test = batch_spec_name.startswith("spec_decode") is_spec_decode_test = batch_spec_name.startswith("spec_decode")
spec_decode_backends = {_Backend.FLASH_ATTN_MLA, _Backend.FLASHMLA} unique_block_sizes = sorted(set(BACKEND_BLOCK_SIZES.values()))
default_block_size = unique_block_sizes[0]
block_size = 16
required_blocks = sum( required_blocks = sum(
(seq_len + block_size - 1) // block_size for seq_len in batch_spec.seq_lens (seq_len + default_block_size - 1) // default_block_size
for seq_len in batch_spec.seq_lens
) )
# Add 1 for null block at index 0, and some buffer # Add 1 for null block at index 0, and some buffer
num_gpu_blocks = required_blocks + 1 + 100 num_gpu_blocks = required_blocks + 1 + 100
@ -370,7 +427,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
model_name=model, model_name=model,
max_model_len=max(batch_spec.seq_lens), max_model_len=max(batch_spec.seq_lens),
num_gpu_blocks=num_gpu_blocks, num_gpu_blocks=num_gpu_blocks,
block_size=block_size, block_size=default_block_size,
) )
# For spec decode tests, add a speculative_config to set the reorder_batch_threshold # For spec decode tests, add a speculative_config to set the reorder_batch_threshold
@ -388,8 +445,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
device = torch.device("cuda:0") device = torch.device("cuda:0")
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
# 1. Setup # 1. Setup
batch_size = batch_spec.batch_size batch_size = batch_spec.batch_size
seq_lens = batch_spec.seq_lens seq_lens = batch_spec.seq_lens
@ -399,7 +454,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
) )
head_size = vllm_config.model_config.get_head_size() head_size = vllm_config.model_config.get_head_size()
dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype) dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
block_size = vllm_config.cache_config.block_size
kv_lora_rank = 512 kv_lora_rank = 512
qk_rope_head_dim = 64 qk_rope_head_dim = 64
qk_nope_head_dim = 128 qk_nope_head_dim = 128
@ -598,33 +652,83 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
) )
mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T, requires_grad=False) mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T, requires_grad=False)
# Create metadata using original batch spec # 3. Create metadata and KV caches for each block size
common_attn_metadata = create_common_attn_metadata( # Group backends by block size and test each group
batch_spec, vllm_config.cache_config.block_size, device metadata_per_block_size = {}
) kv_cache_per_block_size = {}
# 3. Simulate Paged KV Cache and a realistic slot_mapping for block_size in unique_block_sizes:
kv_cache = create_and_prepopulate_kv_cache( # Create metadata for this block size
kv_c_contexts=kv_c_contexts, common_attn_metadata = create_common_attn_metadata(
k_pe_contexts=k_pe_contexts, batch_spec, block_size, device
block_size=block_size, )
head_size=head_size,
dtype=dtype, # Pad block table to meet requirement:
device=device, # block_num % (128 / block_size) == 0
num_blocks=vllm_config.cache_config.num_gpu_blocks, required_divisor = int(128 / block_size)
common_attn_metadata=common_attn_metadata, current_block_num = common_attn_metadata.block_table_tensor.shape[1]
randomize_blocks=True, if current_block_num % required_divisor != 0:
) # Pad to next multiple of required_divisor
padded_block_num = (
(current_block_num + required_divisor - 1) // required_divisor
) * required_divisor
padding_cols = padded_block_num - current_block_num
padding = torch.zeros(
(common_attn_metadata.block_table_tensor.shape[0], padding_cols),
dtype=torch.int32,
device=device,
)
common_attn_metadata.block_table_tensor = torch.cat(
[common_attn_metadata.block_table_tensor, padding], dim=1
)
metadata_per_block_size[block_size] = common_attn_metadata
# Create KV cache for this block size
required_blocks_for_size = sum(
(seq_len + block_size - 1) // block_size for seq_len in batch_spec.seq_lens
)
num_blocks_for_size = required_blocks_for_size + 1 + 100
kv_cache = create_and_prepopulate_kv_cache(
kv_c_contexts=kv_c_contexts,
k_pe_contexts=k_pe_contexts,
block_size=block_size,
head_size=head_size,
dtype=dtype,
device=device,
num_blocks=num_blocks_for_size,
common_attn_metadata=common_attn_metadata,
randomize_blocks=True,
)
kv_cache_per_block_size[block_size] = kv_cache
# 4. Run vLLM backends and compare # 4. Run vLLM backends and compare
failures = []
for backend_idx, backend_name in enumerate(BACKENDS_TO_TEST): for backend_idx, backend_name in enumerate(BACKENDS_TO_TEST):
# Skip backends that don't support spec decode for spec decode tests # Skip backends that don't support spec decode for spec decode tests
if is_spec_decode_test and backend_name not in spec_decode_backends: if is_spec_decode_test and backend_name not in SPEC_DECODE_BACKENDS:
continue continue
# Get the appropriate block_size, metadata, and cache for this backend
block_size = BACKEND_BLOCK_SIZES[backend_name]
common_attn_metadata = metadata_per_block_size[block_size]
kv_cache = kv_cache_per_block_size[block_size]
# Create kv_cache_spec with the correct block_size for this backend
backend_kv_cache_spec = FullAttentionSpec(
block_size=block_size,
num_kv_heads=vllm_config.model_config.get_num_kv_heads(
vllm_config.parallel_config
),
head_size=vllm_config.model_config.get_head_size(),
dtype=vllm_config.model_config.dtype,
sliding_window=vllm_config.model_config.get_sliding_window(),
)
backend_output = run_attention_backend( backend_output = run_attention_backend(
backend_name, backend_name,
kv_cache_spec, backend_kv_cache_spec,
["placeholder"], ["placeholder"],
vllm_config, vllm_config,
device, device,
@ -644,32 +748,48 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
expected_output = sdpa_outputs[backend_name] expected_output = sdpa_outputs[backend_name]
# Check shape and dtype consistency # Check shape and dtype consistency
assert backend_output.shape == expected_output.shape, ( try:
f"[{backend_name}] shape {backend_output.shape} != " assert backend_output.shape == expected_output.shape, (
f"SDPA shape {expected_output.shape}" f"[{backend_name}] shape {backend_output.shape} != "
) f"SDPA shape {expected_output.shape}"
assert backend_output.dtype == expected_output.dtype, ( )
f"[{backend_name}] dtype {backend_output.dtype} != " assert backend_output.dtype == expected_output.dtype, (
f"SDPA dtype {expected_output.dtype}" f"[{backend_name}] dtype {backend_output.dtype} != "
) f"SDPA dtype {expected_output.dtype}"
)
assert torch.isfinite(backend_output).all(), ( assert torch.isfinite(backend_output).all(), (
f"[{backend_name}] produced non-finite values" f"[{backend_name}] produced non-finite values"
) )
# Check numerical similarity # Check numerical similarity
rtol = 1e-2 rtol = 1e-2
atol = 5e-1 atol = 5e-1
max_diff = torch.max(torch.abs(backend_output - expected_output)).item() max_diff = torch.max(torch.abs(backend_output - expected_output)).item()
max_rel_diff = torch.max( max_rel_diff = torch.max(
torch.abs(backend_output - expected_output) / torch.abs(expected_output) torch.abs(backend_output - expected_output) / torch.abs(expected_output)
).item() ).item()
all_close = torch.allclose( all_close = torch.allclose(
backend_output, expected_output, rtol=rtol, atol=atol backend_output, expected_output, rtol=rtol, atol=atol
) )
assert all_close, ( assert all_close, (
f"[{backend_name}] output differs from SDPA baseline. " f"[{backend_name}] output differs from SDPA baseline. "
f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})" f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})"
) )
except AssertionError as e:
failures.append(str(e))
# Report all failures at once
if failures:
# Create a summary for the single-line failure message
backend_names = []
for f in failures:
if "[_Backend." in f:
backend_name = f.split("[")[1].split("]")[0]
backend_names.append(backend_name)
summary = f"{len(failures)} backend(s) failed: {', '.join(backend_names)}"
detailed_msg = "\n".join(failures)
pytest.fail(f"{summary}\n{detailed_msg}")

View File

@ -285,7 +285,17 @@ full_cg_backend_configs = {
name="CutlassMLA", name="CutlassMLA",
env_vars={ env_vars={
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA", "VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
"FORCE_NUM_KV_SPLITS": "1", # TODO: remove this when hang issue is fixed },
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
},
specific_gpu_arch=(10, 0),
),
# FlashInfer MLA on Blackwell
"FlashInferMLA": BackendConfig(
name="FlashInferMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASHINFER_MLA",
}, },
comp_config={ comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE", "cudagraph_mode": "FULL_AND_PIECEWISE",

View File

@ -337,8 +337,6 @@ def test_stop_via_update_from_output():
num_common_prefix_blocks=[], num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
) )
model_output = ModelRunnerOutput( model_output = ModelRunnerOutput(
@ -385,8 +383,6 @@ def test_stop_via_update_from_output():
num_common_prefix_blocks=[], num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
) )
model_output = ModelRunnerOutput( model_output = ModelRunnerOutput(
@ -431,8 +427,6 @@ def test_stop_via_update_from_output():
num_common_prefix_blocks=[], num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
) )
model_output = ModelRunnerOutput( model_output = ModelRunnerOutput(
@ -472,8 +466,6 @@ def test_stop_via_update_from_output():
num_common_prefix_blocks=[], num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
) )
model_output = ModelRunnerOutput( model_output = ModelRunnerOutput(
@ -1988,7 +1980,6 @@ def test_schedule_skip_tokenizer_init():
scheduler.add_request(request) scheduler.add_request(request)
output = scheduler.schedule() output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests) assert len(output.scheduled_new_reqs) == len(requests)
assert output.grammar_bitmask is None
def test_schedule_skip_tokenizer_init_structured_output_request(): def test_schedule_skip_tokenizer_init_structured_output_request():

View File

@ -7,6 +7,7 @@ import torch._dynamo.config as dynamo_config
from vllm import SamplingParams from vllm import SamplingParams
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.sampling_params import StructuredOutputsParams
from ...conftest import VllmRunner from ...conftest import VllmRunner
from ...models.utils import check_outputs_equal from ...models.utils import check_outputs_equal
@ -15,9 +16,12 @@ MODEL = "Qwen/Qwen3-0.6B"
@dynamo_config.patch(cache_size_limit=16) @dynamo_config.patch(cache_size_limit=16)
def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch): def test_preempt_and_async_scheduling_e2e(
sample_json_schema, monkeypatch: pytest.MonkeyPatch
):
"""Test consistency of combos of async scheduling, preemption, """Test consistency of combos of async scheduling, preemption,
uni/multiproc executor, and various sampling parameters.""" uni/multiproc executor, and various sampling parameters
including structured outputs."""
first_prompt = ( first_prompt = (
"The following numbers of the sequence " "The following numbers of the sequence "
@ -35,6 +39,12 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
dict(bad_words=["the", " the"]), dict(bad_words=["the", " the"]),
dict(logprobs=2), dict(logprobs=2),
dict(logprobs=2, presence_penalty=-1.0), dict(logprobs=2, presence_penalty=-1.0),
dict(structured_outputs=StructuredOutputsParams(json=sample_json_schema)),
dict(
structured_outputs=StructuredOutputsParams(json=sample_json_schema),
logprobs=2,
presence_penalty=-1.0,
),
] ]
default_params = dict( default_params = dict(

View File

@ -248,7 +248,7 @@ def test_engine_core_concurrent_batches():
self, self,
scheduler_output, scheduler_output,
non_block=False, non_block=False,
) -> Future[ModelRunnerOutput]: ) -> Future[ModelRunnerOutput | None]:
"""Make execute_model non-blocking.""" """Make execute_model non-blocking."""
# DummyExecutor used only for testing async case. # DummyExecutor used only for testing async case.
@ -263,6 +263,23 @@ def test_engine_core_concurrent_batches():
# Use the thread pool instead of creating a new thread # Use the thread pool instead of creating a new thread
return self.thread_pool.submit(_execute) return self.thread_pool.submit(_execute)
def sample_tokens(
self, grammar_output, non_block=False
) -> Future[ModelRunnerOutput]:
"""Make sample_tokens non-blocking."""
# DummyExecutor used only for testing async case.
assert non_block
def _execute():
output = self.collective_rpc("sample_tokens", args=(grammar_output,))
# Make a copy because output[0] may be reused
# by the next batch.
return copy.deepcopy(output[0])
# Use the thread pool instead of creating a new thread
return self.thread_pool.submit(_execute)
@property @property
def max_concurrent_batches(self) -> int: def max_concurrent_batches(self) -> int:
return 2 return 2

View File

@ -31,7 +31,9 @@ class CustomMultiprocExecutor(MultiprocExecutor):
# Drop marker to show that this was run # Drop marker to show that this was run
with open(".marker", "w"): with open(".marker", "w"):
... ...
return super().collective_rpc(method, timeout, args, kwargs) return super().collective_rpc(
method, timeout, args, kwargs, non_block, unique_reply_rank
)
CustomMultiprocExecutorAsync = CustomMultiprocExecutor CustomMultiprocExecutorAsync = CustomMultiprocExecutor

View File

@ -0,0 +1,65 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for KV cache offloading configuration."""
import pytest
from vllm.config import CacheConfig, KVTransferConfig, ParallelConfig, VllmConfig
pytestmark = pytest.mark.cpu_test
@pytest.mark.parametrize(
"kv_offloading_backend,kv_offloading_size,tp,pp,expected_backend,expected_bytes",
[
("native", 4.0, 1, 1, "OffloadingConnector", 4.0 * (1 << 30)),
# bytes per rank: 8.0 GiB / (2 * 2) = 2.0 GiB
("native", 8.0, 2, 2, "OffloadingConnector", 8.0 * (1 << 30) / 4),
("lmcache", 4.0, 1, 1, "LMCacheConnectorV1", 4.0),
# size per rank: 8.0 GiB / (2 * 2) = 2.0 GiB
("lmcache", 8.0, 2, 2, "LMCacheConnectorV1", 2.0),
(None, None, 1, 1, None, None),
],
)
def test_kv_connector(
kv_offloading_backend, kv_offloading_size, tp, pp, expected_backend, expected_bytes
):
kv_transfer_config = (
KVTransferConfig(kv_connector_extra_config={"existing_key": "existing_value"})
if expected_backend is not None
else None
)
vllm_config = VllmConfig(
cache_config=CacheConfig(
kv_offloading_backend=kv_offloading_backend,
kv_offloading_size=kv_offloading_size,
),
kv_transfer_config=kv_transfer_config,
parallel_config=ParallelConfig(
tensor_parallel_size=tp, pipeline_parallel_size=pp
),
)
# No KV transfer config expected
if expected_backend is None:
assert vllm_config.kv_transfer_config is expected_backend
return
kv_transfer_config = vllm_config.kv_transfer_config
kv_connector_extra_config = kv_transfer_config.kv_connector_extra_config
assert kv_transfer_config.kv_connector == expected_backend
assert kv_transfer_config.kv_role == "kv_both"
if kv_offloading_backend == "native":
assert kv_connector_extra_config["kv_bytes_per_rank"] == expected_bytes
assert kv_connector_extra_config["num_cpu_blocks"] == 0
# Existing config should be preserved
assert kv_connector_extra_config["existing_key"] == "existing_value"
elif kv_offloading_backend == "lmcache":
assert kv_connector_extra_config["lmcache.local_cpu"] is True
assert kv_connector_extra_config["lmcache.max_local_cpu_size"] == expected_bytes
# Existing config should be replaced
assert "existing_key" not in kv_connector_extra_config

View File

@ -26,8 +26,6 @@ def _make_empty_scheduler_output():
num_common_prefix_blocks=[], num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
kv_connector_metadata=SharedStorageConnectorMetadata(), kv_connector_metadata=SharedStorageConnectorMetadata(),
) )

View File

@ -27,6 +27,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlAgentMetadata, NixlAgentMetadata,
NixlConnector, NixlConnector,
NixlConnectorMetadata, NixlConnectorMetadata,
NixlConnectorScheduler,
NixlConnectorWorker, NixlConnectorWorker,
NixlKVConnectorStats, NixlKVConnectorStats,
) )
@ -283,6 +284,92 @@ def test_prompt_less_than_block_size():
assert len(scheduler_output.scheduled_new_reqs) == 0 assert len(scheduler_output.scheduled_new_reqs) == 0
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
def test_kv_transfer_handshake(dist_init):
"""Unit test for basic NixlConnector interface functionality."""
# Test setup, we creates a scheduler that contains a NixlConnector
# of role SCHEDULER, and expect it to be serving NixlAgentMetadata from
# all workers of the instance.
vllm_config = create_vllm_config()
# in case the test runs on non-GPU machine
vllm_config.kv_transfer_config.kv_buffer_device = "cpu"
scheduler = create_scheduler(vllm_config)
# Create two NixlConnector of role WORKER, one is the worker of
# the scheduler (prefill), the other is a worker of decode instance.
# Prefill connector will register KV cache to populate proper handshake
# metadata.
prefill_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
"layer2": shared_tensor,
}
prefill_connector.register_kv_caches(kv_caches)
# Simulate EngineCore initialization that would
# gather connector metadata from all workers, the scheduler connector
# expects metadata to be in dict[int, KVConnectorHandshakeMetadata],
# where the first key is the dp_rank, the second key is the tp_rank.
metadata = {0: prefill_connector.get_handshake_metadata()}
scheduler_connector = scheduler.get_kv_connector()
scheduler_connector.set_xfer_handshake_metadata(metadata)
# Simulate a request that finishes prefill, which returns
# corresponding NixlConnectorMetadata for decode instance.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request = create_request(
request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_decode=True,
)
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
delay, kv_connector_metadata = scheduler.get_kv_connector().request_finished(
request, [0, 1, 2]
)
assert delay
# Decode connector will be able to create handshake with the prefill connector.
decode_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
# Here we are testing the retrieval of NIXLAgentMetadata.
# Knowing the implementation detail, we override the add_remote_agent
# to validate the metadata received is the same as the one in prefill_connector.
with patch.object(
decode_connector.connector_worker, "add_remote_agent"
) as mock_add_remote_agent:
mock_add_remote_agent.return_type = "remote_agent"
decode_connector.connector_worker._nixl_handshake(
kv_connector_metadata["remote_host"],
kv_connector_metadata["remote_port"],
kv_connector_metadata["tp_size"],
kv_connector_metadata["remote_engine_id"],
)
received_metadata = mock_add_remote_agent.call_args.args
assert received_metadata[1] == 0 # remote_tp_rank
assert received_metadata[2] == 1 # remote_tp_size
assert metadata[0] == received_metadata[0]
# Need to shutdown the background thread to release NIXL side channel port
scheduler_connector.shutdown()
class FakeNixlConnectorWorker(NixlConnectorWorker): class FakeNixlConnectorWorker(NixlConnectorWorker):
REMOTE_ENGINE_ID = "remote_engine" REMOTE_ENGINE_ID = "remote_engine"
@ -313,6 +400,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
engine_id=self.REMOTE_ENGINE_ID, engine_id=self.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA, agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0], kv_caches_base_addr=[0],
device_id=0,
num_blocks=1, num_blocks=1,
block_lens=self.block_len_per_layer, block_lens=self.block_len_per_layer,
attn_backend_name=self.backend_name, attn_backend_name=self.backend_name,
@ -559,6 +647,7 @@ class TestNixlHandshake:
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID, engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA, agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0], kv_caches_base_addr=[0],
device_id=0,
num_blocks=1, num_blocks=1,
block_lens=worker.block_len_per_layer, block_lens=worker.block_len_per_layer,
attn_backend_name=worker.backend_name, attn_backend_name=worker.backend_name,
@ -611,6 +700,7 @@ class TestNixlHandshake:
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID, engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA, agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0], kv_caches_base_addr=[0],
device_id=0,
num_blocks=1, num_blocks=1,
# prefill TP=1, decode TP=2, remote block_lens is double to local # prefill TP=1, decode TP=2, remote block_lens is double to local
block_lens=[i * 2 for i in worker.block_len_per_layer], block_lens=[i * 2 for i in worker.block_len_per_layer],
@ -891,9 +981,7 @@ def test_scheduler_kv_connector_stats_aggregation():
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
num_common_prefix_blocks=[0], num_common_prefix_blocks=[0],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=set(), free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
) )
engine_core_outputs = scheduler.update_from_output(scheduler_output, model_output) engine_core_outputs = scheduler.update_from_output(scheduler_output, model_output)
@ -1005,6 +1093,8 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
_ = llm.generate([f"What is the capital of France? {padding}"], sampling_params) _ = llm.generate([f"What is the capital of France? {padding}"], sampling_params)
# Request-0 times out and is cleared! # Request-0 times out and is cleared!
assert "0" not in req_to_blocks assert "0" not in req_to_blocks
# Need to shutdown the background thread to release NIXL side channel port
llm.llm_engine.engine_core.shutdown()
def test_register_kv_caches(dist_init): def test_register_kv_caches(dist_init):
@ -1177,13 +1267,15 @@ def test_shutdown_cleans_up_resources(dist_init):
"""Test that shutdown() properly cleans up all resources.""" """Test that shutdown() properly cleans up all resources."""
vllm_config = create_vllm_config() vllm_config = create_vllm_config()
scheduler = NixlConnectorScheduler(
vllm_config, vllm_config.kv_transfer_config.engine_id
)
worker = NixlConnectorWorker(vllm_config, vllm_config.kv_transfer_config.engine_id) worker = NixlConnectorWorker(vllm_config, vllm_config.kv_transfer_config.engine_id)
nixl_wrapper = worker.nixl_wrapper nixl_wrapper = worker.nixl_wrapper
with ( with (
patch.object(worker, "_handshake_initiation_executor") as mock_exec, patch.object(worker, "_handshake_initiation_executor") as mock_exec,
patch.object(worker, "_nixl_handshake_listener_t") as mock_listener, patch.object(scheduler, "_nixl_handshake_listener_t") as mock_listener,
patch.object(worker, "_nixl_handshake_listener_stop_event") as mock_event,
patch.object(nixl_wrapper, "release_xfer_handle") as mock_rel_xfer, patch.object(nixl_wrapper, "release_xfer_handle") as mock_rel_xfer,
patch.object(nixl_wrapper, "release_dlist_handle") as mock_rel_dlist, patch.object(nixl_wrapper, "release_dlist_handle") as mock_rel_dlist,
patch.object(nixl_wrapper, "remove_remote_agent") as mock_rem_agent, patch.object(nixl_wrapper, "remove_remote_agent") as mock_rem_agent,
@ -1204,8 +1296,12 @@ def test_shutdown_cleans_up_resources(dist_init):
worker.shutdown() worker.shutdown()
mock_exec.shutdown.assert_called_with(wait=False) mock_exec.shutdown.assert_called_with(wait=False)
mock_event.set.assert_called_once()
mock_listener.join.assert_called_once_with(timeout=1.0) # Same sequence on scheduler.shutdown()
scheduler.shutdown()
scheduler.shutdown()
scheduler.shutdown()
mock_listener.join.assert_called_once()
mock_rel_xfer.assert_called_once_with(123) mock_rel_xfer.assert_called_once_with(123)
assert mock_rel_dlist.call_count == 2 assert mock_rel_dlist.call_count == 2

View File

@ -92,8 +92,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
num_common_prefix_blocks=[], num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
) )
@ -171,8 +169,6 @@ def test_update_states_request_finished(model_runner):
num_common_prefix_blocks=[], num_common_prefix_blocks=[],
finished_req_ids={req_id}, finished_req_ids={req_id},
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
) )
model_runner._update_states(scheduler_output) model_runner._update_states(scheduler_output)
@ -201,8 +197,6 @@ def test_update_states_request_resumed(model_runner):
num_common_prefix_blocks=[], num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
) )
model_runner._update_states(scheduler_output) model_runner._update_states(scheduler_output)
@ -230,8 +224,6 @@ def test_update_states_request_resumed(model_runner):
num_common_prefix_blocks=[], num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
) )
model_runner._update_states(scheduler_output) model_runner._update_states(scheduler_output)
@ -261,8 +253,6 @@ def test_update_states_no_changes(model_runner):
num_common_prefix_blocks=[], num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
) )
model_runner._update_states(scheduler_output) model_runner._update_states(scheduler_output)
@ -296,8 +286,6 @@ def test_update_states_request_unscheduled(model_runner):
num_common_prefix_blocks=[], num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
) )
model_runner._update_states(scheduler_output) model_runner._update_states(scheduler_output)

View File

@ -6,6 +6,7 @@ import pytest
import torch import torch
from vllm.attention import Attention from vllm.attention import Attention
from vllm.attention.backends.abstract import MultipleOf
from vllm.config import ( from vllm.config import (
CacheConfig, CacheConfig,
ModelConfig, ModelConfig,
@ -34,6 +35,7 @@ from vllm.v1.kv_cache_interface import (
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.utils import AttentionGroup
BLOCK_SIZE = 16 BLOCK_SIZE = 16
NUM_BLOCKS = 10 NUM_BLOCKS = 10
@ -150,8 +152,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
num_common_prefix_blocks=[], num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
) )
@ -181,6 +181,57 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
).all() ).all()
def _make_mock_backend_for_kernel_block_size(
supported_sizes: list[int | MultipleOf],
):
class _MockBackend:
@staticmethod
def get_supported_kernel_block_size():
return supported_sizes
return _MockBackend()
def _make_kv_cache_spec() -> FullAttentionSpec:
return FullAttentionSpec(block_size=1, num_kv_heads=1, head_size=1, dtype="float16")
def test_select_common_block_size_prefers_manager_block_size():
backend_a = _make_mock_backend_for_kernel_block_size([MultipleOf(32)])
backend_b = _make_mock_backend_for_kernel_block_size([64, MultipleOf(16)])
attn_groups = [
AttentionGroup(backend_a, [], [], _make_kv_cache_spec(), 0),
AttentionGroup(backend_b, [], [], _make_kv_cache_spec(), 0),
]
selected_size = GPUModelRunner.select_common_block_size(128, attn_groups)
assert selected_size == 128
def test_select_common_block_size_uses_largest_shared_int():
backend_a = _make_mock_backend_for_kernel_block_size([128, 64])
backend_b = _make_mock_backend_for_kernel_block_size([64, 32])
attn_groups = [
AttentionGroup(backend_a, [], [], _make_kv_cache_spec(), 0),
AttentionGroup(backend_b, [], [], _make_kv_cache_spec(), 0),
]
selected_size = GPUModelRunner.select_common_block_size(256, attn_groups)
assert selected_size == 64
def test_select_common_block_size_no_valid_option():
backend_a = _make_mock_backend_for_kernel_block_size([64])
backend_b = _make_mock_backend_for_kernel_block_size([MultipleOf(16)])
attn_groups = [
AttentionGroup(backend_a, [], [], _make_kv_cache_spec(), 0),
AttentionGroup(backend_b, [], [], _make_kv_cache_spec(), 0),
]
with pytest.raises(ValueError):
GPUModelRunner.select_common_block_size(48, attn_groups)
def test_update_states_new_request(model_runner, dist_init): def test_update_states_new_request(model_runner, dist_init):
req_id = "req_0" req_id = "req_0"
@ -216,8 +267,6 @@ def test_update_states_request_finished(model_runner, dist_init):
num_common_prefix_blocks=[], num_common_prefix_blocks=[],
finished_req_ids={req_id}, finished_req_ids={req_id},
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
) )
metadata_before = model_runner.input_batch.sampling_metadata metadata_before = model_runner.input_batch.sampling_metadata
@ -248,8 +297,6 @@ def test_update_states_request_resumed(model_runner, dist_init):
num_common_prefix_blocks=[], num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
) )
model_runner._update_states(scheduler_output) model_runner._update_states(scheduler_output)
@ -277,8 +324,6 @@ def test_update_states_request_resumed(model_runner, dist_init):
num_common_prefix_blocks=[], num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
) )
metadata_before = model_runner.input_batch.sampling_metadata metadata_before = model_runner.input_batch.sampling_metadata
@ -370,8 +415,6 @@ def test_update_states_no_changes(model_runner, dist_init):
num_common_prefix_blocks=[], num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
) )
metadata_before = model_runner.input_batch.sampling_metadata metadata_before = model_runner.input_batch.sampling_metadata
@ -407,8 +450,6 @@ def test_update_states_request_unscheduled(model_runner, dist_init):
num_common_prefix_blocks=[], num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
) )
metadata_before = model_runner._update_states(scheduler_output) metadata_before = model_runner._update_states(scheduler_output)

View File

@ -270,21 +270,23 @@ class ipex_ops:
@staticmethod @staticmethod
def flash_attn_varlen_func( def flash_attn_varlen_func(
out: torch.Tensor,
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
cu_seqlens_q: torch.Tensor, cu_seqlens_q: torch.Tensor,
seqused_k: torch.Tensor, # we don't support this in ipex kernel
max_seqlen_q: int, max_seqlen_q: int,
max_seqlen_k: int, max_seqlen_k: int,
softmax_scale: float, softmax_scale: float | None = None,
causal: bool, causal: bool = False,
block_table: torch.Tensor, out: torch.Tensor | None = None,
alibi_slopes: torch.Tensor | None, block_table: torch.Tensor | None = None,
alibi_slopes: torch.Tensor | None = None,
window_size: list[int] | None = None, window_size: list[int] | None = None,
softcap: float | None = 0.0, softcap: float | None = 0.0,
seqused_k: torch.Tensor | None = None,
cu_seqlens_k: torch.Tensor | None = None, cu_seqlens_k: torch.Tensor | None = None,
# passed in qwen vl
dropout_p: float = 0.0,
# The following parameters are not used in ipex kernel currently, # The following parameters are not used in ipex kernel currently,
# we keep API compatible to CUDA's. # we keep API compatible to CUDA's.
scheduler_metadata=None, scheduler_metadata=None,
@ -295,31 +297,63 @@ class ipex_ops:
num_splits=0, num_splits=0,
s_aux: torch.Tensor | None = None, s_aux: torch.Tensor | None = None,
): ):
if out is None:
out = torch.empty(q.shape, dtype=q.dtype, device=q.device)
real_window_size: tuple[int, int] real_window_size: tuple[int, int]
if window_size is None: if window_size is None:
real_window_size = (-1, -1) real_window_size = (-1, -1)
else: else:
assert len(window_size) == 2 assert len(window_size) == 2
real_window_size = (window_size[0], window_size[1]) real_window_size = (window_size[0], window_size[1])
return ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
out, if block_table is None:
q.contiguous(), assert cu_seqlens_k is not None, (
k, "cu_seqlens_k can't be None when calling varlen_attention."
v, )
cu_seqlens_q, if softmax_scale is None:
seqused_k, softmax_scale = q.shape[-1] ** (-0.5)
max_seqlen_q, ipex_ops.varlen_attention(
max_seqlen_k, q.contiguous(),
softmax_scale, k.contiguous(),
causal, v.contiguous(),
block_table, out,
alibi_slopes, cu_seqlens_q,
softcap=softcap, cu_seqlens_k,
window_size_left=real_window_size[0], None,
window_size_right=real_window_size[1], max_seqlen_q,
k_scale=1.0, max_seqlen_k,
v_scale=1.0, 0.0,
) softmax_scale,
False,
causal,
False,
None,
real_window_size[0],
real_window_size[1],
-1,
)
return out
else:
return ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
out,
q.contiguous(),
k,
v,
cu_seqlens_q,
seqused_k,
max_seqlen_q,
max_seqlen_k,
softmax_scale,
causal,
block_table,
alibi_slopes,
sink=s_aux,
softcap=softcap,
window_size_left=real_window_size[0],
window_size_right=real_window_size[1],
k_scale=1.0,
v_scale=1.0,
)
@staticmethod @staticmethod
def get_scheduler_metadata( def get_scheduler_metadata(

View File

@ -123,6 +123,11 @@ def maybe_get_vit_flash_attn_backend(
): ):
attn_backend = _Backend.FLASH_ATTN attn_backend = _Backend.FLASH_ATTN
use_upstream_fa = True use_upstream_fa = True
elif current_platform.is_xpu():
assert attn_backend == _Backend.FLASH_ATTN, (
"XPU platform only supports FLASH_ATTN as vision attention backend."
)
use_upstream_fa = False
else: else:
return _Backend.TORCH_SDPA, None return _Backend.TORCH_SDPA, None
@ -133,7 +138,7 @@ def maybe_get_vit_flash_attn_backend(
if use_upstream_fa: if use_upstream_fa:
from flash_attn import flash_attn_varlen_func from flash_attn import flash_attn_varlen_func
else: else:
from vllm.vllm_flash_attn import flash_attn_varlen_func from vllm.attention.utils.fa_utils import flash_attn_varlen_func
else: else:
flash_attn_varlen_func = None flash_attn_varlen_func = None
@ -521,22 +526,18 @@ class MultiHeadAttention(nn.Module):
# If vllm native fa is selected, we use it directly. # If vllm native fa is selected, we use it directly.
use_upstream_fa = False use_upstream_fa = False
if current_platform.is_xpu(): self.attn_backend = (
# currently, only torch_sdpa is supported on xpu backend
self.attn_backend = _Backend.TORCH_SDPA if backend
else: in {
self.attn_backend = ( _Backend.TORCH_SDPA,
backend _Backend.XFORMERS,
if backend _Backend.PALLAS,
in { _Backend.ROCM_AITER_FA,
_Backend.TORCH_SDPA, _Backend.FLASH_ATTN,
_Backend.XFORMERS, }
_Backend.PALLAS, else _Backend.TORCH_SDPA
_Backend.ROCM_AITER_FA, )
_Backend.FLASH_ATTN,
}
else _Backend.TORCH_SDPA
)
self.attn_backend, self._flash_attn_varlen_func = ( self.attn_backend, self._flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend( maybe_get_vit_flash_attn_backend(

View File

@ -70,7 +70,7 @@ def flash_attn_maxseqlen_wrapper(
if use_upstream_fa: if use_upstream_fa:
from flash_attn import flash_attn_varlen_func from flash_attn import flash_attn_varlen_func
else: else:
from vllm.vllm_flash_attn import flash_attn_varlen_func from vllm.attention.utils.fa_utils import flash_attn_varlen_func
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = flash_attn_varlen_func( output = flash_attn_varlen_func(
q, q,

View File

@ -24,6 +24,7 @@ BlockSize = Literal[1, 8, 16, 32, 64, 128, 256]
CacheDType = Literal["auto", "bfloat16", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"] CacheDType = Literal["auto", "bfloat16", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"]
MambaDType = Literal["auto", "float32"] MambaDType = Literal["auto", "float32"]
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"] PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"]
KVOffloadingBackend = Literal["native", "lmcache"]
@config @config
@ -128,6 +129,17 @@ class CacheConfig:
gpu_memory_utilization. Note that kv_cache_memory_bytes gpu_memory_utilization. Note that kv_cache_memory_bytes
(when not-None) ignores gpu_memory_utilization""" (when not-None) ignores gpu_memory_utilization"""
kv_offloading_size: float | None = None
"""Size of the KV cache offloading buffer in GiB. When TP > 1, this is
the total buffer size summed across all TP ranks. By default, this is set
to None, which means no KV offloading is enabled. When set with
kv_offloading_backend, vLLM will enable KV cache offloading to CPU"""
kv_offloading_backend: KVOffloadingBackend | None = None
"""The backend to use for KV cache offloading. Supported backends include
'native' (vLLM native CPU offloading), 'lmcache' This option must be used
together with kv_offloading_size."""
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
WARNING: Whenever a new field is added to this config, WARNING: Whenever a new field is added to this config,

View File

@ -2,10 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib import hashlib
from dataclasses import InitVar, field from collections.abc import Callable
from dataclasses import InitVar
from typing import Any, Literal from typing import Any, Literal
from pydantic import SkipValidation, model_validator from pydantic import Field, field_validator, model_validator
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
from typing_extensions import Self from typing_extensions import Self
@ -31,28 +32,28 @@ class SchedulerConfig:
runner_type: RunnerType = "generate" runner_type: RunnerType = "generate"
"""The runner type to launch for the model.""" """The runner type to launch for the model."""
max_num_batched_tokens: SkipValidation[int] = None # type: ignore max_num_batched_tokens: int = Field(default=None, ge=1)
"""Maximum number of tokens to be processed in a single iteration. """Maximum number of tokens to be processed in a single iteration.
This config has no static default. If left unspecified by the user, it will This config has no static default. If left unspecified by the user, it will
be set in `EngineArgs.create_engine_config` based on the usage context.""" be set in `EngineArgs.create_engine_config` based on the usage context."""
max_num_seqs: SkipValidation[int] = None # type: ignore max_num_seqs: int = Field(default=None, ge=1)
"""Maximum number of sequences to be processed in a single iteration. """Maximum number of sequences to be processed in a single iteration.
This config has no static default. If left unspecified by the user, it will This config has no static default. If left unspecified by the user, it will
be set in `EngineArgs.create_engine_config` based on the usage context.""" be set in `EngineArgs.create_engine_config` based on the usage context."""
max_model_len: SkipValidation[int] = None # type: ignore max_model_len: int = Field(default=None, ge=1)
"""Maximum length of a sequence (including prompt and generated text). This """Maximum length of a sequence (including prompt and generated text). This
is primarily set in `ModelConfig` and that value should be manually is primarily set in `ModelConfig` and that value should be manually
duplicated here.""" duplicated here."""
max_num_partial_prefills: int = 1 max_num_partial_prefills: int = Field(default=1, ge=1)
"""For chunked prefill, the maximum number of sequences that can be """For chunked prefill, the maximum number of sequences that can be
partially prefilled concurrently.""" partially prefilled concurrently."""
max_long_partial_prefills: int = 1 max_long_partial_prefills: int = Field(default=1, ge=1)
"""For chunked prefill, the maximum number of prompts longer than """For chunked prefill, the maximum number of prompts longer than
long_prefill_token_threshold that will be prefilled concurrently. Setting long_prefill_token_threshold that will be prefilled concurrently. Setting
this less than max_num_partial_prefills will allow shorter prompts to jump this less than max_num_partial_prefills will allow shorter prompts to jump
@ -62,7 +63,7 @@ class SchedulerConfig:
"""For chunked prefill, a request is considered long if the prompt is """For chunked prefill, a request is considered long if the prompt is
longer than this number of tokens.""" longer than this number of tokens."""
num_lookahead_slots: int = 0 num_lookahead_slots: int = Field(default=0, ge=0)
"""The number of slots to allocate per sequence per """The number of slots to allocate per sequence per
step, beyond the known token ids. This is used in speculative step, beyond the known token ids. This is used in speculative
decoding to store KV activations of tokens which may or may not be decoding to store KV activations of tokens which may or may not be
@ -71,7 +72,7 @@ class SchedulerConfig:
NOTE: This will be replaced by speculative config in the future; it is NOTE: This will be replaced by speculative config in the future; it is
present to enable correctness tests until then.""" present to enable correctness tests until then."""
enable_chunked_prefill: SkipValidation[bool] = None # type: ignore enable_chunked_prefill: bool = Field(default=None)
"""If True, prefill requests can be chunked based """If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens.""" on the remaining max_num_batched_tokens."""
@ -86,14 +87,14 @@ class SchedulerConfig:
""" """
# TODO (ywang96): Make this configurable. # TODO (ywang96): Make this configurable.
max_num_encoder_input_tokens: int = field(init=False) max_num_encoder_input_tokens: int = Field(init=False)
"""Multimodal encoder compute budget, only used in V1. """Multimodal encoder compute budget, only used in V1.
NOTE: This is not currently configurable. It will be overridden by NOTE: This is not currently configurable. It will be overridden by
max_num_batched_tokens in case max multimodal embedding size is larger.""" max_num_batched_tokens in case max multimodal embedding size is larger."""
# TODO (ywang96): Make this configurable. # TODO (ywang96): Make this configurable.
encoder_cache_size: int = field(init=False) encoder_cache_size: int = Field(init=False)
"""Multimodal encoder cache size, only used in V1. """Multimodal encoder cache size, only used in V1.
NOTE: This is not currently configurable. It will be overridden by NOTE: This is not currently configurable. It will be overridden by
@ -106,7 +107,7 @@ class SchedulerConfig:
- "priority" means requests are handled based on given priority (lower - "priority" means requests are handled based on given priority (lower
value means earlier handling) and time of arrival deciding any ties).""" value means earlier handling) and time of arrival deciding any ties)."""
chunked_prefill_enabled: bool = field(init=False) chunked_prefill_enabled: bool = Field(init=False)
"""True if chunked prefill is enabled.""" """True if chunked prefill is enabled."""
disable_chunked_mm_input: bool = False disable_chunked_mm_input: bool = False
@ -155,6 +156,20 @@ class SchedulerConfig:
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str return hash_str
@field_validator(
"max_num_batched_tokens",
"max_num_seqs",
"max_model_len",
"enable_chunked_prefill",
mode="wrap",
)
@classmethod
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
"""Skip validation if the value is `None` when initialisation is delayed."""
if value is None:
return value
return handler(value)
def __post_init__(self, is_encoder_decoder: bool) -> None: def __post_init__(self, is_encoder_decoder: bool) -> None:
if self.max_model_len is None: if self.max_model_len is None:
self.max_model_len = 8192 self.max_model_len = 8192
@ -260,19 +275,7 @@ class SchedulerConfig:
self.max_num_seqs * self.max_model_len, self.max_num_seqs * self.max_model_len,
) )
if self.num_lookahead_slots < 0: if self.max_num_partial_prefills > 1:
raise ValueError(
"num_lookahead_slots "
f"({self.num_lookahead_slots}) must be greater than or "
"equal to 0."
)
if self.max_num_partial_prefills < 1:
raise ValueError(
f"max_num_partial_prefills ({self.max_num_partial_prefills}) "
"must be greater than or equal to 1."
)
elif self.max_num_partial_prefills > 1:
if not self.chunked_prefill_enabled: if not self.chunked_prefill_enabled:
raise ValueError( raise ValueError(
"Chunked prefill must be enabled to set " "Chunked prefill must be enabled to set "
@ -286,13 +289,10 @@ class SchedulerConfig:
f"than the max_model_len ({self.max_model_len})." f"than the max_model_len ({self.max_model_len})."
) )
if (self.max_long_partial_prefills < 1) or ( if self.max_long_partial_prefills > self.max_num_partial_prefills:
self.max_long_partial_prefills > self.max_num_partial_prefills
):
raise ValueError( raise ValueError(
f"max_long_partial_prefills ({self.max_long_partial_prefills}) " f"{self.max_long_partial_prefills=} must be less than or equal to "
"must be greater than or equal to 1 and less than or equal to " f"{self.max_num_partial_prefills=}."
f"max_num_partial_prefills ({self.max_num_partial_prefills})."
) )
return self return self

View File

@ -78,10 +78,6 @@ class SpeculativeConfig:
draft_tensor_parallel_size: int | None = Field(default=None, ge=1) draft_tensor_parallel_size: int | None = Field(default=None, ge=1)
"""The degree of the tensor parallelism for the draft model. Can only be 1 """The degree of the tensor parallelism for the draft model. Can only be 1
or the same as the target model's tensor parallel size.""" or the same as the target model's tensor parallel size."""
disable_logprobs: bool = True
"""If set to True, token log probabilities are not returned during
speculative decoding. If set to False, token log probabilities are returned
according to the log probability settings in SamplingParams."""
# Draft model configuration # Draft model configuration
quantization: me_quant.QuantizationMethods | None = None quantization: me_quant.QuantizationMethods | None = None
@ -126,12 +122,6 @@ class SpeculativeConfig:
"""The configuration of the target model.""" """The configuration of the target model."""
target_parallel_config: SkipValidation[ParallelConfig] = None # type: ignore target_parallel_config: SkipValidation[ParallelConfig] = None # type: ignore
"""The parallel configuration for the target model.""" """The parallel configuration for the target model."""
enable_chunked_prefill: SkipValidation[bool] = None # type: ignore
"""Whether vLLM is configured to use chunked prefill or not. Used for
raising an error since it's not yet compatible with speculative decode."""
disable_log_stats: SkipValidation[bool] = None # type: ignore
"""Whether to disable the periodic printing of stage times in speculative
decoding."""
# params generated in the post-init stage # params generated in the post-init stage
draft_model_config: SkipValidation[ModelConfig] = None # type: ignore draft_model_config: SkipValidation[ModelConfig] = None # type: ignore

View File

@ -2,8 +2,9 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib import hashlib
from typing import Any, Literal from typing import Any, Literal, Self
from pydantic import model_validator
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
from vllm.config.utils import config from vllm.config.utils import config
@ -56,7 +57,8 @@ class StructuredOutputsConfig:
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str return hash_str
def __post_init__(self): @model_validator(mode="after")
def _validate_structured_output_config(self) -> Self:
if self.disable_any_whitespace and self.backend not in ("xgrammar", "guidance"): if self.disable_any_whitespace and self.backend not in ("xgrammar", "guidance"):
raise ValueError( raise ValueError(
"disable_any_whitespace is only supported for " "disable_any_whitespace is only supported for "
@ -67,3 +69,4 @@ class StructuredOutputsConfig:
"disable_additional_properties is only supported " "disable_additional_properties is only supported "
"for the guidance backend." "for the guidance backend."
) )
return self

View File

@ -289,6 +289,48 @@ class VllmConfig:
return replace(self, model_config=model_config) return replace(self, model_config=model_config)
def _post_init_kv_transfer_config(self) -> None:
"""Update KVTransferConfig based on top-level configs in VllmConfig.
Right now, this function reads the offloading settings from
CacheConfig and configures the KVTransferConfig accordingly.
"""
if (kv_offloading_backend := self.cache_config.kv_offloading_backend) is None:
return
# If no KVTransferConfig is provided, create a default one.
if self.kv_transfer_config is None:
self.kv_transfer_config = KVTransferConfig()
if (kv_offloading_size := self.cache_config.kv_offloading_size) is None:
raise ValueError(
"You must set kv_offloading_size when kv_offloading_backend is set."
)
num_kv_ranks = (
self.parallel_config.tensor_parallel_size
* self.parallel_config.pipeline_parallel_size
)
if kv_offloading_backend == "native":
self.kv_transfer_config.kv_connector = "OffloadingConnector"
kv_bytes_per_rank = kv_offloading_size * (1 << 30) / num_kv_ranks
# NOTE(ApostaC): the actual calculation for num_cpu_blocks should be
# done after the model's KV cache is initialized
self.kv_transfer_config.kv_connector_extra_config.update(
{"kv_bytes_per_rank": kv_bytes_per_rank, "num_cpu_blocks": 0}
)
elif kv_offloading_backend == "lmcache":
self.kv_transfer_config.kv_connector = "LMCacheConnectorV1"
kv_gb_per_rank = kv_offloading_size / num_kv_ranks
self.kv_transfer_config.kv_connector_extra_config = {
"lmcache.local_cpu": True,
"lmcache.max_local_cpu_size": kv_gb_per_rank,
}
# This is the same for all backends
self.kv_transfer_config.kv_role = "kv_both"
def __post_init__(self): def __post_init__(self):
"""Verify configs are valid & consistent with each other.""" """Verify configs are valid & consistent with each other."""
@ -646,6 +688,9 @@ class VllmConfig:
if "-quant_fp8" not in custom_ops: if "-quant_fp8" not in custom_ops:
custom_ops.append("+quant_fp8") custom_ops.append("+quant_fp8")
# Handle the KV connector configs
self._post_init_kv_transfer_config()
def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list: def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list:
# remove the sizes that not multiple of tp_size when # remove the sizes that not multiple of tp_size when
# enable sequence parallelism # enable sequence parallelism

View File

@ -6,7 +6,7 @@ KV cache helper for store.
from collections.abc import Sequence from collections.abc import Sequence
from concurrent.futures import CancelledError, Future from concurrent.futures import CancelledError, Future
from typing import TYPE_CHECKING, Literal, cast from typing import TYPE_CHECKING, Literal
import torch import torch
@ -138,8 +138,11 @@ class KVOutputAggregator:
return cls(connector.get_finished_count() or world_size) return cls(connector.get_finished_count() or world_size)
def aggregate( def aggregate(
self, outputs: list[ModelRunnerOutput], output_rank: int = 0 self, outputs: list[ModelRunnerOutput | None], output_rank: int = 0
) -> ModelRunnerOutput: ) -> ModelRunnerOutput | None:
if not outputs[output_rank]:
return None
# Aggregate kv_connector_output from all workers # Aggregate kv_connector_output from all workers
def update_finished_set( def update_finished_set(
@ -161,6 +164,7 @@ class KVOutputAggregator:
aggregated_kv_connector_stats = None aggregated_kv_connector_stats = None
invalid_block_ids = set[int]() invalid_block_ids = set[int]()
for model_runner_output in outputs: for model_runner_output in outputs:
assert model_runner_output is not None
kv_output = model_runner_output.kv_connector_output kv_output = model_runner_output.kv_connector_output
if not kv_output: if not kv_output:
continue continue
@ -204,6 +208,7 @@ class KVOutputAggregator:
# select output of the worker specified by output_rank # select output of the worker specified by output_rank
output = outputs[output_rank] output = outputs[output_rank]
assert output is not None
output.kv_connector_output = KVConnectorOutput( output.kv_connector_output = KVConnectorOutput(
finished_sending=finished_sending or None, finished_sending=finished_sending or None,
finished_recving=finished_recving or None, finished_recving=finished_recving or None,
@ -215,13 +220,16 @@ class KVOutputAggregator:
return output return output
def async_aggregate( def async_aggregate(
self, output_futures: Sequence[Future[ModelRunnerOutput]], output_rank: int = 0 self,
) -> Future[ModelRunnerOutput]: output_futures: Sequence[Future[ModelRunnerOutput | None]],
output_rank: int = 0,
) -> Future[ModelRunnerOutput | None]:
"""Takes a list of futures and returns a single future which resolves """Takes a list of futures and returns a single future which resolves
to the respective list of outputs.""" to the respective list of outputs."""
result_future: Future[ModelRunnerOutput] = Future() result_future: Future[ModelRunnerOutput | None] = Future()
outputs: list[ModelRunnerOutput | None] = [None] * len(output_futures) outputs: list[ModelRunnerOutput | None] = [None] * len(output_futures)
remaining = len(output_futures)
def make_callback(idx): def make_callback(idx):
def callback(fut): def callback(fut):
@ -236,12 +244,10 @@ class KVOutputAggregator:
result_future.set_exception(e) result_future.set_exception(e)
# this check assumes io_thread_pool uses a single thread # this check assumes io_thread_pool uses a single thread
if all(outputs): nonlocal remaining
result_future.set_result( remaining -= 1
self.aggregate( if not remaining:
cast(list[ModelRunnerOutput], outputs), output_rank result_future.set_result(self.aggregate(outputs, output_rank))
)
)
return callback return callback

View File

@ -122,6 +122,15 @@ class KVConnectorRole(enum.Enum):
WORKER = 1 WORKER = 1
class KVConnectorHandshakeMetadata(ABC): # noqa: B024
"""
Metadata used for out of band connector handshake between
P/D workers. This needs to serializeable.
"""
pass
class KVConnectorMetadata(ABC): # noqa: B024 class KVConnectorMetadata(ABC): # noqa: B024
""" """
Abstract Metadata used to communicate between the Abstract Metadata used to communicate between the
@ -320,6 +329,18 @@ class KVConnectorBase_V1(ABC):
""" """
return None return None
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
"""
Get the KVConnector handshake metadata for this connector.
This metadata is used for out-of-band connector handshake
between P/D workers.
Returns:
KVConnectorHandshakeMetadata: the handshake metadata.
None if no handshake metadata is available.
"""
return None
# ============================== # ==============================
# Scheduler-side methods # Scheduler-side methods
# ============================== # ==============================
@ -477,6 +498,17 @@ class KVConnectorBase_V1(ABC):
""" """
return None return None
def set_xfer_handshake_metadata(
self, metadata: dict[int, KVConnectorHandshakeMetadata]
) -> None:
"""
Set the KV connector handshake metadata for this connector.
Args:
metadata (KVConnectorHandshakeMetadata): the handshake metadata to set.
"""
return None
@classmethod @classmethod
def build_prom_metrics( def build_prom_metrics(
cls, cls,

View File

@ -27,6 +27,7 @@ from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
CopyBlocksOp, CopyBlocksOp,
KVConnectorBase_V1, KVConnectorBase_V1,
KVConnectorHandshakeMetadata,
KVConnectorMetadata, KVConnectorMetadata,
KVConnectorRole, KVConnectorRole,
) )
@ -93,15 +94,12 @@ _NIXL_SUPPORTED_DEVICE = {
_NIXL_SUPPORTED_DEVICE.update(current_platform.get_nixl_supported_devices()) _NIXL_SUPPORTED_DEVICE.update(current_platform.get_nixl_supported_devices())
class NixlAgentMetadata( @dataclass
msgspec.Struct, class NixlAgentMetadata(KVConnectorHandshakeMetadata):
omit_defaults=True, # type: ignore[call-arg]
# required for @cached_property.
dict=True,
):
engine_id: str engine_id: str
agent_metadata: bytes agent_metadata: bytes
kv_caches_base_addr: list[int] kv_caches_base_addr: list[int]
device_id: int
num_blocks: int num_blocks: int
block_lens: list[int] block_lens: list[int]
attn_backend_name: str attn_backend_name: str
@ -223,6 +221,18 @@ class NixlConnector(KVConnectorBase_V1):
assert self.connector_scheduler is not None assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, block_ids) return self.connector_scheduler.request_finished(request, block_ids)
def set_xfer_handshake_metadata(
self, metadata: dict[int, KVConnectorHandshakeMetadata]
) -> None:
"""
Set the KV connector handshake metadata for this connector.
Args:
metadata (dict): the handshake metadata to set.
"""
assert self.connector_scheduler is not None
self.connector_scheduler.set_xfer_handshake_metadata(metadata)
############################################################ ############################################################
# Worker Side Methods # Worker Side Methods
############################################################ ############################################################
@ -299,6 +309,21 @@ class NixlConnector(KVConnectorBase_V1):
def shutdown(self): def shutdown(self):
if self.connector_worker is not None: if self.connector_worker is not None:
self.connector_worker.shutdown() self.connector_worker.shutdown()
if self.connector_scheduler is not None:
self.connector_scheduler.shutdown()
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
"""
Get the KVConnector handshake metadata for this connector.
This metadata is used for out-of-band connector handshake
between P/D workers.
Returns:
KVConnectorHandshakeMetadata: the handshake metadata.
None if no handshake metadata is available.
"""
assert self.connector_worker is not None
return self.connector_worker.xfer_handshake_metadata
class NixlConnectorScheduler: class NixlConnectorScheduler:
@ -312,12 +337,16 @@ class NixlConnectorScheduler:
self.side_channel_port = ( self.side_channel_port = (
envs.VLLM_NIXL_SIDE_CHANNEL_PORT envs.VLLM_NIXL_SIDE_CHANNEL_PORT
+ vllm_config.parallel_config.data_parallel_rank + vllm_config.parallel_config.data_parallel_rank
* vllm_config.parallel_config.tensor_parallel_size
) )
assert vllm_config.kv_transfer_config is not None assert vllm_config.kv_transfer_config is not None
self.use_host_buffer = vllm_config.kv_transfer_config.kv_buffer_device == "cpu" self.use_host_buffer = vllm_config.kv_transfer_config.kv_buffer_device == "cpu"
logger.info("Initializing NIXL Scheduler %s", engine_id) logger.info("Initializing NIXL Scheduler %s", engine_id)
# Background thread for handling new handshake requests.
self._nixl_handshake_listener_t: threading.Thread | None = None
self._encoded_xfer_handshake_metadata: dict[int, Any] = {}
self._stop_event = threading.Event()
# Requests that need to start recv/send. # Requests that need to start recv/send.
# New requests are added by update_state_after_alloc in # New requests are added by update_state_after_alloc in
# the scheduler. Used to make metadata passed to Worker. # the scheduler. Used to make metadata passed to Worker.
@ -330,6 +359,89 @@ class NixlConnectorScheduler:
# remote prefill or aborted. # remote prefill or aborted.
self._reqs_not_processed: set[ReqId] = set() self._reqs_not_processed: set[ReqId] = set()
def shutdown(self):
self._stop_event.set()
if self._nixl_handshake_listener_t is not None:
self._nixl_handshake_listener_t.join()
self._nixl_handshake_listener_t = None
def set_xfer_handshake_metadata(
self, metadata: dict[int, KVConnectorHandshakeMetadata]
) -> None:
"""
Set the KV connector handshake metadata for this connector.
Args:
metadata (dict): the handshake metadata to set.
"""
encoded_data: dict[int, bytes] = {}
encoder = msgspec.msgpack.Encoder()
for tp_rank, rank_metadata in metadata.items():
if not isinstance(rank_metadata, NixlAgentMetadata):
raise ValueError(
"NixlConnectorScheduler expects NixlAgentMetadata for "
"handshake metadata."
)
encoded_data[tp_rank] = encoder.encode(rank_metadata)
logger.debug(
"Tp rank %d: encoded NixlAgentMetadata size: %s bytes",
tp_rank,
str(len(encoded_data[tp_rank])),
)
self._encoded_xfer_handshake_metadata = encoded_data
# Only start the listener when we have metadata to serve.
if self._nixl_handshake_listener_t is None:
ready_event = threading.Event()
self._nixl_handshake_listener_t = threading.Thread(
target=self._nixl_handshake_listener,
args=(
encoded_data,
ready_event,
self._stop_event,
self.side_channel_port,
),
daemon=True,
name="nixl_handshake_listener",
)
self._nixl_handshake_listener_t.start()
ready_event.wait() # Wait for listener ZMQ socket to be ready.
@staticmethod
def _nixl_handshake_listener(
encoded_data: dict[int, Any],
ready_event: threading.Event,
stop_event: threading.Event,
port: int,
):
"""Background thread for getting new NIXL handshakes."""
# NOTE(rob): this is a simple implementation. We will move
# to a better approach via HTTP endpoint soon.
# Listen for new requests for metadata.
host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
path = make_zmq_path("tcp", host, port)
logger.debug("Starting listening on path: %s", path)
with zmq_ctx(zmq.ROUTER, path) as sock:
sock.setsockopt(zmq.RCVTIMEO, 1000)
ready_event.set()
while True:
try:
identity, _, msg = sock.recv_multipart()
except zmq.Again:
if stop_event.is_set():
break
continue
# Decode the message which contains (GET_META_MSG, rank)
msg, target_tp_rank = msgspec.msgpack.decode(msg)
logger.debug(
"Received message for tp rank %s",
target_tp_rank,
)
if msg != GET_META_MSG:
logger.warning("Connection listener got unexpected message %s", msg)
sock.send_multipart((identity, b"", encoded_data[target_tp_rank]))
def get_num_new_matched_tokens( def get_num_new_matched_tokens(
self, request: "Request", num_computed_tokens: int self, request: "Request", num_computed_tokens: int
) -> tuple[int, bool]: ) -> tuple[int, bool]:
@ -537,8 +649,6 @@ class NixlConnectorScheduler:
class NixlConnectorWorker: class NixlConnectorWorker:
"""Implementation of Worker side methods""" """Implementation of Worker side methods"""
_POLL_TIMEOUT = 0.1 # Handshake thread polls for stop event every 100ms
@dataclass @dataclass
class TpKVTopology: class TpKVTopology:
""" """
@ -651,16 +761,6 @@ class NixlConnectorWorker:
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}. # Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict) self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict)
# NIXL handshake port.
# NOTE(rob): Within a DP group, each DP rank gets its own
# base port (which is sent in the KVTransferParams).
# Each TP rank listens/queries on the base_port + tp_rank.
self.side_channel_port: int = (
envs.VLLM_NIXL_SIDE_CHANNEL_PORT
+ vllm_config.parallel_config.data_parallel_rank
* vllm_config.parallel_config.tensor_parallel_size
)
# Metadata. # Metadata.
self.engine_id: EngineId = engine_id self.engine_id: EngineId = engine_id
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
@ -706,6 +806,7 @@ class NixlConnectorWorker:
# Map of engine_id -> kv_caches_base_addr. For TP case, each local # Map of engine_id -> kv_caches_base_addr. For TP case, each local
# rank will still only pull from a single remote TP worker. # rank will still only pull from a single remote TP worker.
self.kv_caches_base_addr: dict[EngineId, list[int]] = {} self.kv_caches_base_addr: dict[EngineId, list[int]] = {}
self.device_id: int = 0
# Number of NIXL regions. Currently one region per cache # Number of NIXL regions. Currently one region per cache
# (so 1 per layer for MLA, otherwise 2 per layer) # (so 1 per layer for MLA, otherwise 2 per layer)
@ -736,9 +837,8 @@ class NixlConnectorWorker:
# requests that skipped transfer (handshake or transfer failures) # requests that skipped transfer (handshake or transfer failures)
self._failed_recv_reqs: set[ReqId] = set() self._failed_recv_reqs: set[ReqId] = set()
# Background thread for handling new handshake requests. # Handshake metadata of this worker for NIXL transfers.
self._nixl_handshake_listener_t: threading.Thread | None = None self.xfer_handshake_metadata: NixlAgentMetadata | None = None
self._nixl_handshake_listener_stop_event: threading.Event | None = None
# Background thread for initializing new NIXL handshakes. # Background thread for initializing new NIXL handshakes.
self._handshake_initiation_executor = ThreadPoolExecutor( self._handshake_initiation_executor = ThreadPoolExecutor(
# NIXL is not guaranteed to be thread-safe, limit 1 worker. # NIXL is not guaranteed to be thread-safe, limit 1 worker.
@ -790,42 +890,6 @@ class NixlConnectorWorker:
total_num_kv_heads=self.model_config.get_total_num_kv_heads(), total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
) )
@staticmethod
def _nixl_handshake_listener(
metadata: NixlAgentMetadata,
ready_event: threading.Event,
stop_event: threading.Event,
base_port: int,
tp_rank: int,
):
"""Background thread for getting new NIXL handshakes."""
# NOTE(rob): this is a simple implementation. We will move
# to a better approach via HTTP endpoint soon.
encoder = msgspec.msgpack.Encoder()
encoded_data = encoder.encode(metadata)
size_in_bytes = len(encoded_data)
logger.debug("Size of encoded NixlAgentMetadata: %s bytes", str(size_in_bytes))
# Listen for new requests for metadata.
host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
path = make_zmq_path("tcp", host, base_port + tp_rank)
logger.debug("Starting listening on path: %s", path)
with zmq_ctx(zmq.ROUTER, path) as sock:
ready_event.set()
poller = zmq.Poller()
poller.register(sock, zmq.POLLIN)
while not stop_event.is_set():
events = dict(
poller.poll(timeout=NixlConnectorWorker._POLL_TIMEOUT * 1000)
)
if sock not in events:
continue
identity, _, msg = sock.recv_multipart()
if msg != GET_META_MSG:
logger.warning("Connection listener got unexpected message %s", msg)
sock.send_multipart((identity, b"", encoded_data))
def _nixl_handshake( def _nixl_handshake(
self, self,
host: str, host: str,
@ -844,16 +908,17 @@ class NixlConnectorWorker:
# Handshake only with the remote TP rank that current local rank will # Handshake only with the remote TP rank that current local rank will
# pull from. With homogeneous TP it happens to be the same rank_i. # pull from. With homogeneous TP it happens to be the same rank_i.
p_remote_rank = self.kv_topo.get_target_remote_rank(remote_tp_size) p_remote_rank = self.kv_topo.get_target_remote_rank(remote_tp_size)
path = make_zmq_path("tcp", host, port + p_remote_rank) path = make_zmq_path("tcp", host, port)
logger.debug( logger.debug(
"Querying metadata on path: %s at remote rank %s", path, p_remote_rank "Querying metadata on path: %s at remote tp rank %s", path, p_remote_rank
) )
# Send query for the request. # Send query for the request.
with zmq_ctx(zmq.REQ, path) as sock: with zmq_ctx(zmq.REQ, path) as sock:
msg = msgspec.msgpack.encode((GET_META_MSG, p_remote_rank))
# Set receive timeout to 5 seconds to avoid hanging on dead server # Set receive timeout to 5 seconds to avoid hanging on dead server
sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds
sock.send(GET_META_MSG) sock.send(msg)
metadata_bytes = sock.recv() metadata_bytes = sock.recv()
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
metadata = decoder.decode(metadata_bytes) metadata = decoder.decode(metadata_bytes)
@ -1042,6 +1107,10 @@ class NixlConnectorWorker:
assert tensor_size_bytes == curr_tensor_size_bytes, ( assert tensor_size_bytes == curr_tensor_size_bytes, (
"All kv cache tensors must have the same size" "All kv cache tensors must have the same size"
) )
# Need to make sure the device ID is non-negative for NIXL,
# Torch uses -1 to indicate CPU tensors while NIXL uses explicit
# memory type.
self.device_id = max(cache.get_device(), 0)
caches_data.append( caches_data.append(
(base_addr, curr_tensor_size_bytes, self.device_id, "") (base_addr, curr_tensor_size_bytes, self.device_id, "")
) )
@ -1139,10 +1208,11 @@ class NixlConnectorWorker:
assert len(self.block_window_per_layer) == self.num_layers assert len(self.block_window_per_layer) == self.num_layers
# After KV Caches registered, listen for new connections. # After KV Caches registered, listen for new connections.
metadata = NixlAgentMetadata( self.xfer_handshake_metadata = NixlAgentMetadata(
engine_id=self.engine_id, engine_id=self.engine_id,
agent_metadata=self.nixl_wrapper.get_agent_metadata(), agent_metadata=self.nixl_wrapper.get_agent_metadata(),
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
device_id=self.device_id,
num_blocks=self.num_blocks, num_blocks=self.num_blocks,
block_lens=self.block_len_per_layer, block_lens=self.block_len_per_layer,
attn_backend_name=self.backend_name, attn_backend_name=self.backend_name,
@ -1150,22 +1220,6 @@ class NixlConnectorWorker:
if not self.use_host_buffer if not self.use_host_buffer
else self.host_buffer_kv_cache_layout, else self.host_buffer_kv_cache_layout,
) )
ready_event, stop_event = threading.Event(), threading.Event()
self._nixl_handshake_listener_t = threading.Thread(
target=self._nixl_handshake_listener,
args=(
metadata,
ready_event,
stop_event,
self.side_channel_port,
self.tp_rank,
),
daemon=True,
name="nixl_handshake_listener",
)
self._nixl_handshake_listener_t.start()
self._nixl_handshake_listener_stop_event = stop_event
ready_event.wait() # Wait for listener ZMQ socket to be ready.
def add_remote_agent( def add_remote_agent(
self, self,
@ -1267,7 +1321,7 @@ class NixlConnectorWorker:
# self.block_len == remote_block_len//tp_ratio bytes. # self.block_len == remote_block_len//tp_ratio bytes.
addr = base_addr + block_offset + rank_offset addr = base_addr + block_offset + rank_offset
# (addr, len, device id) # (addr, len, device id)
blocks_data.append((addr, kv_block_len, remote_tp_rank)) blocks_data.append((addr, kv_block_len, nixl_agent_meta.device_id))
if self._use_flashinfer: if self._use_flashinfer:
# With FlashInfer index V separately to allow head splitting. # With FlashInfer index V separately to allow head splitting.
@ -1275,7 +1329,9 @@ class NixlConnectorWorker:
block_offset = block_id * nixl_agent_meta.block_lens[i] block_offset = block_id * nixl_agent_meta.block_lens[i]
addr = base_addr + block_offset + rank_offset addr = base_addr + block_offset + rank_offset
v_addr = addr + nixl_agent_meta.block_lens[i] // 2 v_addr = addr + nixl_agent_meta.block_lens[i] // 2
blocks_data.append((v_addr, kv_block_len, remote_tp_rank)) blocks_data.append(
(v_addr, kv_block_len, nixl_agent_meta.device_id)
)
logger.debug( logger.debug(
"Created %s blocks for dst engine %s with remote rank %s and local rank %s", "Created %s blocks for dst engine %s with remote rank %s and local rank %s",
@ -1843,14 +1899,6 @@ class NixlConnectorWorker:
def shutdown(self): def shutdown(self):
"""Shutdown the connector worker.""" """Shutdown the connector worker."""
self._handshake_initiation_executor.shutdown(wait=False) self._handshake_initiation_executor.shutdown(wait=False)
if self._nixl_handshake_listener_stop_event is not None:
self._nixl_handshake_listener_stop_event.set()
self._nixl_handshake_listener_stop_event = None
if self._nixl_handshake_listener_t is not None:
# Generous timeout to allow the thread to exit
self._nixl_handshake_listener_t.join(timeout=self._POLL_TIMEOUT * 10)
assert not self._nixl_handshake_listener_t.is_alive()
self._nixl_handshake_listener_t = None
for handles in self._recving_transfers.values(): for handles in self._recving_transfers.values():
for handle, _ in handles: for handle, _ in handles:
self.nixl_wrapper.release_xfer_handle(handle) self.nixl_wrapper.release_xfer_handle(handle)

View File

@ -54,7 +54,13 @@ from vllm.config import (
VllmConfig, VllmConfig,
get_attr_docs, get_attr_docs,
) )
from vllm.config.cache import BlockSize, CacheDType, MambaDType, PrefixCachingHashAlgo from vllm.config.cache import (
BlockSize,
CacheDType,
KVOffloadingBackend,
MambaDType,
PrefixCachingHashAlgo,
)
from vllm.config.device import Device from vllm.config.device import Device
from vllm.config.model import ( from vllm.config.model import (
ConvertOption, ConvertOption,
@ -553,6 +559,11 @@ class EngineArgs:
kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill
kv_offloading_size: float | None = CacheConfig.kv_offloading_size
kv_offloading_backend: KVOffloadingBackend | None = (
CacheConfig.kv_offloading_backend
)
def __post_init__(self): def __post_init__(self):
# support `EngineArgs(compilation_config={...})` # support `EngineArgs(compilation_config={...})`
# without having to manually construct a # without having to manually construct a
@ -896,6 +907,12 @@ class EngineArgs:
cache_group.add_argument( cache_group.add_argument(
"--mamba-block-size", **cache_kwargs["mamba_block_size"] "--mamba-block-size", **cache_kwargs["mamba_block_size"]
) )
cache_group.add_argument(
"--kv-offloading-size", **cache_kwargs["kv_offloading_size"]
)
cache_group.add_argument(
"--kv-offloading-backend", **cache_kwargs["kv_offloading_backend"]
)
# Multimodal related configs # Multimodal related configs
multimodal_kwargs = get_kwargs(MultiModalConfig) multimodal_kwargs = get_kwargs(MultiModalConfig)
@ -1246,8 +1263,6 @@ class EngineArgs:
self, self,
target_model_config: ModelConfig, target_model_config: ModelConfig,
target_parallel_config: ParallelConfig, target_parallel_config: ParallelConfig,
enable_chunked_prefill: bool,
disable_log_stats: bool,
) -> SpeculativeConfig | None: ) -> SpeculativeConfig | None:
"""Initializes and returns a SpeculativeConfig object based on """Initializes and returns a SpeculativeConfig object based on
`speculative_config`. `speculative_config`.
@ -1267,8 +1282,6 @@ class EngineArgs:
{ {
"target_model_config": target_model_config, "target_model_config": target_model_config,
"target_parallel_config": target_parallel_config, "target_parallel_config": target_parallel_config,
"enable_chunked_prefill": enable_chunked_prefill,
"disable_log_stats": disable_log_stats,
} }
) )
return SpeculativeConfig(**self.speculative_config) return SpeculativeConfig(**self.speculative_config)
@ -1391,6 +1404,8 @@ class EngineArgs:
mamba_cache_dtype=self.mamba_cache_dtype, mamba_cache_dtype=self.mamba_cache_dtype,
mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype, mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
mamba_block_size=self.mamba_block_size, mamba_block_size=self.mamba_block_size,
kv_offloading_size=self.kv_offloading_size,
kv_offloading_backend=self.kv_offloading_backend,
) )
ray_runtime_env = None ray_runtime_env = None
@ -1561,8 +1576,6 @@ class EngineArgs:
speculative_config = self.create_speculative_config( speculative_config = self.create_speculative_config(
target_model_config=model_config, target_model_config=model_config,
target_parallel_config=parallel_config, target_parallel_config=parallel_config,
enable_chunked_prefill=self.enable_chunked_prefill,
disable_log_stats=self.disable_log_stats,
) )
# make sure num_lookahead_slots is set appropriately depending on # make sure num_lookahead_slots is set appropriately depending on
@ -1813,7 +1826,7 @@ class EngineArgs:
incremental_prefill_supported = ( incremental_prefill_supported = (
pooling_type is not None pooling_type is not None
and pooling_type.lower() == "last" and pooling_type.lower() == "last"
and is_causal and bool(is_causal)
) )
action = "Enabling" if incremental_prefill_supported else "Disabling" action = "Enabling" if incremental_prefill_supported else "Disabling"

View File

@ -241,6 +241,7 @@ async def build_async_engine_client_from_engine_args(
) )
# Don't keep the dummy data in memory # Don't keep the dummy data in memory
assert async_llm is not None
await async_llm.reset_mm_cache() await async_llm.reset_mm_cache()
yield async_llm yield async_llm

View File

@ -345,22 +345,7 @@ class OpenAIServing:
if is_explicit_encoder_decoder_prompt(prompt): if is_explicit_encoder_decoder_prompt(prompt):
raise NotImplementedError raise NotImplementedError
else:
processed_inputs = processor.input_preprocessor._prompt_to_llm_inputs(
prompt
)
if processed_inputs["type"] == "embeds":
raise NotImplementedError
# This is a workaround to fix multimodal beam search; this is a
# bandaid fix for 2 small problems:
# 1. Multi_modal_data on the processed_inputs currently resolves to
# `None`.
# 2. preprocessing above expands the multimodal placeholders. However,
# this happens again in generation, so the double expansion causes
# a mismatch.
# TODO - would be ideal to handle this more gracefully.
prompt_text: str | None prompt_text: str | None
prompt_token_ids: list[int] prompt_token_ids: list[int]
multi_modal_data: MultiModalDataDict | None multi_modal_data: MultiModalDataDict | None
@ -373,9 +358,16 @@ class OpenAIServing:
prompt_token_ids = prompt.get("prompt_token_ids", []) # type: ignore prompt_token_ids = prompt.get("prompt_token_ids", []) # type: ignore
multi_modal_data = prompt.get("multi_modal_data") # type: ignore multi_modal_data = prompt.get("multi_modal_data") # type: ignore
mm_processor_kwargs: dict[str, Any] | None = processed_inputs.get( mm_processor_kwargs: dict[str, Any] | None = None
"mm_processor_kwargs"
) # type: ignore # This is a workaround to fix multimodal beam search; this is a
# bandaid fix for 2 small problems:
# 1. Multi_modal_data on the processed_inputs currently resolves to
# `None`.
# 2. preprocessing above expands the multimodal placeholders. However,
# this happens again in generation, so the double expansion causes
# a mismatch.
# TODO - would be ideal to handle this more gracefully.
tokenized_length = len(prompt_token_ids) tokenized_length = len(prompt_token_ids)

View File

@ -2,11 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json import json
import re
import uuid import uuid
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any from typing import Any
import regex as re
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest, ChatCompletionRequest,
DeltaFunctionCall, DeltaFunctionCall,

View File

@ -15,9 +15,7 @@ from vllm.distributed.parallel_state import (
from vllm.lora.layers.base import BaseLayerWithLoRA from vllm.lora.layers.base import BaseLayerWithLoRA
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
_get_config_dtype_str, _get_config_dtype_str,
mxfp4_w4a16_moe_quant_config,
) )
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
modular_marlin_fused_moe, modular_marlin_fused_moe,
@ -26,13 +24,16 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
modular_triton_fused_moe, modular_triton_fused_moe,
try_get_optimal_moe_config, try_get_optimal_moe_config,
) )
from vllm.model_executor.layers.quantization.mxfp4 import Mxfp4Config
class FusedMoEWithLoRA(BaseLayerWithLoRA): class FusedMoEWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: FusedMoE) -> None: def __init__(self, base_layer: FusedMoE) -> None:
super().__init__() super().__init__()
self.base_layer = base_layer self.base_layer = base_layer
assert not self.base_layer.use_ep, (
"EP support for Fused MoE LoRA is not implemented yet."
)
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
self.device = base_layer.w2_weight.device self.device = base_layer.w2_weight.device
@ -42,17 +43,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
moe_state_dict = {} moe_state_dict = {}
top_k = self.base_layer.top_k top_k = self.base_layer.top_k
if self.base_layer.quant_config is None: self.base_layer.ensure_moe_quant_config_init()
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG quant_config = self.base_layer.quant_method.moe_quant_config
elif not isinstance(self.base_layer.quant_config, Mxfp4Config):
quant_config = self.base_layer.quant_config
else:
quant_config = mxfp4_w4a16_moe_quant_config(
w1_bias=self.base_layer.w13_bias,
w2_bias=self.base_layer.w2_bias,
w1_scale=self.base_layer.w13_weight_scale,
w2_scale=self.base_layer.w2_weight_scale,
)
m_fused_moe_fn = ( m_fused_moe_fn = (
modular_triton_fused_moe( modular_triton_fused_moe(
@ -69,7 +61,6 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
moe_state_dict["hidden_states"] = kwargs["hidden_states"] moe_state_dict["hidden_states"] = kwargs["hidden_states"]
moe_state_dict["topk_ids"] = kwargs["topk_ids"] moe_state_dict["topk_ids"] = kwargs["topk_ids"]
moe_state_dict["topk_weights"] = kwargs["topk_weights"] moe_state_dict["topk_weights"] = kwargs["topk_weights"]
moe_state_dict["global_num_experts"] = kwargs["global_num_experts"]
moe_state_dict["expert_map"] = kwargs["expert_map"] moe_state_dict["expert_map"] = kwargs["expert_map"]
moe_state_dict["apply_router_weight_on_input"] = kwargs[ moe_state_dict["apply_router_weight_on_input"] = kwargs[
"apply_router_weight_on_input" "apply_router_weight_on_input"
@ -86,7 +77,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
hidden_states = moe_state_dict["hidden_states"] hidden_states = moe_state_dict["hidden_states"]
topk_weights = moe_state_dict["topk_weights"] topk_weights = moe_state_dict["topk_weights"]
curr_topk_ids = moe_state_dict["topk_ids"] curr_topk_ids = moe_state_dict["topk_ids"]
global_num_experts = moe_state_dict["global_num_experts"]
expert_map = moe_state_dict["expert_map"] expert_map = moe_state_dict["expert_map"]
config_dtype = _get_config_dtype_str( config_dtype = _get_config_dtype_str(
@ -118,7 +109,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
curr_topk_ids, curr_topk_ids,
num_tokens, num_tokens,
config["BLOCK_SIZE_M"], config["BLOCK_SIZE_M"],
global_num_experts, self.base_layer.local_num_experts,
max_loras, max_loras,
expert_map, expert_map,
) )
@ -236,14 +227,10 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
) -> None: ) -> None:
"""Initializes lora matrices.""" """Initializes lora matrices."""
assert not self.base_layer.use_ep, (
"EP support for Fused MoE LoRA is not implemented yet."
)
self.w1_lora_a_stacked = torch.zeros( self.w1_lora_a_stacked = torch.zeros(
( (
max_loras, max_loras,
self.base_layer.global_num_experts, self.base_layer.local_num_experts,
lora_config.max_lora_rank, lora_config.max_lora_rank,
self.base_layer.hidden_size, self.base_layer.hidden_size,
), ),
@ -253,7 +240,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.w1_lora_b_stacked = torch.zeros( self.w1_lora_b_stacked = torch.zeros(
( (
max_loras, max_loras,
self.base_layer.global_num_experts, self.base_layer.local_num_experts,
self.base_layer.intermediate_size_per_partition, self.base_layer.intermediate_size_per_partition,
lora_config.max_lora_rank, lora_config.max_lora_rank,
), ),
@ -264,7 +251,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.w2_lora_a_stacked = torch.zeros( self.w2_lora_a_stacked = torch.zeros(
( (
max_loras, max_loras,
self.base_layer.global_num_experts, self.base_layer.local_num_experts,
lora_config.max_lora_rank, lora_config.max_lora_rank,
self.base_layer.intermediate_size_per_partition, self.base_layer.intermediate_size_per_partition,
), ),
@ -274,7 +261,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.w2_lora_b_stacked = torch.zeros( self.w2_lora_b_stacked = torch.zeros(
( (
max_loras, max_loras,
self.base_layer.global_num_experts, self.base_layer.local_num_experts,
self.base_layer.hidden_size, self.base_layer.hidden_size,
lora_config.max_lora_rank, lora_config.max_lora_rank,
), ),
@ -285,7 +272,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.w3_lora_a_stacked = torch.zeros( self.w3_lora_a_stacked = torch.zeros(
( (
max_loras, max_loras,
self.base_layer.global_num_experts, self.base_layer.local_num_experts,
lora_config.max_lora_rank, lora_config.max_lora_rank,
self.base_layer.hidden_size, self.base_layer.hidden_size,
), ),
@ -295,7 +282,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.w3_lora_b_stacked = torch.zeros( self.w3_lora_b_stacked = torch.zeros(
( (
max_loras, max_loras,
self.base_layer.global_num_experts, self.base_layer.local_num_experts,
self.base_layer.intermediate_size_per_partition, self.base_layer.intermediate_size_per_partition,
lora_config.max_lora_rank, lora_config.max_lora_rank,
), ),
@ -308,7 +295,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.lora_a_stacked = [] self.lora_a_stacked = []
self.lora_b_stacked = [] self.lora_b_stacked = []
for lora_id in range(max_loras): for lora_id in range(max_loras):
for experts_id in range(self.base_layer.global_num_experts): for experts_id in range(self.base_layer.local_num_experts):
# gate_proj,down_proj,up_proj # gate_proj,down_proj,up_proj
self.lora_a_stacked.append(self.w1_lora_a_stacked[lora_id][experts_id]) self.lora_a_stacked.append(self.w1_lora_a_stacked[lora_id][experts_id])
self.lora_a_stacked.append(self.w2_lora_a_stacked[lora_id][experts_id]) self.lora_a_stacked.append(self.w2_lora_a_stacked[lora_id][experts_id])

View File

@ -88,14 +88,17 @@ def _fused_moe_lora_kernel(
grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K) grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)
# calculate pid_m,pid_n # calculate pid_m,pid_n
pid_sk = pid % SPLIT_K
pid_m_n = pid // SPLIT_K
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group group_id = pid_m_n // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) pid_m = first_pid_m + ((pid_m_n % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m pid_n = (pid_m_n % num_pid_in_group) // group_size_m
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr + lora_idx) num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr + lora_idx)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
@ -113,7 +116,7 @@ def _fused_moe_lora_kernel(
cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K) offs_k = pid_sk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
token_ind = stride_tl * lora_idx + offs_token_id token_ind = stride_tl * lora_idx + offs_token_id
@ -131,7 +134,8 @@ def _fused_moe_lora_kernel(
cur_b_ptr cur_b_ptr
+ lora_idx * stride_bl + lora_idx * stride_bl
+ expert_id * stride_be + expert_id * stride_be
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + offs_k[:, None] * stride_bk
+ offs_bn[None, :] * stride_bn
) )
# accumulator # accumulator

View File

@ -56,6 +56,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
ep_size: int = 1, ep_size: int = 1,
tp_rank: int = 0, tp_rank: int = 0,
tp_size: int = 1, tp_size: int = 1,
use_dp: bool = False,
): ):
super().__init__(quant_config) super().__init__(quant_config)
assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn, None), ( assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn, None), (
@ -67,6 +68,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
self.tp_rank = tp_rank self.tp_rank = tp_rank
self.tp_size = tp_size self.tp_size = tp_size
self.out_dtype = out_dtype self.out_dtype = out_dtype
self.use_dp = use_dp
@property @property
def activation_formats( def activation_formats(
@ -117,7 +119,8 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
""" """
workspace1 = (M, K) workspace1 = (M, K)
workspace2 = (0,) workspace2 = (0,)
output_shape = (M, K * 2 if self.quant_dtype == "nvfp4" else K) # For TP, the quantization is fused with fused_moe call.
output_shape = (M, K * 2 if self.quant_dtype == "nvfp4" and self.use_dp else K)
# The workspace is determined by `aq`, since it comes after any # The workspace is determined by `aq`, since it comes after any
# potential communication op and is involved in the expert computation. # potential communication op and is involved in the expert computation.
return (workspace1, workspace2, output_shape) return (workspace1, workspace2, output_shape)
@ -214,6 +217,7 @@ def flashinfer_cutlass_moe_fp4(
FlashInferExperts( FlashInferExperts(
out_dtype=hidden_states.dtype, out_dtype=hidden_states.dtype,
quant_config=quant_config, quant_config=quant_config,
use_dp=False,
), ),
) )

View File

@ -170,6 +170,8 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
self._apply_router_weight_on_input( self._apply_router_weight_on_input(
a1, topk_weights, topk_ids, apply_router_weight_on_input a1, topk_weights, topk_ids, apply_router_weight_on_input
) )
if not self.use_dp:
return a1, None, None, topk_ids, topk_weights
a1q, a1q_scale = moe_kernel_quantize_input( a1q, a1q_scale = moe_kernel_quantize_input(
a1, a1,
@ -179,14 +181,13 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
quant_config.block_shape, quant_config.block_shape,
is_fp4_scale_swizzled=not self.use_dp, is_fp4_scale_swizzled=not self.use_dp,
) )
if self.use_dp: topk_weights, topk_ids, a1q, a1q_scale = get_dp_group().all_gatherv(
topk_weights, topk_ids, a1q, a1q_scale = get_dp_group().all_gatherv( [topk_weights, topk_ids, a1q, a1q_scale],
[topk_weights, topk_ids, a1q, a1q_scale], dim=0,
dim=0, sizes=get_local_sizes(),
sizes=get_local_sizes(), )
) if quant_config.quant_dtype == "nvfp4":
if quant_config.quant_dtype == "nvfp4": a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
return a1q, a1q_scale, None, topk_ids, topk_weights return a1q, a1q_scale, None, topk_ids, topk_weights

View File

@ -672,8 +672,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
) )
elif self.fused_experts is not None: elif self.fused_experts is not None:
if self.moe.has_bias:
raise ValueError("FusedMoEModularKernel does not support bias.")
result = self.fused_experts( result = self.fused_experts(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,

View File

@ -40,18 +40,36 @@ logger = init_logger(__name__)
def kda_attention( def kda_attention(
hidden_states: torch.Tensor, q_proj_states: torch.Tensor,
output: torch.Tensor, k_proj_states: torch.Tensor,
v_proj_states: torch.Tensor,
g1: torch.Tensor,
g2: torch.Tensor,
beta: torch.Tensor,
core_attn_out: torch.Tensor,
layer_name: str, layer_name: str,
) -> None: ) -> None:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
self._forward(hidden_states=hidden_states, output=output) self._forward(
q_proj_states=q_proj_states,
k_proj_states=k_proj_states,
v_proj_states=v_proj_states,
g1=g1,
g2=g2,
beta=beta,
core_attn_out=core_attn_out,
)
def kda_attention_fake( def kda_attention_fake(
hidden_states: torch.Tensor, q_proj_states: torch.Tensor,
output: torch.Tensor, k_proj_states: torch.Tensor,
v_proj_states: torch.Tensor,
g1: torch.Tensor,
g2: torch.Tensor,
beta: torch.Tensor,
core_attn_out: torch.Tensor,
layer_name: str, layer_name: str,
) -> None: ) -> None:
return return
@ -60,7 +78,7 @@ def kda_attention_fake(
direct_register_custom_op( direct_register_custom_op(
op_name="kda_attention", op_name="kda_attention",
op_func=kda_attention, op_func=kda_attention,
mutates_args=["output"], mutates_args=["core_attn_out"],
fake_impl=kda_attention_fake, fake_impl=kda_attention_fake,
) )
@ -242,36 +260,54 @@ class KimiDeltaAttention(nn.Module, MambaBase):
positions: torch.Tensor, positions: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
) -> None: ) -> None:
return torch.ops.vllm.kda_attention( num_tokens = hidden_states.size(0)
hidden_states, q = self.q_proj(hidden_states)[0]
output, k = self.k_proj(hidden_states)[0]
v = self.v_proj(hidden_states)[0]
beta = self.b_proj(hidden_states)[0].float().sigmoid()
g1 = self.f_b_proj(self.f_a_proj(hidden_states)[0])[0]
g1 = fused_kda_gate(g1, self.A_log, self.head_dim, g_bias=self.dt_bias)
beta = beta.unsqueeze(0)
g1 = g1.unsqueeze(0)
g_proj_states = self.g_b_proj(self.g_a_proj(hidden_states)[0])[0]
g2 = rearrange(g_proj_states, "... (h d) -> ... h d", d=self.head_dim)
core_attn_out = torch.zeros(
(1, num_tokens, self.local_num_heads, self.head_dim),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
torch.ops.vllm.kda_attention(
q,
k,
v,
g1,
g2,
beta,
core_attn_out,
self.prefix, self.prefix,
) )
core_attn_out = self.o_norm(core_attn_out, g2)
core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)")
output[:] = self.o_proj(core_attn_out)[0]
def _forward( def _forward(
self, self,
hidden_states: torch.Tensor, q_proj_states: torch.Tensor,
output: torch.Tensor, k_proj_states: torch.Tensor,
v_proj_states: torch.Tensor,
g1: torch.Tensor,
g2: torch.Tensor,
beta: torch.Tensor,
core_attn_out: torch.Tensor,
) -> None: ) -> None:
forward_context = get_forward_context() forward_context = get_forward_context()
attn_metadata: AttentionMetadata = forward_context.attn_metadata attn_metadata: AttentionMetadata = forward_context.attn_metadata
if attn_metadata is None: if attn_metadata is None:
# V1 profile run # # V1 profile run
# Mimic the memory allocation in the real run
q = torch.empty_like(hidden_states)
k = torch.empty_like(hidden_states)
v = torch.empty_like(hidden_states)
g = hidden_states.new_empty(
hidden_states.size(0),
self.local_num_heads,
self.head_dim,
dtype=torch.float32,
)
beta = torch.empty(
hidden_states.size(0), self.local_num_heads, dtype=torch.float32
)
core_attn_out = torch.empty_like(hidden_states)
return return
assert isinstance(attn_metadata, dict) assert isinstance(attn_metadata, dict)
@ -288,10 +324,6 @@ class KimiDeltaAttention(nn.Module, MambaBase):
conv_state_k = conv_state_k.transpose(-1, -2) conv_state_k = conv_state_k.transpose(-1, -2)
conv_state_v = conv_state_v.transpose(-1, -2) conv_state_v = conv_state_v.transpose(-1, -2)
q_proj_states = self.q_proj(hidden_states)[0]
k_proj_states = self.k_proj(hidden_states)[0]
v_proj_states = self.v_proj(hidden_states)[0]
q_conv_weights = self.q_conv1d.weight.view( q_conv_weights = self.q_conv1d.weight.view(
self.q_conv1d.weight.size(0), self.q_conv1d.weight.size(2) self.q_conv1d.weight.size(0), self.q_conv1d.weight.size(2)
) )
@ -374,14 +406,6 @@ class KimiDeltaAttention(nn.Module, MambaBase):
lambda x: rearrange(x, "n (h d) -> 1 n h d", d=self.head_dim), (q, k, v) lambda x: rearrange(x, "n (h d) -> 1 n h d", d=self.head_dim), (q, k, v)
) )
beta = self.b_proj(hidden_states)[0].float().sigmoid()
g = self.f_b_proj(self.f_a_proj(hidden_states)[0])[0]
g = fused_kda_gate(g, self.A_log, self.head_dim, g_bias=self.dt_bias)
beta = beta.unsqueeze(0)
g = g.unsqueeze(0)
if attn_metadata.num_prefills > 0: if attn_metadata.num_prefills > 0:
zero_idx = non_spec_state_indices_tensor[~has_initial_state] zero_idx = non_spec_state_indices_tensor[~has_initial_state]
recurrent_state[zero_idx] = 0 recurrent_state[zero_idx] = 0
@ -393,7 +417,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
q=q, q=q,
k=k, k=k,
v=v, v=v,
g=g, g=g1,
beta=beta, beta=beta,
initial_state=initial_state, initial_state=initial_state,
output_final_state=True, output_final_state=True,
@ -410,17 +434,12 @@ class KimiDeltaAttention(nn.Module, MambaBase):
q=q, q=q,
k=k, k=k,
v=v, v=v,
g=g, g=g1,
beta=beta, beta=beta,
initial_state=recurrent_state, initial_state=recurrent_state,
use_qk_l2norm_in_kernel=True, use_qk_l2norm_in_kernel=True,
cu_seqlens=non_spec_query_start_loc, cu_seqlens=non_spec_query_start_loc,
ssm_state_indices=non_spec_state_indices_tensor, ssm_state_indices=non_spec_state_indices_tensor,
) )
assert core_attn_out_non_spec.shape == core_attn_out.shape
g_proj_states = self.g_b_proj(self.g_a_proj(hidden_states)[0])[0] core_attn_out[:] = core_attn_out_non_spec
g = rearrange(g_proj_states, "... (h d) -> ... h d", d=self.head_dim)
core_attn_out = self.o_norm(core_attn_out_non_spec, g)
core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)")
output[:] = self.o_proj(core_attn_out)[0]

View File

@ -1769,29 +1769,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
expert_map=expert_map, expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
) )
elif (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
):
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
flashinfer_cutlass_moe_fp4,
)
assert self.moe_quant_config is not None
return flashinfer_cutlass_moe_fp4(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
quant_config=self.moe_quant_config,
inplace=False,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
else: else:
# If no modular kernel is provided, use cutlass_moe_fp4 for TP case # If no modular kernel is provided, use cutlass_moe_fp4 for TP case
# only (no EP). # only (no EP).

View File

@ -79,6 +79,7 @@ def select_nvfp4_gemm_impl(
ep_size=moe.moe_parallel_config.ep_size, ep_size=moe.moe_parallel_config.ep_size,
tp_rank=moe.moe_parallel_config.tp_rank, tp_rank=moe.moe_parallel_config.tp_rank,
tp_size=moe.moe_parallel_config.tp_size, tp_size=moe.moe_parallel_config.tp_size,
use_dp=moe.moe_parallel_config.dp_size > 1,
) )
# native cutlass experts currently don't support DP; TP case won't call this # native cutlass experts currently don't support DP; TP case won't call this

View File

@ -26,6 +26,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only GLM-4V model compatible with HuggingFace weights.""" """Inference-only GLM-4V model compatible with HuggingFace weights."""
import itertools
import math import math
from collections.abc import Callable, Iterable, Mapping, Sequence from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import partial from functools import partial
@ -36,7 +37,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from transformers import BatchFeature from transformers import BatchFeature, PretrainedConfig
from transformers.models.glm4v.configuration_glm4v import Glm4vVisionConfig from transformers.models.glm4v.configuration_glm4v import Glm4vVisionConfig
from transformers.models.glm4v.image_processing_glm4v import ( from transformers.models.glm4v.image_processing_glm4v import (
Glm4vImageProcessor, Glm4vImageProcessor,
@ -89,6 +90,7 @@ from ..layers.activation import SiluAndMul
from .interfaces import ( from .interfaces import (
MultiModalEmbeddings, MultiModalEmbeddings,
SupportsLoRA, SupportsLoRA,
SupportsMRoPE,
SupportsMultiModal, SupportsMultiModal,
SupportsPP, SupportsPP,
) )
@ -1386,7 +1388,7 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]):
dummy_inputs=Glm4vDummyInputsBuilder, dummy_inputs=Glm4vDummyInputsBuilder,
) )
class Glm4vForConditionalGeneration( class Glm4vForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
): ):
merge_by_field_config = True merge_by_field_config = True
@ -1613,6 +1615,149 @@ class Glm4vForConditionalGeneration(
multimodal_embeddings += tuple(video_embeddings) multimodal_embeddings += tuple(video_embeddings)
return multimodal_embeddings return multimodal_embeddings
def get_mrope_input_positions(
self,
input_tokens: list[int],
hf_config: "PretrainedConfig",
image_grid_thw: list[list[int]] | torch.Tensor | None,
video_grid_thw: list[list[int]] | torch.Tensor | None,
second_per_grid_ts: list[float] | None = None,
context_len: int = 0,
seq_len: int | None = None,
audio_feature_lengths: torch.Tensor | None = None,
use_audio_in_video: bool = False,
) -> tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value for GLM4V."""
image_token_id = hf_config.image_token_id
video_start_token_id = hf_config.video_start_token_id
video_end_token_id = hf_config.video_end_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
llm_pos_ids_list: list = []
if not (image_grid_thw is None and video_grid_thw is None):
if isinstance(image_grid_thw, torch.Tensor):
image_grid_thw = image_grid_thw.tolist()
input_token_type: list[str] = []
video_check_flg = False
for token in input_tokens:
if token == video_start_token_id:
video_check_flg = True
elif token == video_end_token_id:
video_check_flg = False
if (token == image_token_id) and (video_check_flg is False):
input_token_type.append("image")
elif (token == image_token_id) and (video_check_flg is True):
input_token_type.append("video")
else:
input_token_type.append("text")
input_type_group: list[tuple[str, int, int]] = []
for key, group_iter in itertools.groupby(
enumerate(input_token_type), lambda x: x[1]
):
group_list = list(group_iter)
start_index = group_list[0][0]
end_index = group_list[-1][0] + 1
input_type_group.append((key, start_index, end_index))
video_frame_num = 1
mm_data_idx = 0
for modality_type, start_idx, end_idx in input_type_group:
st_idx = (
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
)
if modality_type == "image":
t, h, w = (
image_grid_thw[mm_data_idx][0],
image_grid_thw[mm_data_idx][1],
image_grid_thw[mm_data_idx][2],
)
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_merge_size,
w // spatial_merge_size,
)
t_index = (
torch.arange(llm_grid_t)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
.flatten()
)
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(llm_grid_t, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(llm_grid_t, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx
)
mm_data_idx += 1
elif modality_type == "video":
t, h, w = (
video_frame_num,
image_grid_thw[mm_data_idx][1],
image_grid_thw[mm_data_idx][2],
)
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_merge_size,
w // spatial_merge_size,
)
for t_idx in range(llm_grid_t):
t_index = (
torch.tensor(t_idx)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
.flatten()
)
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(1, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(1, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx
)
mm_data_idx += 1
video_frame_num += 1
else:
text_len = end_idx - start_idx
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
video_frame_num = 1
else:
text_len = len(input_tokens)
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1))
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
llm_positions = llm_positions[:, context_len:seq_len]
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
return llm_positions, mrope_position_delta
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,

View File

@ -17,7 +17,9 @@ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPo
from transformers.utils import torch_int from transformers.utils import torch_int
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability from vllm.attention.layer import (
maybe_get_vit_flash_attn_backend,
)
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
@ -56,12 +58,14 @@ from vllm.multimodal.processing import (
PromptUpdate, PromptUpdate,
) )
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import ( from .interfaces import (
MultiModalEmbeddings, MultiModalEmbeddings,
SupportsLoRA, SupportsLoRA,
SupportsMRoPE,
SupportsMultiModal, SupportsMultiModal,
SupportsPP, SupportsPP,
) )
@ -337,7 +341,10 @@ def apply_rotary_pos_emb_flashatt(
cos = cos.chunk(2, dim=-1)[0].contiguous() cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous() sin = sin.chunk(2, dim=-1)[0].contiguous()
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
elif current_platform.is_rocm():
from flash_attn.ops.triton.rotary import apply_rotary as apply_rotary_emb
q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q) q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k) k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
@ -398,18 +405,28 @@ class KeyeSiglipAttention(nn.Module):
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
) )
self.use_upstream_fa = False self.attn_backend, self.flash_attn_varlen_func = (
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( maybe_get_vit_flash_attn_backend(
torch.get_default_dtype() self.attn_backend,
): use_upstream_fa=False,
self.attn_backend = _Backend.FLASH_ATTN attn_backend_override=attn_backend_override,
self.use_upstream_fa = True )
)
if self.attn_backend not in {_Backend.FLASH_ATTN, _Backend.XFORMERS}: if self.attn_backend not in {
_Backend.FLASH_ATTN,
_Backend.XFORMERS,
_Backend.ROCM_AITER_FA,
}:
raise RuntimeError( raise RuntimeError(
f"Keye-VL does not support {self.attn_backend} backend now." f"Keye-VL does not support {self.attn_backend} backend now."
) )
self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN,
_Backend.ROCM_AITER_FA,
}
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -457,15 +474,10 @@ class KeyeSiglipAttention(nn.Module):
self.head_dim, self.head_dim,
) )
if self.attn_backend == _Backend.FLASH_ATTN: if self.is_flash_attn_backend:
if self.use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = flash_attn_varlen_func( output = self.flash_attn_varlen_func(
q, q,
k, k,
v, v,
@ -1542,7 +1554,7 @@ class BaseKeyeModule(nn.Module):
dummy_inputs=KeyeDummyInputsBuilder, dummy_inputs=KeyeDummyInputsBuilder,
) )
class KeyeForConditionalGeneration( class KeyeForConditionalGeneration(
BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
): ):
def _build_projector( def _build_projector(
self, self,
@ -1611,3 +1623,142 @@ class KeyeForConditionalGeneration(
return tuple( return tuple(
self._process_video_embeds(video_type, video_grid_thw, pixel_values_videos) self._process_video_embeds(video_type, video_grid_thw, pixel_values_videos)
) )
def get_mrope_input_positions(
self,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: list[list[int]] | torch.Tensor,
video_grid_thw: list[list[int]] | torch.Tensor,
context_len: int = 0,
seq_len: int | None = None,
second_per_grid_ts: list[float] | None = None,
audio_feature_lengths: torch.Tensor | None = None,
use_audio_in_video: bool = False,
) -> tuple[torch.Tensor, int]:
if isinstance(video_grid_thw, list) and len(video_grid_thw) > 0:
video_grid_thw = video_grid_thw[0]
"""Get mrope input positions and delta value (Keye series)."""
def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]:
"""
Split grid_thw along the t dimension.
Args:
grid_thw: shape [N, 3] tensor or nested list of [t, h, w].
Returns:
List of [1, h, w] rows, repeated t times for each original row.
"""
if isinstance(grid_thw, list):
grid_thw = torch.tensor(grid_thw, dtype=torch.long)
if grid_thw.numel() == 0:
return []
t, hw = grid_thw[:, 0], grid_thw[:, 1:]
ones = torch.ones_like(hw[:, :1]) # [N,1]
out = torch.cat([ones, hw], dim=1).repeat_interleave(t, dim=0)
return out.tolist()
video_grid_thw = split_thw(video_grid_thw)
image_token_id = hf_config.image_token_id
video_token_id = hf_config.video_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
image_nums = len(image_grid_thw)
frame_nums = len(video_grid_thw)
llm_pos_ids_list: list = []
st = 0
remain_images, remain_frames = image_nums, frame_nums
image_index, video_index = 0, 0
for _ in range(image_nums + frame_nums):
if remain_images > 0:
try:
ed_image = input_tokens.index(image_token_id, st)
except ValueError:
ed_image = len(input_tokens) + 1
else:
ed_image = len(input_tokens) + 1
if remain_frames > 0:
try:
ed_video = input_tokens.index(video_token_id, st)
except ValueError:
ed_video = len(input_tokens) + 1
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
video_index += 1
remain_frames -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_merge_size,
w // spatial_merge_size,
)
text_len = ed - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
t_index = (
(
torch.arange(llm_grid_t)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
)
.long()
.flatten()
)
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(llm_grid_t, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(llm_grid_t, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
llm_positions = llm_positions[:, context_len:seq_len]
return llm_positions, mrope_position_delta

View File

@ -22,7 +22,6 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear, RowParallelLinear,
) )
@ -61,7 +60,7 @@ class KimiMLP(nn.Module):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
hidden_act: str, hidden_act: str,
quant_config: QKVParallelLinear | None = None, quant_config: QuantizationConfig | None = None,
reduce_results: bool = True, reduce_results: bool = True,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
@ -155,6 +154,7 @@ class KimiMoE(nn.Module):
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
reduce_results=False, reduce_results=False,
prefix=f"{prefix}.shared_experts",
) )
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@ -340,7 +340,7 @@ class KimiDecoderLayer(nn.Module):
self.block_sparse_moe = KimiMoE( self.block_sparse_moe = KimiMoE(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.block_sparse_moe",
) )
self.mlp = self.block_sparse_moe self.mlp = self.block_sparse_moe
else: else:

View File

@ -49,7 +49,7 @@ from functools import cached_property
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers.activations import ACT2FN, PytorchGELUTanh from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.utils import is_flash_attn_2_available from transformers.utils import is_flash_attn_2_available
@ -651,7 +651,7 @@ class MoonVitPretrainedModel(PreTrainedModel):
"num_heads": config.num_attention_heads, "num_heads": config.num_attention_heads,
"hidden_dim": config.hidden_size, "hidden_dim": config.hidden_size,
"mlp_dim": config.intermediate_size, "mlp_dim": config.intermediate_size,
"activation": PytorchGELUTanh(), "activation": ACT2FN["gelu_pytorch_tanh"],
"attn_bias": True, "attn_bias": True,
"attn_implementation": config._attn_implementation, "attn_implementation": config._attn_implementation,
}, },

View File

@ -364,6 +364,8 @@ class Qwen2_5_VisionAttention(nn.Module):
if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN: if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN:
self.use_upstream_fa = True self.use_upstream_fa = True
if current_platform.is_xpu():
self.use_upstream_fa = False
self.is_flash_attn_backend = self.attn_backend in { self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, _Backend.FLASH_ATTN,
_Backend.ROCM_AITER_FA, _Backend.ROCM_AITER_FA,
@ -856,10 +858,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
max_seqlen = torch.zeros([], device=cu_seqlens.device) max_seqlen = torch.zeros([], device=cu_seqlens.device)
seqlens = torch.zeros(1, device=cu_seqlens.device) seqlens = torch.zeros(1, device=cu_seqlens.device)
if ( if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
self.attn_backend == _Backend.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == _Backend.XFORMERS:
seqlens = cu_seqlens[1:] - cu_seqlens[:-1] seqlens = cu_seqlens[1:] - cu_seqlens[:-1]

View File

@ -34,7 +34,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from transformers import AutoConfig, BatchFeature, PretrainedConfig from transformers import BatchFeature, PretrainedConfig
from transformers.models.qwen2_vl import Qwen2VLImageProcessor, Qwen2VLProcessor from transformers.models.qwen2_vl import Qwen2VLImageProcessor, Qwen2VLProcessor
from transformers.models.qwen2_vl.configuration_qwen2_vl import ( from transformers.models.qwen2_vl.configuration_qwen2_vl import (
Qwen2VLConfig, Qwen2VLConfig,
@ -789,10 +789,7 @@ class Qwen2VisionTransformer(nn.Module):
self, cu_seqlens: torch.Tensor self, cu_seqlens: torch.Tensor
) -> tuple[int | None, list[int] | None]: ) -> tuple[int | None, list[int] | None]:
max_seqlen, seqlens = None, None max_seqlen, seqlens = None, None
if ( if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
self.attn_backend == _Backend.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == _Backend.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
@ -1654,9 +1651,7 @@ class Tarsier2Processor(Qwen2VLProcessor):
class Tarsier2ProcessingInfo(Qwen2VLProcessingInfo): class Tarsier2ProcessingInfo(Qwen2VLProcessingInfo):
def get_hf_config(self) -> Qwen2VLConfig: def get_hf_config(self) -> Qwen2VLConfig:
model_path = self.ctx.model_config.model model_path = self.ctx.model_config.model
original_config = AutoConfig.from_pretrained(model_path) correct_config = Qwen2VLConfig.from_pretrained(model_path)
config_dict = original_config.to_dict()
correct_config = Qwen2VLConfig.from_dict(config_dict)
return correct_config return correct_config

View File

@ -115,6 +115,12 @@ class XPUPlatform(Platform):
device_props = torch.xpu.get_device_properties(device_id) device_props = torch.xpu.get_device_properties(device_id)
return device_props.total_memory return device_props.total_memory
@classmethod
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend:
from vllm.attention.backends.registry import _Backend
return _Backend.FLASH_ATTN
@classmethod @classmethod
def inference_mode(cls): def inference_mode(cls):
return torch.no_grad() return torch.no_grad()

View File

@ -896,6 +896,8 @@ def get_kernel_options(
return kernel_options return kernel_options
else: else:
preferred_block = 32 if query.dtype == torch.float32 else 64 preferred_block = 32 if query.dtype == torch.float32 else 64
block_lower_bound = 16
block_m_candidate = ensure_divisible(preferred_block, block_m) block_m_candidate = ensure_divisible(preferred_block, block_m)
block_n_candidate = ensure_divisible(preferred_block, block_n) block_n_candidate = ensure_divisible(preferred_block, block_n)
@ -910,6 +912,9 @@ def get_kernel_options(
max(1, block_n_candidate // 2), block_n max(1, block_n_candidate // 2), block_n
) )
block_m_candidate = max(block_m_candidate, block_lower_bound)
block_n_candidate = max(block_n_candidate, block_lower_bound)
kernel_options["BLOCK_M"] = block_m_candidate kernel_options["BLOCK_M"] = block_m_candidate
kernel_options["BLOCK_N"] = block_n_candidate kernel_options["BLOCK_N"] = block_n_candidate

View File

@ -6,7 +6,7 @@ from typing import ClassVar
import torch import torch
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
from vllm.attention.backends.abstract import AttentionLayer, AttentionType from vllm.attention.backends.abstract import AttentionLayer, AttentionType, MultipleOf
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import ( from vllm.v1.attention.backends.mla.common import (
MLACommonBackend, MLACommonBackend,
@ -40,6 +40,10 @@ class FlashInferMLABackend(MLACommonBackend):
def get_builder_cls() -> type["FlashInferMLAMetadataBuilder"]: def get_builder_cls() -> type["FlashInferMLAMetadataBuilder"]:
return FlashInferMLAMetadataBuilder return FlashInferMLAMetadataBuilder
@classmethod
def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]:
return [32, 64]
g_fi_workspace = torch.zeros( g_fi_workspace = torch.zeros(
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE, FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE,

View File

@ -15,8 +15,12 @@ class AsyncScheduler(Scheduler):
scheduler_output: SchedulerOutput, scheduler_output: SchedulerOutput,
) -> None: ) -> None:
super()._update_after_schedule(scheduler_output) super()._update_after_schedule(scheduler_output)
pending_structured_output_tokens = False
for req_id in scheduler_output.num_scheduled_tokens: for req_id in scheduler_output.num_scheduled_tokens:
request = self.requests[req_id] request = self.requests[req_id]
pending_structured_output_tokens |= (
request.use_structured_output and request.num_output_placeholders > 0
)
if ( if (
request.num_computed_tokens request.num_computed_tokens
== request.num_tokens + request.num_output_placeholders == request.num_tokens + request.num_output_placeholders
@ -25,6 +29,10 @@ class AsyncScheduler(Scheduler):
# TODO(woosuk): Support speculative decoding. # TODO(woosuk): Support speculative decoding.
request.num_output_placeholders += 1 request.num_output_placeholders += 1
scheduler_output.pending_structured_output_tokens = (
pending_structured_output_tokens
)
def _update_request_with_output( def _update_request_with_output(
self, self,
request: Request, request: Request,

View File

@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.engine import EngineCoreOutputs from vllm.v1.engine import EngineCoreOutputs
from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
@ -40,6 +40,12 @@ class SchedulerInterface(ABC):
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod
def get_grammar_bitmask(
self, scheduler_output: "SchedulerOutput"
) -> "GrammarOutput | None":
raise NotImplementedError
@abstractmethod @abstractmethod
def update_from_output( def update_from_output(
self, self,

View File

@ -181,12 +181,17 @@ class SchedulerOutput:
# freed from the encoder cache. # freed from the encoder cache.
free_encoder_mm_hashes: list[str] free_encoder_mm_hashes: list[str]
# ids of structured outputs requests included in the bitmask, in the # Whether the scheduled requests have all the output tokens they
# same order as the corresponding stacked rows of the bitmask. # need to perform grammar bitmask computation.
# There may be more than one row per request in the case of speculative decoding. pending_structured_output_tokens: bool = False
structured_output_request_ids: list[str]
# the bitmask for the whole batch
grammar_bitmask: "npt.NDArray[np.int32] | None"
# KV Cache Connector metadata. # KV Cache Connector metadata.
kv_connector_metadata: KVConnectorMetadata | None = None kv_connector_metadata: KVConnectorMetadata | None = None
@dataclass
class GrammarOutput:
# ids of structured output requests.
structured_output_request_ids: list[str]
# Bitmask ordered as structured_output_request_ids.
grammar_bitmask: "npt.NDArray[np.int32]"

View File

@ -5,7 +5,7 @@ import itertools
import time import time
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterable from collections.abc import Iterable
from typing import TYPE_CHECKING, Any from typing import Any
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
@ -24,7 +24,12 @@ from vllm.v1.core.encoder_cache_manager import (
) )
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.interface import SchedulerInterface
from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput from vllm.v1.core.sched.output import (
CachedRequestData,
GrammarOutput,
NewRequestData,
SchedulerOutput,
)
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
from vllm.v1.core.sched.utils import check_stop, remove_all from vllm.v1.core.sched.utils import check_stop, remove_all
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
@ -35,10 +40,6 @@ from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.structured_output import StructuredOutputManager
if TYPE_CHECKING:
import numpy as np
import numpy.typing as npt
logger = init_logger(__name__) logger = init_logger(__name__)
@ -619,9 +620,6 @@ class Scheduler(SchedulerInterface):
scheduled_spec_decode_tokens, scheduled_spec_decode_tokens,
req_to_new_blocks, req_to_new_blocks,
) )
structured_output_request_ids, grammar_bitmask = self.get_grammar_bitmask(
num_scheduled_tokens.keys(), scheduled_spec_decode_tokens
)
# Record the request ids that were scheduled in this step. # Record the request ids that were scheduled in this step.
self.prev_step_scheduled_req_ids.clear() self.prev_step_scheduled_req_ids.clear()
@ -641,8 +639,6 @@ class Scheduler(SchedulerInterface):
# the previous and the current steps. # the previous and the current steps.
finished_req_ids=self.finished_req_ids, finished_req_ids=self.finished_req_ids,
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(), free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
structured_output_request_ids=structured_output_request_ids,
grammar_bitmask=grammar_bitmask,
) )
# NOTE(Kuntai): this function is designed for multiple purposes: # NOTE(Kuntai): this function is designed for multiple purposes:
@ -872,9 +868,8 @@ class Scheduler(SchedulerInterface):
def get_grammar_bitmask( def get_grammar_bitmask(
self, self,
scheduled_request_ids: Iterable[str], scheduler_output: SchedulerOutput,
scheduled_spec_decode_tokens: dict[str, list[int]], ) -> GrammarOutput | None:
) -> tuple[list[str], "npt.NDArray[np.int32] | None"]:
# Collect list of scheduled request ids that use structured output. # Collect list of scheduled request ids that use structured output.
# The corresponding rows of the bitmask will be in this order. # The corresponding rows of the bitmask will be in this order.
# PERF: in case of chunked prefill, # PERF: in case of chunked prefill,
@ -883,18 +878,18 @@ class Scheduler(SchedulerInterface):
# cycle to fill in the bitmask, which could be a big no-op. # cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids = [ structured_output_request_ids = [
req_id req_id
for req_id in scheduled_request_ids for req_id in scheduler_output.num_scheduled_tokens
if (req := self.requests.get(req_id)) and req.use_structured_output if (req := self.requests.get(req_id)) and req.use_structured_output
] ]
if not structured_output_request_ids: if not structured_output_request_ids:
return structured_output_request_ids, None return None
bitmask = self.structured_output_manager.grammar_bitmask( bitmask = self.structured_output_manager.grammar_bitmask(
self.requests, self.requests,
structured_output_request_ids, structured_output_request_ids,
scheduled_spec_decode_tokens, scheduler_output.scheduled_spec_decode_tokens,
) )
return structured_output_request_ids, bitmask return GrammarOutput(structured_output_request_ids, bitmask)
def update_from_output( def update_from_output(
self, self,

View File

@ -12,7 +12,7 @@ from concurrent.futures import Future
from contextlib import ExitStack, contextmanager from contextlib import ExitStack, contextmanager
from inspect import isclass, signature from inspect import isclass, signature
from logging import DEBUG from logging import DEBUG
from typing import Any, TypeVar from typing import Any, TypeVar, cast
import msgspec import msgspec
import zmq import zmq
@ -163,6 +163,27 @@ class EngineCore:
vllm_config, mm_registry vllm_config, mm_registry
) )
# If a KV connector is initialized for scheduler, we want to collect
# handshake metadata from all workers so the connector in the scheduler
# will have the full context
kv_connector = self.scheduler.get_kv_connector()
if kv_connector is not None:
# Collect and store KV connector xfer metadata from workers
# (after KV cache registration)
xfer_handshake_metadata = (
self.model_executor.get_kv_connector_handshake_metadata()
)
if xfer_handshake_metadata:
# xfer_handshake_metadata is list of dicts from workers
# Each dict already has structure {tp_rank: metadata}
# Merge all worker dicts into a single dict
content: dict[int, Any] = {}
for worker_dict in xfer_handshake_metadata:
if worker_dict is not None:
content.update(worker_dict)
kv_connector.set_xfer_handshake_metadata(content)
# Setup batch queue for pipeline parallelism. # Setup batch queue for pipeline parallelism.
# Batch queue for scheduled batches. This enables us to asynchronously # Batch queue for scheduled batches. This enables us to asynchronously
# schedule and execute batches, and is required by pipeline parallelism # schedule and execute batches, and is required by pipeline parallelism
@ -178,7 +199,7 @@ class EngineCore:
self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None
if ( if (
self.vllm_config.cache_config.enable_prefix_caching self.vllm_config.cache_config.enable_prefix_caching
or self.scheduler.get_kv_connector() is not None or kv_connector is not None
): ):
caching_hash_fn = get_hash_fn_by_name( caching_hash_fn = get_hash_fn_by_name(
vllm_config.cache_config.prefix_caching_hash_algo vllm_config.cache_config.prefix_caching_hash_algo
@ -313,9 +334,12 @@ class EngineCore:
if not self.scheduler.has_requests(): if not self.scheduler.has_requests():
return {}, False return {}, False
scheduler_output = self.scheduler.schedule() scheduler_output = self.scheduler.schedule()
future = self.model_executor.execute_model(scheduler_output, non_block=True)
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
with self.log_error_detail(scheduler_output): with self.log_error_detail(scheduler_output):
model_output = self.model_executor.execute_model(scheduler_output) model_output = future.result()
if model_output is None:
model_output = self.model_executor.sample_tokens(grammar_output)
engine_core_outputs = self.scheduler.update_from_output( engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output scheduler_output, model_output
@ -355,20 +379,47 @@ class EngineCore:
assert len(batch_queue) < self.batch_queue_size assert len(batch_queue) < self.batch_queue_size
model_executed = False model_executed = False
deferred_scheduler_output = None
if self.scheduler.has_requests(): if self.scheduler.has_requests():
scheduler_output = self.scheduler.schedule() scheduler_output = self.scheduler.schedule()
future = self.model_executor.execute_model(scheduler_output, non_block=True) exec_future = self.model_executor.execute_model(
batch_queue.appendleft((future, scheduler_output)) scheduler_output, non_block=True
)
model_executed = scheduler_output.total_num_scheduled_tokens > 0 model_executed = scheduler_output.total_num_scheduled_tokens > 0
if (
model_executed if scheduler_output.pending_structured_output_tokens:
and len(batch_queue) < self.batch_queue_size # We need to defer sampling until we have processed the model output
and not batch_queue[-1][0].done() # from the prior step.
): deferred_scheduler_output = scheduler_output
# Don't block on next worker response unless the queue is full # Block-wait for execute to return (continues running async on the GPU).
# or there are no more requests to schedule. with self.log_error_detail(scheduler_output):
return None, True exec_result = exec_future.result()
assert exec_result is None
else:
# We aren't waiting for any tokens, get any grammar output immediately.
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
# Block-wait for execute to return (continues running async on the GPU).
with self.log_error_detail(scheduler_output):
exec_result = exec_future.result()
if exec_result is None:
# Call sample tokens.
future = self.model_executor.sample_tokens(
grammar_output, non_block=True
)
else:
# No sampling required (e.g. all requests finished).
future = cast(Future[ModelRunnerOutput], exec_future)
# Add this step's future to the queue.
batch_queue.appendleft((future, scheduler_output))
if (
model_executed
and len(batch_queue) < self.batch_queue_size
and not batch_queue[-1][0].done()
):
# Don't block on next worker response unless the queue is full
# or there are no more requests to schedule.
return None, True
elif not batch_queue: elif not batch_queue:
# Queue is empty. We should not reach here since this method should # Queue is empty. We should not reach here since this method should
@ -384,6 +435,19 @@ class EngineCore:
engine_core_outputs = self.scheduler.update_from_output( engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output scheduler_output, model_output
) )
# NOTE(nick): We can either handle the deferred tasks here or save
# in a field and do it immediately once step_with_batch_queue is
# re-called. The latter slightly favors TTFT over TPOT/throughput.
if deferred_scheduler_output:
# We now have the tokens needed to compute the bitmask for the
# deferred request. Get the bitmask and call sample tokens.
grammar_output = self.scheduler.get_grammar_bitmask(
deferred_scheduler_output
)
future = self.model_executor.sample_tokens(grammar_output, non_block=True)
batch_queue.appendleft((future, deferred_scheduler_output))
return engine_core_outputs, model_executed return engine_core_outputs, model_executed
def shutdown(self): def shutdown(self):

View File

@ -9,11 +9,14 @@ from typing import TYPE_CHECKING, Literal, TypeVar, overload
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorHandshakeMetadata,
)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.tasks import SupportedTask from vllm.tasks import SupportedTask
from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.engine import ReconfigureDistributedRequest from vllm.v1.engine import ReconfigureDistributedRequest
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
@ -177,30 +180,51 @@ class Executor(ABC):
): ):
raise NotImplementedError raise NotImplementedError
def get_kv_connector_handshake_metadata(
self,
) -> list[dict[int, KVConnectorHandshakeMetadata]]:
return self.collective_rpc("get_kv_connector_handshake_metadata")
@overload @overload
def execute_model( def execute_model(
self, self, scheduler_output: SchedulerOutput, non_block: Literal[False] = False
scheduler_output: SchedulerOutput, ) -> ModelRunnerOutput | None:
non_block: Literal[False] = False,
) -> ModelRunnerOutput:
pass pass
@overload @overload
def execute_model( def execute_model(
self, self, scheduler_output: SchedulerOutput, non_block: Literal[True] = True
scheduler_output: SchedulerOutput, ) -> Future[ModelRunnerOutput | None]:
non_block: Literal[True] = True,
) -> Future[ModelRunnerOutput]:
pass pass
def execute_model( def execute_model(
self, scheduler_output: SchedulerOutput, non_block: bool = False self, scheduler_output: SchedulerOutput, non_block: bool = False
) -> ModelRunnerOutput | Future[ModelRunnerOutput]: ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
output = self.collective_rpc( # type: ignore[call-overload] output = self.collective_rpc( # type: ignore[call-overload]
"execute_model", args=(scheduler_output,), non_block=non_block "execute_model", args=(scheduler_output,), non_block=non_block
) )
return output[0] return output[0]
@overload
def sample_tokens(
self, grammar_output: GrammarOutput | None, non_block: Literal[False] = False
) -> ModelRunnerOutput:
pass
@overload
def sample_tokens(
self, grammar_output: GrammarOutput | None, non_block: Literal[True] = True
) -> Future[ModelRunnerOutput]:
pass
def sample_tokens(
self, grammar_output: GrammarOutput | None, non_block: bool = False
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
output = self.collective_rpc( # type: ignore[call-overload]
"sample_tokens", args=(grammar_output,), non_block=non_block
)
return output[0]
def execute_dummy_batch(self) -> None: def execute_dummy_batch(self) -> None:
self.collective_rpc("execute_dummy_batch") self.collective_rpc("execute_dummy_batch")

View File

@ -46,7 +46,7 @@ from vllm.utils.system_utils import (
get_mp_context, get_mp_context,
set_process_title, set_process_title,
) )
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.executor.abstract import Executor, FailureCallback from vllm.v1.executor.abstract import Executor, FailureCallback
from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput
from vllm.v1.worker.worker_base import WorkerWrapperBase from vllm.v1.worker.worker_base import WorkerWrapperBase
@ -132,15 +132,12 @@ class MultiprocExecutor(Executor):
uw.death_writer.close() uw.death_writer.close()
self._ensure_worker_termination([uw.proc for uw in unready_workers]) self._ensure_worker_termination([uw.proc for uw in unready_workers])
# For pipeline parallel, we use a thread pool for asynchronous # Note: must use only 1 IO thread to keep dequeue sequence
# execute_model. # from the response queue.
if self.max_concurrent_batches > 1: # _async_aggregate_workers_output also assumes a single IO thread.
# Note: must use only 1 IO thread to keep dequeue sequence self.io_thread_pool = ThreadPoolExecutor(
# from the response queue max_workers=1, thread_name_prefix="mp_exec_io"
# _async_aggregate_workers_output also assumes a single IO thread )
self.io_thread_pool = ThreadPoolExecutor(
max_workers=1, thread_name_prefix="mp_exec_io"
)
self.output_rank = self._get_output_rank() self.output_rank = self._get_output_rank()
self.has_connector = self.vllm_config.kv_transfer_config is not None self.has_connector = self.vllm_config.kv_transfer_config is not None
@ -180,15 +177,27 @@ class MultiprocExecutor(Executor):
self.failure_callback = callback self.failure_callback = callback
def execute_model( # type: ignore[override] def execute_model( # type: ignore[override]
self, self, scheduler_output: SchedulerOutput, non_block: bool = False
scheduler_output: SchedulerOutput, ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
non_block: bool = False, return self._execute_with_aggregation(
"execute_model", scheduler_output, non_block=non_block
)
def sample_tokens( # type: ignore[override]
self, grammar_output: GrammarOutput | None, non_block: bool = False
) -> ModelRunnerOutput | Future[ModelRunnerOutput]: ) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
return self._execute_with_aggregation( # type: ignore[return-value]
"sample_tokens", grammar_output, non_block=non_block
)
def _execute_with_aggregation(
self, method: str, *args, non_block: bool = False
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
if not self.has_connector: if not self.has_connector:
# get output only from a single worker (output_rank) # get output only from a single worker (output_rank)
(output,) = self.collective_rpc( (output,) = self.collective_rpc(
"execute_model", method,
args=(scheduler_output,), args=args,
unique_reply_rank=self.output_rank, unique_reply_rank=self.output_rank,
non_block=non_block, non_block=non_block,
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS, timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
@ -197,8 +206,8 @@ class MultiprocExecutor(Executor):
# get output from all workers # get output from all workers
outputs = self.collective_rpc( outputs = self.collective_rpc(
"execute_model", method,
args=(scheduler_output,), args=args,
non_block=non_block, non_block=non_block,
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS, timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
) )

View File

@ -19,7 +19,7 @@ from vllm.utils.network_utils import (
get_ip, get_ip,
get_open_port, get_open_port,
) )
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.executor.ray_utils import ( from vllm.v1.executor.ray_utils import (
@ -41,6 +41,9 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
COMPLETED_NONE_FUTURE: Future[ModelRunnerOutput | None] = Future()
COMPLETED_NONE_FUTURE.set_result(None)
@dataclass @dataclass
class RayWorkerMetaData: class RayWorkerMetaData:
@ -96,6 +99,8 @@ class RayDistributedExecutor(Executor):
# KV connector setup # KV connector setup
self.has_connector = self.vllm_config.kv_transfer_config is not None self.has_connector = self.vllm_config.kv_transfer_config is not None
self.scheduler_output: SchedulerOutput | None = None
@property @property
def max_concurrent_batches(self) -> int: def max_concurrent_batches(self) -> int:
"""Ray distributed executor supports pipeline parallelism, """Ray distributed executor supports pipeline parallelism,
@ -381,22 +386,46 @@ class RayDistributedExecutor(Executor):
self.shutdown() self.shutdown()
def execute_model( # type: ignore[override] def execute_model( # type: ignore[override]
self, scheduler_output: SchedulerOutput, non_block: bool = False self,
scheduler_output: SchedulerOutput,
non_block: bool = False,
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
if self.scheduler_output is not None:
raise RuntimeError(
"State error: sample_tokens() must be called "
"after execute_model() returns None."
)
self.scheduler_output = scheduler_output
return COMPLETED_NONE_FUTURE if non_block else None
def sample_tokens( # type: ignore[override]
self,
grammar_output: "GrammarOutput | None",
non_block: bool = False,
) -> ModelRunnerOutput | Future[ModelRunnerOutput]: ) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
"""Execute the model on the Ray workers. """Execute the model on the Ray workers.
The scheduler output to use should have been provided in
a prior call to execute_model().
Args: Args:
scheduler_output: The scheduler output to execute. grammar_output: The structured outputs grammar bitmask, if applicable.
non_block: If True, the method will return a Future. non_block: If True, the method will return a Future.
Returns: Returns:
The model runner output. The model runner output.
""" """
scheduler_output = self.scheduler_output
if scheduler_output is None:
return None # noqa
self.scheduler_output = None
# Build the compiled DAG for the first time. # Build the compiled DAG for the first time.
if self.forward_dag is None: # type: ignore if self.forward_dag is None: # type: ignore
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False) self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
refs = self.forward_dag.execute(scheduler_output) # type: ignore refs = self.forward_dag.execute((scheduler_output, grammar_output)) # type: ignore
if not self.has_connector: if not self.has_connector:
# Get output only from a single worker (output_rank) # Get output only from a single worker (output_rank)

View File

@ -19,7 +19,7 @@ from vllm.v1.outputs import AsyncModelRunnerOutput
from vllm.v1.worker.worker_base import WorkerWrapperBase from vllm.v1.worker.worker_base import WorkerWrapperBase
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
logger = init_logger(__name__) logger = init_logger(__name__)
@ -82,36 +82,41 @@ try:
def execute_model_ray( def execute_model_ray(
self, self,
scheduler_output: Union[ execute_model_input: tuple["SchedulerOutput", "GrammarOutput"]
"SchedulerOutput", tuple["SchedulerOutput", "IntermediateTensors"] | tuple["SchedulerOutput", "GrammarOutput", "IntermediateTensors"],
],
) -> Union[ ) -> Union[
"ModelRunnerOutput", tuple["SchedulerOutput", "IntermediateTensors"] "ModelRunnerOutput",
tuple["SchedulerOutput", "GrammarOutput", "IntermediateTensors"],
]: ]:
# This method is used by Ray Compiled Graph to execute the model, # This method is used by Ray Compiled Graph to execute the model,
# and it needs a special logic of self.setup_device_if_necessary() # and it needs a special logic of self.setup_device_if_necessary()
self.setup_device_if_necessary() self.setup_device_if_necessary()
assert self.worker is not None, "Worker is not initialized" assert self.worker is not None, "Worker is not initialized"
if isinstance(scheduler_output, tuple): if len(execute_model_input) == 3:
scheduler_output, intermediate_tensors = scheduler_output scheduler_output, grammar_output, intermediate_tensors = (
execute_model_input
)
else: else:
scheduler_output, intermediate_tensors = scheduler_output, None scheduler_output, grammar_output = execute_model_input
intermediate_tensors = None
assert self.worker.model_runner is not None assert self.worker.model_runner is not None
output = self.worker.model_runner.execute_model( output = self.worker.model_runner.execute_model(
scheduler_output, intermediate_tensors scheduler_output, intermediate_tensors
) )
if isinstance(output, IntermediateTensors): if isinstance(output, IntermediateTensors):
output = scheduler_output, output output = scheduler_output, grammar_output, output
elif not get_pp_group().is_last_rank: elif not get_pp_group().is_last_rank:
# Case where there are no scheduled requests # Case where there are no scheduled requests
# but may still be finished requests. # but may still be finished requests.
assert not output or not output.req_ids assert not output or not output.req_ids
output = scheduler_output, None output = scheduler_output, grammar_output, None
# Ensure outputs crossing Ray compiled DAG are serializable. elif output is None:
# AsyncModelRunnerOutput holds CUDA events and cannot be output = self.worker.model_runner.sample_tokens(grammar_output)
# pickled. # Ensure outputs crossing Ray compiled DAG are serializable.
if isinstance(output, AsyncModelRunnerOutput): # AsyncModelRunnerOutput holds CUDA events and cannot be
output = output.get_output() # pickled.
if isinstance(output, AsyncModelRunnerOutput):
output = output.get_output()
return output return output
def override_env_vars(self, vars: dict[str, str]): def override_env_vars(self, vars: dict[str, str]):

View File

@ -16,6 +16,7 @@ from diskcache import Cache
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.import_utils import LazyLoader from vllm.utils.import_utils import LazyLoader
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
if TYPE_CHECKING: if TYPE_CHECKING:
import outlines_core as oc import outlines_core as oc
@ -24,7 +25,6 @@ if TYPE_CHECKING:
import xgrammar as xgr import xgrammar as xgr
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_input_batch import InputBatch
else: else:
xgr = LazyLoader("xgr", globals(), "xgrammar") xgr = LazyLoader("xgr", globals(), "xgrammar")
@ -47,6 +47,7 @@ CACHE = None
def apply_grammar_bitmask( def apply_grammar_bitmask(
scheduler_output: SchedulerOutput, scheduler_output: SchedulerOutput,
grammar_output: GrammarOutput,
input_batch: InputBatch, input_batch: InputBatch,
logits: torch.Tensor, logits: torch.Tensor,
) -> None: ) -> None:
@ -58,9 +59,9 @@ def apply_grammar_bitmask(
input_batch (InputBatch): The input of model runner. input_batch (InputBatch): The input of model runner.
logits (torch.Tensor): The output logits of model forward. logits (torch.Tensor): The output logits of model forward.
""" """
grammar_bitmask = scheduler_output.grammar_bitmask # Serialization of np.ndarray is much more efficient than a tensor,
if grammar_bitmask is None: # so we receive it in that format.
return grammar_bitmask = grammar_output.grammar_bitmask
# We receive the structured output bitmask from the scheduler, # We receive the structured output bitmask from the scheduler,
# compacted to contain bitmasks only for structured output requests. # compacted to contain bitmasks only for structured output requests.
@ -79,7 +80,7 @@ def apply_grammar_bitmask(
cumulative_offset += len( cumulative_offset += len(
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []) scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
) )
if req_id in scheduler_output.structured_output_request_ids: if req_id in grammar_output.structured_output_request_ids:
struct_out_req_batch_indices[req_id] = logit_index struct_out_req_batch_indices[req_id] = logit_index
out_indices = [] out_indices = []
@ -91,7 +92,7 @@ def apply_grammar_bitmask(
dtype=grammar_bitmask.dtype, dtype=grammar_bitmask.dtype,
) )
cumulative_index = 0 cumulative_index = 0
for req_id in scheduler_output.structured_output_request_ids: for req_id in grammar_output.structured_output_request_ids:
num_spec_tokens = len( num_spec_tokens = len(
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []) scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
) )
@ -101,22 +102,28 @@ def apply_grammar_bitmask(
sorted_bitmask[logit_index + i] = grammar_bitmask[cumulative_index + i] sorted_bitmask[logit_index + i] = grammar_bitmask[cumulative_index + i]
out_indices.append(logit_index + i) out_indices.append(logit_index + i)
cumulative_index += 1 + num_spec_tokens cumulative_index += 1 + num_spec_tokens
grammar_bitmask = sorted_bitmask
# Copy async to device as tensor.
grammar_bitmask = torch.from_numpy(sorted_bitmask).to(
logits.device, non_blocking=True
)
# If the length of out indices and the logits have the same shape # If the length of out indices and the logits have the same shape
# we don't need to pass indices to the kernel, # we don't need to pass indices to the kernel,
# since the bitmask is already aligned with the logits. # since the bitmask is already aligned with the logits.
skip_out_indices = len(out_indices) == logits.shape[0] skip_out_indices = len(out_indices) == logits.shape[0]
# Serialization of np.ndarray is much more efficient than a tensor, index_tensor = None
# so we receive it in that format. if not skip_out_indices:
grammar_bitmask = torch.from_numpy(grammar_bitmask).contiguous() # xgrammar expects a python list of indices but it will actually work with
# a tensor. If we copy the tensor ourselves here we can do it in a non_blocking
# manner and there should be no cpu sync within xgrammar.
index_tensor = torch.tensor(
out_indices, dtype=torch.int32, device="cpu", pin_memory=True
)
index_tensor = index_tensor.to(logits.device, non_blocking=True)
xgr.apply_token_bitmask_inplace( xgr.apply_token_bitmask_inplace(logits, grammar_bitmask, indices=index_tensor)
logits,
grammar_bitmask.to(logits.device, non_blocking=True),
indices=out_indices if not skip_out_indices else None,
)
class OutlinesVocabulary: class OutlinesVocabulary:

View File

@ -204,7 +204,7 @@ class InputBatch:
self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy() self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy()
# lora related # lora related
self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int32) self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int64)
self.lora_id_to_request_ids: dict[int, set[str]] = {} self.lora_id_to_request_ids: dict[int, set[str]] = {}
self.lora_id_to_lora_request: dict[int, LoRARequest] = {} self.lora_id_to_lora_request: dict[int, LoRARequest] = {}

View File

@ -109,6 +109,7 @@ from vllm.v1.outputs import (
EMPTY_MODEL_RUNNER_OUTPUT, EMPTY_MODEL_RUNNER_OUTPUT,
AsyncModelRunnerOutput, AsyncModelRunnerOutput,
DraftTokenIds, DraftTokenIds,
KVConnectorOutput,
LogprobsLists, LogprobsLists,
LogprobsTensors, LogprobsTensors,
ModelRunnerOutput, ModelRunnerOutput,
@ -150,7 +151,7 @@ from .utils import (
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
logger = init_logger(__name__) logger = init_logger(__name__)
@ -218,6 +219,20 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
return output return output
class ExecuteModelState(NamedTuple):
"""Ephemeral cached state transferred between execute_model() and
sample_tokens(), after execute_model() returns None."""
scheduler_output: "SchedulerOutput"
logits: torch.Tensor
spec_decode_metadata: SpecDecodeMetadata | None
spec_decode_common_attn_metadata: CommonAttentionMetadata | None
hidden_states: torch.Tensor
sample_hidden_states: torch.Tensor
aux_hidden_states: list[torch.Tensor] | None
kv_connector_output: KVConnectorOutput | None
class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def __init__( def __init__(
self, self,
@ -509,6 +524,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
) )
# Ephemeral state transferred between execute_model() and sample_tokens().
self.execute_model_state: ExecuteModelState | None = None
def reset_mm_cache(self) -> None: def reset_mm_cache(self) -> None:
if self.mm_budget: if self.mm_budget:
self.mm_budget.reset_cache() self.mm_budget.reset_cache()
@ -2113,7 +2131,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_input_tokens: int, # Padded num_input_tokens: int, # Padded
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
) -> tuple[ ) -> tuple[
int,
torch.Tensor | None, torch.Tensor | None,
torch.Tensor | None, torch.Tensor | None,
torch.Tensor, torch.Tensor,
@ -2207,7 +2224,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
model_kwargs.update(encoder_inputs) model_kwargs.update(encoder_inputs)
return ( return (
num_scheduled_tokens,
input_ids, input_ids,
inputs_embeds, inputs_embeds,
positions, positions,
@ -2425,13 +2441,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors: ) -> ModelRunnerOutput | IntermediateTensors | None:
if self.execute_model_state is not None:
raise RuntimeError(
"State error: sample_tokens() must be called "
"after execute_model() returns None."
)
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
with record_function_or_nullcontext("Preprocess"): with record_function_or_nullcontext("Preprocess"):
with self.synchronize_input_prep(): with self.synchronize_input_prep():
# Update persistent batch states. # Update persistent batch states.
self._update_states(scheduler_output) self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens: if not num_scheduled_tokens:
if not has_kv_transfer_group(): if not has_kv_transfer_group():
# Return empty ModelRunnerOutput if no work to do. # Return empty ModelRunnerOutput if no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT return EMPTY_MODEL_RUNNER_OUTPUT
@ -2471,7 +2493,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
) )
( (
num_scheduled_tokens,
input_ids, input_ids,
inputs_embeds, inputs_embeds,
positions, positions,
@ -2559,6 +2580,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Rare case. # Rare case.
assert not self.is_pooling_model assert not self.is_pooling_model
sample_hidden_states = hidden_states[logits_indices]
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
all_gather_tensors = { all_gather_tensors = {
"residual": not is_residual_scattered_for_sp( "residual": not is_residual_scattered_for_sp(
@ -2572,7 +2594,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
) )
logits = None logits = None
else: else:
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states) logits = self.model.compute_logits(sample_hidden_states)
model_output_broadcast_data = {} model_output_broadcast_data = {}
@ -2585,9 +2606,45 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
assert model_output_broadcast_data is not None assert model_output_broadcast_data is not None
logits = model_output_broadcast_data["logits"] logits = model_output_broadcast_data["logits"]
# Apply structured output bitmasks if present self.execute_model_state = ExecuteModelState(
if scheduler_output.structured_output_request_ids: scheduler_output,
apply_grammar_bitmask(scheduler_output, self.input_batch, logits) logits,
spec_decode_metadata,
spec_decode_common_attn_metadata,
hidden_states,
sample_hidden_states,
aux_hidden_states,
kv_connector_output,
)
return None
@torch.inference_mode
def sample_tokens(
self, grammar_output: "GrammarOutput | None"
) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors:
if self.execute_model_state is None:
# Nothing to do (PP non-final rank case), output isn't used.
return None # noqa
# Unpack ephemeral state.
(
scheduler_output,
logits,
spec_decode_metadata,
spec_decode_common_attn_metadata,
hidden_states,
sample_hidden_states,
aux_hidden_states,
kv_connector_output,
) = self.execute_model_state
# Clear ephemeral state.
self.execute_model_state = None
# Apply structured output bitmasks if present.
if grammar_output is not None:
apply_grammar_bitmask(
scheduler_output, grammar_output, self.input_batch, logits
)
with record_function_or_nullcontext("Sample"): with record_function_or_nullcontext("Sample"):
sampler_output = self._sample(logits, spec_decode_metadata) sampler_output = self._sample(logits, spec_decode_metadata)
@ -2646,7 +2703,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
sampler_output, sampler_output,
logits, logits,
hidden_states, hidden_states,
num_scheduled_tokens, scheduler_output.total_num_scheduled_tokens,
spec_decode_metadata, spec_decode_metadata,
) )
@ -3978,6 +4035,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def create_attn_groups( def create_attn_groups(
attn_backends_map: dict[AttentionGroupKey, list[str]], attn_backends_map: dict[AttentionGroupKey, list[str]],
kv_cache_group_id: int,
) -> list[AttentionGroup]: ) -> list[AttentionGroup]:
attn_groups: list[AttentionGroup] = [] attn_groups: list[AttentionGroup] = []
for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items(): for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items():
@ -3987,6 +4045,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
kv_cache_spec, kv_cache_spec,
self.vllm_config, self.vllm_config,
self.device, self.device,
kv_cache_group_id,
num_metadata_builders=1 num_metadata_builders=1
if not self.parallel_config.enable_dbo if not self.parallel_config.enable_dbo
else 2, else 2,
@ -4005,8 +4064,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Resolve cudagraph_mode before actually initialize metadata_builders # Resolve cudagraph_mode before actually initialize metadata_builders
self._check_and_update_cudagraph_mode(attention_backend_set) self._check_and_update_cudagraph_mode(attention_backend_set)
for attn_backends_map in attention_backend_maps: for i, attn_backend_map in enumerate(attention_backend_maps):
self.attn_groups.append(create_attn_groups(attn_backends_map)) self.attn_groups.append(create_attn_groups(attn_backend_map, i))
# Calculate reorder batch threshold (if needed) # Calculate reorder batch threshold (if needed)
self.calculate_reorder_batch_threshold() self.calculate_reorder_batch_threshold()
@ -4149,89 +4208,88 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
group.get_metadata_builder().reorder_batch_threshold group.get_metadata_builder().reorder_batch_threshold
for group in self._attn_group_iterator() for group in self._attn_group_iterator()
] ]
# If there are no attention groups (attention-free model) or no backend
# reports a threshold, leave reordering disabled.
if len(reorder_batch_thresholds) == 0:
self.reorder_batch_threshold = None
return
self.reorder_batch_threshold = reduce(min_none_high, reorder_batch_thresholds) self.reorder_batch_threshold = reduce(min_none_high, reorder_batch_thresholds)
def _find_compatible_block_sizes( @staticmethod
self, def select_common_block_size(
kv_manager_block_size: int, kv_manager_block_size: int, attn_groups: list[AttentionGroup]
backend_cls: type[AttentionBackend],
return_all: bool = False,
) -> list[int]:
"""
Find compatible block sizes for a backend.
Args:
kv_manager_block_size: Physical block size of KV cache
backend_cls: Attention backend class
return_all: Return all compatible sizes if True, max size if False
Returns:
Compatible block size(s) based on return_all parameter
Raises:
ValueError: If no compatible block size found
"""
supported_block_size = backend_cls.get_supported_kernel_block_size()
compatible_sizes = []
for block_size in supported_block_size:
if isinstance(block_size, int):
if kv_manager_block_size % block_size == 0:
compatible_sizes.append(block_size)
elif (
isinstance(block_size, MultipleOf)
and kv_manager_block_size % block_size.base == 0
):
compatible_sizes.append(kv_manager_block_size)
if not compatible_sizes:
raise ValueError(f"No compatible block size for {kv_manager_block_size}")
return compatible_sizes if return_all else [max(compatible_sizes)]
def _select_common_block_size(
self, kv_manager_block_size: int, attn_groups: list[AttentionGroup]
) -> int: ) -> int:
""" """
Select common block size for all backends. Select a block size that is supported by all backends and is a factor of
kv_manager_block_size.
If kv_manager_block_size is supported by all backends, return it directly.
Otherwise, return the max supported size.
Args: Args:
kv_manager_block_size: Block size of KV cache kv_manager_block_size: Block size of KV cache
attn_groups: List of attention groups attn_groups: List of attention groups
Returns: Returns:
Block size supported by all backends, The selected block size
prioritizing cache_config.block_size
Raises: Raises:
ValueError: If no common block size found ValueError: If no valid block size found
""" """
all_backend_supports = []
for attn_group in attn_groups: def block_size_is_supported(
compatible_sizes = self._find_compatible_block_sizes( backends: list[type[AttentionBackend]], block_size: int
kv_manager_block_size, attn_group.backend, return_all=True ) -> bool:
) """
supported_sizes = sorted(list(set(compatible_sizes)), reverse=True) Check if the block size is supported by all backends.
all_backend_supports.append(set(supported_sizes)) """
for backend in backends:
is_supported = False
for supported_size in backend.get_supported_kernel_block_size():
if isinstance(supported_size, int):
if block_size == supported_size:
is_supported = True
elif isinstance(supported_size, MultipleOf):
if block_size % supported_size.base == 0:
is_supported = True
else:
raise ValueError(f"Unknown supported size: {supported_size}")
if not is_supported:
return False
return True
common_supported_sizes = set.intersection(*all_backend_supports) backends = [group.backend for group in attn_groups]
if not common_supported_sizes: # Case 1: if the block_size of kv cache manager is supported by all backends,
error_msg = f"No common block size for {kv_manager_block_size}. " # return it directly
for i, attn_group in enumerate(attn_groups): if block_size_is_supported(backends, kv_manager_block_size):
supported = all_backend_supports[i] return kv_manager_block_size
error_msg += (
f"Backend {attn_group.backend} supports: {sorted(supported)}. "
)
raise ValueError(error_msg)
if self.cache_config.block_size in common_supported_sizes: # Case 2: otherwise, the block_size must be an `int`-format supported size of
return self.cache_config.block_size # at least one backend. Iterate over all `int`-format supported sizes in
# descending order and return the first one that is supported by all backends.
# Simple proof:
# If the supported size b is in MultipleOf(x_i) format for all attention
# backends i, and b a factor of kv_manager_block_size, then
# kv_manager_block_size also satisfies MultipleOf(x_i) for all i. We will
# return kv_manager_block_size in case 1.
all_int_supported_sizes = set(
supported_size
for backend in backends
for supported_size in backend.get_supported_kernel_block_size()
if isinstance(supported_size, int)
)
return max(common_supported_sizes) for supported_size in sorted(all_int_supported_sizes, reverse=True):
if kv_manager_block_size % supported_size != 0:
continue
if block_size_is_supported(backends, supported_size):
return supported_size
raise ValueError(f"No common block size for {kv_manager_block_size}. ")
def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: def may_reinitialize_input_batch(
self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int]
) -> None:
""" """
Re-initialize the input batch if the block sizes are different from Re-initialize the input batch if the block sizes are different from
`[self.cache_config.block_size]`. This usually happens when there `[self.cache_config.block_size]`. This usually happens when there
@ -4239,6 +4297,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
Args: Args:
kv_cache_config: The KV cache configuration. kv_cache_config: The KV cache configuration.
kernel_block_sizes: The kernel block sizes for each KV cache group.
""" """
block_sizes = [ block_sizes = [
kv_cache_group.kv_cache_spec.block_size kv_cache_group.kv_cache_spec.block_size
@ -4246,9 +4305,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if not isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec) if not isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec)
] ]
# Generate kernel_block_sizes that matches each block_size
kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config)
if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [ if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [
self.cache_config.block_size self.cache_config.block_size
]: ]:
@ -4349,7 +4405,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# all backends in the group. # all backends in the group.
attn_groups = self.attn_groups[kv_cache_group_id] attn_groups = self.attn_groups[kv_cache_group_id]
kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size
selected_kernel_size = self._select_common_block_size( selected_kernel_size = self.select_common_block_size(
kv_manager_block_size, attn_groups kv_manager_block_size, attn_groups
) )
kernel_block_sizes.append(selected_kernel_size) kernel_block_sizes.append(selected_kernel_size)
@ -4367,6 +4423,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self, self,
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
kv_cache_raw_tensors: dict[str, torch.Tensor], kv_cache_raw_tensors: dict[str, torch.Tensor],
kernel_block_sizes: list[int],
) -> dict[str, torch.Tensor]: ) -> dict[str, torch.Tensor]:
""" """
Reshape the KV cache tensors to the desired shape and dtype. Reshape the KV cache tensors to the desired shape and dtype.
@ -4375,6 +4432,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
kv_cache_config: The KV cache config kv_cache_config: The KV cache config
kv_cache_raw_tensors: The KV cache buffer of each layer, with kv_cache_raw_tensors: The KV cache buffer of each layer, with
correct size but uninitialized shape. correct size but uninitialized shape.
kernel_block_sizes: The kernel block sizes for each KV cache group.
Returns: Returns:
Dict[str, torch.Tensor]: A map between layer names to their Dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache. corresponding memory buffer for KV cache.
@ -4384,6 +4442,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for group in self._kv_cache_spec_attn_group_iterator(): for group in self._kv_cache_spec_attn_group_iterator():
kv_cache_spec = group.kv_cache_spec kv_cache_spec = group.kv_cache_spec
attn_backend = group.backend attn_backend = group.backend
if group.kv_cache_group_id == len(kernel_block_sizes):
# There may be a last group for layers without kv cache.
continue
kernel_block_size = kernel_block_sizes[group.kv_cache_group_id]
for layer_name in group.layer_names: for layer_name in group.layer_names:
if layer_name in self.runner_only_attn_layers: if layer_name in self.runner_only_attn_layers:
continue continue
@ -4392,24 +4454,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
if isinstance(kv_cache_spec, AttentionSpec): if isinstance(kv_cache_spec, AttentionSpec):
has_attn = True has_attn = True
kv_manager_block_size = kv_cache_spec.block_size num_blocks_per_kv_block = (
kernel_size_list = self._find_compatible_block_sizes( kv_cache_spec.block_size // kernel_block_size
kv_manager_block_size, attn_backend, return_all=False
) )
kernel_size = kernel_size_list[0]
num_blocks_per_kv_block = kv_manager_block_size // kernel_size
kernel_num_blocks = num_blocks * num_blocks_per_kv_block kernel_num_blocks = num_blocks * num_blocks_per_kv_block
kv_cache_shape = attn_backend.get_kv_cache_shape( kv_cache_shape = attn_backend.get_kv_cache_shape(
kernel_num_blocks, kernel_num_blocks,
kernel_size, kernel_block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size, kv_cache_spec.head_size,
cache_dtype_str=self.cache_config.cache_dtype, cache_dtype_str=self.cache_config.cache_dtype,
) )
dtype = kv_cache_spec.dtype dtype = kv_cache_spec.dtype
try: try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() # noqa: E501 kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
assert len(kv_cache_stride_order) == len(kv_cache_shape) assert len(kv_cache_stride_order) == len(kv_cache_shape)
except (AttributeError, NotImplementedError): except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(kv_cache_shape))) kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
@ -4492,13 +4551,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
) )
def initialize_kv_cache_tensors( def initialize_kv_cache_tensors(
self, kv_cache_config: KVCacheConfig self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int]
) -> dict[str, torch.Tensor]: ) -> dict[str, torch.Tensor]:
""" """
Initialize the memory buffer for KV cache. Initialize the memory buffer for KV cache.
Args: Args:
kv_cache_config: The KV cache config kv_cache_config: The KV cache config
kernel_block_sizes: The kernel block sizes for each KV cache group.
Returns: Returns:
Dict[str, torch.Tensor]: A map between layer names to their Dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache. corresponding memory buffer for KV cache.
@ -4507,7 +4568,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config) kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config)
# Change the memory buffer to the desired shape # Change the memory buffer to the desired shape
kv_caches = self._reshape_kv_cache_tensors( kv_caches = self._reshape_kv_cache_tensors(
kv_cache_config, kv_cache_raw_tensors kv_cache_config, kv_cache_raw_tensors, kernel_block_sizes
) )
# Set up cross-layer KV cache sharing # Set up cross-layer KV cache sharing
@ -4566,9 +4627,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.may_add_encoder_only_layers_to_kv_cache_config() self.may_add_encoder_only_layers_to_kv_cache_config()
self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config) self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)
self.initialize_attn_backend(kv_cache_config) self.initialize_attn_backend(kv_cache_config)
# The kernel block size for all KV cache groups. For example, if
# kv_cache_manager uses block_size 256 for a given group, but the attention
# backends for that group only supports block_size 64, we will return
# kernel_block_size 64 and split the 256-token-block to 4 blocks with 64
# tokens each.
kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config)
# Reinitialize need to after initialize_attn_backend # Reinitialize need to after initialize_attn_backend
self.may_reinitialize_input_batch(kv_cache_config) self.may_reinitialize_input_batch(kv_cache_config, kernel_block_sizes)
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config) kv_caches = self.initialize_kv_cache_tensors(
kv_cache_config, kernel_block_sizes
)
if self.speculative_config and self.speculative_config.use_eagle(): if self.speculative_config and self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer) assert isinstance(self.drafter, EagleProposer)

View File

@ -6,6 +6,7 @@ import copy
import gc import gc
import os import os
from contextlib import AbstractContextManager, nullcontext from contextlib import AbstractContextManager, nullcontext
from types import NoneType
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
import torch import torch
@ -19,7 +20,11 @@ from vllm.distributed import (
init_distributed_environment, init_distributed_environment,
set_custom_all_reduce, set_custom_all_reduce,
) )
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized from vllm.distributed.kv_transfer import (
ensure_kv_transfer_initialized,
get_kv_transfer_group,
has_kv_transfer_group,
)
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
get_pp_group, get_pp_group,
get_tp_group, get_tp_group,
@ -33,6 +38,7 @@ from vllm.sequence import IntermediateTensors
from vllm.tasks import SupportedTask from vllm.tasks import SupportedTask
from vllm.utils.mem_constants import GiB_bytes from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.mem_utils import MemorySnapshot, memory_profiling from vllm.utils.mem_utils import MemorySnapshot, memory_profiling
from vllm.v1.core.sched.output import GrammarOutput
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ( from vllm.v1.outputs import (
@ -348,6 +354,21 @@ class Worker(WorkerBase):
return int(self.available_kv_cache_memory_bytes) return int(self.available_kv_cache_memory_bytes)
def get_kv_connector_handshake_metadata(self) -> dict | None:
"""Get KV connector metadata from this worker if available."""
if not has_kv_transfer_group():
return None
connector = get_kv_transfer_group()
# Return None for connectors that don't need to exchange handshake
# metadata across workers.
if (metadata := connector.get_handshake_metadata()) is None:
return None
tp_rank = get_tp_group().rank_in_group
return {tp_rank: metadata}
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
return self.model_runner.get_kv_cache_spec() return self.model_runner.get_kv_cache_spec()
@ -489,11 +510,16 @@ class Worker(WorkerBase):
def get_supported_tasks(self) -> tuple[SupportedTask, ...]: def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.model_runner.get_supported_tasks() return self.model_runner.get_supported_tasks()
@torch.inference_mode()
def sample_tokens(
self, grammar_output: "GrammarOutput"
) -> ModelRunnerOutput | AsyncModelRunnerOutput:
return self.model_runner.sample_tokens(grammar_output)
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self, scheduler_output: "SchedulerOutput"
scheduler_output: "SchedulerOutput", ) -> ModelRunnerOutput | None:
) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
intermediate_tensors = None intermediate_tensors = None
forward_pass = scheduler_output.total_num_scheduled_tokens > 0 forward_pass = scheduler_output.total_num_scheduled_tokens > 0
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
@ -512,13 +538,13 @@ class Worker(WorkerBase):
) )
output = self.model_runner.execute_model(scheduler_output, intermediate_tensors) output = self.model_runner.execute_model(scheduler_output, intermediate_tensors)
if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)): if isinstance(output, (ModelRunnerOutput, NoneType)):
return output return output
assert isinstance(output, IntermediateTensors) assert isinstance(output, IntermediateTensors)
parallel_config = self.vllm_config.parallel_config parallel_config = self.vllm_config.parallel_config
assert ( assert (
parallel_config.distributed_executor_backend != ("external_launcher") parallel_config.distributed_executor_backend != "external_launcher"
and not get_pp_group().is_last_rank and not get_pp_group().is_last_rank
) )

View File

@ -139,7 +139,7 @@ class InputBatch:
self.min_tokens: dict[int, tuple[int, set[int]]] = {} self.min_tokens: dict[int, tuple[int, set[int]]] = {}
# lora related # lora related
self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int32) self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int64)
self.lora_id_to_request_ids: dict[int, set[str]] = {} self.lora_id_to_request_ids: dict[int, set[str]] = {}
self.lora_id_to_lora_request: dict[int, LoRARequest] = {} self.lora_id_to_lora_request: dict[int, LoRARequest] = {}

View File

@ -92,7 +92,7 @@ from .utils import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
logger = init_logger(__name__) logger = init_logger(__name__)
@ -372,6 +372,11 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
else: else:
self.sample_from_logits_func = self.sample_from_logits self.sample_from_logits_func = self.sample_from_logits
# For passing scheduler_output between successive
# execute_model() and sample_tokens() calls.
self.scheduler_output: SchedulerOutput | None = None
self.mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None
def reset_mm_cache(self) -> None: def reset_mm_cache(self) -> None:
if self.mm_budget: if self.mm_budget:
self.mm_budget.reset_cache() self.mm_budget.reset_cache()
@ -1078,7 +1083,12 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
) -> ModelRunnerOutput: ) -> ModelRunnerOutput | None:
if self.scheduler_output is not None:
raise RuntimeError(
"State error: sample_tokens() must be called "
"after execute_model() returns None."
)
# Update cached state # Update cached state
self._update_states(scheduler_output) self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens: if not scheduler_output.total_num_scheduled_tokens:
@ -1088,14 +1098,30 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return self.kv_connector_no_forward(scheduler_output, self.vllm_config) return self.kv_connector_no_forward(scheduler_output, self.vllm_config)
mm_embed_inputs = None
if self.supports_mm_inputs: if self.supports_mm_inputs:
# Run the multimodal encoder if any. # Run the multimodal encoder if any.
self._execute_mm_encoder(scheduler_output) self._execute_mm_encoder(scheduler_output)
mm_embed_inputs = self._gather_mm_embeddings(scheduler_output) mm_embed_inputs = self._gather_mm_embeddings(scheduler_output)
else:
mm_embed_inputs = None
torch_xla.sync(wait=False) torch_xla.sync(wait=False)
self.scheduler_output = scheduler_output
self.mm_embed_inputs = mm_embed_inputs
return None
@torch.no_grad()
def sample_tokens(
self, grammar_output: "GrammarOutput | None"
) -> ModelRunnerOutput:
if self.scheduler_output is None:
# Nothing to do (PP non-final rank case), output isn't used.
return None # noqa
scheduler_output = self.scheduler_output
mm_embed_inputs = self.mm_embed_inputs
self.scheduler_output = None
self.mm_embed_inputs = None
# Prepare inputs, the requests might be split into multiple # Prepare inputs, the requests might be split into multiple
# executions, combine the result of each execution. # executions, combine the result of each execution.
start_index = 0 start_index = 0
@ -1131,9 +1157,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch( tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
self.input_batch, padded_num_reqs, self.device self.input_batch, padded_num_reqs, self.device
) )
if scheduler_output.grammar_bitmask is not None: if grammar_output is not None:
require_struct_decoding, grammar_bitmask_padded, arange = ( require_struct_decoding, grammar_bitmask_padded, arange = (
self.prepare_structured_decoding_input(logits, scheduler_output) self.prepare_structured_decoding_input(logits, grammar_output)
) )
logits = self.structured_decode( logits = self.structured_decode(
require_struct_decoding, grammar_bitmask_padded, logits, arange require_struct_decoding, grammar_bitmask_padded, logits, arange
@ -1954,10 +1980,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return self.model.get_input_embeddings(*args, **kwargs) return self.model.get_input_embeddings(*args, **kwargs)
def prepare_structured_decoding_input( def prepare_structured_decoding_input(
self, logits: torch.Tensor, scheduler_output: "SchedulerOutput" self, logits: torch.Tensor, grammar_output: "GrammarOutput"
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
grammar_bitmask = scheduler_output.grammar_bitmask grammar_bitmask = grammar_output.grammar_bitmask
assert grammar_bitmask is not None
num_reqs, _ = logits.shape num_reqs, _ = logits.shape
# Reset pre-allocated tensors # Reset pre-allocated tensors
@ -1965,7 +1990,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.require_structured_out_cpu.zero_() self.require_structured_out_cpu.zero_()
cumulative_mask_idx = 0 cumulative_mask_idx = 0
for req_id in scheduler_output.structured_output_request_ids: for req_id in grammar_output.structured_output_request_ids:
if req_id not in self.input_batch.req_id_to_index: if req_id not in self.input_batch.req_id_to_index:
continue continue
batch_index = self.input_batch.req_id_to_index[req_id] batch_index = self.input_batch.req_id_to_index[req_id]

View File

@ -17,7 +17,6 @@ from vllm.distributed import (
) )
from vllm.distributed.kv_transfer import ( from vllm.distributed.kv_transfer import (
ensure_kv_transfer_initialized, ensure_kv_transfer_initialized,
has_kv_transfer_group,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
@ -27,7 +26,7 @@ from vllm.platforms.tpu import USE_TPU_INFERENCE
from vllm.tasks import SupportedTask from vllm.tasks import SupportedTask
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, KVCacheSpec from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.utils import report_usage_stats from vllm.v1.utils import report_usage_stats
@ -255,13 +254,13 @@ class TPUWorker:
tpu_kv_cache_bytes = tpu_kv_cache_bytes * head_size // padded_head_size tpu_kv_cache_bytes = tpu_kv_cache_bytes * head_size // padded_head_size
return int(tpu_kv_cache_bytes) return int(tpu_kv_cache_bytes)
def sample_tokens(self, grammar_output: "GrammarOutput") -> ModelRunnerOutput:
return self.model_runner.sample_tokens(grammar_output)
def execute_model( def execute_model(
self, self, scheduler_output: "SchedulerOutput"
scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput | None: ) -> ModelRunnerOutput | None:
output = self.model_runner.execute_model(scheduler_output) return self.model_runner.execute_model(scheduler_output)
# every worker's output is needed when kv_transfer_group is set up
return output if self.is_driver_worker or has_kv_transfer_group() else None
def profile(self, is_start: bool = True): def profile(self, is_start: bool = True):
if self.rank < 1: if self.rank < 1:

View File

@ -140,6 +140,7 @@ class AttentionGroup:
metadata_builders: list[AttentionMetadataBuilder] metadata_builders: list[AttentionMetadataBuilder]
layer_names: list[str] layer_names: list[str]
kv_cache_spec: KVCacheSpec kv_cache_spec: KVCacheSpec
kv_cache_group_id: int
@staticmethod @staticmethod
def create_with_metadata_builders( def create_with_metadata_builders(
@ -148,13 +149,16 @@ class AttentionGroup:
kv_cache_spec: KVCacheSpec, kv_cache_spec: KVCacheSpec,
vllm_config: VllmConfig, vllm_config: VllmConfig,
device: torch.device, device: torch.device,
kv_cache_group_id: int,
num_metadata_builders: int = 1, num_metadata_builders: int = 1,
) -> "AttentionGroup": ) -> "AttentionGroup":
metadata_builders = [ metadata_builders = [
backend.get_builder_cls()(kv_cache_spec, layer_names, vllm_config, device) backend.get_builder_cls()(kv_cache_spec, layer_names, vllm_config, device)
for _ in range(num_metadata_builders) for _ in range(num_metadata_builders)
] ]
return AttentionGroup(backend, metadata_builders, layer_names, kv_cache_spec) return AttentionGroup(
backend, metadata_builders, layer_names, kv_cache_spec, kv_cache_group_id
)
def get_metadata_builder(self, ubatch_id: int = 0) -> AttentionMetadataBuilder: def get_metadata_builder(self, ubatch_id: int = 0) -> AttentionMetadataBuilder:
assert len(self.metadata_builders) > ubatch_id assert len(self.metadata_builders) > ubatch_id

View File

@ -20,10 +20,12 @@ from vllm.v1.kv_cache_interface import KVCacheSpec
from vllm.v1.serial_utils import run_method from vllm.v1.serial_utils import run_method
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import AsyncModelRunnerOutput, ModelRunnerOutput
else: else:
SchedulerOutput = object SchedulerOutput = object
GrammarOutput = object
AsyncModelRunnerOutput = object
ModelRunnerOutput = object ModelRunnerOutput = object
logger = init_logger(__name__) logger = init_logger(__name__)
@ -122,7 +124,21 @@ class WorkerBase:
"""Load model onto target device.""" """Load model onto target device."""
raise NotImplementedError raise NotImplementedError
def execute_model(self, scheduler_output: SchedulerOutput) -> ModelRunnerOutput: def execute_model(
self, scheduler_output: SchedulerOutput
) -> ModelRunnerOutput | None:
"""If this method returns None, sample_tokens should be called immediately after
to obtain the ModelRunnerOutput.
Note that this design may be changed in future if/when structured outputs
parallelism is re-architected.
"""
raise NotImplementedError
def sample_tokens(
self, grammar_output: GrammarOutput
) -> ModelRunnerOutput | AsyncModelRunnerOutput:
"""Should be called immediately after execute_model iff it returned None."""
raise NotImplementedError raise NotImplementedError
def get_cache_block_size_bytes(self) -> int: def get_cache_block_size_bytes(self) -> int:
@ -344,7 +360,7 @@ class WorkerWrapperBase:
scheduler_output: SchedulerOutput, scheduler_output: SchedulerOutput,
*args, *args,
**kwargs, **kwargs,
) -> ModelRunnerOutput: ) -> ModelRunnerOutput | None:
self._apply_mm_cache(scheduler_output) self._apply_mm_cache(scheduler_output)
assert self.worker is not None assert self.worker is not None