mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 03:57:02 +08:00
Merge remote-tracking branch 'origin/main' into refactor-fp8-linear
This commit is contained in:
commit
8e8218ebac
@ -441,7 +441,7 @@ steps:
|
||||
--ignore=lora/test_llm_with_multi_loras.py \
|
||||
--ignore=lora/test_olmoe_tp.py \
|
||||
--ignore=lora/test_deepseekv2_tp.py \
|
||||
--ignore=lora/test_gptoss.py \
|
||||
--ignore=lora/test_gptoss_tp.py \
|
||||
--ignore=lora/test_qwen3moe_tp.py
|
||||
parallelism: 4
|
||||
|
||||
@ -1217,6 +1217,8 @@ steps:
|
||||
- pytest -v -s -x lora/test_llama_tp.py
|
||||
- pytest -v -s -x lora/test_llm_with_multi_loras.py
|
||||
- pytest -v -s -x lora/test_olmoe_tp.py
|
||||
- pytest -v -s -x lora/test_gptoss_tp.py
|
||||
|
||||
|
||||
- label: Weight Loading Multiple GPU Test # 33min
|
||||
timeout_in_minutes: 45
|
||||
|
||||
@ -340,6 +340,16 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s v1/attention
|
||||
|
||||
- label: V1 Test attention (B200) # 10min
|
||||
timeout_in_minutes: 30
|
||||
gpu: b200
|
||||
source_file_dependencies:
|
||||
- vllm/v1/attention
|
||||
- tests/v1/attention
|
||||
commands:
|
||||
- export VLLM_DISABLE_FLASHINFER_PREFILL=1 # TODO: FI prefill is bugged and causes incorrectness, fix this
|
||||
- pytest -v -s v1/attention
|
||||
|
||||
- label: V1 Test others (CPU) # 5 mins
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -417,7 +427,7 @@ steps:
|
||||
--ignore=lora/test_llm_with_multi_loras.py \
|
||||
--ignore=lora/test_olmoe_tp.py \
|
||||
--ignore=lora/test_deepseekv2_tp.py \
|
||||
--ignore=lora/test_gptoss.py \
|
||||
--ignore=lora/test_gptoss_tp.py \
|
||||
--ignore=lora/test_qwen3moe_tp.py
|
||||
|
||||
parallelism: 4
|
||||
@ -1119,6 +1129,7 @@ steps:
|
||||
- pytest -v -s -x lora/test_llama_tp.py
|
||||
- pytest -v -s -x lora/test_llm_with_multi_loras.py
|
||||
- pytest -v -s -x lora/test_olmoe_tp.py
|
||||
- pytest -v -s -x lora/test_gptoss_tp.py
|
||||
|
||||
|
||||
- label: Weight Loading Multiple GPU Test # 33min
|
||||
|
||||
@ -1429,8 +1429,6 @@ async def main() -> None:
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
if not os.path.exists(args.model):
|
||||
raise OSError(f"Path does not exist: {args.model}")
|
||||
logger.info("Loading tokenizer")
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ This doc serves as a collection of handy tips for optimizing your vLLM on TPU wo
|
||||
|
||||
## Get started
|
||||
|
||||
Looking for setup and installation instructions? Find them [here](../getting_started/installation/google_tpu.md).
|
||||
Looking for setup and installation instructions? Find them [here](https://docs.vllm.ai/projects/tpu/en/latest/getting_started/installation/).
|
||||
|
||||
### TPU workload sizing
|
||||
|
||||
|
||||
133
docs/features/batch_invariance.md
Normal file
133
docs/features/batch_invariance.md
Normal file
@ -0,0 +1,133 @@
|
||||
# Batch Invariance
|
||||
|
||||
!!! note
|
||||
Batch invariance is currently in beta. Some features are still under active development.
|
||||
Track progress and planned improvements at <https://github.com/vllm-project/vllm/issues/27433>
|
||||
|
||||
This document shows how to enable batch invariance in vLLM. Batch invariance ensures that the output of a model is deterministic and independent of the batch size or the order of requests in a batch.
|
||||
|
||||
## Motivation
|
||||
|
||||
Batch invariance is crucial for several use cases:
|
||||
|
||||
- **Framework debugging**: Deterministic outputs make it easier to debug issues in the inference framework, as the same input will always produce the same output regardless of batching.
|
||||
- **Model debugging**: Helps identify issues in model implementations by ensuring consistent behavior across different batch configurations.
|
||||
- **Reinforcement Learning (RL)**: RL training often requires deterministic rollouts for reproducibility and stable training.
|
||||
- **Large-scale inference systems**: Systems that use vLLM as a component benefit from deterministic behavior for testing, validation, and consistency guarantees.
|
||||
|
||||
## Hardware Requirements
|
||||
|
||||
Batch invariance currently requires NVIDIA GPUs with compute capability 9.0 or higher:
|
||||
|
||||
- **H-series**: H100, H200
|
||||
- **B-series**: B100, B200
|
||||
|
||||
## Enabling Batch Invariance
|
||||
|
||||
Batch invariance can be enabled by setting the `VLLM_BATCH_INVARIANT` environment variable to `1`:
|
||||
|
||||
```bash
|
||||
export VLLM_BATCH_INVARIANT=1
|
||||
```
|
||||
|
||||
### Online Inference (Server Mode)
|
||||
|
||||
To start a vLLM server with batch invariance enabled:
|
||||
|
||||
```bash
|
||||
VLLM_BATCH_INVARIANT=1 vllm serve meta-llama/Llama-3.1-8B-Instruct
|
||||
```
|
||||
|
||||
Then use the OpenAI-compatible client:
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
api_key="EMPTY",
|
||||
base_url="http://localhost:8000/v1",
|
||||
)
|
||||
|
||||
# These requests will produce deterministic outputs
|
||||
# regardless of batch size or order
|
||||
response = client.completions.create(
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
prompt="The future of AI is",
|
||||
max_tokens=100,
|
||||
temperature=0.7,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
print(response.choices[0].text)
|
||||
```
|
||||
|
||||
### Offline Inference
|
||||
|
||||
For offline batch inference with batch invariance:
|
||||
|
||||
```python
|
||||
import os
|
||||
os.environ["VLLM_BATCH_INVARIANT"] = "1"
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
prompts = [
|
||||
"The future of AI is",
|
||||
"Machine learning enables",
|
||||
"Deep learning models can",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.7,
|
||||
top_p=0.95,
|
||||
max_tokens=100,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
# Outputs will be deterministic regardless of batch size
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}")
|
||||
print(f"Generated: {generated_text!r}\n")
|
||||
```
|
||||
|
||||
## Tested Models
|
||||
|
||||
Batch invariance has been tested and verified on the following models:
|
||||
|
||||
- **DeepSeek series**: `deepseek-ai/DeepSeek-V3`, `deepseek-ai/DeepSeek-V3-0324`, `deepseek-ai/DeepSeek-R1`, `deepseek-ai/DeepSeek-V3.1`
|
||||
- **Qwen3 (Dense)**: `Qwen/Qwen3-1.7B`, `Qwen/Qwen3-8B`
|
||||
- **Qwen3 (MoE)**: `Qwen/Qwen3-30B-A3B`, `Qwen/Qwen3-Next-80B-A3B-Instruct`
|
||||
- **Llama 3**: `meta-llama/Llama-3.1-8B-Instruct`, `meta-llama/Llama-3.2-1B-Instruct`
|
||||
|
||||
Other models may also work, but these have been explicitly validated. If you encounter issues with a specific model, please report them on the [GitHub issue tracker](https://github.com/vllm-project/vllm/issues/new/choose).
|
||||
|
||||
## Implementation Details
|
||||
|
||||
When batch invariance is enabled, vLLM:
|
||||
|
||||
1. Uses deterministic kernel implementations for attention and other operations
|
||||
2. Ensures consistent numerical behavior across different batch sizes
|
||||
3. Disables certain optimizations that may introduce non-determinism (such as custom all-reduce operations in tensor parallel mode)
|
||||
|
||||
!!! note
|
||||
Enabling batch invariance may impact performance compared to the default non-deterministic mode. This trade-off is intentional to guarantee reproducibility.
|
||||
|
||||
## Future Improvements
|
||||
|
||||
The batch invariance feature is under active development. Planned improvements include:
|
||||
|
||||
- Support for additional GPU architectures
|
||||
- Expanded model coverage
|
||||
- Performance optimizations
|
||||
- Additional testing and validation
|
||||
|
||||
For the latest status and to contribute ideas, see the [tracking issue](https://github.com/vllm-project/vllm/issues/27433).
|
||||
@ -81,7 +81,7 @@ python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py \
|
||||
- Default: 5600
|
||||
- **Required for both prefiller and decoder instances**
|
||||
- Each vLLM worker needs a unique port on its host; using the same port number across different hosts is fine
|
||||
- For TP/DP deployments, each worker's port on a node is computed as: base_port + dp_rank * tp_size + tp_rank (e.g., with `--tensor-parallel-size=4` and base_port=5600, tp_rank 0..3 use ports 5600, 5601, 5602, 5603 on that node).
|
||||
- For TP/DP deployments, each worker's port on a node is computed as: base_port + dp_rank (e.g., with `--data-parallel-size=2` and base_port=5600, dp_rank 0..1 use port 5600, 5601 on that node).
|
||||
- Used for the initial NIXL handshake between the prefiller and the decoder
|
||||
|
||||
- `VLLM_NIXL_SIDE_CHANNEL_HOST`: Host for side channel communication
|
||||
|
||||
@ -2,4 +2,4 @@ nav:
|
||||
- README.md
|
||||
- gpu.md
|
||||
- cpu.md
|
||||
- google_tpu.md
|
||||
- TPU: https://docs.vllm.ai/projects/tpu/en/latest/getting_started/installation/
|
||||
|
||||
@ -11,7 +11,6 @@ vLLM supports the following hardware platforms:
|
||||
- [ARM AArch64](cpu.md#arm-aarch64)
|
||||
- [Apple silicon](cpu.md#apple-silicon)
|
||||
- [IBM Z (S390X)](cpu.md#ibm-z-s390x)
|
||||
- [Google TPU](google_tpu.md)
|
||||
|
||||
## Hardware Plugins
|
||||
|
||||
@ -20,6 +19,7 @@ The backends below live **outside** the main `vllm` repository and follow the
|
||||
|
||||
| Accelerator | PyPI / package | Repository |
|
||||
|-------------|----------------|------------|
|
||||
| Google TPU | `tpu-inference` | <https://github.com/vllm-project/tpu-inference> |
|
||||
| Ascend NPU | `vllm-ascend` | <https://github.com/vllm-project/vllm-ascend> |
|
||||
| Intel Gaudi (HPU) | N/A, install from source | <https://github.com/vllm-project/vllm-gaudi> |
|
||||
| MetaX MACA GPU | N/A, install from source | <https://github.com/MetaX-MACA/vLLM-metax> |
|
||||
|
||||
@ -1,193 +0,0 @@
|
||||
# Google TPU
|
||||
|
||||
Tensor Processing Units (TPUs) are Google's custom-developed application-specific
|
||||
integrated circuits (ASICs) used to accelerate machine learning workloads. TPUs
|
||||
are available in different versions each with different hardware specifications.
|
||||
For more information about TPUs, see [TPU System Architecture](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm).
|
||||
For more information on the TPU versions supported with vLLM, see:
|
||||
|
||||
- [TPU v6e](https://cloud.google.com/tpu/docs/v6e)
|
||||
- [TPU v5e](https://cloud.google.com/tpu/docs/v5e)
|
||||
- [TPU v5p](https://cloud.google.com/tpu/docs/v5p)
|
||||
- [TPU v4](https://cloud.google.com/tpu/docs/v4)
|
||||
|
||||
These TPU versions allow you to configure the physical arrangements of the TPU
|
||||
chips. This can improve throughput and networking performance. For more
|
||||
information see:
|
||||
|
||||
- [TPU v6e topologies](https://cloud.google.com/tpu/docs/v6e#configurations)
|
||||
- [TPU v5e topologies](https://cloud.google.com/tpu/docs/v5e#tpu-v5e-config)
|
||||
- [TPU v5p topologies](https://cloud.google.com/tpu/docs/v5p#tpu-v5p-config)
|
||||
- [TPU v4 topologies](https://cloud.google.com/tpu/docs/v4#tpu-v4-config)
|
||||
|
||||
In order for you to use Cloud TPUs you need to have TPU quota granted to your
|
||||
Google Cloud Platform project. TPU quotas specify how many TPUs you can use in a
|
||||
GPC project and are specified in terms of TPU version, the number of TPU you
|
||||
want to use, and quota type. For more information, see [TPU quota](https://cloud.google.com/tpu/docs/quota#tpu_quota).
|
||||
|
||||
For TPU pricing information, see [Cloud TPU pricing](https://cloud.google.com/tpu/pricing).
|
||||
|
||||
You may need additional persistent storage for your TPU VMs. For more
|
||||
information, see [Storage options for Cloud TPU data](https://cloud.devsite.corp.google.com/tpu/docs/storage-options).
|
||||
|
||||
!!! warning
|
||||
There are no pre-built wheels for this device, so you must either use the pre-built Docker image or build vLLM from source.
|
||||
|
||||
## Requirements
|
||||
|
||||
- Google Cloud TPU VM
|
||||
- TPU versions: v6e, v5e, v5p, v4
|
||||
- Python: 3.11 or newer
|
||||
|
||||
### Provision Cloud TPUs
|
||||
|
||||
You can provision Cloud TPUs using the [Cloud TPU API](https://cloud.google.com/tpu/docs/reference/rest)
|
||||
or the [queued resources](https://cloud.google.com/tpu/docs/queued-resources)
|
||||
API (preferred). This section shows how to create TPUs using the queued resource API. For
|
||||
more information about using the Cloud TPU API, see [Create a Cloud TPU using the Create Node API](https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm#create-node-api).
|
||||
Queued resources enable you to request Cloud TPU resources in a queued manner.
|
||||
When you request queued resources, the request is added to a queue maintained by
|
||||
the Cloud TPU service. When the requested resource becomes available, it's
|
||||
assigned to your Google Cloud project for your immediate exclusive use.
|
||||
|
||||
!!! note
|
||||
In all of the following commands, replace the ALL CAPS parameter names with
|
||||
appropriate values. See the parameter descriptions table for more information.
|
||||
|
||||
### Provision Cloud TPUs with GKE
|
||||
|
||||
For more information about using TPUs with GKE, see:
|
||||
|
||||
- [About TPUs in GKE](https://cloud.google.com/kubernetes-engine/docs/concepts/tpus)
|
||||
- [Deploy TPU workloads in GKE Standard](https://cloud.google.com/kubernetes-engine/docs/how-to/tpus)
|
||||
- [Plan for TPUs in GKE](https://cloud.google.com/kubernetes-engine/docs/concepts/plan-tpus)
|
||||
|
||||
## Configure a new environment
|
||||
|
||||
### Provision a Cloud TPU with the queued resource API
|
||||
|
||||
Create a TPU v5e with 4 TPU chips:
|
||||
|
||||
```bash
|
||||
gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \
|
||||
--node-id TPU_NAME \
|
||||
--project PROJECT_ID \
|
||||
--zone ZONE \
|
||||
--accelerator-type ACCELERATOR_TYPE \
|
||||
--runtime-version RUNTIME_VERSION \
|
||||
--service-account SERVICE_ACCOUNT
|
||||
```
|
||||
|
||||
| Parameter name | Description |
|
||||
|--------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| QUEUED_RESOURCE_ID | The user-assigned ID of the queued resource request. |
|
||||
| TPU_NAME | The user-assigned name of the TPU which is created when the queued resource request is allocated. |
|
||||
| PROJECT_ID | Your Google Cloud project |
|
||||
| ZONE | The GCP zone where you want to create your Cloud TPU. The value you use depends on the version of TPUs you are using. For more information, see [TPU regions and zones] |
|
||||
| ACCELERATOR_TYPE | The TPU version you want to use. Specify the TPU version, for example `v5litepod-4` specifies a v5e TPU with 4 cores, `v6e-1` specifies a v6e TPU with 1 core. For more information, see [TPU versions]. |
|
||||
| RUNTIME_VERSION | The TPU VM runtime version to use. For example, use `v2-alpha-tpuv6e` for a VM loaded with one or more v6e TPU(s). |
|
||||
| SERVICE_ACCOUNT | The email address for your service account. You can find it in the IAM Cloud Console under *Service Accounts*. For example: `tpu-service-account@<your_project_ID>.iam.gserviceaccount.com` |
|
||||
|
||||
Connect to your TPU VM using SSH:
|
||||
|
||||
```bash
|
||||
gcloud compute tpus tpu-vm ssh TPU_NAME --project PROJECT_ID --zone ZONE
|
||||
```
|
||||
|
||||
!!! note
|
||||
When configuring `RUNTIME_VERSION` ("TPU software version") on GCP, ensure it matches the TPU generation you've selected by referencing the [TPU VM images] compatibility matrix. Using an incompatible version may prevent vLLM from running correctly.
|
||||
|
||||
[TPU versions]: https://cloud.google.com/tpu/docs/runtimes
|
||||
[TPU VM images]: https://cloud.google.com/tpu/docs/runtimes
|
||||
[TPU regions and zones]: https://cloud.google.com/tpu/docs/regions-zones
|
||||
|
||||
## Set up using Python
|
||||
|
||||
### Pre-built wheels
|
||||
|
||||
Currently, there are no pre-built TPU wheels.
|
||||
|
||||
### Build wheel from source
|
||||
|
||||
Install Miniconda:
|
||||
|
||||
```bash
|
||||
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
|
||||
bash Miniconda3-latest-Linux-x86_64.sh
|
||||
source ~/.bashrc
|
||||
```
|
||||
|
||||
Create and activate a Conda environment for vLLM:
|
||||
|
||||
```bash
|
||||
conda create -n vllm python=3.12 -y
|
||||
conda activate vllm
|
||||
```
|
||||
|
||||
Clone the vLLM repository and go to the vLLM directory:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/vllm-project/vllm.git && cd vllm
|
||||
```
|
||||
|
||||
Uninstall the existing `torch` and `torch_xla` packages:
|
||||
|
||||
```bash
|
||||
pip uninstall torch torch-xla -y
|
||||
```
|
||||
|
||||
Install build dependencies:
|
||||
|
||||
```bash
|
||||
pip install -r requirements/tpu.txt
|
||||
sudo apt-get install --no-install-recommends --yes libopenblas-base libopenmpi-dev libomp-dev
|
||||
```
|
||||
|
||||
Run the setup script:
|
||||
|
||||
```bash
|
||||
VLLM_TARGET_DEVICE="tpu" python -m pip install -e .
|
||||
```
|
||||
|
||||
## Set up using Docker
|
||||
|
||||
### Pre-built images
|
||||
|
||||
See [Using Docker](../../deployment/docker.md) for instructions on using the official Docker image, making sure to substitute the image name `vllm/vllm-openai` with `vllm/vllm-tpu`.
|
||||
|
||||
### Build image from source
|
||||
|
||||
You can use [docker/Dockerfile.tpu](../../../docker/Dockerfile.tpu) to build a Docker image with TPU support.
|
||||
|
||||
```bash
|
||||
docker build -f docker/Dockerfile.tpu -t vllm-tpu .
|
||||
```
|
||||
|
||||
Run the Docker image with the following command:
|
||||
|
||||
```bash
|
||||
# Make sure to add `--privileged --net host --shm-size=16G`.
|
||||
docker run --privileged --net host --shm-size=16G -it vllm-tpu
|
||||
```
|
||||
|
||||
!!! note
|
||||
Since TPU relies on XLA which requires static shapes, vLLM bucketizes the
|
||||
possible input shapes and compiles an XLA graph for each shape. The
|
||||
compilation time may take 20~30 minutes in the first run. However, the
|
||||
compilation time reduces to ~5 minutes afterwards because the XLA graphs are
|
||||
cached in the disk (in `VLLM_XLA_CACHE_PATH` or `~/.cache/vllm/xla_cache` by default).
|
||||
|
||||
!!! tip
|
||||
If you encounter the following error:
|
||||
|
||||
```console
|
||||
from torch._C import * # noqa: F403
|
||||
ImportError: libopenblas.so.0: cannot open shared object file: No such
|
||||
file or directory
|
||||
```
|
||||
|
||||
Install OpenBLAS with the following command:
|
||||
|
||||
```bash
|
||||
sudo apt-get install --no-install-recommends --yes libopenblas-base libopenmpi-dev libomp-dev
|
||||
```
|
||||
@ -63,6 +63,17 @@ This guide will help you quickly get started with vLLM to perform:
|
||||
rocm/vllm-dev:nightly
|
||||
```
|
||||
|
||||
=== "Google TPU"
|
||||
|
||||
To run vLLM on Google TPUs, you need to install the `vllm-tpu` package.
|
||||
|
||||
```bash
|
||||
uv pip install vllm-tpu
|
||||
```
|
||||
|
||||
!!! note
|
||||
For more detailed instructions, including Docker, installing from source, and troubleshooting, please refer to the [vLLM on TPU documentation](https://docs.vllm.ai/projects/tpu/en/latest/).
|
||||
|
||||
!!! note
|
||||
For more detail and non-CUDA platforms, please refer [here](installation/README.md) for specific instructions on how to install vLLM.
|
||||
|
||||
|
||||
@ -7,7 +7,7 @@ requests >= 2.26.0
|
||||
tqdm
|
||||
blake3
|
||||
py-cpuinfo
|
||||
transformers >= 4.56.0
|
||||
transformers >= 4.56.0, < 5
|
||||
tokenizers >= 0.21.1 # Required for fast incremental detokenization.
|
||||
protobuf # Required by LlamaTokenizer.
|
||||
fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint.
|
||||
|
||||
@ -29,7 +29,7 @@ opencv-python-headless >= 4.11.0 # required for video test
|
||||
datamodel_code_generator # required for minicpm3 test
|
||||
lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d # required for model evaluation test
|
||||
mteb>=1.38.11, <2 # required for mteb test
|
||||
transformers==4.56.2
|
||||
transformers==4.57.1
|
||||
tokenizers==0.22.0
|
||||
schemathesis>=3.39.15 # Required for openai schema test.
|
||||
# quantization
|
||||
|
||||
@ -37,7 +37,7 @@ datamodel_code_generator # required for minicpm3 test
|
||||
# TODO: Use lm-eval[api]==0.4.10 once released
|
||||
lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d # required for model evaluation test
|
||||
mteb[bm25s]>=1.38.11, <2 # required for mteb test
|
||||
transformers==4.56.2
|
||||
transformers==4.57.1
|
||||
tokenizers==0.22.0
|
||||
schemathesis>=3.39.15 # Required for openai schema test.
|
||||
# quantization
|
||||
|
||||
@ -1196,7 +1196,7 @@ tqdm==4.66.6
|
||||
# transformers
|
||||
tqdm-multiprocess==0.0.11
|
||||
# via lm-eval
|
||||
transformers==4.56.2
|
||||
transformers==4.57.1
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# genai-perf
|
||||
|
||||
@ -6,6 +6,9 @@ from copy import deepcopy
|
||||
|
||||
from tblib import pickling_support
|
||||
|
||||
# Import fixture
|
||||
from tests.v1.entrypoints.conftest import sample_json_schema # noqa
|
||||
|
||||
# ruff: noqa
|
||||
|
||||
# Install support for pickling exceptions so that we can nicely propagate
|
||||
|
||||
@ -237,7 +237,7 @@ def deepseekv2_lora_files():
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def gptoss20b_lora_files():
|
||||
return snapshot_download(repo_id="LevinZheng/gpt-oss-20b-lora-adapter")
|
||||
return snapshot_download(repo_id="jeeejeee/gpt-oss-20b-lora-adapter-text2sql")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
||||
@ -1,6 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# NOTE To avoid overloading the CI pipeline, this test script will
|
||||
# not be triggered on CI and is primarily intended for local testing
|
||||
# and verification.
|
||||
|
||||
import vllm
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
|
||||
@ -1,52 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import vllm
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
MODEL_PATH = "openai/gpt-oss-20b"
|
||||
|
||||
PROMPT_TEMPLATE = "<|begin▁of▁sentence|>You are a helpful assistant.\n\nUser: {context}\n\nAssistant:" # noqa: E501
|
||||
|
||||
|
||||
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
|
||||
prompts = [
|
||||
PROMPT_TEMPLATE.format(context="Who are you?"),
|
||||
]
|
||||
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64)
|
||||
outputs = llm.generate(
|
||||
prompts,
|
||||
sampling_params,
|
||||
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None,
|
||||
)
|
||||
# Print the outputs.
|
||||
generated_texts: list[str] = []
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text.strip()
|
||||
generated_texts.append(generated_text)
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
return generated_texts
|
||||
|
||||
|
||||
# FIXME: Load gpt-oss adapter
|
||||
def test_gptoss20b_lora(gptoss20b_lora_files):
|
||||
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
|
||||
# Otherwise, the lora-test will fail due to CUDA OOM.
|
||||
llm = vllm.LLM(
|
||||
MODEL_PATH,
|
||||
enable_lora=True,
|
||||
max_loras=4,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
expected_lora_output = [
|
||||
"I am an AI language model developed by OpenAI. "
|
||||
"I am here to help you with any questions or "
|
||||
"tasks you may have."
|
||||
]
|
||||
|
||||
output1 = do_sample(llm, gptoss20b_lora_files, lora_id=1)
|
||||
print(output1)
|
||||
for i in range(len(expected_lora_output)):
|
||||
assert output1[i].startswith(expected_lora_output[i])
|
||||
106
tests/lora/test_gptoss_tp.py
Normal file
106
tests/lora/test_gptoss_tp.py
Normal file
@ -0,0 +1,106 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import vllm
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
from ..utils import multi_gpu_test
|
||||
|
||||
MODEL_PATH = "openai/gpt-oss-20b"
|
||||
|
||||
PROMPT_TEMPLATE = """<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.
|
||||
Knowledge cutoff: 2024-06
|
||||
Current date: 2025-10-29
|
||||
|
||||
Reasoning: medium
|
||||
|
||||
# Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>user<|message|>I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.
|
||||
"
|
||||
##Instruction:
|
||||
farm contains tables such as city, farm, farm_competition, competition_record. Table city has columns such as City_ID, Official_Name, Status, Area_km_2, Population, Census_Ranking. City_ID is the primary key.
|
||||
Table farm has columns such as Farm_ID, Year, Total_Horses, Working_Horses, Total_Cattle, Oxen, Bulls, Cows, Pigs, Sheep_and_Goats. Farm_ID is the primary key.
|
||||
Table farm_competition has columns such as Competition_ID, Year, Theme, Host_city_ID, Hosts. Competition_ID is the primary key.
|
||||
Table competition_record has columns such as Competition_ID, Farm_ID, Rank. Competition_ID is the primary key.
|
||||
The Host_city_ID of farm_competition is the foreign key of City_ID of city.
|
||||
The Farm_ID of competition_record is the foreign key of Farm_ID of farm.
|
||||
The Competition_ID of competition_record is the foreign key of Competition_ID of farm_competition.
|
||||
|
||||
|
||||
###Input:
|
||||
{context}
|
||||
|
||||
###Response:<|end|><|start|>assistant<|channel|>final<|message|>""" # noqa: E501
|
||||
|
||||
EXPECTED_LORA_OUTPUT = [
|
||||
"SELECT AVG(Working_Horses) FROM farm WHERE Total_Horses > 5000;",
|
||||
"SELECT AVG(Working_Horses) FROM farm WHERE Total_Horses > 5000;",
|
||||
"SELECT MAX(Cows) AS Max_Cows, MIN(Cows) AS Min_Cows FROM farm;",
|
||||
"SELECT MAX(Cows) AS Max_Cows, MIN(Cows) AS Min_Cows FROM farm;",
|
||||
]
|
||||
|
||||
|
||||
def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:
|
||||
prompts = [
|
||||
PROMPT_TEMPLATE.format(
|
||||
context="What is the average number of working horses of farms with more than 5000 total number of horses?" # noqa: E501
|
||||
), # noqa: E501
|
||||
PROMPT_TEMPLATE.format(
|
||||
context="Give the average number of working horses on farms with more than 5000 total horses." # noqa: E501
|
||||
), # noqa: E501
|
||||
PROMPT_TEMPLATE.format(
|
||||
context="What are the maximum and minimum number of cows across all farms."
|
||||
),
|
||||
PROMPT_TEMPLATE.format(
|
||||
context="Return the maximum and minimum number of cows across all farms."
|
||||
),
|
||||
]
|
||||
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64)
|
||||
outputs = llm.generate(
|
||||
prompts,
|
||||
sampling_params,
|
||||
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None,
|
||||
)
|
||||
# Print the outputs.
|
||||
generated_texts: list[str] = []
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text.strip()
|
||||
generated_texts.append(generated_text)
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
for i in range(len(EXPECTED_LORA_OUTPUT)):
|
||||
assert generated_texts[i].startswith(EXPECTED_LORA_OUTPUT[i])
|
||||
|
||||
|
||||
def test_gpt_oss_lora(gptoss20b_lora_files):
|
||||
llm = vllm.LLM(
|
||||
MODEL_PATH,
|
||||
max_model_len=1024,
|
||||
enable_lora=True,
|
||||
max_loras=4,
|
||||
max_lora_rank=8,
|
||||
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
|
||||
cudagraph_specialize_lora=False,
|
||||
),
|
||||
)
|
||||
|
||||
generate_and_test(llm, gptoss20b_lora_files, lora_id=1)
|
||||
generate_and_test(llm, gptoss20b_lora_files, lora_id=2)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
def test_gpt_oss_lora_tp2(gptoss20b_lora_files):
|
||||
llm = vllm.LLM(
|
||||
MODEL_PATH,
|
||||
max_model_len=1024,
|
||||
enable_lora=True,
|
||||
max_loras=2,
|
||||
max_lora_rank=8,
|
||||
tensor_parallel_size=2,
|
||||
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
|
||||
cudagraph_specialize_lora=False,
|
||||
),
|
||||
)
|
||||
|
||||
generate_and_test(llm, gptoss20b_lora_files, lora_id=1)
|
||||
generate_and_test(llm, gptoss20b_lora_files, lora_id=2)
|
||||
@ -1,6 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
# NOTE To avoid overloading the CI pipeline, this test script will not
|
||||
# be triggered on CI and is primarily intended for local testing and verification.
|
||||
|
||||
import vllm
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
|
||||
86
tests/models/multimodal/generation/test_keye.py
Normal file
86
tests/models/multimodal/generation/test_keye.py
Normal file
@ -0,0 +1,86 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import asdict
|
||||
from typing import NamedTuple
|
||||
|
||||
import pytest
|
||||
from PIL.Image import Image
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from vllm import LLM, EngineArgs, SamplingParams
|
||||
from vllm.multimodal.utils import encode_image_base64
|
||||
|
||||
MODEL_NAME = "Kwai-Keye/Keye-VL-8B-Preview"
|
||||
|
||||
QUESTION = "What is the content of each image?"
|
||||
|
||||
|
||||
class ModelRequestData(NamedTuple):
|
||||
engine_args: EngineArgs
|
||||
prompt: str
|
||||
image_data: list[Image]
|
||||
stop_token_ids: list[int] | None = None
|
||||
chat_template: str | None = None
|
||||
sampling_params: SamplingParams | None = None
|
||||
|
||||
|
||||
@pytest.mark.core_model
|
||||
@pytest.mark.parametrize("question", [QUESTION])
|
||||
def test_keye_vl(
|
||||
image_assets,
|
||||
question: str,
|
||||
):
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
image_urls = [
|
||||
f"data:image/jpeg;base64,{encode_image_base64(image)}" for image in images
|
||||
]
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=MODEL_NAME,
|
||||
trust_remote_code=True,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=5,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
)
|
||||
|
||||
placeholders = [{"type": "image", "image": url} for url in image_urls]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
*placeholders,
|
||||
{"type": "text", "text": question},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
||||
|
||||
prompt = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
engine_args = asdict(engine_args) | {"seed": 42}
|
||||
llm = LLM(**engine_args)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.0, max_tokens=256, stop_token_ids=None
|
||||
)
|
||||
|
||||
outputs = llm.generate(
|
||||
{
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": {"image": images},
|
||||
},
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
print("-" * 50)
|
||||
for o in outputs:
|
||||
generated_text = o.outputs[0].text
|
||||
print(generated_text)
|
||||
assert len(generated_text) > 10, (
|
||||
f"Generated text is too short: {generated_text}"
|
||||
)
|
||||
print("-" * 50)
|
||||
@ -186,6 +186,8 @@ def create_reduced_config(
|
||||
if "text_config" in config_dict:
|
||||
original_text_layers = config_dict["text_config"]["num_hidden_layers"]
|
||||
config_dict["text_config"]["num_hidden_layers"] = text_layers
|
||||
original_layer_types = config_dict["text_config"]["layer_types"]
|
||||
config_dict["text_config"]["layer_types"] = original_layer_types[:text_layers]
|
||||
print(f"Reduced text layers from {original_text_layers} to {text_layers}")
|
||||
|
||||
original_num_experts = config_dict["text_config"]["num_local_experts"]
|
||||
|
||||
@ -882,27 +882,27 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
||||
|
||||
_TRANSFORMERS_BACKEND_MODELS = {
|
||||
"TransformersEmbeddingModel": _HfExamplesInfo(
|
||||
"BAAI/bge-base-en-v1.5", min_transformers_version="4.57.0.dev0"
|
||||
"BAAI/bge-base-en-v1.5", min_transformers_version="5.0.0"
|
||||
),
|
||||
"TransformersForSequenceClassification": _HfExamplesInfo(
|
||||
"papluca/xlm-roberta-base-language-detection",
|
||||
min_transformers_version="4.57.0.dev0",
|
||||
min_transformers_version="5.0.0",
|
||||
),
|
||||
"TransformersForCausalLM": _HfExamplesInfo(
|
||||
"hmellor/Ilama-3.2-1B", trust_remote_code=True
|
||||
),
|
||||
"TransformersMultiModalForCausalLM": _HfExamplesInfo("BAAI/Emu3-Chat-hf"),
|
||||
"TransformersMoEForCausalLM": _HfExamplesInfo(
|
||||
"allenai/OLMoE-1B-7B-0924", min_transformers_version="4.57.0.dev0"
|
||||
"allenai/OLMoE-1B-7B-0924", min_transformers_version="5.0.0"
|
||||
),
|
||||
"TransformersMultiModalMoEForCausalLM": _HfExamplesInfo(
|
||||
"Qwen/Qwen3-VL-30B-A3B-Instruct", min_transformers_version="4.57.0.dev0"
|
||||
"Qwen/Qwen3-VL-30B-A3B-Instruct", min_transformers_version="5.0.0"
|
||||
),
|
||||
"TransformersMoEEmbeddingModel": _HfExamplesInfo(
|
||||
"Qwen/Qwen3-30B-A3B", min_transformers_version="4.57.0.dev0"
|
||||
"Qwen/Qwen3-30B-A3B", min_transformers_version="5.0.0"
|
||||
),
|
||||
"TransformersMoEForSequenceClassification": _HfExamplesInfo(
|
||||
"Qwen/Qwen3-30B-A3B", min_transformers_version="4.57.0.dev0"
|
||||
"Qwen/Qwen3-30B-A3B", min_transformers_version="5.0.0"
|
||||
),
|
||||
"TransformersMultiModalEmbeddingModel": _HfExamplesInfo("google/gemma-3-4b-it"),
|
||||
"TransformersMultiModalForSequenceClassification": _HfExamplesInfo(
|
||||
|
||||
@ -82,7 +82,7 @@ def test_models(
|
||||
from packaging.version import Version
|
||||
|
||||
installed = Version(transformers.__version__)
|
||||
required = Version("4.57.0.dev0")
|
||||
required = Version("5.0.0")
|
||||
if model == "allenai/OLMoE-1B-7B-0924" and installed < required:
|
||||
pytest.skip(
|
||||
"MoE models with the Transformers backend require "
|
||||
|
||||
@ -14,16 +14,19 @@ import torch
|
||||
from tests.v1.attention.utils import (
|
||||
BatchSpec,
|
||||
create_common_attn_metadata,
|
||||
create_standard_kv_cache_spec,
|
||||
create_vllm_config,
|
||||
try_get_attention_backend,
|
||||
)
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import _Backend, backend_to_class_str
|
||||
from vllm.attention.ops.flashmla import is_flashmla_dense_supported
|
||||
from vllm.attention.utils.fa_utils import flash_attn_supports_mla
|
||||
from vllm.config.vllm import set_current_vllm_config
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.attention.backends.mla.common import QueryLenSupport
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
|
||||
@ -31,17 +34,46 @@ BACKENDS_TO_TEST = [
|
||||
_Backend.CUTLASS_MLA,
|
||||
_Backend.FLASHMLA,
|
||||
_Backend.FLASH_ATTN_MLA,
|
||||
_Backend.FLASHINFER_MLA,
|
||||
_Backend.TRITON_MLA,
|
||||
]
|
||||
|
||||
# Remove CUTLASS_MLA from the list if not using sm100
|
||||
# Remove sm100 backends from the list if not using sm100
|
||||
if not torch.cuda.is_available() or torch.cuda.get_device_properties(0).major < 10:
|
||||
BACKENDS_TO_TEST.remove(_Backend.CUTLASS_MLA)
|
||||
BACKENDS_TO_TEST.remove(_Backend.FLASHINFER_MLA)
|
||||
|
||||
# Remove FLASH_ATTN_MLA from the list if not supported
|
||||
if not flash_attn_supports_mla():
|
||||
BACKENDS_TO_TEST.remove(_Backend.FLASH_ATTN_MLA)
|
||||
|
||||
# Remove FLASHMLA from the list if not supported
|
||||
if not is_flashmla_dense_supported()[0]:
|
||||
BACKENDS_TO_TEST.remove(_Backend.FLASHMLA)
|
||||
|
||||
SPEC_DECODE_BACKENDS = []
|
||||
for backend in BACKENDS_TO_TEST:
|
||||
builder_cls, _ = try_get_attention_backend(backend)
|
||||
query_len_support = getattr(
|
||||
builder_cls, "query_len_support", QueryLenSupport.SINGLE_ONLY
|
||||
)
|
||||
if query_len_support != QueryLenSupport.SINGLE_ONLY:
|
||||
SPEC_DECODE_BACKENDS.append(backend)
|
||||
|
||||
BACKEND_BLOCK_SIZES = {}
|
||||
for backend in BACKENDS_TO_TEST:
|
||||
backend_class_str = backend_to_class_str(backend)
|
||||
backend_class = resolve_obj_by_qualname(backend_class_str)
|
||||
supported_sizes = backend_class.get_supported_kernel_block_size()
|
||||
if supported_sizes:
|
||||
default_size = supported_sizes[0]
|
||||
block_size = (
|
||||
default_size if isinstance(default_size, int) else default_size.base
|
||||
)
|
||||
else:
|
||||
block_size = 16
|
||||
BACKEND_BLOCK_SIZES[backend] = block_size
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
|
||||
@ -236,6 +268,26 @@ class MockAttentionLayer:
|
||||
self._q_scale = torch.tensor(1.0, device=device)
|
||||
self._k_scale = torch.tensor(1.0, device=device)
|
||||
self._v_scale = torch.tensor(1.0, device=device)
|
||||
self._prob_scale = torch.tensor(1.0, device=device)
|
||||
self._q_scale_float = 1.0
|
||||
self._k_scale_float = 1.0
|
||||
self._v_scale_float = 1.0
|
||||
|
||||
def forward(self, *_args, **_kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MockMLAAttentionLayer(AttentionLayerBase):
|
||||
"""A mock MLA attention layer for populating static_forward_context."""
|
||||
|
||||
def __init__(self, impl):
|
||||
self.impl = impl
|
||||
|
||||
def get_attn_backend(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def run_attention_backend(
|
||||
@ -262,13 +314,6 @@ def run_attention_backend(
|
||||
# Set the current vllm config so that get_current_vllm_config() works
|
||||
# in the backend implementations
|
||||
with set_current_vllm_config(vllm_config):
|
||||
# Build metadata
|
||||
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
|
||||
attn_metadata = builder.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
|
||||
# Instantiate MLA implementation
|
||||
num_heads = vllm_config.model_config.get_num_attention_heads(
|
||||
vllm_config.parallel_config
|
||||
@ -302,6 +347,19 @@ def run_attention_backend(
|
||||
act_dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
|
||||
impl.process_weights_after_loading(act_dtype)
|
||||
|
||||
# Populate static_forward_context with mock attention layers
|
||||
for layer_name in layer_names:
|
||||
vllm_config.compilation_config.static_forward_context[layer_name] = (
|
||||
MockMLAAttentionLayer(impl)
|
||||
)
|
||||
|
||||
# Build metadata
|
||||
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
|
||||
attn_metadata = builder.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
|
||||
# Create mock layer and output buffer
|
||||
mock_layer = MockAttentionLayer(device)
|
||||
num_tokens = query.shape[0]
|
||||
@ -353,15 +411,14 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
simulated paged KV cache.
|
||||
5. Comparing the vLLM backend's output to the ground-truth SDPA output.
|
||||
"""
|
||||
from vllm.v1.attention.backends.mla.common import QueryLenSupport
|
||||
|
||||
batch_spec = BATCH_SPECS[batch_spec_name]
|
||||
is_spec_decode_test = batch_spec_name.startswith("spec_decode")
|
||||
spec_decode_backends = {_Backend.FLASH_ATTN_MLA, _Backend.FLASHMLA}
|
||||
|
||||
block_size = 16
|
||||
unique_block_sizes = sorted(set(BACKEND_BLOCK_SIZES.values()))
|
||||
default_block_size = unique_block_sizes[0]
|
||||
required_blocks = sum(
|
||||
(seq_len + block_size - 1) // block_size for seq_len in batch_spec.seq_lens
|
||||
(seq_len + default_block_size - 1) // default_block_size
|
||||
for seq_len in batch_spec.seq_lens
|
||||
)
|
||||
# Add 1 for null block at index 0, and some buffer
|
||||
num_gpu_blocks = required_blocks + 1 + 100
|
||||
@ -370,7 +427,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
model_name=model,
|
||||
max_model_len=max(batch_spec.seq_lens),
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
block_size=block_size,
|
||||
block_size=default_block_size,
|
||||
)
|
||||
|
||||
# For spec decode tests, add a speculative_config to set the reorder_batch_threshold
|
||||
@ -388,8 +445,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
|
||||
|
||||
# 1. Setup
|
||||
batch_size = batch_spec.batch_size
|
||||
seq_lens = batch_spec.seq_lens
|
||||
@ -399,7 +454,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
)
|
||||
head_size = vllm_config.model_config.get_head_size()
|
||||
dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
kv_lora_rank = 512
|
||||
qk_rope_head_dim = 64
|
||||
qk_nope_head_dim = 128
|
||||
@ -598,33 +652,83 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
)
|
||||
mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T, requires_grad=False)
|
||||
|
||||
# Create metadata using original batch spec
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec, vllm_config.cache_config.block_size, device
|
||||
)
|
||||
# 3. Create metadata and KV caches for each block size
|
||||
# Group backends by block size and test each group
|
||||
metadata_per_block_size = {}
|
||||
kv_cache_per_block_size = {}
|
||||
|
||||
# 3. Simulate Paged KV Cache and a realistic slot_mapping
|
||||
kv_cache = create_and_prepopulate_kv_cache(
|
||||
kv_c_contexts=kv_c_contexts,
|
||||
k_pe_contexts=k_pe_contexts,
|
||||
block_size=block_size,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
num_blocks=vllm_config.cache_config.num_gpu_blocks,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
randomize_blocks=True,
|
||||
)
|
||||
for block_size in unique_block_sizes:
|
||||
# Create metadata for this block size
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec, block_size, device
|
||||
)
|
||||
|
||||
# Pad block table to meet requirement:
|
||||
# block_num % (128 / block_size) == 0
|
||||
required_divisor = int(128 / block_size)
|
||||
current_block_num = common_attn_metadata.block_table_tensor.shape[1]
|
||||
if current_block_num % required_divisor != 0:
|
||||
# Pad to next multiple of required_divisor
|
||||
padded_block_num = (
|
||||
(current_block_num + required_divisor - 1) // required_divisor
|
||||
) * required_divisor
|
||||
padding_cols = padded_block_num - current_block_num
|
||||
padding = torch.zeros(
|
||||
(common_attn_metadata.block_table_tensor.shape[0], padding_cols),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
common_attn_metadata.block_table_tensor = torch.cat(
|
||||
[common_attn_metadata.block_table_tensor, padding], dim=1
|
||||
)
|
||||
|
||||
metadata_per_block_size[block_size] = common_attn_metadata
|
||||
|
||||
# Create KV cache for this block size
|
||||
required_blocks_for_size = sum(
|
||||
(seq_len + block_size - 1) // block_size for seq_len in batch_spec.seq_lens
|
||||
)
|
||||
num_blocks_for_size = required_blocks_for_size + 1 + 100
|
||||
|
||||
kv_cache = create_and_prepopulate_kv_cache(
|
||||
kv_c_contexts=kv_c_contexts,
|
||||
k_pe_contexts=k_pe_contexts,
|
||||
block_size=block_size,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
num_blocks=num_blocks_for_size,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
randomize_blocks=True,
|
||||
)
|
||||
kv_cache_per_block_size[block_size] = kv_cache
|
||||
|
||||
# 4. Run vLLM backends and compare
|
||||
failures = []
|
||||
for backend_idx, backend_name in enumerate(BACKENDS_TO_TEST):
|
||||
# Skip backends that don't support spec decode for spec decode tests
|
||||
if is_spec_decode_test and backend_name not in spec_decode_backends:
|
||||
if is_spec_decode_test and backend_name not in SPEC_DECODE_BACKENDS:
|
||||
continue
|
||||
|
||||
# Get the appropriate block_size, metadata, and cache for this backend
|
||||
block_size = BACKEND_BLOCK_SIZES[backend_name]
|
||||
common_attn_metadata = metadata_per_block_size[block_size]
|
||||
kv_cache = kv_cache_per_block_size[block_size]
|
||||
|
||||
# Create kv_cache_spec with the correct block_size for this backend
|
||||
backend_kv_cache_spec = FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=vllm_config.model_config.get_num_kv_heads(
|
||||
vllm_config.parallel_config
|
||||
),
|
||||
head_size=vllm_config.model_config.get_head_size(),
|
||||
dtype=vllm_config.model_config.dtype,
|
||||
sliding_window=vllm_config.model_config.get_sliding_window(),
|
||||
)
|
||||
|
||||
backend_output = run_attention_backend(
|
||||
backend_name,
|
||||
kv_cache_spec,
|
||||
backend_kv_cache_spec,
|
||||
["placeholder"],
|
||||
vllm_config,
|
||||
device,
|
||||
@ -644,32 +748,48 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
expected_output = sdpa_outputs[backend_name]
|
||||
|
||||
# Check shape and dtype consistency
|
||||
assert backend_output.shape == expected_output.shape, (
|
||||
f"[{backend_name}] shape {backend_output.shape} != "
|
||||
f"SDPA shape {expected_output.shape}"
|
||||
)
|
||||
assert backend_output.dtype == expected_output.dtype, (
|
||||
f"[{backend_name}] dtype {backend_output.dtype} != "
|
||||
f"SDPA dtype {expected_output.dtype}"
|
||||
)
|
||||
try:
|
||||
assert backend_output.shape == expected_output.shape, (
|
||||
f"[{backend_name}] shape {backend_output.shape} != "
|
||||
f"SDPA shape {expected_output.shape}"
|
||||
)
|
||||
assert backend_output.dtype == expected_output.dtype, (
|
||||
f"[{backend_name}] dtype {backend_output.dtype} != "
|
||||
f"SDPA dtype {expected_output.dtype}"
|
||||
)
|
||||
|
||||
assert torch.isfinite(backend_output).all(), (
|
||||
f"[{backend_name}] produced non-finite values"
|
||||
)
|
||||
assert torch.isfinite(backend_output).all(), (
|
||||
f"[{backend_name}] produced non-finite values"
|
||||
)
|
||||
|
||||
# Check numerical similarity
|
||||
rtol = 1e-2
|
||||
atol = 5e-1
|
||||
# Check numerical similarity
|
||||
rtol = 1e-2
|
||||
atol = 5e-1
|
||||
|
||||
max_diff = torch.max(torch.abs(backend_output - expected_output)).item()
|
||||
max_rel_diff = torch.max(
|
||||
torch.abs(backend_output - expected_output) / torch.abs(expected_output)
|
||||
).item()
|
||||
all_close = torch.allclose(
|
||||
backend_output, expected_output, rtol=rtol, atol=atol
|
||||
)
|
||||
max_diff = torch.max(torch.abs(backend_output - expected_output)).item()
|
||||
max_rel_diff = torch.max(
|
||||
torch.abs(backend_output - expected_output) / torch.abs(expected_output)
|
||||
).item()
|
||||
all_close = torch.allclose(
|
||||
backend_output, expected_output, rtol=rtol, atol=atol
|
||||
)
|
||||
|
||||
assert all_close, (
|
||||
f"[{backend_name}] output differs from SDPA baseline. "
|
||||
f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})"
|
||||
)
|
||||
assert all_close, (
|
||||
f"[{backend_name}] output differs from SDPA baseline. "
|
||||
f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})"
|
||||
)
|
||||
except AssertionError as e:
|
||||
failures.append(str(e))
|
||||
|
||||
# Report all failures at once
|
||||
if failures:
|
||||
# Create a summary for the single-line failure message
|
||||
backend_names = []
|
||||
for f in failures:
|
||||
if "[_Backend." in f:
|
||||
backend_name = f.split("[")[1].split("]")[0]
|
||||
backend_names.append(backend_name)
|
||||
|
||||
summary = f"{len(failures)} backend(s) failed: {', '.join(backend_names)}"
|
||||
detailed_msg = "\n".join(failures)
|
||||
pytest.fail(f"{summary}\n{detailed_msg}")
|
||||
|
||||
@ -285,7 +285,17 @@ full_cg_backend_configs = {
|
||||
name="CutlassMLA",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
|
||||
"FORCE_NUM_KV_SPLITS": "1", # TODO: remove this when hang issue is fixed
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
},
|
||||
specific_gpu_arch=(10, 0),
|
||||
),
|
||||
# FlashInfer MLA on Blackwell
|
||||
"FlashInferMLA": BackendConfig(
|
||||
name="FlashInferMLA",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASHINFER_MLA",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
|
||||
@ -337,8 +337,6 @@ def test_stop_via_update_from_output():
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids=[],
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
model_output = ModelRunnerOutput(
|
||||
@ -385,8 +383,6 @@ def test_stop_via_update_from_output():
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids=[],
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
model_output = ModelRunnerOutput(
|
||||
@ -431,8 +427,6 @@ def test_stop_via_update_from_output():
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids=[],
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
model_output = ModelRunnerOutput(
|
||||
@ -472,8 +466,6 @@ def test_stop_via_update_from_output():
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids=[],
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
model_output = ModelRunnerOutput(
|
||||
@ -1988,7 +1980,6 @@ def test_schedule_skip_tokenizer_init():
|
||||
scheduler.add_request(request)
|
||||
output = scheduler.schedule()
|
||||
assert len(output.scheduled_new_reqs) == len(requests)
|
||||
assert output.grammar_bitmask is None
|
||||
|
||||
|
||||
def test_schedule_skip_tokenizer_init_structured_output_request():
|
||||
|
||||
@ -7,6 +7,7 @@ import torch._dynamo.config as dynamo_config
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.sampling_params import StructuredOutputsParams
|
||||
|
||||
from ...conftest import VllmRunner
|
||||
from ...models.utils import check_outputs_equal
|
||||
@ -15,9 +16,12 @@ MODEL = "Qwen/Qwen3-0.6B"
|
||||
|
||||
|
||||
@dynamo_config.patch(cache_size_limit=16)
|
||||
def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
|
||||
def test_preempt_and_async_scheduling_e2e(
|
||||
sample_json_schema, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
"""Test consistency of combos of async scheduling, preemption,
|
||||
uni/multiproc executor, and various sampling parameters."""
|
||||
uni/multiproc executor, and various sampling parameters
|
||||
including structured outputs."""
|
||||
|
||||
first_prompt = (
|
||||
"The following numbers of the sequence "
|
||||
@ -35,6 +39,12 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
|
||||
dict(bad_words=["the", " the"]),
|
||||
dict(logprobs=2),
|
||||
dict(logprobs=2, presence_penalty=-1.0),
|
||||
dict(structured_outputs=StructuredOutputsParams(json=sample_json_schema)),
|
||||
dict(
|
||||
structured_outputs=StructuredOutputsParams(json=sample_json_schema),
|
||||
logprobs=2,
|
||||
presence_penalty=-1.0,
|
||||
),
|
||||
]
|
||||
|
||||
default_params = dict(
|
||||
@ -248,7 +248,7 @@ def test_engine_core_concurrent_batches():
|
||||
self,
|
||||
scheduler_output,
|
||||
non_block=False,
|
||||
) -> Future[ModelRunnerOutput]:
|
||||
) -> Future[ModelRunnerOutput | None]:
|
||||
"""Make execute_model non-blocking."""
|
||||
|
||||
# DummyExecutor used only for testing async case.
|
||||
@ -263,6 +263,23 @@ def test_engine_core_concurrent_batches():
|
||||
# Use the thread pool instead of creating a new thread
|
||||
return self.thread_pool.submit(_execute)
|
||||
|
||||
def sample_tokens(
|
||||
self, grammar_output, non_block=False
|
||||
) -> Future[ModelRunnerOutput]:
|
||||
"""Make sample_tokens non-blocking."""
|
||||
|
||||
# DummyExecutor used only for testing async case.
|
||||
assert non_block
|
||||
|
||||
def _execute():
|
||||
output = self.collective_rpc("sample_tokens", args=(grammar_output,))
|
||||
# Make a copy because output[0] may be reused
|
||||
# by the next batch.
|
||||
return copy.deepcopy(output[0])
|
||||
|
||||
# Use the thread pool instead of creating a new thread
|
||||
return self.thread_pool.submit(_execute)
|
||||
|
||||
@property
|
||||
def max_concurrent_batches(self) -> int:
|
||||
return 2
|
||||
|
||||
@ -31,7 +31,9 @@ class CustomMultiprocExecutor(MultiprocExecutor):
|
||||
# Drop marker to show that this was run
|
||||
with open(".marker", "w"):
|
||||
...
|
||||
return super().collective_rpc(method, timeout, args, kwargs)
|
||||
return super().collective_rpc(
|
||||
method, timeout, args, kwargs, non_block, unique_reply_rank
|
||||
)
|
||||
|
||||
|
||||
CustomMultiprocExecutorAsync = CustomMultiprocExecutor
|
||||
|
||||
65
tests/v1/kv_connector/unit/test_config.py
Normal file
65
tests/v1/kv_connector/unit/test_config.py
Normal file
@ -0,0 +1,65 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""Tests for KV cache offloading configuration."""
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import CacheConfig, KVTransferConfig, ParallelConfig, VllmConfig
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"kv_offloading_backend,kv_offloading_size,tp,pp,expected_backend,expected_bytes",
|
||||
[
|
||||
("native", 4.0, 1, 1, "OffloadingConnector", 4.0 * (1 << 30)),
|
||||
# bytes per rank: 8.0 GiB / (2 * 2) = 2.0 GiB
|
||||
("native", 8.0, 2, 2, "OffloadingConnector", 8.0 * (1 << 30) / 4),
|
||||
("lmcache", 4.0, 1, 1, "LMCacheConnectorV1", 4.0),
|
||||
# size per rank: 8.0 GiB / (2 * 2) = 2.0 GiB
|
||||
("lmcache", 8.0, 2, 2, "LMCacheConnectorV1", 2.0),
|
||||
(None, None, 1, 1, None, None),
|
||||
],
|
||||
)
|
||||
def test_kv_connector(
|
||||
kv_offloading_backend, kv_offloading_size, tp, pp, expected_backend, expected_bytes
|
||||
):
|
||||
kv_transfer_config = (
|
||||
KVTransferConfig(kv_connector_extra_config={"existing_key": "existing_value"})
|
||||
if expected_backend is not None
|
||||
else None
|
||||
)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
cache_config=CacheConfig(
|
||||
kv_offloading_backend=kv_offloading_backend,
|
||||
kv_offloading_size=kv_offloading_size,
|
||||
),
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
parallel_config=ParallelConfig(
|
||||
tensor_parallel_size=tp, pipeline_parallel_size=pp
|
||||
),
|
||||
)
|
||||
|
||||
# No KV transfer config expected
|
||||
if expected_backend is None:
|
||||
assert vllm_config.kv_transfer_config is expected_backend
|
||||
return
|
||||
|
||||
kv_transfer_config = vllm_config.kv_transfer_config
|
||||
kv_connector_extra_config = kv_transfer_config.kv_connector_extra_config
|
||||
|
||||
assert kv_transfer_config.kv_connector == expected_backend
|
||||
assert kv_transfer_config.kv_role == "kv_both"
|
||||
|
||||
if kv_offloading_backend == "native":
|
||||
assert kv_connector_extra_config["kv_bytes_per_rank"] == expected_bytes
|
||||
assert kv_connector_extra_config["num_cpu_blocks"] == 0
|
||||
# Existing config should be preserved
|
||||
assert kv_connector_extra_config["existing_key"] == "existing_value"
|
||||
elif kv_offloading_backend == "lmcache":
|
||||
assert kv_connector_extra_config["lmcache.local_cpu"] is True
|
||||
assert kv_connector_extra_config["lmcache.max_local_cpu_size"] == expected_bytes
|
||||
# Existing config should be replaced
|
||||
assert "existing_key" not in kv_connector_extra_config
|
||||
@ -26,8 +26,6 @@ def _make_empty_scheduler_output():
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids=[],
|
||||
grammar_bitmask=None,
|
||||
kv_connector_metadata=SharedStorageConnectorMetadata(),
|
||||
)
|
||||
|
||||
|
||||
@ -27,6 +27,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
||||
NixlAgentMetadata,
|
||||
NixlConnector,
|
||||
NixlConnectorMetadata,
|
||||
NixlConnectorScheduler,
|
||||
NixlConnectorWorker,
|
||||
NixlKVConnectorStats,
|
||||
)
|
||||
@ -283,6 +284,92 @@ def test_prompt_less_than_block_size():
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||
|
||||
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FakeNixlWrapper,
|
||||
)
|
||||
def test_kv_transfer_handshake(dist_init):
|
||||
"""Unit test for basic NixlConnector interface functionality."""
|
||||
|
||||
# Test setup, we creates a scheduler that contains a NixlConnector
|
||||
# of role SCHEDULER, and expect it to be serving NixlAgentMetadata from
|
||||
# all workers of the instance.
|
||||
vllm_config = create_vllm_config()
|
||||
# in case the test runs on non-GPU machine
|
||||
vllm_config.kv_transfer_config.kv_buffer_device = "cpu"
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# Create two NixlConnector of role WORKER, one is the worker of
|
||||
# the scheduler (prefill), the other is a worker of decode instance.
|
||||
|
||||
# Prefill connector will register KV cache to populate proper handshake
|
||||
# metadata.
|
||||
prefill_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
|
||||
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
|
||||
)
|
||||
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
|
||||
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
|
||||
kv_caches = {
|
||||
"layer0": shared_tensor,
|
||||
"layer1": unique_tensor,
|
||||
"layer2": shared_tensor,
|
||||
}
|
||||
prefill_connector.register_kv_caches(kv_caches)
|
||||
|
||||
# Simulate EngineCore initialization that would
|
||||
# gather connector metadata from all workers, the scheduler connector
|
||||
# expects metadata to be in dict[int, KVConnectorHandshakeMetadata],
|
||||
# where the first key is the dp_rank, the second key is the tp_rank.
|
||||
metadata = {0: prefill_connector.get_handshake_metadata()}
|
||||
scheduler_connector = scheduler.get_kv_connector()
|
||||
scheduler_connector.set_xfer_handshake_metadata(metadata)
|
||||
|
||||
# Simulate a request that finishes prefill, which returns
|
||||
# corresponding NixlConnectorMetadata for decode instance.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request = create_request(
|
||||
request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_decode=True,
|
||||
)
|
||||
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
delay, kv_connector_metadata = scheduler.get_kv_connector().request_finished(
|
||||
request, [0, 1, 2]
|
||||
)
|
||||
assert delay
|
||||
|
||||
# Decode connector will be able to create handshake with the prefill connector.
|
||||
decode_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
|
||||
# Here we are testing the retrieval of NIXLAgentMetadata.
|
||||
# Knowing the implementation detail, we override the add_remote_agent
|
||||
# to validate the metadata received is the same as the one in prefill_connector.
|
||||
with patch.object(
|
||||
decode_connector.connector_worker, "add_remote_agent"
|
||||
) as mock_add_remote_agent:
|
||||
mock_add_remote_agent.return_type = "remote_agent"
|
||||
|
||||
decode_connector.connector_worker._nixl_handshake(
|
||||
kv_connector_metadata["remote_host"],
|
||||
kv_connector_metadata["remote_port"],
|
||||
kv_connector_metadata["tp_size"],
|
||||
kv_connector_metadata["remote_engine_id"],
|
||||
)
|
||||
|
||||
received_metadata = mock_add_remote_agent.call_args.args
|
||||
assert received_metadata[1] == 0 # remote_tp_rank
|
||||
assert received_metadata[2] == 1 # remote_tp_size
|
||||
assert metadata[0] == received_metadata[0]
|
||||
|
||||
# Need to shutdown the background thread to release NIXL side channel port
|
||||
scheduler_connector.shutdown()
|
||||
|
||||
|
||||
class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||
REMOTE_ENGINE_ID = "remote_engine"
|
||||
|
||||
@ -313,6 +400,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||
engine_id=self.REMOTE_ENGINE_ID,
|
||||
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
||||
kv_caches_base_addr=[0],
|
||||
device_id=0,
|
||||
num_blocks=1,
|
||||
block_lens=self.block_len_per_layer,
|
||||
attn_backend_name=self.backend_name,
|
||||
@ -559,6 +647,7 @@ class TestNixlHandshake:
|
||||
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
||||
kv_caches_base_addr=[0],
|
||||
device_id=0,
|
||||
num_blocks=1,
|
||||
block_lens=worker.block_len_per_layer,
|
||||
attn_backend_name=worker.backend_name,
|
||||
@ -611,6 +700,7 @@ class TestNixlHandshake:
|
||||
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
||||
kv_caches_base_addr=[0],
|
||||
device_id=0,
|
||||
num_blocks=1,
|
||||
# prefill TP=1, decode TP=2, remote block_lens is double to local
|
||||
block_lens=[i * 2 for i in worker.block_len_per_layer],
|
||||
@ -891,9 +981,7 @@ def test_scheduler_kv_connector_stats_aggregation():
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=[0],
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=set(),
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None,
|
||||
free_encoder_mm_hashes=[],
|
||||
)
|
||||
|
||||
engine_core_outputs = scheduler.update_from_output(scheduler_output, model_output)
|
||||
@ -1005,6 +1093,8 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
|
||||
_ = llm.generate([f"What is the capital of France? {padding}"], sampling_params)
|
||||
# Request-0 times out and is cleared!
|
||||
assert "0" not in req_to_blocks
|
||||
# Need to shutdown the background thread to release NIXL side channel port
|
||||
llm.llm_engine.engine_core.shutdown()
|
||||
|
||||
|
||||
def test_register_kv_caches(dist_init):
|
||||
@ -1177,13 +1267,15 @@ def test_shutdown_cleans_up_resources(dist_init):
|
||||
"""Test that shutdown() properly cleans up all resources."""
|
||||
vllm_config = create_vllm_config()
|
||||
|
||||
scheduler = NixlConnectorScheduler(
|
||||
vllm_config, vllm_config.kv_transfer_config.engine_id
|
||||
)
|
||||
worker = NixlConnectorWorker(vllm_config, vllm_config.kv_transfer_config.engine_id)
|
||||
nixl_wrapper = worker.nixl_wrapper
|
||||
|
||||
with (
|
||||
patch.object(worker, "_handshake_initiation_executor") as mock_exec,
|
||||
patch.object(worker, "_nixl_handshake_listener_t") as mock_listener,
|
||||
patch.object(worker, "_nixl_handshake_listener_stop_event") as mock_event,
|
||||
patch.object(scheduler, "_nixl_handshake_listener_t") as mock_listener,
|
||||
patch.object(nixl_wrapper, "release_xfer_handle") as mock_rel_xfer,
|
||||
patch.object(nixl_wrapper, "release_dlist_handle") as mock_rel_dlist,
|
||||
patch.object(nixl_wrapper, "remove_remote_agent") as mock_rem_agent,
|
||||
@ -1204,8 +1296,12 @@ def test_shutdown_cleans_up_resources(dist_init):
|
||||
worker.shutdown()
|
||||
|
||||
mock_exec.shutdown.assert_called_with(wait=False)
|
||||
mock_event.set.assert_called_once()
|
||||
mock_listener.join.assert_called_once_with(timeout=1.0)
|
||||
|
||||
# Same sequence on scheduler.shutdown()
|
||||
scheduler.shutdown()
|
||||
scheduler.shutdown()
|
||||
scheduler.shutdown()
|
||||
mock_listener.join.assert_called_once()
|
||||
|
||||
mock_rel_xfer.assert_called_once_with(123)
|
||||
assert mock_rel_dlist.call_count == 2
|
||||
|
||||
@ -92,8 +92,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids=[],
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
|
||||
@ -171,8 +169,6 @@ def test_update_states_request_finished(model_runner):
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids={req_id},
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids=[],
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
model_runner._update_states(scheduler_output)
|
||||
@ -201,8 +197,6 @@ def test_update_states_request_resumed(model_runner):
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids=[],
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
model_runner._update_states(scheduler_output)
|
||||
@ -230,8 +224,6 @@ def test_update_states_request_resumed(model_runner):
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids=[],
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
model_runner._update_states(scheduler_output)
|
||||
@ -261,8 +253,6 @@ def test_update_states_no_changes(model_runner):
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids=[],
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
model_runner._update_states(scheduler_output)
|
||||
@ -296,8 +286,6 @@ def test_update_states_request_unscheduled(model_runner):
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids=[],
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
model_runner._update_states(scheduler_output)
|
||||
|
||||
@ -6,6 +6,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.attention.backends.abstract import MultipleOf
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
ModelConfig,
|
||||
@ -34,6 +35,7 @@ from vllm.v1.kv_cache_interface import (
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
BLOCK_SIZE = 16
|
||||
NUM_BLOCKS = 10
|
||||
@ -150,8 +152,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids=[],
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
|
||||
@ -181,6 +181,57 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
|
||||
).all()
|
||||
|
||||
|
||||
def _make_mock_backend_for_kernel_block_size(
|
||||
supported_sizes: list[int | MultipleOf],
|
||||
):
|
||||
class _MockBackend:
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_size():
|
||||
return supported_sizes
|
||||
|
||||
return _MockBackend()
|
||||
|
||||
|
||||
def _make_kv_cache_spec() -> FullAttentionSpec:
|
||||
return FullAttentionSpec(block_size=1, num_kv_heads=1, head_size=1, dtype="float16")
|
||||
|
||||
|
||||
def test_select_common_block_size_prefers_manager_block_size():
|
||||
backend_a = _make_mock_backend_for_kernel_block_size([MultipleOf(32)])
|
||||
backend_b = _make_mock_backend_for_kernel_block_size([64, MultipleOf(16)])
|
||||
attn_groups = [
|
||||
AttentionGroup(backend_a, [], [], _make_kv_cache_spec(), 0),
|
||||
AttentionGroup(backend_b, [], [], _make_kv_cache_spec(), 0),
|
||||
]
|
||||
|
||||
selected_size = GPUModelRunner.select_common_block_size(128, attn_groups)
|
||||
assert selected_size == 128
|
||||
|
||||
|
||||
def test_select_common_block_size_uses_largest_shared_int():
|
||||
backend_a = _make_mock_backend_for_kernel_block_size([128, 64])
|
||||
backend_b = _make_mock_backend_for_kernel_block_size([64, 32])
|
||||
attn_groups = [
|
||||
AttentionGroup(backend_a, [], [], _make_kv_cache_spec(), 0),
|
||||
AttentionGroup(backend_b, [], [], _make_kv_cache_spec(), 0),
|
||||
]
|
||||
|
||||
selected_size = GPUModelRunner.select_common_block_size(256, attn_groups)
|
||||
assert selected_size == 64
|
||||
|
||||
|
||||
def test_select_common_block_size_no_valid_option():
|
||||
backend_a = _make_mock_backend_for_kernel_block_size([64])
|
||||
backend_b = _make_mock_backend_for_kernel_block_size([MultipleOf(16)])
|
||||
attn_groups = [
|
||||
AttentionGroup(backend_a, [], [], _make_kv_cache_spec(), 0),
|
||||
AttentionGroup(backend_b, [], [], _make_kv_cache_spec(), 0),
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
GPUModelRunner.select_common_block_size(48, attn_groups)
|
||||
|
||||
|
||||
def test_update_states_new_request(model_runner, dist_init):
|
||||
req_id = "req_0"
|
||||
|
||||
@ -216,8 +267,6 @@ def test_update_states_request_finished(model_runner, dist_init):
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids={req_id},
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids=[],
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
metadata_before = model_runner.input_batch.sampling_metadata
|
||||
@ -248,8 +297,6 @@ def test_update_states_request_resumed(model_runner, dist_init):
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids=[],
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
model_runner._update_states(scheduler_output)
|
||||
@ -277,8 +324,6 @@ def test_update_states_request_resumed(model_runner, dist_init):
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids=[],
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
metadata_before = model_runner.input_batch.sampling_metadata
|
||||
@ -370,8 +415,6 @@ def test_update_states_no_changes(model_runner, dist_init):
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids=[],
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
metadata_before = model_runner.input_batch.sampling_metadata
|
||||
@ -407,8 +450,6 @@ def test_update_states_request_unscheduled(model_runner, dist_init):
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids=[],
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
metadata_before = model_runner._update_states(scheduler_output)
|
||||
|
||||
@ -270,21 +270,23 @@ class ipex_ops:
|
||||
|
||||
@staticmethod
|
||||
def flash_attn_varlen_func(
|
||||
out: torch.Tensor,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
cu_seqlens_q: torch.Tensor,
|
||||
seqused_k: torch.Tensor, # we don't support this in ipex kernel
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k: int,
|
||||
softmax_scale: float,
|
||||
causal: bool,
|
||||
block_table: torch.Tensor,
|
||||
alibi_slopes: torch.Tensor | None,
|
||||
softmax_scale: float | None = None,
|
||||
causal: bool = False,
|
||||
out: torch.Tensor | None = None,
|
||||
block_table: torch.Tensor | None = None,
|
||||
alibi_slopes: torch.Tensor | None = None,
|
||||
window_size: list[int] | None = None,
|
||||
softcap: float | None = 0.0,
|
||||
seqused_k: torch.Tensor | None = None,
|
||||
cu_seqlens_k: torch.Tensor | None = None,
|
||||
# passed in qwen vl
|
||||
dropout_p: float = 0.0,
|
||||
# The following parameters are not used in ipex kernel currently,
|
||||
# we keep API compatible to CUDA's.
|
||||
scheduler_metadata=None,
|
||||
@ -295,31 +297,63 @@ class ipex_ops:
|
||||
num_splits=0,
|
||||
s_aux: torch.Tensor | None = None,
|
||||
):
|
||||
if out is None:
|
||||
out = torch.empty(q.shape, dtype=q.dtype, device=q.device)
|
||||
real_window_size: tuple[int, int]
|
||||
if window_size is None:
|
||||
real_window_size = (-1, -1)
|
||||
else:
|
||||
assert len(window_size) == 2
|
||||
real_window_size = (window_size[0], window_size[1])
|
||||
return ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
||||
out,
|
||||
q.contiguous(),
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q,
|
||||
seqused_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
softmax_scale,
|
||||
causal,
|
||||
block_table,
|
||||
alibi_slopes,
|
||||
softcap=softcap,
|
||||
window_size_left=real_window_size[0],
|
||||
window_size_right=real_window_size[1],
|
||||
k_scale=1.0,
|
||||
v_scale=1.0,
|
||||
)
|
||||
|
||||
if block_table is None:
|
||||
assert cu_seqlens_k is not None, (
|
||||
"cu_seqlens_k can't be None when calling varlen_attention."
|
||||
)
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
ipex_ops.varlen_attention(
|
||||
q.contiguous(),
|
||||
k.contiguous(),
|
||||
v.contiguous(),
|
||||
out,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
None,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
0.0,
|
||||
softmax_scale,
|
||||
False,
|
||||
causal,
|
||||
False,
|
||||
None,
|
||||
real_window_size[0],
|
||||
real_window_size[1],
|
||||
-1,
|
||||
)
|
||||
return out
|
||||
else:
|
||||
return ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
||||
out,
|
||||
q.contiguous(),
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q,
|
||||
seqused_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
softmax_scale,
|
||||
causal,
|
||||
block_table,
|
||||
alibi_slopes,
|
||||
sink=s_aux,
|
||||
softcap=softcap,
|
||||
window_size_left=real_window_size[0],
|
||||
window_size_right=real_window_size[1],
|
||||
k_scale=1.0,
|
||||
v_scale=1.0,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_scheduler_metadata(
|
||||
|
||||
@ -123,6 +123,11 @@ def maybe_get_vit_flash_attn_backend(
|
||||
):
|
||||
attn_backend = _Backend.FLASH_ATTN
|
||||
use_upstream_fa = True
|
||||
elif current_platform.is_xpu():
|
||||
assert attn_backend == _Backend.FLASH_ATTN, (
|
||||
"XPU platform only supports FLASH_ATTN as vision attention backend."
|
||||
)
|
||||
use_upstream_fa = False
|
||||
else:
|
||||
return _Backend.TORCH_SDPA, None
|
||||
|
||||
@ -133,7 +138,7 @@ def maybe_get_vit_flash_attn_backend(
|
||||
if use_upstream_fa:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
else:
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
|
||||
else:
|
||||
flash_attn_varlen_func = None
|
||||
|
||||
@ -521,22 +526,18 @@ class MultiHeadAttention(nn.Module):
|
||||
# If vllm native fa is selected, we use it directly.
|
||||
use_upstream_fa = False
|
||||
|
||||
if current_platform.is_xpu():
|
||||
# currently, only torch_sdpa is supported on xpu
|
||||
self.attn_backend = _Backend.TORCH_SDPA
|
||||
else:
|
||||
self.attn_backend = (
|
||||
backend
|
||||
if backend
|
||||
in {
|
||||
_Backend.TORCH_SDPA,
|
||||
_Backend.XFORMERS,
|
||||
_Backend.PALLAS,
|
||||
_Backend.ROCM_AITER_FA,
|
||||
_Backend.FLASH_ATTN,
|
||||
}
|
||||
else _Backend.TORCH_SDPA
|
||||
)
|
||||
self.attn_backend = (
|
||||
backend
|
||||
if backend
|
||||
in {
|
||||
_Backend.TORCH_SDPA,
|
||||
_Backend.XFORMERS,
|
||||
_Backend.PALLAS,
|
||||
_Backend.ROCM_AITER_FA,
|
||||
_Backend.FLASH_ATTN,
|
||||
}
|
||||
else _Backend.TORCH_SDPA
|
||||
)
|
||||
|
||||
self.attn_backend, self._flash_attn_varlen_func = (
|
||||
maybe_get_vit_flash_attn_backend(
|
||||
|
||||
@ -70,7 +70,7 @@ def flash_attn_maxseqlen_wrapper(
|
||||
if use_upstream_fa:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
else:
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
|
||||
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||
output = flash_attn_varlen_func(
|
||||
q,
|
||||
|
||||
@ -24,6 +24,7 @@ BlockSize = Literal[1, 8, 16, 32, 64, 128, 256]
|
||||
CacheDType = Literal["auto", "bfloat16", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"]
|
||||
MambaDType = Literal["auto", "float32"]
|
||||
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"]
|
||||
KVOffloadingBackend = Literal["native", "lmcache"]
|
||||
|
||||
|
||||
@config
|
||||
@ -128,6 +129,17 @@ class CacheConfig:
|
||||
gpu_memory_utilization. Note that kv_cache_memory_bytes
|
||||
(when not-None) ignores gpu_memory_utilization"""
|
||||
|
||||
kv_offloading_size: float | None = None
|
||||
"""Size of the KV cache offloading buffer in GiB. When TP > 1, this is
|
||||
the total buffer size summed across all TP ranks. By default, this is set
|
||||
to None, which means no KV offloading is enabled. When set with
|
||||
kv_offloading_backend, vLLM will enable KV cache offloading to CPU"""
|
||||
|
||||
kv_offloading_backend: KVOffloadingBackend | None = None
|
||||
"""The backend to use for KV cache offloading. Supported backends include
|
||||
'native' (vLLM native CPU offloading), 'lmcache' This option must be used
|
||||
together with kv_offloading_size."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
|
||||
@ -2,10 +2,11 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
from dataclasses import InitVar, field
|
||||
from collections.abc import Callable
|
||||
from dataclasses import InitVar
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import SkipValidation, model_validator
|
||||
from pydantic import Field, field_validator, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
from typing_extensions import Self
|
||||
|
||||
@ -31,28 +32,28 @@ class SchedulerConfig:
|
||||
runner_type: RunnerType = "generate"
|
||||
"""The runner type to launch for the model."""
|
||||
|
||||
max_num_batched_tokens: SkipValidation[int] = None # type: ignore
|
||||
max_num_batched_tokens: int = Field(default=None, ge=1)
|
||||
"""Maximum number of tokens to be processed in a single iteration.
|
||||
|
||||
This config has no static default. If left unspecified by the user, it will
|
||||
be set in `EngineArgs.create_engine_config` based on the usage context."""
|
||||
|
||||
max_num_seqs: SkipValidation[int] = None # type: ignore
|
||||
max_num_seqs: int = Field(default=None, ge=1)
|
||||
"""Maximum number of sequences to be processed in a single iteration.
|
||||
|
||||
This config has no static default. If left unspecified by the user, it will
|
||||
be set in `EngineArgs.create_engine_config` based on the usage context."""
|
||||
|
||||
max_model_len: SkipValidation[int] = None # type: ignore
|
||||
max_model_len: int = Field(default=None, ge=1)
|
||||
"""Maximum length of a sequence (including prompt and generated text). This
|
||||
is primarily set in `ModelConfig` and that value should be manually
|
||||
duplicated here."""
|
||||
|
||||
max_num_partial_prefills: int = 1
|
||||
max_num_partial_prefills: int = Field(default=1, ge=1)
|
||||
"""For chunked prefill, the maximum number of sequences that can be
|
||||
partially prefilled concurrently."""
|
||||
|
||||
max_long_partial_prefills: int = 1
|
||||
max_long_partial_prefills: int = Field(default=1, ge=1)
|
||||
"""For chunked prefill, the maximum number of prompts longer than
|
||||
long_prefill_token_threshold that will be prefilled concurrently. Setting
|
||||
this less than max_num_partial_prefills will allow shorter prompts to jump
|
||||
@ -62,7 +63,7 @@ class SchedulerConfig:
|
||||
"""For chunked prefill, a request is considered long if the prompt is
|
||||
longer than this number of tokens."""
|
||||
|
||||
num_lookahead_slots: int = 0
|
||||
num_lookahead_slots: int = Field(default=0, ge=0)
|
||||
"""The number of slots to allocate per sequence per
|
||||
step, beyond the known token ids. This is used in speculative
|
||||
decoding to store KV activations of tokens which may or may not be
|
||||
@ -71,7 +72,7 @@ class SchedulerConfig:
|
||||
NOTE: This will be replaced by speculative config in the future; it is
|
||||
present to enable correctness tests until then."""
|
||||
|
||||
enable_chunked_prefill: SkipValidation[bool] = None # type: ignore
|
||||
enable_chunked_prefill: bool = Field(default=None)
|
||||
"""If True, prefill requests can be chunked based
|
||||
on the remaining max_num_batched_tokens."""
|
||||
|
||||
@ -86,14 +87,14 @@ class SchedulerConfig:
|
||||
"""
|
||||
|
||||
# TODO (ywang96): Make this configurable.
|
||||
max_num_encoder_input_tokens: int = field(init=False)
|
||||
max_num_encoder_input_tokens: int = Field(init=False)
|
||||
"""Multimodal encoder compute budget, only used in V1.
|
||||
|
||||
NOTE: This is not currently configurable. It will be overridden by
|
||||
max_num_batched_tokens in case max multimodal embedding size is larger."""
|
||||
|
||||
# TODO (ywang96): Make this configurable.
|
||||
encoder_cache_size: int = field(init=False)
|
||||
encoder_cache_size: int = Field(init=False)
|
||||
"""Multimodal encoder cache size, only used in V1.
|
||||
|
||||
NOTE: This is not currently configurable. It will be overridden by
|
||||
@ -106,7 +107,7 @@ class SchedulerConfig:
|
||||
- "priority" means requests are handled based on given priority (lower
|
||||
value means earlier handling) and time of arrival deciding any ties)."""
|
||||
|
||||
chunked_prefill_enabled: bool = field(init=False)
|
||||
chunked_prefill_enabled: bool = Field(init=False)
|
||||
"""True if chunked prefill is enabled."""
|
||||
|
||||
disable_chunked_mm_input: bool = False
|
||||
@ -155,6 +156,20 @@ class SchedulerConfig:
|
||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@field_validator(
|
||||
"max_num_batched_tokens",
|
||||
"max_num_seqs",
|
||||
"max_model_len",
|
||||
"enable_chunked_prefill",
|
||||
mode="wrap",
|
||||
)
|
||||
@classmethod
|
||||
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
|
||||
"""Skip validation if the value is `None` when initialisation is delayed."""
|
||||
if value is None:
|
||||
return value
|
||||
return handler(value)
|
||||
|
||||
def __post_init__(self, is_encoder_decoder: bool) -> None:
|
||||
if self.max_model_len is None:
|
||||
self.max_model_len = 8192
|
||||
@ -260,19 +275,7 @@ class SchedulerConfig:
|
||||
self.max_num_seqs * self.max_model_len,
|
||||
)
|
||||
|
||||
if self.num_lookahead_slots < 0:
|
||||
raise ValueError(
|
||||
"num_lookahead_slots "
|
||||
f"({self.num_lookahead_slots}) must be greater than or "
|
||||
"equal to 0."
|
||||
)
|
||||
|
||||
if self.max_num_partial_prefills < 1:
|
||||
raise ValueError(
|
||||
f"max_num_partial_prefills ({self.max_num_partial_prefills}) "
|
||||
"must be greater than or equal to 1."
|
||||
)
|
||||
elif self.max_num_partial_prefills > 1:
|
||||
if self.max_num_partial_prefills > 1:
|
||||
if not self.chunked_prefill_enabled:
|
||||
raise ValueError(
|
||||
"Chunked prefill must be enabled to set "
|
||||
@ -286,13 +289,10 @@ class SchedulerConfig:
|
||||
f"than the max_model_len ({self.max_model_len})."
|
||||
)
|
||||
|
||||
if (self.max_long_partial_prefills < 1) or (
|
||||
self.max_long_partial_prefills > self.max_num_partial_prefills
|
||||
):
|
||||
if self.max_long_partial_prefills > self.max_num_partial_prefills:
|
||||
raise ValueError(
|
||||
f"max_long_partial_prefills ({self.max_long_partial_prefills}) "
|
||||
"must be greater than or equal to 1 and less than or equal to "
|
||||
f"max_num_partial_prefills ({self.max_num_partial_prefills})."
|
||||
f"{self.max_long_partial_prefills=} must be less than or equal to "
|
||||
f"{self.max_num_partial_prefills=}."
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
@ -78,10 +78,6 @@ class SpeculativeConfig:
|
||||
draft_tensor_parallel_size: int | None = Field(default=None, ge=1)
|
||||
"""The degree of the tensor parallelism for the draft model. Can only be 1
|
||||
or the same as the target model's tensor parallel size."""
|
||||
disable_logprobs: bool = True
|
||||
"""If set to True, token log probabilities are not returned during
|
||||
speculative decoding. If set to False, token log probabilities are returned
|
||||
according to the log probability settings in SamplingParams."""
|
||||
|
||||
# Draft model configuration
|
||||
quantization: me_quant.QuantizationMethods | None = None
|
||||
@ -126,12 +122,6 @@ class SpeculativeConfig:
|
||||
"""The configuration of the target model."""
|
||||
target_parallel_config: SkipValidation[ParallelConfig] = None # type: ignore
|
||||
"""The parallel configuration for the target model."""
|
||||
enable_chunked_prefill: SkipValidation[bool] = None # type: ignore
|
||||
"""Whether vLLM is configured to use chunked prefill or not. Used for
|
||||
raising an error since it's not yet compatible with speculative decode."""
|
||||
disable_log_stats: SkipValidation[bool] = None # type: ignore
|
||||
"""Whether to disable the periodic printing of stage times in speculative
|
||||
decoding."""
|
||||
|
||||
# params generated in the post-init stage
|
||||
draft_model_config: SkipValidation[ModelConfig] = None # type: ignore
|
||||
|
||||
@ -2,8 +2,9 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, Self
|
||||
|
||||
from pydantic import model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.config.utils import config
|
||||
@ -56,7 +57,8 @@ class StructuredOutputsConfig:
|
||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self):
|
||||
@model_validator(mode="after")
|
||||
def _validate_structured_output_config(self) -> Self:
|
||||
if self.disable_any_whitespace and self.backend not in ("xgrammar", "guidance"):
|
||||
raise ValueError(
|
||||
"disable_any_whitespace is only supported for "
|
||||
@ -67,3 +69,4 @@ class StructuredOutputsConfig:
|
||||
"disable_additional_properties is only supported "
|
||||
"for the guidance backend."
|
||||
)
|
||||
return self
|
||||
|
||||
@ -289,6 +289,48 @@ class VllmConfig:
|
||||
|
||||
return replace(self, model_config=model_config)
|
||||
|
||||
def _post_init_kv_transfer_config(self) -> None:
|
||||
"""Update KVTransferConfig based on top-level configs in VllmConfig.
|
||||
|
||||
Right now, this function reads the offloading settings from
|
||||
CacheConfig and configures the KVTransferConfig accordingly.
|
||||
"""
|
||||
if (kv_offloading_backend := self.cache_config.kv_offloading_backend) is None:
|
||||
return
|
||||
|
||||
# If no KVTransferConfig is provided, create a default one.
|
||||
if self.kv_transfer_config is None:
|
||||
self.kv_transfer_config = KVTransferConfig()
|
||||
|
||||
if (kv_offloading_size := self.cache_config.kv_offloading_size) is None:
|
||||
raise ValueError(
|
||||
"You must set kv_offloading_size when kv_offloading_backend is set."
|
||||
)
|
||||
num_kv_ranks = (
|
||||
self.parallel_config.tensor_parallel_size
|
||||
* self.parallel_config.pipeline_parallel_size
|
||||
)
|
||||
|
||||
if kv_offloading_backend == "native":
|
||||
self.kv_transfer_config.kv_connector = "OffloadingConnector"
|
||||
kv_bytes_per_rank = kv_offloading_size * (1 << 30) / num_kv_ranks
|
||||
|
||||
# NOTE(ApostaC): the actual calculation for num_cpu_blocks should be
|
||||
# done after the model's KV cache is initialized
|
||||
self.kv_transfer_config.kv_connector_extra_config.update(
|
||||
{"kv_bytes_per_rank": kv_bytes_per_rank, "num_cpu_blocks": 0}
|
||||
)
|
||||
elif kv_offloading_backend == "lmcache":
|
||||
self.kv_transfer_config.kv_connector = "LMCacheConnectorV1"
|
||||
kv_gb_per_rank = kv_offloading_size / num_kv_ranks
|
||||
self.kv_transfer_config.kv_connector_extra_config = {
|
||||
"lmcache.local_cpu": True,
|
||||
"lmcache.max_local_cpu_size": kv_gb_per_rank,
|
||||
}
|
||||
|
||||
# This is the same for all backends
|
||||
self.kv_transfer_config.kv_role = "kv_both"
|
||||
|
||||
def __post_init__(self):
|
||||
"""Verify configs are valid & consistent with each other."""
|
||||
|
||||
@ -646,6 +688,9 @@ class VllmConfig:
|
||||
if "-quant_fp8" not in custom_ops:
|
||||
custom_ops.append("+quant_fp8")
|
||||
|
||||
# Handle the KV connector configs
|
||||
self._post_init_kv_transfer_config()
|
||||
|
||||
def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list:
|
||||
# remove the sizes that not multiple of tp_size when
|
||||
# enable sequence parallelism
|
||||
|
||||
@ -6,7 +6,7 @@ KV cache helper for store.
|
||||
|
||||
from collections.abc import Sequence
|
||||
from concurrent.futures import CancelledError, Future
|
||||
from typing import TYPE_CHECKING, Literal, cast
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
import torch
|
||||
|
||||
@ -138,8 +138,11 @@ class KVOutputAggregator:
|
||||
return cls(connector.get_finished_count() or world_size)
|
||||
|
||||
def aggregate(
|
||||
self, outputs: list[ModelRunnerOutput], output_rank: int = 0
|
||||
) -> ModelRunnerOutput:
|
||||
self, outputs: list[ModelRunnerOutput | None], output_rank: int = 0
|
||||
) -> ModelRunnerOutput | None:
|
||||
if not outputs[output_rank]:
|
||||
return None
|
||||
|
||||
# Aggregate kv_connector_output from all workers
|
||||
|
||||
def update_finished_set(
|
||||
@ -161,6 +164,7 @@ class KVOutputAggregator:
|
||||
aggregated_kv_connector_stats = None
|
||||
invalid_block_ids = set[int]()
|
||||
for model_runner_output in outputs:
|
||||
assert model_runner_output is not None
|
||||
kv_output = model_runner_output.kv_connector_output
|
||||
if not kv_output:
|
||||
continue
|
||||
@ -204,6 +208,7 @@ class KVOutputAggregator:
|
||||
# select output of the worker specified by output_rank
|
||||
output = outputs[output_rank]
|
||||
|
||||
assert output is not None
|
||||
output.kv_connector_output = KVConnectorOutput(
|
||||
finished_sending=finished_sending or None,
|
||||
finished_recving=finished_recving or None,
|
||||
@ -215,13 +220,16 @@ class KVOutputAggregator:
|
||||
return output
|
||||
|
||||
def async_aggregate(
|
||||
self, output_futures: Sequence[Future[ModelRunnerOutput]], output_rank: int = 0
|
||||
) -> Future[ModelRunnerOutput]:
|
||||
self,
|
||||
output_futures: Sequence[Future[ModelRunnerOutput | None]],
|
||||
output_rank: int = 0,
|
||||
) -> Future[ModelRunnerOutput | None]:
|
||||
"""Takes a list of futures and returns a single future which resolves
|
||||
to the respective list of outputs."""
|
||||
result_future: Future[ModelRunnerOutput] = Future()
|
||||
result_future: Future[ModelRunnerOutput | None] = Future()
|
||||
|
||||
outputs: list[ModelRunnerOutput | None] = [None] * len(output_futures)
|
||||
remaining = len(output_futures)
|
||||
|
||||
def make_callback(idx):
|
||||
def callback(fut):
|
||||
@ -236,12 +244,10 @@ class KVOutputAggregator:
|
||||
result_future.set_exception(e)
|
||||
|
||||
# this check assumes io_thread_pool uses a single thread
|
||||
if all(outputs):
|
||||
result_future.set_result(
|
||||
self.aggregate(
|
||||
cast(list[ModelRunnerOutput], outputs), output_rank
|
||||
)
|
||||
)
|
||||
nonlocal remaining
|
||||
remaining -= 1
|
||||
if not remaining:
|
||||
result_future.set_result(self.aggregate(outputs, output_rank))
|
||||
|
||||
return callback
|
||||
|
||||
|
||||
@ -122,6 +122,15 @@ class KVConnectorRole(enum.Enum):
|
||||
WORKER = 1
|
||||
|
||||
|
||||
class KVConnectorHandshakeMetadata(ABC): # noqa: B024
|
||||
"""
|
||||
Metadata used for out of band connector handshake between
|
||||
P/D workers. This needs to serializeable.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class KVConnectorMetadata(ABC): # noqa: B024
|
||||
"""
|
||||
Abstract Metadata used to communicate between the
|
||||
@ -320,6 +329,18 @@ class KVConnectorBase_V1(ABC):
|
||||
"""
|
||||
return None
|
||||
|
||||
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
|
||||
"""
|
||||
Get the KVConnector handshake metadata for this connector.
|
||||
This metadata is used for out-of-band connector handshake
|
||||
between P/D workers.
|
||||
|
||||
Returns:
|
||||
KVConnectorHandshakeMetadata: the handshake metadata.
|
||||
None if no handshake metadata is available.
|
||||
"""
|
||||
return None
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
@ -477,6 +498,17 @@ class KVConnectorBase_V1(ABC):
|
||||
"""
|
||||
return None
|
||||
|
||||
def set_xfer_handshake_metadata(
|
||||
self, metadata: dict[int, KVConnectorHandshakeMetadata]
|
||||
) -> None:
|
||||
"""
|
||||
Set the KV connector handshake metadata for this connector.
|
||||
|
||||
Args:
|
||||
metadata (KVConnectorHandshakeMetadata): the handshake metadata to set.
|
||||
"""
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def build_prom_metrics(
|
||||
cls,
|
||||
|
||||
@ -27,6 +27,7 @@ from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
CopyBlocksOp,
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorHandshakeMetadata,
|
||||
KVConnectorMetadata,
|
||||
KVConnectorRole,
|
||||
)
|
||||
@ -93,15 +94,12 @@ _NIXL_SUPPORTED_DEVICE = {
|
||||
_NIXL_SUPPORTED_DEVICE.update(current_platform.get_nixl_supported_devices())
|
||||
|
||||
|
||||
class NixlAgentMetadata(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
# required for @cached_property.
|
||||
dict=True,
|
||||
):
|
||||
@dataclass
|
||||
class NixlAgentMetadata(KVConnectorHandshakeMetadata):
|
||||
engine_id: str
|
||||
agent_metadata: bytes
|
||||
kv_caches_base_addr: list[int]
|
||||
device_id: int
|
||||
num_blocks: int
|
||||
block_lens: list[int]
|
||||
attn_backend_name: str
|
||||
@ -223,6 +221,18 @@ class NixlConnector(KVConnectorBase_V1):
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.request_finished(request, block_ids)
|
||||
|
||||
def set_xfer_handshake_metadata(
|
||||
self, metadata: dict[int, KVConnectorHandshakeMetadata]
|
||||
) -> None:
|
||||
"""
|
||||
Set the KV connector handshake metadata for this connector.
|
||||
|
||||
Args:
|
||||
metadata (dict): the handshake metadata to set.
|
||||
"""
|
||||
assert self.connector_scheduler is not None
|
||||
self.connector_scheduler.set_xfer_handshake_metadata(metadata)
|
||||
|
||||
############################################################
|
||||
# Worker Side Methods
|
||||
############################################################
|
||||
@ -299,6 +309,21 @@ class NixlConnector(KVConnectorBase_V1):
|
||||
def shutdown(self):
|
||||
if self.connector_worker is not None:
|
||||
self.connector_worker.shutdown()
|
||||
if self.connector_scheduler is not None:
|
||||
self.connector_scheduler.shutdown()
|
||||
|
||||
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
|
||||
"""
|
||||
Get the KVConnector handshake metadata for this connector.
|
||||
This metadata is used for out-of-band connector handshake
|
||||
between P/D workers.
|
||||
|
||||
Returns:
|
||||
KVConnectorHandshakeMetadata: the handshake metadata.
|
||||
None if no handshake metadata is available.
|
||||
"""
|
||||
assert self.connector_worker is not None
|
||||
return self.connector_worker.xfer_handshake_metadata
|
||||
|
||||
|
||||
class NixlConnectorScheduler:
|
||||
@ -312,12 +337,16 @@ class NixlConnectorScheduler:
|
||||
self.side_channel_port = (
|
||||
envs.VLLM_NIXL_SIDE_CHANNEL_PORT
|
||||
+ vllm_config.parallel_config.data_parallel_rank
|
||||
* vllm_config.parallel_config.tensor_parallel_size
|
||||
)
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
self.use_host_buffer = vllm_config.kv_transfer_config.kv_buffer_device == "cpu"
|
||||
logger.info("Initializing NIXL Scheduler %s", engine_id)
|
||||
|
||||
# Background thread for handling new handshake requests.
|
||||
self._nixl_handshake_listener_t: threading.Thread | None = None
|
||||
self._encoded_xfer_handshake_metadata: dict[int, Any] = {}
|
||||
self._stop_event = threading.Event()
|
||||
|
||||
# Requests that need to start recv/send.
|
||||
# New requests are added by update_state_after_alloc in
|
||||
# the scheduler. Used to make metadata passed to Worker.
|
||||
@ -330,6 +359,89 @@ class NixlConnectorScheduler:
|
||||
# remote prefill or aborted.
|
||||
self._reqs_not_processed: set[ReqId] = set()
|
||||
|
||||
def shutdown(self):
|
||||
self._stop_event.set()
|
||||
if self._nixl_handshake_listener_t is not None:
|
||||
self._nixl_handshake_listener_t.join()
|
||||
self._nixl_handshake_listener_t = None
|
||||
|
||||
def set_xfer_handshake_metadata(
|
||||
self, metadata: dict[int, KVConnectorHandshakeMetadata]
|
||||
) -> None:
|
||||
"""
|
||||
Set the KV connector handshake metadata for this connector.
|
||||
|
||||
Args:
|
||||
metadata (dict): the handshake metadata to set.
|
||||
"""
|
||||
encoded_data: dict[int, bytes] = {}
|
||||
encoder = msgspec.msgpack.Encoder()
|
||||
for tp_rank, rank_metadata in metadata.items():
|
||||
if not isinstance(rank_metadata, NixlAgentMetadata):
|
||||
raise ValueError(
|
||||
"NixlConnectorScheduler expects NixlAgentMetadata for "
|
||||
"handshake metadata."
|
||||
)
|
||||
encoded_data[tp_rank] = encoder.encode(rank_metadata)
|
||||
logger.debug(
|
||||
"Tp rank %d: encoded NixlAgentMetadata size: %s bytes",
|
||||
tp_rank,
|
||||
str(len(encoded_data[tp_rank])),
|
||||
)
|
||||
self._encoded_xfer_handshake_metadata = encoded_data
|
||||
|
||||
# Only start the listener when we have metadata to serve.
|
||||
if self._nixl_handshake_listener_t is None:
|
||||
ready_event = threading.Event()
|
||||
self._nixl_handshake_listener_t = threading.Thread(
|
||||
target=self._nixl_handshake_listener,
|
||||
args=(
|
||||
encoded_data,
|
||||
ready_event,
|
||||
self._stop_event,
|
||||
self.side_channel_port,
|
||||
),
|
||||
daemon=True,
|
||||
name="nixl_handshake_listener",
|
||||
)
|
||||
self._nixl_handshake_listener_t.start()
|
||||
ready_event.wait() # Wait for listener ZMQ socket to be ready.
|
||||
|
||||
@staticmethod
|
||||
def _nixl_handshake_listener(
|
||||
encoded_data: dict[int, Any],
|
||||
ready_event: threading.Event,
|
||||
stop_event: threading.Event,
|
||||
port: int,
|
||||
):
|
||||
"""Background thread for getting new NIXL handshakes."""
|
||||
# NOTE(rob): this is a simple implementation. We will move
|
||||
# to a better approach via HTTP endpoint soon.
|
||||
|
||||
# Listen for new requests for metadata.
|
||||
host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
|
||||
path = make_zmq_path("tcp", host, port)
|
||||
logger.debug("Starting listening on path: %s", path)
|
||||
with zmq_ctx(zmq.ROUTER, path) as sock:
|
||||
sock.setsockopt(zmq.RCVTIMEO, 1000)
|
||||
ready_event.set()
|
||||
while True:
|
||||
try:
|
||||
identity, _, msg = sock.recv_multipart()
|
||||
except zmq.Again:
|
||||
if stop_event.is_set():
|
||||
break
|
||||
continue
|
||||
# Decode the message which contains (GET_META_MSG, rank)
|
||||
msg, target_tp_rank = msgspec.msgpack.decode(msg)
|
||||
logger.debug(
|
||||
"Received message for tp rank %s",
|
||||
target_tp_rank,
|
||||
)
|
||||
if msg != GET_META_MSG:
|
||||
logger.warning("Connection listener got unexpected message %s", msg)
|
||||
sock.send_multipart((identity, b"", encoded_data[target_tp_rank]))
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request", num_computed_tokens: int
|
||||
) -> tuple[int, bool]:
|
||||
@ -537,8 +649,6 @@ class NixlConnectorScheduler:
|
||||
class NixlConnectorWorker:
|
||||
"""Implementation of Worker side methods"""
|
||||
|
||||
_POLL_TIMEOUT = 0.1 # Handshake thread polls for stop event every 100ms
|
||||
|
||||
@dataclass
|
||||
class TpKVTopology:
|
||||
"""
|
||||
@ -651,16 +761,6 @@ class NixlConnectorWorker:
|
||||
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
|
||||
self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict)
|
||||
|
||||
# NIXL handshake port.
|
||||
# NOTE(rob): Within a DP group, each DP rank gets its own
|
||||
# base port (which is sent in the KVTransferParams).
|
||||
# Each TP rank listens/queries on the base_port + tp_rank.
|
||||
self.side_channel_port: int = (
|
||||
envs.VLLM_NIXL_SIDE_CHANNEL_PORT
|
||||
+ vllm_config.parallel_config.data_parallel_rank
|
||||
* vllm_config.parallel_config.tensor_parallel_size
|
||||
)
|
||||
|
||||
# Metadata.
|
||||
self.engine_id: EngineId = engine_id
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
@ -706,6 +806,7 @@ class NixlConnectorWorker:
|
||||
# Map of engine_id -> kv_caches_base_addr. For TP case, each local
|
||||
# rank will still only pull from a single remote TP worker.
|
||||
self.kv_caches_base_addr: dict[EngineId, list[int]] = {}
|
||||
self.device_id: int = 0
|
||||
|
||||
# Number of NIXL regions. Currently one region per cache
|
||||
# (so 1 per layer for MLA, otherwise 2 per layer)
|
||||
@ -736,9 +837,8 @@ class NixlConnectorWorker:
|
||||
# requests that skipped transfer (handshake or transfer failures)
|
||||
self._failed_recv_reqs: set[ReqId] = set()
|
||||
|
||||
# Background thread for handling new handshake requests.
|
||||
self._nixl_handshake_listener_t: threading.Thread | None = None
|
||||
self._nixl_handshake_listener_stop_event: threading.Event | None = None
|
||||
# Handshake metadata of this worker for NIXL transfers.
|
||||
self.xfer_handshake_metadata: NixlAgentMetadata | None = None
|
||||
# Background thread for initializing new NIXL handshakes.
|
||||
self._handshake_initiation_executor = ThreadPoolExecutor(
|
||||
# NIXL is not guaranteed to be thread-safe, limit 1 worker.
|
||||
@ -790,42 +890,6 @@ class NixlConnectorWorker:
|
||||
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _nixl_handshake_listener(
|
||||
metadata: NixlAgentMetadata,
|
||||
ready_event: threading.Event,
|
||||
stop_event: threading.Event,
|
||||
base_port: int,
|
||||
tp_rank: int,
|
||||
):
|
||||
"""Background thread for getting new NIXL handshakes."""
|
||||
# NOTE(rob): this is a simple implementation. We will move
|
||||
# to a better approach via HTTP endpoint soon.
|
||||
|
||||
encoder = msgspec.msgpack.Encoder()
|
||||
encoded_data = encoder.encode(metadata)
|
||||
size_in_bytes = len(encoded_data)
|
||||
logger.debug("Size of encoded NixlAgentMetadata: %s bytes", str(size_in_bytes))
|
||||
|
||||
# Listen for new requests for metadata.
|
||||
host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
|
||||
path = make_zmq_path("tcp", host, base_port + tp_rank)
|
||||
logger.debug("Starting listening on path: %s", path)
|
||||
with zmq_ctx(zmq.ROUTER, path) as sock:
|
||||
ready_event.set()
|
||||
poller = zmq.Poller()
|
||||
poller.register(sock, zmq.POLLIN)
|
||||
while not stop_event.is_set():
|
||||
events = dict(
|
||||
poller.poll(timeout=NixlConnectorWorker._POLL_TIMEOUT * 1000)
|
||||
)
|
||||
if sock not in events:
|
||||
continue
|
||||
identity, _, msg = sock.recv_multipart()
|
||||
if msg != GET_META_MSG:
|
||||
logger.warning("Connection listener got unexpected message %s", msg)
|
||||
sock.send_multipart((identity, b"", encoded_data))
|
||||
|
||||
def _nixl_handshake(
|
||||
self,
|
||||
host: str,
|
||||
@ -844,16 +908,17 @@ class NixlConnectorWorker:
|
||||
# Handshake only with the remote TP rank that current local rank will
|
||||
# pull from. With homogeneous TP it happens to be the same rank_i.
|
||||
p_remote_rank = self.kv_topo.get_target_remote_rank(remote_tp_size)
|
||||
path = make_zmq_path("tcp", host, port + p_remote_rank)
|
||||
path = make_zmq_path("tcp", host, port)
|
||||
logger.debug(
|
||||
"Querying metadata on path: %s at remote rank %s", path, p_remote_rank
|
||||
"Querying metadata on path: %s at remote tp rank %s", path, p_remote_rank
|
||||
)
|
||||
|
||||
# Send query for the request.
|
||||
with zmq_ctx(zmq.REQ, path) as sock:
|
||||
msg = msgspec.msgpack.encode((GET_META_MSG, p_remote_rank))
|
||||
# Set receive timeout to 5 seconds to avoid hanging on dead server
|
||||
sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds
|
||||
sock.send(GET_META_MSG)
|
||||
sock.send(msg)
|
||||
metadata_bytes = sock.recv()
|
||||
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
|
||||
metadata = decoder.decode(metadata_bytes)
|
||||
@ -1042,6 +1107,10 @@ class NixlConnectorWorker:
|
||||
assert tensor_size_bytes == curr_tensor_size_bytes, (
|
||||
"All kv cache tensors must have the same size"
|
||||
)
|
||||
# Need to make sure the device ID is non-negative for NIXL,
|
||||
# Torch uses -1 to indicate CPU tensors while NIXL uses explicit
|
||||
# memory type.
|
||||
self.device_id = max(cache.get_device(), 0)
|
||||
caches_data.append(
|
||||
(base_addr, curr_tensor_size_bytes, self.device_id, "")
|
||||
)
|
||||
@ -1139,10 +1208,11 @@ class NixlConnectorWorker:
|
||||
assert len(self.block_window_per_layer) == self.num_layers
|
||||
|
||||
# After KV Caches registered, listen for new connections.
|
||||
metadata = NixlAgentMetadata(
|
||||
self.xfer_handshake_metadata = NixlAgentMetadata(
|
||||
engine_id=self.engine_id,
|
||||
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
|
||||
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
|
||||
device_id=self.device_id,
|
||||
num_blocks=self.num_blocks,
|
||||
block_lens=self.block_len_per_layer,
|
||||
attn_backend_name=self.backend_name,
|
||||
@ -1150,22 +1220,6 @@ class NixlConnectorWorker:
|
||||
if not self.use_host_buffer
|
||||
else self.host_buffer_kv_cache_layout,
|
||||
)
|
||||
ready_event, stop_event = threading.Event(), threading.Event()
|
||||
self._nixl_handshake_listener_t = threading.Thread(
|
||||
target=self._nixl_handshake_listener,
|
||||
args=(
|
||||
metadata,
|
||||
ready_event,
|
||||
stop_event,
|
||||
self.side_channel_port,
|
||||
self.tp_rank,
|
||||
),
|
||||
daemon=True,
|
||||
name="nixl_handshake_listener",
|
||||
)
|
||||
self._nixl_handshake_listener_t.start()
|
||||
self._nixl_handshake_listener_stop_event = stop_event
|
||||
ready_event.wait() # Wait for listener ZMQ socket to be ready.
|
||||
|
||||
def add_remote_agent(
|
||||
self,
|
||||
@ -1267,7 +1321,7 @@ class NixlConnectorWorker:
|
||||
# self.block_len == remote_block_len//tp_ratio bytes.
|
||||
addr = base_addr + block_offset + rank_offset
|
||||
# (addr, len, device id)
|
||||
blocks_data.append((addr, kv_block_len, remote_tp_rank))
|
||||
blocks_data.append((addr, kv_block_len, nixl_agent_meta.device_id))
|
||||
|
||||
if self._use_flashinfer:
|
||||
# With FlashInfer index V separately to allow head splitting.
|
||||
@ -1275,7 +1329,9 @@ class NixlConnectorWorker:
|
||||
block_offset = block_id * nixl_agent_meta.block_lens[i]
|
||||
addr = base_addr + block_offset + rank_offset
|
||||
v_addr = addr + nixl_agent_meta.block_lens[i] // 2
|
||||
blocks_data.append((v_addr, kv_block_len, remote_tp_rank))
|
||||
blocks_data.append(
|
||||
(v_addr, kv_block_len, nixl_agent_meta.device_id)
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Created %s blocks for dst engine %s with remote rank %s and local rank %s",
|
||||
@ -1843,14 +1899,6 @@ class NixlConnectorWorker:
|
||||
def shutdown(self):
|
||||
"""Shutdown the connector worker."""
|
||||
self._handshake_initiation_executor.shutdown(wait=False)
|
||||
if self._nixl_handshake_listener_stop_event is not None:
|
||||
self._nixl_handshake_listener_stop_event.set()
|
||||
self._nixl_handshake_listener_stop_event = None
|
||||
if self._nixl_handshake_listener_t is not None:
|
||||
# Generous timeout to allow the thread to exit
|
||||
self._nixl_handshake_listener_t.join(timeout=self._POLL_TIMEOUT * 10)
|
||||
assert not self._nixl_handshake_listener_t.is_alive()
|
||||
self._nixl_handshake_listener_t = None
|
||||
for handles in self._recving_transfers.values():
|
||||
for handle, _ in handles:
|
||||
self.nixl_wrapper.release_xfer_handle(handle)
|
||||
|
||||
@ -54,7 +54,13 @@ from vllm.config import (
|
||||
VllmConfig,
|
||||
get_attr_docs,
|
||||
)
|
||||
from vllm.config.cache import BlockSize, CacheDType, MambaDType, PrefixCachingHashAlgo
|
||||
from vllm.config.cache import (
|
||||
BlockSize,
|
||||
CacheDType,
|
||||
KVOffloadingBackend,
|
||||
MambaDType,
|
||||
PrefixCachingHashAlgo,
|
||||
)
|
||||
from vllm.config.device import Device
|
||||
from vllm.config.model import (
|
||||
ConvertOption,
|
||||
@ -553,6 +559,11 @@ class EngineArgs:
|
||||
|
||||
kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill
|
||||
|
||||
kv_offloading_size: float | None = CacheConfig.kv_offloading_size
|
||||
kv_offloading_backend: KVOffloadingBackend | None = (
|
||||
CacheConfig.kv_offloading_backend
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# support `EngineArgs(compilation_config={...})`
|
||||
# without having to manually construct a
|
||||
@ -896,6 +907,12 @@ class EngineArgs:
|
||||
cache_group.add_argument(
|
||||
"--mamba-block-size", **cache_kwargs["mamba_block_size"]
|
||||
)
|
||||
cache_group.add_argument(
|
||||
"--kv-offloading-size", **cache_kwargs["kv_offloading_size"]
|
||||
)
|
||||
cache_group.add_argument(
|
||||
"--kv-offloading-backend", **cache_kwargs["kv_offloading_backend"]
|
||||
)
|
||||
|
||||
# Multimodal related configs
|
||||
multimodal_kwargs = get_kwargs(MultiModalConfig)
|
||||
@ -1246,8 +1263,6 @@ class EngineArgs:
|
||||
self,
|
||||
target_model_config: ModelConfig,
|
||||
target_parallel_config: ParallelConfig,
|
||||
enable_chunked_prefill: bool,
|
||||
disable_log_stats: bool,
|
||||
) -> SpeculativeConfig | None:
|
||||
"""Initializes and returns a SpeculativeConfig object based on
|
||||
`speculative_config`.
|
||||
@ -1267,8 +1282,6 @@ class EngineArgs:
|
||||
{
|
||||
"target_model_config": target_model_config,
|
||||
"target_parallel_config": target_parallel_config,
|
||||
"enable_chunked_prefill": enable_chunked_prefill,
|
||||
"disable_log_stats": disable_log_stats,
|
||||
}
|
||||
)
|
||||
return SpeculativeConfig(**self.speculative_config)
|
||||
@ -1391,6 +1404,8 @@ class EngineArgs:
|
||||
mamba_cache_dtype=self.mamba_cache_dtype,
|
||||
mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
|
||||
mamba_block_size=self.mamba_block_size,
|
||||
kv_offloading_size=self.kv_offloading_size,
|
||||
kv_offloading_backend=self.kv_offloading_backend,
|
||||
)
|
||||
|
||||
ray_runtime_env = None
|
||||
@ -1561,8 +1576,6 @@ class EngineArgs:
|
||||
speculative_config = self.create_speculative_config(
|
||||
target_model_config=model_config,
|
||||
target_parallel_config=parallel_config,
|
||||
enable_chunked_prefill=self.enable_chunked_prefill,
|
||||
disable_log_stats=self.disable_log_stats,
|
||||
)
|
||||
|
||||
# make sure num_lookahead_slots is set appropriately depending on
|
||||
@ -1813,7 +1826,7 @@ class EngineArgs:
|
||||
incremental_prefill_supported = (
|
||||
pooling_type is not None
|
||||
and pooling_type.lower() == "last"
|
||||
and is_causal
|
||||
and bool(is_causal)
|
||||
)
|
||||
|
||||
action = "Enabling" if incremental_prefill_supported else "Disabling"
|
||||
|
||||
@ -241,6 +241,7 @@ async def build_async_engine_client_from_engine_args(
|
||||
)
|
||||
|
||||
# Don't keep the dummy data in memory
|
||||
assert async_llm is not None
|
||||
await async_llm.reset_mm_cache()
|
||||
|
||||
yield async_llm
|
||||
|
||||
@ -345,22 +345,7 @@ class OpenAIServing:
|
||||
|
||||
if is_explicit_encoder_decoder_prompt(prompt):
|
||||
raise NotImplementedError
|
||||
else:
|
||||
processed_inputs = processor.input_preprocessor._prompt_to_llm_inputs(
|
||||
prompt
|
||||
)
|
||||
|
||||
if processed_inputs["type"] == "embeds":
|
||||
raise NotImplementedError
|
||||
|
||||
# This is a workaround to fix multimodal beam search; this is a
|
||||
# bandaid fix for 2 small problems:
|
||||
# 1. Multi_modal_data on the processed_inputs currently resolves to
|
||||
# `None`.
|
||||
# 2. preprocessing above expands the multimodal placeholders. However,
|
||||
# this happens again in generation, so the double expansion causes
|
||||
# a mismatch.
|
||||
# TODO - would be ideal to handle this more gracefully.
|
||||
prompt_text: str | None
|
||||
prompt_token_ids: list[int]
|
||||
multi_modal_data: MultiModalDataDict | None
|
||||
@ -373,9 +358,16 @@ class OpenAIServing:
|
||||
prompt_token_ids = prompt.get("prompt_token_ids", []) # type: ignore
|
||||
multi_modal_data = prompt.get("multi_modal_data") # type: ignore
|
||||
|
||||
mm_processor_kwargs: dict[str, Any] | None = processed_inputs.get(
|
||||
"mm_processor_kwargs"
|
||||
) # type: ignore
|
||||
mm_processor_kwargs: dict[str, Any] | None = None
|
||||
|
||||
# This is a workaround to fix multimodal beam search; this is a
|
||||
# bandaid fix for 2 small problems:
|
||||
# 1. Multi_modal_data on the processed_inputs currently resolves to
|
||||
# `None`.
|
||||
# 2. preprocessing above expands the multimodal placeholders. However,
|
||||
# this happens again in generation, so the double expansion causes
|
||||
# a mismatch.
|
||||
# TODO - would be ideal to handle this more gracefully.
|
||||
|
||||
tokenized_length = len(prompt_token_ids)
|
||||
|
||||
|
||||
@ -2,11 +2,12 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaFunctionCall,
|
||||
|
||||
@ -15,9 +15,7 @@ from vllm.distributed.parallel_state import (
|
||||
from vllm.lora.layers.base import BaseLayerWithLoRA
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG,
|
||||
_get_config_dtype_str,
|
||||
mxfp4_w4a16_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
modular_marlin_fused_moe,
|
||||
@ -26,13 +24,16 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
modular_triton_fused_moe,
|
||||
try_get_optimal_moe_config,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.mxfp4 import Mxfp4Config
|
||||
|
||||
|
||||
class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
def __init__(self, base_layer: FusedMoE) -> None:
|
||||
super().__init__()
|
||||
self.base_layer = base_layer
|
||||
|
||||
assert not self.base_layer.use_ep, (
|
||||
"EP support for Fused MoE LoRA is not implemented yet."
|
||||
)
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.device = base_layer.w2_weight.device
|
||||
@ -42,17 +43,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
moe_state_dict = {}
|
||||
top_k = self.base_layer.top_k
|
||||
|
||||
if self.base_layer.quant_config is None:
|
||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
elif not isinstance(self.base_layer.quant_config, Mxfp4Config):
|
||||
quant_config = self.base_layer.quant_config
|
||||
else:
|
||||
quant_config = mxfp4_w4a16_moe_quant_config(
|
||||
w1_bias=self.base_layer.w13_bias,
|
||||
w2_bias=self.base_layer.w2_bias,
|
||||
w1_scale=self.base_layer.w13_weight_scale,
|
||||
w2_scale=self.base_layer.w2_weight_scale,
|
||||
)
|
||||
self.base_layer.ensure_moe_quant_config_init()
|
||||
quant_config = self.base_layer.quant_method.moe_quant_config
|
||||
|
||||
m_fused_moe_fn = (
|
||||
modular_triton_fused_moe(
|
||||
@ -69,7 +61,6 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
moe_state_dict["hidden_states"] = kwargs["hidden_states"]
|
||||
moe_state_dict["topk_ids"] = kwargs["topk_ids"]
|
||||
moe_state_dict["topk_weights"] = kwargs["topk_weights"]
|
||||
moe_state_dict["global_num_experts"] = kwargs["global_num_experts"]
|
||||
moe_state_dict["expert_map"] = kwargs["expert_map"]
|
||||
moe_state_dict["apply_router_weight_on_input"] = kwargs[
|
||||
"apply_router_weight_on_input"
|
||||
@ -86,7 +77,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
hidden_states = moe_state_dict["hidden_states"]
|
||||
topk_weights = moe_state_dict["topk_weights"]
|
||||
curr_topk_ids = moe_state_dict["topk_ids"]
|
||||
global_num_experts = moe_state_dict["global_num_experts"]
|
||||
|
||||
expert_map = moe_state_dict["expert_map"]
|
||||
|
||||
config_dtype = _get_config_dtype_str(
|
||||
@ -118,7 +109,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
curr_topk_ids,
|
||||
num_tokens,
|
||||
config["BLOCK_SIZE_M"],
|
||||
global_num_experts,
|
||||
self.base_layer.local_num_experts,
|
||||
max_loras,
|
||||
expert_map,
|
||||
)
|
||||
@ -236,14 +227,10 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
) -> None:
|
||||
"""Initializes lora matrices."""
|
||||
|
||||
assert not self.base_layer.use_ep, (
|
||||
"EP support for Fused MoE LoRA is not implemented yet."
|
||||
)
|
||||
|
||||
self.w1_lora_a_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
self.base_layer.global_num_experts,
|
||||
self.base_layer.local_num_experts,
|
||||
lora_config.max_lora_rank,
|
||||
self.base_layer.hidden_size,
|
||||
),
|
||||
@ -253,7 +240,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
self.w1_lora_b_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
self.base_layer.global_num_experts,
|
||||
self.base_layer.local_num_experts,
|
||||
self.base_layer.intermediate_size_per_partition,
|
||||
lora_config.max_lora_rank,
|
||||
),
|
||||
@ -264,7 +251,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
self.w2_lora_a_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
self.base_layer.global_num_experts,
|
||||
self.base_layer.local_num_experts,
|
||||
lora_config.max_lora_rank,
|
||||
self.base_layer.intermediate_size_per_partition,
|
||||
),
|
||||
@ -274,7 +261,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
self.w2_lora_b_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
self.base_layer.global_num_experts,
|
||||
self.base_layer.local_num_experts,
|
||||
self.base_layer.hidden_size,
|
||||
lora_config.max_lora_rank,
|
||||
),
|
||||
@ -285,7 +272,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
self.w3_lora_a_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
self.base_layer.global_num_experts,
|
||||
self.base_layer.local_num_experts,
|
||||
lora_config.max_lora_rank,
|
||||
self.base_layer.hidden_size,
|
||||
),
|
||||
@ -295,7 +282,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
self.w3_lora_b_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
self.base_layer.global_num_experts,
|
||||
self.base_layer.local_num_experts,
|
||||
self.base_layer.intermediate_size_per_partition,
|
||||
lora_config.max_lora_rank,
|
||||
),
|
||||
@ -308,7 +295,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
self.lora_a_stacked = []
|
||||
self.lora_b_stacked = []
|
||||
for lora_id in range(max_loras):
|
||||
for experts_id in range(self.base_layer.global_num_experts):
|
||||
for experts_id in range(self.base_layer.local_num_experts):
|
||||
# gate_proj,down_proj,up_proj
|
||||
self.lora_a_stacked.append(self.w1_lora_a_stacked[lora_id][experts_id])
|
||||
self.lora_a_stacked.append(self.w2_lora_a_stacked[lora_id][experts_id])
|
||||
|
||||
@ -88,14 +88,17 @@ def _fused_moe_lora_kernel(
|
||||
grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)
|
||||
|
||||
# calculate pid_m,pid_n
|
||||
pid_sk = pid % SPLIT_K
|
||||
pid_m_n = pid // SPLIT_K
|
||||
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
group_id = pid_m_n // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
pid_m = first_pid_m + ((pid_m_n % num_pid_in_group) % group_size_m)
|
||||
pid_n = (pid_m_n % num_pid_in_group) // group_size_m
|
||||
|
||||
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr + lora_idx)
|
||||
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
||||
@ -113,7 +116,7 @@ def _fused_moe_lora_kernel(
|
||||
cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size
|
||||
|
||||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
offs_k = pid_sk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
||||
|
||||
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
||||
token_ind = stride_tl * lora_idx + offs_token_id
|
||||
@ -131,7 +134,8 @@ def _fused_moe_lora_kernel(
|
||||
cur_b_ptr
|
||||
+ lora_idx * stride_bl
|
||||
+ expert_id * stride_be
|
||||
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
+ offs_k[:, None] * stride_bk
|
||||
+ offs_bn[None, :] * stride_bn
|
||||
)
|
||||
|
||||
# accumulator
|
||||
|
||||
@ -56,6 +56,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
ep_size: int = 1,
|
||||
tp_rank: int = 0,
|
||||
tp_size: int = 1,
|
||||
use_dp: bool = False,
|
||||
):
|
||||
super().__init__(quant_config)
|
||||
assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn, None), (
|
||||
@ -67,6 +68,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
self.tp_rank = tp_rank
|
||||
self.tp_size = tp_size
|
||||
self.out_dtype = out_dtype
|
||||
self.use_dp = use_dp
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
@ -117,7 +119,8 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
"""
|
||||
workspace1 = (M, K)
|
||||
workspace2 = (0,)
|
||||
output_shape = (M, K * 2 if self.quant_dtype == "nvfp4" else K)
|
||||
# For TP, the quantization is fused with fused_moe call.
|
||||
output_shape = (M, K * 2 if self.quant_dtype == "nvfp4" and self.use_dp else K)
|
||||
# The workspace is determined by `aq`, since it comes after any
|
||||
# potential communication op and is involved in the expert computation.
|
||||
return (workspace1, workspace2, output_shape)
|
||||
@ -214,6 +217,7 @@ def flashinfer_cutlass_moe_fp4(
|
||||
FlashInferExperts(
|
||||
out_dtype=hidden_states.dtype,
|
||||
quant_config=quant_config,
|
||||
use_dp=False,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -170,6 +170,8 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
|
||||
self._apply_router_weight_on_input(
|
||||
a1, topk_weights, topk_ids, apply_router_weight_on_input
|
||||
)
|
||||
if not self.use_dp:
|
||||
return a1, None, None, topk_ids, topk_weights
|
||||
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1,
|
||||
@ -179,14 +181,13 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
|
||||
quant_config.block_shape,
|
||||
is_fp4_scale_swizzled=not self.use_dp,
|
||||
)
|
||||
if self.use_dp:
|
||||
topk_weights, topk_ids, a1q, a1q_scale = get_dp_group().all_gatherv(
|
||||
[topk_weights, topk_ids, a1q, a1q_scale],
|
||||
dim=0,
|
||||
sizes=get_local_sizes(),
|
||||
)
|
||||
if quant_config.quant_dtype == "nvfp4":
|
||||
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
|
||||
topk_weights, topk_ids, a1q, a1q_scale = get_dp_group().all_gatherv(
|
||||
[topk_weights, topk_ids, a1q, a1q_scale],
|
||||
dim=0,
|
||||
sizes=get_local_sizes(),
|
||||
)
|
||||
if quant_config.quant_dtype == "nvfp4":
|
||||
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
|
||||
|
||||
return a1q, a1q_scale, None, topk_ids, topk_weights
|
||||
|
||||
|
||||
@ -672,8 +672,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
elif self.fused_experts is not None:
|
||||
if self.moe.has_bias:
|
||||
raise ValueError("FusedMoEModularKernel does not support bias.")
|
||||
result = self.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
|
||||
@ -40,18 +40,36 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
def kda_attention(
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
q_proj_states: torch.Tensor,
|
||||
k_proj_states: torch.Tensor,
|
||||
v_proj_states: torch.Tensor,
|
||||
g1: torch.Tensor,
|
||||
g2: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
core_attn_out: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
self._forward(hidden_states=hidden_states, output=output)
|
||||
self._forward(
|
||||
q_proj_states=q_proj_states,
|
||||
k_proj_states=k_proj_states,
|
||||
v_proj_states=v_proj_states,
|
||||
g1=g1,
|
||||
g2=g2,
|
||||
beta=beta,
|
||||
core_attn_out=core_attn_out,
|
||||
)
|
||||
|
||||
|
||||
def kda_attention_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
q_proj_states: torch.Tensor,
|
||||
k_proj_states: torch.Tensor,
|
||||
v_proj_states: torch.Tensor,
|
||||
g1: torch.Tensor,
|
||||
g2: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
core_attn_out: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
return
|
||||
@ -60,7 +78,7 @@ def kda_attention_fake(
|
||||
direct_register_custom_op(
|
||||
op_name="kda_attention",
|
||||
op_func=kda_attention,
|
||||
mutates_args=["output"],
|
||||
mutates_args=["core_attn_out"],
|
||||
fake_impl=kda_attention_fake,
|
||||
)
|
||||
|
||||
@ -242,36 +260,54 @@ class KimiDeltaAttention(nn.Module, MambaBase):
|
||||
positions: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
) -> None:
|
||||
return torch.ops.vllm.kda_attention(
|
||||
hidden_states,
|
||||
output,
|
||||
num_tokens = hidden_states.size(0)
|
||||
q = self.q_proj(hidden_states)[0]
|
||||
k = self.k_proj(hidden_states)[0]
|
||||
v = self.v_proj(hidden_states)[0]
|
||||
|
||||
beta = self.b_proj(hidden_states)[0].float().sigmoid()
|
||||
g1 = self.f_b_proj(self.f_a_proj(hidden_states)[0])[0]
|
||||
g1 = fused_kda_gate(g1, self.A_log, self.head_dim, g_bias=self.dt_bias)
|
||||
beta = beta.unsqueeze(0)
|
||||
g1 = g1.unsqueeze(0)
|
||||
|
||||
g_proj_states = self.g_b_proj(self.g_a_proj(hidden_states)[0])[0]
|
||||
g2 = rearrange(g_proj_states, "... (h d) -> ... h d", d=self.head_dim)
|
||||
|
||||
core_attn_out = torch.zeros(
|
||||
(1, num_tokens, self.local_num_heads, self.head_dim),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
torch.ops.vllm.kda_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
g1,
|
||||
g2,
|
||||
beta,
|
||||
core_attn_out,
|
||||
self.prefix,
|
||||
)
|
||||
core_attn_out = self.o_norm(core_attn_out, g2)
|
||||
core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)")
|
||||
output[:] = self.o_proj(core_attn_out)[0]
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
q_proj_states: torch.Tensor,
|
||||
k_proj_states: torch.Tensor,
|
||||
v_proj_states: torch.Tensor,
|
||||
g1: torch.Tensor,
|
||||
g2: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
core_attn_out: torch.Tensor,
|
||||
) -> None:
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||
|
||||
if attn_metadata is None:
|
||||
# V1 profile run
|
||||
# Mimic the memory allocation in the real run
|
||||
q = torch.empty_like(hidden_states)
|
||||
k = torch.empty_like(hidden_states)
|
||||
v = torch.empty_like(hidden_states)
|
||||
g = hidden_states.new_empty(
|
||||
hidden_states.size(0),
|
||||
self.local_num_heads,
|
||||
self.head_dim,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
beta = torch.empty(
|
||||
hidden_states.size(0), self.local_num_heads, dtype=torch.float32
|
||||
)
|
||||
core_attn_out = torch.empty_like(hidden_states)
|
||||
# # V1 profile run
|
||||
return
|
||||
|
||||
assert isinstance(attn_metadata, dict)
|
||||
@ -288,10 +324,6 @@ class KimiDeltaAttention(nn.Module, MambaBase):
|
||||
conv_state_k = conv_state_k.transpose(-1, -2)
|
||||
conv_state_v = conv_state_v.transpose(-1, -2)
|
||||
|
||||
q_proj_states = self.q_proj(hidden_states)[0]
|
||||
k_proj_states = self.k_proj(hidden_states)[0]
|
||||
v_proj_states = self.v_proj(hidden_states)[0]
|
||||
|
||||
q_conv_weights = self.q_conv1d.weight.view(
|
||||
self.q_conv1d.weight.size(0), self.q_conv1d.weight.size(2)
|
||||
)
|
||||
@ -374,14 +406,6 @@ class KimiDeltaAttention(nn.Module, MambaBase):
|
||||
lambda x: rearrange(x, "n (h d) -> 1 n h d", d=self.head_dim), (q, k, v)
|
||||
)
|
||||
|
||||
beta = self.b_proj(hidden_states)[0].float().sigmoid()
|
||||
|
||||
g = self.f_b_proj(self.f_a_proj(hidden_states)[0])[0]
|
||||
g = fused_kda_gate(g, self.A_log, self.head_dim, g_bias=self.dt_bias)
|
||||
|
||||
beta = beta.unsqueeze(0)
|
||||
g = g.unsqueeze(0)
|
||||
|
||||
if attn_metadata.num_prefills > 0:
|
||||
zero_idx = non_spec_state_indices_tensor[~has_initial_state]
|
||||
recurrent_state[zero_idx] = 0
|
||||
@ -393,7 +417,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
g=g,
|
||||
g=g1,
|
||||
beta=beta,
|
||||
initial_state=initial_state,
|
||||
output_final_state=True,
|
||||
@ -410,17 +434,12 @@ class KimiDeltaAttention(nn.Module, MambaBase):
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
g=g,
|
||||
g=g1,
|
||||
beta=beta,
|
||||
initial_state=recurrent_state,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
cu_seqlens=non_spec_query_start_loc,
|
||||
ssm_state_indices=non_spec_state_indices_tensor,
|
||||
)
|
||||
|
||||
g_proj_states = self.g_b_proj(self.g_a_proj(hidden_states)[0])[0]
|
||||
g = rearrange(g_proj_states, "... (h d) -> ... h d", d=self.head_dim)
|
||||
core_attn_out = self.o_norm(core_attn_out_non_spec, g)
|
||||
core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)")
|
||||
|
||||
output[:] = self.o_proj(core_attn_out)[0]
|
||||
assert core_attn_out_non_spec.shape == core_attn_out.shape
|
||||
core_attn_out[:] = core_attn_out_non_spec
|
||||
|
||||
@ -1769,29 +1769,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
elif (
|
||||
self.allow_flashinfer
|
||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
|
||||
):
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
|
||||
flashinfer_cutlass_moe_fp4,
|
||||
)
|
||||
|
||||
assert self.moe_quant_config is not None
|
||||
|
||||
return flashinfer_cutlass_moe_fp4(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
quant_config=self.moe_quant_config,
|
||||
inplace=False,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
else:
|
||||
# If no modular kernel is provided, use cutlass_moe_fp4 for TP case
|
||||
# only (no EP).
|
||||
|
||||
@ -79,6 +79,7 @@ def select_nvfp4_gemm_impl(
|
||||
ep_size=moe.moe_parallel_config.ep_size,
|
||||
tp_rank=moe.moe_parallel_config.tp_rank,
|
||||
tp_size=moe.moe_parallel_config.tp_size,
|
||||
use_dp=moe.moe_parallel_config.dp_size > 1,
|
||||
)
|
||||
|
||||
# native cutlass experts currently don't support DP; TP case won't call this
|
||||
|
||||
@ -26,6 +26,7 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only GLM-4V model compatible with HuggingFace weights."""
|
||||
|
||||
import itertools
|
||||
import math
|
||||
from collections.abc import Callable, Iterable, Mapping, Sequence
|
||||
from functools import partial
|
||||
@ -36,7 +37,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from transformers import BatchFeature
|
||||
from transformers import BatchFeature, PretrainedConfig
|
||||
from transformers.models.glm4v.configuration_glm4v import Glm4vVisionConfig
|
||||
from transformers.models.glm4v.image_processing_glm4v import (
|
||||
Glm4vImageProcessor,
|
||||
@ -89,6 +90,7 @@ from ..layers.activation import SiluAndMul
|
||||
from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
SupportsLoRA,
|
||||
SupportsMRoPE,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
)
|
||||
@ -1386,7 +1388,7 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]):
|
||||
dummy_inputs=Glm4vDummyInputsBuilder,
|
||||
)
|
||||
class Glm4vForConditionalGeneration(
|
||||
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP
|
||||
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
|
||||
):
|
||||
merge_by_field_config = True
|
||||
|
||||
@ -1613,6 +1615,149 @@ class Glm4vForConditionalGeneration(
|
||||
multimodal_embeddings += tuple(video_embeddings)
|
||||
return multimodal_embeddings
|
||||
|
||||
def get_mrope_input_positions(
|
||||
self,
|
||||
input_tokens: list[int],
|
||||
hf_config: "PretrainedConfig",
|
||||
image_grid_thw: list[list[int]] | torch.Tensor | None,
|
||||
video_grid_thw: list[list[int]] | torch.Tensor | None,
|
||||
second_per_grid_ts: list[float] | None = None,
|
||||
context_len: int = 0,
|
||||
seq_len: int | None = None,
|
||||
audio_feature_lengths: torch.Tensor | None = None,
|
||||
use_audio_in_video: bool = False,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
"""Get mrope input positions and delta value for GLM4V."""
|
||||
|
||||
image_token_id = hf_config.image_token_id
|
||||
video_start_token_id = hf_config.video_start_token_id
|
||||
video_end_token_id = hf_config.video_end_token_id
|
||||
spatial_merge_size = hf_config.vision_config.spatial_merge_size
|
||||
llm_pos_ids_list: list = []
|
||||
|
||||
if not (image_grid_thw is None and video_grid_thw is None):
|
||||
if isinstance(image_grid_thw, torch.Tensor):
|
||||
image_grid_thw = image_grid_thw.tolist()
|
||||
|
||||
input_token_type: list[str] = []
|
||||
video_check_flg = False
|
||||
for token in input_tokens:
|
||||
if token == video_start_token_id:
|
||||
video_check_flg = True
|
||||
elif token == video_end_token_id:
|
||||
video_check_flg = False
|
||||
|
||||
if (token == image_token_id) and (video_check_flg is False):
|
||||
input_token_type.append("image")
|
||||
elif (token == image_token_id) and (video_check_flg is True):
|
||||
input_token_type.append("video")
|
||||
else:
|
||||
input_token_type.append("text")
|
||||
|
||||
input_type_group: list[tuple[str, int, int]] = []
|
||||
for key, group_iter in itertools.groupby(
|
||||
enumerate(input_token_type), lambda x: x[1]
|
||||
):
|
||||
group_list = list(group_iter)
|
||||
start_index = group_list[0][0]
|
||||
end_index = group_list[-1][0] + 1
|
||||
input_type_group.append((key, start_index, end_index))
|
||||
|
||||
video_frame_num = 1
|
||||
mm_data_idx = 0
|
||||
for modality_type, start_idx, end_idx in input_type_group:
|
||||
st_idx = (
|
||||
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
)
|
||||
if modality_type == "image":
|
||||
t, h, w = (
|
||||
image_grid_thw[mm_data_idx][0],
|
||||
image_grid_thw[mm_data_idx][1],
|
||||
image_grid_thw[mm_data_idx][2],
|
||||
)
|
||||
llm_grid_t, llm_grid_h, llm_grid_w = (
|
||||
t,
|
||||
h // spatial_merge_size,
|
||||
w // spatial_merge_size,
|
||||
)
|
||||
|
||||
t_index = (
|
||||
torch.arange(llm_grid_t)
|
||||
.view(-1, 1)
|
||||
.expand(-1, llm_grid_h * llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
h_index = (
|
||||
torch.arange(llm_grid_h)
|
||||
.view(1, -1, 1)
|
||||
.expand(llm_grid_t, -1, llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
w_index = (
|
||||
torch.arange(llm_grid_w)
|
||||
.view(1, 1, -1)
|
||||
.expand(llm_grid_t, llm_grid_h, -1)
|
||||
.flatten()
|
||||
)
|
||||
llm_pos_ids_list.append(
|
||||
torch.stack([t_index, h_index, w_index]) + st_idx
|
||||
)
|
||||
mm_data_idx += 1
|
||||
|
||||
elif modality_type == "video":
|
||||
t, h, w = (
|
||||
video_frame_num,
|
||||
image_grid_thw[mm_data_idx][1],
|
||||
image_grid_thw[mm_data_idx][2],
|
||||
)
|
||||
llm_grid_t, llm_grid_h, llm_grid_w = (
|
||||
t,
|
||||
h // spatial_merge_size,
|
||||
w // spatial_merge_size,
|
||||
)
|
||||
|
||||
for t_idx in range(llm_grid_t):
|
||||
t_index = (
|
||||
torch.tensor(t_idx)
|
||||
.view(-1, 1)
|
||||
.expand(-1, llm_grid_h * llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
h_index = (
|
||||
torch.arange(llm_grid_h)
|
||||
.view(1, -1, 1)
|
||||
.expand(1, -1, llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
w_index = (
|
||||
torch.arange(llm_grid_w)
|
||||
.view(1, 1, -1)
|
||||
.expand(1, llm_grid_h, -1)
|
||||
.flatten()
|
||||
)
|
||||
llm_pos_ids_list.append(
|
||||
torch.stack([t_index, h_index, w_index]) + st_idx
|
||||
)
|
||||
|
||||
mm_data_idx += 1
|
||||
video_frame_num += 1
|
||||
|
||||
else:
|
||||
text_len = end_idx - start_idx
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||
)
|
||||
video_frame_num = 1
|
||||
|
||||
else:
|
||||
text_len = len(input_tokens)
|
||||
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1))
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
llm_positions = llm_positions[:, context_len:seq_len]
|
||||
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
|
||||
return llm_positions, mrope_position_delta
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
@ -17,7 +17,9 @@ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPo
|
||||
from transformers.utils import torch_int
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.layer import check_upstream_fa_availability
|
||||
from vllm.attention.layer import (
|
||||
maybe_get_vit_flash_attn_backend,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
@ -56,12 +58,14 @@ from vllm.multimodal.processing import (
|
||||
PromptUpdate,
|
||||
)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
SupportsLoRA,
|
||||
SupportsMRoPE,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
)
|
||||
@ -337,7 +341,10 @@ def apply_rotary_pos_emb_flashatt(
|
||||
cos = cos.chunk(2, dim=-1)[0].contiguous()
|
||||
sin = sin.chunk(2, dim=-1)[0].contiguous()
|
||||
|
||||
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
||||
if current_platform.is_cuda():
|
||||
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
||||
elif current_platform.is_rocm():
|
||||
from flash_attn.ops.triton.rotary import apply_rotary as apply_rotary_emb
|
||||
|
||||
q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
|
||||
k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
|
||||
@ -398,18 +405,28 @@ class KeyeSiglipAttention(nn.Module):
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
|
||||
self.use_upstream_fa = False
|
||||
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
|
||||
torch.get_default_dtype()
|
||||
):
|
||||
self.attn_backend = _Backend.FLASH_ATTN
|
||||
self.use_upstream_fa = True
|
||||
self.attn_backend, self.flash_attn_varlen_func = (
|
||||
maybe_get_vit_flash_attn_backend(
|
||||
self.attn_backend,
|
||||
use_upstream_fa=False,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
)
|
||||
|
||||
if self.attn_backend not in {_Backend.FLASH_ATTN, _Backend.XFORMERS}:
|
||||
if self.attn_backend not in {
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.XFORMERS,
|
||||
_Backend.ROCM_AITER_FA,
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"Keye-VL does not support {self.attn_backend} backend now."
|
||||
)
|
||||
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.ROCM_AITER_FA,
|
||||
}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -457,15 +474,10 @@ class KeyeSiglipAttention(nn.Module):
|
||||
self.head_dim,
|
||||
)
|
||||
|
||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||
if self.use_upstream_fa:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
else:
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
|
||||
if self.is_flash_attn_backend:
|
||||
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||
|
||||
output = flash_attn_varlen_func(
|
||||
output = self.flash_attn_varlen_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
@ -1542,7 +1554,7 @@ class BaseKeyeModule(nn.Module):
|
||||
dummy_inputs=KeyeDummyInputsBuilder,
|
||||
)
|
||||
class KeyeForConditionalGeneration(
|
||||
BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP
|
||||
BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
|
||||
):
|
||||
def _build_projector(
|
||||
self,
|
||||
@ -1611,3 +1623,142 @@ class KeyeForConditionalGeneration(
|
||||
return tuple(
|
||||
self._process_video_embeds(video_type, video_grid_thw, pixel_values_videos)
|
||||
)
|
||||
|
||||
def get_mrope_input_positions(
|
||||
self,
|
||||
input_tokens: list[int],
|
||||
hf_config: PretrainedConfig,
|
||||
image_grid_thw: list[list[int]] | torch.Tensor,
|
||||
video_grid_thw: list[list[int]] | torch.Tensor,
|
||||
context_len: int = 0,
|
||||
seq_len: int | None = None,
|
||||
second_per_grid_ts: list[float] | None = None,
|
||||
audio_feature_lengths: torch.Tensor | None = None,
|
||||
use_audio_in_video: bool = False,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
if isinstance(video_grid_thw, list) and len(video_grid_thw) > 0:
|
||||
video_grid_thw = video_grid_thw[0]
|
||||
"""Get mrope input positions and delta value (Keye series)."""
|
||||
|
||||
def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]:
|
||||
"""
|
||||
Split grid_thw along the t dimension.
|
||||
|
||||
Args:
|
||||
grid_thw: shape [N, 3] tensor or nested list of [t, h, w].
|
||||
|
||||
Returns:
|
||||
List of [1, h, w] rows, repeated t times for each original row.
|
||||
"""
|
||||
|
||||
if isinstance(grid_thw, list):
|
||||
grid_thw = torch.tensor(grid_thw, dtype=torch.long)
|
||||
|
||||
if grid_thw.numel() == 0:
|
||||
return []
|
||||
|
||||
t, hw = grid_thw[:, 0], grid_thw[:, 1:]
|
||||
ones = torch.ones_like(hw[:, :1]) # [N,1]
|
||||
out = torch.cat([ones, hw], dim=1).repeat_interleave(t, dim=0)
|
||||
return out.tolist()
|
||||
|
||||
video_grid_thw = split_thw(video_grid_thw)
|
||||
|
||||
image_token_id = hf_config.image_token_id
|
||||
video_token_id = hf_config.video_token_id
|
||||
spatial_merge_size = hf_config.vision_config.spatial_merge_size
|
||||
|
||||
image_nums = len(image_grid_thw)
|
||||
frame_nums = len(video_grid_thw)
|
||||
llm_pos_ids_list: list = []
|
||||
|
||||
st = 0
|
||||
remain_images, remain_frames = image_nums, frame_nums
|
||||
|
||||
image_index, video_index = 0, 0
|
||||
for _ in range(image_nums + frame_nums):
|
||||
if remain_images > 0:
|
||||
try:
|
||||
ed_image = input_tokens.index(image_token_id, st)
|
||||
except ValueError:
|
||||
ed_image = len(input_tokens) + 1
|
||||
else:
|
||||
ed_image = len(input_tokens) + 1
|
||||
if remain_frames > 0:
|
||||
try:
|
||||
ed_video = input_tokens.index(video_token_id, st)
|
||||
except ValueError:
|
||||
ed_video = len(input_tokens) + 1
|
||||
else:
|
||||
ed_video = len(input_tokens) + 1
|
||||
|
||||
if ed_image < ed_video:
|
||||
t, h, w = (
|
||||
image_grid_thw[image_index][0],
|
||||
image_grid_thw[image_index][1],
|
||||
image_grid_thw[image_index][2],
|
||||
)
|
||||
image_index += 1
|
||||
remain_images -= 1
|
||||
ed = ed_image
|
||||
else:
|
||||
t, h, w = (
|
||||
video_grid_thw[video_index][0],
|
||||
video_grid_thw[video_index][1],
|
||||
video_grid_thw[video_index][2],
|
||||
)
|
||||
video_index += 1
|
||||
remain_frames -= 1
|
||||
ed = ed_video
|
||||
|
||||
llm_grid_t, llm_grid_h, llm_grid_w = (
|
||||
t,
|
||||
h // spatial_merge_size,
|
||||
w // spatial_merge_size,
|
||||
)
|
||||
text_len = ed - st
|
||||
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||
)
|
||||
|
||||
t_index = (
|
||||
(
|
||||
torch.arange(llm_grid_t)
|
||||
.view(-1, 1)
|
||||
.expand(-1, llm_grid_h * llm_grid_w)
|
||||
)
|
||||
.long()
|
||||
.flatten()
|
||||
)
|
||||
|
||||
h_index = (
|
||||
torch.arange(llm_grid_h)
|
||||
.view(1, -1, 1)
|
||||
.expand(llm_grid_t, -1, llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
w_index = (
|
||||
torch.arange(llm_grid_w)
|
||||
.view(1, 1, -1)
|
||||
.expand(llm_grid_t, llm_grid_h, -1)
|
||||
.flatten()
|
||||
)
|
||||
llm_pos_ids_list.append(
|
||||
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
|
||||
)
|
||||
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
||||
|
||||
if st < len(input_tokens):
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
text_len = len(input_tokens) - st
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||
)
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
|
||||
llm_positions = llm_positions[:, context_len:seq_len]
|
||||
|
||||
return llm_positions, mrope_position_delta
|
||||
|
||||
@ -22,7 +22,6 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
@ -61,7 +60,7 @@ class KimiMLP(nn.Module):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: QKVParallelLinear | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
reduce_results: bool = True,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
@ -155,6 +154,7 @@ class KimiMoE(nn.Module):
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
reduce_results=False,
|
||||
prefix=f"{prefix}.shared_experts",
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@ -340,7 +340,7 @@ class KimiDecoderLayer(nn.Module):
|
||||
self.block_sparse_moe = KimiMoE(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
prefix=f"{prefix}.block_sparse_moe",
|
||||
)
|
||||
self.mlp = self.block_sparse_moe
|
||||
else:
|
||||
|
||||
@ -49,7 +49,7 @@ from functools import cached_property
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers.activations import ACT2FN, PytorchGELUTanh
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.utils import is_flash_attn_2_available
|
||||
|
||||
@ -651,7 +651,7 @@ class MoonVitPretrainedModel(PreTrainedModel):
|
||||
"num_heads": config.num_attention_heads,
|
||||
"hidden_dim": config.hidden_size,
|
||||
"mlp_dim": config.intermediate_size,
|
||||
"activation": PytorchGELUTanh(),
|
||||
"activation": ACT2FN["gelu_pytorch_tanh"],
|
||||
"attn_bias": True,
|
||||
"attn_implementation": config._attn_implementation,
|
||||
},
|
||||
|
||||
@ -364,6 +364,8 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
|
||||
if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN:
|
||||
self.use_upstream_fa = True
|
||||
if current_platform.is_xpu():
|
||||
self.use_upstream_fa = False
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.ROCM_AITER_FA,
|
||||
@ -856,10 +858,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
max_seqlen = torch.zeros([], device=cu_seqlens.device)
|
||||
seqlens = torch.zeros(1, device=cu_seqlens.device)
|
||||
if (
|
||||
self.attn_backend == _Backend.FLASH_ATTN
|
||||
or self.attn_backend == _Backend.ROCM_AITER_FA
|
||||
):
|
||||
if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
|
||||
@ -34,7 +34,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from transformers import AutoConfig, BatchFeature, PretrainedConfig
|
||||
from transformers import BatchFeature, PretrainedConfig
|
||||
from transformers.models.qwen2_vl import Qwen2VLImageProcessor, Qwen2VLProcessor
|
||||
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
|
||||
Qwen2VLConfig,
|
||||
@ -789,10 +789,7 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
self, cu_seqlens: torch.Tensor
|
||||
) -> tuple[int | None, list[int] | None]:
|
||||
max_seqlen, seqlens = None, None
|
||||
if (
|
||||
self.attn_backend == _Backend.FLASH_ATTN
|
||||
or self.attn_backend == _Backend.ROCM_AITER_FA
|
||||
):
|
||||
if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
@ -1654,9 +1651,7 @@ class Tarsier2Processor(Qwen2VLProcessor):
|
||||
class Tarsier2ProcessingInfo(Qwen2VLProcessingInfo):
|
||||
def get_hf_config(self) -> Qwen2VLConfig:
|
||||
model_path = self.ctx.model_config.model
|
||||
original_config = AutoConfig.from_pretrained(model_path)
|
||||
config_dict = original_config.to_dict()
|
||||
correct_config = Qwen2VLConfig.from_dict(config_dict)
|
||||
correct_config = Qwen2VLConfig.from_pretrained(model_path)
|
||||
|
||||
return correct_config
|
||||
|
||||
|
||||
@ -115,6 +115,12 @@ class XPUPlatform(Platform):
|
||||
device_props = torch.xpu.get_device_properties(device_id)
|
||||
return device_props.total_memory
|
||||
|
||||
@classmethod
|
||||
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend:
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
|
||||
return _Backend.FLASH_ATTN
|
||||
|
||||
@classmethod
|
||||
def inference_mode(cls):
|
||||
return torch.no_grad()
|
||||
|
||||
@ -896,6 +896,8 @@ def get_kernel_options(
|
||||
return kernel_options
|
||||
else:
|
||||
preferred_block = 32 if query.dtype == torch.float32 else 64
|
||||
block_lower_bound = 16
|
||||
|
||||
block_m_candidate = ensure_divisible(preferred_block, block_m)
|
||||
block_n_candidate = ensure_divisible(preferred_block, block_n)
|
||||
|
||||
@ -910,6 +912,9 @@ def get_kernel_options(
|
||||
max(1, block_n_candidate // 2), block_n
|
||||
)
|
||||
|
||||
block_m_candidate = max(block_m_candidate, block_lower_bound)
|
||||
block_n_candidate = max(block_n_candidate, block_lower_bound)
|
||||
|
||||
kernel_options["BLOCK_M"] = block_m_candidate
|
||||
kernel_options["BLOCK_N"] = block_n_candidate
|
||||
|
||||
|
||||
@ -6,7 +6,7 @@ from typing import ClassVar
|
||||
import torch
|
||||
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionLayer, AttentionType
|
||||
from vllm.attention.backends.abstract import AttentionLayer, AttentionType, MultipleOf
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
@ -40,6 +40,10 @@ class FlashInferMLABackend(MLACommonBackend):
|
||||
def get_builder_cls() -> type["FlashInferMLAMetadataBuilder"]:
|
||||
return FlashInferMLAMetadataBuilder
|
||||
|
||||
@classmethod
|
||||
def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]:
|
||||
return [32, 64]
|
||||
|
||||
|
||||
g_fi_workspace = torch.zeros(
|
||||
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE,
|
||||
|
||||
@ -15,8 +15,12 @@ class AsyncScheduler(Scheduler):
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> None:
|
||||
super()._update_after_schedule(scheduler_output)
|
||||
pending_structured_output_tokens = False
|
||||
for req_id in scheduler_output.num_scheduled_tokens:
|
||||
request = self.requests[req_id]
|
||||
pending_structured_output_tokens |= (
|
||||
request.use_structured_output and request.num_output_placeholders > 0
|
||||
)
|
||||
if (
|
||||
request.num_computed_tokens
|
||||
== request.num_tokens + request.num_output_placeholders
|
||||
@ -25,6 +29,10 @@ class AsyncScheduler(Scheduler):
|
||||
# TODO(woosuk): Support speculative decoding.
|
||||
request.num_output_placeholders += 1
|
||||
|
||||
scheduler_output.pending_structured_output_tokens = (
|
||||
pending_structured_output_tokens
|
||||
)
|
||||
|
||||
def _update_request_with_output(
|
||||
self,
|
||||
request: Request,
|
||||
|
||||
@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
from vllm.v1.engine import EngineCoreOutputs
|
||||
from vllm.v1.metrics.stats import SchedulerStats
|
||||
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
||||
@ -40,6 +40,12 @@ class SchedulerInterface(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_grammar_bitmask(
|
||||
self, scheduler_output: "SchedulerOutput"
|
||||
) -> "GrammarOutput | None":
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def update_from_output(
|
||||
self,
|
||||
|
||||
@ -181,12 +181,17 @@ class SchedulerOutput:
|
||||
# freed from the encoder cache.
|
||||
free_encoder_mm_hashes: list[str]
|
||||
|
||||
# ids of structured outputs requests included in the bitmask, in the
|
||||
# same order as the corresponding stacked rows of the bitmask.
|
||||
# There may be more than one row per request in the case of speculative decoding.
|
||||
structured_output_request_ids: list[str]
|
||||
# the bitmask for the whole batch
|
||||
grammar_bitmask: "npt.NDArray[np.int32] | None"
|
||||
# Whether the scheduled requests have all the output tokens they
|
||||
# need to perform grammar bitmask computation.
|
||||
pending_structured_output_tokens: bool = False
|
||||
|
||||
# KV Cache Connector metadata.
|
||||
kv_connector_metadata: KVConnectorMetadata | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class GrammarOutput:
|
||||
# ids of structured output requests.
|
||||
structured_output_request_ids: list[str]
|
||||
# Bitmask ordered as structured_output_request_ids.
|
||||
grammar_bitmask: "npt.NDArray[np.int32]"
|
||||
|
||||
@ -5,7 +5,7 @@ import itertools
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import Any
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
|
||||
@ -24,7 +24,12 @@ from vllm.v1.core.encoder_cache_manager import (
|
||||
)
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
|
||||
from vllm.v1.core.sched.interface import SchedulerInterface
|
||||
from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput
|
||||
from vllm.v1.core.sched.output import (
|
||||
CachedRequestData,
|
||||
GrammarOutput,
|
||||
NewRequestData,
|
||||
SchedulerOutput,
|
||||
)
|
||||
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
|
||||
from vllm.v1.core.sched.utils import check_stop, remove_all
|
||||
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
|
||||
@ -35,10 +40,6 @@ from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -619,9 +620,6 @@ class Scheduler(SchedulerInterface):
|
||||
scheduled_spec_decode_tokens,
|
||||
req_to_new_blocks,
|
||||
)
|
||||
structured_output_request_ids, grammar_bitmask = self.get_grammar_bitmask(
|
||||
num_scheduled_tokens.keys(), scheduled_spec_decode_tokens
|
||||
)
|
||||
|
||||
# Record the request ids that were scheduled in this step.
|
||||
self.prev_step_scheduled_req_ids.clear()
|
||||
@ -641,8 +639,6 @@ class Scheduler(SchedulerInterface):
|
||||
# the previous and the current steps.
|
||||
finished_req_ids=self.finished_req_ids,
|
||||
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
|
||||
structured_output_request_ids=structured_output_request_ids,
|
||||
grammar_bitmask=grammar_bitmask,
|
||||
)
|
||||
|
||||
# NOTE(Kuntai): this function is designed for multiple purposes:
|
||||
@ -872,9 +868,8 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
def get_grammar_bitmask(
|
||||
self,
|
||||
scheduled_request_ids: Iterable[str],
|
||||
scheduled_spec_decode_tokens: dict[str, list[int]],
|
||||
) -> tuple[list[str], "npt.NDArray[np.int32] | None"]:
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> GrammarOutput | None:
|
||||
# Collect list of scheduled request ids that use structured output.
|
||||
# The corresponding rows of the bitmask will be in this order.
|
||||
# PERF: in case of chunked prefill,
|
||||
@ -883,18 +878,18 @@ class Scheduler(SchedulerInterface):
|
||||
# cycle to fill in the bitmask, which could be a big no-op.
|
||||
structured_output_request_ids = [
|
||||
req_id
|
||||
for req_id in scheduled_request_ids
|
||||
for req_id in scheduler_output.num_scheduled_tokens
|
||||
if (req := self.requests.get(req_id)) and req.use_structured_output
|
||||
]
|
||||
if not structured_output_request_ids:
|
||||
return structured_output_request_ids, None
|
||||
return None
|
||||
|
||||
bitmask = self.structured_output_manager.grammar_bitmask(
|
||||
self.requests,
|
||||
structured_output_request_ids,
|
||||
scheduled_spec_decode_tokens,
|
||||
scheduler_output.scheduled_spec_decode_tokens,
|
||||
)
|
||||
return structured_output_request_ids, bitmask
|
||||
return GrammarOutput(structured_output_request_ids, bitmask)
|
||||
|
||||
def update_from_output(
|
||||
self,
|
||||
|
||||
@ -12,7 +12,7 @@ from concurrent.futures import Future
|
||||
from contextlib import ExitStack, contextmanager
|
||||
from inspect import isclass, signature
|
||||
from logging import DEBUG
|
||||
from typing import Any, TypeVar
|
||||
from typing import Any, TypeVar, cast
|
||||
|
||||
import msgspec
|
||||
import zmq
|
||||
@ -163,6 +163,27 @@ class EngineCore:
|
||||
vllm_config, mm_registry
|
||||
)
|
||||
|
||||
# If a KV connector is initialized for scheduler, we want to collect
|
||||
# handshake metadata from all workers so the connector in the scheduler
|
||||
# will have the full context
|
||||
kv_connector = self.scheduler.get_kv_connector()
|
||||
if kv_connector is not None:
|
||||
# Collect and store KV connector xfer metadata from workers
|
||||
# (after KV cache registration)
|
||||
xfer_handshake_metadata = (
|
||||
self.model_executor.get_kv_connector_handshake_metadata()
|
||||
)
|
||||
|
||||
if xfer_handshake_metadata:
|
||||
# xfer_handshake_metadata is list of dicts from workers
|
||||
# Each dict already has structure {tp_rank: metadata}
|
||||
# Merge all worker dicts into a single dict
|
||||
content: dict[int, Any] = {}
|
||||
for worker_dict in xfer_handshake_metadata:
|
||||
if worker_dict is not None:
|
||||
content.update(worker_dict)
|
||||
kv_connector.set_xfer_handshake_metadata(content)
|
||||
|
||||
# Setup batch queue for pipeline parallelism.
|
||||
# Batch queue for scheduled batches. This enables us to asynchronously
|
||||
# schedule and execute batches, and is required by pipeline parallelism
|
||||
@ -178,7 +199,7 @@ class EngineCore:
|
||||
self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None
|
||||
if (
|
||||
self.vllm_config.cache_config.enable_prefix_caching
|
||||
or self.scheduler.get_kv_connector() is not None
|
||||
or kv_connector is not None
|
||||
):
|
||||
caching_hash_fn = get_hash_fn_by_name(
|
||||
vllm_config.cache_config.prefix_caching_hash_algo
|
||||
@ -313,9 +334,12 @@ class EngineCore:
|
||||
if not self.scheduler.has_requests():
|
||||
return {}, False
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
|
||||
future = self.model_executor.execute_model(scheduler_output, non_block=True)
|
||||
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
|
||||
with self.log_error_detail(scheduler_output):
|
||||
model_output = self.model_executor.execute_model(scheduler_output)
|
||||
model_output = future.result()
|
||||
if model_output is None:
|
||||
model_output = self.model_executor.sample_tokens(grammar_output)
|
||||
|
||||
engine_core_outputs = self.scheduler.update_from_output(
|
||||
scheduler_output, model_output
|
||||
@ -355,20 +379,47 @@ class EngineCore:
|
||||
assert len(batch_queue) < self.batch_queue_size
|
||||
|
||||
model_executed = False
|
||||
deferred_scheduler_output = None
|
||||
if self.scheduler.has_requests():
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
future = self.model_executor.execute_model(scheduler_output, non_block=True)
|
||||
batch_queue.appendleft((future, scheduler_output))
|
||||
|
||||
exec_future = self.model_executor.execute_model(
|
||||
scheduler_output, non_block=True
|
||||
)
|
||||
model_executed = scheduler_output.total_num_scheduled_tokens > 0
|
||||
if (
|
||||
model_executed
|
||||
and len(batch_queue) < self.batch_queue_size
|
||||
and not batch_queue[-1][0].done()
|
||||
):
|
||||
# Don't block on next worker response unless the queue is full
|
||||
# or there are no more requests to schedule.
|
||||
return None, True
|
||||
|
||||
if scheduler_output.pending_structured_output_tokens:
|
||||
# We need to defer sampling until we have processed the model output
|
||||
# from the prior step.
|
||||
deferred_scheduler_output = scheduler_output
|
||||
# Block-wait for execute to return (continues running async on the GPU).
|
||||
with self.log_error_detail(scheduler_output):
|
||||
exec_result = exec_future.result()
|
||||
assert exec_result is None
|
||||
else:
|
||||
# We aren't waiting for any tokens, get any grammar output immediately.
|
||||
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
|
||||
# Block-wait for execute to return (continues running async on the GPU).
|
||||
with self.log_error_detail(scheduler_output):
|
||||
exec_result = exec_future.result()
|
||||
|
||||
if exec_result is None:
|
||||
# Call sample tokens.
|
||||
future = self.model_executor.sample_tokens(
|
||||
grammar_output, non_block=True
|
||||
)
|
||||
else:
|
||||
# No sampling required (e.g. all requests finished).
|
||||
future = cast(Future[ModelRunnerOutput], exec_future)
|
||||
# Add this step's future to the queue.
|
||||
batch_queue.appendleft((future, scheduler_output))
|
||||
if (
|
||||
model_executed
|
||||
and len(batch_queue) < self.batch_queue_size
|
||||
and not batch_queue[-1][0].done()
|
||||
):
|
||||
# Don't block on next worker response unless the queue is full
|
||||
# or there are no more requests to schedule.
|
||||
return None, True
|
||||
|
||||
elif not batch_queue:
|
||||
# Queue is empty. We should not reach here since this method should
|
||||
@ -384,6 +435,19 @@ class EngineCore:
|
||||
engine_core_outputs = self.scheduler.update_from_output(
|
||||
scheduler_output, model_output
|
||||
)
|
||||
|
||||
# NOTE(nick): We can either handle the deferred tasks here or save
|
||||
# in a field and do it immediately once step_with_batch_queue is
|
||||
# re-called. The latter slightly favors TTFT over TPOT/throughput.
|
||||
if deferred_scheduler_output:
|
||||
# We now have the tokens needed to compute the bitmask for the
|
||||
# deferred request. Get the bitmask and call sample tokens.
|
||||
grammar_output = self.scheduler.get_grammar_bitmask(
|
||||
deferred_scheduler_output
|
||||
)
|
||||
future = self.model_executor.sample_tokens(grammar_output, non_block=True)
|
||||
batch_queue.appendleft((future, deferred_scheduler_output))
|
||||
|
||||
return engine_core_outputs, model_executed
|
||||
|
||||
def shutdown(self):
|
||||
|
||||
@ -9,11 +9,14 @@ from typing import TYPE_CHECKING, Literal, TypeVar, overload
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorHandshakeMetadata,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
from vllm.v1.engine import ReconfigureDistributedRequest
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
||||
@ -177,30 +180,51 @@ class Executor(ABC):
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_kv_connector_handshake_metadata(
|
||||
self,
|
||||
) -> list[dict[int, KVConnectorHandshakeMetadata]]:
|
||||
return self.collective_rpc("get_kv_connector_handshake_metadata")
|
||||
|
||||
@overload
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
non_block: Literal[False] = False,
|
||||
) -> ModelRunnerOutput:
|
||||
self, scheduler_output: SchedulerOutput, non_block: Literal[False] = False
|
||||
) -> ModelRunnerOutput | None:
|
||||
pass
|
||||
|
||||
@overload
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
non_block: Literal[True] = True,
|
||||
) -> Future[ModelRunnerOutput]:
|
||||
self, scheduler_output: SchedulerOutput, non_block: Literal[True] = True
|
||||
) -> Future[ModelRunnerOutput | None]:
|
||||
pass
|
||||
|
||||
def execute_model(
|
||||
self, scheduler_output: SchedulerOutput, non_block: bool = False
|
||||
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
|
||||
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
|
||||
output = self.collective_rpc( # type: ignore[call-overload]
|
||||
"execute_model", args=(scheduler_output,), non_block=non_block
|
||||
)
|
||||
return output[0]
|
||||
|
||||
@overload
|
||||
def sample_tokens(
|
||||
self, grammar_output: GrammarOutput | None, non_block: Literal[False] = False
|
||||
) -> ModelRunnerOutput:
|
||||
pass
|
||||
|
||||
@overload
|
||||
def sample_tokens(
|
||||
self, grammar_output: GrammarOutput | None, non_block: Literal[True] = True
|
||||
) -> Future[ModelRunnerOutput]:
|
||||
pass
|
||||
|
||||
def sample_tokens(
|
||||
self, grammar_output: GrammarOutput | None, non_block: bool = False
|
||||
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
|
||||
output = self.collective_rpc( # type: ignore[call-overload]
|
||||
"sample_tokens", args=(grammar_output,), non_block=non_block
|
||||
)
|
||||
return output[0]
|
||||
|
||||
def execute_dummy_batch(self) -> None:
|
||||
self.collective_rpc("execute_dummy_batch")
|
||||
|
||||
|
||||
@ -46,7 +46,7 @@ from vllm.utils.system_utils import (
|
||||
get_mp_context,
|
||||
set_process_title,
|
||||
)
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
from vllm.v1.executor.abstract import Executor, FailureCallback
|
||||
from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput
|
||||
from vllm.v1.worker.worker_base import WorkerWrapperBase
|
||||
@ -132,15 +132,12 @@ class MultiprocExecutor(Executor):
|
||||
uw.death_writer.close()
|
||||
self._ensure_worker_termination([uw.proc for uw in unready_workers])
|
||||
|
||||
# For pipeline parallel, we use a thread pool for asynchronous
|
||||
# execute_model.
|
||||
if self.max_concurrent_batches > 1:
|
||||
# Note: must use only 1 IO thread to keep dequeue sequence
|
||||
# from the response queue
|
||||
# _async_aggregate_workers_output also assumes a single IO thread
|
||||
self.io_thread_pool = ThreadPoolExecutor(
|
||||
max_workers=1, thread_name_prefix="mp_exec_io"
|
||||
)
|
||||
# Note: must use only 1 IO thread to keep dequeue sequence
|
||||
# from the response queue.
|
||||
# _async_aggregate_workers_output also assumes a single IO thread.
|
||||
self.io_thread_pool = ThreadPoolExecutor(
|
||||
max_workers=1, thread_name_prefix="mp_exec_io"
|
||||
)
|
||||
|
||||
self.output_rank = self._get_output_rank()
|
||||
self.has_connector = self.vllm_config.kv_transfer_config is not None
|
||||
@ -180,15 +177,27 @@ class MultiprocExecutor(Executor):
|
||||
self.failure_callback = callback
|
||||
|
||||
def execute_model( # type: ignore[override]
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
non_block: bool = False,
|
||||
self, scheduler_output: SchedulerOutput, non_block: bool = False
|
||||
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
|
||||
return self._execute_with_aggregation(
|
||||
"execute_model", scheduler_output, non_block=non_block
|
||||
)
|
||||
|
||||
def sample_tokens( # type: ignore[override]
|
||||
self, grammar_output: GrammarOutput | None, non_block: bool = False
|
||||
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
|
||||
return self._execute_with_aggregation( # type: ignore[return-value]
|
||||
"sample_tokens", grammar_output, non_block=non_block
|
||||
)
|
||||
|
||||
def _execute_with_aggregation(
|
||||
self, method: str, *args, non_block: bool = False
|
||||
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
|
||||
if not self.has_connector:
|
||||
# get output only from a single worker (output_rank)
|
||||
(output,) = self.collective_rpc(
|
||||
"execute_model",
|
||||
args=(scheduler_output,),
|
||||
method,
|
||||
args=args,
|
||||
unique_reply_rank=self.output_rank,
|
||||
non_block=non_block,
|
||||
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
|
||||
@ -197,8 +206,8 @@ class MultiprocExecutor(Executor):
|
||||
|
||||
# get output from all workers
|
||||
outputs = self.collective_rpc(
|
||||
"execute_model",
|
||||
args=(scheduler_output,),
|
||||
method,
|
||||
args=args,
|
||||
non_block=non_block,
|
||||
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
@ -19,7 +19,7 @@ from vllm.utils.network_utils import (
|
||||
get_ip,
|
||||
get_open_port,
|
||||
)
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.executor.ray_utils import (
|
||||
@ -41,6 +41,9 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
COMPLETED_NONE_FUTURE: Future[ModelRunnerOutput | None] = Future()
|
||||
COMPLETED_NONE_FUTURE.set_result(None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RayWorkerMetaData:
|
||||
@ -96,6 +99,8 @@ class RayDistributedExecutor(Executor):
|
||||
# KV connector setup
|
||||
self.has_connector = self.vllm_config.kv_transfer_config is not None
|
||||
|
||||
self.scheduler_output: SchedulerOutput | None = None
|
||||
|
||||
@property
|
||||
def max_concurrent_batches(self) -> int:
|
||||
"""Ray distributed executor supports pipeline parallelism,
|
||||
@ -381,22 +386,46 @@ class RayDistributedExecutor(Executor):
|
||||
self.shutdown()
|
||||
|
||||
def execute_model( # type: ignore[override]
|
||||
self, scheduler_output: SchedulerOutput, non_block: bool = False
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
non_block: bool = False,
|
||||
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
|
||||
if self.scheduler_output is not None:
|
||||
raise RuntimeError(
|
||||
"State error: sample_tokens() must be called "
|
||||
"after execute_model() returns None."
|
||||
)
|
||||
self.scheduler_output = scheduler_output
|
||||
return COMPLETED_NONE_FUTURE if non_block else None
|
||||
|
||||
def sample_tokens( # type: ignore[override]
|
||||
self,
|
||||
grammar_output: "GrammarOutput | None",
|
||||
non_block: bool = False,
|
||||
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
|
||||
"""Execute the model on the Ray workers.
|
||||
|
||||
The scheduler output to use should have been provided in
|
||||
a prior call to execute_model().
|
||||
|
||||
Args:
|
||||
scheduler_output: The scheduler output to execute.
|
||||
grammar_output: The structured outputs grammar bitmask, if applicable.
|
||||
non_block: If True, the method will return a Future.
|
||||
|
||||
Returns:
|
||||
The model runner output.
|
||||
"""
|
||||
scheduler_output = self.scheduler_output
|
||||
if scheduler_output is None:
|
||||
return None # noqa
|
||||
|
||||
self.scheduler_output = None
|
||||
|
||||
# Build the compiled DAG for the first time.
|
||||
if self.forward_dag is None: # type: ignore
|
||||
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
|
||||
|
||||
refs = self.forward_dag.execute(scheduler_output) # type: ignore
|
||||
refs = self.forward_dag.execute((scheduler_output, grammar_output)) # type: ignore
|
||||
|
||||
if not self.has_connector:
|
||||
# Get output only from a single worker (output_rank)
|
||||
|
||||
@ -19,7 +19,7 @@ from vllm.v1.outputs import AsyncModelRunnerOutput
|
||||
from vllm.v1.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -82,36 +82,41 @@ try:
|
||||
|
||||
def execute_model_ray(
|
||||
self,
|
||||
scheduler_output: Union[
|
||||
"SchedulerOutput", tuple["SchedulerOutput", "IntermediateTensors"]
|
||||
],
|
||||
execute_model_input: tuple["SchedulerOutput", "GrammarOutput"]
|
||||
| tuple["SchedulerOutput", "GrammarOutput", "IntermediateTensors"],
|
||||
) -> Union[
|
||||
"ModelRunnerOutput", tuple["SchedulerOutput", "IntermediateTensors"]
|
||||
"ModelRunnerOutput",
|
||||
tuple["SchedulerOutput", "GrammarOutput", "IntermediateTensors"],
|
||||
]:
|
||||
# This method is used by Ray Compiled Graph to execute the model,
|
||||
# and it needs a special logic of self.setup_device_if_necessary()
|
||||
self.setup_device_if_necessary()
|
||||
assert self.worker is not None, "Worker is not initialized"
|
||||
if isinstance(scheduler_output, tuple):
|
||||
scheduler_output, intermediate_tensors = scheduler_output
|
||||
if len(execute_model_input) == 3:
|
||||
scheduler_output, grammar_output, intermediate_tensors = (
|
||||
execute_model_input
|
||||
)
|
||||
else:
|
||||
scheduler_output, intermediate_tensors = scheduler_output, None
|
||||
scheduler_output, grammar_output = execute_model_input
|
||||
intermediate_tensors = None
|
||||
assert self.worker.model_runner is not None
|
||||
output = self.worker.model_runner.execute_model(
|
||||
scheduler_output, intermediate_tensors
|
||||
)
|
||||
if isinstance(output, IntermediateTensors):
|
||||
output = scheduler_output, output
|
||||
output = scheduler_output, grammar_output, output
|
||||
elif not get_pp_group().is_last_rank:
|
||||
# Case where there are no scheduled requests
|
||||
# but may still be finished requests.
|
||||
assert not output or not output.req_ids
|
||||
output = scheduler_output, None
|
||||
# Ensure outputs crossing Ray compiled DAG are serializable.
|
||||
# AsyncModelRunnerOutput holds CUDA events and cannot be
|
||||
# pickled.
|
||||
if isinstance(output, AsyncModelRunnerOutput):
|
||||
output = output.get_output()
|
||||
output = scheduler_output, grammar_output, None
|
||||
elif output is None:
|
||||
output = self.worker.model_runner.sample_tokens(grammar_output)
|
||||
# Ensure outputs crossing Ray compiled DAG are serializable.
|
||||
# AsyncModelRunnerOutput holds CUDA events and cannot be
|
||||
# pickled.
|
||||
if isinstance(output, AsyncModelRunnerOutput):
|
||||
output = output.get_output()
|
||||
return output
|
||||
|
||||
def override_env_vars(self, vars: dict[str, str]):
|
||||
|
||||
@ -16,6 +16,7 @@ from diskcache import Cache
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.import_utils import LazyLoader
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import outlines_core as oc
|
||||
@ -24,7 +25,6 @@ if TYPE_CHECKING:
|
||||
import xgrammar as xgr
|
||||
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
else:
|
||||
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
||||
@ -47,6 +47,7 @@ CACHE = None
|
||||
|
||||
def apply_grammar_bitmask(
|
||||
scheduler_output: SchedulerOutput,
|
||||
grammar_output: GrammarOutput,
|
||||
input_batch: InputBatch,
|
||||
logits: torch.Tensor,
|
||||
) -> None:
|
||||
@ -58,9 +59,9 @@ def apply_grammar_bitmask(
|
||||
input_batch (InputBatch): The input of model runner.
|
||||
logits (torch.Tensor): The output logits of model forward.
|
||||
"""
|
||||
grammar_bitmask = scheduler_output.grammar_bitmask
|
||||
if grammar_bitmask is None:
|
||||
return
|
||||
# Serialization of np.ndarray is much more efficient than a tensor,
|
||||
# so we receive it in that format.
|
||||
grammar_bitmask = grammar_output.grammar_bitmask
|
||||
|
||||
# We receive the structured output bitmask from the scheduler,
|
||||
# compacted to contain bitmasks only for structured output requests.
|
||||
@ -79,7 +80,7 @@ def apply_grammar_bitmask(
|
||||
cumulative_offset += len(
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
|
||||
)
|
||||
if req_id in scheduler_output.structured_output_request_ids:
|
||||
if req_id in grammar_output.structured_output_request_ids:
|
||||
struct_out_req_batch_indices[req_id] = logit_index
|
||||
|
||||
out_indices = []
|
||||
@ -91,7 +92,7 @@ def apply_grammar_bitmask(
|
||||
dtype=grammar_bitmask.dtype,
|
||||
)
|
||||
cumulative_index = 0
|
||||
for req_id in scheduler_output.structured_output_request_ids:
|
||||
for req_id in grammar_output.structured_output_request_ids:
|
||||
num_spec_tokens = len(
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
|
||||
)
|
||||
@ -101,22 +102,28 @@ def apply_grammar_bitmask(
|
||||
sorted_bitmask[logit_index + i] = grammar_bitmask[cumulative_index + i]
|
||||
out_indices.append(logit_index + i)
|
||||
cumulative_index += 1 + num_spec_tokens
|
||||
grammar_bitmask = sorted_bitmask
|
||||
|
||||
# Copy async to device as tensor.
|
||||
grammar_bitmask = torch.from_numpy(sorted_bitmask).to(
|
||||
logits.device, non_blocking=True
|
||||
)
|
||||
|
||||
# If the length of out indices and the logits have the same shape
|
||||
# we don't need to pass indices to the kernel,
|
||||
# since the bitmask is already aligned with the logits.
|
||||
skip_out_indices = len(out_indices) == logits.shape[0]
|
||||
|
||||
# Serialization of np.ndarray is much more efficient than a tensor,
|
||||
# so we receive it in that format.
|
||||
grammar_bitmask = torch.from_numpy(grammar_bitmask).contiguous()
|
||||
index_tensor = None
|
||||
if not skip_out_indices:
|
||||
# xgrammar expects a python list of indices but it will actually work with
|
||||
# a tensor. If we copy the tensor ourselves here we can do it in a non_blocking
|
||||
# manner and there should be no cpu sync within xgrammar.
|
||||
index_tensor = torch.tensor(
|
||||
out_indices, dtype=torch.int32, device="cpu", pin_memory=True
|
||||
)
|
||||
index_tensor = index_tensor.to(logits.device, non_blocking=True)
|
||||
|
||||
xgr.apply_token_bitmask_inplace(
|
||||
logits,
|
||||
grammar_bitmask.to(logits.device, non_blocking=True),
|
||||
indices=out_indices if not skip_out_indices else None,
|
||||
)
|
||||
xgr.apply_token_bitmask_inplace(logits, grammar_bitmask, indices=index_tensor)
|
||||
|
||||
|
||||
class OutlinesVocabulary:
|
||||
|
||||
@ -204,7 +204,7 @@ class InputBatch:
|
||||
self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy()
|
||||
|
||||
# lora related
|
||||
self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int32)
|
||||
self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int64)
|
||||
self.lora_id_to_request_ids: dict[int, set[str]] = {}
|
||||
self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
|
||||
|
||||
|
||||
@ -109,6 +109,7 @@ from vllm.v1.outputs import (
|
||||
EMPTY_MODEL_RUNNER_OUTPUT,
|
||||
AsyncModelRunnerOutput,
|
||||
DraftTokenIds,
|
||||
KVConnectorOutput,
|
||||
LogprobsLists,
|
||||
LogprobsTensors,
|
||||
ModelRunnerOutput,
|
||||
@ -150,7 +151,7 @@ from .utils import (
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -218,6 +219,20 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
|
||||
return output
|
||||
|
||||
|
||||
class ExecuteModelState(NamedTuple):
|
||||
"""Ephemeral cached state transferred between execute_model() and
|
||||
sample_tokens(), after execute_model() returns None."""
|
||||
|
||||
scheduler_output: "SchedulerOutput"
|
||||
logits: torch.Tensor
|
||||
spec_decode_metadata: SpecDecodeMetadata | None
|
||||
spec_decode_common_attn_metadata: CommonAttentionMetadata | None
|
||||
hidden_states: torch.Tensor
|
||||
sample_hidden_states: torch.Tensor
|
||||
aux_hidden_states: list[torch.Tensor] | None
|
||||
kv_connector_output: KVConnectorOutput | None
|
||||
|
||||
|
||||
class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
def __init__(
|
||||
self,
|
||||
@ -509,6 +524,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
|
||||
# Ephemeral state transferred between execute_model() and sample_tokens().
|
||||
self.execute_model_state: ExecuteModelState | None = None
|
||||
|
||||
def reset_mm_cache(self) -> None:
|
||||
if self.mm_budget:
|
||||
self.mm_budget.reset_cache()
|
||||
@ -2113,7 +2131,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_input_tokens: int, # Padded
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
) -> tuple[
|
||||
int,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor,
|
||||
@ -2207,7 +2224,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
model_kwargs.update(encoder_inputs)
|
||||
|
||||
return (
|
||||
num_scheduled_tokens,
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
positions,
|
||||
@ -2425,13 +2441,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors:
|
||||
) -> ModelRunnerOutput | IntermediateTensors | None:
|
||||
if self.execute_model_state is not None:
|
||||
raise RuntimeError(
|
||||
"State error: sample_tokens() must be called "
|
||||
"after execute_model() returns None."
|
||||
)
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
with record_function_or_nullcontext("Preprocess"):
|
||||
with self.synchronize_input_prep():
|
||||
# Update persistent batch states.
|
||||
self._update_states(scheduler_output)
|
||||
|
||||
if not scheduler_output.total_num_scheduled_tokens:
|
||||
if not num_scheduled_tokens:
|
||||
if not has_kv_transfer_group():
|
||||
# Return empty ModelRunnerOutput if no work to do.
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
@ -2471,7 +2493,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
)
|
||||
|
||||
(
|
||||
num_scheduled_tokens,
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
positions,
|
||||
@ -2559,6 +2580,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Rare case.
|
||||
assert not self.is_pooling_model
|
||||
|
||||
sample_hidden_states = hidden_states[logits_indices]
|
||||
if not get_pp_group().is_last_rank:
|
||||
all_gather_tensors = {
|
||||
"residual": not is_residual_scattered_for_sp(
|
||||
@ -2572,7 +2594,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
)
|
||||
logits = None
|
||||
else:
|
||||
sample_hidden_states = hidden_states[logits_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
|
||||
model_output_broadcast_data = {}
|
||||
@ -2585,9 +2606,45 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
assert model_output_broadcast_data is not None
|
||||
logits = model_output_broadcast_data["logits"]
|
||||
|
||||
# Apply structured output bitmasks if present
|
||||
if scheduler_output.structured_output_request_ids:
|
||||
apply_grammar_bitmask(scheduler_output, self.input_batch, logits)
|
||||
self.execute_model_state = ExecuteModelState(
|
||||
scheduler_output,
|
||||
logits,
|
||||
spec_decode_metadata,
|
||||
spec_decode_common_attn_metadata,
|
||||
hidden_states,
|
||||
sample_hidden_states,
|
||||
aux_hidden_states,
|
||||
kv_connector_output,
|
||||
)
|
||||
return None
|
||||
|
||||
@torch.inference_mode
|
||||
def sample_tokens(
|
||||
self, grammar_output: "GrammarOutput | None"
|
||||
) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors:
|
||||
if self.execute_model_state is None:
|
||||
# Nothing to do (PP non-final rank case), output isn't used.
|
||||
return None # noqa
|
||||
|
||||
# Unpack ephemeral state.
|
||||
(
|
||||
scheduler_output,
|
||||
logits,
|
||||
spec_decode_metadata,
|
||||
spec_decode_common_attn_metadata,
|
||||
hidden_states,
|
||||
sample_hidden_states,
|
||||
aux_hidden_states,
|
||||
kv_connector_output,
|
||||
) = self.execute_model_state
|
||||
# Clear ephemeral state.
|
||||
self.execute_model_state = None
|
||||
|
||||
# Apply structured output bitmasks if present.
|
||||
if grammar_output is not None:
|
||||
apply_grammar_bitmask(
|
||||
scheduler_output, grammar_output, self.input_batch, logits
|
||||
)
|
||||
|
||||
with record_function_or_nullcontext("Sample"):
|
||||
sampler_output = self._sample(logits, spec_decode_metadata)
|
||||
@ -2646,7 +2703,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
sampler_output,
|
||||
logits,
|
||||
hidden_states,
|
||||
num_scheduled_tokens,
|
||||
scheduler_output.total_num_scheduled_tokens,
|
||||
spec_decode_metadata,
|
||||
)
|
||||
|
||||
@ -3978,6 +4035,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
def create_attn_groups(
|
||||
attn_backends_map: dict[AttentionGroupKey, list[str]],
|
||||
kv_cache_group_id: int,
|
||||
) -> list[AttentionGroup]:
|
||||
attn_groups: list[AttentionGroup] = []
|
||||
for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items():
|
||||
@ -3987,6 +4045,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
kv_cache_spec,
|
||||
self.vllm_config,
|
||||
self.device,
|
||||
kv_cache_group_id,
|
||||
num_metadata_builders=1
|
||||
if not self.parallel_config.enable_dbo
|
||||
else 2,
|
||||
@ -4005,8 +4064,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Resolve cudagraph_mode before actually initialize metadata_builders
|
||||
self._check_and_update_cudagraph_mode(attention_backend_set)
|
||||
|
||||
for attn_backends_map in attention_backend_maps:
|
||||
self.attn_groups.append(create_attn_groups(attn_backends_map))
|
||||
for i, attn_backend_map in enumerate(attention_backend_maps):
|
||||
self.attn_groups.append(create_attn_groups(attn_backend_map, i))
|
||||
|
||||
# Calculate reorder batch threshold (if needed)
|
||||
self.calculate_reorder_batch_threshold()
|
||||
@ -4149,89 +4208,88 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
group.get_metadata_builder().reorder_batch_threshold
|
||||
for group in self._attn_group_iterator()
|
||||
]
|
||||
# If there are no attention groups (attention-free model) or no backend
|
||||
# reports a threshold, leave reordering disabled.
|
||||
if len(reorder_batch_thresholds) == 0:
|
||||
self.reorder_batch_threshold = None
|
||||
return
|
||||
self.reorder_batch_threshold = reduce(min_none_high, reorder_batch_thresholds)
|
||||
|
||||
def _find_compatible_block_sizes(
|
||||
self,
|
||||
kv_manager_block_size: int,
|
||||
backend_cls: type[AttentionBackend],
|
||||
return_all: bool = False,
|
||||
) -> list[int]:
|
||||
"""
|
||||
Find compatible block sizes for a backend.
|
||||
|
||||
Args:
|
||||
kv_manager_block_size: Physical block size of KV cache
|
||||
backend_cls: Attention backend class
|
||||
return_all: Return all compatible sizes if True, max size if False
|
||||
|
||||
Returns:
|
||||
Compatible block size(s) based on return_all parameter
|
||||
|
||||
Raises:
|
||||
ValueError: If no compatible block size found
|
||||
"""
|
||||
supported_block_size = backend_cls.get_supported_kernel_block_size()
|
||||
compatible_sizes = []
|
||||
|
||||
for block_size in supported_block_size:
|
||||
if isinstance(block_size, int):
|
||||
if kv_manager_block_size % block_size == 0:
|
||||
compatible_sizes.append(block_size)
|
||||
elif (
|
||||
isinstance(block_size, MultipleOf)
|
||||
and kv_manager_block_size % block_size.base == 0
|
||||
):
|
||||
compatible_sizes.append(kv_manager_block_size)
|
||||
|
||||
if not compatible_sizes:
|
||||
raise ValueError(f"No compatible block size for {kv_manager_block_size}")
|
||||
|
||||
return compatible_sizes if return_all else [max(compatible_sizes)]
|
||||
|
||||
def _select_common_block_size(
|
||||
self, kv_manager_block_size: int, attn_groups: list[AttentionGroup]
|
||||
@staticmethod
|
||||
def select_common_block_size(
|
||||
kv_manager_block_size: int, attn_groups: list[AttentionGroup]
|
||||
) -> int:
|
||||
"""
|
||||
Select common block size for all backends.
|
||||
Select a block size that is supported by all backends and is a factor of
|
||||
kv_manager_block_size.
|
||||
|
||||
If kv_manager_block_size is supported by all backends, return it directly.
|
||||
Otherwise, return the max supported size.
|
||||
|
||||
Args:
|
||||
kv_manager_block_size: Block size of KV cache
|
||||
attn_groups: List of attention groups
|
||||
|
||||
Returns:
|
||||
Block size supported by all backends,
|
||||
prioritizing cache_config.block_size
|
||||
The selected block size
|
||||
|
||||
Raises:
|
||||
ValueError: If no common block size found
|
||||
ValueError: If no valid block size found
|
||||
"""
|
||||
all_backend_supports = []
|
||||
|
||||
for attn_group in attn_groups:
|
||||
compatible_sizes = self._find_compatible_block_sizes(
|
||||
kv_manager_block_size, attn_group.backend, return_all=True
|
||||
)
|
||||
supported_sizes = sorted(list(set(compatible_sizes)), reverse=True)
|
||||
all_backend_supports.append(set(supported_sizes))
|
||||
def block_size_is_supported(
|
||||
backends: list[type[AttentionBackend]], block_size: int
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the block size is supported by all backends.
|
||||
"""
|
||||
for backend in backends:
|
||||
is_supported = False
|
||||
for supported_size in backend.get_supported_kernel_block_size():
|
||||
if isinstance(supported_size, int):
|
||||
if block_size == supported_size:
|
||||
is_supported = True
|
||||
elif isinstance(supported_size, MultipleOf):
|
||||
if block_size % supported_size.base == 0:
|
||||
is_supported = True
|
||||
else:
|
||||
raise ValueError(f"Unknown supported size: {supported_size}")
|
||||
if not is_supported:
|
||||
return False
|
||||
return True
|
||||
|
||||
common_supported_sizes = set.intersection(*all_backend_supports)
|
||||
backends = [group.backend for group in attn_groups]
|
||||
|
||||
if not common_supported_sizes:
|
||||
error_msg = f"No common block size for {kv_manager_block_size}. "
|
||||
for i, attn_group in enumerate(attn_groups):
|
||||
supported = all_backend_supports[i]
|
||||
error_msg += (
|
||||
f"Backend {attn_group.backend} supports: {sorted(supported)}. "
|
||||
)
|
||||
raise ValueError(error_msg)
|
||||
# Case 1: if the block_size of kv cache manager is supported by all backends,
|
||||
# return it directly
|
||||
if block_size_is_supported(backends, kv_manager_block_size):
|
||||
return kv_manager_block_size
|
||||
|
||||
if self.cache_config.block_size in common_supported_sizes:
|
||||
return self.cache_config.block_size
|
||||
# Case 2: otherwise, the block_size must be an `int`-format supported size of
|
||||
# at least one backend. Iterate over all `int`-format supported sizes in
|
||||
# descending order and return the first one that is supported by all backends.
|
||||
# Simple proof:
|
||||
# If the supported size b is in MultipleOf(x_i) format for all attention
|
||||
# backends i, and b a factor of kv_manager_block_size, then
|
||||
# kv_manager_block_size also satisfies MultipleOf(x_i) for all i. We will
|
||||
# return kv_manager_block_size in case 1.
|
||||
all_int_supported_sizes = set(
|
||||
supported_size
|
||||
for backend in backends
|
||||
for supported_size in backend.get_supported_kernel_block_size()
|
||||
if isinstance(supported_size, int)
|
||||
)
|
||||
|
||||
return max(common_supported_sizes)
|
||||
for supported_size in sorted(all_int_supported_sizes, reverse=True):
|
||||
if kv_manager_block_size % supported_size != 0:
|
||||
continue
|
||||
if block_size_is_supported(backends, supported_size):
|
||||
return supported_size
|
||||
raise ValueError(f"No common block size for {kv_manager_block_size}. ")
|
||||
|
||||
def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
def may_reinitialize_input_batch(
|
||||
self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int]
|
||||
) -> None:
|
||||
"""
|
||||
Re-initialize the input batch if the block sizes are different from
|
||||
`[self.cache_config.block_size]`. This usually happens when there
|
||||
@ -4239,6 +4297,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
Args:
|
||||
kv_cache_config: The KV cache configuration.
|
||||
kernel_block_sizes: The kernel block sizes for each KV cache group.
|
||||
"""
|
||||
block_sizes = [
|
||||
kv_cache_group.kv_cache_spec.block_size
|
||||
@ -4246,9 +4305,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
if not isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec)
|
||||
]
|
||||
|
||||
# Generate kernel_block_sizes that matches each block_size
|
||||
kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config)
|
||||
|
||||
if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [
|
||||
self.cache_config.block_size
|
||||
]:
|
||||
@ -4349,7 +4405,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# all backends in the group.
|
||||
attn_groups = self.attn_groups[kv_cache_group_id]
|
||||
kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size
|
||||
selected_kernel_size = self._select_common_block_size(
|
||||
selected_kernel_size = self.select_common_block_size(
|
||||
kv_manager_block_size, attn_groups
|
||||
)
|
||||
kernel_block_sizes.append(selected_kernel_size)
|
||||
@ -4367,6 +4423,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
kv_cache_raw_tensors: dict[str, torch.Tensor],
|
||||
kernel_block_sizes: list[int],
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Reshape the KV cache tensors to the desired shape and dtype.
|
||||
@ -4375,6 +4432,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
kv_cache_config: The KV cache config
|
||||
kv_cache_raw_tensors: The KV cache buffer of each layer, with
|
||||
correct size but uninitialized shape.
|
||||
kernel_block_sizes: The kernel block sizes for each KV cache group.
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: A map between layer names to their
|
||||
corresponding memory buffer for KV cache.
|
||||
@ -4384,6 +4442,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
for group in self._kv_cache_spec_attn_group_iterator():
|
||||
kv_cache_spec = group.kv_cache_spec
|
||||
attn_backend = group.backend
|
||||
if group.kv_cache_group_id == len(kernel_block_sizes):
|
||||
# There may be a last group for layers without kv cache.
|
||||
continue
|
||||
kernel_block_size = kernel_block_sizes[group.kv_cache_group_id]
|
||||
for layer_name in group.layer_names:
|
||||
if layer_name in self.runner_only_attn_layers:
|
||||
continue
|
||||
@ -4392,24 +4454,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
|
||||
if isinstance(kv_cache_spec, AttentionSpec):
|
||||
has_attn = True
|
||||
kv_manager_block_size = kv_cache_spec.block_size
|
||||
kernel_size_list = self._find_compatible_block_sizes(
|
||||
kv_manager_block_size, attn_backend, return_all=False
|
||||
num_blocks_per_kv_block = (
|
||||
kv_cache_spec.block_size // kernel_block_size
|
||||
)
|
||||
kernel_size = kernel_size_list[0]
|
||||
num_blocks_per_kv_block = kv_manager_block_size // kernel_size
|
||||
kernel_num_blocks = num_blocks * num_blocks_per_kv_block
|
||||
|
||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
kernel_num_blocks,
|
||||
kernel_size,
|
||||
kernel_block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size,
|
||||
cache_dtype_str=self.cache_config.cache_dtype,
|
||||
)
|
||||
dtype = kv_cache_spec.dtype
|
||||
try:
|
||||
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() # noqa: E501
|
||||
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
|
||||
assert len(kv_cache_stride_order) == len(kv_cache_shape)
|
||||
except (AttributeError, NotImplementedError):
|
||||
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
|
||||
@ -4492,13 +4551,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
)
|
||||
|
||||
def initialize_kv_cache_tensors(
|
||||
self, kv_cache_config: KVCacheConfig
|
||||
self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int]
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Initialize the memory buffer for KV cache.
|
||||
|
||||
Args:
|
||||
kv_cache_config: The KV cache config
|
||||
kernel_block_sizes: The kernel block sizes for each KV cache group.
|
||||
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: A map between layer names to their
|
||||
corresponding memory buffer for KV cache.
|
||||
@ -4507,7 +4568,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config)
|
||||
# Change the memory buffer to the desired shape
|
||||
kv_caches = self._reshape_kv_cache_tensors(
|
||||
kv_cache_config, kv_cache_raw_tensors
|
||||
kv_cache_config, kv_cache_raw_tensors, kernel_block_sizes
|
||||
)
|
||||
|
||||
# Set up cross-layer KV cache sharing
|
||||
@ -4566,9 +4627,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.may_add_encoder_only_layers_to_kv_cache_config()
|
||||
self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)
|
||||
self.initialize_attn_backend(kv_cache_config)
|
||||
# The kernel block size for all KV cache groups. For example, if
|
||||
# kv_cache_manager uses block_size 256 for a given group, but the attention
|
||||
# backends for that group only supports block_size 64, we will return
|
||||
# kernel_block_size 64 and split the 256-token-block to 4 blocks with 64
|
||||
# tokens each.
|
||||
kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config)
|
||||
# Reinitialize need to after initialize_attn_backend
|
||||
self.may_reinitialize_input_batch(kv_cache_config)
|
||||
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
|
||||
self.may_reinitialize_input_batch(kv_cache_config, kernel_block_sizes)
|
||||
kv_caches = self.initialize_kv_cache_tensors(
|
||||
kv_cache_config, kernel_block_sizes
|
||||
)
|
||||
|
||||
if self.speculative_config and self.speculative_config.use_eagle():
|
||||
assert isinstance(self.drafter, EagleProposer)
|
||||
|
||||
@ -6,6 +6,7 @@ import copy
|
||||
import gc
|
||||
import os
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from types import NoneType
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
@ -19,7 +20,11 @@ from vllm.distributed import (
|
||||
init_distributed_environment,
|
||||
set_custom_all_reduce,
|
||||
)
|
||||
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
||||
from vllm.distributed.kv_transfer import (
|
||||
ensure_kv_transfer_initialized,
|
||||
get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
)
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_pp_group,
|
||||
get_tp_group,
|
||||
@ -33,6 +38,7 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.utils.mem_constants import GiB_bytes
|
||||
from vllm.utils.mem_utils import MemorySnapshot, memory_profiling
|
||||
from vllm.v1.core.sched.output import GrammarOutput
|
||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import (
|
||||
@ -348,6 +354,21 @@ class Worker(WorkerBase):
|
||||
|
||||
return int(self.available_kv_cache_memory_bytes)
|
||||
|
||||
def get_kv_connector_handshake_metadata(self) -> dict | None:
|
||||
"""Get KV connector metadata from this worker if available."""
|
||||
|
||||
if not has_kv_transfer_group():
|
||||
return None
|
||||
|
||||
connector = get_kv_transfer_group()
|
||||
# Return None for connectors that don't need to exchange handshake
|
||||
# metadata across workers.
|
||||
if (metadata := connector.get_handshake_metadata()) is None:
|
||||
return None
|
||||
|
||||
tp_rank = get_tp_group().rank_in_group
|
||||
return {tp_rank: metadata}
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
return self.model_runner.get_kv_cache_spec()
|
||||
|
||||
@ -489,11 +510,16 @@ class Worker(WorkerBase):
|
||||
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||
return self.model_runner.get_supported_tasks()
|
||||
|
||||
@torch.inference_mode()
|
||||
def sample_tokens(
|
||||
self, grammar_output: "GrammarOutput"
|
||||
) -> ModelRunnerOutput | AsyncModelRunnerOutput:
|
||||
return self.model_runner.sample_tokens(grammar_output)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
|
||||
self, scheduler_output: "SchedulerOutput"
|
||||
) -> ModelRunnerOutput | None:
|
||||
intermediate_tensors = None
|
||||
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
@ -512,13 +538,13 @@ class Worker(WorkerBase):
|
||||
)
|
||||
|
||||
output = self.model_runner.execute_model(scheduler_output, intermediate_tensors)
|
||||
if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)):
|
||||
if isinstance(output, (ModelRunnerOutput, NoneType)):
|
||||
return output
|
||||
|
||||
assert isinstance(output, IntermediateTensors)
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
assert (
|
||||
parallel_config.distributed_executor_backend != ("external_launcher")
|
||||
parallel_config.distributed_executor_backend != "external_launcher"
|
||||
and not get_pp_group().is_last_rank
|
||||
)
|
||||
|
||||
|
||||
@ -139,7 +139,7 @@ class InputBatch:
|
||||
self.min_tokens: dict[int, tuple[int, set[int]]] = {}
|
||||
|
||||
# lora related
|
||||
self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int32)
|
||||
self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int64)
|
||||
self.lora_id_to_request_ids: dict[int, set[str]] = {}
|
||||
self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
|
||||
|
||||
|
||||
@ -92,7 +92,7 @@ from .utils import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -372,6 +372,11 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
else:
|
||||
self.sample_from_logits_func = self.sample_from_logits
|
||||
|
||||
# For passing scheduler_output between successive
|
||||
# execute_model() and sample_tokens() calls.
|
||||
self.scheduler_output: SchedulerOutput | None = None
|
||||
self.mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None
|
||||
|
||||
def reset_mm_cache(self) -> None:
|
||||
if self.mm_budget:
|
||||
self.mm_budget.reset_cache()
|
||||
@ -1078,7 +1083,12 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
) -> ModelRunnerOutput:
|
||||
) -> ModelRunnerOutput | None:
|
||||
if self.scheduler_output is not None:
|
||||
raise RuntimeError(
|
||||
"State error: sample_tokens() must be called "
|
||||
"after execute_model() returns None."
|
||||
)
|
||||
# Update cached state
|
||||
self._update_states(scheduler_output)
|
||||
if not scheduler_output.total_num_scheduled_tokens:
|
||||
@ -1088,14 +1098,30 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
return self.kv_connector_no_forward(scheduler_output, self.vllm_config)
|
||||
|
||||
mm_embed_inputs = None
|
||||
if self.supports_mm_inputs:
|
||||
# Run the multimodal encoder if any.
|
||||
self._execute_mm_encoder(scheduler_output)
|
||||
mm_embed_inputs = self._gather_mm_embeddings(scheduler_output)
|
||||
else:
|
||||
mm_embed_inputs = None
|
||||
|
||||
torch_xla.sync(wait=False)
|
||||
|
||||
self.scheduler_output = scheduler_output
|
||||
self.mm_embed_inputs = mm_embed_inputs
|
||||
return None
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_tokens(
|
||||
self, grammar_output: "GrammarOutput | None"
|
||||
) -> ModelRunnerOutput:
|
||||
if self.scheduler_output is None:
|
||||
# Nothing to do (PP non-final rank case), output isn't used.
|
||||
return None # noqa
|
||||
scheduler_output = self.scheduler_output
|
||||
mm_embed_inputs = self.mm_embed_inputs
|
||||
self.scheduler_output = None
|
||||
self.mm_embed_inputs = None
|
||||
|
||||
# Prepare inputs, the requests might be split into multiple
|
||||
# executions, combine the result of each execution.
|
||||
start_index = 0
|
||||
@ -1131,9 +1157,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
|
||||
self.input_batch, padded_num_reqs, self.device
|
||||
)
|
||||
if scheduler_output.grammar_bitmask is not None:
|
||||
if grammar_output is not None:
|
||||
require_struct_decoding, grammar_bitmask_padded, arange = (
|
||||
self.prepare_structured_decoding_input(logits, scheduler_output)
|
||||
self.prepare_structured_decoding_input(logits, grammar_output)
|
||||
)
|
||||
logits = self.structured_decode(
|
||||
require_struct_decoding, grammar_bitmask_padded, logits, arange
|
||||
@ -1954,10 +1980,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
return self.model.get_input_embeddings(*args, **kwargs)
|
||||
|
||||
def prepare_structured_decoding_input(
|
||||
self, logits: torch.Tensor, scheduler_output: "SchedulerOutput"
|
||||
self, logits: torch.Tensor, grammar_output: "GrammarOutput"
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
grammar_bitmask = scheduler_output.grammar_bitmask
|
||||
assert grammar_bitmask is not None
|
||||
grammar_bitmask = grammar_output.grammar_bitmask
|
||||
num_reqs, _ = logits.shape
|
||||
|
||||
# Reset pre-allocated tensors
|
||||
@ -1965,7 +1990,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.require_structured_out_cpu.zero_()
|
||||
|
||||
cumulative_mask_idx = 0
|
||||
for req_id in scheduler_output.structured_output_request_ids:
|
||||
for req_id in grammar_output.structured_output_request_ids:
|
||||
if req_id not in self.input_batch.req_id_to_index:
|
||||
continue
|
||||
batch_index = self.input_batch.req_id_to_index[req_id]
|
||||
|
||||
@ -17,7 +17,6 @@ from vllm.distributed import (
|
||||
)
|
||||
from vllm.distributed.kv_transfer import (
|
||||
ensure_kv_transfer_initialized,
|
||||
has_kv_transfer_group,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -27,7 +26,7 @@ from vllm.platforms.tpu import USE_TPU_INFERENCE
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.utils import report_usage_stats
|
||||
@ -255,13 +254,13 @@ class TPUWorker:
|
||||
tpu_kv_cache_bytes = tpu_kv_cache_bytes * head_size // padded_head_size
|
||||
return int(tpu_kv_cache_bytes)
|
||||
|
||||
def sample_tokens(self, grammar_output: "GrammarOutput") -> ModelRunnerOutput:
|
||||
return self.model_runner.sample_tokens(grammar_output)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
self, scheduler_output: "SchedulerOutput"
|
||||
) -> ModelRunnerOutput | None:
|
||||
output = self.model_runner.execute_model(scheduler_output)
|
||||
# every worker's output is needed when kv_transfer_group is set up
|
||||
return output if self.is_driver_worker or has_kv_transfer_group() else None
|
||||
return self.model_runner.execute_model(scheduler_output)
|
||||
|
||||
def profile(self, is_start: bool = True):
|
||||
if self.rank < 1:
|
||||
|
||||
@ -140,6 +140,7 @@ class AttentionGroup:
|
||||
metadata_builders: list[AttentionMetadataBuilder]
|
||||
layer_names: list[str]
|
||||
kv_cache_spec: KVCacheSpec
|
||||
kv_cache_group_id: int
|
||||
|
||||
@staticmethod
|
||||
def create_with_metadata_builders(
|
||||
@ -148,13 +149,16 @@ class AttentionGroup:
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
kv_cache_group_id: int,
|
||||
num_metadata_builders: int = 1,
|
||||
) -> "AttentionGroup":
|
||||
metadata_builders = [
|
||||
backend.get_builder_cls()(kv_cache_spec, layer_names, vllm_config, device)
|
||||
for _ in range(num_metadata_builders)
|
||||
]
|
||||
return AttentionGroup(backend, metadata_builders, layer_names, kv_cache_spec)
|
||||
return AttentionGroup(
|
||||
backend, metadata_builders, layer_names, kv_cache_spec, kv_cache_group_id
|
||||
)
|
||||
|
||||
def get_metadata_builder(self, ubatch_id: int = 0) -> AttentionMetadataBuilder:
|
||||
assert len(self.metadata_builders) > ubatch_id
|
||||
|
||||
@ -20,10 +20,12 @@ from vllm.v1.kv_cache_interface import KVCacheSpec
|
||||
from vllm.v1.serial_utils import run_method
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
from vllm.v1.outputs import AsyncModelRunnerOutput, ModelRunnerOutput
|
||||
else:
|
||||
SchedulerOutput = object
|
||||
GrammarOutput = object
|
||||
AsyncModelRunnerOutput = object
|
||||
ModelRunnerOutput = object
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -122,7 +124,21 @@ class WorkerBase:
|
||||
"""Load model onto target device."""
|
||||
raise NotImplementedError
|
||||
|
||||
def execute_model(self, scheduler_output: SchedulerOutput) -> ModelRunnerOutput:
|
||||
def execute_model(
|
||||
self, scheduler_output: SchedulerOutput
|
||||
) -> ModelRunnerOutput | None:
|
||||
"""If this method returns None, sample_tokens should be called immediately after
|
||||
to obtain the ModelRunnerOutput.
|
||||
|
||||
Note that this design may be changed in future if/when structured outputs
|
||||
parallelism is re-architected.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def sample_tokens(
|
||||
self, grammar_output: GrammarOutput
|
||||
) -> ModelRunnerOutput | AsyncModelRunnerOutput:
|
||||
"""Should be called immediately after execute_model iff it returned None."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_cache_block_size_bytes(self) -> int:
|
||||
@ -344,7 +360,7 @@ class WorkerWrapperBase:
|
||||
scheduler_output: SchedulerOutput,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> ModelRunnerOutput:
|
||||
) -> ModelRunnerOutput | None:
|
||||
self._apply_mm_cache(scheduler_output)
|
||||
|
||||
assert self.worker is not None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user