mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-05 01:51:19 +08:00
Merge remote-tracking branch 'origin/main' into refactor-fp8-linear
This commit is contained in:
commit
8e8218ebac
@ -441,7 +441,7 @@ steps:
|
|||||||
--ignore=lora/test_llm_with_multi_loras.py \
|
--ignore=lora/test_llm_with_multi_loras.py \
|
||||||
--ignore=lora/test_olmoe_tp.py \
|
--ignore=lora/test_olmoe_tp.py \
|
||||||
--ignore=lora/test_deepseekv2_tp.py \
|
--ignore=lora/test_deepseekv2_tp.py \
|
||||||
--ignore=lora/test_gptoss.py \
|
--ignore=lora/test_gptoss_tp.py \
|
||||||
--ignore=lora/test_qwen3moe_tp.py
|
--ignore=lora/test_qwen3moe_tp.py
|
||||||
parallelism: 4
|
parallelism: 4
|
||||||
|
|
||||||
@ -1217,6 +1217,8 @@ steps:
|
|||||||
- pytest -v -s -x lora/test_llama_tp.py
|
- pytest -v -s -x lora/test_llama_tp.py
|
||||||
- pytest -v -s -x lora/test_llm_with_multi_loras.py
|
- pytest -v -s -x lora/test_llm_with_multi_loras.py
|
||||||
- pytest -v -s -x lora/test_olmoe_tp.py
|
- pytest -v -s -x lora/test_olmoe_tp.py
|
||||||
|
- pytest -v -s -x lora/test_gptoss_tp.py
|
||||||
|
|
||||||
|
|
||||||
- label: Weight Loading Multiple GPU Test # 33min
|
- label: Weight Loading Multiple GPU Test # 33min
|
||||||
timeout_in_minutes: 45
|
timeout_in_minutes: 45
|
||||||
|
|||||||
@ -340,6 +340,16 @@ steps:
|
|||||||
commands:
|
commands:
|
||||||
- pytest -v -s v1/attention
|
- pytest -v -s v1/attention
|
||||||
|
|
||||||
|
- label: V1 Test attention (B200) # 10min
|
||||||
|
timeout_in_minutes: 30
|
||||||
|
gpu: b200
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/v1/attention
|
||||||
|
- tests/v1/attention
|
||||||
|
commands:
|
||||||
|
- export VLLM_DISABLE_FLASHINFER_PREFILL=1 # TODO: FI prefill is bugged and causes incorrectness, fix this
|
||||||
|
- pytest -v -s v1/attention
|
||||||
|
|
||||||
- label: V1 Test others (CPU) # 5 mins
|
- label: V1 Test others (CPU) # 5 mins
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
@ -417,7 +427,7 @@ steps:
|
|||||||
--ignore=lora/test_llm_with_multi_loras.py \
|
--ignore=lora/test_llm_with_multi_loras.py \
|
||||||
--ignore=lora/test_olmoe_tp.py \
|
--ignore=lora/test_olmoe_tp.py \
|
||||||
--ignore=lora/test_deepseekv2_tp.py \
|
--ignore=lora/test_deepseekv2_tp.py \
|
||||||
--ignore=lora/test_gptoss.py \
|
--ignore=lora/test_gptoss_tp.py \
|
||||||
--ignore=lora/test_qwen3moe_tp.py
|
--ignore=lora/test_qwen3moe_tp.py
|
||||||
|
|
||||||
parallelism: 4
|
parallelism: 4
|
||||||
@ -1119,6 +1129,7 @@ steps:
|
|||||||
- pytest -v -s -x lora/test_llama_tp.py
|
- pytest -v -s -x lora/test_llama_tp.py
|
||||||
- pytest -v -s -x lora/test_llm_with_multi_loras.py
|
- pytest -v -s -x lora/test_llm_with_multi_loras.py
|
||||||
- pytest -v -s -x lora/test_olmoe_tp.py
|
- pytest -v -s -x lora/test_olmoe_tp.py
|
||||||
|
- pytest -v -s -x lora/test_gptoss_tp.py
|
||||||
|
|
||||||
|
|
||||||
- label: Weight Loading Multiple GPU Test # 33min
|
- label: Weight Loading Multiple GPU Test # 33min
|
||||||
|
|||||||
@ -1429,8 +1429,6 @@ async def main() -> None:
|
|||||||
random.seed(args.seed)
|
random.seed(args.seed)
|
||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
|
|
||||||
if not os.path.exists(args.model):
|
|
||||||
raise OSError(f"Path does not exist: {args.model}")
|
|
||||||
logger.info("Loading tokenizer")
|
logger.info("Loading tokenizer")
|
||||||
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,7 @@ This doc serves as a collection of handy tips for optimizing your vLLM on TPU wo
|
|||||||
|
|
||||||
## Get started
|
## Get started
|
||||||
|
|
||||||
Looking for setup and installation instructions? Find them [here](../getting_started/installation/google_tpu.md).
|
Looking for setup and installation instructions? Find them [here](https://docs.vllm.ai/projects/tpu/en/latest/getting_started/installation/).
|
||||||
|
|
||||||
### TPU workload sizing
|
### TPU workload sizing
|
||||||
|
|
||||||
|
|||||||
133
docs/features/batch_invariance.md
Normal file
133
docs/features/batch_invariance.md
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
# Batch Invariance
|
||||||
|
|
||||||
|
!!! note
|
||||||
|
Batch invariance is currently in beta. Some features are still under active development.
|
||||||
|
Track progress and planned improvements at <https://github.com/vllm-project/vllm/issues/27433>
|
||||||
|
|
||||||
|
This document shows how to enable batch invariance in vLLM. Batch invariance ensures that the output of a model is deterministic and independent of the batch size or the order of requests in a batch.
|
||||||
|
|
||||||
|
## Motivation
|
||||||
|
|
||||||
|
Batch invariance is crucial for several use cases:
|
||||||
|
|
||||||
|
- **Framework debugging**: Deterministic outputs make it easier to debug issues in the inference framework, as the same input will always produce the same output regardless of batching.
|
||||||
|
- **Model debugging**: Helps identify issues in model implementations by ensuring consistent behavior across different batch configurations.
|
||||||
|
- **Reinforcement Learning (RL)**: RL training often requires deterministic rollouts for reproducibility and stable training.
|
||||||
|
- **Large-scale inference systems**: Systems that use vLLM as a component benefit from deterministic behavior for testing, validation, and consistency guarantees.
|
||||||
|
|
||||||
|
## Hardware Requirements
|
||||||
|
|
||||||
|
Batch invariance currently requires NVIDIA GPUs with compute capability 9.0 or higher:
|
||||||
|
|
||||||
|
- **H-series**: H100, H200
|
||||||
|
- **B-series**: B100, B200
|
||||||
|
|
||||||
|
## Enabling Batch Invariance
|
||||||
|
|
||||||
|
Batch invariance can be enabled by setting the `VLLM_BATCH_INVARIANT` environment variable to `1`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export VLLM_BATCH_INVARIANT=1
|
||||||
|
```
|
||||||
|
|
||||||
|
### Online Inference (Server Mode)
|
||||||
|
|
||||||
|
To start a vLLM server with batch invariance enabled:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
VLLM_BATCH_INVARIANT=1 vllm serve meta-llama/Llama-3.1-8B-Instruct
|
||||||
|
```
|
||||||
|
|
||||||
|
Then use the OpenAI-compatible client:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
client = OpenAI(
|
||||||
|
api_key="EMPTY",
|
||||||
|
base_url="http://localhost:8000/v1",
|
||||||
|
)
|
||||||
|
|
||||||
|
# These requests will produce deterministic outputs
|
||||||
|
# regardless of batch size or order
|
||||||
|
response = client.completions.create(
|
||||||
|
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
prompt="The future of AI is",
|
||||||
|
max_tokens=100,
|
||||||
|
temperature=0.7,
|
||||||
|
seed=42,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response.choices[0].text)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Offline Inference
|
||||||
|
|
||||||
|
For offline batch inference with batch invariance:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
os.environ["VLLM_BATCH_INVARIANT"] = "1"
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
prompts = [
|
||||||
|
"The future of AI is",
|
||||||
|
"Machine learning enables",
|
||||||
|
"Deep learning models can",
|
||||||
|
]
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=0.7,
|
||||||
|
top_p=0.95,
|
||||||
|
max_tokens=100,
|
||||||
|
seed=42,
|
||||||
|
)
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Outputs will be deterministic regardless of batch size
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"Prompt: {prompt!r}")
|
||||||
|
print(f"Generated: {generated_text!r}\n")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Tested Models
|
||||||
|
|
||||||
|
Batch invariance has been tested and verified on the following models:
|
||||||
|
|
||||||
|
- **DeepSeek series**: `deepseek-ai/DeepSeek-V3`, `deepseek-ai/DeepSeek-V3-0324`, `deepseek-ai/DeepSeek-R1`, `deepseek-ai/DeepSeek-V3.1`
|
||||||
|
- **Qwen3 (Dense)**: `Qwen/Qwen3-1.7B`, `Qwen/Qwen3-8B`
|
||||||
|
- **Qwen3 (MoE)**: `Qwen/Qwen3-30B-A3B`, `Qwen/Qwen3-Next-80B-A3B-Instruct`
|
||||||
|
- **Llama 3**: `meta-llama/Llama-3.1-8B-Instruct`, `meta-llama/Llama-3.2-1B-Instruct`
|
||||||
|
|
||||||
|
Other models may also work, but these have been explicitly validated. If you encounter issues with a specific model, please report them on the [GitHub issue tracker](https://github.com/vllm-project/vllm/issues/new/choose).
|
||||||
|
|
||||||
|
## Implementation Details
|
||||||
|
|
||||||
|
When batch invariance is enabled, vLLM:
|
||||||
|
|
||||||
|
1. Uses deterministic kernel implementations for attention and other operations
|
||||||
|
2. Ensures consistent numerical behavior across different batch sizes
|
||||||
|
3. Disables certain optimizations that may introduce non-determinism (such as custom all-reduce operations in tensor parallel mode)
|
||||||
|
|
||||||
|
!!! note
|
||||||
|
Enabling batch invariance may impact performance compared to the default non-deterministic mode. This trade-off is intentional to guarantee reproducibility.
|
||||||
|
|
||||||
|
## Future Improvements
|
||||||
|
|
||||||
|
The batch invariance feature is under active development. Planned improvements include:
|
||||||
|
|
||||||
|
- Support for additional GPU architectures
|
||||||
|
- Expanded model coverage
|
||||||
|
- Performance optimizations
|
||||||
|
- Additional testing and validation
|
||||||
|
|
||||||
|
For the latest status and to contribute ideas, see the [tracking issue](https://github.com/vllm-project/vllm/issues/27433).
|
||||||
@ -81,7 +81,7 @@ python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py \
|
|||||||
- Default: 5600
|
- Default: 5600
|
||||||
- **Required for both prefiller and decoder instances**
|
- **Required for both prefiller and decoder instances**
|
||||||
- Each vLLM worker needs a unique port on its host; using the same port number across different hosts is fine
|
- Each vLLM worker needs a unique port on its host; using the same port number across different hosts is fine
|
||||||
- For TP/DP deployments, each worker's port on a node is computed as: base_port + dp_rank * tp_size + tp_rank (e.g., with `--tensor-parallel-size=4` and base_port=5600, tp_rank 0..3 use ports 5600, 5601, 5602, 5603 on that node).
|
- For TP/DP deployments, each worker's port on a node is computed as: base_port + dp_rank (e.g., with `--data-parallel-size=2` and base_port=5600, dp_rank 0..1 use port 5600, 5601 on that node).
|
||||||
- Used for the initial NIXL handshake between the prefiller and the decoder
|
- Used for the initial NIXL handshake between the prefiller and the decoder
|
||||||
|
|
||||||
- `VLLM_NIXL_SIDE_CHANNEL_HOST`: Host for side channel communication
|
- `VLLM_NIXL_SIDE_CHANNEL_HOST`: Host for side channel communication
|
||||||
|
|||||||
@ -2,4 +2,4 @@ nav:
|
|||||||
- README.md
|
- README.md
|
||||||
- gpu.md
|
- gpu.md
|
||||||
- cpu.md
|
- cpu.md
|
||||||
- google_tpu.md
|
- TPU: https://docs.vllm.ai/projects/tpu/en/latest/getting_started/installation/
|
||||||
|
|||||||
@ -11,7 +11,6 @@ vLLM supports the following hardware platforms:
|
|||||||
- [ARM AArch64](cpu.md#arm-aarch64)
|
- [ARM AArch64](cpu.md#arm-aarch64)
|
||||||
- [Apple silicon](cpu.md#apple-silicon)
|
- [Apple silicon](cpu.md#apple-silicon)
|
||||||
- [IBM Z (S390X)](cpu.md#ibm-z-s390x)
|
- [IBM Z (S390X)](cpu.md#ibm-z-s390x)
|
||||||
- [Google TPU](google_tpu.md)
|
|
||||||
|
|
||||||
## Hardware Plugins
|
## Hardware Plugins
|
||||||
|
|
||||||
@ -20,6 +19,7 @@ The backends below live **outside** the main `vllm` repository and follow the
|
|||||||
|
|
||||||
| Accelerator | PyPI / package | Repository |
|
| Accelerator | PyPI / package | Repository |
|
||||||
|-------------|----------------|------------|
|
|-------------|----------------|------------|
|
||||||
|
| Google TPU | `tpu-inference` | <https://github.com/vllm-project/tpu-inference> |
|
||||||
| Ascend NPU | `vllm-ascend` | <https://github.com/vllm-project/vllm-ascend> |
|
| Ascend NPU | `vllm-ascend` | <https://github.com/vllm-project/vllm-ascend> |
|
||||||
| Intel Gaudi (HPU) | N/A, install from source | <https://github.com/vllm-project/vllm-gaudi> |
|
| Intel Gaudi (HPU) | N/A, install from source | <https://github.com/vllm-project/vllm-gaudi> |
|
||||||
| MetaX MACA GPU | N/A, install from source | <https://github.com/MetaX-MACA/vLLM-metax> |
|
| MetaX MACA GPU | N/A, install from source | <https://github.com/MetaX-MACA/vLLM-metax> |
|
||||||
|
|||||||
@ -1,193 +0,0 @@
|
|||||||
# Google TPU
|
|
||||||
|
|
||||||
Tensor Processing Units (TPUs) are Google's custom-developed application-specific
|
|
||||||
integrated circuits (ASICs) used to accelerate machine learning workloads. TPUs
|
|
||||||
are available in different versions each with different hardware specifications.
|
|
||||||
For more information about TPUs, see [TPU System Architecture](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm).
|
|
||||||
For more information on the TPU versions supported with vLLM, see:
|
|
||||||
|
|
||||||
- [TPU v6e](https://cloud.google.com/tpu/docs/v6e)
|
|
||||||
- [TPU v5e](https://cloud.google.com/tpu/docs/v5e)
|
|
||||||
- [TPU v5p](https://cloud.google.com/tpu/docs/v5p)
|
|
||||||
- [TPU v4](https://cloud.google.com/tpu/docs/v4)
|
|
||||||
|
|
||||||
These TPU versions allow you to configure the physical arrangements of the TPU
|
|
||||||
chips. This can improve throughput and networking performance. For more
|
|
||||||
information see:
|
|
||||||
|
|
||||||
- [TPU v6e topologies](https://cloud.google.com/tpu/docs/v6e#configurations)
|
|
||||||
- [TPU v5e topologies](https://cloud.google.com/tpu/docs/v5e#tpu-v5e-config)
|
|
||||||
- [TPU v5p topologies](https://cloud.google.com/tpu/docs/v5p#tpu-v5p-config)
|
|
||||||
- [TPU v4 topologies](https://cloud.google.com/tpu/docs/v4#tpu-v4-config)
|
|
||||||
|
|
||||||
In order for you to use Cloud TPUs you need to have TPU quota granted to your
|
|
||||||
Google Cloud Platform project. TPU quotas specify how many TPUs you can use in a
|
|
||||||
GPC project and are specified in terms of TPU version, the number of TPU you
|
|
||||||
want to use, and quota type. For more information, see [TPU quota](https://cloud.google.com/tpu/docs/quota#tpu_quota).
|
|
||||||
|
|
||||||
For TPU pricing information, see [Cloud TPU pricing](https://cloud.google.com/tpu/pricing).
|
|
||||||
|
|
||||||
You may need additional persistent storage for your TPU VMs. For more
|
|
||||||
information, see [Storage options for Cloud TPU data](https://cloud.devsite.corp.google.com/tpu/docs/storage-options).
|
|
||||||
|
|
||||||
!!! warning
|
|
||||||
There are no pre-built wheels for this device, so you must either use the pre-built Docker image or build vLLM from source.
|
|
||||||
|
|
||||||
## Requirements
|
|
||||||
|
|
||||||
- Google Cloud TPU VM
|
|
||||||
- TPU versions: v6e, v5e, v5p, v4
|
|
||||||
- Python: 3.11 or newer
|
|
||||||
|
|
||||||
### Provision Cloud TPUs
|
|
||||||
|
|
||||||
You can provision Cloud TPUs using the [Cloud TPU API](https://cloud.google.com/tpu/docs/reference/rest)
|
|
||||||
or the [queued resources](https://cloud.google.com/tpu/docs/queued-resources)
|
|
||||||
API (preferred). This section shows how to create TPUs using the queued resource API. For
|
|
||||||
more information about using the Cloud TPU API, see [Create a Cloud TPU using the Create Node API](https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm#create-node-api).
|
|
||||||
Queued resources enable you to request Cloud TPU resources in a queued manner.
|
|
||||||
When you request queued resources, the request is added to a queue maintained by
|
|
||||||
the Cloud TPU service. When the requested resource becomes available, it's
|
|
||||||
assigned to your Google Cloud project for your immediate exclusive use.
|
|
||||||
|
|
||||||
!!! note
|
|
||||||
In all of the following commands, replace the ALL CAPS parameter names with
|
|
||||||
appropriate values. See the parameter descriptions table for more information.
|
|
||||||
|
|
||||||
### Provision Cloud TPUs with GKE
|
|
||||||
|
|
||||||
For more information about using TPUs with GKE, see:
|
|
||||||
|
|
||||||
- [About TPUs in GKE](https://cloud.google.com/kubernetes-engine/docs/concepts/tpus)
|
|
||||||
- [Deploy TPU workloads in GKE Standard](https://cloud.google.com/kubernetes-engine/docs/how-to/tpus)
|
|
||||||
- [Plan for TPUs in GKE](https://cloud.google.com/kubernetes-engine/docs/concepts/plan-tpus)
|
|
||||||
|
|
||||||
## Configure a new environment
|
|
||||||
|
|
||||||
### Provision a Cloud TPU with the queued resource API
|
|
||||||
|
|
||||||
Create a TPU v5e with 4 TPU chips:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \
|
|
||||||
--node-id TPU_NAME \
|
|
||||||
--project PROJECT_ID \
|
|
||||||
--zone ZONE \
|
|
||||||
--accelerator-type ACCELERATOR_TYPE \
|
|
||||||
--runtime-version RUNTIME_VERSION \
|
|
||||||
--service-account SERVICE_ACCOUNT
|
|
||||||
```
|
|
||||||
|
|
||||||
| Parameter name | Description |
|
|
||||||
|--------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
|
||||||
| QUEUED_RESOURCE_ID | The user-assigned ID of the queued resource request. |
|
|
||||||
| TPU_NAME | The user-assigned name of the TPU which is created when the queued resource request is allocated. |
|
|
||||||
| PROJECT_ID | Your Google Cloud project |
|
|
||||||
| ZONE | The GCP zone where you want to create your Cloud TPU. The value you use depends on the version of TPUs you are using. For more information, see [TPU regions and zones] |
|
|
||||||
| ACCELERATOR_TYPE | The TPU version you want to use. Specify the TPU version, for example `v5litepod-4` specifies a v5e TPU with 4 cores, `v6e-1` specifies a v6e TPU with 1 core. For more information, see [TPU versions]. |
|
|
||||||
| RUNTIME_VERSION | The TPU VM runtime version to use. For example, use `v2-alpha-tpuv6e` for a VM loaded with one or more v6e TPU(s). |
|
|
||||||
| SERVICE_ACCOUNT | The email address for your service account. You can find it in the IAM Cloud Console under *Service Accounts*. For example: `tpu-service-account@<your_project_ID>.iam.gserviceaccount.com` |
|
|
||||||
|
|
||||||
Connect to your TPU VM using SSH:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
gcloud compute tpus tpu-vm ssh TPU_NAME --project PROJECT_ID --zone ZONE
|
|
||||||
```
|
|
||||||
|
|
||||||
!!! note
|
|
||||||
When configuring `RUNTIME_VERSION` ("TPU software version") on GCP, ensure it matches the TPU generation you've selected by referencing the [TPU VM images] compatibility matrix. Using an incompatible version may prevent vLLM from running correctly.
|
|
||||||
|
|
||||||
[TPU versions]: https://cloud.google.com/tpu/docs/runtimes
|
|
||||||
[TPU VM images]: https://cloud.google.com/tpu/docs/runtimes
|
|
||||||
[TPU regions and zones]: https://cloud.google.com/tpu/docs/regions-zones
|
|
||||||
|
|
||||||
## Set up using Python
|
|
||||||
|
|
||||||
### Pre-built wheels
|
|
||||||
|
|
||||||
Currently, there are no pre-built TPU wheels.
|
|
||||||
|
|
||||||
### Build wheel from source
|
|
||||||
|
|
||||||
Install Miniconda:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
|
|
||||||
bash Miniconda3-latest-Linux-x86_64.sh
|
|
||||||
source ~/.bashrc
|
|
||||||
```
|
|
||||||
|
|
||||||
Create and activate a Conda environment for vLLM:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
conda create -n vllm python=3.12 -y
|
|
||||||
conda activate vllm
|
|
||||||
```
|
|
||||||
|
|
||||||
Clone the vLLM repository and go to the vLLM directory:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git clone https://github.com/vllm-project/vllm.git && cd vllm
|
|
||||||
```
|
|
||||||
|
|
||||||
Uninstall the existing `torch` and `torch_xla` packages:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip uninstall torch torch-xla -y
|
|
||||||
```
|
|
||||||
|
|
||||||
Install build dependencies:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install -r requirements/tpu.txt
|
|
||||||
sudo apt-get install --no-install-recommends --yes libopenblas-base libopenmpi-dev libomp-dev
|
|
||||||
```
|
|
||||||
|
|
||||||
Run the setup script:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
VLLM_TARGET_DEVICE="tpu" python -m pip install -e .
|
|
||||||
```
|
|
||||||
|
|
||||||
## Set up using Docker
|
|
||||||
|
|
||||||
### Pre-built images
|
|
||||||
|
|
||||||
See [Using Docker](../../deployment/docker.md) for instructions on using the official Docker image, making sure to substitute the image name `vllm/vllm-openai` with `vllm/vllm-tpu`.
|
|
||||||
|
|
||||||
### Build image from source
|
|
||||||
|
|
||||||
You can use [docker/Dockerfile.tpu](../../../docker/Dockerfile.tpu) to build a Docker image with TPU support.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
docker build -f docker/Dockerfile.tpu -t vllm-tpu .
|
|
||||||
```
|
|
||||||
|
|
||||||
Run the Docker image with the following command:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Make sure to add `--privileged --net host --shm-size=16G`.
|
|
||||||
docker run --privileged --net host --shm-size=16G -it vllm-tpu
|
|
||||||
```
|
|
||||||
|
|
||||||
!!! note
|
|
||||||
Since TPU relies on XLA which requires static shapes, vLLM bucketizes the
|
|
||||||
possible input shapes and compiles an XLA graph for each shape. The
|
|
||||||
compilation time may take 20~30 minutes in the first run. However, the
|
|
||||||
compilation time reduces to ~5 minutes afterwards because the XLA graphs are
|
|
||||||
cached in the disk (in `VLLM_XLA_CACHE_PATH` or `~/.cache/vllm/xla_cache` by default).
|
|
||||||
|
|
||||||
!!! tip
|
|
||||||
If you encounter the following error:
|
|
||||||
|
|
||||||
```console
|
|
||||||
from torch._C import * # noqa: F403
|
|
||||||
ImportError: libopenblas.so.0: cannot open shared object file: No such
|
|
||||||
file or directory
|
|
||||||
```
|
|
||||||
|
|
||||||
Install OpenBLAS with the following command:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
sudo apt-get install --no-install-recommends --yes libopenblas-base libopenmpi-dev libomp-dev
|
|
||||||
```
|
|
||||||
@ -63,6 +63,17 @@ This guide will help you quickly get started with vLLM to perform:
|
|||||||
rocm/vllm-dev:nightly
|
rocm/vllm-dev:nightly
|
||||||
```
|
```
|
||||||
|
|
||||||
|
=== "Google TPU"
|
||||||
|
|
||||||
|
To run vLLM on Google TPUs, you need to install the `vllm-tpu` package.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv pip install vllm-tpu
|
||||||
|
```
|
||||||
|
|
||||||
|
!!! note
|
||||||
|
For more detailed instructions, including Docker, installing from source, and troubleshooting, please refer to the [vLLM on TPU documentation](https://docs.vllm.ai/projects/tpu/en/latest/).
|
||||||
|
|
||||||
!!! note
|
!!! note
|
||||||
For more detail and non-CUDA platforms, please refer [here](installation/README.md) for specific instructions on how to install vLLM.
|
For more detail and non-CUDA platforms, please refer [here](installation/README.md) for specific instructions on how to install vLLM.
|
||||||
|
|
||||||
|
|||||||
@ -7,7 +7,7 @@ requests >= 2.26.0
|
|||||||
tqdm
|
tqdm
|
||||||
blake3
|
blake3
|
||||||
py-cpuinfo
|
py-cpuinfo
|
||||||
transformers >= 4.56.0
|
transformers >= 4.56.0, < 5
|
||||||
tokenizers >= 0.21.1 # Required for fast incremental detokenization.
|
tokenizers >= 0.21.1 # Required for fast incremental detokenization.
|
||||||
protobuf # Required by LlamaTokenizer.
|
protobuf # Required by LlamaTokenizer.
|
||||||
fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint.
|
fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint.
|
||||||
|
|||||||
@ -29,7 +29,7 @@ opencv-python-headless >= 4.11.0 # required for video test
|
|||||||
datamodel_code_generator # required for minicpm3 test
|
datamodel_code_generator # required for minicpm3 test
|
||||||
lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d # required for model evaluation test
|
lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d # required for model evaluation test
|
||||||
mteb>=1.38.11, <2 # required for mteb test
|
mteb>=1.38.11, <2 # required for mteb test
|
||||||
transformers==4.56.2
|
transformers==4.57.1
|
||||||
tokenizers==0.22.0
|
tokenizers==0.22.0
|
||||||
schemathesis>=3.39.15 # Required for openai schema test.
|
schemathesis>=3.39.15 # Required for openai schema test.
|
||||||
# quantization
|
# quantization
|
||||||
|
|||||||
@ -37,7 +37,7 @@ datamodel_code_generator # required for minicpm3 test
|
|||||||
# TODO: Use lm-eval[api]==0.4.10 once released
|
# TODO: Use lm-eval[api]==0.4.10 once released
|
||||||
lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d # required for model evaluation test
|
lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d # required for model evaluation test
|
||||||
mteb[bm25s]>=1.38.11, <2 # required for mteb test
|
mteb[bm25s]>=1.38.11, <2 # required for mteb test
|
||||||
transformers==4.56.2
|
transformers==4.57.1
|
||||||
tokenizers==0.22.0
|
tokenizers==0.22.0
|
||||||
schemathesis>=3.39.15 # Required for openai schema test.
|
schemathesis>=3.39.15 # Required for openai schema test.
|
||||||
# quantization
|
# quantization
|
||||||
|
|||||||
@ -1196,7 +1196,7 @@ tqdm==4.66.6
|
|||||||
# transformers
|
# transformers
|
||||||
tqdm-multiprocess==0.0.11
|
tqdm-multiprocess==0.0.11
|
||||||
# via lm-eval
|
# via lm-eval
|
||||||
transformers==4.56.2
|
transformers==4.57.1
|
||||||
# via
|
# via
|
||||||
# -r requirements/test.in
|
# -r requirements/test.in
|
||||||
# genai-perf
|
# genai-perf
|
||||||
|
|||||||
@ -6,6 +6,9 @@ from copy import deepcopy
|
|||||||
|
|
||||||
from tblib import pickling_support
|
from tblib import pickling_support
|
||||||
|
|
||||||
|
# Import fixture
|
||||||
|
from tests.v1.entrypoints.conftest import sample_json_schema # noqa
|
||||||
|
|
||||||
# ruff: noqa
|
# ruff: noqa
|
||||||
|
|
||||||
# Install support for pickling exceptions so that we can nicely propagate
|
# Install support for pickling exceptions so that we can nicely propagate
|
||||||
|
|||||||
@ -237,7 +237,7 @@ def deepseekv2_lora_files():
|
|||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def gptoss20b_lora_files():
|
def gptoss20b_lora_files():
|
||||||
return snapshot_download(repo_id="LevinZheng/gpt-oss-20b-lora-adapter")
|
return snapshot_download(repo_id="jeeejeee/gpt-oss-20b-lora-adapter-text2sql")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
|
|||||||
@ -1,6 +1,10 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
# NOTE To avoid overloading the CI pipeline, this test script will
|
||||||
|
# not be triggered on CI and is primarily intended for local testing
|
||||||
|
# and verification.
|
||||||
|
|
||||||
import vllm
|
import vllm
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
|
|
||||||
|
|||||||
@ -1,52 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
|
|
||||||
import vllm
|
|
||||||
from vllm.lora.request import LoRARequest
|
|
||||||
|
|
||||||
MODEL_PATH = "openai/gpt-oss-20b"
|
|
||||||
|
|
||||||
PROMPT_TEMPLATE = "<|begin▁of▁sentence|>You are a helpful assistant.\n\nUser: {context}\n\nAssistant:" # noqa: E501
|
|
||||||
|
|
||||||
|
|
||||||
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
|
|
||||||
prompts = [
|
|
||||||
PROMPT_TEMPLATE.format(context="Who are you?"),
|
|
||||||
]
|
|
||||||
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64)
|
|
||||||
outputs = llm.generate(
|
|
||||||
prompts,
|
|
||||||
sampling_params,
|
|
||||||
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None,
|
|
||||||
)
|
|
||||||
# Print the outputs.
|
|
||||||
generated_texts: list[str] = []
|
|
||||||
for output in outputs:
|
|
||||||
prompt = output.prompt
|
|
||||||
generated_text = output.outputs[0].text.strip()
|
|
||||||
generated_texts.append(generated_text)
|
|
||||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
|
||||||
return generated_texts
|
|
||||||
|
|
||||||
|
|
||||||
# FIXME: Load gpt-oss adapter
|
|
||||||
def test_gptoss20b_lora(gptoss20b_lora_files):
|
|
||||||
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
|
|
||||||
# Otherwise, the lora-test will fail due to CUDA OOM.
|
|
||||||
llm = vllm.LLM(
|
|
||||||
MODEL_PATH,
|
|
||||||
enable_lora=True,
|
|
||||||
max_loras=4,
|
|
||||||
trust_remote_code=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
expected_lora_output = [
|
|
||||||
"I am an AI language model developed by OpenAI. "
|
|
||||||
"I am here to help you with any questions or "
|
|
||||||
"tasks you may have."
|
|
||||||
]
|
|
||||||
|
|
||||||
output1 = do_sample(llm, gptoss20b_lora_files, lora_id=1)
|
|
||||||
print(output1)
|
|
||||||
for i in range(len(expected_lora_output)):
|
|
||||||
assert output1[i].startswith(expected_lora_output[i])
|
|
||||||
106
tests/lora/test_gptoss_tp.py
Normal file
106
tests/lora/test_gptoss_tp.py
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import vllm
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
|
||||||
|
from ..utils import multi_gpu_test
|
||||||
|
|
||||||
|
MODEL_PATH = "openai/gpt-oss-20b"
|
||||||
|
|
||||||
|
PROMPT_TEMPLATE = """<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.
|
||||||
|
Knowledge cutoff: 2024-06
|
||||||
|
Current date: 2025-10-29
|
||||||
|
|
||||||
|
Reasoning: medium
|
||||||
|
|
||||||
|
# Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>user<|message|>I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.
|
||||||
|
"
|
||||||
|
##Instruction:
|
||||||
|
farm contains tables such as city, farm, farm_competition, competition_record. Table city has columns such as City_ID, Official_Name, Status, Area_km_2, Population, Census_Ranking. City_ID is the primary key.
|
||||||
|
Table farm has columns such as Farm_ID, Year, Total_Horses, Working_Horses, Total_Cattle, Oxen, Bulls, Cows, Pigs, Sheep_and_Goats. Farm_ID is the primary key.
|
||||||
|
Table farm_competition has columns such as Competition_ID, Year, Theme, Host_city_ID, Hosts. Competition_ID is the primary key.
|
||||||
|
Table competition_record has columns such as Competition_ID, Farm_ID, Rank. Competition_ID is the primary key.
|
||||||
|
The Host_city_ID of farm_competition is the foreign key of City_ID of city.
|
||||||
|
The Farm_ID of competition_record is the foreign key of Farm_ID of farm.
|
||||||
|
The Competition_ID of competition_record is the foreign key of Competition_ID of farm_competition.
|
||||||
|
|
||||||
|
|
||||||
|
###Input:
|
||||||
|
{context}
|
||||||
|
|
||||||
|
###Response:<|end|><|start|>assistant<|channel|>final<|message|>""" # noqa: E501
|
||||||
|
|
||||||
|
EXPECTED_LORA_OUTPUT = [
|
||||||
|
"SELECT AVG(Working_Horses) FROM farm WHERE Total_Horses > 5000;",
|
||||||
|
"SELECT AVG(Working_Horses) FROM farm WHERE Total_Horses > 5000;",
|
||||||
|
"SELECT MAX(Cows) AS Max_Cows, MIN(Cows) AS Min_Cows FROM farm;",
|
||||||
|
"SELECT MAX(Cows) AS Max_Cows, MIN(Cows) AS Min_Cows FROM farm;",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:
|
||||||
|
prompts = [
|
||||||
|
PROMPT_TEMPLATE.format(
|
||||||
|
context="What is the average number of working horses of farms with more than 5000 total number of horses?" # noqa: E501
|
||||||
|
), # noqa: E501
|
||||||
|
PROMPT_TEMPLATE.format(
|
||||||
|
context="Give the average number of working horses on farms with more than 5000 total horses." # noqa: E501
|
||||||
|
), # noqa: E501
|
||||||
|
PROMPT_TEMPLATE.format(
|
||||||
|
context="What are the maximum and minimum number of cows across all farms."
|
||||||
|
),
|
||||||
|
PROMPT_TEMPLATE.format(
|
||||||
|
context="Return the maximum and minimum number of cows across all farms."
|
||||||
|
),
|
||||||
|
]
|
||||||
|
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64)
|
||||||
|
outputs = llm.generate(
|
||||||
|
prompts,
|
||||||
|
sampling_params,
|
||||||
|
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None,
|
||||||
|
)
|
||||||
|
# Print the outputs.
|
||||||
|
generated_texts: list[str] = []
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text.strip()
|
||||||
|
generated_texts.append(generated_text)
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
|
||||||
|
for i in range(len(EXPECTED_LORA_OUTPUT)):
|
||||||
|
assert generated_texts[i].startswith(EXPECTED_LORA_OUTPUT[i])
|
||||||
|
|
||||||
|
|
||||||
|
def test_gpt_oss_lora(gptoss20b_lora_files):
|
||||||
|
llm = vllm.LLM(
|
||||||
|
MODEL_PATH,
|
||||||
|
max_model_len=1024,
|
||||||
|
enable_lora=True,
|
||||||
|
max_loras=4,
|
||||||
|
max_lora_rank=8,
|
||||||
|
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
|
||||||
|
cudagraph_specialize_lora=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
generate_and_test(llm, gptoss20b_lora_files, lora_id=1)
|
||||||
|
generate_and_test(llm, gptoss20b_lora_files, lora_id=2)
|
||||||
|
|
||||||
|
|
||||||
|
@multi_gpu_test(num_gpus=2)
|
||||||
|
def test_gpt_oss_lora_tp2(gptoss20b_lora_files):
|
||||||
|
llm = vllm.LLM(
|
||||||
|
MODEL_PATH,
|
||||||
|
max_model_len=1024,
|
||||||
|
enable_lora=True,
|
||||||
|
max_loras=2,
|
||||||
|
max_lora_rank=8,
|
||||||
|
tensor_parallel_size=2,
|
||||||
|
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
|
||||||
|
cudagraph_specialize_lora=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
generate_and_test(llm, gptoss20b_lora_files, lora_id=1)
|
||||||
|
generate_and_test(llm, gptoss20b_lora_files, lora_id=2)
|
||||||
@ -1,6 +1,10 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE To avoid overloading the CI pipeline, this test script will not
|
||||||
|
# be triggered on CI and is primarily intended for local testing and verification.
|
||||||
|
|
||||||
import vllm
|
import vllm
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
|
|
||||||
|
|||||||
86
tests/models/multimodal/generation/test_keye.py
Normal file
86
tests/models/multimodal/generation/test_keye.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from dataclasses import asdict
|
||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from PIL.Image import Image
|
||||||
|
from transformers import AutoProcessor
|
||||||
|
|
||||||
|
from vllm import LLM, EngineArgs, SamplingParams
|
||||||
|
from vllm.multimodal.utils import encode_image_base64
|
||||||
|
|
||||||
|
MODEL_NAME = "Kwai-Keye/Keye-VL-8B-Preview"
|
||||||
|
|
||||||
|
QUESTION = "What is the content of each image?"
|
||||||
|
|
||||||
|
|
||||||
|
class ModelRequestData(NamedTuple):
|
||||||
|
engine_args: EngineArgs
|
||||||
|
prompt: str
|
||||||
|
image_data: list[Image]
|
||||||
|
stop_token_ids: list[int] | None = None
|
||||||
|
chat_template: str | None = None
|
||||||
|
sampling_params: SamplingParams | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.core_model
|
||||||
|
@pytest.mark.parametrize("question", [QUESTION])
|
||||||
|
def test_keye_vl(
|
||||||
|
image_assets,
|
||||||
|
question: str,
|
||||||
|
):
|
||||||
|
images = [asset.pil_image for asset in image_assets]
|
||||||
|
|
||||||
|
image_urls = [
|
||||||
|
f"data:image/jpeg;base64,{encode_image_base64(image)}" for image in images
|
||||||
|
]
|
||||||
|
|
||||||
|
engine_args = EngineArgs(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
trust_remote_code=True,
|
||||||
|
max_model_len=8192,
|
||||||
|
max_num_seqs=5,
|
||||||
|
limit_mm_per_prompt={"image": len(image_urls)},
|
||||||
|
)
|
||||||
|
|
||||||
|
placeholders = [{"type": "image", "image": url} for url in image_urls]
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
*placeholders,
|
||||||
|
{"type": "text", "text": question},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
||||||
|
|
||||||
|
prompt = processor.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
|
||||||
|
engine_args = asdict(engine_args) | {"seed": 42}
|
||||||
|
llm = LLM(**engine_args)
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=0.0, max_tokens=256, stop_token_ids=None
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = llm.generate(
|
||||||
|
{
|
||||||
|
"prompt": prompt,
|
||||||
|
"multi_modal_data": {"image": images},
|
||||||
|
},
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("-" * 50)
|
||||||
|
for o in outputs:
|
||||||
|
generated_text = o.outputs[0].text
|
||||||
|
print(generated_text)
|
||||||
|
assert len(generated_text) > 10, (
|
||||||
|
f"Generated text is too short: {generated_text}"
|
||||||
|
)
|
||||||
|
print("-" * 50)
|
||||||
@ -186,6 +186,8 @@ def create_reduced_config(
|
|||||||
if "text_config" in config_dict:
|
if "text_config" in config_dict:
|
||||||
original_text_layers = config_dict["text_config"]["num_hidden_layers"]
|
original_text_layers = config_dict["text_config"]["num_hidden_layers"]
|
||||||
config_dict["text_config"]["num_hidden_layers"] = text_layers
|
config_dict["text_config"]["num_hidden_layers"] = text_layers
|
||||||
|
original_layer_types = config_dict["text_config"]["layer_types"]
|
||||||
|
config_dict["text_config"]["layer_types"] = original_layer_types[:text_layers]
|
||||||
print(f"Reduced text layers from {original_text_layers} to {text_layers}")
|
print(f"Reduced text layers from {original_text_layers} to {text_layers}")
|
||||||
|
|
||||||
original_num_experts = config_dict["text_config"]["num_local_experts"]
|
original_num_experts = config_dict["text_config"]["num_local_experts"]
|
||||||
|
|||||||
@ -882,27 +882,27 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
|||||||
|
|
||||||
_TRANSFORMERS_BACKEND_MODELS = {
|
_TRANSFORMERS_BACKEND_MODELS = {
|
||||||
"TransformersEmbeddingModel": _HfExamplesInfo(
|
"TransformersEmbeddingModel": _HfExamplesInfo(
|
||||||
"BAAI/bge-base-en-v1.5", min_transformers_version="4.57.0.dev0"
|
"BAAI/bge-base-en-v1.5", min_transformers_version="5.0.0"
|
||||||
),
|
),
|
||||||
"TransformersForSequenceClassification": _HfExamplesInfo(
|
"TransformersForSequenceClassification": _HfExamplesInfo(
|
||||||
"papluca/xlm-roberta-base-language-detection",
|
"papluca/xlm-roberta-base-language-detection",
|
||||||
min_transformers_version="4.57.0.dev0",
|
min_transformers_version="5.0.0",
|
||||||
),
|
),
|
||||||
"TransformersForCausalLM": _HfExamplesInfo(
|
"TransformersForCausalLM": _HfExamplesInfo(
|
||||||
"hmellor/Ilama-3.2-1B", trust_remote_code=True
|
"hmellor/Ilama-3.2-1B", trust_remote_code=True
|
||||||
),
|
),
|
||||||
"TransformersMultiModalForCausalLM": _HfExamplesInfo("BAAI/Emu3-Chat-hf"),
|
"TransformersMultiModalForCausalLM": _HfExamplesInfo("BAAI/Emu3-Chat-hf"),
|
||||||
"TransformersMoEForCausalLM": _HfExamplesInfo(
|
"TransformersMoEForCausalLM": _HfExamplesInfo(
|
||||||
"allenai/OLMoE-1B-7B-0924", min_transformers_version="4.57.0.dev0"
|
"allenai/OLMoE-1B-7B-0924", min_transformers_version="5.0.0"
|
||||||
),
|
),
|
||||||
"TransformersMultiModalMoEForCausalLM": _HfExamplesInfo(
|
"TransformersMultiModalMoEForCausalLM": _HfExamplesInfo(
|
||||||
"Qwen/Qwen3-VL-30B-A3B-Instruct", min_transformers_version="4.57.0.dev0"
|
"Qwen/Qwen3-VL-30B-A3B-Instruct", min_transformers_version="5.0.0"
|
||||||
),
|
),
|
||||||
"TransformersMoEEmbeddingModel": _HfExamplesInfo(
|
"TransformersMoEEmbeddingModel": _HfExamplesInfo(
|
||||||
"Qwen/Qwen3-30B-A3B", min_transformers_version="4.57.0.dev0"
|
"Qwen/Qwen3-30B-A3B", min_transformers_version="5.0.0"
|
||||||
),
|
),
|
||||||
"TransformersMoEForSequenceClassification": _HfExamplesInfo(
|
"TransformersMoEForSequenceClassification": _HfExamplesInfo(
|
||||||
"Qwen/Qwen3-30B-A3B", min_transformers_version="4.57.0.dev0"
|
"Qwen/Qwen3-30B-A3B", min_transformers_version="5.0.0"
|
||||||
),
|
),
|
||||||
"TransformersMultiModalEmbeddingModel": _HfExamplesInfo("google/gemma-3-4b-it"),
|
"TransformersMultiModalEmbeddingModel": _HfExamplesInfo("google/gemma-3-4b-it"),
|
||||||
"TransformersMultiModalForSequenceClassification": _HfExamplesInfo(
|
"TransformersMultiModalForSequenceClassification": _HfExamplesInfo(
|
||||||
|
|||||||
@ -82,7 +82,7 @@ def test_models(
|
|||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
|
|
||||||
installed = Version(transformers.__version__)
|
installed = Version(transformers.__version__)
|
||||||
required = Version("4.57.0.dev0")
|
required = Version("5.0.0")
|
||||||
if model == "allenai/OLMoE-1B-7B-0924" and installed < required:
|
if model == "allenai/OLMoE-1B-7B-0924" and installed < required:
|
||||||
pytest.skip(
|
pytest.skip(
|
||||||
"MoE models with the Transformers backend require "
|
"MoE models with the Transformers backend require "
|
||||||
|
|||||||
@ -14,16 +14,19 @@ import torch
|
|||||||
from tests.v1.attention.utils import (
|
from tests.v1.attention.utils import (
|
||||||
BatchSpec,
|
BatchSpec,
|
||||||
create_common_attn_metadata,
|
create_common_attn_metadata,
|
||||||
create_standard_kv_cache_spec,
|
|
||||||
create_vllm_config,
|
create_vllm_config,
|
||||||
try_get_attention_backend,
|
try_get_attention_backend,
|
||||||
)
|
)
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.attention.backends.registry import _Backend
|
from vllm.attention.backends.registry import _Backend, backend_to_class_str
|
||||||
from vllm.attention.ops.flashmla import is_flashmla_dense_supported
|
from vllm.attention.ops.flashmla import is_flashmla_dense_supported
|
||||||
|
from vllm.attention.utils.fa_utils import flash_attn_supports_mla
|
||||||
from vllm.config.vllm import set_current_vllm_config
|
from vllm.config.vllm import set_current_vllm_config
|
||||||
|
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||||
|
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||||
from vllm.utils.math_utils import cdiv
|
from vllm.utils.math_utils import cdiv
|
||||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||||
|
from vllm.v1.attention.backends.mla.common import QueryLenSupport
|
||||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||||
|
|
||||||
@ -31,17 +34,46 @@ BACKENDS_TO_TEST = [
|
|||||||
_Backend.CUTLASS_MLA,
|
_Backend.CUTLASS_MLA,
|
||||||
_Backend.FLASHMLA,
|
_Backend.FLASHMLA,
|
||||||
_Backend.FLASH_ATTN_MLA,
|
_Backend.FLASH_ATTN_MLA,
|
||||||
|
_Backend.FLASHINFER_MLA,
|
||||||
_Backend.TRITON_MLA,
|
_Backend.TRITON_MLA,
|
||||||
]
|
]
|
||||||
|
|
||||||
# Remove CUTLASS_MLA from the list if not using sm100
|
# Remove sm100 backends from the list if not using sm100
|
||||||
if not torch.cuda.is_available() or torch.cuda.get_device_properties(0).major < 10:
|
if not torch.cuda.is_available() or torch.cuda.get_device_properties(0).major < 10:
|
||||||
BACKENDS_TO_TEST.remove(_Backend.CUTLASS_MLA)
|
BACKENDS_TO_TEST.remove(_Backend.CUTLASS_MLA)
|
||||||
|
BACKENDS_TO_TEST.remove(_Backend.FLASHINFER_MLA)
|
||||||
|
|
||||||
|
# Remove FLASH_ATTN_MLA from the list if not supported
|
||||||
|
if not flash_attn_supports_mla():
|
||||||
|
BACKENDS_TO_TEST.remove(_Backend.FLASH_ATTN_MLA)
|
||||||
|
|
||||||
# Remove FLASHMLA from the list if not supported
|
# Remove FLASHMLA from the list if not supported
|
||||||
if not is_flashmla_dense_supported()[0]:
|
if not is_flashmla_dense_supported()[0]:
|
||||||
BACKENDS_TO_TEST.remove(_Backend.FLASHMLA)
|
BACKENDS_TO_TEST.remove(_Backend.FLASHMLA)
|
||||||
|
|
||||||
|
SPEC_DECODE_BACKENDS = []
|
||||||
|
for backend in BACKENDS_TO_TEST:
|
||||||
|
builder_cls, _ = try_get_attention_backend(backend)
|
||||||
|
query_len_support = getattr(
|
||||||
|
builder_cls, "query_len_support", QueryLenSupport.SINGLE_ONLY
|
||||||
|
)
|
||||||
|
if query_len_support != QueryLenSupport.SINGLE_ONLY:
|
||||||
|
SPEC_DECODE_BACKENDS.append(backend)
|
||||||
|
|
||||||
|
BACKEND_BLOCK_SIZES = {}
|
||||||
|
for backend in BACKENDS_TO_TEST:
|
||||||
|
backend_class_str = backend_to_class_str(backend)
|
||||||
|
backend_class = resolve_obj_by_qualname(backend_class_str)
|
||||||
|
supported_sizes = backend_class.get_supported_kernel_block_size()
|
||||||
|
if supported_sizes:
|
||||||
|
default_size = supported_sizes[0]
|
||||||
|
block_size = (
|
||||||
|
default_size if isinstance(default_size, int) else default_size.base
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
block_size = 16
|
||||||
|
BACKEND_BLOCK_SIZES[backend] = block_size
|
||||||
|
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
|
||||||
@ -236,6 +268,26 @@ class MockAttentionLayer:
|
|||||||
self._q_scale = torch.tensor(1.0, device=device)
|
self._q_scale = torch.tensor(1.0, device=device)
|
||||||
self._k_scale = torch.tensor(1.0, device=device)
|
self._k_scale = torch.tensor(1.0, device=device)
|
||||||
self._v_scale = torch.tensor(1.0, device=device)
|
self._v_scale = torch.tensor(1.0, device=device)
|
||||||
|
self._prob_scale = torch.tensor(1.0, device=device)
|
||||||
|
self._q_scale_float = 1.0
|
||||||
|
self._k_scale_float = 1.0
|
||||||
|
self._v_scale_float = 1.0
|
||||||
|
|
||||||
|
def forward(self, *_args, **_kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class MockMLAAttentionLayer(AttentionLayerBase):
|
||||||
|
"""A mock MLA attention layer for populating static_forward_context."""
|
||||||
|
|
||||||
|
def __init__(self, impl):
|
||||||
|
self.impl = impl
|
||||||
|
|
||||||
|
def get_attn_backend(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_kv_cache_spec(self, vllm_config):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
def run_attention_backend(
|
def run_attention_backend(
|
||||||
@ -262,13 +314,6 @@ def run_attention_backend(
|
|||||||
# Set the current vllm config so that get_current_vllm_config() works
|
# Set the current vllm config so that get_current_vllm_config() works
|
||||||
# in the backend implementations
|
# in the backend implementations
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
# Build metadata
|
|
||||||
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
|
|
||||||
attn_metadata = builder.build(
|
|
||||||
common_prefix_len=0,
|
|
||||||
common_attn_metadata=common_attn_metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Instantiate MLA implementation
|
# Instantiate MLA implementation
|
||||||
num_heads = vllm_config.model_config.get_num_attention_heads(
|
num_heads = vllm_config.model_config.get_num_attention_heads(
|
||||||
vllm_config.parallel_config
|
vllm_config.parallel_config
|
||||||
@ -302,6 +347,19 @@ def run_attention_backend(
|
|||||||
act_dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
|
act_dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
|
||||||
impl.process_weights_after_loading(act_dtype)
|
impl.process_weights_after_loading(act_dtype)
|
||||||
|
|
||||||
|
# Populate static_forward_context with mock attention layers
|
||||||
|
for layer_name in layer_names:
|
||||||
|
vllm_config.compilation_config.static_forward_context[layer_name] = (
|
||||||
|
MockMLAAttentionLayer(impl)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build metadata
|
||||||
|
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
|
||||||
|
attn_metadata = builder.build(
|
||||||
|
common_prefix_len=0,
|
||||||
|
common_attn_metadata=common_attn_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
# Create mock layer and output buffer
|
# Create mock layer and output buffer
|
||||||
mock_layer = MockAttentionLayer(device)
|
mock_layer = MockAttentionLayer(device)
|
||||||
num_tokens = query.shape[0]
|
num_tokens = query.shape[0]
|
||||||
@ -353,15 +411,14 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
|||||||
simulated paged KV cache.
|
simulated paged KV cache.
|
||||||
5. Comparing the vLLM backend's output to the ground-truth SDPA output.
|
5. Comparing the vLLM backend's output to the ground-truth SDPA output.
|
||||||
"""
|
"""
|
||||||
from vllm.v1.attention.backends.mla.common import QueryLenSupport
|
|
||||||
|
|
||||||
batch_spec = BATCH_SPECS[batch_spec_name]
|
batch_spec = BATCH_SPECS[batch_spec_name]
|
||||||
is_spec_decode_test = batch_spec_name.startswith("spec_decode")
|
is_spec_decode_test = batch_spec_name.startswith("spec_decode")
|
||||||
spec_decode_backends = {_Backend.FLASH_ATTN_MLA, _Backend.FLASHMLA}
|
unique_block_sizes = sorted(set(BACKEND_BLOCK_SIZES.values()))
|
||||||
|
default_block_size = unique_block_sizes[0]
|
||||||
block_size = 16
|
|
||||||
required_blocks = sum(
|
required_blocks = sum(
|
||||||
(seq_len + block_size - 1) // block_size for seq_len in batch_spec.seq_lens
|
(seq_len + default_block_size - 1) // default_block_size
|
||||||
|
for seq_len in batch_spec.seq_lens
|
||||||
)
|
)
|
||||||
# Add 1 for null block at index 0, and some buffer
|
# Add 1 for null block at index 0, and some buffer
|
||||||
num_gpu_blocks = required_blocks + 1 + 100
|
num_gpu_blocks = required_blocks + 1 + 100
|
||||||
@ -370,7 +427,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
|||||||
model_name=model,
|
model_name=model,
|
||||||
max_model_len=max(batch_spec.seq_lens),
|
max_model_len=max(batch_spec.seq_lens),
|
||||||
num_gpu_blocks=num_gpu_blocks,
|
num_gpu_blocks=num_gpu_blocks,
|
||||||
block_size=block_size,
|
block_size=default_block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# For spec decode tests, add a speculative_config to set the reorder_batch_threshold
|
# For spec decode tests, add a speculative_config to set the reorder_batch_threshold
|
||||||
@ -388,8 +445,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
|||||||
|
|
||||||
device = torch.device("cuda:0")
|
device = torch.device("cuda:0")
|
||||||
|
|
||||||
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
|
|
||||||
|
|
||||||
# 1. Setup
|
# 1. Setup
|
||||||
batch_size = batch_spec.batch_size
|
batch_size = batch_spec.batch_size
|
||||||
seq_lens = batch_spec.seq_lens
|
seq_lens = batch_spec.seq_lens
|
||||||
@ -399,7 +454,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
|||||||
)
|
)
|
||||||
head_size = vllm_config.model_config.get_head_size()
|
head_size = vllm_config.model_config.get_head_size()
|
||||||
dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
|
dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
|
||||||
block_size = vllm_config.cache_config.block_size
|
|
||||||
kv_lora_rank = 512
|
kv_lora_rank = 512
|
||||||
qk_rope_head_dim = 64
|
qk_rope_head_dim = 64
|
||||||
qk_nope_head_dim = 128
|
qk_nope_head_dim = 128
|
||||||
@ -598,33 +652,83 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
|||||||
)
|
)
|
||||||
mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T, requires_grad=False)
|
mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T, requires_grad=False)
|
||||||
|
|
||||||
# Create metadata using original batch spec
|
# 3. Create metadata and KV caches for each block size
|
||||||
common_attn_metadata = create_common_attn_metadata(
|
# Group backends by block size and test each group
|
||||||
batch_spec, vllm_config.cache_config.block_size, device
|
metadata_per_block_size = {}
|
||||||
)
|
kv_cache_per_block_size = {}
|
||||||
|
|
||||||
# 3. Simulate Paged KV Cache and a realistic slot_mapping
|
for block_size in unique_block_sizes:
|
||||||
kv_cache = create_and_prepopulate_kv_cache(
|
# Create metadata for this block size
|
||||||
kv_c_contexts=kv_c_contexts,
|
common_attn_metadata = create_common_attn_metadata(
|
||||||
k_pe_contexts=k_pe_contexts,
|
batch_spec, block_size, device
|
||||||
block_size=block_size,
|
)
|
||||||
head_size=head_size,
|
|
||||||
dtype=dtype,
|
# Pad block table to meet requirement:
|
||||||
device=device,
|
# block_num % (128 / block_size) == 0
|
||||||
num_blocks=vllm_config.cache_config.num_gpu_blocks,
|
required_divisor = int(128 / block_size)
|
||||||
common_attn_metadata=common_attn_metadata,
|
current_block_num = common_attn_metadata.block_table_tensor.shape[1]
|
||||||
randomize_blocks=True,
|
if current_block_num % required_divisor != 0:
|
||||||
)
|
# Pad to next multiple of required_divisor
|
||||||
|
padded_block_num = (
|
||||||
|
(current_block_num + required_divisor - 1) // required_divisor
|
||||||
|
) * required_divisor
|
||||||
|
padding_cols = padded_block_num - current_block_num
|
||||||
|
padding = torch.zeros(
|
||||||
|
(common_attn_metadata.block_table_tensor.shape[0], padding_cols),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
common_attn_metadata.block_table_tensor = torch.cat(
|
||||||
|
[common_attn_metadata.block_table_tensor, padding], dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata_per_block_size[block_size] = common_attn_metadata
|
||||||
|
|
||||||
|
# Create KV cache for this block size
|
||||||
|
required_blocks_for_size = sum(
|
||||||
|
(seq_len + block_size - 1) // block_size for seq_len in batch_spec.seq_lens
|
||||||
|
)
|
||||||
|
num_blocks_for_size = required_blocks_for_size + 1 + 100
|
||||||
|
|
||||||
|
kv_cache = create_and_prepopulate_kv_cache(
|
||||||
|
kv_c_contexts=kv_c_contexts,
|
||||||
|
k_pe_contexts=k_pe_contexts,
|
||||||
|
block_size=block_size,
|
||||||
|
head_size=head_size,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
num_blocks=num_blocks_for_size,
|
||||||
|
common_attn_metadata=common_attn_metadata,
|
||||||
|
randomize_blocks=True,
|
||||||
|
)
|
||||||
|
kv_cache_per_block_size[block_size] = kv_cache
|
||||||
|
|
||||||
# 4. Run vLLM backends and compare
|
# 4. Run vLLM backends and compare
|
||||||
|
failures = []
|
||||||
for backend_idx, backend_name in enumerate(BACKENDS_TO_TEST):
|
for backend_idx, backend_name in enumerate(BACKENDS_TO_TEST):
|
||||||
# Skip backends that don't support spec decode for spec decode tests
|
# Skip backends that don't support spec decode for spec decode tests
|
||||||
if is_spec_decode_test and backend_name not in spec_decode_backends:
|
if is_spec_decode_test and backend_name not in SPEC_DECODE_BACKENDS:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Get the appropriate block_size, metadata, and cache for this backend
|
||||||
|
block_size = BACKEND_BLOCK_SIZES[backend_name]
|
||||||
|
common_attn_metadata = metadata_per_block_size[block_size]
|
||||||
|
kv_cache = kv_cache_per_block_size[block_size]
|
||||||
|
|
||||||
|
# Create kv_cache_spec with the correct block_size for this backend
|
||||||
|
backend_kv_cache_spec = FullAttentionSpec(
|
||||||
|
block_size=block_size,
|
||||||
|
num_kv_heads=vllm_config.model_config.get_num_kv_heads(
|
||||||
|
vllm_config.parallel_config
|
||||||
|
),
|
||||||
|
head_size=vllm_config.model_config.get_head_size(),
|
||||||
|
dtype=vllm_config.model_config.dtype,
|
||||||
|
sliding_window=vllm_config.model_config.get_sliding_window(),
|
||||||
|
)
|
||||||
|
|
||||||
backend_output = run_attention_backend(
|
backend_output = run_attention_backend(
|
||||||
backend_name,
|
backend_name,
|
||||||
kv_cache_spec,
|
backend_kv_cache_spec,
|
||||||
["placeholder"],
|
["placeholder"],
|
||||||
vllm_config,
|
vllm_config,
|
||||||
device,
|
device,
|
||||||
@ -644,32 +748,48 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
|||||||
expected_output = sdpa_outputs[backend_name]
|
expected_output = sdpa_outputs[backend_name]
|
||||||
|
|
||||||
# Check shape and dtype consistency
|
# Check shape and dtype consistency
|
||||||
assert backend_output.shape == expected_output.shape, (
|
try:
|
||||||
f"[{backend_name}] shape {backend_output.shape} != "
|
assert backend_output.shape == expected_output.shape, (
|
||||||
f"SDPA shape {expected_output.shape}"
|
f"[{backend_name}] shape {backend_output.shape} != "
|
||||||
)
|
f"SDPA shape {expected_output.shape}"
|
||||||
assert backend_output.dtype == expected_output.dtype, (
|
)
|
||||||
f"[{backend_name}] dtype {backend_output.dtype} != "
|
assert backend_output.dtype == expected_output.dtype, (
|
||||||
f"SDPA dtype {expected_output.dtype}"
|
f"[{backend_name}] dtype {backend_output.dtype} != "
|
||||||
)
|
f"SDPA dtype {expected_output.dtype}"
|
||||||
|
)
|
||||||
|
|
||||||
assert torch.isfinite(backend_output).all(), (
|
assert torch.isfinite(backend_output).all(), (
|
||||||
f"[{backend_name}] produced non-finite values"
|
f"[{backend_name}] produced non-finite values"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check numerical similarity
|
# Check numerical similarity
|
||||||
rtol = 1e-2
|
rtol = 1e-2
|
||||||
atol = 5e-1
|
atol = 5e-1
|
||||||
|
|
||||||
max_diff = torch.max(torch.abs(backend_output - expected_output)).item()
|
max_diff = torch.max(torch.abs(backend_output - expected_output)).item()
|
||||||
max_rel_diff = torch.max(
|
max_rel_diff = torch.max(
|
||||||
torch.abs(backend_output - expected_output) / torch.abs(expected_output)
|
torch.abs(backend_output - expected_output) / torch.abs(expected_output)
|
||||||
).item()
|
).item()
|
||||||
all_close = torch.allclose(
|
all_close = torch.allclose(
|
||||||
backend_output, expected_output, rtol=rtol, atol=atol
|
backend_output, expected_output, rtol=rtol, atol=atol
|
||||||
)
|
)
|
||||||
|
|
||||||
assert all_close, (
|
assert all_close, (
|
||||||
f"[{backend_name}] output differs from SDPA baseline. "
|
f"[{backend_name}] output differs from SDPA baseline. "
|
||||||
f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})"
|
f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})"
|
||||||
)
|
)
|
||||||
|
except AssertionError as e:
|
||||||
|
failures.append(str(e))
|
||||||
|
|
||||||
|
# Report all failures at once
|
||||||
|
if failures:
|
||||||
|
# Create a summary for the single-line failure message
|
||||||
|
backend_names = []
|
||||||
|
for f in failures:
|
||||||
|
if "[_Backend." in f:
|
||||||
|
backend_name = f.split("[")[1].split("]")[0]
|
||||||
|
backend_names.append(backend_name)
|
||||||
|
|
||||||
|
summary = f"{len(failures)} backend(s) failed: {', '.join(backend_names)}"
|
||||||
|
detailed_msg = "\n".join(failures)
|
||||||
|
pytest.fail(f"{summary}\n{detailed_msg}")
|
||||||
|
|||||||
@ -285,7 +285,17 @@ full_cg_backend_configs = {
|
|||||||
name="CutlassMLA",
|
name="CutlassMLA",
|
||||||
env_vars={
|
env_vars={
|
||||||
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
|
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
|
||||||
"FORCE_NUM_KV_SPLITS": "1", # TODO: remove this when hang issue is fixed
|
},
|
||||||
|
comp_config={
|
||||||
|
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||||
|
},
|
||||||
|
specific_gpu_arch=(10, 0),
|
||||||
|
),
|
||||||
|
# FlashInfer MLA on Blackwell
|
||||||
|
"FlashInferMLA": BackendConfig(
|
||||||
|
name="FlashInferMLA",
|
||||||
|
env_vars={
|
||||||
|
"VLLM_ATTENTION_BACKEND": "FLASHINFER_MLA",
|
||||||
},
|
},
|
||||||
comp_config={
|
comp_config={
|
||||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||||
|
|||||||
@ -337,8 +337,6 @@ def test_stop_via_update_from_output():
|
|||||||
num_common_prefix_blocks=[],
|
num_common_prefix_blocks=[],
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_mm_hashes=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids=[],
|
|
||||||
grammar_bitmask=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
model_output = ModelRunnerOutput(
|
model_output = ModelRunnerOutput(
|
||||||
@ -385,8 +383,6 @@ def test_stop_via_update_from_output():
|
|||||||
num_common_prefix_blocks=[],
|
num_common_prefix_blocks=[],
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_mm_hashes=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids=[],
|
|
||||||
grammar_bitmask=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
model_output = ModelRunnerOutput(
|
model_output = ModelRunnerOutput(
|
||||||
@ -431,8 +427,6 @@ def test_stop_via_update_from_output():
|
|||||||
num_common_prefix_blocks=[],
|
num_common_prefix_blocks=[],
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_mm_hashes=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids=[],
|
|
||||||
grammar_bitmask=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
model_output = ModelRunnerOutput(
|
model_output = ModelRunnerOutput(
|
||||||
@ -472,8 +466,6 @@ def test_stop_via_update_from_output():
|
|||||||
num_common_prefix_blocks=[],
|
num_common_prefix_blocks=[],
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_mm_hashes=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids=[],
|
|
||||||
grammar_bitmask=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
model_output = ModelRunnerOutput(
|
model_output = ModelRunnerOutput(
|
||||||
@ -1988,7 +1980,6 @@ def test_schedule_skip_tokenizer_init():
|
|||||||
scheduler.add_request(request)
|
scheduler.add_request(request)
|
||||||
output = scheduler.schedule()
|
output = scheduler.schedule()
|
||||||
assert len(output.scheduled_new_reqs) == len(requests)
|
assert len(output.scheduled_new_reqs) == len(requests)
|
||||||
assert output.grammar_bitmask is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_schedule_skip_tokenizer_init_structured_output_request():
|
def test_schedule_skip_tokenizer_init_structured_output_request():
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import torch._dynamo.config as dynamo_config
|
|||||||
|
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
from vllm.logprobs import Logprob
|
from vllm.logprobs import Logprob
|
||||||
|
from vllm.sampling_params import StructuredOutputsParams
|
||||||
|
|
||||||
from ...conftest import VllmRunner
|
from ...conftest import VllmRunner
|
||||||
from ...models.utils import check_outputs_equal
|
from ...models.utils import check_outputs_equal
|
||||||
@ -15,9 +16,12 @@ MODEL = "Qwen/Qwen3-0.6B"
|
|||||||
|
|
||||||
|
|
||||||
@dynamo_config.patch(cache_size_limit=16)
|
@dynamo_config.patch(cache_size_limit=16)
|
||||||
def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
|
def test_preempt_and_async_scheduling_e2e(
|
||||||
|
sample_json_schema, monkeypatch: pytest.MonkeyPatch
|
||||||
|
):
|
||||||
"""Test consistency of combos of async scheduling, preemption,
|
"""Test consistency of combos of async scheduling, preemption,
|
||||||
uni/multiproc executor, and various sampling parameters."""
|
uni/multiproc executor, and various sampling parameters
|
||||||
|
including structured outputs."""
|
||||||
|
|
||||||
first_prompt = (
|
first_prompt = (
|
||||||
"The following numbers of the sequence "
|
"The following numbers of the sequence "
|
||||||
@ -35,6 +39,12 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
|
|||||||
dict(bad_words=["the", " the"]),
|
dict(bad_words=["the", " the"]),
|
||||||
dict(logprobs=2),
|
dict(logprobs=2),
|
||||||
dict(logprobs=2, presence_penalty=-1.0),
|
dict(logprobs=2, presence_penalty=-1.0),
|
||||||
|
dict(structured_outputs=StructuredOutputsParams(json=sample_json_schema)),
|
||||||
|
dict(
|
||||||
|
structured_outputs=StructuredOutputsParams(json=sample_json_schema),
|
||||||
|
logprobs=2,
|
||||||
|
presence_penalty=-1.0,
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
default_params = dict(
|
default_params = dict(
|
||||||
@ -248,7 +248,7 @@ def test_engine_core_concurrent_batches():
|
|||||||
self,
|
self,
|
||||||
scheduler_output,
|
scheduler_output,
|
||||||
non_block=False,
|
non_block=False,
|
||||||
) -> Future[ModelRunnerOutput]:
|
) -> Future[ModelRunnerOutput | None]:
|
||||||
"""Make execute_model non-blocking."""
|
"""Make execute_model non-blocking."""
|
||||||
|
|
||||||
# DummyExecutor used only for testing async case.
|
# DummyExecutor used only for testing async case.
|
||||||
@ -263,6 +263,23 @@ def test_engine_core_concurrent_batches():
|
|||||||
# Use the thread pool instead of creating a new thread
|
# Use the thread pool instead of creating a new thread
|
||||||
return self.thread_pool.submit(_execute)
|
return self.thread_pool.submit(_execute)
|
||||||
|
|
||||||
|
def sample_tokens(
|
||||||
|
self, grammar_output, non_block=False
|
||||||
|
) -> Future[ModelRunnerOutput]:
|
||||||
|
"""Make sample_tokens non-blocking."""
|
||||||
|
|
||||||
|
# DummyExecutor used only for testing async case.
|
||||||
|
assert non_block
|
||||||
|
|
||||||
|
def _execute():
|
||||||
|
output = self.collective_rpc("sample_tokens", args=(grammar_output,))
|
||||||
|
# Make a copy because output[0] may be reused
|
||||||
|
# by the next batch.
|
||||||
|
return copy.deepcopy(output[0])
|
||||||
|
|
||||||
|
# Use the thread pool instead of creating a new thread
|
||||||
|
return self.thread_pool.submit(_execute)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def max_concurrent_batches(self) -> int:
|
def max_concurrent_batches(self) -> int:
|
||||||
return 2
|
return 2
|
||||||
|
|||||||
@ -31,7 +31,9 @@ class CustomMultiprocExecutor(MultiprocExecutor):
|
|||||||
# Drop marker to show that this was run
|
# Drop marker to show that this was run
|
||||||
with open(".marker", "w"):
|
with open(".marker", "w"):
|
||||||
...
|
...
|
||||||
return super().collective_rpc(method, timeout, args, kwargs)
|
return super().collective_rpc(
|
||||||
|
method, timeout, args, kwargs, non_block, unique_reply_rank
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
CustomMultiprocExecutorAsync = CustomMultiprocExecutor
|
CustomMultiprocExecutorAsync = CustomMultiprocExecutor
|
||||||
|
|||||||
65
tests/v1/kv_connector/unit/test_config.py
Normal file
65
tests/v1/kv_connector/unit/test_config.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
"""Tests for KV cache offloading configuration."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.config import CacheConfig, KVTransferConfig, ParallelConfig, VllmConfig
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.cpu_test
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"kv_offloading_backend,kv_offloading_size,tp,pp,expected_backend,expected_bytes",
|
||||||
|
[
|
||||||
|
("native", 4.0, 1, 1, "OffloadingConnector", 4.0 * (1 << 30)),
|
||||||
|
# bytes per rank: 8.0 GiB / (2 * 2) = 2.0 GiB
|
||||||
|
("native", 8.0, 2, 2, "OffloadingConnector", 8.0 * (1 << 30) / 4),
|
||||||
|
("lmcache", 4.0, 1, 1, "LMCacheConnectorV1", 4.0),
|
||||||
|
# size per rank: 8.0 GiB / (2 * 2) = 2.0 GiB
|
||||||
|
("lmcache", 8.0, 2, 2, "LMCacheConnectorV1", 2.0),
|
||||||
|
(None, None, 1, 1, None, None),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_kv_connector(
|
||||||
|
kv_offloading_backend, kv_offloading_size, tp, pp, expected_backend, expected_bytes
|
||||||
|
):
|
||||||
|
kv_transfer_config = (
|
||||||
|
KVTransferConfig(kv_connector_extra_config={"existing_key": "existing_value"})
|
||||||
|
if expected_backend is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
vllm_config = VllmConfig(
|
||||||
|
cache_config=CacheConfig(
|
||||||
|
kv_offloading_backend=kv_offloading_backend,
|
||||||
|
kv_offloading_size=kv_offloading_size,
|
||||||
|
),
|
||||||
|
kv_transfer_config=kv_transfer_config,
|
||||||
|
parallel_config=ParallelConfig(
|
||||||
|
tensor_parallel_size=tp, pipeline_parallel_size=pp
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# No KV transfer config expected
|
||||||
|
if expected_backend is None:
|
||||||
|
assert vllm_config.kv_transfer_config is expected_backend
|
||||||
|
return
|
||||||
|
|
||||||
|
kv_transfer_config = vllm_config.kv_transfer_config
|
||||||
|
kv_connector_extra_config = kv_transfer_config.kv_connector_extra_config
|
||||||
|
|
||||||
|
assert kv_transfer_config.kv_connector == expected_backend
|
||||||
|
assert kv_transfer_config.kv_role == "kv_both"
|
||||||
|
|
||||||
|
if kv_offloading_backend == "native":
|
||||||
|
assert kv_connector_extra_config["kv_bytes_per_rank"] == expected_bytes
|
||||||
|
assert kv_connector_extra_config["num_cpu_blocks"] == 0
|
||||||
|
# Existing config should be preserved
|
||||||
|
assert kv_connector_extra_config["existing_key"] == "existing_value"
|
||||||
|
elif kv_offloading_backend == "lmcache":
|
||||||
|
assert kv_connector_extra_config["lmcache.local_cpu"] is True
|
||||||
|
assert kv_connector_extra_config["lmcache.max_local_cpu_size"] == expected_bytes
|
||||||
|
# Existing config should be replaced
|
||||||
|
assert "existing_key" not in kv_connector_extra_config
|
||||||
@ -26,8 +26,6 @@ def _make_empty_scheduler_output():
|
|||||||
num_common_prefix_blocks=[],
|
num_common_prefix_blocks=[],
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_mm_hashes=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids=[],
|
|
||||||
grammar_bitmask=None,
|
|
||||||
kv_connector_metadata=SharedStorageConnectorMetadata(),
|
kv_connector_metadata=SharedStorageConnectorMetadata(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -27,6 +27,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
|||||||
NixlAgentMetadata,
|
NixlAgentMetadata,
|
||||||
NixlConnector,
|
NixlConnector,
|
||||||
NixlConnectorMetadata,
|
NixlConnectorMetadata,
|
||||||
|
NixlConnectorScheduler,
|
||||||
NixlConnectorWorker,
|
NixlConnectorWorker,
|
||||||
NixlKVConnectorStats,
|
NixlKVConnectorStats,
|
||||||
)
|
)
|
||||||
@ -283,6 +284,92 @@ def test_prompt_less_than_block_size():
|
|||||||
assert len(scheduler_output.scheduled_new_reqs) == 0
|
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||||
|
FakeNixlWrapper,
|
||||||
|
)
|
||||||
|
def test_kv_transfer_handshake(dist_init):
|
||||||
|
"""Unit test for basic NixlConnector interface functionality."""
|
||||||
|
|
||||||
|
# Test setup, we creates a scheduler that contains a NixlConnector
|
||||||
|
# of role SCHEDULER, and expect it to be serving NixlAgentMetadata from
|
||||||
|
# all workers of the instance.
|
||||||
|
vllm_config = create_vllm_config()
|
||||||
|
# in case the test runs on non-GPU machine
|
||||||
|
vllm_config.kv_transfer_config.kv_buffer_device = "cpu"
|
||||||
|
scheduler = create_scheduler(vllm_config)
|
||||||
|
|
||||||
|
# Create two NixlConnector of role WORKER, one is the worker of
|
||||||
|
# the scheduler (prefill), the other is a worker of decode instance.
|
||||||
|
|
||||||
|
# Prefill connector will register KV cache to populate proper handshake
|
||||||
|
# metadata.
|
||||||
|
prefill_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||||
|
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
|
||||||
|
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
|
||||||
|
)
|
||||||
|
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
|
||||||
|
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
|
||||||
|
kv_caches = {
|
||||||
|
"layer0": shared_tensor,
|
||||||
|
"layer1": unique_tensor,
|
||||||
|
"layer2": shared_tensor,
|
||||||
|
}
|
||||||
|
prefill_connector.register_kv_caches(kv_caches)
|
||||||
|
|
||||||
|
# Simulate EngineCore initialization that would
|
||||||
|
# gather connector metadata from all workers, the scheduler connector
|
||||||
|
# expects metadata to be in dict[int, KVConnectorHandshakeMetadata],
|
||||||
|
# where the first key is the dp_rank, the second key is the tp_rank.
|
||||||
|
metadata = {0: prefill_connector.get_handshake_metadata()}
|
||||||
|
scheduler_connector = scheduler.get_kv_connector()
|
||||||
|
scheduler_connector.set_xfer_handshake_metadata(metadata)
|
||||||
|
|
||||||
|
# Simulate a request that finishes prefill, which returns
|
||||||
|
# corresponding NixlConnectorMetadata for decode instance.
|
||||||
|
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||||
|
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||||
|
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||||
|
|
||||||
|
request = create_request(
|
||||||
|
request_id=1,
|
||||||
|
block_size=BLOCK_SIZE,
|
||||||
|
num_tokens=NUM_TOKENS,
|
||||||
|
do_remote_decode=True,
|
||||||
|
)
|
||||||
|
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
|
||||||
|
delay, kv_connector_metadata = scheduler.get_kv_connector().request_finished(
|
||||||
|
request, [0, 1, 2]
|
||||||
|
)
|
||||||
|
assert delay
|
||||||
|
|
||||||
|
# Decode connector will be able to create handshake with the prefill connector.
|
||||||
|
decode_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||||
|
|
||||||
|
# Here we are testing the retrieval of NIXLAgentMetadata.
|
||||||
|
# Knowing the implementation detail, we override the add_remote_agent
|
||||||
|
# to validate the metadata received is the same as the one in prefill_connector.
|
||||||
|
with patch.object(
|
||||||
|
decode_connector.connector_worker, "add_remote_agent"
|
||||||
|
) as mock_add_remote_agent:
|
||||||
|
mock_add_remote_agent.return_type = "remote_agent"
|
||||||
|
|
||||||
|
decode_connector.connector_worker._nixl_handshake(
|
||||||
|
kv_connector_metadata["remote_host"],
|
||||||
|
kv_connector_metadata["remote_port"],
|
||||||
|
kv_connector_metadata["tp_size"],
|
||||||
|
kv_connector_metadata["remote_engine_id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
received_metadata = mock_add_remote_agent.call_args.args
|
||||||
|
assert received_metadata[1] == 0 # remote_tp_rank
|
||||||
|
assert received_metadata[2] == 1 # remote_tp_size
|
||||||
|
assert metadata[0] == received_metadata[0]
|
||||||
|
|
||||||
|
# Need to shutdown the background thread to release NIXL side channel port
|
||||||
|
scheduler_connector.shutdown()
|
||||||
|
|
||||||
|
|
||||||
class FakeNixlConnectorWorker(NixlConnectorWorker):
|
class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||||
REMOTE_ENGINE_ID = "remote_engine"
|
REMOTE_ENGINE_ID = "remote_engine"
|
||||||
|
|
||||||
@ -313,6 +400,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
|
|||||||
engine_id=self.REMOTE_ENGINE_ID,
|
engine_id=self.REMOTE_ENGINE_ID,
|
||||||
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
||||||
kv_caches_base_addr=[0],
|
kv_caches_base_addr=[0],
|
||||||
|
device_id=0,
|
||||||
num_blocks=1,
|
num_blocks=1,
|
||||||
block_lens=self.block_len_per_layer,
|
block_lens=self.block_len_per_layer,
|
||||||
attn_backend_name=self.backend_name,
|
attn_backend_name=self.backend_name,
|
||||||
@ -559,6 +647,7 @@ class TestNixlHandshake:
|
|||||||
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||||
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
||||||
kv_caches_base_addr=[0],
|
kv_caches_base_addr=[0],
|
||||||
|
device_id=0,
|
||||||
num_blocks=1,
|
num_blocks=1,
|
||||||
block_lens=worker.block_len_per_layer,
|
block_lens=worker.block_len_per_layer,
|
||||||
attn_backend_name=worker.backend_name,
|
attn_backend_name=worker.backend_name,
|
||||||
@ -611,6 +700,7 @@ class TestNixlHandshake:
|
|||||||
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||||
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
||||||
kv_caches_base_addr=[0],
|
kv_caches_base_addr=[0],
|
||||||
|
device_id=0,
|
||||||
num_blocks=1,
|
num_blocks=1,
|
||||||
# prefill TP=1, decode TP=2, remote block_lens is double to local
|
# prefill TP=1, decode TP=2, remote block_lens is double to local
|
||||||
block_lens=[i * 2 for i in worker.block_len_per_layer],
|
block_lens=[i * 2 for i in worker.block_len_per_layer],
|
||||||
@ -891,9 +981,7 @@ def test_scheduler_kv_connector_stats_aggregation():
|
|||||||
scheduled_encoder_inputs={},
|
scheduled_encoder_inputs={},
|
||||||
num_common_prefix_blocks=[0],
|
num_common_prefix_blocks=[0],
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_mm_hashes=set(),
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids={},
|
|
||||||
grammar_bitmask=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
engine_core_outputs = scheduler.update_from_output(scheduler_output, model_output)
|
engine_core_outputs = scheduler.update_from_output(scheduler_output, model_output)
|
||||||
@ -1005,6 +1093,8 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
|
|||||||
_ = llm.generate([f"What is the capital of France? {padding}"], sampling_params)
|
_ = llm.generate([f"What is the capital of France? {padding}"], sampling_params)
|
||||||
# Request-0 times out and is cleared!
|
# Request-0 times out and is cleared!
|
||||||
assert "0" not in req_to_blocks
|
assert "0" not in req_to_blocks
|
||||||
|
# Need to shutdown the background thread to release NIXL side channel port
|
||||||
|
llm.llm_engine.engine_core.shutdown()
|
||||||
|
|
||||||
|
|
||||||
def test_register_kv_caches(dist_init):
|
def test_register_kv_caches(dist_init):
|
||||||
@ -1177,13 +1267,15 @@ def test_shutdown_cleans_up_resources(dist_init):
|
|||||||
"""Test that shutdown() properly cleans up all resources."""
|
"""Test that shutdown() properly cleans up all resources."""
|
||||||
vllm_config = create_vllm_config()
|
vllm_config = create_vllm_config()
|
||||||
|
|
||||||
|
scheduler = NixlConnectorScheduler(
|
||||||
|
vllm_config, vllm_config.kv_transfer_config.engine_id
|
||||||
|
)
|
||||||
worker = NixlConnectorWorker(vllm_config, vllm_config.kv_transfer_config.engine_id)
|
worker = NixlConnectorWorker(vllm_config, vllm_config.kv_transfer_config.engine_id)
|
||||||
nixl_wrapper = worker.nixl_wrapper
|
nixl_wrapper = worker.nixl_wrapper
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch.object(worker, "_handshake_initiation_executor") as mock_exec,
|
patch.object(worker, "_handshake_initiation_executor") as mock_exec,
|
||||||
patch.object(worker, "_nixl_handshake_listener_t") as mock_listener,
|
patch.object(scheduler, "_nixl_handshake_listener_t") as mock_listener,
|
||||||
patch.object(worker, "_nixl_handshake_listener_stop_event") as mock_event,
|
|
||||||
patch.object(nixl_wrapper, "release_xfer_handle") as mock_rel_xfer,
|
patch.object(nixl_wrapper, "release_xfer_handle") as mock_rel_xfer,
|
||||||
patch.object(nixl_wrapper, "release_dlist_handle") as mock_rel_dlist,
|
patch.object(nixl_wrapper, "release_dlist_handle") as mock_rel_dlist,
|
||||||
patch.object(nixl_wrapper, "remove_remote_agent") as mock_rem_agent,
|
patch.object(nixl_wrapper, "remove_remote_agent") as mock_rem_agent,
|
||||||
@ -1204,8 +1296,12 @@ def test_shutdown_cleans_up_resources(dist_init):
|
|||||||
worker.shutdown()
|
worker.shutdown()
|
||||||
|
|
||||||
mock_exec.shutdown.assert_called_with(wait=False)
|
mock_exec.shutdown.assert_called_with(wait=False)
|
||||||
mock_event.set.assert_called_once()
|
|
||||||
mock_listener.join.assert_called_once_with(timeout=1.0)
|
# Same sequence on scheduler.shutdown()
|
||||||
|
scheduler.shutdown()
|
||||||
|
scheduler.shutdown()
|
||||||
|
scheduler.shutdown()
|
||||||
|
mock_listener.join.assert_called_once()
|
||||||
|
|
||||||
mock_rel_xfer.assert_called_once_with(123)
|
mock_rel_xfer.assert_called_once_with(123)
|
||||||
assert mock_rel_dlist.call_count == 2
|
assert mock_rel_dlist.call_count == 2
|
||||||
|
|||||||
@ -92,8 +92,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
|
|||||||
num_common_prefix_blocks=[],
|
num_common_prefix_blocks=[],
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_mm_hashes=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids=[],
|
|
||||||
grammar_bitmask=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -171,8 +169,6 @@ def test_update_states_request_finished(model_runner):
|
|||||||
num_common_prefix_blocks=[],
|
num_common_prefix_blocks=[],
|
||||||
finished_req_ids={req_id},
|
finished_req_ids={req_id},
|
||||||
free_encoder_mm_hashes=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids=[],
|
|
||||||
grammar_bitmask=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
model_runner._update_states(scheduler_output)
|
model_runner._update_states(scheduler_output)
|
||||||
@ -201,8 +197,6 @@ def test_update_states_request_resumed(model_runner):
|
|||||||
num_common_prefix_blocks=[],
|
num_common_prefix_blocks=[],
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_mm_hashes=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids=[],
|
|
||||||
grammar_bitmask=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
model_runner._update_states(scheduler_output)
|
model_runner._update_states(scheduler_output)
|
||||||
@ -230,8 +224,6 @@ def test_update_states_request_resumed(model_runner):
|
|||||||
num_common_prefix_blocks=[],
|
num_common_prefix_blocks=[],
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_mm_hashes=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids=[],
|
|
||||||
grammar_bitmask=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
model_runner._update_states(scheduler_output)
|
model_runner._update_states(scheduler_output)
|
||||||
@ -261,8 +253,6 @@ def test_update_states_no_changes(model_runner):
|
|||||||
num_common_prefix_blocks=[],
|
num_common_prefix_blocks=[],
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_mm_hashes=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids=[],
|
|
||||||
grammar_bitmask=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
model_runner._update_states(scheduler_output)
|
model_runner._update_states(scheduler_output)
|
||||||
@ -296,8 +286,6 @@ def test_update_states_request_unscheduled(model_runner):
|
|||||||
num_common_prefix_blocks=[],
|
num_common_prefix_blocks=[],
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_mm_hashes=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids=[],
|
|
||||||
grammar_bitmask=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
model_runner._update_states(scheduler_output)
|
model_runner._update_states(scheduler_output)
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.attention import Attention
|
from vllm.attention import Attention
|
||||||
|
from vllm.attention.backends.abstract import MultipleOf
|
||||||
from vllm.config import (
|
from vllm.config import (
|
||||||
CacheConfig,
|
CacheConfig,
|
||||||
ModelConfig,
|
ModelConfig,
|
||||||
@ -34,6 +35,7 @@ from vllm.v1.kv_cache_interface import (
|
|||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||||
|
from vllm.v1.worker.utils import AttentionGroup
|
||||||
|
|
||||||
BLOCK_SIZE = 16
|
BLOCK_SIZE = 16
|
||||||
NUM_BLOCKS = 10
|
NUM_BLOCKS = 10
|
||||||
@ -150,8 +152,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
|
|||||||
num_common_prefix_blocks=[],
|
num_common_prefix_blocks=[],
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_mm_hashes=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids=[],
|
|
||||||
grammar_bitmask=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -181,6 +181,57 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
|
|||||||
).all()
|
).all()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_backend_for_kernel_block_size(
|
||||||
|
supported_sizes: list[int | MultipleOf],
|
||||||
|
):
|
||||||
|
class _MockBackend:
|
||||||
|
@staticmethod
|
||||||
|
def get_supported_kernel_block_size():
|
||||||
|
return supported_sizes
|
||||||
|
|
||||||
|
return _MockBackend()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_kv_cache_spec() -> FullAttentionSpec:
|
||||||
|
return FullAttentionSpec(block_size=1, num_kv_heads=1, head_size=1, dtype="float16")
|
||||||
|
|
||||||
|
|
||||||
|
def test_select_common_block_size_prefers_manager_block_size():
|
||||||
|
backend_a = _make_mock_backend_for_kernel_block_size([MultipleOf(32)])
|
||||||
|
backend_b = _make_mock_backend_for_kernel_block_size([64, MultipleOf(16)])
|
||||||
|
attn_groups = [
|
||||||
|
AttentionGroup(backend_a, [], [], _make_kv_cache_spec(), 0),
|
||||||
|
AttentionGroup(backend_b, [], [], _make_kv_cache_spec(), 0),
|
||||||
|
]
|
||||||
|
|
||||||
|
selected_size = GPUModelRunner.select_common_block_size(128, attn_groups)
|
||||||
|
assert selected_size == 128
|
||||||
|
|
||||||
|
|
||||||
|
def test_select_common_block_size_uses_largest_shared_int():
|
||||||
|
backend_a = _make_mock_backend_for_kernel_block_size([128, 64])
|
||||||
|
backend_b = _make_mock_backend_for_kernel_block_size([64, 32])
|
||||||
|
attn_groups = [
|
||||||
|
AttentionGroup(backend_a, [], [], _make_kv_cache_spec(), 0),
|
||||||
|
AttentionGroup(backend_b, [], [], _make_kv_cache_spec(), 0),
|
||||||
|
]
|
||||||
|
|
||||||
|
selected_size = GPUModelRunner.select_common_block_size(256, attn_groups)
|
||||||
|
assert selected_size == 64
|
||||||
|
|
||||||
|
|
||||||
|
def test_select_common_block_size_no_valid_option():
|
||||||
|
backend_a = _make_mock_backend_for_kernel_block_size([64])
|
||||||
|
backend_b = _make_mock_backend_for_kernel_block_size([MultipleOf(16)])
|
||||||
|
attn_groups = [
|
||||||
|
AttentionGroup(backend_a, [], [], _make_kv_cache_spec(), 0),
|
||||||
|
AttentionGroup(backend_b, [], [], _make_kv_cache_spec(), 0),
|
||||||
|
]
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
GPUModelRunner.select_common_block_size(48, attn_groups)
|
||||||
|
|
||||||
|
|
||||||
def test_update_states_new_request(model_runner, dist_init):
|
def test_update_states_new_request(model_runner, dist_init):
|
||||||
req_id = "req_0"
|
req_id = "req_0"
|
||||||
|
|
||||||
@ -216,8 +267,6 @@ def test_update_states_request_finished(model_runner, dist_init):
|
|||||||
num_common_prefix_blocks=[],
|
num_common_prefix_blocks=[],
|
||||||
finished_req_ids={req_id},
|
finished_req_ids={req_id},
|
||||||
free_encoder_mm_hashes=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids=[],
|
|
||||||
grammar_bitmask=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata_before = model_runner.input_batch.sampling_metadata
|
metadata_before = model_runner.input_batch.sampling_metadata
|
||||||
@ -248,8 +297,6 @@ def test_update_states_request_resumed(model_runner, dist_init):
|
|||||||
num_common_prefix_blocks=[],
|
num_common_prefix_blocks=[],
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_mm_hashes=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids=[],
|
|
||||||
grammar_bitmask=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
model_runner._update_states(scheduler_output)
|
model_runner._update_states(scheduler_output)
|
||||||
@ -277,8 +324,6 @@ def test_update_states_request_resumed(model_runner, dist_init):
|
|||||||
num_common_prefix_blocks=[],
|
num_common_prefix_blocks=[],
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_mm_hashes=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids=[],
|
|
||||||
grammar_bitmask=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata_before = model_runner.input_batch.sampling_metadata
|
metadata_before = model_runner.input_batch.sampling_metadata
|
||||||
@ -370,8 +415,6 @@ def test_update_states_no_changes(model_runner, dist_init):
|
|||||||
num_common_prefix_blocks=[],
|
num_common_prefix_blocks=[],
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_mm_hashes=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids=[],
|
|
||||||
grammar_bitmask=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata_before = model_runner.input_batch.sampling_metadata
|
metadata_before = model_runner.input_batch.sampling_metadata
|
||||||
@ -407,8 +450,6 @@ def test_update_states_request_unscheduled(model_runner, dist_init):
|
|||||||
num_common_prefix_blocks=[],
|
num_common_prefix_blocks=[],
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_mm_hashes=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids=[],
|
|
||||||
grammar_bitmask=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata_before = model_runner._update_states(scheduler_output)
|
metadata_before = model_runner._update_states(scheduler_output)
|
||||||
|
|||||||
@ -270,21 +270,23 @@ class ipex_ops:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def flash_attn_varlen_func(
|
def flash_attn_varlen_func(
|
||||||
out: torch.Tensor,
|
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
k: torch.Tensor,
|
k: torch.Tensor,
|
||||||
v: torch.Tensor,
|
v: torch.Tensor,
|
||||||
cu_seqlens_q: torch.Tensor,
|
cu_seqlens_q: torch.Tensor,
|
||||||
seqused_k: torch.Tensor, # we don't support this in ipex kernel
|
|
||||||
max_seqlen_q: int,
|
max_seqlen_q: int,
|
||||||
max_seqlen_k: int,
|
max_seqlen_k: int,
|
||||||
softmax_scale: float,
|
softmax_scale: float | None = None,
|
||||||
causal: bool,
|
causal: bool = False,
|
||||||
block_table: torch.Tensor,
|
out: torch.Tensor | None = None,
|
||||||
alibi_slopes: torch.Tensor | None,
|
block_table: torch.Tensor | None = None,
|
||||||
|
alibi_slopes: torch.Tensor | None = None,
|
||||||
window_size: list[int] | None = None,
|
window_size: list[int] | None = None,
|
||||||
softcap: float | None = 0.0,
|
softcap: float | None = 0.0,
|
||||||
|
seqused_k: torch.Tensor | None = None,
|
||||||
cu_seqlens_k: torch.Tensor | None = None,
|
cu_seqlens_k: torch.Tensor | None = None,
|
||||||
|
# passed in qwen vl
|
||||||
|
dropout_p: float = 0.0,
|
||||||
# The following parameters are not used in ipex kernel currently,
|
# The following parameters are not used in ipex kernel currently,
|
||||||
# we keep API compatible to CUDA's.
|
# we keep API compatible to CUDA's.
|
||||||
scheduler_metadata=None,
|
scheduler_metadata=None,
|
||||||
@ -295,31 +297,63 @@ class ipex_ops:
|
|||||||
num_splits=0,
|
num_splits=0,
|
||||||
s_aux: torch.Tensor | None = None,
|
s_aux: torch.Tensor | None = None,
|
||||||
):
|
):
|
||||||
|
if out is None:
|
||||||
|
out = torch.empty(q.shape, dtype=q.dtype, device=q.device)
|
||||||
real_window_size: tuple[int, int]
|
real_window_size: tuple[int, int]
|
||||||
if window_size is None:
|
if window_size is None:
|
||||||
real_window_size = (-1, -1)
|
real_window_size = (-1, -1)
|
||||||
else:
|
else:
|
||||||
assert len(window_size) == 2
|
assert len(window_size) == 2
|
||||||
real_window_size = (window_size[0], window_size[1])
|
real_window_size = (window_size[0], window_size[1])
|
||||||
return ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
|
||||||
out,
|
if block_table is None:
|
||||||
q.contiguous(),
|
assert cu_seqlens_k is not None, (
|
||||||
k,
|
"cu_seqlens_k can't be None when calling varlen_attention."
|
||||||
v,
|
)
|
||||||
cu_seqlens_q,
|
if softmax_scale is None:
|
||||||
seqused_k,
|
softmax_scale = q.shape[-1] ** (-0.5)
|
||||||
max_seqlen_q,
|
ipex_ops.varlen_attention(
|
||||||
max_seqlen_k,
|
q.contiguous(),
|
||||||
softmax_scale,
|
k.contiguous(),
|
||||||
causal,
|
v.contiguous(),
|
||||||
block_table,
|
out,
|
||||||
alibi_slopes,
|
cu_seqlens_q,
|
||||||
softcap=softcap,
|
cu_seqlens_k,
|
||||||
window_size_left=real_window_size[0],
|
None,
|
||||||
window_size_right=real_window_size[1],
|
max_seqlen_q,
|
||||||
k_scale=1.0,
|
max_seqlen_k,
|
||||||
v_scale=1.0,
|
0.0,
|
||||||
)
|
softmax_scale,
|
||||||
|
False,
|
||||||
|
causal,
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
real_window_size[0],
|
||||||
|
real_window_size[1],
|
||||||
|
-1,
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
else:
|
||||||
|
return ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
||||||
|
out,
|
||||||
|
q.contiguous(),
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
cu_seqlens_q,
|
||||||
|
seqused_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
softmax_scale,
|
||||||
|
causal,
|
||||||
|
block_table,
|
||||||
|
alibi_slopes,
|
||||||
|
sink=s_aux,
|
||||||
|
softcap=softcap,
|
||||||
|
window_size_left=real_window_size[0],
|
||||||
|
window_size_right=real_window_size[1],
|
||||||
|
k_scale=1.0,
|
||||||
|
v_scale=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_scheduler_metadata(
|
def get_scheduler_metadata(
|
||||||
|
|||||||
@ -123,6 +123,11 @@ def maybe_get_vit_flash_attn_backend(
|
|||||||
):
|
):
|
||||||
attn_backend = _Backend.FLASH_ATTN
|
attn_backend = _Backend.FLASH_ATTN
|
||||||
use_upstream_fa = True
|
use_upstream_fa = True
|
||||||
|
elif current_platform.is_xpu():
|
||||||
|
assert attn_backend == _Backend.FLASH_ATTN, (
|
||||||
|
"XPU platform only supports FLASH_ATTN as vision attention backend."
|
||||||
|
)
|
||||||
|
use_upstream_fa = False
|
||||||
else:
|
else:
|
||||||
return _Backend.TORCH_SDPA, None
|
return _Backend.TORCH_SDPA, None
|
||||||
|
|
||||||
@ -133,7 +138,7 @@ def maybe_get_vit_flash_attn_backend(
|
|||||||
if use_upstream_fa:
|
if use_upstream_fa:
|
||||||
from flash_attn import flash_attn_varlen_func
|
from flash_attn import flash_attn_varlen_func
|
||||||
else:
|
else:
|
||||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
|
||||||
else:
|
else:
|
||||||
flash_attn_varlen_func = None
|
flash_attn_varlen_func = None
|
||||||
|
|
||||||
@ -521,22 +526,18 @@ class MultiHeadAttention(nn.Module):
|
|||||||
# If vllm native fa is selected, we use it directly.
|
# If vllm native fa is selected, we use it directly.
|
||||||
use_upstream_fa = False
|
use_upstream_fa = False
|
||||||
|
|
||||||
if current_platform.is_xpu():
|
self.attn_backend = (
|
||||||
# currently, only torch_sdpa is supported on xpu
|
backend
|
||||||
self.attn_backend = _Backend.TORCH_SDPA
|
if backend
|
||||||
else:
|
in {
|
||||||
self.attn_backend = (
|
_Backend.TORCH_SDPA,
|
||||||
backend
|
_Backend.XFORMERS,
|
||||||
if backend
|
_Backend.PALLAS,
|
||||||
in {
|
_Backend.ROCM_AITER_FA,
|
||||||
_Backend.TORCH_SDPA,
|
_Backend.FLASH_ATTN,
|
||||||
_Backend.XFORMERS,
|
}
|
||||||
_Backend.PALLAS,
|
else _Backend.TORCH_SDPA
|
||||||
_Backend.ROCM_AITER_FA,
|
)
|
||||||
_Backend.FLASH_ATTN,
|
|
||||||
}
|
|
||||||
else _Backend.TORCH_SDPA
|
|
||||||
)
|
|
||||||
|
|
||||||
self.attn_backend, self._flash_attn_varlen_func = (
|
self.attn_backend, self._flash_attn_varlen_func = (
|
||||||
maybe_get_vit_flash_attn_backend(
|
maybe_get_vit_flash_attn_backend(
|
||||||
|
|||||||
@ -70,7 +70,7 @@ def flash_attn_maxseqlen_wrapper(
|
|||||||
if use_upstream_fa:
|
if use_upstream_fa:
|
||||||
from flash_attn import flash_attn_varlen_func
|
from flash_attn import flash_attn_varlen_func
|
||||||
else:
|
else:
|
||||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
|
||||||
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||||
output = flash_attn_varlen_func(
|
output = flash_attn_varlen_func(
|
||||||
q,
|
q,
|
||||||
|
|||||||
@ -24,6 +24,7 @@ BlockSize = Literal[1, 8, 16, 32, 64, 128, 256]
|
|||||||
CacheDType = Literal["auto", "bfloat16", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"]
|
CacheDType = Literal["auto", "bfloat16", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"]
|
||||||
MambaDType = Literal["auto", "float32"]
|
MambaDType = Literal["auto", "float32"]
|
||||||
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"]
|
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"]
|
||||||
|
KVOffloadingBackend = Literal["native", "lmcache"]
|
||||||
|
|
||||||
|
|
||||||
@config
|
@config
|
||||||
@ -128,6 +129,17 @@ class CacheConfig:
|
|||||||
gpu_memory_utilization. Note that kv_cache_memory_bytes
|
gpu_memory_utilization. Note that kv_cache_memory_bytes
|
||||||
(when not-None) ignores gpu_memory_utilization"""
|
(when not-None) ignores gpu_memory_utilization"""
|
||||||
|
|
||||||
|
kv_offloading_size: float | None = None
|
||||||
|
"""Size of the KV cache offloading buffer in GiB. When TP > 1, this is
|
||||||
|
the total buffer size summed across all TP ranks. By default, this is set
|
||||||
|
to None, which means no KV offloading is enabled. When set with
|
||||||
|
kv_offloading_backend, vLLM will enable KV cache offloading to CPU"""
|
||||||
|
|
||||||
|
kv_offloading_backend: KVOffloadingBackend | None = None
|
||||||
|
"""The backend to use for KV cache offloading. Supported backends include
|
||||||
|
'native' (vLLM native CPU offloading), 'lmcache' This option must be used
|
||||||
|
together with kv_offloading_size."""
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
"""
|
"""
|
||||||
WARNING: Whenever a new field is added to this config,
|
WARNING: Whenever a new field is added to this config,
|
||||||
|
|||||||
@ -2,10 +2,11 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
from dataclasses import InitVar, field
|
from collections.abc import Callable
|
||||||
|
from dataclasses import InitVar
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
from pydantic import SkipValidation, model_validator
|
from pydantic import Field, field_validator, model_validator
|
||||||
from pydantic.dataclasses import dataclass
|
from pydantic.dataclasses import dataclass
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
@ -31,28 +32,28 @@ class SchedulerConfig:
|
|||||||
runner_type: RunnerType = "generate"
|
runner_type: RunnerType = "generate"
|
||||||
"""The runner type to launch for the model."""
|
"""The runner type to launch for the model."""
|
||||||
|
|
||||||
max_num_batched_tokens: SkipValidation[int] = None # type: ignore
|
max_num_batched_tokens: int = Field(default=None, ge=1)
|
||||||
"""Maximum number of tokens to be processed in a single iteration.
|
"""Maximum number of tokens to be processed in a single iteration.
|
||||||
|
|
||||||
This config has no static default. If left unspecified by the user, it will
|
This config has no static default. If left unspecified by the user, it will
|
||||||
be set in `EngineArgs.create_engine_config` based on the usage context."""
|
be set in `EngineArgs.create_engine_config` based on the usage context."""
|
||||||
|
|
||||||
max_num_seqs: SkipValidation[int] = None # type: ignore
|
max_num_seqs: int = Field(default=None, ge=1)
|
||||||
"""Maximum number of sequences to be processed in a single iteration.
|
"""Maximum number of sequences to be processed in a single iteration.
|
||||||
|
|
||||||
This config has no static default. If left unspecified by the user, it will
|
This config has no static default. If left unspecified by the user, it will
|
||||||
be set in `EngineArgs.create_engine_config` based on the usage context."""
|
be set in `EngineArgs.create_engine_config` based on the usage context."""
|
||||||
|
|
||||||
max_model_len: SkipValidation[int] = None # type: ignore
|
max_model_len: int = Field(default=None, ge=1)
|
||||||
"""Maximum length of a sequence (including prompt and generated text). This
|
"""Maximum length of a sequence (including prompt and generated text). This
|
||||||
is primarily set in `ModelConfig` and that value should be manually
|
is primarily set in `ModelConfig` and that value should be manually
|
||||||
duplicated here."""
|
duplicated here."""
|
||||||
|
|
||||||
max_num_partial_prefills: int = 1
|
max_num_partial_prefills: int = Field(default=1, ge=1)
|
||||||
"""For chunked prefill, the maximum number of sequences that can be
|
"""For chunked prefill, the maximum number of sequences that can be
|
||||||
partially prefilled concurrently."""
|
partially prefilled concurrently."""
|
||||||
|
|
||||||
max_long_partial_prefills: int = 1
|
max_long_partial_prefills: int = Field(default=1, ge=1)
|
||||||
"""For chunked prefill, the maximum number of prompts longer than
|
"""For chunked prefill, the maximum number of prompts longer than
|
||||||
long_prefill_token_threshold that will be prefilled concurrently. Setting
|
long_prefill_token_threshold that will be prefilled concurrently. Setting
|
||||||
this less than max_num_partial_prefills will allow shorter prompts to jump
|
this less than max_num_partial_prefills will allow shorter prompts to jump
|
||||||
@ -62,7 +63,7 @@ class SchedulerConfig:
|
|||||||
"""For chunked prefill, a request is considered long if the prompt is
|
"""For chunked prefill, a request is considered long if the prompt is
|
||||||
longer than this number of tokens."""
|
longer than this number of tokens."""
|
||||||
|
|
||||||
num_lookahead_slots: int = 0
|
num_lookahead_slots: int = Field(default=0, ge=0)
|
||||||
"""The number of slots to allocate per sequence per
|
"""The number of slots to allocate per sequence per
|
||||||
step, beyond the known token ids. This is used in speculative
|
step, beyond the known token ids. This is used in speculative
|
||||||
decoding to store KV activations of tokens which may or may not be
|
decoding to store KV activations of tokens which may or may not be
|
||||||
@ -71,7 +72,7 @@ class SchedulerConfig:
|
|||||||
NOTE: This will be replaced by speculative config in the future; it is
|
NOTE: This will be replaced by speculative config in the future; it is
|
||||||
present to enable correctness tests until then."""
|
present to enable correctness tests until then."""
|
||||||
|
|
||||||
enable_chunked_prefill: SkipValidation[bool] = None # type: ignore
|
enable_chunked_prefill: bool = Field(default=None)
|
||||||
"""If True, prefill requests can be chunked based
|
"""If True, prefill requests can be chunked based
|
||||||
on the remaining max_num_batched_tokens."""
|
on the remaining max_num_batched_tokens."""
|
||||||
|
|
||||||
@ -86,14 +87,14 @@ class SchedulerConfig:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# TODO (ywang96): Make this configurable.
|
# TODO (ywang96): Make this configurable.
|
||||||
max_num_encoder_input_tokens: int = field(init=False)
|
max_num_encoder_input_tokens: int = Field(init=False)
|
||||||
"""Multimodal encoder compute budget, only used in V1.
|
"""Multimodal encoder compute budget, only used in V1.
|
||||||
|
|
||||||
NOTE: This is not currently configurable. It will be overridden by
|
NOTE: This is not currently configurable. It will be overridden by
|
||||||
max_num_batched_tokens in case max multimodal embedding size is larger."""
|
max_num_batched_tokens in case max multimodal embedding size is larger."""
|
||||||
|
|
||||||
# TODO (ywang96): Make this configurable.
|
# TODO (ywang96): Make this configurable.
|
||||||
encoder_cache_size: int = field(init=False)
|
encoder_cache_size: int = Field(init=False)
|
||||||
"""Multimodal encoder cache size, only used in V1.
|
"""Multimodal encoder cache size, only used in V1.
|
||||||
|
|
||||||
NOTE: This is not currently configurable. It will be overridden by
|
NOTE: This is not currently configurable. It will be overridden by
|
||||||
@ -106,7 +107,7 @@ class SchedulerConfig:
|
|||||||
- "priority" means requests are handled based on given priority (lower
|
- "priority" means requests are handled based on given priority (lower
|
||||||
value means earlier handling) and time of arrival deciding any ties)."""
|
value means earlier handling) and time of arrival deciding any ties)."""
|
||||||
|
|
||||||
chunked_prefill_enabled: bool = field(init=False)
|
chunked_prefill_enabled: bool = Field(init=False)
|
||||||
"""True if chunked prefill is enabled."""
|
"""True if chunked prefill is enabled."""
|
||||||
|
|
||||||
disable_chunked_mm_input: bool = False
|
disable_chunked_mm_input: bool = False
|
||||||
@ -155,6 +156,20 @@ class SchedulerConfig:
|
|||||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||||
return hash_str
|
return hash_str
|
||||||
|
|
||||||
|
@field_validator(
|
||||||
|
"max_num_batched_tokens",
|
||||||
|
"max_num_seqs",
|
||||||
|
"max_model_len",
|
||||||
|
"enable_chunked_prefill",
|
||||||
|
mode="wrap",
|
||||||
|
)
|
||||||
|
@classmethod
|
||||||
|
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
|
||||||
|
"""Skip validation if the value is `None` when initialisation is delayed."""
|
||||||
|
if value is None:
|
||||||
|
return value
|
||||||
|
return handler(value)
|
||||||
|
|
||||||
def __post_init__(self, is_encoder_decoder: bool) -> None:
|
def __post_init__(self, is_encoder_decoder: bool) -> None:
|
||||||
if self.max_model_len is None:
|
if self.max_model_len is None:
|
||||||
self.max_model_len = 8192
|
self.max_model_len = 8192
|
||||||
@ -260,19 +275,7 @@ class SchedulerConfig:
|
|||||||
self.max_num_seqs * self.max_model_len,
|
self.max_num_seqs * self.max_model_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.num_lookahead_slots < 0:
|
if self.max_num_partial_prefills > 1:
|
||||||
raise ValueError(
|
|
||||||
"num_lookahead_slots "
|
|
||||||
f"({self.num_lookahead_slots}) must be greater than or "
|
|
||||||
"equal to 0."
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.max_num_partial_prefills < 1:
|
|
||||||
raise ValueError(
|
|
||||||
f"max_num_partial_prefills ({self.max_num_partial_prefills}) "
|
|
||||||
"must be greater than or equal to 1."
|
|
||||||
)
|
|
||||||
elif self.max_num_partial_prefills > 1:
|
|
||||||
if not self.chunked_prefill_enabled:
|
if not self.chunked_prefill_enabled:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Chunked prefill must be enabled to set "
|
"Chunked prefill must be enabled to set "
|
||||||
@ -286,13 +289,10 @@ class SchedulerConfig:
|
|||||||
f"than the max_model_len ({self.max_model_len})."
|
f"than the max_model_len ({self.max_model_len})."
|
||||||
)
|
)
|
||||||
|
|
||||||
if (self.max_long_partial_prefills < 1) or (
|
if self.max_long_partial_prefills > self.max_num_partial_prefills:
|
||||||
self.max_long_partial_prefills > self.max_num_partial_prefills
|
|
||||||
):
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"max_long_partial_prefills ({self.max_long_partial_prefills}) "
|
f"{self.max_long_partial_prefills=} must be less than or equal to "
|
||||||
"must be greater than or equal to 1 and less than or equal to "
|
f"{self.max_num_partial_prefills=}."
|
||||||
f"max_num_partial_prefills ({self.max_num_partial_prefills})."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|||||||
@ -78,10 +78,6 @@ class SpeculativeConfig:
|
|||||||
draft_tensor_parallel_size: int | None = Field(default=None, ge=1)
|
draft_tensor_parallel_size: int | None = Field(default=None, ge=1)
|
||||||
"""The degree of the tensor parallelism for the draft model. Can only be 1
|
"""The degree of the tensor parallelism for the draft model. Can only be 1
|
||||||
or the same as the target model's tensor parallel size."""
|
or the same as the target model's tensor parallel size."""
|
||||||
disable_logprobs: bool = True
|
|
||||||
"""If set to True, token log probabilities are not returned during
|
|
||||||
speculative decoding. If set to False, token log probabilities are returned
|
|
||||||
according to the log probability settings in SamplingParams."""
|
|
||||||
|
|
||||||
# Draft model configuration
|
# Draft model configuration
|
||||||
quantization: me_quant.QuantizationMethods | None = None
|
quantization: me_quant.QuantizationMethods | None = None
|
||||||
@ -126,12 +122,6 @@ class SpeculativeConfig:
|
|||||||
"""The configuration of the target model."""
|
"""The configuration of the target model."""
|
||||||
target_parallel_config: SkipValidation[ParallelConfig] = None # type: ignore
|
target_parallel_config: SkipValidation[ParallelConfig] = None # type: ignore
|
||||||
"""The parallel configuration for the target model."""
|
"""The parallel configuration for the target model."""
|
||||||
enable_chunked_prefill: SkipValidation[bool] = None # type: ignore
|
|
||||||
"""Whether vLLM is configured to use chunked prefill or not. Used for
|
|
||||||
raising an error since it's not yet compatible with speculative decode."""
|
|
||||||
disable_log_stats: SkipValidation[bool] = None # type: ignore
|
|
||||||
"""Whether to disable the periodic printing of stage times in speculative
|
|
||||||
decoding."""
|
|
||||||
|
|
||||||
# params generated in the post-init stage
|
# params generated in the post-init stage
|
||||||
draft_model_config: SkipValidation[ModelConfig] = None # type: ignore
|
draft_model_config: SkipValidation[ModelConfig] = None # type: ignore
|
||||||
|
|||||||
@ -2,8 +2,9 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal, Self
|
||||||
|
|
||||||
|
from pydantic import model_validator
|
||||||
from pydantic.dataclasses import dataclass
|
from pydantic.dataclasses import dataclass
|
||||||
|
|
||||||
from vllm.config.utils import config
|
from vllm.config.utils import config
|
||||||
@ -56,7 +57,8 @@ class StructuredOutputsConfig:
|
|||||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||||
return hash_str
|
return hash_str
|
||||||
|
|
||||||
def __post_init__(self):
|
@model_validator(mode="after")
|
||||||
|
def _validate_structured_output_config(self) -> Self:
|
||||||
if self.disable_any_whitespace and self.backend not in ("xgrammar", "guidance"):
|
if self.disable_any_whitespace and self.backend not in ("xgrammar", "guidance"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"disable_any_whitespace is only supported for "
|
"disable_any_whitespace is only supported for "
|
||||||
@ -67,3 +69,4 @@ class StructuredOutputsConfig:
|
|||||||
"disable_additional_properties is only supported "
|
"disable_additional_properties is only supported "
|
||||||
"for the guidance backend."
|
"for the guidance backend."
|
||||||
)
|
)
|
||||||
|
return self
|
||||||
|
|||||||
@ -289,6 +289,48 @@ class VllmConfig:
|
|||||||
|
|
||||||
return replace(self, model_config=model_config)
|
return replace(self, model_config=model_config)
|
||||||
|
|
||||||
|
def _post_init_kv_transfer_config(self) -> None:
|
||||||
|
"""Update KVTransferConfig based on top-level configs in VllmConfig.
|
||||||
|
|
||||||
|
Right now, this function reads the offloading settings from
|
||||||
|
CacheConfig and configures the KVTransferConfig accordingly.
|
||||||
|
"""
|
||||||
|
if (kv_offloading_backend := self.cache_config.kv_offloading_backend) is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# If no KVTransferConfig is provided, create a default one.
|
||||||
|
if self.kv_transfer_config is None:
|
||||||
|
self.kv_transfer_config = KVTransferConfig()
|
||||||
|
|
||||||
|
if (kv_offloading_size := self.cache_config.kv_offloading_size) is None:
|
||||||
|
raise ValueError(
|
||||||
|
"You must set kv_offloading_size when kv_offloading_backend is set."
|
||||||
|
)
|
||||||
|
num_kv_ranks = (
|
||||||
|
self.parallel_config.tensor_parallel_size
|
||||||
|
* self.parallel_config.pipeline_parallel_size
|
||||||
|
)
|
||||||
|
|
||||||
|
if kv_offloading_backend == "native":
|
||||||
|
self.kv_transfer_config.kv_connector = "OffloadingConnector"
|
||||||
|
kv_bytes_per_rank = kv_offloading_size * (1 << 30) / num_kv_ranks
|
||||||
|
|
||||||
|
# NOTE(ApostaC): the actual calculation for num_cpu_blocks should be
|
||||||
|
# done after the model's KV cache is initialized
|
||||||
|
self.kv_transfer_config.kv_connector_extra_config.update(
|
||||||
|
{"kv_bytes_per_rank": kv_bytes_per_rank, "num_cpu_blocks": 0}
|
||||||
|
)
|
||||||
|
elif kv_offloading_backend == "lmcache":
|
||||||
|
self.kv_transfer_config.kv_connector = "LMCacheConnectorV1"
|
||||||
|
kv_gb_per_rank = kv_offloading_size / num_kv_ranks
|
||||||
|
self.kv_transfer_config.kv_connector_extra_config = {
|
||||||
|
"lmcache.local_cpu": True,
|
||||||
|
"lmcache.max_local_cpu_size": kv_gb_per_rank,
|
||||||
|
}
|
||||||
|
|
||||||
|
# This is the same for all backends
|
||||||
|
self.kv_transfer_config.kv_role = "kv_both"
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""Verify configs are valid & consistent with each other."""
|
"""Verify configs are valid & consistent with each other."""
|
||||||
|
|
||||||
@ -646,6 +688,9 @@ class VllmConfig:
|
|||||||
if "-quant_fp8" not in custom_ops:
|
if "-quant_fp8" not in custom_ops:
|
||||||
custom_ops.append("+quant_fp8")
|
custom_ops.append("+quant_fp8")
|
||||||
|
|
||||||
|
# Handle the KV connector configs
|
||||||
|
self._post_init_kv_transfer_config()
|
||||||
|
|
||||||
def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list:
|
def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list:
|
||||||
# remove the sizes that not multiple of tp_size when
|
# remove the sizes that not multiple of tp_size when
|
||||||
# enable sequence parallelism
|
# enable sequence parallelism
|
||||||
|
|||||||
@ -6,7 +6,7 @@ KV cache helper for store.
|
|||||||
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from concurrent.futures import CancelledError, Future
|
from concurrent.futures import CancelledError, Future
|
||||||
from typing import TYPE_CHECKING, Literal, cast
|
from typing import TYPE_CHECKING, Literal
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -138,8 +138,11 @@ class KVOutputAggregator:
|
|||||||
return cls(connector.get_finished_count() or world_size)
|
return cls(connector.get_finished_count() or world_size)
|
||||||
|
|
||||||
def aggregate(
|
def aggregate(
|
||||||
self, outputs: list[ModelRunnerOutput], output_rank: int = 0
|
self, outputs: list[ModelRunnerOutput | None], output_rank: int = 0
|
||||||
) -> ModelRunnerOutput:
|
) -> ModelRunnerOutput | None:
|
||||||
|
if not outputs[output_rank]:
|
||||||
|
return None
|
||||||
|
|
||||||
# Aggregate kv_connector_output from all workers
|
# Aggregate kv_connector_output from all workers
|
||||||
|
|
||||||
def update_finished_set(
|
def update_finished_set(
|
||||||
@ -161,6 +164,7 @@ class KVOutputAggregator:
|
|||||||
aggregated_kv_connector_stats = None
|
aggregated_kv_connector_stats = None
|
||||||
invalid_block_ids = set[int]()
|
invalid_block_ids = set[int]()
|
||||||
for model_runner_output in outputs:
|
for model_runner_output in outputs:
|
||||||
|
assert model_runner_output is not None
|
||||||
kv_output = model_runner_output.kv_connector_output
|
kv_output = model_runner_output.kv_connector_output
|
||||||
if not kv_output:
|
if not kv_output:
|
||||||
continue
|
continue
|
||||||
@ -204,6 +208,7 @@ class KVOutputAggregator:
|
|||||||
# select output of the worker specified by output_rank
|
# select output of the worker specified by output_rank
|
||||||
output = outputs[output_rank]
|
output = outputs[output_rank]
|
||||||
|
|
||||||
|
assert output is not None
|
||||||
output.kv_connector_output = KVConnectorOutput(
|
output.kv_connector_output = KVConnectorOutput(
|
||||||
finished_sending=finished_sending or None,
|
finished_sending=finished_sending or None,
|
||||||
finished_recving=finished_recving or None,
|
finished_recving=finished_recving or None,
|
||||||
@ -215,13 +220,16 @@ class KVOutputAggregator:
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
def async_aggregate(
|
def async_aggregate(
|
||||||
self, output_futures: Sequence[Future[ModelRunnerOutput]], output_rank: int = 0
|
self,
|
||||||
) -> Future[ModelRunnerOutput]:
|
output_futures: Sequence[Future[ModelRunnerOutput | None]],
|
||||||
|
output_rank: int = 0,
|
||||||
|
) -> Future[ModelRunnerOutput | None]:
|
||||||
"""Takes a list of futures and returns a single future which resolves
|
"""Takes a list of futures and returns a single future which resolves
|
||||||
to the respective list of outputs."""
|
to the respective list of outputs."""
|
||||||
result_future: Future[ModelRunnerOutput] = Future()
|
result_future: Future[ModelRunnerOutput | None] = Future()
|
||||||
|
|
||||||
outputs: list[ModelRunnerOutput | None] = [None] * len(output_futures)
|
outputs: list[ModelRunnerOutput | None] = [None] * len(output_futures)
|
||||||
|
remaining = len(output_futures)
|
||||||
|
|
||||||
def make_callback(idx):
|
def make_callback(idx):
|
||||||
def callback(fut):
|
def callback(fut):
|
||||||
@ -236,12 +244,10 @@ class KVOutputAggregator:
|
|||||||
result_future.set_exception(e)
|
result_future.set_exception(e)
|
||||||
|
|
||||||
# this check assumes io_thread_pool uses a single thread
|
# this check assumes io_thread_pool uses a single thread
|
||||||
if all(outputs):
|
nonlocal remaining
|
||||||
result_future.set_result(
|
remaining -= 1
|
||||||
self.aggregate(
|
if not remaining:
|
||||||
cast(list[ModelRunnerOutput], outputs), output_rank
|
result_future.set_result(self.aggregate(outputs, output_rank))
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return callback
|
return callback
|
||||||
|
|
||||||
|
|||||||
@ -122,6 +122,15 @@ class KVConnectorRole(enum.Enum):
|
|||||||
WORKER = 1
|
WORKER = 1
|
||||||
|
|
||||||
|
|
||||||
|
class KVConnectorHandshakeMetadata(ABC): # noqa: B024
|
||||||
|
"""
|
||||||
|
Metadata used for out of band connector handshake between
|
||||||
|
P/D workers. This needs to serializeable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class KVConnectorMetadata(ABC): # noqa: B024
|
class KVConnectorMetadata(ABC): # noqa: B024
|
||||||
"""
|
"""
|
||||||
Abstract Metadata used to communicate between the
|
Abstract Metadata used to communicate between the
|
||||||
@ -320,6 +329,18 @@ class KVConnectorBase_V1(ABC):
|
|||||||
"""
|
"""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
|
||||||
|
"""
|
||||||
|
Get the KVConnector handshake metadata for this connector.
|
||||||
|
This metadata is used for out-of-band connector handshake
|
||||||
|
between P/D workers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
KVConnectorHandshakeMetadata: the handshake metadata.
|
||||||
|
None if no handshake metadata is available.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
# ==============================
|
# ==============================
|
||||||
# Scheduler-side methods
|
# Scheduler-side methods
|
||||||
# ==============================
|
# ==============================
|
||||||
@ -477,6 +498,17 @@ class KVConnectorBase_V1(ABC):
|
|||||||
"""
|
"""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def set_xfer_handshake_metadata(
|
||||||
|
self, metadata: dict[int, KVConnectorHandshakeMetadata]
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Set the KV connector handshake metadata for this connector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata (KVConnectorHandshakeMetadata): the handshake metadata to set.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build_prom_metrics(
|
def build_prom_metrics(
|
||||||
cls,
|
cls,
|
||||||
|
|||||||
@ -27,6 +27,7 @@ from vllm.config import VllmConfig
|
|||||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||||
CopyBlocksOp,
|
CopyBlocksOp,
|
||||||
KVConnectorBase_V1,
|
KVConnectorBase_V1,
|
||||||
|
KVConnectorHandshakeMetadata,
|
||||||
KVConnectorMetadata,
|
KVConnectorMetadata,
|
||||||
KVConnectorRole,
|
KVConnectorRole,
|
||||||
)
|
)
|
||||||
@ -93,15 +94,12 @@ _NIXL_SUPPORTED_DEVICE = {
|
|||||||
_NIXL_SUPPORTED_DEVICE.update(current_platform.get_nixl_supported_devices())
|
_NIXL_SUPPORTED_DEVICE.update(current_platform.get_nixl_supported_devices())
|
||||||
|
|
||||||
|
|
||||||
class NixlAgentMetadata(
|
@dataclass
|
||||||
msgspec.Struct,
|
class NixlAgentMetadata(KVConnectorHandshakeMetadata):
|
||||||
omit_defaults=True, # type: ignore[call-arg]
|
|
||||||
# required for @cached_property.
|
|
||||||
dict=True,
|
|
||||||
):
|
|
||||||
engine_id: str
|
engine_id: str
|
||||||
agent_metadata: bytes
|
agent_metadata: bytes
|
||||||
kv_caches_base_addr: list[int]
|
kv_caches_base_addr: list[int]
|
||||||
|
device_id: int
|
||||||
num_blocks: int
|
num_blocks: int
|
||||||
block_lens: list[int]
|
block_lens: list[int]
|
||||||
attn_backend_name: str
|
attn_backend_name: str
|
||||||
@ -223,6 +221,18 @@ class NixlConnector(KVConnectorBase_V1):
|
|||||||
assert self.connector_scheduler is not None
|
assert self.connector_scheduler is not None
|
||||||
return self.connector_scheduler.request_finished(request, block_ids)
|
return self.connector_scheduler.request_finished(request, block_ids)
|
||||||
|
|
||||||
|
def set_xfer_handshake_metadata(
|
||||||
|
self, metadata: dict[int, KVConnectorHandshakeMetadata]
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Set the KV connector handshake metadata for this connector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata (dict): the handshake metadata to set.
|
||||||
|
"""
|
||||||
|
assert self.connector_scheduler is not None
|
||||||
|
self.connector_scheduler.set_xfer_handshake_metadata(metadata)
|
||||||
|
|
||||||
############################################################
|
############################################################
|
||||||
# Worker Side Methods
|
# Worker Side Methods
|
||||||
############################################################
|
############################################################
|
||||||
@ -299,6 +309,21 @@ class NixlConnector(KVConnectorBase_V1):
|
|||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
if self.connector_worker is not None:
|
if self.connector_worker is not None:
|
||||||
self.connector_worker.shutdown()
|
self.connector_worker.shutdown()
|
||||||
|
if self.connector_scheduler is not None:
|
||||||
|
self.connector_scheduler.shutdown()
|
||||||
|
|
||||||
|
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
|
||||||
|
"""
|
||||||
|
Get the KVConnector handshake metadata for this connector.
|
||||||
|
This metadata is used for out-of-band connector handshake
|
||||||
|
between P/D workers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
KVConnectorHandshakeMetadata: the handshake metadata.
|
||||||
|
None if no handshake metadata is available.
|
||||||
|
"""
|
||||||
|
assert self.connector_worker is not None
|
||||||
|
return self.connector_worker.xfer_handshake_metadata
|
||||||
|
|
||||||
|
|
||||||
class NixlConnectorScheduler:
|
class NixlConnectorScheduler:
|
||||||
@ -312,12 +337,16 @@ class NixlConnectorScheduler:
|
|||||||
self.side_channel_port = (
|
self.side_channel_port = (
|
||||||
envs.VLLM_NIXL_SIDE_CHANNEL_PORT
|
envs.VLLM_NIXL_SIDE_CHANNEL_PORT
|
||||||
+ vllm_config.parallel_config.data_parallel_rank
|
+ vllm_config.parallel_config.data_parallel_rank
|
||||||
* vllm_config.parallel_config.tensor_parallel_size
|
|
||||||
)
|
)
|
||||||
assert vllm_config.kv_transfer_config is not None
|
assert vllm_config.kv_transfer_config is not None
|
||||||
self.use_host_buffer = vllm_config.kv_transfer_config.kv_buffer_device == "cpu"
|
self.use_host_buffer = vllm_config.kv_transfer_config.kv_buffer_device == "cpu"
|
||||||
logger.info("Initializing NIXL Scheduler %s", engine_id)
|
logger.info("Initializing NIXL Scheduler %s", engine_id)
|
||||||
|
|
||||||
|
# Background thread for handling new handshake requests.
|
||||||
|
self._nixl_handshake_listener_t: threading.Thread | None = None
|
||||||
|
self._encoded_xfer_handshake_metadata: dict[int, Any] = {}
|
||||||
|
self._stop_event = threading.Event()
|
||||||
|
|
||||||
# Requests that need to start recv/send.
|
# Requests that need to start recv/send.
|
||||||
# New requests are added by update_state_after_alloc in
|
# New requests are added by update_state_after_alloc in
|
||||||
# the scheduler. Used to make metadata passed to Worker.
|
# the scheduler. Used to make metadata passed to Worker.
|
||||||
@ -330,6 +359,89 @@ class NixlConnectorScheduler:
|
|||||||
# remote prefill or aborted.
|
# remote prefill or aborted.
|
||||||
self._reqs_not_processed: set[ReqId] = set()
|
self._reqs_not_processed: set[ReqId] = set()
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
self._stop_event.set()
|
||||||
|
if self._nixl_handshake_listener_t is not None:
|
||||||
|
self._nixl_handshake_listener_t.join()
|
||||||
|
self._nixl_handshake_listener_t = None
|
||||||
|
|
||||||
|
def set_xfer_handshake_metadata(
|
||||||
|
self, metadata: dict[int, KVConnectorHandshakeMetadata]
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Set the KV connector handshake metadata for this connector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata (dict): the handshake metadata to set.
|
||||||
|
"""
|
||||||
|
encoded_data: dict[int, bytes] = {}
|
||||||
|
encoder = msgspec.msgpack.Encoder()
|
||||||
|
for tp_rank, rank_metadata in metadata.items():
|
||||||
|
if not isinstance(rank_metadata, NixlAgentMetadata):
|
||||||
|
raise ValueError(
|
||||||
|
"NixlConnectorScheduler expects NixlAgentMetadata for "
|
||||||
|
"handshake metadata."
|
||||||
|
)
|
||||||
|
encoded_data[tp_rank] = encoder.encode(rank_metadata)
|
||||||
|
logger.debug(
|
||||||
|
"Tp rank %d: encoded NixlAgentMetadata size: %s bytes",
|
||||||
|
tp_rank,
|
||||||
|
str(len(encoded_data[tp_rank])),
|
||||||
|
)
|
||||||
|
self._encoded_xfer_handshake_metadata = encoded_data
|
||||||
|
|
||||||
|
# Only start the listener when we have metadata to serve.
|
||||||
|
if self._nixl_handshake_listener_t is None:
|
||||||
|
ready_event = threading.Event()
|
||||||
|
self._nixl_handshake_listener_t = threading.Thread(
|
||||||
|
target=self._nixl_handshake_listener,
|
||||||
|
args=(
|
||||||
|
encoded_data,
|
||||||
|
ready_event,
|
||||||
|
self._stop_event,
|
||||||
|
self.side_channel_port,
|
||||||
|
),
|
||||||
|
daemon=True,
|
||||||
|
name="nixl_handshake_listener",
|
||||||
|
)
|
||||||
|
self._nixl_handshake_listener_t.start()
|
||||||
|
ready_event.wait() # Wait for listener ZMQ socket to be ready.
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _nixl_handshake_listener(
|
||||||
|
encoded_data: dict[int, Any],
|
||||||
|
ready_event: threading.Event,
|
||||||
|
stop_event: threading.Event,
|
||||||
|
port: int,
|
||||||
|
):
|
||||||
|
"""Background thread for getting new NIXL handshakes."""
|
||||||
|
# NOTE(rob): this is a simple implementation. We will move
|
||||||
|
# to a better approach via HTTP endpoint soon.
|
||||||
|
|
||||||
|
# Listen for new requests for metadata.
|
||||||
|
host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
|
||||||
|
path = make_zmq_path("tcp", host, port)
|
||||||
|
logger.debug("Starting listening on path: %s", path)
|
||||||
|
with zmq_ctx(zmq.ROUTER, path) as sock:
|
||||||
|
sock.setsockopt(zmq.RCVTIMEO, 1000)
|
||||||
|
ready_event.set()
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
identity, _, msg = sock.recv_multipart()
|
||||||
|
except zmq.Again:
|
||||||
|
if stop_event.is_set():
|
||||||
|
break
|
||||||
|
continue
|
||||||
|
# Decode the message which contains (GET_META_MSG, rank)
|
||||||
|
msg, target_tp_rank = msgspec.msgpack.decode(msg)
|
||||||
|
logger.debug(
|
||||||
|
"Received message for tp rank %s",
|
||||||
|
target_tp_rank,
|
||||||
|
)
|
||||||
|
if msg != GET_META_MSG:
|
||||||
|
logger.warning("Connection listener got unexpected message %s", msg)
|
||||||
|
sock.send_multipart((identity, b"", encoded_data[target_tp_rank]))
|
||||||
|
|
||||||
def get_num_new_matched_tokens(
|
def get_num_new_matched_tokens(
|
||||||
self, request: "Request", num_computed_tokens: int
|
self, request: "Request", num_computed_tokens: int
|
||||||
) -> tuple[int, bool]:
|
) -> tuple[int, bool]:
|
||||||
@ -537,8 +649,6 @@ class NixlConnectorScheduler:
|
|||||||
class NixlConnectorWorker:
|
class NixlConnectorWorker:
|
||||||
"""Implementation of Worker side methods"""
|
"""Implementation of Worker side methods"""
|
||||||
|
|
||||||
_POLL_TIMEOUT = 0.1 # Handshake thread polls for stop event every 100ms
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TpKVTopology:
|
class TpKVTopology:
|
||||||
"""
|
"""
|
||||||
@ -651,16 +761,6 @@ class NixlConnectorWorker:
|
|||||||
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
|
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
|
||||||
self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict)
|
self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict)
|
||||||
|
|
||||||
# NIXL handshake port.
|
|
||||||
# NOTE(rob): Within a DP group, each DP rank gets its own
|
|
||||||
# base port (which is sent in the KVTransferParams).
|
|
||||||
# Each TP rank listens/queries on the base_port + tp_rank.
|
|
||||||
self.side_channel_port: int = (
|
|
||||||
envs.VLLM_NIXL_SIDE_CHANNEL_PORT
|
|
||||||
+ vllm_config.parallel_config.data_parallel_rank
|
|
||||||
* vllm_config.parallel_config.tensor_parallel_size
|
|
||||||
)
|
|
||||||
|
|
||||||
# Metadata.
|
# Metadata.
|
||||||
self.engine_id: EngineId = engine_id
|
self.engine_id: EngineId = engine_id
|
||||||
self.tp_rank = get_tensor_model_parallel_rank()
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
@ -706,6 +806,7 @@ class NixlConnectorWorker:
|
|||||||
# Map of engine_id -> kv_caches_base_addr. For TP case, each local
|
# Map of engine_id -> kv_caches_base_addr. For TP case, each local
|
||||||
# rank will still only pull from a single remote TP worker.
|
# rank will still only pull from a single remote TP worker.
|
||||||
self.kv_caches_base_addr: dict[EngineId, list[int]] = {}
|
self.kv_caches_base_addr: dict[EngineId, list[int]] = {}
|
||||||
|
self.device_id: int = 0
|
||||||
|
|
||||||
# Number of NIXL regions. Currently one region per cache
|
# Number of NIXL regions. Currently one region per cache
|
||||||
# (so 1 per layer for MLA, otherwise 2 per layer)
|
# (so 1 per layer for MLA, otherwise 2 per layer)
|
||||||
@ -736,9 +837,8 @@ class NixlConnectorWorker:
|
|||||||
# requests that skipped transfer (handshake or transfer failures)
|
# requests that skipped transfer (handshake or transfer failures)
|
||||||
self._failed_recv_reqs: set[ReqId] = set()
|
self._failed_recv_reqs: set[ReqId] = set()
|
||||||
|
|
||||||
# Background thread for handling new handshake requests.
|
# Handshake metadata of this worker for NIXL transfers.
|
||||||
self._nixl_handshake_listener_t: threading.Thread | None = None
|
self.xfer_handshake_metadata: NixlAgentMetadata | None = None
|
||||||
self._nixl_handshake_listener_stop_event: threading.Event | None = None
|
|
||||||
# Background thread for initializing new NIXL handshakes.
|
# Background thread for initializing new NIXL handshakes.
|
||||||
self._handshake_initiation_executor = ThreadPoolExecutor(
|
self._handshake_initiation_executor = ThreadPoolExecutor(
|
||||||
# NIXL is not guaranteed to be thread-safe, limit 1 worker.
|
# NIXL is not guaranteed to be thread-safe, limit 1 worker.
|
||||||
@ -790,42 +890,6 @@ class NixlConnectorWorker:
|
|||||||
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
|
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _nixl_handshake_listener(
|
|
||||||
metadata: NixlAgentMetadata,
|
|
||||||
ready_event: threading.Event,
|
|
||||||
stop_event: threading.Event,
|
|
||||||
base_port: int,
|
|
||||||
tp_rank: int,
|
|
||||||
):
|
|
||||||
"""Background thread for getting new NIXL handshakes."""
|
|
||||||
# NOTE(rob): this is a simple implementation. We will move
|
|
||||||
# to a better approach via HTTP endpoint soon.
|
|
||||||
|
|
||||||
encoder = msgspec.msgpack.Encoder()
|
|
||||||
encoded_data = encoder.encode(metadata)
|
|
||||||
size_in_bytes = len(encoded_data)
|
|
||||||
logger.debug("Size of encoded NixlAgentMetadata: %s bytes", str(size_in_bytes))
|
|
||||||
|
|
||||||
# Listen for new requests for metadata.
|
|
||||||
host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
|
|
||||||
path = make_zmq_path("tcp", host, base_port + tp_rank)
|
|
||||||
logger.debug("Starting listening on path: %s", path)
|
|
||||||
with zmq_ctx(zmq.ROUTER, path) as sock:
|
|
||||||
ready_event.set()
|
|
||||||
poller = zmq.Poller()
|
|
||||||
poller.register(sock, zmq.POLLIN)
|
|
||||||
while not stop_event.is_set():
|
|
||||||
events = dict(
|
|
||||||
poller.poll(timeout=NixlConnectorWorker._POLL_TIMEOUT * 1000)
|
|
||||||
)
|
|
||||||
if sock not in events:
|
|
||||||
continue
|
|
||||||
identity, _, msg = sock.recv_multipart()
|
|
||||||
if msg != GET_META_MSG:
|
|
||||||
logger.warning("Connection listener got unexpected message %s", msg)
|
|
||||||
sock.send_multipart((identity, b"", encoded_data))
|
|
||||||
|
|
||||||
def _nixl_handshake(
|
def _nixl_handshake(
|
||||||
self,
|
self,
|
||||||
host: str,
|
host: str,
|
||||||
@ -844,16 +908,17 @@ class NixlConnectorWorker:
|
|||||||
# Handshake only with the remote TP rank that current local rank will
|
# Handshake only with the remote TP rank that current local rank will
|
||||||
# pull from. With homogeneous TP it happens to be the same rank_i.
|
# pull from. With homogeneous TP it happens to be the same rank_i.
|
||||||
p_remote_rank = self.kv_topo.get_target_remote_rank(remote_tp_size)
|
p_remote_rank = self.kv_topo.get_target_remote_rank(remote_tp_size)
|
||||||
path = make_zmq_path("tcp", host, port + p_remote_rank)
|
path = make_zmq_path("tcp", host, port)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Querying metadata on path: %s at remote rank %s", path, p_remote_rank
|
"Querying metadata on path: %s at remote tp rank %s", path, p_remote_rank
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send query for the request.
|
# Send query for the request.
|
||||||
with zmq_ctx(zmq.REQ, path) as sock:
|
with zmq_ctx(zmq.REQ, path) as sock:
|
||||||
|
msg = msgspec.msgpack.encode((GET_META_MSG, p_remote_rank))
|
||||||
# Set receive timeout to 5 seconds to avoid hanging on dead server
|
# Set receive timeout to 5 seconds to avoid hanging on dead server
|
||||||
sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds
|
sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds
|
||||||
sock.send(GET_META_MSG)
|
sock.send(msg)
|
||||||
metadata_bytes = sock.recv()
|
metadata_bytes = sock.recv()
|
||||||
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
|
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
|
||||||
metadata = decoder.decode(metadata_bytes)
|
metadata = decoder.decode(metadata_bytes)
|
||||||
@ -1042,6 +1107,10 @@ class NixlConnectorWorker:
|
|||||||
assert tensor_size_bytes == curr_tensor_size_bytes, (
|
assert tensor_size_bytes == curr_tensor_size_bytes, (
|
||||||
"All kv cache tensors must have the same size"
|
"All kv cache tensors must have the same size"
|
||||||
)
|
)
|
||||||
|
# Need to make sure the device ID is non-negative for NIXL,
|
||||||
|
# Torch uses -1 to indicate CPU tensors while NIXL uses explicit
|
||||||
|
# memory type.
|
||||||
|
self.device_id = max(cache.get_device(), 0)
|
||||||
caches_data.append(
|
caches_data.append(
|
||||||
(base_addr, curr_tensor_size_bytes, self.device_id, "")
|
(base_addr, curr_tensor_size_bytes, self.device_id, "")
|
||||||
)
|
)
|
||||||
@ -1139,10 +1208,11 @@ class NixlConnectorWorker:
|
|||||||
assert len(self.block_window_per_layer) == self.num_layers
|
assert len(self.block_window_per_layer) == self.num_layers
|
||||||
|
|
||||||
# After KV Caches registered, listen for new connections.
|
# After KV Caches registered, listen for new connections.
|
||||||
metadata = NixlAgentMetadata(
|
self.xfer_handshake_metadata = NixlAgentMetadata(
|
||||||
engine_id=self.engine_id,
|
engine_id=self.engine_id,
|
||||||
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
|
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
|
||||||
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
|
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
|
||||||
|
device_id=self.device_id,
|
||||||
num_blocks=self.num_blocks,
|
num_blocks=self.num_blocks,
|
||||||
block_lens=self.block_len_per_layer,
|
block_lens=self.block_len_per_layer,
|
||||||
attn_backend_name=self.backend_name,
|
attn_backend_name=self.backend_name,
|
||||||
@ -1150,22 +1220,6 @@ class NixlConnectorWorker:
|
|||||||
if not self.use_host_buffer
|
if not self.use_host_buffer
|
||||||
else self.host_buffer_kv_cache_layout,
|
else self.host_buffer_kv_cache_layout,
|
||||||
)
|
)
|
||||||
ready_event, stop_event = threading.Event(), threading.Event()
|
|
||||||
self._nixl_handshake_listener_t = threading.Thread(
|
|
||||||
target=self._nixl_handshake_listener,
|
|
||||||
args=(
|
|
||||||
metadata,
|
|
||||||
ready_event,
|
|
||||||
stop_event,
|
|
||||||
self.side_channel_port,
|
|
||||||
self.tp_rank,
|
|
||||||
),
|
|
||||||
daemon=True,
|
|
||||||
name="nixl_handshake_listener",
|
|
||||||
)
|
|
||||||
self._nixl_handshake_listener_t.start()
|
|
||||||
self._nixl_handshake_listener_stop_event = stop_event
|
|
||||||
ready_event.wait() # Wait for listener ZMQ socket to be ready.
|
|
||||||
|
|
||||||
def add_remote_agent(
|
def add_remote_agent(
|
||||||
self,
|
self,
|
||||||
@ -1267,7 +1321,7 @@ class NixlConnectorWorker:
|
|||||||
# self.block_len == remote_block_len//tp_ratio bytes.
|
# self.block_len == remote_block_len//tp_ratio bytes.
|
||||||
addr = base_addr + block_offset + rank_offset
|
addr = base_addr + block_offset + rank_offset
|
||||||
# (addr, len, device id)
|
# (addr, len, device id)
|
||||||
blocks_data.append((addr, kv_block_len, remote_tp_rank))
|
blocks_data.append((addr, kv_block_len, nixl_agent_meta.device_id))
|
||||||
|
|
||||||
if self._use_flashinfer:
|
if self._use_flashinfer:
|
||||||
# With FlashInfer index V separately to allow head splitting.
|
# With FlashInfer index V separately to allow head splitting.
|
||||||
@ -1275,7 +1329,9 @@ class NixlConnectorWorker:
|
|||||||
block_offset = block_id * nixl_agent_meta.block_lens[i]
|
block_offset = block_id * nixl_agent_meta.block_lens[i]
|
||||||
addr = base_addr + block_offset + rank_offset
|
addr = base_addr + block_offset + rank_offset
|
||||||
v_addr = addr + nixl_agent_meta.block_lens[i] // 2
|
v_addr = addr + nixl_agent_meta.block_lens[i] // 2
|
||||||
blocks_data.append((v_addr, kv_block_len, remote_tp_rank))
|
blocks_data.append(
|
||||||
|
(v_addr, kv_block_len, nixl_agent_meta.device_id)
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Created %s blocks for dst engine %s with remote rank %s and local rank %s",
|
"Created %s blocks for dst engine %s with remote rank %s and local rank %s",
|
||||||
@ -1843,14 +1899,6 @@ class NixlConnectorWorker:
|
|||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
"""Shutdown the connector worker."""
|
"""Shutdown the connector worker."""
|
||||||
self._handshake_initiation_executor.shutdown(wait=False)
|
self._handshake_initiation_executor.shutdown(wait=False)
|
||||||
if self._nixl_handshake_listener_stop_event is not None:
|
|
||||||
self._nixl_handshake_listener_stop_event.set()
|
|
||||||
self._nixl_handshake_listener_stop_event = None
|
|
||||||
if self._nixl_handshake_listener_t is not None:
|
|
||||||
# Generous timeout to allow the thread to exit
|
|
||||||
self._nixl_handshake_listener_t.join(timeout=self._POLL_TIMEOUT * 10)
|
|
||||||
assert not self._nixl_handshake_listener_t.is_alive()
|
|
||||||
self._nixl_handshake_listener_t = None
|
|
||||||
for handles in self._recving_transfers.values():
|
for handles in self._recving_transfers.values():
|
||||||
for handle, _ in handles:
|
for handle, _ in handles:
|
||||||
self.nixl_wrapper.release_xfer_handle(handle)
|
self.nixl_wrapper.release_xfer_handle(handle)
|
||||||
|
|||||||
@ -54,7 +54,13 @@ from vllm.config import (
|
|||||||
VllmConfig,
|
VllmConfig,
|
||||||
get_attr_docs,
|
get_attr_docs,
|
||||||
)
|
)
|
||||||
from vllm.config.cache import BlockSize, CacheDType, MambaDType, PrefixCachingHashAlgo
|
from vllm.config.cache import (
|
||||||
|
BlockSize,
|
||||||
|
CacheDType,
|
||||||
|
KVOffloadingBackend,
|
||||||
|
MambaDType,
|
||||||
|
PrefixCachingHashAlgo,
|
||||||
|
)
|
||||||
from vllm.config.device import Device
|
from vllm.config.device import Device
|
||||||
from vllm.config.model import (
|
from vllm.config.model import (
|
||||||
ConvertOption,
|
ConvertOption,
|
||||||
@ -553,6 +559,11 @@ class EngineArgs:
|
|||||||
|
|
||||||
kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill
|
kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill
|
||||||
|
|
||||||
|
kv_offloading_size: float | None = CacheConfig.kv_offloading_size
|
||||||
|
kv_offloading_backend: KVOffloadingBackend | None = (
|
||||||
|
CacheConfig.kv_offloading_backend
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# support `EngineArgs(compilation_config={...})`
|
# support `EngineArgs(compilation_config={...})`
|
||||||
# without having to manually construct a
|
# without having to manually construct a
|
||||||
@ -896,6 +907,12 @@ class EngineArgs:
|
|||||||
cache_group.add_argument(
|
cache_group.add_argument(
|
||||||
"--mamba-block-size", **cache_kwargs["mamba_block_size"]
|
"--mamba-block-size", **cache_kwargs["mamba_block_size"]
|
||||||
)
|
)
|
||||||
|
cache_group.add_argument(
|
||||||
|
"--kv-offloading-size", **cache_kwargs["kv_offloading_size"]
|
||||||
|
)
|
||||||
|
cache_group.add_argument(
|
||||||
|
"--kv-offloading-backend", **cache_kwargs["kv_offloading_backend"]
|
||||||
|
)
|
||||||
|
|
||||||
# Multimodal related configs
|
# Multimodal related configs
|
||||||
multimodal_kwargs = get_kwargs(MultiModalConfig)
|
multimodal_kwargs = get_kwargs(MultiModalConfig)
|
||||||
@ -1246,8 +1263,6 @@ class EngineArgs:
|
|||||||
self,
|
self,
|
||||||
target_model_config: ModelConfig,
|
target_model_config: ModelConfig,
|
||||||
target_parallel_config: ParallelConfig,
|
target_parallel_config: ParallelConfig,
|
||||||
enable_chunked_prefill: bool,
|
|
||||||
disable_log_stats: bool,
|
|
||||||
) -> SpeculativeConfig | None:
|
) -> SpeculativeConfig | None:
|
||||||
"""Initializes and returns a SpeculativeConfig object based on
|
"""Initializes and returns a SpeculativeConfig object based on
|
||||||
`speculative_config`.
|
`speculative_config`.
|
||||||
@ -1267,8 +1282,6 @@ class EngineArgs:
|
|||||||
{
|
{
|
||||||
"target_model_config": target_model_config,
|
"target_model_config": target_model_config,
|
||||||
"target_parallel_config": target_parallel_config,
|
"target_parallel_config": target_parallel_config,
|
||||||
"enable_chunked_prefill": enable_chunked_prefill,
|
|
||||||
"disable_log_stats": disable_log_stats,
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return SpeculativeConfig(**self.speculative_config)
|
return SpeculativeConfig(**self.speculative_config)
|
||||||
@ -1391,6 +1404,8 @@ class EngineArgs:
|
|||||||
mamba_cache_dtype=self.mamba_cache_dtype,
|
mamba_cache_dtype=self.mamba_cache_dtype,
|
||||||
mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
|
mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
|
||||||
mamba_block_size=self.mamba_block_size,
|
mamba_block_size=self.mamba_block_size,
|
||||||
|
kv_offloading_size=self.kv_offloading_size,
|
||||||
|
kv_offloading_backend=self.kv_offloading_backend,
|
||||||
)
|
)
|
||||||
|
|
||||||
ray_runtime_env = None
|
ray_runtime_env = None
|
||||||
@ -1561,8 +1576,6 @@ class EngineArgs:
|
|||||||
speculative_config = self.create_speculative_config(
|
speculative_config = self.create_speculative_config(
|
||||||
target_model_config=model_config,
|
target_model_config=model_config,
|
||||||
target_parallel_config=parallel_config,
|
target_parallel_config=parallel_config,
|
||||||
enable_chunked_prefill=self.enable_chunked_prefill,
|
|
||||||
disable_log_stats=self.disable_log_stats,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# make sure num_lookahead_slots is set appropriately depending on
|
# make sure num_lookahead_slots is set appropriately depending on
|
||||||
@ -1813,7 +1826,7 @@ class EngineArgs:
|
|||||||
incremental_prefill_supported = (
|
incremental_prefill_supported = (
|
||||||
pooling_type is not None
|
pooling_type is not None
|
||||||
and pooling_type.lower() == "last"
|
and pooling_type.lower() == "last"
|
||||||
and is_causal
|
and bool(is_causal)
|
||||||
)
|
)
|
||||||
|
|
||||||
action = "Enabling" if incremental_prefill_supported else "Disabling"
|
action = "Enabling" if incremental_prefill_supported else "Disabling"
|
||||||
|
|||||||
@ -241,6 +241,7 @@ async def build_async_engine_client_from_engine_args(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Don't keep the dummy data in memory
|
# Don't keep the dummy data in memory
|
||||||
|
assert async_llm is not None
|
||||||
await async_llm.reset_mm_cache()
|
await async_llm.reset_mm_cache()
|
||||||
|
|
||||||
yield async_llm
|
yield async_llm
|
||||||
|
|||||||
@ -345,22 +345,7 @@ class OpenAIServing:
|
|||||||
|
|
||||||
if is_explicit_encoder_decoder_prompt(prompt):
|
if is_explicit_encoder_decoder_prompt(prompt):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
else:
|
|
||||||
processed_inputs = processor.input_preprocessor._prompt_to_llm_inputs(
|
|
||||||
prompt
|
|
||||||
)
|
|
||||||
|
|
||||||
if processed_inputs["type"] == "embeds":
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
# This is a workaround to fix multimodal beam search; this is a
|
|
||||||
# bandaid fix for 2 small problems:
|
|
||||||
# 1. Multi_modal_data on the processed_inputs currently resolves to
|
|
||||||
# `None`.
|
|
||||||
# 2. preprocessing above expands the multimodal placeholders. However,
|
|
||||||
# this happens again in generation, so the double expansion causes
|
|
||||||
# a mismatch.
|
|
||||||
# TODO - would be ideal to handle this more gracefully.
|
|
||||||
prompt_text: str | None
|
prompt_text: str | None
|
||||||
prompt_token_ids: list[int]
|
prompt_token_ids: list[int]
|
||||||
multi_modal_data: MultiModalDataDict | None
|
multi_modal_data: MultiModalDataDict | None
|
||||||
@ -373,9 +358,16 @@ class OpenAIServing:
|
|||||||
prompt_token_ids = prompt.get("prompt_token_ids", []) # type: ignore
|
prompt_token_ids = prompt.get("prompt_token_ids", []) # type: ignore
|
||||||
multi_modal_data = prompt.get("multi_modal_data") # type: ignore
|
multi_modal_data = prompt.get("multi_modal_data") # type: ignore
|
||||||
|
|
||||||
mm_processor_kwargs: dict[str, Any] | None = processed_inputs.get(
|
mm_processor_kwargs: dict[str, Any] | None = None
|
||||||
"mm_processor_kwargs"
|
|
||||||
) # type: ignore
|
# This is a workaround to fix multimodal beam search; this is a
|
||||||
|
# bandaid fix for 2 small problems:
|
||||||
|
# 1. Multi_modal_data on the processed_inputs currently resolves to
|
||||||
|
# `None`.
|
||||||
|
# 2. preprocessing above expands the multimodal placeholders. However,
|
||||||
|
# this happens again in generation, so the double expansion causes
|
||||||
|
# a mismatch.
|
||||||
|
# TODO - would be ideal to handle this more gracefully.
|
||||||
|
|
||||||
tokenized_length = len(prompt_token_ids)
|
tokenized_length = len(prompt_token_ids)
|
||||||
|
|
||||||
|
|||||||
@ -2,11 +2,12 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import re
|
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import regex as re
|
||||||
|
|
||||||
from vllm.entrypoints.openai.protocol import (
|
from vllm.entrypoints.openai.protocol import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
DeltaFunctionCall,
|
DeltaFunctionCall,
|
||||||
|
|||||||
@ -15,9 +15,7 @@ from vllm.distributed.parallel_state import (
|
|||||||
from vllm.lora.layers.base import BaseLayerWithLoRA
|
from vllm.lora.layers.base import BaseLayerWithLoRA
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
from vllm.model_executor.layers.fused_moe.config import (
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
FUSED_MOE_UNQUANTIZED_CONFIG,
|
|
||||||
_get_config_dtype_str,
|
_get_config_dtype_str,
|
||||||
mxfp4_w4a16_moe_quant_config,
|
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||||
modular_marlin_fused_moe,
|
modular_marlin_fused_moe,
|
||||||
@ -26,13 +24,16 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
|
|||||||
modular_triton_fused_moe,
|
modular_triton_fused_moe,
|
||||||
try_get_optimal_moe_config,
|
try_get_optimal_moe_config,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.mxfp4 import Mxfp4Config
|
|
||||||
|
|
||||||
|
|
||||||
class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||||
def __init__(self, base_layer: FusedMoE) -> None:
|
def __init__(self, base_layer: FusedMoE) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.base_layer = base_layer
|
self.base_layer = base_layer
|
||||||
|
|
||||||
|
assert not self.base_layer.use_ep, (
|
||||||
|
"EP support for Fused MoE LoRA is not implemented yet."
|
||||||
|
)
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.tp_rank = get_tensor_model_parallel_rank()
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
self.device = base_layer.w2_weight.device
|
self.device = base_layer.w2_weight.device
|
||||||
@ -42,17 +43,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
moe_state_dict = {}
|
moe_state_dict = {}
|
||||||
top_k = self.base_layer.top_k
|
top_k = self.base_layer.top_k
|
||||||
|
|
||||||
if self.base_layer.quant_config is None:
|
self.base_layer.ensure_moe_quant_config_init()
|
||||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
quant_config = self.base_layer.quant_method.moe_quant_config
|
||||||
elif not isinstance(self.base_layer.quant_config, Mxfp4Config):
|
|
||||||
quant_config = self.base_layer.quant_config
|
|
||||||
else:
|
|
||||||
quant_config = mxfp4_w4a16_moe_quant_config(
|
|
||||||
w1_bias=self.base_layer.w13_bias,
|
|
||||||
w2_bias=self.base_layer.w2_bias,
|
|
||||||
w1_scale=self.base_layer.w13_weight_scale,
|
|
||||||
w2_scale=self.base_layer.w2_weight_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
m_fused_moe_fn = (
|
m_fused_moe_fn = (
|
||||||
modular_triton_fused_moe(
|
modular_triton_fused_moe(
|
||||||
@ -69,7 +61,6 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
moe_state_dict["hidden_states"] = kwargs["hidden_states"]
|
moe_state_dict["hidden_states"] = kwargs["hidden_states"]
|
||||||
moe_state_dict["topk_ids"] = kwargs["topk_ids"]
|
moe_state_dict["topk_ids"] = kwargs["topk_ids"]
|
||||||
moe_state_dict["topk_weights"] = kwargs["topk_weights"]
|
moe_state_dict["topk_weights"] = kwargs["topk_weights"]
|
||||||
moe_state_dict["global_num_experts"] = kwargs["global_num_experts"]
|
|
||||||
moe_state_dict["expert_map"] = kwargs["expert_map"]
|
moe_state_dict["expert_map"] = kwargs["expert_map"]
|
||||||
moe_state_dict["apply_router_weight_on_input"] = kwargs[
|
moe_state_dict["apply_router_weight_on_input"] = kwargs[
|
||||||
"apply_router_weight_on_input"
|
"apply_router_weight_on_input"
|
||||||
@ -86,7 +77,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
hidden_states = moe_state_dict["hidden_states"]
|
hidden_states = moe_state_dict["hidden_states"]
|
||||||
topk_weights = moe_state_dict["topk_weights"]
|
topk_weights = moe_state_dict["topk_weights"]
|
||||||
curr_topk_ids = moe_state_dict["topk_ids"]
|
curr_topk_ids = moe_state_dict["topk_ids"]
|
||||||
global_num_experts = moe_state_dict["global_num_experts"]
|
|
||||||
expert_map = moe_state_dict["expert_map"]
|
expert_map = moe_state_dict["expert_map"]
|
||||||
|
|
||||||
config_dtype = _get_config_dtype_str(
|
config_dtype = _get_config_dtype_str(
|
||||||
@ -118,7 +109,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
curr_topk_ids,
|
curr_topk_ids,
|
||||||
num_tokens,
|
num_tokens,
|
||||||
config["BLOCK_SIZE_M"],
|
config["BLOCK_SIZE_M"],
|
||||||
global_num_experts,
|
self.base_layer.local_num_experts,
|
||||||
max_loras,
|
max_loras,
|
||||||
expert_map,
|
expert_map,
|
||||||
)
|
)
|
||||||
@ -236,14 +227,10 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Initializes lora matrices."""
|
"""Initializes lora matrices."""
|
||||||
|
|
||||||
assert not self.base_layer.use_ep, (
|
|
||||||
"EP support for Fused MoE LoRA is not implemented yet."
|
|
||||||
)
|
|
||||||
|
|
||||||
self.w1_lora_a_stacked = torch.zeros(
|
self.w1_lora_a_stacked = torch.zeros(
|
||||||
(
|
(
|
||||||
max_loras,
|
max_loras,
|
||||||
self.base_layer.global_num_experts,
|
self.base_layer.local_num_experts,
|
||||||
lora_config.max_lora_rank,
|
lora_config.max_lora_rank,
|
||||||
self.base_layer.hidden_size,
|
self.base_layer.hidden_size,
|
||||||
),
|
),
|
||||||
@ -253,7 +240,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
self.w1_lora_b_stacked = torch.zeros(
|
self.w1_lora_b_stacked = torch.zeros(
|
||||||
(
|
(
|
||||||
max_loras,
|
max_loras,
|
||||||
self.base_layer.global_num_experts,
|
self.base_layer.local_num_experts,
|
||||||
self.base_layer.intermediate_size_per_partition,
|
self.base_layer.intermediate_size_per_partition,
|
||||||
lora_config.max_lora_rank,
|
lora_config.max_lora_rank,
|
||||||
),
|
),
|
||||||
@ -264,7 +251,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
self.w2_lora_a_stacked = torch.zeros(
|
self.w2_lora_a_stacked = torch.zeros(
|
||||||
(
|
(
|
||||||
max_loras,
|
max_loras,
|
||||||
self.base_layer.global_num_experts,
|
self.base_layer.local_num_experts,
|
||||||
lora_config.max_lora_rank,
|
lora_config.max_lora_rank,
|
||||||
self.base_layer.intermediate_size_per_partition,
|
self.base_layer.intermediate_size_per_partition,
|
||||||
),
|
),
|
||||||
@ -274,7 +261,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
self.w2_lora_b_stacked = torch.zeros(
|
self.w2_lora_b_stacked = torch.zeros(
|
||||||
(
|
(
|
||||||
max_loras,
|
max_loras,
|
||||||
self.base_layer.global_num_experts,
|
self.base_layer.local_num_experts,
|
||||||
self.base_layer.hidden_size,
|
self.base_layer.hidden_size,
|
||||||
lora_config.max_lora_rank,
|
lora_config.max_lora_rank,
|
||||||
),
|
),
|
||||||
@ -285,7 +272,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
self.w3_lora_a_stacked = torch.zeros(
|
self.w3_lora_a_stacked = torch.zeros(
|
||||||
(
|
(
|
||||||
max_loras,
|
max_loras,
|
||||||
self.base_layer.global_num_experts,
|
self.base_layer.local_num_experts,
|
||||||
lora_config.max_lora_rank,
|
lora_config.max_lora_rank,
|
||||||
self.base_layer.hidden_size,
|
self.base_layer.hidden_size,
|
||||||
),
|
),
|
||||||
@ -295,7 +282,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
self.w3_lora_b_stacked = torch.zeros(
|
self.w3_lora_b_stacked = torch.zeros(
|
||||||
(
|
(
|
||||||
max_loras,
|
max_loras,
|
||||||
self.base_layer.global_num_experts,
|
self.base_layer.local_num_experts,
|
||||||
self.base_layer.intermediate_size_per_partition,
|
self.base_layer.intermediate_size_per_partition,
|
||||||
lora_config.max_lora_rank,
|
lora_config.max_lora_rank,
|
||||||
),
|
),
|
||||||
@ -308,7 +295,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
self.lora_a_stacked = []
|
self.lora_a_stacked = []
|
||||||
self.lora_b_stacked = []
|
self.lora_b_stacked = []
|
||||||
for lora_id in range(max_loras):
|
for lora_id in range(max_loras):
|
||||||
for experts_id in range(self.base_layer.global_num_experts):
|
for experts_id in range(self.base_layer.local_num_experts):
|
||||||
# gate_proj,down_proj,up_proj
|
# gate_proj,down_proj,up_proj
|
||||||
self.lora_a_stacked.append(self.w1_lora_a_stacked[lora_id][experts_id])
|
self.lora_a_stacked.append(self.w1_lora_a_stacked[lora_id][experts_id])
|
||||||
self.lora_a_stacked.append(self.w2_lora_a_stacked[lora_id][experts_id])
|
self.lora_a_stacked.append(self.w2_lora_a_stacked[lora_id][experts_id])
|
||||||
|
|||||||
@ -88,14 +88,17 @@ def _fused_moe_lora_kernel(
|
|||||||
grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)
|
grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)
|
||||||
|
|
||||||
# calculate pid_m,pid_n
|
# calculate pid_m,pid_n
|
||||||
|
pid_sk = pid % SPLIT_K
|
||||||
|
pid_m_n = pid // SPLIT_K
|
||||||
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
|
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
|
||||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||||
|
|
||||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||||
group_id = pid // num_pid_in_group
|
group_id = pid_m_n // num_pid_in_group
|
||||||
first_pid_m = group_id * GROUP_SIZE_M
|
first_pid_m = group_id * GROUP_SIZE_M
|
||||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||||
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
pid_m = first_pid_m + ((pid_m_n % num_pid_in_group) % group_size_m)
|
||||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
pid_n = (pid_m_n % num_pid_in_group) // group_size_m
|
||||||
|
|
||||||
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr + lora_idx)
|
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr + lora_idx)
|
||||||
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
||||||
@ -113,7 +116,7 @@ def _fused_moe_lora_kernel(
|
|||||||
cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size
|
cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size
|
||||||
|
|
||||||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
||||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
offs_k = pid_sk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
||||||
|
|
||||||
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
||||||
token_ind = stride_tl * lora_idx + offs_token_id
|
token_ind = stride_tl * lora_idx + offs_token_id
|
||||||
@ -131,7 +134,8 @@ def _fused_moe_lora_kernel(
|
|||||||
cur_b_ptr
|
cur_b_ptr
|
||||||
+ lora_idx * stride_bl
|
+ lora_idx * stride_bl
|
||||||
+ expert_id * stride_be
|
+ expert_id * stride_be
|
||||||
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
+ offs_k[:, None] * stride_bk
|
||||||
|
+ offs_bn[None, :] * stride_bn
|
||||||
)
|
)
|
||||||
|
|
||||||
# accumulator
|
# accumulator
|
||||||
|
|||||||
@ -56,6 +56,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
ep_size: int = 1,
|
ep_size: int = 1,
|
||||||
tp_rank: int = 0,
|
tp_rank: int = 0,
|
||||||
tp_size: int = 1,
|
tp_size: int = 1,
|
||||||
|
use_dp: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(quant_config)
|
super().__init__(quant_config)
|
||||||
assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn, None), (
|
assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn, None), (
|
||||||
@ -67,6 +68,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
self.tp_rank = tp_rank
|
self.tp_rank = tp_rank
|
||||||
self.tp_size = tp_size
|
self.tp_size = tp_size
|
||||||
self.out_dtype = out_dtype
|
self.out_dtype = out_dtype
|
||||||
|
self.use_dp = use_dp
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def activation_formats(
|
def activation_formats(
|
||||||
@ -117,7 +119,8 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
"""
|
"""
|
||||||
workspace1 = (M, K)
|
workspace1 = (M, K)
|
||||||
workspace2 = (0,)
|
workspace2 = (0,)
|
||||||
output_shape = (M, K * 2 if self.quant_dtype == "nvfp4" else K)
|
# For TP, the quantization is fused with fused_moe call.
|
||||||
|
output_shape = (M, K * 2 if self.quant_dtype == "nvfp4" and self.use_dp else K)
|
||||||
# The workspace is determined by `aq`, since it comes after any
|
# The workspace is determined by `aq`, since it comes after any
|
||||||
# potential communication op and is involved in the expert computation.
|
# potential communication op and is involved in the expert computation.
|
||||||
return (workspace1, workspace2, output_shape)
|
return (workspace1, workspace2, output_shape)
|
||||||
@ -214,6 +217,7 @@ def flashinfer_cutlass_moe_fp4(
|
|||||||
FlashInferExperts(
|
FlashInferExperts(
|
||||||
out_dtype=hidden_states.dtype,
|
out_dtype=hidden_states.dtype,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
use_dp=False,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -170,6 +170,8 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
|
|||||||
self._apply_router_weight_on_input(
|
self._apply_router_weight_on_input(
|
||||||
a1, topk_weights, topk_ids, apply_router_weight_on_input
|
a1, topk_weights, topk_ids, apply_router_weight_on_input
|
||||||
)
|
)
|
||||||
|
if not self.use_dp:
|
||||||
|
return a1, None, None, topk_ids, topk_weights
|
||||||
|
|
||||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||||
a1,
|
a1,
|
||||||
@ -179,14 +181,13 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
|
|||||||
quant_config.block_shape,
|
quant_config.block_shape,
|
||||||
is_fp4_scale_swizzled=not self.use_dp,
|
is_fp4_scale_swizzled=not self.use_dp,
|
||||||
)
|
)
|
||||||
if self.use_dp:
|
topk_weights, topk_ids, a1q, a1q_scale = get_dp_group().all_gatherv(
|
||||||
topk_weights, topk_ids, a1q, a1q_scale = get_dp_group().all_gatherv(
|
[topk_weights, topk_ids, a1q, a1q_scale],
|
||||||
[topk_weights, topk_ids, a1q, a1q_scale],
|
dim=0,
|
||||||
dim=0,
|
sizes=get_local_sizes(),
|
||||||
sizes=get_local_sizes(),
|
)
|
||||||
)
|
if quant_config.quant_dtype == "nvfp4":
|
||||||
if quant_config.quant_dtype == "nvfp4":
|
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
|
||||||
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
|
|
||||||
|
|
||||||
return a1q, a1q_scale, None, topk_ids, topk_weights
|
return a1q, a1q_scale, None, topk_ids, topk_weights
|
||||||
|
|
||||||
|
|||||||
@ -672,8 +672,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
)
|
)
|
||||||
elif self.fused_experts is not None:
|
elif self.fused_experts is not None:
|
||||||
if self.moe.has_bias:
|
|
||||||
raise ValueError("FusedMoEModularKernel does not support bias.")
|
|
||||||
result = self.fused_experts(
|
result = self.fused_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
|
|||||||
@ -40,18 +40,36 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def kda_attention(
|
def kda_attention(
|
||||||
hidden_states: torch.Tensor,
|
q_proj_states: torch.Tensor,
|
||||||
output: torch.Tensor,
|
k_proj_states: torch.Tensor,
|
||||||
|
v_proj_states: torch.Tensor,
|
||||||
|
g1: torch.Tensor,
|
||||||
|
g2: torch.Tensor,
|
||||||
|
beta: torch.Tensor,
|
||||||
|
core_attn_out: torch.Tensor,
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
forward_context: ForwardContext = get_forward_context()
|
forward_context: ForwardContext = get_forward_context()
|
||||||
self = forward_context.no_compile_layers[layer_name]
|
self = forward_context.no_compile_layers[layer_name]
|
||||||
self._forward(hidden_states=hidden_states, output=output)
|
self._forward(
|
||||||
|
q_proj_states=q_proj_states,
|
||||||
|
k_proj_states=k_proj_states,
|
||||||
|
v_proj_states=v_proj_states,
|
||||||
|
g1=g1,
|
||||||
|
g2=g2,
|
||||||
|
beta=beta,
|
||||||
|
core_attn_out=core_attn_out,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def kda_attention_fake(
|
def kda_attention_fake(
|
||||||
hidden_states: torch.Tensor,
|
q_proj_states: torch.Tensor,
|
||||||
output: torch.Tensor,
|
k_proj_states: torch.Tensor,
|
||||||
|
v_proj_states: torch.Tensor,
|
||||||
|
g1: torch.Tensor,
|
||||||
|
g2: torch.Tensor,
|
||||||
|
beta: torch.Tensor,
|
||||||
|
core_attn_out: torch.Tensor,
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
return
|
return
|
||||||
@ -60,7 +78,7 @@ def kda_attention_fake(
|
|||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="kda_attention",
|
op_name="kda_attention",
|
||||||
op_func=kda_attention,
|
op_func=kda_attention,
|
||||||
mutates_args=["output"],
|
mutates_args=["core_attn_out"],
|
||||||
fake_impl=kda_attention_fake,
|
fake_impl=kda_attention_fake,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -242,36 +260,54 @@ class KimiDeltaAttention(nn.Module, MambaBase):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
output: torch.Tensor,
|
output: torch.Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
return torch.ops.vllm.kda_attention(
|
num_tokens = hidden_states.size(0)
|
||||||
hidden_states,
|
q = self.q_proj(hidden_states)[0]
|
||||||
output,
|
k = self.k_proj(hidden_states)[0]
|
||||||
|
v = self.v_proj(hidden_states)[0]
|
||||||
|
|
||||||
|
beta = self.b_proj(hidden_states)[0].float().sigmoid()
|
||||||
|
g1 = self.f_b_proj(self.f_a_proj(hidden_states)[0])[0]
|
||||||
|
g1 = fused_kda_gate(g1, self.A_log, self.head_dim, g_bias=self.dt_bias)
|
||||||
|
beta = beta.unsqueeze(0)
|
||||||
|
g1 = g1.unsqueeze(0)
|
||||||
|
|
||||||
|
g_proj_states = self.g_b_proj(self.g_a_proj(hidden_states)[0])[0]
|
||||||
|
g2 = rearrange(g_proj_states, "... (h d) -> ... h d", d=self.head_dim)
|
||||||
|
|
||||||
|
core_attn_out = torch.zeros(
|
||||||
|
(1, num_tokens, self.local_num_heads, self.head_dim),
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
device=hidden_states.device,
|
||||||
|
)
|
||||||
|
torch.ops.vllm.kda_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
g1,
|
||||||
|
g2,
|
||||||
|
beta,
|
||||||
|
core_attn_out,
|
||||||
self.prefix,
|
self.prefix,
|
||||||
)
|
)
|
||||||
|
core_attn_out = self.o_norm(core_attn_out, g2)
|
||||||
|
core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)")
|
||||||
|
output[:] = self.o_proj(core_attn_out)[0]
|
||||||
|
|
||||||
def _forward(
|
def _forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
q_proj_states: torch.Tensor,
|
||||||
output: torch.Tensor,
|
k_proj_states: torch.Tensor,
|
||||||
|
v_proj_states: torch.Tensor,
|
||||||
|
g1: torch.Tensor,
|
||||||
|
g2: torch.Tensor,
|
||||||
|
beta: torch.Tensor,
|
||||||
|
core_attn_out: torch.Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||||
|
|
||||||
if attn_metadata is None:
|
if attn_metadata is None:
|
||||||
# V1 profile run
|
# # V1 profile run
|
||||||
# Mimic the memory allocation in the real run
|
|
||||||
q = torch.empty_like(hidden_states)
|
|
||||||
k = torch.empty_like(hidden_states)
|
|
||||||
v = torch.empty_like(hidden_states)
|
|
||||||
g = hidden_states.new_empty(
|
|
||||||
hidden_states.size(0),
|
|
||||||
self.local_num_heads,
|
|
||||||
self.head_dim,
|
|
||||||
dtype=torch.float32,
|
|
||||||
)
|
|
||||||
beta = torch.empty(
|
|
||||||
hidden_states.size(0), self.local_num_heads, dtype=torch.float32
|
|
||||||
)
|
|
||||||
core_attn_out = torch.empty_like(hidden_states)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
assert isinstance(attn_metadata, dict)
|
assert isinstance(attn_metadata, dict)
|
||||||
@ -288,10 +324,6 @@ class KimiDeltaAttention(nn.Module, MambaBase):
|
|||||||
conv_state_k = conv_state_k.transpose(-1, -2)
|
conv_state_k = conv_state_k.transpose(-1, -2)
|
||||||
conv_state_v = conv_state_v.transpose(-1, -2)
|
conv_state_v = conv_state_v.transpose(-1, -2)
|
||||||
|
|
||||||
q_proj_states = self.q_proj(hidden_states)[0]
|
|
||||||
k_proj_states = self.k_proj(hidden_states)[0]
|
|
||||||
v_proj_states = self.v_proj(hidden_states)[0]
|
|
||||||
|
|
||||||
q_conv_weights = self.q_conv1d.weight.view(
|
q_conv_weights = self.q_conv1d.weight.view(
|
||||||
self.q_conv1d.weight.size(0), self.q_conv1d.weight.size(2)
|
self.q_conv1d.weight.size(0), self.q_conv1d.weight.size(2)
|
||||||
)
|
)
|
||||||
@ -374,14 +406,6 @@ class KimiDeltaAttention(nn.Module, MambaBase):
|
|||||||
lambda x: rearrange(x, "n (h d) -> 1 n h d", d=self.head_dim), (q, k, v)
|
lambda x: rearrange(x, "n (h d) -> 1 n h d", d=self.head_dim), (q, k, v)
|
||||||
)
|
)
|
||||||
|
|
||||||
beta = self.b_proj(hidden_states)[0].float().sigmoid()
|
|
||||||
|
|
||||||
g = self.f_b_proj(self.f_a_proj(hidden_states)[0])[0]
|
|
||||||
g = fused_kda_gate(g, self.A_log, self.head_dim, g_bias=self.dt_bias)
|
|
||||||
|
|
||||||
beta = beta.unsqueeze(0)
|
|
||||||
g = g.unsqueeze(0)
|
|
||||||
|
|
||||||
if attn_metadata.num_prefills > 0:
|
if attn_metadata.num_prefills > 0:
|
||||||
zero_idx = non_spec_state_indices_tensor[~has_initial_state]
|
zero_idx = non_spec_state_indices_tensor[~has_initial_state]
|
||||||
recurrent_state[zero_idx] = 0
|
recurrent_state[zero_idx] = 0
|
||||||
@ -393,7 +417,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
|
|||||||
q=q,
|
q=q,
|
||||||
k=k,
|
k=k,
|
||||||
v=v,
|
v=v,
|
||||||
g=g,
|
g=g1,
|
||||||
beta=beta,
|
beta=beta,
|
||||||
initial_state=initial_state,
|
initial_state=initial_state,
|
||||||
output_final_state=True,
|
output_final_state=True,
|
||||||
@ -410,17 +434,12 @@ class KimiDeltaAttention(nn.Module, MambaBase):
|
|||||||
q=q,
|
q=q,
|
||||||
k=k,
|
k=k,
|
||||||
v=v,
|
v=v,
|
||||||
g=g,
|
g=g1,
|
||||||
beta=beta,
|
beta=beta,
|
||||||
initial_state=recurrent_state,
|
initial_state=recurrent_state,
|
||||||
use_qk_l2norm_in_kernel=True,
|
use_qk_l2norm_in_kernel=True,
|
||||||
cu_seqlens=non_spec_query_start_loc,
|
cu_seqlens=non_spec_query_start_loc,
|
||||||
ssm_state_indices=non_spec_state_indices_tensor,
|
ssm_state_indices=non_spec_state_indices_tensor,
|
||||||
)
|
)
|
||||||
|
assert core_attn_out_non_spec.shape == core_attn_out.shape
|
||||||
g_proj_states = self.g_b_proj(self.g_a_proj(hidden_states)[0])[0]
|
core_attn_out[:] = core_attn_out_non_spec
|
||||||
g = rearrange(g_proj_states, "... (h d) -> ... h d", d=self.head_dim)
|
|
||||||
core_attn_out = self.o_norm(core_attn_out_non_spec, g)
|
|
||||||
core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)")
|
|
||||||
|
|
||||||
output[:] = self.o_proj(core_attn_out)[0]
|
|
||||||
|
|||||||
@ -1769,29 +1769,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
)
|
)
|
||||||
elif (
|
|
||||||
self.allow_flashinfer
|
|
||||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
|
|
||||||
):
|
|
||||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
|
|
||||||
flashinfer_cutlass_moe_fp4,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert self.moe_quant_config is not None
|
|
||||||
|
|
||||||
return flashinfer_cutlass_moe_fp4(
|
|
||||||
hidden_states=x,
|
|
||||||
w1=layer.w13_weight,
|
|
||||||
w2=layer.w2_weight,
|
|
||||||
topk_weights=topk_weights,
|
|
||||||
topk_ids=topk_ids,
|
|
||||||
quant_config=self.moe_quant_config,
|
|
||||||
inplace=False,
|
|
||||||
activation=activation,
|
|
||||||
global_num_experts=global_num_experts,
|
|
||||||
expert_map=expert_map,
|
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# If no modular kernel is provided, use cutlass_moe_fp4 for TP case
|
# If no modular kernel is provided, use cutlass_moe_fp4 for TP case
|
||||||
# only (no EP).
|
# only (no EP).
|
||||||
|
|||||||
@ -79,6 +79,7 @@ def select_nvfp4_gemm_impl(
|
|||||||
ep_size=moe.moe_parallel_config.ep_size,
|
ep_size=moe.moe_parallel_config.ep_size,
|
||||||
tp_rank=moe.moe_parallel_config.tp_rank,
|
tp_rank=moe.moe_parallel_config.tp_rank,
|
||||||
tp_size=moe.moe_parallel_config.tp_size,
|
tp_size=moe.moe_parallel_config.tp_size,
|
||||||
|
use_dp=moe.moe_parallel_config.dp_size > 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# native cutlass experts currently don't support DP; TP case won't call this
|
# native cutlass experts currently don't support DP; TP case won't call this
|
||||||
|
|||||||
@ -26,6 +26,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only GLM-4V model compatible with HuggingFace weights."""
|
"""Inference-only GLM-4V model compatible with HuggingFace weights."""
|
||||||
|
|
||||||
|
import itertools
|
||||||
import math
|
import math
|
||||||
from collections.abc import Callable, Iterable, Mapping, Sequence
|
from collections.abc import Callable, Iterable, Mapping, Sequence
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@ -36,7 +37,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from transformers import BatchFeature
|
from transformers import BatchFeature, PretrainedConfig
|
||||||
from transformers.models.glm4v.configuration_glm4v import Glm4vVisionConfig
|
from transformers.models.glm4v.configuration_glm4v import Glm4vVisionConfig
|
||||||
from transformers.models.glm4v.image_processing_glm4v import (
|
from transformers.models.glm4v.image_processing_glm4v import (
|
||||||
Glm4vImageProcessor,
|
Glm4vImageProcessor,
|
||||||
@ -89,6 +90,7 @@ from ..layers.activation import SiluAndMul
|
|||||||
from .interfaces import (
|
from .interfaces import (
|
||||||
MultiModalEmbeddings,
|
MultiModalEmbeddings,
|
||||||
SupportsLoRA,
|
SupportsLoRA,
|
||||||
|
SupportsMRoPE,
|
||||||
SupportsMultiModal,
|
SupportsMultiModal,
|
||||||
SupportsPP,
|
SupportsPP,
|
||||||
)
|
)
|
||||||
@ -1386,7 +1388,7 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]):
|
|||||||
dummy_inputs=Glm4vDummyInputsBuilder,
|
dummy_inputs=Glm4vDummyInputsBuilder,
|
||||||
)
|
)
|
||||||
class Glm4vForConditionalGeneration(
|
class Glm4vForConditionalGeneration(
|
||||||
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP
|
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
|
||||||
):
|
):
|
||||||
merge_by_field_config = True
|
merge_by_field_config = True
|
||||||
|
|
||||||
@ -1613,6 +1615,149 @@ class Glm4vForConditionalGeneration(
|
|||||||
multimodal_embeddings += tuple(video_embeddings)
|
multimodal_embeddings += tuple(video_embeddings)
|
||||||
return multimodal_embeddings
|
return multimodal_embeddings
|
||||||
|
|
||||||
|
def get_mrope_input_positions(
|
||||||
|
self,
|
||||||
|
input_tokens: list[int],
|
||||||
|
hf_config: "PretrainedConfig",
|
||||||
|
image_grid_thw: list[list[int]] | torch.Tensor | None,
|
||||||
|
video_grid_thw: list[list[int]] | torch.Tensor | None,
|
||||||
|
second_per_grid_ts: list[float] | None = None,
|
||||||
|
context_len: int = 0,
|
||||||
|
seq_len: int | None = None,
|
||||||
|
audio_feature_lengths: torch.Tensor | None = None,
|
||||||
|
use_audio_in_video: bool = False,
|
||||||
|
) -> tuple[torch.Tensor, int]:
|
||||||
|
"""Get mrope input positions and delta value for GLM4V."""
|
||||||
|
|
||||||
|
image_token_id = hf_config.image_token_id
|
||||||
|
video_start_token_id = hf_config.video_start_token_id
|
||||||
|
video_end_token_id = hf_config.video_end_token_id
|
||||||
|
spatial_merge_size = hf_config.vision_config.spatial_merge_size
|
||||||
|
llm_pos_ids_list: list = []
|
||||||
|
|
||||||
|
if not (image_grid_thw is None and video_grid_thw is None):
|
||||||
|
if isinstance(image_grid_thw, torch.Tensor):
|
||||||
|
image_grid_thw = image_grid_thw.tolist()
|
||||||
|
|
||||||
|
input_token_type: list[str] = []
|
||||||
|
video_check_flg = False
|
||||||
|
for token in input_tokens:
|
||||||
|
if token == video_start_token_id:
|
||||||
|
video_check_flg = True
|
||||||
|
elif token == video_end_token_id:
|
||||||
|
video_check_flg = False
|
||||||
|
|
||||||
|
if (token == image_token_id) and (video_check_flg is False):
|
||||||
|
input_token_type.append("image")
|
||||||
|
elif (token == image_token_id) and (video_check_flg is True):
|
||||||
|
input_token_type.append("video")
|
||||||
|
else:
|
||||||
|
input_token_type.append("text")
|
||||||
|
|
||||||
|
input_type_group: list[tuple[str, int, int]] = []
|
||||||
|
for key, group_iter in itertools.groupby(
|
||||||
|
enumerate(input_token_type), lambda x: x[1]
|
||||||
|
):
|
||||||
|
group_list = list(group_iter)
|
||||||
|
start_index = group_list[0][0]
|
||||||
|
end_index = group_list[-1][0] + 1
|
||||||
|
input_type_group.append((key, start_index, end_index))
|
||||||
|
|
||||||
|
video_frame_num = 1
|
||||||
|
mm_data_idx = 0
|
||||||
|
for modality_type, start_idx, end_idx in input_type_group:
|
||||||
|
st_idx = (
|
||||||
|
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||||
|
)
|
||||||
|
if modality_type == "image":
|
||||||
|
t, h, w = (
|
||||||
|
image_grid_thw[mm_data_idx][0],
|
||||||
|
image_grid_thw[mm_data_idx][1],
|
||||||
|
image_grid_thw[mm_data_idx][2],
|
||||||
|
)
|
||||||
|
llm_grid_t, llm_grid_h, llm_grid_w = (
|
||||||
|
t,
|
||||||
|
h // spatial_merge_size,
|
||||||
|
w // spatial_merge_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
t_index = (
|
||||||
|
torch.arange(llm_grid_t)
|
||||||
|
.view(-1, 1)
|
||||||
|
.expand(-1, llm_grid_h * llm_grid_w)
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
h_index = (
|
||||||
|
torch.arange(llm_grid_h)
|
||||||
|
.view(1, -1, 1)
|
||||||
|
.expand(llm_grid_t, -1, llm_grid_w)
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
w_index = (
|
||||||
|
torch.arange(llm_grid_w)
|
||||||
|
.view(1, 1, -1)
|
||||||
|
.expand(llm_grid_t, llm_grid_h, -1)
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
torch.stack([t_index, h_index, w_index]) + st_idx
|
||||||
|
)
|
||||||
|
mm_data_idx += 1
|
||||||
|
|
||||||
|
elif modality_type == "video":
|
||||||
|
t, h, w = (
|
||||||
|
video_frame_num,
|
||||||
|
image_grid_thw[mm_data_idx][1],
|
||||||
|
image_grid_thw[mm_data_idx][2],
|
||||||
|
)
|
||||||
|
llm_grid_t, llm_grid_h, llm_grid_w = (
|
||||||
|
t,
|
||||||
|
h // spatial_merge_size,
|
||||||
|
w // spatial_merge_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
for t_idx in range(llm_grid_t):
|
||||||
|
t_index = (
|
||||||
|
torch.tensor(t_idx)
|
||||||
|
.view(-1, 1)
|
||||||
|
.expand(-1, llm_grid_h * llm_grid_w)
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
h_index = (
|
||||||
|
torch.arange(llm_grid_h)
|
||||||
|
.view(1, -1, 1)
|
||||||
|
.expand(1, -1, llm_grid_w)
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
w_index = (
|
||||||
|
torch.arange(llm_grid_w)
|
||||||
|
.view(1, 1, -1)
|
||||||
|
.expand(1, llm_grid_h, -1)
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
torch.stack([t_index, h_index, w_index]) + st_idx
|
||||||
|
)
|
||||||
|
|
||||||
|
mm_data_idx += 1
|
||||||
|
video_frame_num += 1
|
||||||
|
|
||||||
|
else:
|
||||||
|
text_len = end_idx - start_idx
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||||
|
)
|
||||||
|
video_frame_num = 1
|
||||||
|
|
||||||
|
else:
|
||||||
|
text_len = len(input_tokens)
|
||||||
|
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1))
|
||||||
|
|
||||||
|
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||||
|
llm_positions = llm_positions[:, context_len:seq_len]
|
||||||
|
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
|
||||||
|
return llm_positions, mrope_position_delta
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
|
|||||||
@ -17,7 +17,9 @@ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPo
|
|||||||
from transformers.utils import torch_int
|
from transformers.utils import torch_int
|
||||||
|
|
||||||
from vllm.attention.backends.registry import _Backend
|
from vllm.attention.backends.registry import _Backend
|
||||||
from vllm.attention.layer import check_upstream_fa_availability
|
from vllm.attention.layer import (
|
||||||
|
maybe_get_vit_flash_attn_backend,
|
||||||
|
)
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.config.multimodal import BaseDummyOptions
|
from vllm.config.multimodal import BaseDummyOptions
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
@ -56,12 +58,14 @@ from vllm.multimodal.processing import (
|
|||||||
PromptUpdate,
|
PromptUpdate,
|
||||||
)
|
)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
from .interfaces import (
|
from .interfaces import (
|
||||||
MultiModalEmbeddings,
|
MultiModalEmbeddings,
|
||||||
SupportsLoRA,
|
SupportsLoRA,
|
||||||
|
SupportsMRoPE,
|
||||||
SupportsMultiModal,
|
SupportsMultiModal,
|
||||||
SupportsPP,
|
SupportsPP,
|
||||||
)
|
)
|
||||||
@ -337,7 +341,10 @@ def apply_rotary_pos_emb_flashatt(
|
|||||||
cos = cos.chunk(2, dim=-1)[0].contiguous()
|
cos = cos.chunk(2, dim=-1)[0].contiguous()
|
||||||
sin = sin.chunk(2, dim=-1)[0].contiguous()
|
sin = sin.chunk(2, dim=-1)[0].contiguous()
|
||||||
|
|
||||||
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
if current_platform.is_cuda():
|
||||||
|
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
||||||
|
elif current_platform.is_rocm():
|
||||||
|
from flash_attn.ops.triton.rotary import apply_rotary as apply_rotary_emb
|
||||||
|
|
||||||
q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
|
q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
|
||||||
k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
|
k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
|
||||||
@ -398,18 +405,28 @@ class KeyeSiglipAttention(nn.Module):
|
|||||||
attn_backend_override=attn_backend_override,
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.use_upstream_fa = False
|
self.attn_backend, self.flash_attn_varlen_func = (
|
||||||
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
|
maybe_get_vit_flash_attn_backend(
|
||||||
torch.get_default_dtype()
|
self.attn_backend,
|
||||||
):
|
use_upstream_fa=False,
|
||||||
self.attn_backend = _Backend.FLASH_ATTN
|
attn_backend_override=attn_backend_override,
|
||||||
self.use_upstream_fa = True
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if self.attn_backend not in {_Backend.FLASH_ATTN, _Backend.XFORMERS}:
|
if self.attn_backend not in {
|
||||||
|
_Backend.FLASH_ATTN,
|
||||||
|
_Backend.XFORMERS,
|
||||||
|
_Backend.ROCM_AITER_FA,
|
||||||
|
}:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Keye-VL does not support {self.attn_backend} backend now."
|
f"Keye-VL does not support {self.attn_backend} backend now."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.is_flash_attn_backend = self.attn_backend in {
|
||||||
|
_Backend.FLASH_ATTN,
|
||||||
|
_Backend.ROCM_AITER_FA,
|
||||||
|
}
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@ -457,15 +474,10 @@ class KeyeSiglipAttention(nn.Module):
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
if self.is_flash_attn_backend:
|
||||||
if self.use_upstream_fa:
|
|
||||||
from flash_attn import flash_attn_varlen_func
|
|
||||||
else:
|
|
||||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
|
||||||
|
|
||||||
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||||
|
|
||||||
output = flash_attn_varlen_func(
|
output = self.flash_attn_varlen_func(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
@ -1542,7 +1554,7 @@ class BaseKeyeModule(nn.Module):
|
|||||||
dummy_inputs=KeyeDummyInputsBuilder,
|
dummy_inputs=KeyeDummyInputsBuilder,
|
||||||
)
|
)
|
||||||
class KeyeForConditionalGeneration(
|
class KeyeForConditionalGeneration(
|
||||||
BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP
|
BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
|
||||||
):
|
):
|
||||||
def _build_projector(
|
def _build_projector(
|
||||||
self,
|
self,
|
||||||
@ -1611,3 +1623,142 @@ class KeyeForConditionalGeneration(
|
|||||||
return tuple(
|
return tuple(
|
||||||
self._process_video_embeds(video_type, video_grid_thw, pixel_values_videos)
|
self._process_video_embeds(video_type, video_grid_thw, pixel_values_videos)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_mrope_input_positions(
|
||||||
|
self,
|
||||||
|
input_tokens: list[int],
|
||||||
|
hf_config: PretrainedConfig,
|
||||||
|
image_grid_thw: list[list[int]] | torch.Tensor,
|
||||||
|
video_grid_thw: list[list[int]] | torch.Tensor,
|
||||||
|
context_len: int = 0,
|
||||||
|
seq_len: int | None = None,
|
||||||
|
second_per_grid_ts: list[float] | None = None,
|
||||||
|
audio_feature_lengths: torch.Tensor | None = None,
|
||||||
|
use_audio_in_video: bool = False,
|
||||||
|
) -> tuple[torch.Tensor, int]:
|
||||||
|
if isinstance(video_grid_thw, list) and len(video_grid_thw) > 0:
|
||||||
|
video_grid_thw = video_grid_thw[0]
|
||||||
|
"""Get mrope input positions and delta value (Keye series)."""
|
||||||
|
|
||||||
|
def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]:
|
||||||
|
"""
|
||||||
|
Split grid_thw along the t dimension.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grid_thw: shape [N, 3] tensor or nested list of [t, h, w].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of [1, h, w] rows, repeated t times for each original row.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(grid_thw, list):
|
||||||
|
grid_thw = torch.tensor(grid_thw, dtype=torch.long)
|
||||||
|
|
||||||
|
if grid_thw.numel() == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
t, hw = grid_thw[:, 0], grid_thw[:, 1:]
|
||||||
|
ones = torch.ones_like(hw[:, :1]) # [N,1]
|
||||||
|
out = torch.cat([ones, hw], dim=1).repeat_interleave(t, dim=0)
|
||||||
|
return out.tolist()
|
||||||
|
|
||||||
|
video_grid_thw = split_thw(video_grid_thw)
|
||||||
|
|
||||||
|
image_token_id = hf_config.image_token_id
|
||||||
|
video_token_id = hf_config.video_token_id
|
||||||
|
spatial_merge_size = hf_config.vision_config.spatial_merge_size
|
||||||
|
|
||||||
|
image_nums = len(image_grid_thw)
|
||||||
|
frame_nums = len(video_grid_thw)
|
||||||
|
llm_pos_ids_list: list = []
|
||||||
|
|
||||||
|
st = 0
|
||||||
|
remain_images, remain_frames = image_nums, frame_nums
|
||||||
|
|
||||||
|
image_index, video_index = 0, 0
|
||||||
|
for _ in range(image_nums + frame_nums):
|
||||||
|
if remain_images > 0:
|
||||||
|
try:
|
||||||
|
ed_image = input_tokens.index(image_token_id, st)
|
||||||
|
except ValueError:
|
||||||
|
ed_image = len(input_tokens) + 1
|
||||||
|
else:
|
||||||
|
ed_image = len(input_tokens) + 1
|
||||||
|
if remain_frames > 0:
|
||||||
|
try:
|
||||||
|
ed_video = input_tokens.index(video_token_id, st)
|
||||||
|
except ValueError:
|
||||||
|
ed_video = len(input_tokens) + 1
|
||||||
|
else:
|
||||||
|
ed_video = len(input_tokens) + 1
|
||||||
|
|
||||||
|
if ed_image < ed_video:
|
||||||
|
t, h, w = (
|
||||||
|
image_grid_thw[image_index][0],
|
||||||
|
image_grid_thw[image_index][1],
|
||||||
|
image_grid_thw[image_index][2],
|
||||||
|
)
|
||||||
|
image_index += 1
|
||||||
|
remain_images -= 1
|
||||||
|
ed = ed_image
|
||||||
|
else:
|
||||||
|
t, h, w = (
|
||||||
|
video_grid_thw[video_index][0],
|
||||||
|
video_grid_thw[video_index][1],
|
||||||
|
video_grid_thw[video_index][2],
|
||||||
|
)
|
||||||
|
video_index += 1
|
||||||
|
remain_frames -= 1
|
||||||
|
ed = ed_video
|
||||||
|
|
||||||
|
llm_grid_t, llm_grid_h, llm_grid_w = (
|
||||||
|
t,
|
||||||
|
h // spatial_merge_size,
|
||||||
|
w // spatial_merge_size,
|
||||||
|
)
|
||||||
|
text_len = ed - st
|
||||||
|
|
||||||
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||||
|
)
|
||||||
|
|
||||||
|
t_index = (
|
||||||
|
(
|
||||||
|
torch.arange(llm_grid_t)
|
||||||
|
.view(-1, 1)
|
||||||
|
.expand(-1, llm_grid_h * llm_grid_w)
|
||||||
|
)
|
||||||
|
.long()
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
|
||||||
|
h_index = (
|
||||||
|
torch.arange(llm_grid_h)
|
||||||
|
.view(1, -1, 1)
|
||||||
|
.expand(llm_grid_t, -1, llm_grid_w)
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
w_index = (
|
||||||
|
torch.arange(llm_grid_w)
|
||||||
|
.view(1, 1, -1)
|
||||||
|
.expand(llm_grid_t, llm_grid_h, -1)
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
|
||||||
|
)
|
||||||
|
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
||||||
|
|
||||||
|
if st < len(input_tokens):
|
||||||
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||||
|
text_len = len(input_tokens) - st
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||||
|
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
|
||||||
|
llm_positions = llm_positions[:, context_len:seq_len]
|
||||||
|
|
||||||
|
return llm_positions, mrope_position_delta
|
||||||
|
|||||||
@ -22,7 +22,6 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
|||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
ColumnParallelLinear,
|
ColumnParallelLinear,
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
|
||||||
ReplicatedLinear,
|
ReplicatedLinear,
|
||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
@ -61,7 +60,7 @@ class KimiMLP(nn.Module):
|
|||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
intermediate_size: int,
|
intermediate_size: int,
|
||||||
hidden_act: str,
|
hidden_act: str,
|
||||||
quant_config: QKVParallelLinear | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
reduce_results: bool = True,
|
reduce_results: bool = True,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -155,6 +154,7 @@ class KimiMoE(nn.Module):
|
|||||||
hidden_act=config.hidden_act,
|
hidden_act=config.hidden_act,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
reduce_results=False,
|
reduce_results=False,
|
||||||
|
prefix=f"{prefix}.shared_experts",
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
@ -340,7 +340,7 @@ class KimiDecoderLayer(nn.Module):
|
|||||||
self.block_sparse_moe = KimiMoE(
|
self.block_sparse_moe = KimiMoE(
|
||||||
config=config,
|
config=config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.mlp",
|
prefix=f"{prefix}.block_sparse_moe",
|
||||||
)
|
)
|
||||||
self.mlp = self.block_sparse_moe
|
self.mlp = self.block_sparse_moe
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -49,7 +49,7 @@ from functools import cached_property
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers.activations import ACT2FN, PytorchGELUTanh
|
from transformers.activations import ACT2FN
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from transformers.utils import is_flash_attn_2_available
|
from transformers.utils import is_flash_attn_2_available
|
||||||
|
|
||||||
@ -651,7 +651,7 @@ class MoonVitPretrainedModel(PreTrainedModel):
|
|||||||
"num_heads": config.num_attention_heads,
|
"num_heads": config.num_attention_heads,
|
||||||
"hidden_dim": config.hidden_size,
|
"hidden_dim": config.hidden_size,
|
||||||
"mlp_dim": config.intermediate_size,
|
"mlp_dim": config.intermediate_size,
|
||||||
"activation": PytorchGELUTanh(),
|
"activation": ACT2FN["gelu_pytorch_tanh"],
|
||||||
"attn_bias": True,
|
"attn_bias": True,
|
||||||
"attn_implementation": config._attn_implementation,
|
"attn_implementation": config._attn_implementation,
|
||||||
},
|
},
|
||||||
|
|||||||
@ -364,6 +364,8 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
|
|
||||||
if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN:
|
if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN:
|
||||||
self.use_upstream_fa = True
|
self.use_upstream_fa = True
|
||||||
|
if current_platform.is_xpu():
|
||||||
|
self.use_upstream_fa = False
|
||||||
self.is_flash_attn_backend = self.attn_backend in {
|
self.is_flash_attn_backend = self.attn_backend in {
|
||||||
_Backend.FLASH_ATTN,
|
_Backend.FLASH_ATTN,
|
||||||
_Backend.ROCM_AITER_FA,
|
_Backend.ROCM_AITER_FA,
|
||||||
@ -856,10 +858,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
max_seqlen = torch.zeros([], device=cu_seqlens.device)
|
max_seqlen = torch.zeros([], device=cu_seqlens.device)
|
||||||
seqlens = torch.zeros(1, device=cu_seqlens.device)
|
seqlens = torch.zeros(1, device=cu_seqlens.device)
|
||||||
if (
|
if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
|
||||||
self.attn_backend == _Backend.FLASH_ATTN
|
|
||||||
or self.attn_backend == _Backend.ROCM_AITER_FA
|
|
||||||
):
|
|
||||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||||
elif self.attn_backend == _Backend.XFORMERS:
|
elif self.attn_backend == _Backend.XFORMERS:
|
||||||
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
|
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||||
|
|||||||
@ -34,7 +34,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from transformers import AutoConfig, BatchFeature, PretrainedConfig
|
from transformers import BatchFeature, PretrainedConfig
|
||||||
from transformers.models.qwen2_vl import Qwen2VLImageProcessor, Qwen2VLProcessor
|
from transformers.models.qwen2_vl import Qwen2VLImageProcessor, Qwen2VLProcessor
|
||||||
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
|
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
|
||||||
Qwen2VLConfig,
|
Qwen2VLConfig,
|
||||||
@ -789,10 +789,7 @@ class Qwen2VisionTransformer(nn.Module):
|
|||||||
self, cu_seqlens: torch.Tensor
|
self, cu_seqlens: torch.Tensor
|
||||||
) -> tuple[int | None, list[int] | None]:
|
) -> tuple[int | None, list[int] | None]:
|
||||||
max_seqlen, seqlens = None, None
|
max_seqlen, seqlens = None, None
|
||||||
if (
|
if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
|
||||||
self.attn_backend == _Backend.FLASH_ATTN
|
|
||||||
or self.attn_backend == _Backend.ROCM_AITER_FA
|
|
||||||
):
|
|
||||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||||
elif self.attn_backend == _Backend.XFORMERS:
|
elif self.attn_backend == _Backend.XFORMERS:
|
||||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||||
@ -1654,9 +1651,7 @@ class Tarsier2Processor(Qwen2VLProcessor):
|
|||||||
class Tarsier2ProcessingInfo(Qwen2VLProcessingInfo):
|
class Tarsier2ProcessingInfo(Qwen2VLProcessingInfo):
|
||||||
def get_hf_config(self) -> Qwen2VLConfig:
|
def get_hf_config(self) -> Qwen2VLConfig:
|
||||||
model_path = self.ctx.model_config.model
|
model_path = self.ctx.model_config.model
|
||||||
original_config = AutoConfig.from_pretrained(model_path)
|
correct_config = Qwen2VLConfig.from_pretrained(model_path)
|
||||||
config_dict = original_config.to_dict()
|
|
||||||
correct_config = Qwen2VLConfig.from_dict(config_dict)
|
|
||||||
|
|
||||||
return correct_config
|
return correct_config
|
||||||
|
|
||||||
|
|||||||
@ -115,6 +115,12 @@ class XPUPlatform(Platform):
|
|||||||
device_props = torch.xpu.get_device_properties(device_id)
|
device_props = torch.xpu.get_device_properties(device_id)
|
||||||
return device_props.total_memory
|
return device_props.total_memory
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend:
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
|
|
||||||
|
return _Backend.FLASH_ATTN
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def inference_mode(cls):
|
def inference_mode(cls):
|
||||||
return torch.no_grad()
|
return torch.no_grad()
|
||||||
|
|||||||
@ -896,6 +896,8 @@ def get_kernel_options(
|
|||||||
return kernel_options
|
return kernel_options
|
||||||
else:
|
else:
|
||||||
preferred_block = 32 if query.dtype == torch.float32 else 64
|
preferred_block = 32 if query.dtype == torch.float32 else 64
|
||||||
|
block_lower_bound = 16
|
||||||
|
|
||||||
block_m_candidate = ensure_divisible(preferred_block, block_m)
|
block_m_candidate = ensure_divisible(preferred_block, block_m)
|
||||||
block_n_candidate = ensure_divisible(preferred_block, block_n)
|
block_n_candidate = ensure_divisible(preferred_block, block_n)
|
||||||
|
|
||||||
@ -910,6 +912,9 @@ def get_kernel_options(
|
|||||||
max(1, block_n_candidate // 2), block_n
|
max(1, block_n_candidate // 2), block_n
|
||||||
)
|
)
|
||||||
|
|
||||||
|
block_m_candidate = max(block_m_candidate, block_lower_bound)
|
||||||
|
block_n_candidate = max(block_n_candidate, block_lower_bound)
|
||||||
|
|
||||||
kernel_options["BLOCK_M"] = block_m_candidate
|
kernel_options["BLOCK_M"] = block_m_candidate
|
||||||
kernel_options["BLOCK_N"] = block_n_candidate
|
kernel_options["BLOCK_N"] = block_n_candidate
|
||||||
|
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from typing import ClassVar
|
|||||||
import torch
|
import torch
|
||||||
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
|
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionLayer, AttentionType
|
from vllm.attention.backends.abstract import AttentionLayer, AttentionType, MultipleOf
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.attention.backends.mla.common import (
|
from vllm.v1.attention.backends.mla.common import (
|
||||||
MLACommonBackend,
|
MLACommonBackend,
|
||||||
@ -40,6 +40,10 @@ class FlashInferMLABackend(MLACommonBackend):
|
|||||||
def get_builder_cls() -> type["FlashInferMLAMetadataBuilder"]:
|
def get_builder_cls() -> type["FlashInferMLAMetadataBuilder"]:
|
||||||
return FlashInferMLAMetadataBuilder
|
return FlashInferMLAMetadataBuilder
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]:
|
||||||
|
return [32, 64]
|
||||||
|
|
||||||
|
|
||||||
g_fi_workspace = torch.zeros(
|
g_fi_workspace = torch.zeros(
|
||||||
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE,
|
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE,
|
||||||
|
|||||||
@ -15,8 +15,12 @@ class AsyncScheduler(Scheduler):
|
|||||||
scheduler_output: SchedulerOutput,
|
scheduler_output: SchedulerOutput,
|
||||||
) -> None:
|
) -> None:
|
||||||
super()._update_after_schedule(scheduler_output)
|
super()._update_after_schedule(scheduler_output)
|
||||||
|
pending_structured_output_tokens = False
|
||||||
for req_id in scheduler_output.num_scheduled_tokens:
|
for req_id in scheduler_output.num_scheduled_tokens:
|
||||||
request = self.requests[req_id]
|
request = self.requests[req_id]
|
||||||
|
pending_structured_output_tokens |= (
|
||||||
|
request.use_structured_output and request.num_output_placeholders > 0
|
||||||
|
)
|
||||||
if (
|
if (
|
||||||
request.num_computed_tokens
|
request.num_computed_tokens
|
||||||
== request.num_tokens + request.num_output_placeholders
|
== request.num_tokens + request.num_output_placeholders
|
||||||
@ -25,6 +29,10 @@ class AsyncScheduler(Scheduler):
|
|||||||
# TODO(woosuk): Support speculative decoding.
|
# TODO(woosuk): Support speculative decoding.
|
||||||
request.num_output_placeholders += 1
|
request.num_output_placeholders += 1
|
||||||
|
|
||||||
|
scheduler_output.pending_structured_output_tokens = (
|
||||||
|
pending_structured_output_tokens
|
||||||
|
)
|
||||||
|
|
||||||
def _update_request_with_output(
|
def _update_request_with_output(
|
||||||
self,
|
self,
|
||||||
request: Request,
|
request: Request,
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||||
from vllm.v1.engine import EngineCoreOutputs
|
from vllm.v1.engine import EngineCoreOutputs
|
||||||
from vllm.v1.metrics.stats import SchedulerStats
|
from vllm.v1.metrics.stats import SchedulerStats
|
||||||
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
||||||
@ -40,6 +40,12 @@ class SchedulerInterface(ABC):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_grammar_bitmask(
|
||||||
|
self, scheduler_output: "SchedulerOutput"
|
||||||
|
) -> "GrammarOutput | None":
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def update_from_output(
|
def update_from_output(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -181,12 +181,17 @@ class SchedulerOutput:
|
|||||||
# freed from the encoder cache.
|
# freed from the encoder cache.
|
||||||
free_encoder_mm_hashes: list[str]
|
free_encoder_mm_hashes: list[str]
|
||||||
|
|
||||||
# ids of structured outputs requests included in the bitmask, in the
|
# Whether the scheduled requests have all the output tokens they
|
||||||
# same order as the corresponding stacked rows of the bitmask.
|
# need to perform grammar bitmask computation.
|
||||||
# There may be more than one row per request in the case of speculative decoding.
|
pending_structured_output_tokens: bool = False
|
||||||
structured_output_request_ids: list[str]
|
|
||||||
# the bitmask for the whole batch
|
|
||||||
grammar_bitmask: "npt.NDArray[np.int32] | None"
|
|
||||||
|
|
||||||
# KV Cache Connector metadata.
|
# KV Cache Connector metadata.
|
||||||
kv_connector_metadata: KVConnectorMetadata | None = None
|
kv_connector_metadata: KVConnectorMetadata | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GrammarOutput:
|
||||||
|
# ids of structured output requests.
|
||||||
|
structured_output_request_ids: list[str]
|
||||||
|
# Bitmask ordered as structured_output_request_ids.
|
||||||
|
grammar_bitmask: "npt.NDArray[np.int32]"
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import itertools
|
|||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import Any
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
|
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
|
||||||
@ -24,7 +24,12 @@ from vllm.v1.core.encoder_cache_manager import (
|
|||||||
)
|
)
|
||||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
|
||||||
from vllm.v1.core.sched.interface import SchedulerInterface
|
from vllm.v1.core.sched.interface import SchedulerInterface
|
||||||
from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput
|
from vllm.v1.core.sched.output import (
|
||||||
|
CachedRequestData,
|
||||||
|
GrammarOutput,
|
||||||
|
NewRequestData,
|
||||||
|
SchedulerOutput,
|
||||||
|
)
|
||||||
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
|
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
|
||||||
from vllm.v1.core.sched.utils import check_stop, remove_all
|
from vllm.v1.core.sched.utils import check_stop, remove_all
|
||||||
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
|
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
|
||||||
@ -35,10 +40,6 @@ from vllm.v1.request import Request, RequestStatus
|
|||||||
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
||||||
from vllm.v1.structured_output import StructuredOutputManager
|
from vllm.v1.structured_output import StructuredOutputManager
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
import numpy as np
|
|
||||||
import numpy.typing as npt
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -619,9 +620,6 @@ class Scheduler(SchedulerInterface):
|
|||||||
scheduled_spec_decode_tokens,
|
scheduled_spec_decode_tokens,
|
||||||
req_to_new_blocks,
|
req_to_new_blocks,
|
||||||
)
|
)
|
||||||
structured_output_request_ids, grammar_bitmask = self.get_grammar_bitmask(
|
|
||||||
num_scheduled_tokens.keys(), scheduled_spec_decode_tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
# Record the request ids that were scheduled in this step.
|
# Record the request ids that were scheduled in this step.
|
||||||
self.prev_step_scheduled_req_ids.clear()
|
self.prev_step_scheduled_req_ids.clear()
|
||||||
@ -641,8 +639,6 @@ class Scheduler(SchedulerInterface):
|
|||||||
# the previous and the current steps.
|
# the previous and the current steps.
|
||||||
finished_req_ids=self.finished_req_ids,
|
finished_req_ids=self.finished_req_ids,
|
||||||
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
|
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
|
||||||
structured_output_request_ids=structured_output_request_ids,
|
|
||||||
grammar_bitmask=grammar_bitmask,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# NOTE(Kuntai): this function is designed for multiple purposes:
|
# NOTE(Kuntai): this function is designed for multiple purposes:
|
||||||
@ -872,9 +868,8 @@ class Scheduler(SchedulerInterface):
|
|||||||
|
|
||||||
def get_grammar_bitmask(
|
def get_grammar_bitmask(
|
||||||
self,
|
self,
|
||||||
scheduled_request_ids: Iterable[str],
|
scheduler_output: SchedulerOutput,
|
||||||
scheduled_spec_decode_tokens: dict[str, list[int]],
|
) -> GrammarOutput | None:
|
||||||
) -> tuple[list[str], "npt.NDArray[np.int32] | None"]:
|
|
||||||
# Collect list of scheduled request ids that use structured output.
|
# Collect list of scheduled request ids that use structured output.
|
||||||
# The corresponding rows of the bitmask will be in this order.
|
# The corresponding rows of the bitmask will be in this order.
|
||||||
# PERF: in case of chunked prefill,
|
# PERF: in case of chunked prefill,
|
||||||
@ -883,18 +878,18 @@ class Scheduler(SchedulerInterface):
|
|||||||
# cycle to fill in the bitmask, which could be a big no-op.
|
# cycle to fill in the bitmask, which could be a big no-op.
|
||||||
structured_output_request_ids = [
|
structured_output_request_ids = [
|
||||||
req_id
|
req_id
|
||||||
for req_id in scheduled_request_ids
|
for req_id in scheduler_output.num_scheduled_tokens
|
||||||
if (req := self.requests.get(req_id)) and req.use_structured_output
|
if (req := self.requests.get(req_id)) and req.use_structured_output
|
||||||
]
|
]
|
||||||
if not structured_output_request_ids:
|
if not structured_output_request_ids:
|
||||||
return structured_output_request_ids, None
|
return None
|
||||||
|
|
||||||
bitmask = self.structured_output_manager.grammar_bitmask(
|
bitmask = self.structured_output_manager.grammar_bitmask(
|
||||||
self.requests,
|
self.requests,
|
||||||
structured_output_request_ids,
|
structured_output_request_ids,
|
||||||
scheduled_spec_decode_tokens,
|
scheduler_output.scheduled_spec_decode_tokens,
|
||||||
)
|
)
|
||||||
return structured_output_request_ids, bitmask
|
return GrammarOutput(structured_output_request_ids, bitmask)
|
||||||
|
|
||||||
def update_from_output(
|
def update_from_output(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from concurrent.futures import Future
|
|||||||
from contextlib import ExitStack, contextmanager
|
from contextlib import ExitStack, contextmanager
|
||||||
from inspect import isclass, signature
|
from inspect import isclass, signature
|
||||||
from logging import DEBUG
|
from logging import DEBUG
|
||||||
from typing import Any, TypeVar
|
from typing import Any, TypeVar, cast
|
||||||
|
|
||||||
import msgspec
|
import msgspec
|
||||||
import zmq
|
import zmq
|
||||||
@ -163,6 +163,27 @@ class EngineCore:
|
|||||||
vllm_config, mm_registry
|
vllm_config, mm_registry
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# If a KV connector is initialized for scheduler, we want to collect
|
||||||
|
# handshake metadata from all workers so the connector in the scheduler
|
||||||
|
# will have the full context
|
||||||
|
kv_connector = self.scheduler.get_kv_connector()
|
||||||
|
if kv_connector is not None:
|
||||||
|
# Collect and store KV connector xfer metadata from workers
|
||||||
|
# (after KV cache registration)
|
||||||
|
xfer_handshake_metadata = (
|
||||||
|
self.model_executor.get_kv_connector_handshake_metadata()
|
||||||
|
)
|
||||||
|
|
||||||
|
if xfer_handshake_metadata:
|
||||||
|
# xfer_handshake_metadata is list of dicts from workers
|
||||||
|
# Each dict already has structure {tp_rank: metadata}
|
||||||
|
# Merge all worker dicts into a single dict
|
||||||
|
content: dict[int, Any] = {}
|
||||||
|
for worker_dict in xfer_handshake_metadata:
|
||||||
|
if worker_dict is not None:
|
||||||
|
content.update(worker_dict)
|
||||||
|
kv_connector.set_xfer_handshake_metadata(content)
|
||||||
|
|
||||||
# Setup batch queue for pipeline parallelism.
|
# Setup batch queue for pipeline parallelism.
|
||||||
# Batch queue for scheduled batches. This enables us to asynchronously
|
# Batch queue for scheduled batches. This enables us to asynchronously
|
||||||
# schedule and execute batches, and is required by pipeline parallelism
|
# schedule and execute batches, and is required by pipeline parallelism
|
||||||
@ -178,7 +199,7 @@ class EngineCore:
|
|||||||
self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None
|
self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None
|
||||||
if (
|
if (
|
||||||
self.vllm_config.cache_config.enable_prefix_caching
|
self.vllm_config.cache_config.enable_prefix_caching
|
||||||
or self.scheduler.get_kv_connector() is not None
|
or kv_connector is not None
|
||||||
):
|
):
|
||||||
caching_hash_fn = get_hash_fn_by_name(
|
caching_hash_fn = get_hash_fn_by_name(
|
||||||
vllm_config.cache_config.prefix_caching_hash_algo
|
vllm_config.cache_config.prefix_caching_hash_algo
|
||||||
@ -313,9 +334,12 @@ class EngineCore:
|
|||||||
if not self.scheduler.has_requests():
|
if not self.scheduler.has_requests():
|
||||||
return {}, False
|
return {}, False
|
||||||
scheduler_output = self.scheduler.schedule()
|
scheduler_output = self.scheduler.schedule()
|
||||||
|
future = self.model_executor.execute_model(scheduler_output, non_block=True)
|
||||||
|
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
|
||||||
with self.log_error_detail(scheduler_output):
|
with self.log_error_detail(scheduler_output):
|
||||||
model_output = self.model_executor.execute_model(scheduler_output)
|
model_output = future.result()
|
||||||
|
if model_output is None:
|
||||||
|
model_output = self.model_executor.sample_tokens(grammar_output)
|
||||||
|
|
||||||
engine_core_outputs = self.scheduler.update_from_output(
|
engine_core_outputs = self.scheduler.update_from_output(
|
||||||
scheduler_output, model_output
|
scheduler_output, model_output
|
||||||
@ -355,20 +379,47 @@ class EngineCore:
|
|||||||
assert len(batch_queue) < self.batch_queue_size
|
assert len(batch_queue) < self.batch_queue_size
|
||||||
|
|
||||||
model_executed = False
|
model_executed = False
|
||||||
|
deferred_scheduler_output = None
|
||||||
if self.scheduler.has_requests():
|
if self.scheduler.has_requests():
|
||||||
scheduler_output = self.scheduler.schedule()
|
scheduler_output = self.scheduler.schedule()
|
||||||
future = self.model_executor.execute_model(scheduler_output, non_block=True)
|
exec_future = self.model_executor.execute_model(
|
||||||
batch_queue.appendleft((future, scheduler_output))
|
scheduler_output, non_block=True
|
||||||
|
)
|
||||||
model_executed = scheduler_output.total_num_scheduled_tokens > 0
|
model_executed = scheduler_output.total_num_scheduled_tokens > 0
|
||||||
if (
|
|
||||||
model_executed
|
if scheduler_output.pending_structured_output_tokens:
|
||||||
and len(batch_queue) < self.batch_queue_size
|
# We need to defer sampling until we have processed the model output
|
||||||
and not batch_queue[-1][0].done()
|
# from the prior step.
|
||||||
):
|
deferred_scheduler_output = scheduler_output
|
||||||
# Don't block on next worker response unless the queue is full
|
# Block-wait for execute to return (continues running async on the GPU).
|
||||||
# or there are no more requests to schedule.
|
with self.log_error_detail(scheduler_output):
|
||||||
return None, True
|
exec_result = exec_future.result()
|
||||||
|
assert exec_result is None
|
||||||
|
else:
|
||||||
|
# We aren't waiting for any tokens, get any grammar output immediately.
|
||||||
|
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
|
||||||
|
# Block-wait for execute to return (continues running async on the GPU).
|
||||||
|
with self.log_error_detail(scheduler_output):
|
||||||
|
exec_result = exec_future.result()
|
||||||
|
|
||||||
|
if exec_result is None:
|
||||||
|
# Call sample tokens.
|
||||||
|
future = self.model_executor.sample_tokens(
|
||||||
|
grammar_output, non_block=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# No sampling required (e.g. all requests finished).
|
||||||
|
future = cast(Future[ModelRunnerOutput], exec_future)
|
||||||
|
# Add this step's future to the queue.
|
||||||
|
batch_queue.appendleft((future, scheduler_output))
|
||||||
|
if (
|
||||||
|
model_executed
|
||||||
|
and len(batch_queue) < self.batch_queue_size
|
||||||
|
and not batch_queue[-1][0].done()
|
||||||
|
):
|
||||||
|
# Don't block on next worker response unless the queue is full
|
||||||
|
# or there are no more requests to schedule.
|
||||||
|
return None, True
|
||||||
|
|
||||||
elif not batch_queue:
|
elif not batch_queue:
|
||||||
# Queue is empty. We should not reach here since this method should
|
# Queue is empty. We should not reach here since this method should
|
||||||
@ -384,6 +435,19 @@ class EngineCore:
|
|||||||
engine_core_outputs = self.scheduler.update_from_output(
|
engine_core_outputs = self.scheduler.update_from_output(
|
||||||
scheduler_output, model_output
|
scheduler_output, model_output
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# NOTE(nick): We can either handle the deferred tasks here or save
|
||||||
|
# in a field and do it immediately once step_with_batch_queue is
|
||||||
|
# re-called. The latter slightly favors TTFT over TPOT/throughput.
|
||||||
|
if deferred_scheduler_output:
|
||||||
|
# We now have the tokens needed to compute the bitmask for the
|
||||||
|
# deferred request. Get the bitmask and call sample tokens.
|
||||||
|
grammar_output = self.scheduler.get_grammar_bitmask(
|
||||||
|
deferred_scheduler_output
|
||||||
|
)
|
||||||
|
future = self.model_executor.sample_tokens(grammar_output, non_block=True)
|
||||||
|
batch_queue.appendleft((future, deferred_scheduler_output))
|
||||||
|
|
||||||
return engine_core_outputs, model_executed
|
return engine_core_outputs, model_executed
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
|
|||||||
@ -9,11 +9,14 @@ from typing import TYPE_CHECKING, Literal, TypeVar, overload
|
|||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||||
|
KVConnectorHandshakeMetadata,
|
||||||
|
)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.tasks import SupportedTask
|
from vllm.tasks import SupportedTask
|
||||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||||
from vllm.v1.engine import ReconfigureDistributedRequest
|
from vllm.v1.engine import ReconfigureDistributedRequest
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||||
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
||||||
@ -177,30 +180,51 @@ class Executor(ABC):
|
|||||||
):
|
):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_kv_connector_handshake_metadata(
|
||||||
|
self,
|
||||||
|
) -> list[dict[int, KVConnectorHandshakeMetadata]]:
|
||||||
|
return self.collective_rpc("get_kv_connector_handshake_metadata")
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self, scheduler_output: SchedulerOutput, non_block: Literal[False] = False
|
||||||
scheduler_output: SchedulerOutput,
|
) -> ModelRunnerOutput | None:
|
||||||
non_block: Literal[False] = False,
|
|
||||||
) -> ModelRunnerOutput:
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self, scheduler_output: SchedulerOutput, non_block: Literal[True] = True
|
||||||
scheduler_output: SchedulerOutput,
|
) -> Future[ModelRunnerOutput | None]:
|
||||||
non_block: Literal[True] = True,
|
|
||||||
) -> Future[ModelRunnerOutput]:
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self, scheduler_output: SchedulerOutput, non_block: bool = False
|
self, scheduler_output: SchedulerOutput, non_block: bool = False
|
||||||
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
|
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
|
||||||
output = self.collective_rpc( # type: ignore[call-overload]
|
output = self.collective_rpc( # type: ignore[call-overload]
|
||||||
"execute_model", args=(scheduler_output,), non_block=non_block
|
"execute_model", args=(scheduler_output,), non_block=non_block
|
||||||
)
|
)
|
||||||
return output[0]
|
return output[0]
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def sample_tokens(
|
||||||
|
self, grammar_output: GrammarOutput | None, non_block: Literal[False] = False
|
||||||
|
) -> ModelRunnerOutput:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def sample_tokens(
|
||||||
|
self, grammar_output: GrammarOutput | None, non_block: Literal[True] = True
|
||||||
|
) -> Future[ModelRunnerOutput]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def sample_tokens(
|
||||||
|
self, grammar_output: GrammarOutput | None, non_block: bool = False
|
||||||
|
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
|
||||||
|
output = self.collective_rpc( # type: ignore[call-overload]
|
||||||
|
"sample_tokens", args=(grammar_output,), non_block=non_block
|
||||||
|
)
|
||||||
|
return output[0]
|
||||||
|
|
||||||
def execute_dummy_batch(self) -> None:
|
def execute_dummy_batch(self) -> None:
|
||||||
self.collective_rpc("execute_dummy_batch")
|
self.collective_rpc("execute_dummy_batch")
|
||||||
|
|
||||||
|
|||||||
@ -46,7 +46,7 @@ from vllm.utils.system_utils import (
|
|||||||
get_mp_context,
|
get_mp_context,
|
||||||
set_process_title,
|
set_process_title,
|
||||||
)
|
)
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||||
from vllm.v1.executor.abstract import Executor, FailureCallback
|
from vllm.v1.executor.abstract import Executor, FailureCallback
|
||||||
from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput
|
from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput
|
||||||
from vllm.v1.worker.worker_base import WorkerWrapperBase
|
from vllm.v1.worker.worker_base import WorkerWrapperBase
|
||||||
@ -132,15 +132,12 @@ class MultiprocExecutor(Executor):
|
|||||||
uw.death_writer.close()
|
uw.death_writer.close()
|
||||||
self._ensure_worker_termination([uw.proc for uw in unready_workers])
|
self._ensure_worker_termination([uw.proc for uw in unready_workers])
|
||||||
|
|
||||||
# For pipeline parallel, we use a thread pool for asynchronous
|
# Note: must use only 1 IO thread to keep dequeue sequence
|
||||||
# execute_model.
|
# from the response queue.
|
||||||
if self.max_concurrent_batches > 1:
|
# _async_aggregate_workers_output also assumes a single IO thread.
|
||||||
# Note: must use only 1 IO thread to keep dequeue sequence
|
self.io_thread_pool = ThreadPoolExecutor(
|
||||||
# from the response queue
|
max_workers=1, thread_name_prefix="mp_exec_io"
|
||||||
# _async_aggregate_workers_output also assumes a single IO thread
|
)
|
||||||
self.io_thread_pool = ThreadPoolExecutor(
|
|
||||||
max_workers=1, thread_name_prefix="mp_exec_io"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.output_rank = self._get_output_rank()
|
self.output_rank = self._get_output_rank()
|
||||||
self.has_connector = self.vllm_config.kv_transfer_config is not None
|
self.has_connector = self.vllm_config.kv_transfer_config is not None
|
||||||
@ -180,15 +177,27 @@ class MultiprocExecutor(Executor):
|
|||||||
self.failure_callback = callback
|
self.failure_callback = callback
|
||||||
|
|
||||||
def execute_model( # type: ignore[override]
|
def execute_model( # type: ignore[override]
|
||||||
self,
|
self, scheduler_output: SchedulerOutput, non_block: bool = False
|
||||||
scheduler_output: SchedulerOutput,
|
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
|
||||||
non_block: bool = False,
|
return self._execute_with_aggregation(
|
||||||
|
"execute_model", scheduler_output, non_block=non_block
|
||||||
|
)
|
||||||
|
|
||||||
|
def sample_tokens( # type: ignore[override]
|
||||||
|
self, grammar_output: GrammarOutput | None, non_block: bool = False
|
||||||
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
|
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
|
||||||
|
return self._execute_with_aggregation( # type: ignore[return-value]
|
||||||
|
"sample_tokens", grammar_output, non_block=non_block
|
||||||
|
)
|
||||||
|
|
||||||
|
def _execute_with_aggregation(
|
||||||
|
self, method: str, *args, non_block: bool = False
|
||||||
|
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
|
||||||
if not self.has_connector:
|
if not self.has_connector:
|
||||||
# get output only from a single worker (output_rank)
|
# get output only from a single worker (output_rank)
|
||||||
(output,) = self.collective_rpc(
|
(output,) = self.collective_rpc(
|
||||||
"execute_model",
|
method,
|
||||||
args=(scheduler_output,),
|
args=args,
|
||||||
unique_reply_rank=self.output_rank,
|
unique_reply_rank=self.output_rank,
|
||||||
non_block=non_block,
|
non_block=non_block,
|
||||||
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
|
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
|
||||||
@ -197,8 +206,8 @@ class MultiprocExecutor(Executor):
|
|||||||
|
|
||||||
# get output from all workers
|
# get output from all workers
|
||||||
outputs = self.collective_rpc(
|
outputs = self.collective_rpc(
|
||||||
"execute_model",
|
method,
|
||||||
args=(scheduler_output,),
|
args=args,
|
||||||
non_block=non_block,
|
non_block=non_block,
|
||||||
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
|
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -19,7 +19,7 @@ from vllm.utils.network_utils import (
|
|||||||
get_ip,
|
get_ip,
|
||||||
get_open_port,
|
get_open_port,
|
||||||
)
|
)
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||||
from vllm.v1.executor.abstract import Executor
|
from vllm.v1.executor.abstract import Executor
|
||||||
from vllm.v1.executor.ray_utils import (
|
from vllm.v1.executor.ray_utils import (
|
||||||
@ -41,6 +41,9 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
COMPLETED_NONE_FUTURE: Future[ModelRunnerOutput | None] = Future()
|
||||||
|
COMPLETED_NONE_FUTURE.set_result(None)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RayWorkerMetaData:
|
class RayWorkerMetaData:
|
||||||
@ -96,6 +99,8 @@ class RayDistributedExecutor(Executor):
|
|||||||
# KV connector setup
|
# KV connector setup
|
||||||
self.has_connector = self.vllm_config.kv_transfer_config is not None
|
self.has_connector = self.vllm_config.kv_transfer_config is not None
|
||||||
|
|
||||||
|
self.scheduler_output: SchedulerOutput | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def max_concurrent_batches(self) -> int:
|
def max_concurrent_batches(self) -> int:
|
||||||
"""Ray distributed executor supports pipeline parallelism,
|
"""Ray distributed executor supports pipeline parallelism,
|
||||||
@ -381,22 +386,46 @@ class RayDistributedExecutor(Executor):
|
|||||||
self.shutdown()
|
self.shutdown()
|
||||||
|
|
||||||
def execute_model( # type: ignore[override]
|
def execute_model( # type: ignore[override]
|
||||||
self, scheduler_output: SchedulerOutput, non_block: bool = False
|
self,
|
||||||
|
scheduler_output: SchedulerOutput,
|
||||||
|
non_block: bool = False,
|
||||||
|
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
|
||||||
|
if self.scheduler_output is not None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"State error: sample_tokens() must be called "
|
||||||
|
"after execute_model() returns None."
|
||||||
|
)
|
||||||
|
self.scheduler_output = scheduler_output
|
||||||
|
return COMPLETED_NONE_FUTURE if non_block else None
|
||||||
|
|
||||||
|
def sample_tokens( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
grammar_output: "GrammarOutput | None",
|
||||||
|
non_block: bool = False,
|
||||||
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
|
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
|
||||||
"""Execute the model on the Ray workers.
|
"""Execute the model on the Ray workers.
|
||||||
|
|
||||||
|
The scheduler output to use should have been provided in
|
||||||
|
a prior call to execute_model().
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
scheduler_output: The scheduler output to execute.
|
grammar_output: The structured outputs grammar bitmask, if applicable.
|
||||||
non_block: If True, the method will return a Future.
|
non_block: If True, the method will return a Future.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The model runner output.
|
The model runner output.
|
||||||
"""
|
"""
|
||||||
|
scheduler_output = self.scheduler_output
|
||||||
|
if scheduler_output is None:
|
||||||
|
return None # noqa
|
||||||
|
|
||||||
|
self.scheduler_output = None
|
||||||
|
|
||||||
# Build the compiled DAG for the first time.
|
# Build the compiled DAG for the first time.
|
||||||
if self.forward_dag is None: # type: ignore
|
if self.forward_dag is None: # type: ignore
|
||||||
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
|
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
|
||||||
|
|
||||||
refs = self.forward_dag.execute(scheduler_output) # type: ignore
|
refs = self.forward_dag.execute((scheduler_output, grammar_output)) # type: ignore
|
||||||
|
|
||||||
if not self.has_connector:
|
if not self.has_connector:
|
||||||
# Get output only from a single worker (output_rank)
|
# Get output only from a single worker (output_rank)
|
||||||
|
|||||||
@ -19,7 +19,7 @@ from vllm.v1.outputs import AsyncModelRunnerOutput
|
|||||||
from vllm.v1.worker.worker_base import WorkerWrapperBase
|
from vllm.v1.worker.worker_base import WorkerWrapperBase
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -82,36 +82,41 @@ try:
|
|||||||
|
|
||||||
def execute_model_ray(
|
def execute_model_ray(
|
||||||
self,
|
self,
|
||||||
scheduler_output: Union[
|
execute_model_input: tuple["SchedulerOutput", "GrammarOutput"]
|
||||||
"SchedulerOutput", tuple["SchedulerOutput", "IntermediateTensors"]
|
| tuple["SchedulerOutput", "GrammarOutput", "IntermediateTensors"],
|
||||||
],
|
|
||||||
) -> Union[
|
) -> Union[
|
||||||
"ModelRunnerOutput", tuple["SchedulerOutput", "IntermediateTensors"]
|
"ModelRunnerOutput",
|
||||||
|
tuple["SchedulerOutput", "GrammarOutput", "IntermediateTensors"],
|
||||||
]:
|
]:
|
||||||
# This method is used by Ray Compiled Graph to execute the model,
|
# This method is used by Ray Compiled Graph to execute the model,
|
||||||
# and it needs a special logic of self.setup_device_if_necessary()
|
# and it needs a special logic of self.setup_device_if_necessary()
|
||||||
self.setup_device_if_necessary()
|
self.setup_device_if_necessary()
|
||||||
assert self.worker is not None, "Worker is not initialized"
|
assert self.worker is not None, "Worker is not initialized"
|
||||||
if isinstance(scheduler_output, tuple):
|
if len(execute_model_input) == 3:
|
||||||
scheduler_output, intermediate_tensors = scheduler_output
|
scheduler_output, grammar_output, intermediate_tensors = (
|
||||||
|
execute_model_input
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
scheduler_output, intermediate_tensors = scheduler_output, None
|
scheduler_output, grammar_output = execute_model_input
|
||||||
|
intermediate_tensors = None
|
||||||
assert self.worker.model_runner is not None
|
assert self.worker.model_runner is not None
|
||||||
output = self.worker.model_runner.execute_model(
|
output = self.worker.model_runner.execute_model(
|
||||||
scheduler_output, intermediate_tensors
|
scheduler_output, intermediate_tensors
|
||||||
)
|
)
|
||||||
if isinstance(output, IntermediateTensors):
|
if isinstance(output, IntermediateTensors):
|
||||||
output = scheduler_output, output
|
output = scheduler_output, grammar_output, output
|
||||||
elif not get_pp_group().is_last_rank:
|
elif not get_pp_group().is_last_rank:
|
||||||
# Case where there are no scheduled requests
|
# Case where there are no scheduled requests
|
||||||
# but may still be finished requests.
|
# but may still be finished requests.
|
||||||
assert not output or not output.req_ids
|
assert not output or not output.req_ids
|
||||||
output = scheduler_output, None
|
output = scheduler_output, grammar_output, None
|
||||||
# Ensure outputs crossing Ray compiled DAG are serializable.
|
elif output is None:
|
||||||
# AsyncModelRunnerOutput holds CUDA events and cannot be
|
output = self.worker.model_runner.sample_tokens(grammar_output)
|
||||||
# pickled.
|
# Ensure outputs crossing Ray compiled DAG are serializable.
|
||||||
if isinstance(output, AsyncModelRunnerOutput):
|
# AsyncModelRunnerOutput holds CUDA events and cannot be
|
||||||
output = output.get_output()
|
# pickled.
|
||||||
|
if isinstance(output, AsyncModelRunnerOutput):
|
||||||
|
output = output.get_output()
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def override_env_vars(self, vars: dict[str, str]):
|
def override_env_vars(self, vars: dict[str, str]):
|
||||||
|
|||||||
@ -16,6 +16,7 @@ from diskcache import Cache
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils.import_utils import LazyLoader
|
from vllm.utils.import_utils import LazyLoader
|
||||||
|
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import outlines_core as oc
|
import outlines_core as oc
|
||||||
@ -24,7 +25,6 @@ if TYPE_CHECKING:
|
|||||||
import xgrammar as xgr
|
import xgrammar as xgr
|
||||||
|
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
|
||||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||||
else:
|
else:
|
||||||
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
||||||
@ -47,6 +47,7 @@ CACHE = None
|
|||||||
|
|
||||||
def apply_grammar_bitmask(
|
def apply_grammar_bitmask(
|
||||||
scheduler_output: SchedulerOutput,
|
scheduler_output: SchedulerOutput,
|
||||||
|
grammar_output: GrammarOutput,
|
||||||
input_batch: InputBatch,
|
input_batch: InputBatch,
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -58,9 +59,9 @@ def apply_grammar_bitmask(
|
|||||||
input_batch (InputBatch): The input of model runner.
|
input_batch (InputBatch): The input of model runner.
|
||||||
logits (torch.Tensor): The output logits of model forward.
|
logits (torch.Tensor): The output logits of model forward.
|
||||||
"""
|
"""
|
||||||
grammar_bitmask = scheduler_output.grammar_bitmask
|
# Serialization of np.ndarray is much more efficient than a tensor,
|
||||||
if grammar_bitmask is None:
|
# so we receive it in that format.
|
||||||
return
|
grammar_bitmask = grammar_output.grammar_bitmask
|
||||||
|
|
||||||
# We receive the structured output bitmask from the scheduler,
|
# We receive the structured output bitmask from the scheduler,
|
||||||
# compacted to contain bitmasks only for structured output requests.
|
# compacted to contain bitmasks only for structured output requests.
|
||||||
@ -79,7 +80,7 @@ def apply_grammar_bitmask(
|
|||||||
cumulative_offset += len(
|
cumulative_offset += len(
|
||||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
|
scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
|
||||||
)
|
)
|
||||||
if req_id in scheduler_output.structured_output_request_ids:
|
if req_id in grammar_output.structured_output_request_ids:
|
||||||
struct_out_req_batch_indices[req_id] = logit_index
|
struct_out_req_batch_indices[req_id] = logit_index
|
||||||
|
|
||||||
out_indices = []
|
out_indices = []
|
||||||
@ -91,7 +92,7 @@ def apply_grammar_bitmask(
|
|||||||
dtype=grammar_bitmask.dtype,
|
dtype=grammar_bitmask.dtype,
|
||||||
)
|
)
|
||||||
cumulative_index = 0
|
cumulative_index = 0
|
||||||
for req_id in scheduler_output.structured_output_request_ids:
|
for req_id in grammar_output.structured_output_request_ids:
|
||||||
num_spec_tokens = len(
|
num_spec_tokens = len(
|
||||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
|
scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
|
||||||
)
|
)
|
||||||
@ -101,22 +102,28 @@ def apply_grammar_bitmask(
|
|||||||
sorted_bitmask[logit_index + i] = grammar_bitmask[cumulative_index + i]
|
sorted_bitmask[logit_index + i] = grammar_bitmask[cumulative_index + i]
|
||||||
out_indices.append(logit_index + i)
|
out_indices.append(logit_index + i)
|
||||||
cumulative_index += 1 + num_spec_tokens
|
cumulative_index += 1 + num_spec_tokens
|
||||||
grammar_bitmask = sorted_bitmask
|
|
||||||
|
# Copy async to device as tensor.
|
||||||
|
grammar_bitmask = torch.from_numpy(sorted_bitmask).to(
|
||||||
|
logits.device, non_blocking=True
|
||||||
|
)
|
||||||
|
|
||||||
# If the length of out indices and the logits have the same shape
|
# If the length of out indices and the logits have the same shape
|
||||||
# we don't need to pass indices to the kernel,
|
# we don't need to pass indices to the kernel,
|
||||||
# since the bitmask is already aligned with the logits.
|
# since the bitmask is already aligned with the logits.
|
||||||
skip_out_indices = len(out_indices) == logits.shape[0]
|
skip_out_indices = len(out_indices) == logits.shape[0]
|
||||||
|
|
||||||
# Serialization of np.ndarray is much more efficient than a tensor,
|
index_tensor = None
|
||||||
# so we receive it in that format.
|
if not skip_out_indices:
|
||||||
grammar_bitmask = torch.from_numpy(grammar_bitmask).contiguous()
|
# xgrammar expects a python list of indices but it will actually work with
|
||||||
|
# a tensor. If we copy the tensor ourselves here we can do it in a non_blocking
|
||||||
|
# manner and there should be no cpu sync within xgrammar.
|
||||||
|
index_tensor = torch.tensor(
|
||||||
|
out_indices, dtype=torch.int32, device="cpu", pin_memory=True
|
||||||
|
)
|
||||||
|
index_tensor = index_tensor.to(logits.device, non_blocking=True)
|
||||||
|
|
||||||
xgr.apply_token_bitmask_inplace(
|
xgr.apply_token_bitmask_inplace(logits, grammar_bitmask, indices=index_tensor)
|
||||||
logits,
|
|
||||||
grammar_bitmask.to(logits.device, non_blocking=True),
|
|
||||||
indices=out_indices if not skip_out_indices else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class OutlinesVocabulary:
|
class OutlinesVocabulary:
|
||||||
|
|||||||
@ -204,7 +204,7 @@ class InputBatch:
|
|||||||
self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy()
|
self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy()
|
||||||
|
|
||||||
# lora related
|
# lora related
|
||||||
self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int32)
|
self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int64)
|
||||||
self.lora_id_to_request_ids: dict[int, set[str]] = {}
|
self.lora_id_to_request_ids: dict[int, set[str]] = {}
|
||||||
self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
|
self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
|
||||||
|
|
||||||
|
|||||||
@ -109,6 +109,7 @@ from vllm.v1.outputs import (
|
|||||||
EMPTY_MODEL_RUNNER_OUTPUT,
|
EMPTY_MODEL_RUNNER_OUTPUT,
|
||||||
AsyncModelRunnerOutput,
|
AsyncModelRunnerOutput,
|
||||||
DraftTokenIds,
|
DraftTokenIds,
|
||||||
|
KVConnectorOutput,
|
||||||
LogprobsLists,
|
LogprobsLists,
|
||||||
LogprobsTensors,
|
LogprobsTensors,
|
||||||
ModelRunnerOutput,
|
ModelRunnerOutput,
|
||||||
@ -150,7 +151,7 @@ from .utils import (
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -218,6 +219,20 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class ExecuteModelState(NamedTuple):
|
||||||
|
"""Ephemeral cached state transferred between execute_model() and
|
||||||
|
sample_tokens(), after execute_model() returns None."""
|
||||||
|
|
||||||
|
scheduler_output: "SchedulerOutput"
|
||||||
|
logits: torch.Tensor
|
||||||
|
spec_decode_metadata: SpecDecodeMetadata | None
|
||||||
|
spec_decode_common_attn_metadata: CommonAttentionMetadata | None
|
||||||
|
hidden_states: torch.Tensor
|
||||||
|
sample_hidden_states: torch.Tensor
|
||||||
|
aux_hidden_states: list[torch.Tensor] | None
|
||||||
|
kv_connector_output: KVConnectorOutput | None
|
||||||
|
|
||||||
|
|
||||||
class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -509,6 +524,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
pin_memory=self.pin_memory,
|
pin_memory=self.pin_memory,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Ephemeral state transferred between execute_model() and sample_tokens().
|
||||||
|
self.execute_model_state: ExecuteModelState | None = None
|
||||||
|
|
||||||
def reset_mm_cache(self) -> None:
|
def reset_mm_cache(self) -> None:
|
||||||
if self.mm_budget:
|
if self.mm_budget:
|
||||||
self.mm_budget.reset_cache()
|
self.mm_budget.reset_cache()
|
||||||
@ -2113,7 +2131,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
num_input_tokens: int, # Padded
|
num_input_tokens: int, # Padded
|
||||||
intermediate_tensors: IntermediateTensors | None = None,
|
intermediate_tensors: IntermediateTensors | None = None,
|
||||||
) -> tuple[
|
) -> tuple[
|
||||||
int,
|
|
||||||
torch.Tensor | None,
|
torch.Tensor | None,
|
||||||
torch.Tensor | None,
|
torch.Tensor | None,
|
||||||
torch.Tensor,
|
torch.Tensor,
|
||||||
@ -2207,7 +2224,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
model_kwargs.update(encoder_inputs)
|
model_kwargs.update(encoder_inputs)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
num_scheduled_tokens,
|
|
||||||
input_ids,
|
input_ids,
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
positions,
|
positions,
|
||||||
@ -2425,13 +2441,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self,
|
self,
|
||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
intermediate_tensors: IntermediateTensors | None = None,
|
intermediate_tensors: IntermediateTensors | None = None,
|
||||||
) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors:
|
) -> ModelRunnerOutput | IntermediateTensors | None:
|
||||||
|
if self.execute_model_state is not None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"State error: sample_tokens() must be called "
|
||||||
|
"after execute_model() returns None."
|
||||||
|
)
|
||||||
|
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||||
with record_function_or_nullcontext("Preprocess"):
|
with record_function_or_nullcontext("Preprocess"):
|
||||||
with self.synchronize_input_prep():
|
with self.synchronize_input_prep():
|
||||||
# Update persistent batch states.
|
# Update persistent batch states.
|
||||||
self._update_states(scheduler_output)
|
self._update_states(scheduler_output)
|
||||||
|
|
||||||
if not scheduler_output.total_num_scheduled_tokens:
|
if not num_scheduled_tokens:
|
||||||
if not has_kv_transfer_group():
|
if not has_kv_transfer_group():
|
||||||
# Return empty ModelRunnerOutput if no work to do.
|
# Return empty ModelRunnerOutput if no work to do.
|
||||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||||
@ -2471,7 +2493,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
(
|
(
|
||||||
num_scheduled_tokens,
|
|
||||||
input_ids,
|
input_ids,
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
positions,
|
positions,
|
||||||
@ -2559,6 +2580,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# Rare case.
|
# Rare case.
|
||||||
assert not self.is_pooling_model
|
assert not self.is_pooling_model
|
||||||
|
|
||||||
|
sample_hidden_states = hidden_states[logits_indices]
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
all_gather_tensors = {
|
all_gather_tensors = {
|
||||||
"residual": not is_residual_scattered_for_sp(
|
"residual": not is_residual_scattered_for_sp(
|
||||||
@ -2572,7 +2594,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
)
|
)
|
||||||
logits = None
|
logits = None
|
||||||
else:
|
else:
|
||||||
sample_hidden_states = hidden_states[logits_indices]
|
|
||||||
logits = self.model.compute_logits(sample_hidden_states)
|
logits = self.model.compute_logits(sample_hidden_states)
|
||||||
|
|
||||||
model_output_broadcast_data = {}
|
model_output_broadcast_data = {}
|
||||||
@ -2585,9 +2606,45 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
assert model_output_broadcast_data is not None
|
assert model_output_broadcast_data is not None
|
||||||
logits = model_output_broadcast_data["logits"]
|
logits = model_output_broadcast_data["logits"]
|
||||||
|
|
||||||
# Apply structured output bitmasks if present
|
self.execute_model_state = ExecuteModelState(
|
||||||
if scheduler_output.structured_output_request_ids:
|
scheduler_output,
|
||||||
apply_grammar_bitmask(scheduler_output, self.input_batch, logits)
|
logits,
|
||||||
|
spec_decode_metadata,
|
||||||
|
spec_decode_common_attn_metadata,
|
||||||
|
hidden_states,
|
||||||
|
sample_hidden_states,
|
||||||
|
aux_hidden_states,
|
||||||
|
kv_connector_output,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@torch.inference_mode
|
||||||
|
def sample_tokens(
|
||||||
|
self, grammar_output: "GrammarOutput | None"
|
||||||
|
) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors:
|
||||||
|
if self.execute_model_state is None:
|
||||||
|
# Nothing to do (PP non-final rank case), output isn't used.
|
||||||
|
return None # noqa
|
||||||
|
|
||||||
|
# Unpack ephemeral state.
|
||||||
|
(
|
||||||
|
scheduler_output,
|
||||||
|
logits,
|
||||||
|
spec_decode_metadata,
|
||||||
|
spec_decode_common_attn_metadata,
|
||||||
|
hidden_states,
|
||||||
|
sample_hidden_states,
|
||||||
|
aux_hidden_states,
|
||||||
|
kv_connector_output,
|
||||||
|
) = self.execute_model_state
|
||||||
|
# Clear ephemeral state.
|
||||||
|
self.execute_model_state = None
|
||||||
|
|
||||||
|
# Apply structured output bitmasks if present.
|
||||||
|
if grammar_output is not None:
|
||||||
|
apply_grammar_bitmask(
|
||||||
|
scheduler_output, grammar_output, self.input_batch, logits
|
||||||
|
)
|
||||||
|
|
||||||
with record_function_or_nullcontext("Sample"):
|
with record_function_or_nullcontext("Sample"):
|
||||||
sampler_output = self._sample(logits, spec_decode_metadata)
|
sampler_output = self._sample(logits, spec_decode_metadata)
|
||||||
@ -2646,7 +2703,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
sampler_output,
|
sampler_output,
|
||||||
logits,
|
logits,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
num_scheduled_tokens,
|
scheduler_output.total_num_scheduled_tokens,
|
||||||
spec_decode_metadata,
|
spec_decode_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -3978,6 +4035,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
def create_attn_groups(
|
def create_attn_groups(
|
||||||
attn_backends_map: dict[AttentionGroupKey, list[str]],
|
attn_backends_map: dict[AttentionGroupKey, list[str]],
|
||||||
|
kv_cache_group_id: int,
|
||||||
) -> list[AttentionGroup]:
|
) -> list[AttentionGroup]:
|
||||||
attn_groups: list[AttentionGroup] = []
|
attn_groups: list[AttentionGroup] = []
|
||||||
for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items():
|
for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items():
|
||||||
@ -3987,6 +4045,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
kv_cache_spec,
|
kv_cache_spec,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
self.device,
|
self.device,
|
||||||
|
kv_cache_group_id,
|
||||||
num_metadata_builders=1
|
num_metadata_builders=1
|
||||||
if not self.parallel_config.enable_dbo
|
if not self.parallel_config.enable_dbo
|
||||||
else 2,
|
else 2,
|
||||||
@ -4005,8 +4064,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# Resolve cudagraph_mode before actually initialize metadata_builders
|
# Resolve cudagraph_mode before actually initialize metadata_builders
|
||||||
self._check_and_update_cudagraph_mode(attention_backend_set)
|
self._check_and_update_cudagraph_mode(attention_backend_set)
|
||||||
|
|
||||||
for attn_backends_map in attention_backend_maps:
|
for i, attn_backend_map in enumerate(attention_backend_maps):
|
||||||
self.attn_groups.append(create_attn_groups(attn_backends_map))
|
self.attn_groups.append(create_attn_groups(attn_backend_map, i))
|
||||||
|
|
||||||
# Calculate reorder batch threshold (if needed)
|
# Calculate reorder batch threshold (if needed)
|
||||||
self.calculate_reorder_batch_threshold()
|
self.calculate_reorder_batch_threshold()
|
||||||
@ -4149,89 +4208,88 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
group.get_metadata_builder().reorder_batch_threshold
|
group.get_metadata_builder().reorder_batch_threshold
|
||||||
for group in self._attn_group_iterator()
|
for group in self._attn_group_iterator()
|
||||||
]
|
]
|
||||||
|
# If there are no attention groups (attention-free model) or no backend
|
||||||
|
# reports a threshold, leave reordering disabled.
|
||||||
|
if len(reorder_batch_thresholds) == 0:
|
||||||
|
self.reorder_batch_threshold = None
|
||||||
|
return
|
||||||
self.reorder_batch_threshold = reduce(min_none_high, reorder_batch_thresholds)
|
self.reorder_batch_threshold = reduce(min_none_high, reorder_batch_thresholds)
|
||||||
|
|
||||||
def _find_compatible_block_sizes(
|
@staticmethod
|
||||||
self,
|
def select_common_block_size(
|
||||||
kv_manager_block_size: int,
|
kv_manager_block_size: int, attn_groups: list[AttentionGroup]
|
||||||
backend_cls: type[AttentionBackend],
|
|
||||||
return_all: bool = False,
|
|
||||||
) -> list[int]:
|
|
||||||
"""
|
|
||||||
Find compatible block sizes for a backend.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
kv_manager_block_size: Physical block size of KV cache
|
|
||||||
backend_cls: Attention backend class
|
|
||||||
return_all: Return all compatible sizes if True, max size if False
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Compatible block size(s) based on return_all parameter
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If no compatible block size found
|
|
||||||
"""
|
|
||||||
supported_block_size = backend_cls.get_supported_kernel_block_size()
|
|
||||||
compatible_sizes = []
|
|
||||||
|
|
||||||
for block_size in supported_block_size:
|
|
||||||
if isinstance(block_size, int):
|
|
||||||
if kv_manager_block_size % block_size == 0:
|
|
||||||
compatible_sizes.append(block_size)
|
|
||||||
elif (
|
|
||||||
isinstance(block_size, MultipleOf)
|
|
||||||
and kv_manager_block_size % block_size.base == 0
|
|
||||||
):
|
|
||||||
compatible_sizes.append(kv_manager_block_size)
|
|
||||||
|
|
||||||
if not compatible_sizes:
|
|
||||||
raise ValueError(f"No compatible block size for {kv_manager_block_size}")
|
|
||||||
|
|
||||||
return compatible_sizes if return_all else [max(compatible_sizes)]
|
|
||||||
|
|
||||||
def _select_common_block_size(
|
|
||||||
self, kv_manager_block_size: int, attn_groups: list[AttentionGroup]
|
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Select common block size for all backends.
|
Select a block size that is supported by all backends and is a factor of
|
||||||
|
kv_manager_block_size.
|
||||||
|
|
||||||
|
If kv_manager_block_size is supported by all backends, return it directly.
|
||||||
|
Otherwise, return the max supported size.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
kv_manager_block_size: Block size of KV cache
|
kv_manager_block_size: Block size of KV cache
|
||||||
attn_groups: List of attention groups
|
attn_groups: List of attention groups
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Block size supported by all backends,
|
The selected block size
|
||||||
prioritizing cache_config.block_size
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If no common block size found
|
ValueError: If no valid block size found
|
||||||
"""
|
"""
|
||||||
all_backend_supports = []
|
|
||||||
|
|
||||||
for attn_group in attn_groups:
|
def block_size_is_supported(
|
||||||
compatible_sizes = self._find_compatible_block_sizes(
|
backends: list[type[AttentionBackend]], block_size: int
|
||||||
kv_manager_block_size, attn_group.backend, return_all=True
|
) -> bool:
|
||||||
)
|
"""
|
||||||
supported_sizes = sorted(list(set(compatible_sizes)), reverse=True)
|
Check if the block size is supported by all backends.
|
||||||
all_backend_supports.append(set(supported_sizes))
|
"""
|
||||||
|
for backend in backends:
|
||||||
|
is_supported = False
|
||||||
|
for supported_size in backend.get_supported_kernel_block_size():
|
||||||
|
if isinstance(supported_size, int):
|
||||||
|
if block_size == supported_size:
|
||||||
|
is_supported = True
|
||||||
|
elif isinstance(supported_size, MultipleOf):
|
||||||
|
if block_size % supported_size.base == 0:
|
||||||
|
is_supported = True
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown supported size: {supported_size}")
|
||||||
|
if not is_supported:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
common_supported_sizes = set.intersection(*all_backend_supports)
|
backends = [group.backend for group in attn_groups]
|
||||||
|
|
||||||
if not common_supported_sizes:
|
# Case 1: if the block_size of kv cache manager is supported by all backends,
|
||||||
error_msg = f"No common block size for {kv_manager_block_size}. "
|
# return it directly
|
||||||
for i, attn_group in enumerate(attn_groups):
|
if block_size_is_supported(backends, kv_manager_block_size):
|
||||||
supported = all_backend_supports[i]
|
return kv_manager_block_size
|
||||||
error_msg += (
|
|
||||||
f"Backend {attn_group.backend} supports: {sorted(supported)}. "
|
|
||||||
)
|
|
||||||
raise ValueError(error_msg)
|
|
||||||
|
|
||||||
if self.cache_config.block_size in common_supported_sizes:
|
# Case 2: otherwise, the block_size must be an `int`-format supported size of
|
||||||
return self.cache_config.block_size
|
# at least one backend. Iterate over all `int`-format supported sizes in
|
||||||
|
# descending order and return the first one that is supported by all backends.
|
||||||
|
# Simple proof:
|
||||||
|
# If the supported size b is in MultipleOf(x_i) format for all attention
|
||||||
|
# backends i, and b a factor of kv_manager_block_size, then
|
||||||
|
# kv_manager_block_size also satisfies MultipleOf(x_i) for all i. We will
|
||||||
|
# return kv_manager_block_size in case 1.
|
||||||
|
all_int_supported_sizes = set(
|
||||||
|
supported_size
|
||||||
|
for backend in backends
|
||||||
|
for supported_size in backend.get_supported_kernel_block_size()
|
||||||
|
if isinstance(supported_size, int)
|
||||||
|
)
|
||||||
|
|
||||||
return max(common_supported_sizes)
|
for supported_size in sorted(all_int_supported_sizes, reverse=True):
|
||||||
|
if kv_manager_block_size % supported_size != 0:
|
||||||
|
continue
|
||||||
|
if block_size_is_supported(backends, supported_size):
|
||||||
|
return supported_size
|
||||||
|
raise ValueError(f"No common block size for {kv_manager_block_size}. ")
|
||||||
|
|
||||||
def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None:
|
def may_reinitialize_input_batch(
|
||||||
|
self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int]
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Re-initialize the input batch if the block sizes are different from
|
Re-initialize the input batch if the block sizes are different from
|
||||||
`[self.cache_config.block_size]`. This usually happens when there
|
`[self.cache_config.block_size]`. This usually happens when there
|
||||||
@ -4239,6 +4297,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
kv_cache_config: The KV cache configuration.
|
kv_cache_config: The KV cache configuration.
|
||||||
|
kernel_block_sizes: The kernel block sizes for each KV cache group.
|
||||||
"""
|
"""
|
||||||
block_sizes = [
|
block_sizes = [
|
||||||
kv_cache_group.kv_cache_spec.block_size
|
kv_cache_group.kv_cache_spec.block_size
|
||||||
@ -4246,9 +4305,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
if not isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec)
|
if not isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec)
|
||||||
]
|
]
|
||||||
|
|
||||||
# Generate kernel_block_sizes that matches each block_size
|
|
||||||
kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config)
|
|
||||||
|
|
||||||
if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [
|
if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [
|
||||||
self.cache_config.block_size
|
self.cache_config.block_size
|
||||||
]:
|
]:
|
||||||
@ -4349,7 +4405,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# all backends in the group.
|
# all backends in the group.
|
||||||
attn_groups = self.attn_groups[kv_cache_group_id]
|
attn_groups = self.attn_groups[kv_cache_group_id]
|
||||||
kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size
|
kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size
|
||||||
selected_kernel_size = self._select_common_block_size(
|
selected_kernel_size = self.select_common_block_size(
|
||||||
kv_manager_block_size, attn_groups
|
kv_manager_block_size, attn_groups
|
||||||
)
|
)
|
||||||
kernel_block_sizes.append(selected_kernel_size)
|
kernel_block_sizes.append(selected_kernel_size)
|
||||||
@ -4367,6 +4423,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self,
|
self,
|
||||||
kv_cache_config: KVCacheConfig,
|
kv_cache_config: KVCacheConfig,
|
||||||
kv_cache_raw_tensors: dict[str, torch.Tensor],
|
kv_cache_raw_tensors: dict[str, torch.Tensor],
|
||||||
|
kernel_block_sizes: list[int],
|
||||||
) -> dict[str, torch.Tensor]:
|
) -> dict[str, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Reshape the KV cache tensors to the desired shape and dtype.
|
Reshape the KV cache tensors to the desired shape and dtype.
|
||||||
@ -4375,6 +4432,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
kv_cache_config: The KV cache config
|
kv_cache_config: The KV cache config
|
||||||
kv_cache_raw_tensors: The KV cache buffer of each layer, with
|
kv_cache_raw_tensors: The KV cache buffer of each layer, with
|
||||||
correct size but uninitialized shape.
|
correct size but uninitialized shape.
|
||||||
|
kernel_block_sizes: The kernel block sizes for each KV cache group.
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, torch.Tensor]: A map between layer names to their
|
Dict[str, torch.Tensor]: A map between layer names to their
|
||||||
corresponding memory buffer for KV cache.
|
corresponding memory buffer for KV cache.
|
||||||
@ -4384,6 +4442,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
for group in self._kv_cache_spec_attn_group_iterator():
|
for group in self._kv_cache_spec_attn_group_iterator():
|
||||||
kv_cache_spec = group.kv_cache_spec
|
kv_cache_spec = group.kv_cache_spec
|
||||||
attn_backend = group.backend
|
attn_backend = group.backend
|
||||||
|
if group.kv_cache_group_id == len(kernel_block_sizes):
|
||||||
|
# There may be a last group for layers without kv cache.
|
||||||
|
continue
|
||||||
|
kernel_block_size = kernel_block_sizes[group.kv_cache_group_id]
|
||||||
for layer_name in group.layer_names:
|
for layer_name in group.layer_names:
|
||||||
if layer_name in self.runner_only_attn_layers:
|
if layer_name in self.runner_only_attn_layers:
|
||||||
continue
|
continue
|
||||||
@ -4392,24 +4454,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
|
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
|
||||||
if isinstance(kv_cache_spec, AttentionSpec):
|
if isinstance(kv_cache_spec, AttentionSpec):
|
||||||
has_attn = True
|
has_attn = True
|
||||||
kv_manager_block_size = kv_cache_spec.block_size
|
num_blocks_per_kv_block = (
|
||||||
kernel_size_list = self._find_compatible_block_sizes(
|
kv_cache_spec.block_size // kernel_block_size
|
||||||
kv_manager_block_size, attn_backend, return_all=False
|
|
||||||
)
|
)
|
||||||
kernel_size = kernel_size_list[0]
|
|
||||||
num_blocks_per_kv_block = kv_manager_block_size // kernel_size
|
|
||||||
kernel_num_blocks = num_blocks * num_blocks_per_kv_block
|
kernel_num_blocks = num_blocks * num_blocks_per_kv_block
|
||||||
|
|
||||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||||
kernel_num_blocks,
|
kernel_num_blocks,
|
||||||
kernel_size,
|
kernel_block_size,
|
||||||
kv_cache_spec.num_kv_heads,
|
kv_cache_spec.num_kv_heads,
|
||||||
kv_cache_spec.head_size,
|
kv_cache_spec.head_size,
|
||||||
cache_dtype_str=self.cache_config.cache_dtype,
|
cache_dtype_str=self.cache_config.cache_dtype,
|
||||||
)
|
)
|
||||||
dtype = kv_cache_spec.dtype
|
dtype = kv_cache_spec.dtype
|
||||||
try:
|
try:
|
||||||
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() # noqa: E501
|
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
|
||||||
assert len(kv_cache_stride_order) == len(kv_cache_shape)
|
assert len(kv_cache_stride_order) == len(kv_cache_shape)
|
||||||
except (AttributeError, NotImplementedError):
|
except (AttributeError, NotImplementedError):
|
||||||
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
|
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
|
||||||
@ -4492,13 +4551,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def initialize_kv_cache_tensors(
|
def initialize_kv_cache_tensors(
|
||||||
self, kv_cache_config: KVCacheConfig
|
self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int]
|
||||||
) -> dict[str, torch.Tensor]:
|
) -> dict[str, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Initialize the memory buffer for KV cache.
|
Initialize the memory buffer for KV cache.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
kv_cache_config: The KV cache config
|
kv_cache_config: The KV cache config
|
||||||
|
kernel_block_sizes: The kernel block sizes for each KV cache group.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, torch.Tensor]: A map between layer names to their
|
Dict[str, torch.Tensor]: A map between layer names to their
|
||||||
corresponding memory buffer for KV cache.
|
corresponding memory buffer for KV cache.
|
||||||
@ -4507,7 +4568,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config)
|
kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config)
|
||||||
# Change the memory buffer to the desired shape
|
# Change the memory buffer to the desired shape
|
||||||
kv_caches = self._reshape_kv_cache_tensors(
|
kv_caches = self._reshape_kv_cache_tensors(
|
||||||
kv_cache_config, kv_cache_raw_tensors
|
kv_cache_config, kv_cache_raw_tensors, kernel_block_sizes
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set up cross-layer KV cache sharing
|
# Set up cross-layer KV cache sharing
|
||||||
@ -4566,9 +4627,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self.may_add_encoder_only_layers_to_kv_cache_config()
|
self.may_add_encoder_only_layers_to_kv_cache_config()
|
||||||
self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)
|
self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)
|
||||||
self.initialize_attn_backend(kv_cache_config)
|
self.initialize_attn_backend(kv_cache_config)
|
||||||
|
# The kernel block size for all KV cache groups. For example, if
|
||||||
|
# kv_cache_manager uses block_size 256 for a given group, but the attention
|
||||||
|
# backends for that group only supports block_size 64, we will return
|
||||||
|
# kernel_block_size 64 and split the 256-token-block to 4 blocks with 64
|
||||||
|
# tokens each.
|
||||||
|
kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config)
|
||||||
# Reinitialize need to after initialize_attn_backend
|
# Reinitialize need to after initialize_attn_backend
|
||||||
self.may_reinitialize_input_batch(kv_cache_config)
|
self.may_reinitialize_input_batch(kv_cache_config, kernel_block_sizes)
|
||||||
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
|
kv_caches = self.initialize_kv_cache_tensors(
|
||||||
|
kv_cache_config, kernel_block_sizes
|
||||||
|
)
|
||||||
|
|
||||||
if self.speculative_config and self.speculative_config.use_eagle():
|
if self.speculative_config and self.speculative_config.use_eagle():
|
||||||
assert isinstance(self.drafter, EagleProposer)
|
assert isinstance(self.drafter, EagleProposer)
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import copy
|
|||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
from contextlib import AbstractContextManager, nullcontext
|
from contextlib import AbstractContextManager, nullcontext
|
||||||
|
from types import NoneType
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -19,7 +20,11 @@ from vllm.distributed import (
|
|||||||
init_distributed_environment,
|
init_distributed_environment,
|
||||||
set_custom_all_reduce,
|
set_custom_all_reduce,
|
||||||
)
|
)
|
||||||
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
from vllm.distributed.kv_transfer import (
|
||||||
|
ensure_kv_transfer_initialized,
|
||||||
|
get_kv_transfer_group,
|
||||||
|
has_kv_transfer_group,
|
||||||
|
)
|
||||||
from vllm.distributed.parallel_state import (
|
from vllm.distributed.parallel_state import (
|
||||||
get_pp_group,
|
get_pp_group,
|
||||||
get_tp_group,
|
get_tp_group,
|
||||||
@ -33,6 +38,7 @@ from vllm.sequence import IntermediateTensors
|
|||||||
from vllm.tasks import SupportedTask
|
from vllm.tasks import SupportedTask
|
||||||
from vllm.utils.mem_constants import GiB_bytes
|
from vllm.utils.mem_constants import GiB_bytes
|
||||||
from vllm.utils.mem_utils import MemorySnapshot, memory_profiling
|
from vllm.utils.mem_utils import MemorySnapshot, memory_profiling
|
||||||
|
from vllm.v1.core.sched.output import GrammarOutput
|
||||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||||
from vllm.v1.outputs import (
|
from vllm.v1.outputs import (
|
||||||
@ -348,6 +354,21 @@ class Worker(WorkerBase):
|
|||||||
|
|
||||||
return int(self.available_kv_cache_memory_bytes)
|
return int(self.available_kv_cache_memory_bytes)
|
||||||
|
|
||||||
|
def get_kv_connector_handshake_metadata(self) -> dict | None:
|
||||||
|
"""Get KV connector metadata from this worker if available."""
|
||||||
|
|
||||||
|
if not has_kv_transfer_group():
|
||||||
|
return None
|
||||||
|
|
||||||
|
connector = get_kv_transfer_group()
|
||||||
|
# Return None for connectors that don't need to exchange handshake
|
||||||
|
# metadata across workers.
|
||||||
|
if (metadata := connector.get_handshake_metadata()) is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
tp_rank = get_tp_group().rank_in_group
|
||||||
|
return {tp_rank: metadata}
|
||||||
|
|
||||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||||
return self.model_runner.get_kv_cache_spec()
|
return self.model_runner.get_kv_cache_spec()
|
||||||
|
|
||||||
@ -489,11 +510,16 @@ class Worker(WorkerBase):
|
|||||||
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||||
return self.model_runner.get_supported_tasks()
|
return self.model_runner.get_supported_tasks()
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def sample_tokens(
|
||||||
|
self, grammar_output: "GrammarOutput"
|
||||||
|
) -> ModelRunnerOutput | AsyncModelRunnerOutput:
|
||||||
|
return self.model_runner.sample_tokens(grammar_output)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self, scheduler_output: "SchedulerOutput"
|
||||||
scheduler_output: "SchedulerOutput",
|
) -> ModelRunnerOutput | None:
|
||||||
) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
|
|
||||||
intermediate_tensors = None
|
intermediate_tensors = None
|
||||||
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
|
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
|
||||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||||
@ -512,13 +538,13 @@ class Worker(WorkerBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
output = self.model_runner.execute_model(scheduler_output, intermediate_tensors)
|
output = self.model_runner.execute_model(scheduler_output, intermediate_tensors)
|
||||||
if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)):
|
if isinstance(output, (ModelRunnerOutput, NoneType)):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
assert isinstance(output, IntermediateTensors)
|
assert isinstance(output, IntermediateTensors)
|
||||||
parallel_config = self.vllm_config.parallel_config
|
parallel_config = self.vllm_config.parallel_config
|
||||||
assert (
|
assert (
|
||||||
parallel_config.distributed_executor_backend != ("external_launcher")
|
parallel_config.distributed_executor_backend != "external_launcher"
|
||||||
and not get_pp_group().is_last_rank
|
and not get_pp_group().is_last_rank
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -139,7 +139,7 @@ class InputBatch:
|
|||||||
self.min_tokens: dict[int, tuple[int, set[int]]] = {}
|
self.min_tokens: dict[int, tuple[int, set[int]]] = {}
|
||||||
|
|
||||||
# lora related
|
# lora related
|
||||||
self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int32)
|
self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int64)
|
||||||
self.lora_id_to_request_ids: dict[int, set[str]] = {}
|
self.lora_id_to_request_ids: dict[int, set[str]] = {}
|
||||||
self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
|
self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
|
||||||
|
|
||||||
|
|||||||
@ -92,7 +92,7 @@ from .utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -372,6 +372,11 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
else:
|
else:
|
||||||
self.sample_from_logits_func = self.sample_from_logits
|
self.sample_from_logits_func = self.sample_from_logits
|
||||||
|
|
||||||
|
# For passing scheduler_output between successive
|
||||||
|
# execute_model() and sample_tokens() calls.
|
||||||
|
self.scheduler_output: SchedulerOutput | None = None
|
||||||
|
self.mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None
|
||||||
|
|
||||||
def reset_mm_cache(self) -> None:
|
def reset_mm_cache(self) -> None:
|
||||||
if self.mm_budget:
|
if self.mm_budget:
|
||||||
self.mm_budget.reset_cache()
|
self.mm_budget.reset_cache()
|
||||||
@ -1078,7 +1083,12 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self,
|
self,
|
||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
intermediate_tensors: IntermediateTensors | None = None,
|
intermediate_tensors: IntermediateTensors | None = None,
|
||||||
) -> ModelRunnerOutput:
|
) -> ModelRunnerOutput | None:
|
||||||
|
if self.scheduler_output is not None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"State error: sample_tokens() must be called "
|
||||||
|
"after execute_model() returns None."
|
||||||
|
)
|
||||||
# Update cached state
|
# Update cached state
|
||||||
self._update_states(scheduler_output)
|
self._update_states(scheduler_output)
|
||||||
if not scheduler_output.total_num_scheduled_tokens:
|
if not scheduler_output.total_num_scheduled_tokens:
|
||||||
@ -1088,14 +1098,30 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
return self.kv_connector_no_forward(scheduler_output, self.vllm_config)
|
return self.kv_connector_no_forward(scheduler_output, self.vllm_config)
|
||||||
|
|
||||||
|
mm_embed_inputs = None
|
||||||
if self.supports_mm_inputs:
|
if self.supports_mm_inputs:
|
||||||
# Run the multimodal encoder if any.
|
# Run the multimodal encoder if any.
|
||||||
self._execute_mm_encoder(scheduler_output)
|
self._execute_mm_encoder(scheduler_output)
|
||||||
mm_embed_inputs = self._gather_mm_embeddings(scheduler_output)
|
mm_embed_inputs = self._gather_mm_embeddings(scheduler_output)
|
||||||
else:
|
|
||||||
mm_embed_inputs = None
|
|
||||||
|
|
||||||
torch_xla.sync(wait=False)
|
torch_xla.sync(wait=False)
|
||||||
|
|
||||||
|
self.scheduler_output = scheduler_output
|
||||||
|
self.mm_embed_inputs = mm_embed_inputs
|
||||||
|
return None
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_tokens(
|
||||||
|
self, grammar_output: "GrammarOutput | None"
|
||||||
|
) -> ModelRunnerOutput:
|
||||||
|
if self.scheduler_output is None:
|
||||||
|
# Nothing to do (PP non-final rank case), output isn't used.
|
||||||
|
return None # noqa
|
||||||
|
scheduler_output = self.scheduler_output
|
||||||
|
mm_embed_inputs = self.mm_embed_inputs
|
||||||
|
self.scheduler_output = None
|
||||||
|
self.mm_embed_inputs = None
|
||||||
|
|
||||||
# Prepare inputs, the requests might be split into multiple
|
# Prepare inputs, the requests might be split into multiple
|
||||||
# executions, combine the result of each execution.
|
# executions, combine the result of each execution.
|
||||||
start_index = 0
|
start_index = 0
|
||||||
@ -1131,9 +1157,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
|
tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
|
||||||
self.input_batch, padded_num_reqs, self.device
|
self.input_batch, padded_num_reqs, self.device
|
||||||
)
|
)
|
||||||
if scheduler_output.grammar_bitmask is not None:
|
if grammar_output is not None:
|
||||||
require_struct_decoding, grammar_bitmask_padded, arange = (
|
require_struct_decoding, grammar_bitmask_padded, arange = (
|
||||||
self.prepare_structured_decoding_input(logits, scheduler_output)
|
self.prepare_structured_decoding_input(logits, grammar_output)
|
||||||
)
|
)
|
||||||
logits = self.structured_decode(
|
logits = self.structured_decode(
|
||||||
require_struct_decoding, grammar_bitmask_padded, logits, arange
|
require_struct_decoding, grammar_bitmask_padded, logits, arange
|
||||||
@ -1954,10 +1980,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
return self.model.get_input_embeddings(*args, **kwargs)
|
return self.model.get_input_embeddings(*args, **kwargs)
|
||||||
|
|
||||||
def prepare_structured_decoding_input(
|
def prepare_structured_decoding_input(
|
||||||
self, logits: torch.Tensor, scheduler_output: "SchedulerOutput"
|
self, logits: torch.Tensor, grammar_output: "GrammarOutput"
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
grammar_bitmask = scheduler_output.grammar_bitmask
|
grammar_bitmask = grammar_output.grammar_bitmask
|
||||||
assert grammar_bitmask is not None
|
|
||||||
num_reqs, _ = logits.shape
|
num_reqs, _ = logits.shape
|
||||||
|
|
||||||
# Reset pre-allocated tensors
|
# Reset pre-allocated tensors
|
||||||
@ -1965,7 +1990,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self.require_structured_out_cpu.zero_()
|
self.require_structured_out_cpu.zero_()
|
||||||
|
|
||||||
cumulative_mask_idx = 0
|
cumulative_mask_idx = 0
|
||||||
for req_id in scheduler_output.structured_output_request_ids:
|
for req_id in grammar_output.structured_output_request_ids:
|
||||||
if req_id not in self.input_batch.req_id_to_index:
|
if req_id not in self.input_batch.req_id_to_index:
|
||||||
continue
|
continue
|
||||||
batch_index = self.input_batch.req_id_to_index[req_id]
|
batch_index = self.input_batch.req_id_to_index[req_id]
|
||||||
|
|||||||
@ -17,7 +17,6 @@ from vllm.distributed import (
|
|||||||
)
|
)
|
||||||
from vllm.distributed.kv_transfer import (
|
from vllm.distributed.kv_transfer import (
|
||||||
ensure_kv_transfer_initialized,
|
ensure_kv_transfer_initialized,
|
||||||
has_kv_transfer_group,
|
|
||||||
)
|
)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
@ -27,7 +26,7 @@ from vllm.platforms.tpu import USE_TPU_INFERENCE
|
|||||||
from vllm.tasks import SupportedTask
|
from vllm.tasks import SupportedTask
|
||||||
from vllm.utils.math_utils import cdiv
|
from vllm.utils.math_utils import cdiv
|
||||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, KVCacheSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, KVCacheSpec
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
from vllm.v1.utils import report_usage_stats
|
from vllm.v1.utils import report_usage_stats
|
||||||
@ -255,13 +254,13 @@ class TPUWorker:
|
|||||||
tpu_kv_cache_bytes = tpu_kv_cache_bytes * head_size // padded_head_size
|
tpu_kv_cache_bytes = tpu_kv_cache_bytes * head_size // padded_head_size
|
||||||
return int(tpu_kv_cache_bytes)
|
return int(tpu_kv_cache_bytes)
|
||||||
|
|
||||||
|
def sample_tokens(self, grammar_output: "GrammarOutput") -> ModelRunnerOutput:
|
||||||
|
return self.model_runner.sample_tokens(grammar_output)
|
||||||
|
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self, scheduler_output: "SchedulerOutput"
|
||||||
scheduler_output: "SchedulerOutput",
|
|
||||||
) -> ModelRunnerOutput | None:
|
) -> ModelRunnerOutput | None:
|
||||||
output = self.model_runner.execute_model(scheduler_output)
|
return self.model_runner.execute_model(scheduler_output)
|
||||||
# every worker's output is needed when kv_transfer_group is set up
|
|
||||||
return output if self.is_driver_worker or has_kv_transfer_group() else None
|
|
||||||
|
|
||||||
def profile(self, is_start: bool = True):
|
def profile(self, is_start: bool = True):
|
||||||
if self.rank < 1:
|
if self.rank < 1:
|
||||||
|
|||||||
@ -140,6 +140,7 @@ class AttentionGroup:
|
|||||||
metadata_builders: list[AttentionMetadataBuilder]
|
metadata_builders: list[AttentionMetadataBuilder]
|
||||||
layer_names: list[str]
|
layer_names: list[str]
|
||||||
kv_cache_spec: KVCacheSpec
|
kv_cache_spec: KVCacheSpec
|
||||||
|
kv_cache_group_id: int
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_with_metadata_builders(
|
def create_with_metadata_builders(
|
||||||
@ -148,13 +149,16 @@ class AttentionGroup:
|
|||||||
kv_cache_spec: KVCacheSpec,
|
kv_cache_spec: KVCacheSpec,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
|
kv_cache_group_id: int,
|
||||||
num_metadata_builders: int = 1,
|
num_metadata_builders: int = 1,
|
||||||
) -> "AttentionGroup":
|
) -> "AttentionGroup":
|
||||||
metadata_builders = [
|
metadata_builders = [
|
||||||
backend.get_builder_cls()(kv_cache_spec, layer_names, vllm_config, device)
|
backend.get_builder_cls()(kv_cache_spec, layer_names, vllm_config, device)
|
||||||
for _ in range(num_metadata_builders)
|
for _ in range(num_metadata_builders)
|
||||||
]
|
]
|
||||||
return AttentionGroup(backend, metadata_builders, layer_names, kv_cache_spec)
|
return AttentionGroup(
|
||||||
|
backend, metadata_builders, layer_names, kv_cache_spec, kv_cache_group_id
|
||||||
|
)
|
||||||
|
|
||||||
def get_metadata_builder(self, ubatch_id: int = 0) -> AttentionMetadataBuilder:
|
def get_metadata_builder(self, ubatch_id: int = 0) -> AttentionMetadataBuilder:
|
||||||
assert len(self.metadata_builders) > ubatch_id
|
assert len(self.metadata_builders) > ubatch_id
|
||||||
|
|||||||
@ -20,10 +20,12 @@ from vllm.v1.kv_cache_interface import KVCacheSpec
|
|||||||
from vllm.v1.serial_utils import run_method
|
from vllm.v1.serial_utils import run_method
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import AsyncModelRunnerOutput, ModelRunnerOutput
|
||||||
else:
|
else:
|
||||||
SchedulerOutput = object
|
SchedulerOutput = object
|
||||||
|
GrammarOutput = object
|
||||||
|
AsyncModelRunnerOutput = object
|
||||||
ModelRunnerOutput = object
|
ModelRunnerOutput = object
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -122,7 +124,21 @@ class WorkerBase:
|
|||||||
"""Load model onto target device."""
|
"""Load model onto target device."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def execute_model(self, scheduler_output: SchedulerOutput) -> ModelRunnerOutput:
|
def execute_model(
|
||||||
|
self, scheduler_output: SchedulerOutput
|
||||||
|
) -> ModelRunnerOutput | None:
|
||||||
|
"""If this method returns None, sample_tokens should be called immediately after
|
||||||
|
to obtain the ModelRunnerOutput.
|
||||||
|
|
||||||
|
Note that this design may be changed in future if/when structured outputs
|
||||||
|
parallelism is re-architected.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def sample_tokens(
|
||||||
|
self, grammar_output: GrammarOutput
|
||||||
|
) -> ModelRunnerOutput | AsyncModelRunnerOutput:
|
||||||
|
"""Should be called immediately after execute_model iff it returned None."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_cache_block_size_bytes(self) -> int:
|
def get_cache_block_size_bytes(self) -> int:
|
||||||
@ -344,7 +360,7 @@ class WorkerWrapperBase:
|
|||||||
scheduler_output: SchedulerOutput,
|
scheduler_output: SchedulerOutput,
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> ModelRunnerOutput:
|
) -> ModelRunnerOutput | None:
|
||||||
self._apply_mm_cache(scheduler_output)
|
self._apply_mm_cache(scheduler_output)
|
||||||
|
|
||||||
assert self.worker is not None
|
assert self.worker is not None
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user