Merge branch 'main' into woosuk/model-runner-v2

This commit is contained in:
Woosuk Kwon 2025-09-23 09:22:58 -07:00
commit 17c2c106b1
192 changed files with 7964 additions and 4624 deletions

View File

@ -62,7 +62,7 @@ echo "--- Installing Python dependencies ---"
python3 -m pip install --progress-bar off git+https://github.com/thuml/depyf.git \ python3 -m pip install --progress-bar off git+https://github.com/thuml/depyf.git \
&& python3 -m pip install --progress-bar off pytest pytest-asyncio tpu-info \ && python3 -m pip install --progress-bar off pytest pytest-asyncio tpu-info \
&& python3 -m pip install --progress-bar off "lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d" \ && python3 -m pip install --progress-bar off "lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d" \
&& python3 -m pip install --progress-bar off hf-transfer && python3 -m pip install --progress-bar off hf-transfer tblib==3.1.0
echo "--- Python dependencies installed ---" echo "--- Python dependencies installed ---"
export VLLM_USE_V1=1 export VLLM_USE_V1=1
export VLLM_XLA_CHECK_RECOMPILATION=1 export VLLM_XLA_CHECK_RECOMPILATION=1

View File

@ -62,7 +62,7 @@ echo "--- Installing Python dependencies ---"
python3 -m pip install --progress-bar off git+https://github.com/thuml/depyf.git \ python3 -m pip install --progress-bar off git+https://github.com/thuml/depyf.git \
&& python3 -m pip install --progress-bar off pytest pytest-asyncio tpu-info \ && python3 -m pip install --progress-bar off pytest pytest-asyncio tpu-info \
&& python3 -m pip install --progress-bar off "lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d" \ && python3 -m pip install --progress-bar off "lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d" \
&& python3 -m pip install --progress-bar off hf-transfer && python3 -m pip install --progress-bar off hf-transfer tblib==3.1.0
echo "--- Python dependencies installed ---" echo "--- Python dependencies installed ---"
export VLLM_USE_V1=1 export VLLM_USE_V1=1
export VLLM_XLA_CHECK_RECOMPILATION=1 export VLLM_XLA_CHECK_RECOMPILATION=1

View File

@ -165,10 +165,18 @@ steps:
- tests/v1/test_hybrid_lb_dp.py - tests/v1/test_hybrid_lb_dp.py
- tests/v1/engine/test_engine_core_client.py - tests/v1/engine/test_engine_core_client.py
commands: commands:
# test with tp=2 and external_dp=2 # test with torchrun tp=2 and external_dp=2
- torchrun --nproc-per-node=4 distributed/test_torchrun_example.py - torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
# test with tp=2 and pp=2 # test with torchrun tp=2 and pp=2
- PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py - PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
# test with torchrun tp=4 and dp=1
- TP_SIZE=4 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py
# test with torchrun tp=2, pp=2 and dp=1
- PP_SIZE=2 TP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py
# test with torchrun tp=1 and dp=4 with ep
- DP_SIZE=4 ENABLE_EP=1 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py
# test with torchrun tp=2 and dp=2 with ep
- TP_SIZE=2 DP_SIZE=2 ENABLE_EP=1 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py
# test with internal dp # test with internal dp
- python3 ../examples/offline_inference/data_parallel.py --enforce-eager - python3 ../examples/offline_inference/data_parallel.py --enforce-eager
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py

1
.github/CODEOWNERS vendored
View File

@ -72,6 +72,7 @@ mkdocs.yaml @hmellor
# Linting # Linting
.markdownlint.yaml @hmellor .markdownlint.yaml @hmellor
.pre-commit-config.yaml @hmellor .pre-commit-config.yaml @hmellor
/tools/pre_commit @hmellor
# CPU # CPU
/vllm/v1/worker/cpu* @bigPYJ1151 /vllm/v1/worker/cpu* @bigPYJ1151

View File

@ -43,10 +43,6 @@ body:
Any other things you would like to mention. Any other things you would like to mention.
validations: validations:
required: false required: false
- type: markdown
attributes:
value: >
Thanks for contributing 🎉! The vLLM core team hosts a biweekly RFC review session at 9:30AM Pacific Time, while most RFCs can be discussed online, you can optionally sign up for a slot to discuss your RFC online [here](https://docs.google.com/document/d/1CiLVBZeIVfR7_PNAKVSusxpceywkoOOB78qoWqHvSZc/edit).
- type: checkboxes - type: checkboxes
id: askllm id: askllm
attributes: attributes:

View File

@ -60,38 +60,32 @@ repos:
files: ^requirements/test\.(in|txt)$ files: ^requirements/test\.(in|txt)$
- id: mypy-local - id: mypy-local
name: Run mypy for local Python installation name: Run mypy for local Python installation
entry: tools/mypy.sh 0 "local" entry: python tools/pre_commit/mypy.py 0 "local"
language: python
types: [python]
additional_dependencies: &mypy_deps [mypy==1.11.1, types-cachetools, types-setuptools, types-PyYAML, types-requests, pydantic]
stages: [pre-commit] # Don't run in CI stages: [pre-commit] # Don't run in CI
<<: &mypy_common
language: python
types_or: [python, pyi]
require_serial: true
additional_dependencies: [mypy==1.11.1, regex, types-cachetools, types-setuptools, types-PyYAML, types-requests, types-torch, pydantic]
- id: mypy-3.9 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward - id: mypy-3.9 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
name: Run mypy for Python 3.9 name: Run mypy for Python 3.9
entry: tools/mypy.sh 1 "3.9" entry: python tools/pre_commit/mypy.py 1 "3.9"
language: python <<: *mypy_common
types: [python]
additional_dependencies: *mypy_deps
stages: [manual] # Only run in CI stages: [manual] # Only run in CI
- id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward - id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
name: Run mypy for Python 3.10 name: Run mypy for Python 3.10
entry: tools/mypy.sh 1 "3.10" entry: python tools/pre_commit/mypy.py 1 "3.10"
language: python <<: *mypy_common
types: [python]
additional_dependencies: *mypy_deps
stages: [manual] # Only run in CI stages: [manual] # Only run in CI
- id: mypy-3.11 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward - id: mypy-3.11 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
name: Run mypy for Python 3.11 name: Run mypy for Python 3.11
entry: tools/mypy.sh 1 "3.11" entry: python tools/pre_commit/mypy.py 1 "3.11"
language: python <<: *mypy_common
types: [python]
additional_dependencies: *mypy_deps
stages: [manual] # Only run in CI stages: [manual] # Only run in CI
- id: mypy-3.12 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward - id: mypy-3.12 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
name: Run mypy for Python 3.12 name: Run mypy for Python 3.12
entry: tools/mypy.sh 1 "3.12" entry: python tools/pre_commit/mypy.py 1 "3.12"
language: python <<: *mypy_common
types: [python]
additional_dependencies: *mypy_deps
stages: [manual] # Only run in CI stages: [manual] # Only run in CI
- id: shellcheck - id: shellcheck
name: Lint shell scripts name: Lint shell scripts
@ -155,11 +149,10 @@ repos:
additional_dependencies: [regex] additional_dependencies: [regex]
- id: check-pickle-imports - id: check-pickle-imports
name: Prevent new pickle/cloudpickle imports name: Prevent new pickle/cloudpickle imports
entry: python tools/check_pickle_imports.py entry: python tools/pre_commit/check_pickle_imports.py
language: python language: python
types: [python] types: [python]
pass_filenames: false additional_dependencies: [regex]
additional_dependencies: [pathspec, regex]
- id: validate-config - id: validate-config
name: Validate configuration has default values and that each field has a docstring name: Validate configuration has default values and that each field has a docstring
entry: python tools/validate_config.py entry: python tools/validate_config.py

View File

@ -680,7 +680,7 @@ vllm bench serve \
--save-result \ --save-result \
--result-dir ~/vllm_benchmark_results \ --result-dir ~/vllm_benchmark_results \
--save-detailed \ --save-detailed \
--endpoint /v1/chat/completion --endpoint /v1/chat/completions
``` ```
##### Videos (ShareGPT4Video) ##### Videos (ShareGPT4Video)
@ -707,7 +707,7 @@ vllm bench serve \
--save-result \ --save-result \
--result-dir ~/vllm_benchmark_results \ --result-dir ~/vllm_benchmark_results \
--save-detailed \ --save-detailed \
--endpoint /v1/chat/completion --endpoint /v1/chat/completions
``` ```
##### Synthetic Random Images (random-mm) ##### Synthetic Random Images (random-mm)

View File

@ -23,7 +23,7 @@ Now supports 5 types of connectors:
- **SharedStorageConnector**: refer to <gh-file:examples/offline_inference/disaggregated-prefill-v1/run.sh> for the example usage of SharedStorageConnector disaggregated prefilling. - **SharedStorageConnector**: refer to <gh-file:examples/offline_inference/disaggregated-prefill-v1/run.sh> for the example usage of SharedStorageConnector disaggregated prefilling.
- **LMCacheConnectorV1**: refer to <gh-file:examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh> for the example usage of LMCacheConnectorV1 disaggregated prefilling which uses NIXL as the underlying KV transmission. - **LMCacheConnectorV1**: refer to <gh-file:examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh> for the example usage of LMCacheConnectorV1 disaggregated prefilling which uses NIXL as the underlying KV transmission.
- **NixlConnector**: refer to <gh-file:tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh> for the example usage of NixlConnector disaggregated prefilling which support fully async send/recv. - **NixlConnector**: refer to <gh-file:tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh> for the example usage of NixlConnector disaggregated prefilling which support fully async send/recv. For detailed usage guide, see [NixlConnector Usage Guide](nixl_connector_usage.md).
- **P2pNcclConnector**: refer to <gh-file:examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_example_p2p_nccl_xpyd.sh> for the example usage of P2pNcclConnector disaggregated prefilling. - **P2pNcclConnector**: refer to <gh-file:examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_example_p2p_nccl_xpyd.sh> for the example usage of P2pNcclConnector disaggregated prefilling.
- **MultiConnector**: take advantage of the kv_connector_extra_config: dict[str, Any] already present in KVTransferConfig to stash all the connectors we want in an ordered list of kwargs.such as: - **MultiConnector**: take advantage of the kv_connector_extra_config: dict[str, Any] already present in KVTransferConfig to stash all the connectors we want in an ordered list of kwargs.such as:
@ -31,6 +31,18 @@ Now supports 5 types of connectors:
--kv-transfer-config '{"kv_connector":"MultiConnector","kv_role":"kv_both","kv_connector_extra_config":{"connectors":[{"kv_connector":"NixlConnector","kv_role":"kv_both"},{"kv_connector":"SharedStorageConnector","kv_role":"kv_both","kv_connector_extra_config":{"shared_storage_path":"local_storage"}}]}}' --kv-transfer-config '{"kv_connector":"MultiConnector","kv_role":"kv_both","kv_connector_extra_config":{"connectors":[{"kv_connector":"NixlConnector","kv_role":"kv_both"},{"kv_connector":"SharedStorageConnector","kv_role":"kv_both","kv_connector_extra_config":{"shared_storage_path":"local_storage"}}]}}'
``` ```
For NixlConnector, you may also specify one or multiple NIXL_Backend. Such as:
```bash
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both", "kv_buffer_device":"cuda", "kv_connector_extra_config":{"backend":["UCX", "GDS"]}'
```
- **OffloadingConnector**: enable offloading of KV data to CPU memory, customizing the CPU block size (in tokens) and number of blocks to allocate (per worker):
```bash
--kv-transfer-config '{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"block_size": 64, "num_cpu_blocks": 1000}}'
```
## Benchmarks ## Benchmarks
Please refer to <gh-file:benchmarks/disagg_benchmarks> for disaggregated prefilling benchmarks. Please refer to <gh-file:benchmarks/disagg_benchmarks> for disaggregated prefilling benchmarks.

View File

@ -0,0 +1,159 @@
# NixlConnector Usage Guide
NixlConnector is a high-performance KV cache transfer connector for vLLM's disaggregated prefilling feature. It provides fully asynchronous send/receive operations using the NIXL library for efficient cross-process KV cache transfer.
## Prerequisites
### Installation
Install the NIXL library: `uv pip install nixl`, as a quick start.
- Refer to [NIXL official repository](https://github.com/ai-dynamo/nixl) for more installation instructions
- The specified required NIXL version can be found in [requirements/kv_connectors.txt](../../requirements/kv_connectors.txt) and other relevant config files
### Transport Configuration
NixlConnector uses NIXL library for underlying communication, which supports multiple transport backends. UCX (Unified Communication X) is the primary default transport library used by NIXL. Configure transport environment variables:
```bash
# Example UCX configuration, adjust according to your enviroment
export UCX_TLS=all # or specify specific transports like "rc,ud,sm,^cuda_ipc" ..etc
export UCX_NET_DEVICES=all # or specify network devices like "mlx5_0:1,mlx5_1:1"
```
!!! tip
When using UCX as the transport backend, NCCL environment variables (like `NCCL_IB_HCA`, `NCCL_SOCKET_IFNAME`) are not applicable to NixlConnector, so configure UCX-specific environment variables instead of NCCL variables.
## Basic Usage (on the same host)
### Producer (Prefiller) Configuration
Start a prefiller instance that produces KV caches
```bash
# 1st GPU as prefiller
CUDA_VISIBLE_DEVICES=0 \
UCX_NET_DEVICES=all \
VLLM_NIXL_SIDE_CHANNEL_PORT=5600 \
vllm serve Qwen/Qwen3-0.6B \
--port 8100 \
--enforce-eager \
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
```
### Consumer (Decoder) Configuration
Start a decoder instance that consumes KV caches:
```bash
# 2nd GPU as decoder
CUDA_VISIBLE_DEVICES=1 \
UCX_NET_DEVICES=all \
VLLM_NIXL_SIDE_CHANNEL_PORT=5601 \
vllm serve Qwen/Qwen3-0.6B \
--port 8200 \
--enforce-eager \
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
```
### Proxy Server
Use a proxy server to route requests between prefiller and decoder:
```bash
python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py \
--port 8192 \
--prefiller-hosts localhost \
--prefiller-ports 8100 \
--decoder-hosts localhost \
--decoder-ports 8200
```
## Environment Variables
- `VLLM_NIXL_SIDE_CHANNEL_PORT`: Port for NIXL handshake communication
- Default: 5600
- **Required for both prefiller and decoder instances**
- Each vLLM worker needs a unique port on its host; using the same port number across different hosts is fine
- For TP/DP deployments, each worker's port on a node is computed as: base_port + dp_rank * tp_size + tp_rank (e.g., with `--tensor-parallel-size=4` and base_port=5600, tp_rank 0..3 use ports 5600, 5601, 5602, 5603 on that node).
- Used for the initial NIXL handshake between the prefiller and the decoder
- `VLLM_NIXL_SIDE_CHANNEL_HOST`: Host for side channel communication
- Default: "localhost"
- Set when prefiller and decoder are on different machines
- Connection info is passed via KVTransferParams from prefiller to decoder for handshake
- `VLLM_NIXL_ABORT_REQUEST_TIMEOUT`: Timeout (in seconds) for automatically releasing the prefillers KV cache for a particular request. (Optional)
- Default: 120
- If a request is aborted and the decoder has not yet read the KV-cache blocks through the nixl channel, the prefill instance will release its KV-cache blocks after this timeout to avoid holding them indefinitely.
## Multi-Instance Setup
### Multiple Prefiller Instances on Different Machines
```bash
# Prefiller 1 on Machine A (example IP: ${IP1})
VLLM_NIXL_SIDE_CHANNEL_HOST=${IP1} \
VLLM_NIXL_SIDE_CHANNEL_PORT=5600 \
UCX_NET_DEVICES=all \
vllm serve Qwen/Qwen3-0.6B --port 8000 \
--tensor-parallel-size 8 \
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_producer"}'
# Prefiller 2 on Machine B (example IP: ${IP2})
VLLM_NIXL_SIDE_CHANNEL_HOST=${IP2} \
VLLM_NIXL_SIDE_CHANNEL_PORT=5600 \
UCX_NET_DEVICES=all \
vllm serve Qwen/Qwen3-0.6B --port 8000 \
--tensor-parallel-size 8 \
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_producer"}'
```
### Multiple Decoder Instances on Different Machines
```bash
# Decoder 1 on Machine C (example IP: ${IP3})
VLLM_NIXL_SIDE_CHANNEL_HOST=${IP3} \
VLLM_NIXL_SIDE_CHANNEL_PORT=5600 \
UCX_NET_DEVICES=all \
vllm serve Qwen/Qwen3-0.6B --port 8000 \
--tensor-parallel-size 8 \
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_consumer"}'
# Decoder 2 on Machine D (example IP: ${IP4})
VLLM_NIXL_SIDE_CHANNEL_HOST=${IP4} \
VLLM_NIXL_SIDE_CHANNEL_PORT=5600 \
UCX_NET_DEVICES=all \
vllm serve Qwen/Qwen3-0.6B --port 8000 \
--tensor-parallel-size 8 \
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_consumer"}'
```
### Proxy for Multiple Instances
```bash
python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py \
--port 8192 \
--prefiller-hosts ${IP1} ${IP2} \
--prefiller-ports 8000 8000 \
--decoder-hosts ${IP3} ${IP4} \
--decoder-ports 8000 8000
```
### KV Role Options
- **kv_producer**: For prefiller instances that generate KV caches
- **kv_consumer**: For decoder instances that consume KV caches from prefiller
- **kv_both**: Enables symmetric functionality where the connector can act as both producer and consumer. This provides flexibility for experimental setups and scenarios where the role distinction is not predetermined.
!!! tip
NixlConnector currently does not distinguish `kv_role`; the actual prefiller/decoder roles are determined by the upper-level proxy (e.g., `toy_proxy_server.py` using `--prefiller-hosts` and `--decoder-hosts`).
Therefore, `kv_role` in `--kv-transfer-config` is effectively a placeholder and does not affect NixlConnector's behavior.
## Example Scripts/Code
Refer to these example scripts in the vLLM repository:
- [run_accuracy_test.sh](../../tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh)
- [toy_proxy_server.py](../../tests/v1/kv_connector/nixl_integration/toy_proxy_server.py)
- [test_accuracy.py](../../tests/v1/kv_connector/nixl_integration/test_accuracy.py)

View File

@ -319,6 +319,15 @@ Supported models:
Flags: `--tool-call-parser glm45` Flags: `--tool-call-parser glm45`
### Qwen3-Coder Models (`qwen3_xml`)
Supported models:
* `Qwen/Qwen3-480B-A35B-Instruct`
* `Qwen/Qwen3-Coder-30B-A3B-Instruct`
Flags: `--tool-call-parser qwen3_xml`
### Models with Pythonic Tool Calls (`pythonic`) ### Models with Pythonic Tool Calls (`pythonic`)
A growing number of models output a python list to represent tool calls instead of using JSON. This has the advantage of inherently supporting parallel tool calls and removing ambiguity around the JSON schema required for tool calls. The `pythonic` tool parser can support such models. A growing number of models output a python list to represent tool calls instead of using JSON. This has the advantage of inherently supporting parallel tool calls and removing ambiguity around the JSON schema required for tool calls. The `pythonic` tool parser can support such models.

View File

@ -352,6 +352,7 @@ th {
| `DeepseekV2ForCausalLM` | DeepSeek-V2 | `deepseek-ai/DeepSeek-V2`, `deepseek-ai/DeepSeek-V2-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | | `DeepseekV2ForCausalLM` | DeepSeek-V2 | `deepseek-ai/DeepSeek-V2`, `deepseek-ai/DeepSeek-V2-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3`, `deepseek-ai/DeepSeek-R1`, `deepseek-ai/DeepSeek-V3.1`, etc. | ✅︎ | ✅︎ | ✅︎ | | `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3`, `deepseek-ai/DeepSeek-R1`, `deepseek-ai/DeepSeek-V3.1`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Dots1ForCausalLM` | dots.llm1 | `rednote-hilab/dots.llm1.base`, `rednote-hilab/dots.llm1.inst`, etc. | | ✅︎ | ✅︎ | | `Dots1ForCausalLM` | dots.llm1 | `rednote-hilab/dots.llm1.base`, `rednote-hilab/dots.llm1.inst`, etc. | | ✅︎ | ✅︎ |
| `DotsOCRForCausalLM` | dots_ocr | `rednote-hilab/dots.ocr` | | ✅︎ | ✅︎ |
| `Ernie4_5ForCausalLM` | Ernie4.5 | `baidu/ERNIE-4.5-0.3B-PT`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Ernie4_5ForCausalLM` | Ernie4.5 | `baidu/ERNIE-4.5-0.3B-PT`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Ernie4_5_MoeForCausalLM` | Ernie4.5MoE | `baidu/ERNIE-4.5-21B-A3B-PT`, `baidu/ERNIE-4.5-300B-A47B-PT`, etc. |✅︎| ✅︎ | ✅︎ | | `Ernie4_5_MoeForCausalLM` | Ernie4.5MoE | `baidu/ERNIE-4.5-21B-A3B-PT`, `baidu/ERNIE-4.5-300B-A47B-PT`, etc. |✅︎| ✅︎ | ✅︎ |
| `ExaoneForCausalLM` | EXAONE-3 | `LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `ExaoneForCausalLM` | EXAONE-3 | `LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |

View File

@ -193,7 +193,7 @@ For production deployments requiring strict SLA guarantees for time-to-first-tok
1. **Install gdrcopy/ucx/nixl**: For maximum performance, run the [install_gdrcopy.sh](gh-file:tools/install_gdrcopy.sh) script to install `gdrcopy` (e.g., `install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "12.8" "x64"`). You can find available OS versions [here](https://developer.download.nvidia.com/compute/redist/gdrcopy/CUDA%2012.8/). If `gdrcopy` is not installed, things will still work with a plain `pip install nixl`, just with lower performance. `nixl` and `ucx` are installed as dependencies via pip. 1. **Install gdrcopy/ucx/nixl**: For maximum performance, run the [install_gdrcopy.sh](gh-file:tools/install_gdrcopy.sh) script to install `gdrcopy` (e.g., `install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "12.8" "x64"`). You can find available OS versions [here](https://developer.download.nvidia.com/compute/redist/gdrcopy/CUDA%2012.8/). If `gdrcopy` is not installed, things will still work with a plain `pip install nixl`, just with lower performance. `nixl` and `ucx` are installed as dependencies via pip.
2. **Configure Both Instances**: Add this flag to both prefill and decode instances `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}` 2. **Configure Both Instances**: Add this flag to both prefill and decode instances `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}`. Noted, you may also specify one or multiple NIXL_Backend. Such as: `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both", "kv_connector_extra_config":{"backend":["UCX", "GDS"]}'`
3. **Client Orchestration**: Use the client-side script below to coordinate prefill/decode operations. We are actively working on routing solutions. 3. **Client Orchestration**: Use the client-side script below to coordinate prefill/decode operations. We are actively working on routing solutions.

View File

@ -101,6 +101,13 @@ def parse_args():
"--quantization", "--quantization",
type=str, type=str,
) )
parser.add_argument(
"--disable-expert-parallel",
dest="enable_expert_parallel",
action="store_false",
help="Disable expert parallel (default: enabled).",
)
parser.set_defaults(enable_expert_parallel=True)
return parser.parse_args() return parser.parse_args()
@ -113,6 +120,7 @@ def main(
dp_master_port, dp_master_port,
GPUs_per_dp_rank, GPUs_per_dp_rank,
enforce_eager, enforce_eager,
enable_expert_parallel,
trust_remote_code, trust_remote_code,
max_num_seqs, max_num_seqs,
max_model_len, max_model_len,
@ -168,7 +176,7 @@ def main(
model=model, model=model,
tensor_parallel_size=GPUs_per_dp_rank, tensor_parallel_size=GPUs_per_dp_rank,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
enable_expert_parallel=True, enable_expert_parallel=enable_expert_parallel,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
max_num_seqs=max_num_seqs, max_num_seqs=max_num_seqs,
max_model_len=max_model_len, max_model_len=max_model_len,
@ -229,6 +237,7 @@ if __name__ == "__main__":
dp_master_port, dp_master_port,
tp_size, tp_size,
args.enforce_eager, args.enforce_eager,
args.enable_expert_parallel,
args.trust_remote_code, args.trust_remote_code,
args.max_num_seqs, args.max_num_seqs,
args.max_model_len, args.max_model_len,

View File

@ -0,0 +1,81 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
experimental support for data-parallel inference with torchrun
Note the data load balancing and distribution is done out of the vllm engine,
no internal lb supported in external_launcher mode.
"""
from vllm import LLM, SamplingParams
# Create prompts, the same across all ranks
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
] * 50
# Create sampling parameters, the same across all ranks
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Use `distributed_executor_backend="external_launcher"` so that
# this llm engine/instance only creates one worker.
# it is important to set an explicit seed to make sure that
# all ranks have the same random seed, so that sampling can be
# deterministic across ranks.
llm = LLM(
model="microsoft/Phi-mini-MoE-instruct",
tensor_parallel_size=1,
data_parallel_size=2,
pipeline_parallel_size=1,
enable_expert_parallel=False,
distributed_executor_backend="external_launcher",
max_model_len=4096,
gpu_memory_utilization=0.6,
seed=1,
)
dp_rank = llm.llm_engine.vllm_config.parallel_config.data_parallel_rank
dp_size = llm.llm_engine.vllm_config.parallel_config.data_parallel_size
prompts = [
f"{idx}.{prompt}" for idx, prompt in enumerate(prompts) if idx % dp_size == dp_rank
]
outputs = llm.generate(prompts, sampling_params)
# all ranks will have the same outputs
print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}\n")
print("-" * 50)
"""
Further tips:
1. to communicate control messages across all ranks, use the cpu group,
a PyTorch ProcessGroup with GLOO backend.
```python
from vllm.distributed.parallel_state import get_world_group
cpu_group = get_world_group().cpu_group
torch_rank = dist.get_rank(group=cpu_group)
if torch_rank == 0:
# do something for rank 0, e.g. saving the results to disk.
```
2. to communicate data across all ranks, use the model's device group,
a PyTorch ProcessGroup with NCCL backend.
```python
from vllm.distributed.parallel_state import get_world_group
device_group = get_world_group().device_group
```
3. to access the model directly in every rank, use the following code:
```python
llm.llm_engine.model_executor.driver_worker.worker.model_runner.model
```
"""

View File

@ -126,6 +126,23 @@ def run_chameleon(questions: list[str], modality: str) -> ModelRequestData:
) )
# Dots-OCR
def run_dots_ocr(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
prompts = [f"<|img|><|imgpad|><|endofimg|>{question}" for question in questions]
engine_args = EngineArgs(
model="rednote-hilab/dots.ocr",
limit_mm_per_prompt={modality: 1},
trust_remote_code=True,
)
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
def run_command_a_vision(questions: list[str], modality: str) -> ModelRequestData: def run_command_a_vision(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image" assert modality == "image"
@ -1676,6 +1693,7 @@ model_example_map = {
"aya_vision": run_aya_vision, "aya_vision": run_aya_vision,
"blip-2": run_blip2, "blip-2": run_blip2,
"chameleon": run_chameleon, "chameleon": run_chameleon,
"dots_ocr": run_dots_ocr,
"command_a_vision": run_command_a_vision, "command_a_vision": run_command_a_vision,
"deepseek_vl_v2": run_deepseek_vl2, "deepseek_vl_v2": run_deepseek_vl2,
"ernie45_vl": run_ernie45_vl, "ernie45_vl": run_ernie45_vl,

View File

@ -110,27 +110,6 @@ ignore_missing_imports = true
check_untyped_defs = true check_untyped_defs = true
follow_imports = "silent" follow_imports = "silent"
# After fixing type errors resulting from follow_imports: "skip" -> "silent",
# move the directory here and remove it from tools/mypy.sh
files = [
"vllm/*.py",
"vllm/assets",
"vllm/entrypoints",
"vllm/inputs",
"vllm/logging_utils",
"vllm/multimodal",
"vllm/platforms",
"vllm/transformers_utils",
"vllm/triton_utils",
"vllm/usage",
]
# TODO(woosuk): Include the code from Megatron and HuggingFace.
exclude = [
"vllm/model_executor/parallel_utils/|vllm/model_executor/models/",
# Ignore triton kernels in ops.
'vllm/attention/ops/.*\.py$'
]
[tool.isort] [tool.isort]
skip_glob = [ skip_glob = [
".buildkite/*", ".buildkite/*",

View File

@ -14,14 +14,4 @@ nixl==0.3.0
tpu_info==0.4.0 tpu_info==0.4.0
# Install torch_xla # Install torch_xla
--pre torch_xla[tpu, pallas]==2.8.0
--extra-index-url https://download.pytorch.org/whl/nightly/cpu
--find-links https://storage.googleapis.com/libtpu-wheels/index.html
--find-links https://storage.googleapis.com/libtpu-releases/index.html
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
torch==2.9.0.dev20250730
torchvision==0.24.0.dev20250730
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250730-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250730-cp312-cp312-linux_x86_64.whl ; python_version == "3.12"

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import weakref
from collections.abc import Sequence from collections.abc import Sequence
from copy import deepcopy from copy import deepcopy
from typing import Callable, Union from typing import Callable, Union
@ -10,7 +11,26 @@ from torch._ops import OpOverload
from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.inductor_pass import InductorPass from vllm.compilation.inductor_pass import InductorPass
from vllm.config import get_current_vllm_config from vllm.compilation.pass_manager import with_pattern_match_debug
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import VllmConfig, get_current_vllm_config
class LazyInitPass(InductorPass):
"""
If there's a pass that we want to initialize lazily in a test,
we can wrap it in LazyInitPass, which will initialize the pass when invoked
and then immediately invoke it.
"""
def __init__(self, pass_cls: type[VllmInductorPass],
vllm_config: VllmConfig):
self.pass_cls = pass_cls
self.vllm_config = weakref.proxy(vllm_config) # avoid cycle
def __call__(self, graph: fx.Graph) -> None:
self.pass_ = self.pass_cls(self.vllm_config)
self.pass_(graph)
class TestBackend: class TestBackend:
@ -40,10 +60,16 @@ class TestBackend:
example_inputs, example_inputs,
config_patches=self.inductor_config) config_patches=self.inductor_config)
@with_pattern_match_debug
def post_pass(self, graph: fx.Graph): def post_pass(self, graph: fx.Graph):
self.graph_pre_pass = deepcopy(graph) self.graph_pre_pass = deepcopy(graph)
VllmInductorPass.dump_prefix = 0
for pass_ in self.custom_passes: for pass_ in self.custom_passes:
pass_(graph) pass_(graph)
VllmInductorPass.dump_prefix += 1
VllmInductorPass.dump_prefix = None
self.graph_post_pass = deepcopy(graph) self.graph_post_pass = deepcopy(graph)
# assign by reference, will reflect the final state of the graph # assign by reference, will reflect the final state of the graph

View File

@ -46,7 +46,10 @@ backend_configs = {
# FA3 on Hopper # FA3 on Hopper
"FA3": "FA3":
BackendConfig(name="FA3", BackendConfig(name="FA3",
env_vars={"VLLM_FLASH_ATTN_VERSION": "3"}, env_vars={
"VLLM_FLASH_ATTN_VERSION": "3",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={ comp_config={
"cudagraph_mode": "FULL", "cudagraph_mode": "FULL",
}, },
@ -66,6 +69,7 @@ backend_configs = {
BackendConfig(name="FlashAttentionMLA", BackendConfig(name="FlashAttentionMLA",
env_vars={ env_vars={
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA", "VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
}, },
comp_config={ comp_config={
"cudagraph_mode": "FULL_DECODE_ONLY", "cudagraph_mode": "FULL_DECODE_ONLY",
@ -89,7 +93,10 @@ backend_configs = {
# FA2 # FA2
"FA2": "FA2":
BackendConfig(name="FA2", BackendConfig(name="FA2",
env_vars={"VLLM_FLASH_ATTN_VERSION": "2"}, env_vars={
"VLLM_FLASH_ATTN_VERSION": "2",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={ comp_config={
"cudagraph_mode": "FULL", "cudagraph_mode": "FULL",
}), }),

View File

@ -294,6 +294,8 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
compiled_model = torch.compile(model, backend=backend) compiled_model = torch.compile(model, backend=backend)
compiled_model(hidden_states) compiled_model(hidden_states)
assert async_tp_pass.matched_count == 1
# In pre-nodes, all gather or reduce scatter should exist, # In pre-nodes, all gather or reduce scatter should exist,
# fused_matmul_reduce_scatter or fused_all_gather_matmul should not # fused_matmul_reduce_scatter or fused_all_gather_matmul should not
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False) backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)

View File

@ -4,7 +4,7 @@ import pytest
import vllm import vllm
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.config import VllmConfig from vllm.config import CompilationConfig, VllmConfig
from vllm.utils import _is_torch_equal_or_newer from vllm.utils import _is_torch_equal_or_newer
@ -26,6 +26,14 @@ def test_use_cudagraphs_dynamic(monkeypatch):
assert not vllm_config.compilation_config.use_cudagraph assert not vllm_config.compilation_config.use_cudagraph
def test_custom_op():
# proper syntax
_ = CompilationConfig(custom_ops=["+quant_fp8", "-silu_and_mul"])
with pytest.raises(ValueError, match="Invalid syntax '"):
_ = CompilationConfig(custom_ops=["quant_fp8"])
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073 # forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked @pytest.mark.forked
# NB: We don't test VLLM_DISABLE_COMPILE_CACHE=0 because that depends # NB: We don't test VLLM_DISABLE_COMPILE_CACHE=0 because that depends

View File

@ -8,9 +8,10 @@ import vllm.envs as envs
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
from vllm.compilation.fix_functionalization import FixFunctionalizationPass from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.fusion import FUSED_OPS, FusionPass from vllm.compilation.fusion import FUSED_OPS, RMSNormQuantFusionPass
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import CompilationConfig, PassConfig, VllmConfig from vllm.config import CompilationConfig, PassConfig, VllmConfig
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, kFp8DynamicTokenSym, kFp8StaticTensorSym) QuantKey, kFp8DynamicTokenSym, kFp8StaticTensorSym)
@ -58,11 +59,12 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
vllm_config.compilation_config = CompilationConfig( vllm_config.compilation_config = CompilationConfig(
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True)) pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True))
noop_pass = NoOpEliminationPass(vllm_config) noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = FusionPass.instance(vllm_config) fusion_pass = RMSNormQuantFusionPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config) act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
passes = [noop_pass, fusion_pass, act_quant_fusion_pass passes = [noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass
] if do_fusion else [noop_pass] ] if do_fusion else [noop_pass, cleanup_pass]
func_pass = FixFunctionalizationPass(vllm_config) func_pass = FixFunctionalizationPass(vllm_config)
backend_func = TestBackend(*passes, func_pass) backend_func = TestBackend(*passes, func_pass)
backend_no_func = TestBackend(*passes) backend_no_func = TestBackend(*passes)

View File

@ -4,11 +4,11 @@
import pytest import pytest
import torch import torch
import vllm.envs as envs
import vllm.plugins import vllm.plugins
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey, from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
FusionPass) RMSNormQuantFusionPass)
from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig, from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
VllmConfig) VllmConfig)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
@ -79,15 +79,15 @@ class TestModel(torch.nn.Module):
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [64, 3392, 4096]) @pytest.mark.parametrize("hidden_size", [64])
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049]) @pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5, 1e-6]) @pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize("static", [True, False]) @pytest.mark.parametrize("static", [True, False])
# cuda_force_torch used to test torch code path on platforms that # cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True. # cutlass_fp8_supported() == True.
@pytest.mark.parametrize("cuda_force_torch", @pytest.mark.parametrize("cuda_force_torch",
[True, False] if cutlass_fp8_supported() else [True]) [True, False] if cutlass_fp8_supported() else [True])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], @pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Only test on CUDA and ROCm") reason="Only test on CUDA and ROCm")
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
cuda_force_torch): cuda_force_torch):
@ -104,9 +104,10 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
with vllm.config.set_current_vllm_config(vllm_config): with vllm.config.set_current_vllm_config(vllm_config):
# Reshape pass is needed for the fusion pass to work # Reshape pass is needed for the fusion pass to work
noop_pass = NoOpEliminationPass(vllm_config) noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = FusionPass.instance(vllm_config) fusion_pass = RMSNormQuantFusionPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
backend = TestBackend(noop_pass, fusion_pass) backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
model = TestModel(hidden_size, eps, static, cuda_force_torch) model = TestModel(hidden_size, eps, static, cuda_force_torch)
# First dimension dynamic # First dimension dynamic
@ -128,6 +129,8 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL) torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
assert fusion_pass.matched_count == 2
# In pre-nodes, fp8 quant should be there and fused kernels should not # In pre-nodes, fp8 quant should be there and fused kernels should not
backend.check_before_ops(model.ops_in_model_before()) backend.check_before_ops(model.ops_in_model_before())

View File

@ -9,6 +9,7 @@ import vllm.envs as envs
from vllm.compilation.collective_fusion import AllReduceFusionPass from vllm.compilation.collective_fusion import AllReduceFusionPass
from vllm.compilation.fix_functionalization import FixFunctionalizationPass from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig, from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig,
ModelConfig, PassConfig, VllmConfig) ModelConfig, PassConfig, VllmConfig)
from vllm.distributed import tensor_model_parallel_all_reduce from vllm.distributed import tensor_model_parallel_all_reduce
@ -215,8 +216,10 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config) all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
noop_pass = NoOpEliminationPass(vllm_config) noop_pass = NoOpEliminationPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config) func_pass = FixFunctionalizationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass) backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass,
cleanup_pass)
token_num = batch_size * seq_len token_num = batch_size * seq_len
model = test_model_cls(hidden_size, token_num) model = test_model_cls(hidden_size, token_num)
@ -227,6 +230,7 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
compiled_model = torch.compile(model, backend=backend) compiled_model = torch.compile(model, backend=backend)
compiled_model(hidden_states, residual) compiled_model(hidden_states, residual)
assert all_reduce_fusion_pass.matched_count == 1
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False) backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
backend.check_after_ops(model.ops_in_model_after()) backend.check_after_ops(model.ops_in_model_after())
del all_reduce_fusion_pass del all_reduce_fusion_pass

View File

@ -6,18 +6,19 @@ from typing import Optional
import pytest import pytest
import torch._dynamo import torch._dynamo
from tests.compile.backend import TestBackend from tests.compile.backend import LazyInitPass, TestBackend
from tests.models.utils import check_outputs_equal from tests.models.utils import check_outputs_equal
from tests.v1.attention.utils import (BatchSpec, _Backend, from tests.v1.attention.utils import (BatchSpec, _Backend,
create_common_attn_metadata) create_common_attn_metadata)
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.attention import Attention from vllm.attention import Attention, AttentionMetadata
from vllm.attention.selector import global_force_attn_backend_context_manager from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.compilation.fusion import QUANT_OPS from vllm.compilation.fusion import QUANT_OPS
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel, from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
ModelConfig, PassConfig, SchedulerConfig, VllmConfig, ModelConfig, PassConfig, SchedulerConfig, VllmConfig,
set_current_vllm_config) set_current_vllm_config)
@ -104,7 +105,7 @@ def test_attention_fusion_v0(example_prompts, monkeypatch, model: str,
# AttnFusionPass needs attention layers to be registered in config upon init # AttnFusionPass needs attention layers to be registered in config upon init
# so we initialize it during compilation. # so we initialize it during compilation.
attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw) attn_pass = LazyInitPass(AttnFusionPass, vllm_config)
backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass) backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass)
llm2 = LLM(model, llm2 = LLM(model,
enforce_eager=True, enforce_eager=True,
@ -197,7 +198,8 @@ class AttentionQuantPatternModel(torch.nn.Module):
device=self.device, device=self.device,
) )
def build_attn_metadata(self, batch_size: int, use_hnd: bool): def build_attn_metadata(self, batch_size: int, use_hnd: bool) \
-> AttentionMetadata:
"""Initialize attention metadata.""" """Initialize attention metadata."""
# Create common attn metadata # Create common attn metadata
@ -447,9 +449,10 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
# Create test backend with fusion passes enabled # Create test backend with fusion passes enabled
noop_pass = NoOpEliminationPass(vllm_config) noop_pass = NoOpEliminationPass(vllm_config)
attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw attn_pass = LazyInitPass(AttnFusionPass, vllm_config)
) cleanup_pass = PostCleanupPass(vllm_config)
test_backend = TestBackend(noop_pass, attn_pass)
test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass)
# Compile model with fusion enabled # Compile model with fusion enabled
model_compiled = torch.compile(model_fused, model_compiled = torch.compile(model_fused,
@ -485,6 +488,9 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
test_backend.check_before_ops([QUANT_OPS[quant_key]], test_backend.check_before_ops([QUANT_OPS[quant_key]],
fully_replaced=True) fully_replaced=True)
# access the underlying `AttnFusionPass` on the `LazyInitPass`
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)
# Check attention ops in the graph before and after fusion # Check attention ops in the graph before and after fusion
attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass)) attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass))
attn_nodes_post = list(find_op_nodes(ATTN_OP, attn_nodes_post = list(find_op_nodes(ATTN_OP,

View File

@ -6,10 +6,12 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.compilation.fix_functionalization import FixFunctionalizationPass from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.fusion import FusionPass from vllm.compilation.fusion import RMSNormQuantFusionPass
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.compilation.sequence_parallelism import SequenceParallelismPass from vllm.compilation.sequence_parallelism import SequenceParallelismPass
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig, from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
PassConfig, VllmConfig) PassConfig, VllmConfig)
from vllm.distributed import tensor_model_parallel_all_reduce from vllm.distributed import tensor_model_parallel_all_reduce
@ -104,7 +106,7 @@ class TestQuantModel(torch.nn.Module):
# Initialize weights # Initialize weights
torch.nn.init.normal_(self.gate_proj, std=0.02) torch.nn.init.normal_(self.gate_proj, std=0.02)
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=False) self.fp8_linear = Fp8LinearOp(act_quant_static=True)
self.scale = torch.rand(1, dtype=torch.float32) self.scale = torch.rand(1, dtype=torch.float32)
# Create a weight that is compatible with torch._scaled_mm, # Create a weight that is compatible with torch._scaled_mm,
@ -137,8 +139,7 @@ class TestQuantModel(torch.nn.Module):
# layer normalization # layer normalization
norm_output, residual_output = self.norm(all_reduce, residual) norm_output, residual_output = self.norm(all_reduce, residual)
# for static input quantization # scaled_mm with static input quantization
# self.fp8_linear is initialized with use_per_token_if_dynamic=False
fp8_linear_result = self.fp8_linear.apply(norm_output, fp8_linear_result = self.fp8_linear.apply(norm_output,
self.w, self.w,
self.wscale, self.wscale,
@ -253,16 +254,20 @@ def sequence_parallelism_pass_on_test_model(
dtype=dtype, dtype=dtype,
seed=42) seed=42)
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
noop_pass = NoOpEliminationPass(vllm_config) noop_pass = NoOpEliminationPass(vllm_config)
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config) func_pass = FixFunctionalizationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
passes_for_backend = [noop_pass, sequence_parallelism_pass] passes_for_backend: list[VllmInductorPass] = \
[noop_pass, sequence_parallelism_pass]
if enable_fusion: if enable_fusion:
fusion_pass = FusionPass.instance(vllm_config) fusion_pass = RMSNormQuantFusionPass(vllm_config)
passes_for_backend.append(fusion_pass) passes_for_backend.append(fusion_pass)
passes_for_backend.append(cleanup_pass)
backend_no_func = TestBackend(*passes_for_backend) backend_no_func = TestBackend(*passes_for_backend)
backend_func = TestBackend(*passes_for_backend, func_pass) backend_func = TestBackend(*passes_for_backend, func_pass)
@ -279,6 +284,8 @@ def sequence_parallelism_pass_on_test_model(
compiled_model_func = torch.compile(model, backend=backend_func) compiled_model_func = torch.compile(model, backend=backend_func)
compiled_model_func(hidden_states, residual) compiled_model_func(hidden_states, residual)
assert sequence_parallelism_pass.matched_count == 1
# In pre-nodes, all reduce should be there, # In pre-nodes, all reduce should be there,
# reduce scatter and all gather should not # reduce scatter and all gather should not
backend_no_func.check_before_ops(model.ops_in_model_before()) backend_no_func.check_before_ops(model.ops_in_model_before())

View File

@ -15,6 +15,7 @@ from vllm.compilation.activation_quant_fusion import (
# yapf: enable # yapf: enable
from vllm.compilation.fusion import QUANT_OPS from vllm.compilation.fusion import QUANT_OPS
from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import CompilationConfig, PassConfig, VllmConfig from vllm.config import CompilationConfig, PassConfig, VllmConfig
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
@ -69,6 +70,10 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
def __init__(self, hidden_size: int, x: torch.Tensor, **kwargs): def __init__(self, hidden_size: int, x: torch.Tensor, **kwargs):
super().__init__() super().__init__()
from vllm.compilation.activation_quant_fusion import (
silu_and_mul_nvfp4_quant_supported)
assert silu_and_mul_nvfp4_quant_supported
self.silu_and_mul = SiluAndMul() self.silu_and_mul = SiluAndMul()
# create nvfp4 weight # create nvfp4 weight
@ -127,7 +132,11 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class,
pass_config=PassConfig(enable_fusion=True, enable_noop=True)) pass_config=PassConfig(enable_fusion=True, enable_noop=True))
fusion_pass = ActivationQuantFusionPass(config) fusion_pass = ActivationQuantFusionPass(config)
backend = TestBackend(NoOpEliminationPass(config), fusion_pass) passes = [
NoOpEliminationPass(config), fusion_pass,
PostCleanupPass(config)
]
backend = TestBackend(*passes)
model = model_class(hidden_size=hidden_size, model = model_class(hidden_size=hidden_size,
cuda_force_torch=cuda_force_torch, cuda_force_torch=cuda_force_torch,
x=x) x=x)
@ -151,6 +160,8 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class,
atol=atol, atol=atol,
rtol=rtol) rtol=rtol)
assert fusion_pass.matched_count == 1
# In pre-nodes, quant op should be present and fused kernels should not # In pre-nodes, quant op should be present and fused kernels should not
backend.check_before_ops(model.ops_in_model_before()) backend.check_before_ops(model.ops_in_model_before())

View File

@ -0,0 +1,81 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# unit test for `examples/offline_inference/torchrun_example.py`
import os
import random
import torch.distributed as dist
from vllm import LLM, SamplingParams
from vllm.distributed.parallel_state import get_tp_group, get_world_group
dist.init_process_group(backend="gloo")
# Create prompts
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
] * 10
dp_size = int(os.getenv("DP_SIZE", "1"))
dp_rank = int(os.getenv("DP_RANK", "0"))
if dp_size > 1:
# distribute the prompts across the data parallel ranks
prompts = [
prompt for idx, prompt in enumerate(prompts)
if idx % dp_size == dp_rank
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# set different `gpu_memory_utilization` and `swap_space` for different ranks,
# to test if all ranks agree on the same kv cache configuration.
llm = LLM(model="microsoft/Phi-mini-MoE-instruct",
tensor_parallel_size=int(os.getenv("TP_SIZE", "1")),
pipeline_parallel_size=int(os.getenv("PP_SIZE", "1")),
enable_expert_parallel=int(os.getenv("ENABLE_EP", "0")) == 1,
distributed_executor_backend="external_launcher",
gpu_memory_utilization=random.uniform(0.7, 0.9),
swap_space=random.randint(1, 4),
seed=0)
outputs = llm.generate(prompts, sampling_params)
group = get_world_group() if dp_size == 1 else get_tp_group()
cpu_group = group.cpu_group
group_rank = dist.get_rank(group=cpu_group)
def test_consistent_across_ranks(obj):
if group_rank == 0:
dist.broadcast_object_list([obj], src=group.ranks[0], group=cpu_group)
else:
container = [None]
dist.broadcast_object_list(container,
src=group.ranks[0],
group=cpu_group)
assert container[0] == obj
test_consistent_across_ranks(
llm.llm_engine.vllm_config.cache_config.num_cpu_blocks)
test_consistent_across_ranks(
llm.llm_engine.vllm_config.cache_config.num_gpu_blocks)
# make sure we can access the model parameters from the calling process
# of the `LLM` instance.
params = list(llm.llm_engine.model_executor.driver_worker.worker.model_runner.
model.parameters())
test_consistent_across_ranks(len(params))
# all ranks should have the same outputs
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
test_consistent_across_ranks(prompt)
test_consistent_across_ranks(generated_text)
print(f"Rank {group_rank}, Prompt: {prompt!r}, "
f"Generated text: {generated_text!r}")

View File

@ -0,0 +1,106 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import pytest_asyncio
from openai import OpenAI
from ...utils import RemoteOpenAIServer
MODEL_NAME = "openai/gpt-oss-20b"
@pytest.fixture(scope="module")
def monkeypatch_module():
from _pytest.monkeypatch import MonkeyPatch
mpatch = MonkeyPatch()
yield mpatch
mpatch.undo()
@pytest.fixture(scope="module")
def mcp_disabled_server(monkeypatch_module: pytest.MonkeyPatch):
args = ["--enforce-eager", "--tool-server", "demo"]
with monkeypatch_module.context() as m:
m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1")
m.setenv("PYTHON_EXECUTION_BACKEND", "dangerously_use_uv")
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest.fixture(scope="function")
def mcp_enabled_server(monkeypatch_module: pytest.MonkeyPatch):
args = ["--enforce-eager", "--tool-server", "demo"]
with monkeypatch_module.context() as m:
m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1")
m.setenv("PYTHON_EXECUTION_BACKEND", "dangerously_use_uv")
m.setenv("GPT_OSS_SYSTEM_TOOL_MCP_LABELS",
"code_interpreter,container")
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def mcp_disabled_client(mcp_disabled_server):
async with mcp_disabled_server.get_async_client() as async_client:
yield async_client
@pytest_asyncio.fixture
async def mcp_enabled_client(mcp_enabled_server):
async with mcp_enabled_server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.skip(reason="Code interpreter tool is not available in CI yet.")
async def test_mcp_tool_env_flag_enabled(mcp_enabled_client: OpenAI,
model_name: str):
response = await mcp_enabled_client.responses.create(
model=model_name,
# TODO: Ideally should be able to set max tool calls
# to prevent multi-turn, but it is not currently supported
# would speed up the test
input=("What's the first 4 digits after the decimal point of "
"cube root of `19910212 * 20250910`? "
"Show only the digits. The python interpreter is not stateful "
"and you must print to see the output."),
tools=[{
"type": "mcp",
"server_label": "code_interpreter",
# URL unused for DemoToolServer
"server_url": "http://localhost:8888"
}],
)
assert response is not None
assert response.status == "completed"
assert response.usage.output_tokens_details.tool_output_tokens > 0
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.skip(reason="Code interpreter tool is not available in CI yet.")
async def test_mcp_tool_env_flag_disabled(mcp_disabled_client: OpenAI,
model_name: str):
response = await mcp_disabled_client.responses.create(
model=model_name,
# TODO: Ideally should be able to set max tool calls
# to prevent multi-turn, but it is not currently supported
# would speed up the test
input=("What's the first 4 digits after the decimal point of "
"cube root of `19910212 * 20250910`? "
"Show only the digits. The python interpreter is not stateful "
"and you must print to see the output."),
tools=[{
"type": "mcp",
"server_label": "code_interpreter",
# URL unused for DemoToolServer
"server_url": "http://localhost:8888"
}],
)
assert response is not None
assert response.status == "completed"
assert response.usage.output_tokens_details.tool_output_tokens == 0

View File

@ -454,7 +454,13 @@ async def test_web_search(client: OpenAI, model_name: str):
async def test_code_interpreter(client: OpenAI, model_name: str): async def test_code_interpreter(client: OpenAI, model_name: str):
response = await client.responses.create( response = await client.responses.create(
model=model_name, model=model_name,
input="Multiply 64548*15151 using builtin python interpreter.", # TODO: Ideally should be able to set max tool calls
# to prevent multi-turn, but it is not currently supported
# would speed up the test
input=("What's the first 4 digits after the decimal point of "
"cube root of `19910212 * 20250910`? "
"Show only the digits. The python interpreter is not stateful "
"and you must print to see the output."),
tools=[{ tools=[{
"type": "code_interpreter", "type": "code_interpreter",
"container": { "container": {
@ -464,6 +470,7 @@ async def test_code_interpreter(client: OpenAI, model_name: str):
) )
assert response is not None assert response is not None
assert response.status == "completed" assert response.status == "completed"
assert response.usage.output_tokens_details.tool_output_tokens > 0
def get_weather(latitude, longitude): def get_weather(latitude, longitude):

View File

@ -5,6 +5,11 @@ import json
import pytest import pytest
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import (
Hermes2ProToolParser)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from ....utils import RemoteOpenAIServer from ....utils import RemoteOpenAIServer
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
@ -37,7 +42,7 @@ TOOLS = [{
}, },
"unit": { "unit": {
"type": "string", "type": "string",
"enum": ["celsius", "fahrenheit"] "enum": ["celsius", "fahrenheit"],
}, },
}, },
"required": ["location"], "required": ["location"],
@ -45,8 +50,39 @@ TOOLS = [{
}, },
}] }]
PRODUCT_TOOLS = [{
"type": "function",
"function": {
"name": "get_product_info",
"description": "Get detailed information of a product based on its "
"product ID.",
"parameters": {
"type": "object",
"properties": {
"inserted": {
"type": "boolean",
"description": "inserted.",
},
"product_id": {
"type": "integer",
"description": "The product ID of the product.",
},
},
"required": ["product_id", "inserted"],
},
},
}]
MESSAGES = [{"role": "user", "content": "What's the weather like in Boston?"}] MESSAGES = [{"role": "user", "content": "What's the weather like in Boston?"}]
PRODUCT_MESSAGES = [{
"role":
"user",
"content":
"Hi! Do you have any detailed information about the product id "
"7355608 and inserted true?",
}]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_non_streaming_tool_call(): async def test_non_streaming_tool_call():
@ -113,8 +149,8 @@ async def test_streaming_tool_call():
if tool_chunk.function.name: if tool_chunk.function.name:
tool_call_chunks[index]["name"] += tool_chunk.function.name tool_call_chunks[index]["name"] += tool_chunk.function.name
if tool_chunk.function.arguments: if tool_chunk.function.arguments:
tool_call_chunks[index][ tool_call_chunks[index]["arguments"] += (
"arguments"] += tool_chunk.function.arguments tool_chunk.function.arguments)
assert len(tool_call_chunks) == 1 assert len(tool_call_chunks) == 1
reconstructed_tool_call = tool_call_chunks[0] reconstructed_tool_call = tool_call_chunks[0]
@ -127,3 +163,295 @@ async def test_streaming_tool_call():
print("\n[Streaming Test Passed]") print("\n[Streaming Test Passed]")
print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}") print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}")
print(f"Reconstructed Arguments: {arguments}") print(f"Reconstructed Arguments: {arguments}")
@pytest.mark.asyncio
async def test_non_streaming_product_tool_call():
"""Test tool call integer and boolean parameters in non-streaming mode."""
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server:
client = server.get_async_client()
response = await client.chat.completions.create(
model=LORA_MODEL,
messages=PRODUCT_MESSAGES,
tools=PRODUCT_TOOLS,
tool_choice="auto",
temperature=0.66,
)
assert response.choices
choice = response.choices[0]
message = choice.message
assert choice.finish_reason == "tool_calls"
assert message.tool_calls is not None
tool_call = message.tool_calls[0]
assert tool_call.type == "function"
assert tool_call.function.name == "get_product_info"
arguments = json.loads(tool_call.function.arguments)
assert "product_id" in arguments
assert "inserted" in arguments
product_id = arguments.get("product_id")
inserted = arguments.get("inserted")
assert isinstance(product_id, int)
assert product_id == 7355608
assert isinstance(inserted, bool)
assert inserted is True
print("\n[Non-Streaming Product Test Passed]")
print(f"Tool Call: {tool_call.function.name}")
print(f"Arguments: {arguments}")
@pytest.mark.asyncio
async def test_streaming_product_tool_call():
"""Test tool call integer and boolean parameters in streaming mode."""
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server:
client = server.get_async_client()
stream = await client.chat.completions.create(
model=LORA_MODEL,
messages=PRODUCT_MESSAGES,
tools=PRODUCT_TOOLS,
tool_choice="auto",
temperature=0.66,
stream=True,
)
tool_call_chunks = {}
async for chunk in stream:
if not chunk.choices:
continue
delta = chunk.choices[0].delta
if not delta or not delta.tool_calls:
continue
for tool_chunk in delta.tool_calls:
index = tool_chunk.index
if index not in tool_call_chunks:
tool_call_chunks[index] = {"name": "", "arguments": ""}
if tool_chunk.function.name:
tool_call_chunks[index]["name"] += tool_chunk.function.name
if tool_chunk.function.arguments:
tool_call_chunks[index]["arguments"] += (
tool_chunk.function.arguments)
assert len(tool_call_chunks) == 1
reconstructed_tool_call = tool_call_chunks[0]
assert reconstructed_tool_call["name"] == "get_product_info"
arguments = json.loads(reconstructed_tool_call["arguments"])
assert "product_id" in arguments
assert "inserted" in arguments
# Handle type coercion for streaming test as well
product_id = arguments.get("product_id")
inserted = arguments.get("inserted")
assert isinstance(product_id, int)
assert product_id == 7355608
assert isinstance(inserted, bool)
assert inserted is True
print("\n[Streaming Product Test Passed]")
print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}")
print(f"Reconstructed Arguments: {arguments}")
@pytest.fixture
def qwen_tokenizer() -> AnyTokenizer:
from vllm.transformers_utils.tokenizer import get_tokenizer
return get_tokenizer("Qwen/Qwen3-32B")
@pytest.fixture
def hermes_parser(qwen_tokenizer: AnyTokenizer) -> Hermes2ProToolParser:
return Hermes2ProToolParser(qwen_tokenizer)
@pytest.fixture
def any_chat_request() -> ChatCompletionRequest:
return ChatCompletionRequest(
seed=42,
model="Qwen/Qwen3-32B",
messages=[],
)
def test_hermes_parser_streaming_just_forward_text(
qwen_tokenizer: AnyTokenizer,
hermes_parser: Hermes2ProToolParser,
any_chat_request: ChatCompletionRequest,
) -> None:
text = (
"""This is some prior text that has nothing to do with tool calling."""
)
tokens = qwen_tokenizer.encode(text)
previous_text = ""
delta_messages = []
for token in tokens:
delta_text = qwen_tokenizer.decode([token])
current_text = previous_text + delta_text
delta = hermes_parser.extract_tool_calls_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=any_chat_request,
)
previous_text = current_text
delta_messages.append(delta)
for delta in delta_messages:
assert delta is not None
assert not delta.tool_calls
print(delta_messages)
assert "".join([delta.content for delta in delta_messages]) == text
def test_hermes_parser_streaming_failure_case_bug_19056(
qwen_tokenizer: AnyTokenizer,
hermes_parser: Hermes2ProToolParser,
any_chat_request: ChatCompletionRequest,
) -> None:
text = """<tool_call>
{"name": "final_answer", "arguments": {"trigger": true}}
</tool_call>"""
tokens = qwen_tokenizer.encode(text)
previous_text = ""
delta_messages = []
for token in tokens:
text = qwen_tokenizer.decode([token])
current_text = previous_text + text
delta = hermes_parser.extract_tool_calls_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=text,
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=any_chat_request,
)
previous_text = current_text
if delta is not None:
delta_messages.append(delta)
assert delta_messages[0].tool_calls[0].function.name == "final_answer"
tool_call_args = "".join(delta.tool_calls[0].function.arguments or ""
for delta in delta_messages)
assert tool_call_args == '{"trigger": true}'
def test_hermes_parser_streaming(
qwen_tokenizer: AnyTokenizer,
hermes_parser: Hermes2ProToolParser,
any_chat_request: ChatCompletionRequest,
) -> None:
text = '<tool_call>\
{"name": "get_current_temperature",\
"arguments": {"location":\
"San Francisco, California, United States", "unit": "celsius"}}\
</tool_call>'
tokens = qwen_tokenizer.encode(text)
previous_text = ""
delta_messages = []
for token in tokens:
text = qwen_tokenizer.decode([token])
current_text = previous_text + text
delta = hermes_parser.extract_tool_calls_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=text,
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=any_chat_request,
)
previous_text = current_text
if delta is not None:
delta_messages.append(delta)
print(delta_messages)
assert (delta_messages[0].tool_calls[0].function.name ==
"get_current_temperature")
tool_call_args = "".join(delta.tool_calls[0].function.arguments or ""
for delta in delta_messages)
assert tool_call_args == (
'{"location":"San Francisco, California, United States", '
'"unit": "celsius"}')
def test_hermes_parser_non_streaming_no_tool_call(
hermes_parser: Hermes2ProToolParser,
any_chat_request: ChatCompletionRequest,
) -> None:
text = """This is not a tool call."""
tool_call = hermes_parser.extract_tool_calls(
model_output=text,
request=any_chat_request,
)
assert tool_call is not None
assert not tool_call.tools_called
def test_hermes_parser_non_streaming_tool_call_between_tags(
hermes_parser: Hermes2ProToolParser,
any_chat_request: ChatCompletionRequest,
) -> None:
text = """<tool_call>
{"name": "final_answer", "arguments": {"trigger": true}}
</tool_call>"""
tool_call = hermes_parser.extract_tool_calls(
model_output=text,
request=any_chat_request,
)
assert tool_call is not None
assert tool_call.tools_called
assert tool_call.tool_calls[0].function.name == "final_answer"
assert tool_call.tool_calls[0].function.arguments == '{"trigger": true}'
def test_hermes_parser_non_streaming_tool_call_until_eos(
hermes_parser: Hermes2ProToolParser,
any_chat_request: ChatCompletionRequest,
) -> None:
text = """<tool_call>
{"name": "final_answer", "arguments": {"trigger": true}}"""
tool_call = hermes_parser.extract_tool_calls(
model_output=text,
request=any_chat_request,
)
assert tool_call is not None
assert tool_call.tools_called
assert tool_call.tool_calls[0].function.name == "final_answer"
assert tool_call.tool_calls[0].function.arguments == '{"trigger": true}'
def test_hermes_parser_non_streaming_tool_call_invalid_json(
hermes_parser: Hermes2ProToolParser,
any_chat_request: ChatCompletionRequest,
) -> None:
# Missing closing brace to trigger exception
text = """<tool_call>
{"name": "final_answer", "arguments": {"trigger": true}"""
tool_call = hermes_parser.extract_tool_calls(
model_output=text,
request=any_chat_request,
)
assert tool_call is not None
assert not tool_call.tools_called

View File

@ -19,7 +19,7 @@ pytest -s -v tests/gsm8k/test_gsm8k_correctness.py \
vllm serve Qwen/Qwen2.5-1.5B-Instruct --port 8000 vllm serve Qwen/Qwen2.5-1.5B-Instruct --port 8000
# Run evaluation # Run evaluation
python tests/gsm8k/gsm8k_eval.py --port 8000 python tests/evals/gsm8k/gsm8k_eval.py --port 8000
``` ```
## Configuration Format ## Configuration Format

View File

@ -67,7 +67,6 @@ def generate_params():
return params return params
@pytest.mark.skip(reason="Skipped for now. Should be revisited.")
@pytest.mark.parametrize("device, name, use_mla, block_size", @pytest.mark.parametrize("device, name, use_mla, block_size",
generate_params()) generate_params())
def test_env( def test_env(
@ -189,7 +188,7 @@ def test_env(
# FlashMLA only supports block_size == 64 # FlashMLA only supports block_size == 64
pytest.skip("FlashMLA only supports block_size 64") pytest.skip("FlashMLA only supports block_size 64")
else: else:
from vllm.attention.backends.flashmla import ( from vllm.v1.attention.backends.mla.flashmla import ( # noqa: E501
is_flashmla_supported) is_flashmla_supported)
is_supported, _ = is_flashmla_supported() is_supported, _ = is_flashmla_supported()
if not is_supported: if not is_supported:

View File

@ -959,7 +959,6 @@ def make_test_metadata(
return attn_backend_obj.make_metadata( return attn_backend_obj.make_metadata(
num_prefills=num_prefills, num_prefills=num_prefills,
slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping), slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True, enable_kv_scales_calculation=True,
num_prefill_tokens=num_prefill_tokens, num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,
@ -1009,7 +1008,6 @@ def make_test_metadata(
return attn_backend_obj.make_metadata( return attn_backend_obj.make_metadata(
num_prefills=num_prefills, num_prefills=num_prefills,
slot_mapping=kv_mmap.slot_mapping, slot_mapping=kv_mmap.slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True, enable_kv_scales_calculation=True,
num_prefill_tokens=num_prefill_tokens, num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,

View File

@ -164,8 +164,8 @@ def populate_loras(
weight=layer_weights, weight=layer_weights,
generate_embeddings_tensor=generate_embeddings_tensor, generate_embeddings_tensor=generate_embeddings_tensor,
) )
sublora.lora_b = sublora.lora_b[:, (sublora_len * sublora.lora_b = sublora.lora_b[(sublora_len *
i):(sublora_len * (i + 1))] i):(sublora_len * (i + 1)), :]
sublora.optimize() sublora.optimize()
subloras.append(sublora) subloras.append(sublora)
@ -304,9 +304,9 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
result = embedding(input_) result = embedding(input_)
after_a = F.embedding( after_a = F.embedding(
input_, input_,
lora.lora_a, lora.lora_a.T,
) )
result += (after_a @ lora.lora_b) result += (after_a @ lora.lora_b.T)
expected_results.append(result) expected_results.append(result)
expected_result = torch.cat(expected_results) expected_result = torch.cat(expected_results)
@ -445,9 +445,9 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
result = expanded_embedding(input_) result = expanded_embedding(input_)
after_a = F.embedding( after_a = F.embedding(
original_input_, original_input_,
lora.lora_a, lora.lora_a.T,
) )
result += (after_a @ lora.lora_b) result += (after_a @ lora.lora_b.T)
expected_results.append(result) expected_results.append(result)
expected_result = torch.cat(expected_results) expected_result = torch.cat(expected_results)
@ -575,7 +575,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
lm_head=linear, lm_head=linear,
embedding_bias=None) embedding_bias=None)
result[:, vocab_size + embeddings_tensor_len:] = float("-inf") result[:, vocab_size + embeddings_tensor_len:] = float("-inf")
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
expected_results.append(result) expected_results.append(result)
expected_result = torch.cat(expected_results) expected_result = torch.cat(expected_results)
logits_processor.org_vocab_size = vocab_size logits_processor.org_vocab_size = vocab_size
@ -692,9 +692,10 @@ def test_linear_replicated(
expected_results: list[torch.Tensor] = [] expected_results: list[torch.Tensor] = []
for input_, lora_id in zip(inputs, prompt_mapping): for input_, lora_id in zip(inputs, prompt_mapping):
lora = lora_dict[lora_id] lora = lora_dict[lora_id]
result = linear(input_)[0] result = linear(input_)[0]
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
expected_results.append(result) expected_results.append(result)
expected_result = torch.cat(expected_results) expected_result = torch.cat(expected_results)
@ -817,7 +818,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
for input_, lora_id in zip(inputs, prompt_mapping): for input_, lora_id in zip(inputs, prompt_mapping):
lora = lora_dict[lora_id] lora = lora_dict[lora_id]
result = linear(input_)[0] result = linear(input_)[0]
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
expected_results.append(result) expected_results.append(result)
expected_result = torch.cat(expected_results) expected_result = torch.cat(expected_results)
@ -965,9 +966,10 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
result = linear(input_)[0] result = linear(input_)[0]
subloras = sublora_dict[lora_id] subloras = sublora_dict[lora_id]
for i, sublora in enumerate(subloras): for i, sublora in enumerate(subloras):
result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] * result[:, sublora.lora_b.shape[0] * i:sublora.lora_b.shape[0] *
(i + 1)] += (input_ @ sublora.lora_a @ sublora.lora_b * (i + 1)] += (
sublora.scaling) input_ @ sublora.lora_a.T @ sublora.lora_b.T *
sublora.scaling)
expected_results.append(result) expected_results.append(result)
expected_result = torch.cat(expected_results) expected_result = torch.cat(expected_results)

View File

@ -63,9 +63,9 @@ def test_from_lora_tensors(sql_lora_files, device):
assert lora.lora_b is not None assert lora.lora_b is not None
assert lora.lora_a.device == torch.device(device) assert lora.lora_a.device == torch.device(device)
assert lora.lora_b.device == torch.device(device) assert lora.lora_b.device == torch.device(device)
assert (lora.lora_a.shape[1] == lora.lora_b.shape[0] assert (lora.lora_a.shape[0] == lora.lora_b.shape[1]
), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}" ), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
assert lora.lora_a.shape[1] == 8 assert lora.lora_a.shape[0] == 8
embeddings_module = next( embeddings_module = next(
(k for k in EMBEDDING_MODULES if k in module_name), None) (k for k in EMBEDDING_MODULES if k in module_name), None)
if embeddings_module: if embeddings_module:
@ -86,8 +86,8 @@ def create_lora(lora_id: int, model: nn.Module, sub_modules: list[str],
name, name,
8, 8,
16, 16,
torch.rand([w.shape[1], 8], device=device), torch.rand([8, w.shape[1]], device=device),
torch.rand([8, w.shape[0]], device=device), torch.rand([w.shape[0], 8], device=device),
) )
return LoRAModel(lora_id, 8, loras) return LoRAModel(lora_id, 8, loras)
@ -109,8 +109,8 @@ def create_packed_lora(
replaced_module_name, replaced_module_name,
8, 8,
16, 16,
torch.rand([w.shape[1], 8], device=device), torch.rand([8, w.shape[1]], device=device),
torch.rand([8, w.shape[0] // len(replaced_module_names)], torch.rand([w.shape[0] // len(replaced_module_names), 8],
device=device), device=device),
) )
return LoRAModel(lora_id, 8, loras) return LoRAModel(lora_id, 8, loras)

View File

@ -36,10 +36,10 @@ class DummyLoRAManager:
module_name, module_name,
rank=rank, rank=rank,
lora_alpha=1, lora_alpha=1,
lora_a=torch.rand([weight.shape[1], rank], lora_a=torch.rand([rank, weight.shape[1]],
dtype=weight.dtype, dtype=weight.dtype,
device=self._device), device=self._device),
lora_b=torch.rand([rank, weight.shape[0]], lora_b=torch.rand([weight.shape[0], rank],
dtype=weight.dtype, dtype=weight.dtype,
device=self._device), device=self._device),
) )
@ -67,8 +67,8 @@ class DummyLoRAManager:
module_name, module_name,
rank=rank, rank=rank,
lora_alpha=1, lora_alpha=1,
lora_a=torch.rand([input_dim, rank], device="cuda"), lora_a=torch.rand([rank, input_dim], device="cuda"),
lora_b=torch.rand([rank, output_dim], device="cuda"), lora_b=torch.rand([output_dim, input_dim], device="cuda"),
embeddings_tensor=embeddings_tensor, embeddings_tensor=embeddings_tensor,
) )
self.set_module_lora(module_name, lora) self.set_module_lora(module_name, lora)

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import pytest import pytest
import torch import torch
@ -34,15 +35,15 @@ class Relu3(ReLUSquaredActivation):
[ [
# Default values based on compile level # Default values based on compile level
# - All by default (no Inductor compilation) # - All by default (no Inductor compilation)
("", 0, False, [True] * 4, True), (None, 0, False, [True] * 4, True),
("", 1, True, [True] * 4, True), (None, 1, True, [True] * 4, True),
("", 2, False, [True] * 4, True), (None, 2, False, [True] * 4, True),
# - None by default (with Inductor) # - None by default (with Inductor)
("", 3, True, [False] * 4, False), (None, 3, True, [False] * 4, False),
("", 4, True, [False] * 4, False), (None, 4, True, [False] * 4, False),
# - All by default (without Inductor) # - All by default (without Inductor)
("", 3, False, [True] * 4, True), (None, 3, False, [True] * 4, True),
("", 4, False, [True] * 4, True), (None, 4, False, [True] * 4, True),
# Explicitly enabling/disabling # Explicitly enabling/disabling
# #
# Default: all # Default: all
@ -54,7 +55,7 @@ class Relu3(ReLUSquaredActivation):
# All but SiluAndMul # All but SiluAndMul
("all,-silu_and_mul", 2, True, [1, 0, 1, 1], True), ("all,-silu_and_mul", 2, True, [1, 0, 1, 1], True),
# All but ReLU3 (even if ReLU2 is on) # All but ReLU3 (even if ReLU2 is on)
("-relu3,relu2", 3, False, [1, 1, 1, 0], True), ("-relu3,+relu2", 3, False, [1, 1, 1, 0], True),
# RMSNorm and SiluAndMul # RMSNorm and SiluAndMul
("none,-relu3,+rms_norm,+silu_and_mul", 4, False, [1, 1, 0, 0], False), ("none,-relu3,+rms_norm,+silu_and_mul", 4, False, [1, 1, 0, 0], False),
# All but RMSNorm # All but RMSNorm
@ -67,12 +68,13 @@ class Relu3(ReLUSquaredActivation):
# All but RMSNorm # All but RMSNorm
("all,-rms_norm", 4, True, [0, 1, 1, 1], True), ("all,-rms_norm", 4, True, [0, 1, 1, 1], True),
]) ])
def test_enabled_ops(env: str, torch_level: int, use_inductor: bool, def test_enabled_ops(env: Optional[str], torch_level: int, use_inductor: bool,
ops_enabled: list[int], default_on: bool): ops_enabled: list[int], default_on: bool):
custom_ops = env.split(',') if env else []
vllm_config = VllmConfig( vllm_config = VllmConfig(
compilation_config=CompilationConfig(use_inductor=bool(use_inductor), compilation_config=CompilationConfig(use_inductor=bool(use_inductor),
level=torch_level, level=torch_level,
custom_ops=env.split(","))) custom_ops=custom_ops))
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
assert CustomOp.default_on() == default_on assert CustomOp.default_on() == default_on

View File

@ -20,7 +20,9 @@ pytestmark = pytest.mark.hybrid_model
SSM_MODELS = [ SSM_MODELS = [
"state-spaces/mamba-130m-hf", "state-spaces/mamba-130m-hf",
"tiiuae/falcon-mamba-tiny-dev", "tiiuae/falcon-mamba-tiny-dev",
"yujiepan/mamba2-codestral-v0.1-tiny-random", # mamba2-codestral in transformers is broken pending:
# https://github.com/huggingface/transformers/pull/40861
#"yujiepan/mamba2-codestral-v0.1-tiny-random",
] ]
HYBRID_MODELS = [ HYBRID_MODELS = [
@ -31,18 +33,7 @@ HYBRID_MODELS = [
"ibm-granite/granite-4.0-tiny-preview", "ibm-granite/granite-4.0-tiny-preview",
"tiiuae/Falcon-H1-0.5B-Base", "tiiuae/Falcon-H1-0.5B-Base",
"LiquidAI/LFM2-1.2B", "LiquidAI/LFM2-1.2B",
] "tiny-random/qwen3-next-moe",
V1_SUPPORTED_MODELS = [
"state-spaces/mamba-130m-hf",
"ai21labs/Jamba-tiny-dev",
"pfnet/plamo-2-1b",
"yujiepan/mamba2-codestral-v0.1-tiny-random",
"Zyphra/Zamba2-1.2B-instruct",
"hmellor/tiny-random-BambaForCausalLM",
"ibm-granite/granite-4.0-tiny-preview",
"tiiuae/Falcon-H1-0.5B-Base",
"LiquidAI/LFM2-1.2B",
] ]
FULL_CUDA_GRAPH_MODELS = [ FULL_CUDA_GRAPH_MODELS = [
@ -51,10 +42,6 @@ FULL_CUDA_GRAPH_MODELS = [
"Zyphra/Zamba2-1.2B-instruct", "Zyphra/Zamba2-1.2B-instruct",
] ]
V0_UNSUPPORTED_MODELS = [
"LiquidAI/LFM2-1.2B",
]
FP32_STATE_MODELS = [ FP32_STATE_MODELS = [
"state-spaces/mamba-130m-hf", "state-spaces/mamba-130m-hf",
"Zyphra/Zamba2-1.2B-instruct", "Zyphra/Zamba2-1.2B-instruct",
@ -88,20 +75,16 @@ def test_models(
hf_outputs = hf_model.generate_greedy_logprobs_limit( hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs) example_prompts, max_tokens, num_logprobs)
if model in V1_SUPPORTED_MODELS: with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs(
vllm_v1_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs)
example_prompts, max_tokens, num_logprobs)
else:
vllm_v1_outputs = None
if model in V1_SUPPORTED_MODELS: check_logprobs_close(
check_logprobs_close( outputs_0_lst=hf_outputs,
outputs_0_lst=hf_outputs, outputs_1_lst=vllm_outputs,
outputs_1_lst=vllm_v1_outputs, name_0="hf",
name_0="hf", name_1="vllm",
name_1="vllm-v1", )
)
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) @pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
@ -299,14 +282,14 @@ def test_full_cuda_graph(
example_prompts, max_tokens, num_logprobs) example_prompts, max_tokens, num_logprobs)
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs( vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs) example_prompts, max_tokens, num_logprobs)
check_logprobs_close( check_logprobs_close(
outputs_0_lst=hf_outputs, outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_v1_outputs, outputs_1_lst=vllm_outputs,
name_0="hf", name_0="hf",
name_1="vllm-v1", name_1="vllm",
) )
@ -340,12 +323,12 @@ def test_fp32_cache_state(
with vllm_runner(model, with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS, max_num_seqs=MAX_NUM_SEQS,
**{cache_dtype_param: "float32"}) as vllm_model: **{cache_dtype_param: "float32"}) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs( vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs) example_prompts, max_tokens, num_logprobs)
check_logprobs_close( check_logprobs_close(
outputs_0_lst=hf_outputs, outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_v1_outputs, outputs_1_lst=vllm_outputs,
name_0="hf", name_0="hf",
name_1="vllm-v1", name_1="vllm",
) )

View File

@ -209,7 +209,6 @@ def batch_make_video_embeddings(
return visual(pixel_values_on_device, return visual(pixel_values_on_device,
grid_thw=video_grid_thw_on_device).cpu() grid_thw=video_grid_thw_on_device).cpu()
# V1 Test: this calls a V0 internal.
video_embeds = torch.concat(llm.apply_model(get_image_embeds)) video_embeds = torch.concat(llm.apply_model(get_image_embeds))
# split into original batches # split into original batches

View File

@ -312,14 +312,12 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"), "PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"),
"PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"), "PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"),
"Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"), "Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"),
"Phi4FlashForCausalLM": _HfExamplesInfo("microsoft/Phi-4-mini-flash-reasoning", # noqa: E501
trust_remote_code=True,
v0_only=True,
max_model_len=10240),
"PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct", "PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct",
trust_remote_code=True), trust_remote_code=True),
"Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b", "Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b",
trust_remote_code=True), max_transformers_version="4.55.4",
transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers", # noqa: E501
trust_remote_code=True),
"QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat", "QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat",
max_transformers_version="4.53", max_transformers_version="4.53",
transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers", # noqa: E501 transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers", # noqa: E501
@ -330,7 +328,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"), "Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"), "Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
"Qwen3NextForCausalLM": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct", "Qwen3NextForCausalLM": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct",
min_transformers_version="4.56.2"), extras={"tiny-random": "tiny-random/qwen3-next-moe"}, # noqa: E501
min_transformers_version="4.56.3"),
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"), "RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
"SeedOssForCausalLM": _HfExamplesInfo("ByteDance-Seed/Seed-OSS-36B-Instruct", # noqa: E501 "SeedOssForCausalLM": _HfExamplesInfo("ByteDance-Seed/Seed-OSS-36B-Instruct", # noqa: E501
trust_remote_code=True, trust_remote_code=True,
@ -448,6 +447,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
max_transformers_version="4.48", # noqa: E501 max_transformers_version="4.48", # noqa: E501
transformers_version_reason="HF model is not compatible.", # noqa: E501 transformers_version_reason="HF model is not compatible.", # noqa: E501
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501 hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501
"DotsOCRForCausalLM": _HfExamplesInfo("rednote-hilab/dots.ocr",
trust_remote_code=True),
"Emu3ForConditionalGeneration": _HfExamplesInfo("BAAI/Emu3-Chat-hf"), "Emu3ForConditionalGeneration": _HfExamplesInfo("BAAI/Emu3-Chat-hf"),
"Ernie4_5_VLMoeForConditionalGeneration": _HfExamplesInfo("baidu/ERNIE-4.5-VL-28B-A3B-PT", # noqa: E501 "Ernie4_5_VLMoeForConditionalGeneration": _HfExamplesInfo("baidu/ERNIE-4.5-VL-28B-A3B-PT", # noqa: E501
trust_remote_code=True), trust_remote_code=True),
@ -560,10 +561,12 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ"), # noqa: E501 "Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ"), # noqa: E501
"Qwen3VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen3-VL-4B-Instruct", # noqa: E501 "Qwen3VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen3-VL-4B-Instruct", # noqa: E501
max_model_len=4096, max_model_len=4096,
min_transformers_version="4.57"), # noqa: E501 min_transformers_version="4.57",
is_available_online=False),
"Qwen3VLMoeForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen3-VL-30B-A3B-Instruct", # noqa: E501 "Qwen3VLMoeForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen3-VL-30B-A3B-Instruct", # noqa: E501
max_model_len=4096, max_model_len=4096,
min_transformers_version="4.57"), min_transformers_version="4.57",
is_available_online=False),
"RForConditionalGeneration": _HfExamplesInfo("YannQi/R-4B", "RForConditionalGeneration": _HfExamplesInfo("YannQi/R-4B",
trust_remote_code=True), trust_remote_code=True),
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B", "SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B",
@ -640,7 +643,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
trust_remote_code=True, trust_remote_code=True,
speculative_model="XiaomiMiMo/MiMo-7B-RL"), speculative_model="XiaomiMiMo/MiMo-7B-RL"),
"Qwen3NextMTP": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct", "Qwen3NextMTP": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct",
min_transformers_version="4.56.2"), min_transformers_version="4.56.3"),
} }
_TRANSFORMERS_BACKEND_MODELS = { _TRANSFORMERS_BACKEND_MODELS = {

View File

@ -1,10 +1,20 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
from vllm.model_executor.models.vision import resolve_visual_encoder_outputs from tests.utils import multi_gpu_test
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import (init_distributed_environment,
initialize_model_parallel)
from vllm.model_executor.models.vision import (
get_load_balance_assignment, resolve_visual_encoder_outputs,
run_dp_sharded_mrope_vision_model, run_dp_sharded_vision_model)
from vllm.platforms import current_platform
from vllm.utils import get_open_port, update_environment_variables
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -33,3 +43,415 @@ def test_resolve_visual_encoder_outputs(feature_sample_layers,
post_layer_norm=None, post_layer_norm=None,
max_possible_layers=max_possible_layers) max_possible_layers=max_possible_layers)
assert torch.equal(torch.tensor(expected_features), output_tensor) assert torch.equal(torch.tensor(expected_features), output_tensor)
class SimpleLinearModel(torch.nn.Module):
"""A simple linear vision model for testing."""
def __init__(self, input_dim: int = 3 * 224 * 224, output_dim: int = 32):
super().__init__()
self.flatten = torch.nn.Flatten()
self.linear = torch.nn.Linear(input_dim, output_dim)
def forward(self, x: torch.Tensor):
# Flatten the input and apply linear transformation
x = self.flatten(x)
return self.linear(x)
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"batch_size",
[
1, # Single image
4, # Small batch
5, # Odd batch size (for testing padding)
],
)
def test_run_dp_sharded_vision_model(batch_size: int):
world_size = 2
# Launch processes
mp.spawn(
run_dp_sharded_vision_model_vs_direct,
args=(
world_size,
batch_size,
get_open_port(),
),
nprocs=world_size,
)
def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int,
batch_size: int, master_port: int):
"""
Test that run_dp_sharded_vision_model produces the same results as
calling the model directly.
"""
# Set random seed for reproducibility
current_platform.seed_everything(0)
device = f"{current_platform.device_name}:{local_rank}"
current_platform.set_device(device)
torch.set_default_device(device)
update_environment_variables({
'RANK': str(local_rank),
'LOCAL_RANK': str(local_rank),
'WORLD_SIZE': str(world_size),
'MASTER_ADDR': 'localhost',
'MASTER_PORT': str(master_port),
})
# initialize distributed
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
# Create a test input tensor
image_input = torch.randn(batch_size, 3, 224, 224)
# Create a simple linear model
vision_model = SimpleLinearModel()
# Run the model directly on the full input
with torch.inference_mode():
direct_output = vision_model(image_input)
# Run the model through the sharded function
with torch.inference_mode():
sharded_output = run_dp_sharded_vision_model(image_input, vision_model)
# Check that the world size is set up correctly
assert get_tensor_model_parallel_world_size() == world_size
# Check that the outputs have the same shape
assert direct_output.shape == sharded_output.shape
# Check that the outputs are close (they should be identical)
assert torch.allclose(direct_output, sharded_output, rtol=1e-5, atol=1e-5)
@pytest.mark.parametrize(
"sizes,num_gpus,expected_shuffle_indices,expected_gpu_sample_counts,"
"expected_grouped_sizes_per_gpu,test_description",
[
# Empty input
([], 2, [], [0, 0], [0, 0], "empty input"),
# Fewer samples than GPUs
([100, 200], 4, [1, 0], [1, 1, 0, 0], [200, 100, 0, 0
], "fewer samples than GPUs"),
# Single GPU
([100, 200, 300], 1, [2, 1, 0], [3], [600], "single GPU"),
# Balanced assignment
([100, 100, 100, 100
], 2, [0, 2, 1, 3], [2, 2], [200, 200], "balanced assignment"),
# Unbalanced sizes - this one is trickier since the algorithm is greedy
([1000, 100, 200, 50], 2, [0, 2, 1, 3
], [1, 3], [1000, 350], "unbalanced sizes"),
],
)
def test_get_load_balance_assignment_cases(sizes, num_gpus,
expected_shuffle_indices,
expected_gpu_sample_counts,
expected_grouped_sizes_per_gpu,
test_description):
"""Test get_load_balance_assignment with various input cases."""
result = get_load_balance_assignment(sizes, num_gpus=num_gpus)
(shuffle_indices, gpu_sample_counts, grouped_sizes_per_gpu) = result
# Common assertions for all cases
assert len(shuffle_indices) == len(sizes)
assert len(gpu_sample_counts) == num_gpus
assert len(grouped_sizes_per_gpu) == num_gpus
assert sum(gpu_sample_counts) == len(sizes)
assert shuffle_indices == expected_shuffle_indices
assert gpu_sample_counts == expected_gpu_sample_counts
assert grouped_sizes_per_gpu == expected_grouped_sizes_per_gpu
class SimpleMRopeVisionModel(torch.nn.Module):
"""A simple vision model for testing mrope functionality."""
def __init__(self, spatial_merge_size: int = 2, out_hidden_size: int = 64):
super().__init__()
self.spatial_merge_size = spatial_merge_size
self.out_hidden_size = out_hidden_size
self.linear = torch.nn.Linear(768, out_hidden_size)
def forward(self, pixel_values: torch.Tensor,
grid_thw_list: list[list[int]]):
"""Simple forward pass that simulates spatial merging."""
# Apply linear transformation
embeddings = self.linear(pixel_values)
# Simulate spatial merging by reducing the number of patches
merge_factor = self.spatial_merge_size * self.spatial_merge_size
# Group patches and merge spatially
merged_embeddings = []
start_idx = 0
for grid_thw in grid_thw_list:
num_patches = math.prod(grid_thw)
end_idx = start_idx + num_patches
# Get patches for this image
image_patches = embeddings[start_idx:end_idx]
# Simulate spatial merging by averaging groups of patches
merged_patches = num_patches // merge_factor
if merged_patches > 0:
# Reshape and average to simulate merging
reshaped = image_patches[:merged_patches * merge_factor].view(
merged_patches, merge_factor, -1)
merged = reshaped.mean(dim=1)
merged_embeddings.append(merged)
start_idx = end_idx
if merged_embeddings:
return torch.cat(merged_embeddings, dim=0)
else:
return torch.empty((0, self.out_hidden_size),
device=pixel_values.device,
dtype=pixel_values.dtype)
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"batch_size",
[
1, # Single image
3, # Small batch
5, # Odd batch size (for testing padding)
],
)
def test_run_dp_sharded_mrope_vision_model(batch_size: int):
world_size = 2
# Launch processes
mp.spawn(
run_dp_sharded_mrope_vision_model_vs_direct,
args=(
world_size,
batch_size,
get_open_port(),
),
nprocs=world_size,
)
def run_dp_sharded_mrope_vision_model_vs_direct(local_rank: int,
world_size: int,
batch_size: int,
master_port: int):
"""
Test that run_dp_sharded_mrope_vision_model produces the same results as
calling the model directly.
"""
# Set random seed for reproducibility
current_platform.seed_everything(0)
device = f"{current_platform.device_name}:{local_rank}"
current_platform.set_device(device)
torch.set_default_device(device)
update_environment_variables({
'RANK': str(local_rank),
'LOCAL_RANK': str(local_rank),
'WORLD_SIZE': str(world_size),
'MASTER_ADDR': 'localhost',
'MASTER_PORT': str(master_port),
})
# initialize distributed
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
# Create test data
grid_thw_list = []
pixel_values_list = []
for i in range(batch_size):
# Varying image sizes for better testing
t, h, w = 1, 4 + i, 4 + i
grid_thw_list.append([t, h, w])
num_patches = t * h * w
# Create random pixel values for this image
image_pixels = torch.randn(num_patches, 768)
pixel_values_list.append(image_pixels)
# Concatenate all pixel values
pixel_values = torch.cat(pixel_values_list, dim=0)
# Create a simple mrope vision model
vision_model = SimpleMRopeVisionModel()
# Run the model directly on the full input (only on rank 0)
if local_rank == 0:
with torch.inference_mode():
direct_output = vision_model(pixel_values, grid_thw_list)
# Run the model through the sharded function
with torch.inference_mode():
sharded_output = run_dp_sharded_mrope_vision_model(vision_model,
pixel_values,
grid_thw_list,
rope_type="rope_3d")
sharded_output = torch.cat(sharded_output, dim=0)
# Check that the world size is set up correctly
assert get_tensor_model_parallel_world_size() == world_size
# Compare outputs (only on rank 0)
if local_rank == 0:
# Check that the outputs have the same shape
assert direct_output.shape == sharded_output.shape
# Check that the outputs are close (they should be identical)
assert torch.allclose(direct_output,
sharded_output,
rtol=1e-5,
atol=1e-5)
@multi_gpu_test(num_gpus=2)
def test_run_dp_sharded_mrope_vision_model_empty_input():
world_size = 2
mp.spawn(
run_dp_sharded_mrope_vision_model_empty_input_worker,
args=(world_size, get_open_port()),
nprocs=world_size,
)
def run_dp_sharded_mrope_vision_model_empty_input_worker(
local_rank: int, world_size: int, master_port: int):
"""Test run_dp_sharded_mrope_vision_model with empty input."""
# Set up distributed environment
device = f"{current_platform.device_name}:{local_rank}"
current_platform.set_device(device)
torch.set_default_device(device)
update_environment_variables({
'RANK': str(local_rank),
'LOCAL_RANK': str(local_rank),
'WORLD_SIZE': str(world_size),
'MASTER_ADDR': 'localhost',
'MASTER_PORT': str(master_port),
})
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
# Create empty inputs
pixel_values = torch.empty((0, 768))
grid_thw_list: list[list[int]] = []
vision_model = SimpleMRopeVisionModel()
# Should handle empty input gracefully
with torch.inference_mode():
output = run_dp_sharded_mrope_vision_model(vision_model,
pixel_values,
grid_thw_list,
rope_type="rope_3d")
assert len(output) == 0
@multi_gpu_test(num_gpus=4)
def test_run_dp_sharded_mrope_vision_model_uneven_load():
world_size = 4
mp.spawn(
run_dp_sharded_mrope_vision_model_uneven_load_worker,
args=(world_size, get_open_port()),
nprocs=world_size,
)
def run_dp_sharded_mrope_vision_model_uneven_load_worker(
local_rank: int, world_size: int, master_port: int):
"""Test run_dp_sharded_mrope_vision_model with uneven load distribution."""
# Set up distributed environment
current_platform.seed_everything(123)
device = f"{current_platform.device_name}:{local_rank}"
current_platform.set_device(device)
torch.set_default_device(device)
update_environment_variables({
'RANK': str(local_rank),
'LOCAL_RANK': str(local_rank),
'WORLD_SIZE': str(world_size),
'MASTER_ADDR': 'localhost',
'MASTER_PORT': str(master_port),
})
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
# Create images with very different sizes
grid_thw_list = [
[1, 2, 2], # Small: 4 patches
[1, 8, 8], # Large: 64 patches
[1, 3, 3], # Medium: 9 patches
]
pixel_values_list = []
for grid_thw in grid_thw_list:
num_patches = math.prod(grid_thw)
image_pixels = torch.randn(num_patches, 768)
pixel_values_list.append(image_pixels)
pixel_values = torch.cat(pixel_values_list, dim=0)
vision_model = SimpleMRopeVisionModel()
# Should handle uneven distribution without errors
with torch.inference_mode():
output_tuple = run_dp_sharded_mrope_vision_model(vision_model,
pixel_values,
grid_thw_list,
rope_type="rope_3d")
# Verify output shape is reasonable
merge_factor = vision_model.spatial_merge_size**2
expected_output_patches = list(
math.prod(grid_thw) // merge_factor for grid_thw in grid_thw_list)
for i, output in enumerate(output_tuple):
assert output.shape[0] == expected_output_patches[i]
assert output.shape[1] == vision_model.out_hidden_size
@pytest.mark.parametrize("spatial_merge_size", [2, 4])
def test_simple_mrope_vision_model_spatial_merge(spatial_merge_size: int):
"""Test SimpleMRopeVisionModel with different spatial merge sizes."""
device = current_platform.device_type
grid_thw_list = [[1, 4, 4], [1, 6, 6]] # Two images
pixel_values_list = []
for grid_thw in grid_thw_list:
num_patches = math.prod(grid_thw)
image_pixels = torch.randn(num_patches, 768, device=device)
pixel_values_list.append(image_pixels)
pixel_values = torch.cat(pixel_values_list, dim=0)
vision_model = SimpleMRopeVisionModel(
spatial_merge_size=spatial_merge_size).to(device)
with torch.inference_mode():
output = vision_model(pixel_values, grid_thw_list)
# Verify output dimensions based on spatial merging
total_patches = sum(math.prod(grid_thw) for grid_thw in grid_thw_list)
merge_factor = spatial_merge_size**2
expected_output_patches = total_patches // merge_factor
assert output.shape[0] == expected_output_patches
assert output.shape[1] == vision_model.out_hidden_size

View File

@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64 import base64
import math
import mimetypes import mimetypes
import os import os
from tempfile import NamedTemporaryFile, TemporaryDirectory from tempfile import NamedTemporaryFile, TemporaryDirectory
@ -10,22 +9,11 @@ from typing import TYPE_CHECKING, NamedTuple
import numpy as np import numpy as np
import pytest import pytest
import torch
import torch.multiprocessing as mp
from PIL import Image, ImageChops from PIL import Image, ImageChops
from tests.utils import multi_gpu_test
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import (init_distributed_environment,
initialize_model_parallel)
from vllm.multimodal.image import convert_image_mode from vllm.multimodal.image import convert_image_mode
from vllm.multimodal.inputs import PlaceholderRange from vllm.multimodal.inputs import PlaceholderRange
from vllm.multimodal.utils import (MediaConnector, argsort_mm_positions, from vllm.multimodal.utils import MediaConnector, argsort_mm_positions
get_load_balance_assignment,
run_dp_sharded_mrope_vision_model,
run_dp_sharded_vision_model)
from vllm.platforms import current_platform
from vllm.utils import get_open_port, update_environment_variables
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.multimodal.inputs import MultiModalPlaceholderDict from vllm.multimodal.inputs import MultiModalPlaceholderDict
@ -404,415 +392,3 @@ def test_argsort_mm_positions():
modality_idxs = argsort_mm_positions(mm_positions) modality_idxs = argsort_mm_positions(mm_positions)
assert modality_idxs == expected_modality_idxs assert modality_idxs == expected_modality_idxs
class SimpleLinearModel(torch.nn.Module):
"""A simple linear vision model for testing."""
def __init__(self, input_dim: int = 3 * 224 * 224, output_dim: int = 32):
super().__init__()
self.flatten = torch.nn.Flatten()
self.linear = torch.nn.Linear(input_dim, output_dim)
def forward(self, x: torch.Tensor):
# Flatten the input and apply linear transformation
x = self.flatten(x)
return self.linear(x)
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"batch_size",
[
1, # Single image
4, # Small batch
5, # Odd batch size (for testing padding)
],
)
def test_run_dp_sharded_vision_model(batch_size: int):
world_size = 2
# Launch processes
mp.spawn(
run_dp_sharded_vision_model_vs_direct,
args=(
world_size,
batch_size,
get_open_port(),
),
nprocs=world_size,
)
def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int,
batch_size: int, master_port: int):
"""
Test that run_dp_sharded_vision_model produces the same results as
calling the model directly.
"""
# Set random seed for reproducibility
current_platform.seed_everything(0)
device = f"{current_platform.device_name}:{local_rank}"
current_platform.set_device(device)
torch.set_default_device(device)
update_environment_variables({
'RANK': str(local_rank),
'LOCAL_RANK': str(local_rank),
'WORLD_SIZE': str(world_size),
'MASTER_ADDR': 'localhost',
'MASTER_PORT': str(master_port),
})
# initialize distributed
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
# Create a test input tensor
image_input = torch.randn(batch_size, 3, 224, 224)
# Create a simple linear model
vision_model = SimpleLinearModel()
# Run the model directly on the full input
with torch.inference_mode():
direct_output = vision_model(image_input)
# Run the model through the sharded function
with torch.inference_mode():
sharded_output = run_dp_sharded_vision_model(image_input, vision_model)
# Check that the world size is set up correctly
assert get_tensor_model_parallel_world_size() == world_size
# Check that the outputs have the same shape
assert direct_output.shape == sharded_output.shape
# Check that the outputs are close (they should be identical)
assert torch.allclose(direct_output, sharded_output, rtol=1e-5, atol=1e-5)
@pytest.mark.parametrize(
"sizes,num_gpus,expected_shuffle_indices,expected_gpu_sample_counts,"
"expected_grouped_sizes_per_gpu,test_description",
[
# Empty input
([], 2, [], [0, 0], [0, 0], "empty input"),
# Fewer samples than GPUs
([100, 200], 4, [1, 0], [1, 1, 0, 0], [200, 100, 0, 0
], "fewer samples than GPUs"),
# Single GPU
([100, 200, 300], 1, [2, 1, 0], [3], [600], "single GPU"),
# Balanced assignment
([100, 100, 100, 100
], 2, [0, 2, 1, 3], [2, 2], [200, 200], "balanced assignment"),
# Unbalanced sizes - this one is trickier since the algorithm is greedy
([1000, 100, 200, 50], 2, [0, 2, 1, 3
], [1, 3], [1000, 350], "unbalanced sizes"),
],
)
def test_get_load_balance_assignment_cases(sizes, num_gpus,
expected_shuffle_indices,
expected_gpu_sample_counts,
expected_grouped_sizes_per_gpu,
test_description):
"""Test get_load_balance_assignment with various input cases."""
result = get_load_balance_assignment(sizes, num_gpus=num_gpus)
(shuffle_indices, gpu_sample_counts, grouped_sizes_per_gpu) = result
# Common assertions for all cases
assert len(shuffle_indices) == len(sizes)
assert len(gpu_sample_counts) == num_gpus
assert len(grouped_sizes_per_gpu) == num_gpus
assert sum(gpu_sample_counts) == len(sizes)
assert shuffle_indices == expected_shuffle_indices
assert gpu_sample_counts == expected_gpu_sample_counts
assert grouped_sizes_per_gpu == expected_grouped_sizes_per_gpu
class SimpleMRopeVisionModel(torch.nn.Module):
"""A simple vision model for testing mrope functionality."""
def __init__(self, spatial_merge_size: int = 2, out_hidden_size: int = 64):
super().__init__()
self.spatial_merge_size = spatial_merge_size
self.out_hidden_size = out_hidden_size
self.linear = torch.nn.Linear(768, out_hidden_size)
def forward(self, pixel_values: torch.Tensor,
grid_thw_list: list[list[int]]):
"""Simple forward pass that simulates spatial merging."""
# Apply linear transformation
embeddings = self.linear(pixel_values)
# Simulate spatial merging by reducing the number of patches
merge_factor = self.spatial_merge_size * self.spatial_merge_size
# Group patches and merge spatially
merged_embeddings = []
start_idx = 0
for grid_thw in grid_thw_list:
num_patches = math.prod(grid_thw)
end_idx = start_idx + num_patches
# Get patches for this image
image_patches = embeddings[start_idx:end_idx]
# Simulate spatial merging by averaging groups of patches
merged_patches = num_patches // merge_factor
if merged_patches > 0:
# Reshape and average to simulate merging
reshaped = image_patches[:merged_patches * merge_factor].view(
merged_patches, merge_factor, -1)
merged = reshaped.mean(dim=1)
merged_embeddings.append(merged)
start_idx = end_idx
if merged_embeddings:
return torch.cat(merged_embeddings, dim=0)
else:
return torch.empty((0, self.out_hidden_size),
device=pixel_values.device,
dtype=pixel_values.dtype)
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"batch_size",
[
1, # Single image
3, # Small batch
5, # Odd batch size (for testing padding)
],
)
def test_run_dp_sharded_mrope_vision_model(batch_size: int):
world_size = 2
# Launch processes
mp.spawn(
run_dp_sharded_mrope_vision_model_vs_direct,
args=(
world_size,
batch_size,
get_open_port(),
),
nprocs=world_size,
)
def run_dp_sharded_mrope_vision_model_vs_direct(local_rank: int,
world_size: int,
batch_size: int,
master_port: int):
"""
Test that run_dp_sharded_mrope_vision_model produces the same results as
calling the model directly.
"""
# Set random seed for reproducibility
current_platform.seed_everything(0)
device = f"{current_platform.device_name}:{local_rank}"
current_platform.set_device(device)
torch.set_default_device(device)
update_environment_variables({
'RANK': str(local_rank),
'LOCAL_RANK': str(local_rank),
'WORLD_SIZE': str(world_size),
'MASTER_ADDR': 'localhost',
'MASTER_PORT': str(master_port),
})
# initialize distributed
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
# Create test data
grid_thw_list = []
pixel_values_list = []
for i in range(batch_size):
# Varying image sizes for better testing
t, h, w = 1, 4 + i, 4 + i
grid_thw_list.append([t, h, w])
num_patches = t * h * w
# Create random pixel values for this image
image_pixels = torch.randn(num_patches, 768)
pixel_values_list.append(image_pixels)
# Concatenate all pixel values
pixel_values = torch.cat(pixel_values_list, dim=0)
# Create a simple mrope vision model
vision_model = SimpleMRopeVisionModel()
# Run the model directly on the full input (only on rank 0)
if local_rank == 0:
with torch.inference_mode():
direct_output = vision_model(pixel_values, grid_thw_list)
# Run the model through the sharded function
with torch.inference_mode():
sharded_output = run_dp_sharded_mrope_vision_model(vision_model,
pixel_values,
grid_thw_list,
rope_type="rope_3d")
sharded_output = torch.cat(sharded_output, dim=0)
# Check that the world size is set up correctly
assert get_tensor_model_parallel_world_size() == world_size
# Compare outputs (only on rank 0)
if local_rank == 0:
# Check that the outputs have the same shape
assert direct_output.shape == sharded_output.shape
# Check that the outputs are close (they should be identical)
assert torch.allclose(direct_output,
sharded_output,
rtol=1e-5,
atol=1e-5)
@multi_gpu_test(num_gpus=2)
def test_run_dp_sharded_mrope_vision_model_empty_input():
world_size = 2
mp.spawn(
run_dp_sharded_mrope_vision_model_empty_input_worker,
args=(world_size, get_open_port()),
nprocs=world_size,
)
def run_dp_sharded_mrope_vision_model_empty_input_worker(
local_rank: int, world_size: int, master_port: int):
"""Test run_dp_sharded_mrope_vision_model with empty input."""
# Set up distributed environment
device = f"{current_platform.device_name}:{local_rank}"
current_platform.set_device(device)
torch.set_default_device(device)
update_environment_variables({
'RANK': str(local_rank),
'LOCAL_RANK': str(local_rank),
'WORLD_SIZE': str(world_size),
'MASTER_ADDR': 'localhost',
'MASTER_PORT': str(master_port),
})
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
# Create empty inputs
pixel_values = torch.empty((0, 768))
grid_thw_list: list[list[int]] = []
vision_model = SimpleMRopeVisionModel()
# Should handle empty input gracefully
with torch.inference_mode():
output = run_dp_sharded_mrope_vision_model(vision_model,
pixel_values,
grid_thw_list,
rope_type="rope_3d")
assert len(output) == 0
@multi_gpu_test(num_gpus=4)
def test_run_dp_sharded_mrope_vision_model_uneven_load():
world_size = 4
mp.spawn(
run_dp_sharded_mrope_vision_model_uneven_load_worker,
args=(world_size, get_open_port()),
nprocs=world_size,
)
def run_dp_sharded_mrope_vision_model_uneven_load_worker(
local_rank: int, world_size: int, master_port: int):
"""Test run_dp_sharded_mrope_vision_model with uneven load distribution."""
# Set up distributed environment
current_platform.seed_everything(123)
device = f"{current_platform.device_name}:{local_rank}"
current_platform.set_device(device)
torch.set_default_device(device)
update_environment_variables({
'RANK': str(local_rank),
'LOCAL_RANK': str(local_rank),
'WORLD_SIZE': str(world_size),
'MASTER_ADDR': 'localhost',
'MASTER_PORT': str(master_port),
})
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
# Create images with very different sizes
grid_thw_list = [
[1, 2, 2], # Small: 4 patches
[1, 8, 8], # Large: 64 patches
[1, 3, 3], # Medium: 9 patches
]
pixel_values_list = []
for grid_thw in grid_thw_list:
num_patches = math.prod(grid_thw)
image_pixels = torch.randn(num_patches, 768)
pixel_values_list.append(image_pixels)
pixel_values = torch.cat(pixel_values_list, dim=0)
vision_model = SimpleMRopeVisionModel()
# Should handle uneven distribution without errors
with torch.inference_mode():
output_tuple = run_dp_sharded_mrope_vision_model(vision_model,
pixel_values,
grid_thw_list,
rope_type="rope_3d")
# Verify output shape is reasonable
merge_factor = vision_model.spatial_merge_size**2
expected_output_patches = list(
math.prod(grid_thw) // merge_factor for grid_thw in grid_thw_list)
for i, output in enumerate(output_tuple):
assert output.shape[0] == expected_output_patches[i]
assert output.shape[1] == vision_model.out_hidden_size
@pytest.mark.parametrize("spatial_merge_size", [2, 4])
def test_simple_mrope_vision_model_spatial_merge(spatial_merge_size: int):
"""Test SimpleMRopeVisionModel with different spatial merge sizes."""
device = current_platform.device_type
grid_thw_list = [[1, 4, 4], [1, 6, 6]] # Two images
pixel_values_list = []
for grid_thw in grid_thw_list:
num_patches = math.prod(grid_thw)
image_pixels = torch.randn(num_patches, 768, device=device)
pixel_values_list.append(image_pixels)
pixel_values = torch.cat(pixel_values_list, dim=0)
vision_model = SimpleMRopeVisionModel(
spatial_merge_size=spatial_merge_size).to(device)
with torch.inference_mode():
output = vision_model(pixel_values, grid_thw_list)
# Verify output dimensions based on spatial merging
total_patches = sum(math.prod(grid_thw) for grid_thw in grid_thw_list)
merge_factor = spatial_merge_size**2
expected_output_patches = total_patches // merge_factor
assert output.shape[0] == expected_output_patches
assert output.shape[1] == vision_model.out_hidden_size

216
tests/test_envs.py Normal file
View File

@ -0,0 +1,216 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from unittest.mock import patch
import pytest
from vllm.envs import env_list_with_choices, env_with_choices
class TestEnvWithChoices:
"""Test cases for env_with_choices function."""
def test_default_value_returned_when_env_not_set(self):
"""Test default is returned when env var is not set."""
env_func = env_with_choices("NONEXISTENT_ENV", "default",
["option1", "option2"])
assert env_func() == "default"
def test_none_default_returned_when_env_not_set(self):
"""Test that None is returned when env not set and default is None."""
env_func = env_with_choices("NONEXISTENT_ENV", None,
["option1", "option2"])
assert env_func() is None
def test_valid_value_returned_case_sensitive(self):
"""Test that valid value is returned in case sensitive mode."""
with patch.dict(os.environ, {"TEST_ENV": "option1"}):
env_func = env_with_choices("TEST_ENV",
"default", ["option1", "option2"],
case_sensitive=True)
assert env_func() == "option1"
def test_valid_lowercase_value_returned_case_insensitive(self):
"""Test that lowercase value is accepted in case insensitive mode."""
with patch.dict(os.environ, {"TEST_ENV": "option1"}):
env_func = env_with_choices("TEST_ENV",
"default", ["OPTION1", "OPTION2"],
case_sensitive=False)
assert env_func() == "option1"
def test_valid_uppercase_value_returned_case_insensitive(self):
"""Test that uppercase value is accepted in case insensitive mode."""
with patch.dict(os.environ, {"TEST_ENV": "OPTION1"}):
env_func = env_with_choices("TEST_ENV",
"default", ["option1", "option2"],
case_sensitive=False)
assert env_func() == "OPTION1"
def test_invalid_value_raises_error_case_sensitive(self):
"""Test that invalid value raises ValueError in case sensitive mode."""
with patch.dict(os.environ, {"TEST_ENV": "invalid"}):
env_func = env_with_choices("TEST_ENV",
"default", ["option1", "option2"],
case_sensitive=True)
with pytest.raises(ValueError,
match="Invalid value 'invalid' for TEST_ENV"):
env_func()
def test_case_mismatch_raises_error_case_sensitive(self):
"""Test that case mismatch raises ValueError in case sensitive mode."""
with patch.dict(os.environ, {"TEST_ENV": "OPTION1"}):
env_func = env_with_choices("TEST_ENV",
"default", ["option1", "option2"],
case_sensitive=True)
with pytest.raises(ValueError,
match="Invalid value 'OPTION1' for TEST_ENV"):
env_func()
def test_invalid_value_raises_error_case_insensitive(self):
"""Test that invalid value raises ValueError when case insensitive."""
with patch.dict(os.environ, {"TEST_ENV": "invalid"}):
env_func = env_with_choices("TEST_ENV",
"default", ["option1", "option2"],
case_sensitive=False)
with pytest.raises(ValueError,
match="Invalid value 'invalid' for TEST_ENV"):
env_func()
def test_callable_choices_resolved_correctly(self):
"""Test that callable choices are resolved correctly."""
def get_choices():
return ["dynamic1", "dynamic2"]
with patch.dict(os.environ, {"TEST_ENV": "dynamic1"}):
env_func = env_with_choices("TEST_ENV", "default", get_choices)
assert env_func() == "dynamic1"
def test_callable_choices_with_invalid_value(self):
"""Test that callable choices raise error for invalid values."""
def get_choices():
return ["dynamic1", "dynamic2"]
with patch.dict(os.environ, {"TEST_ENV": "invalid"}):
env_func = env_with_choices("TEST_ENV", "default", get_choices)
with pytest.raises(ValueError,
match="Invalid value 'invalid' for TEST_ENV"):
env_func()
class TestEnvListWithChoices:
"""Test cases for env_list_with_choices function."""
def test_default_list_returned_when_env_not_set(self):
"""Test that default list is returned when env var is not set."""
env_func = env_list_with_choices("NONEXISTENT_ENV",
["default1", "default2"],
["option1", "option2"])
assert env_func() == ["default1", "default2"]
def test_empty_default_list_returned_when_env_not_set(self):
"""Test that empty default list is returned when env not set."""
env_func = env_list_with_choices("NONEXISTENT_ENV", [],
["option1", "option2"])
assert env_func() == []
def test_single_valid_value_parsed_correctly(self):
"""Test that single valid value is parsed correctly."""
with patch.dict(os.environ, {"TEST_ENV": "option1"}):
env_func = env_list_with_choices("TEST_ENV", [],
["option1", "option2"])
assert env_func() == ["option1"]
def test_multiple_valid_values_parsed_correctly(self):
"""Test that multiple valid values are parsed correctly."""
with patch.dict(os.environ, {"TEST_ENV": "option1,option2"}):
env_func = env_list_with_choices("TEST_ENV", [],
["option1", "option2"])
assert env_func() == ["option1", "option2"]
def test_values_with_whitespace_trimmed(self):
"""Test that values with whitespace are trimmed correctly."""
with patch.dict(os.environ, {"TEST_ENV": " option1 , option2 "}):
env_func = env_list_with_choices("TEST_ENV", [],
["option1", "option2"])
assert env_func() == ["option1", "option2"]
def test_empty_values_filtered_out(self):
"""Test that empty values are filtered out."""
with patch.dict(os.environ, {"TEST_ENV": "option1,,option2,"}):
env_func = env_list_with_choices("TEST_ENV", [],
["option1", "option2"])
assert env_func() == ["option1", "option2"]
def test_empty_string_returns_default(self):
"""Test that empty string returns default."""
with patch.dict(os.environ, {"TEST_ENV": ""}):
env_func = env_list_with_choices("TEST_ENV", ["default"],
["option1", "option2"])
assert env_func() == ["default"]
def test_only_commas_returns_default(self):
"""Test that string with only commas returns default."""
with patch.dict(os.environ, {"TEST_ENV": ",,,"}):
env_func = env_list_with_choices("TEST_ENV", ["default"],
["option1", "option2"])
assert env_func() == ["default"]
def test_case_sensitive_validation(self):
"""Test case sensitive validation."""
with patch.dict(os.environ, {"TEST_ENV": "option1,OPTION2"}):
env_func = env_list_with_choices("TEST_ENV", [],
["option1", "option2"],
case_sensitive=True)
with pytest.raises(ValueError,
match="Invalid value 'OPTION2' in TEST_ENV"):
env_func()
def test_case_insensitive_validation(self):
"""Test case insensitive validation."""
with patch.dict(os.environ, {"TEST_ENV": "OPTION1,option2"}):
env_func = env_list_with_choices("TEST_ENV", [],
["option1", "option2"],
case_sensitive=False)
assert env_func() == ["OPTION1", "option2"]
def test_invalid_value_in_list_raises_error(self):
"""Test that invalid value in list raises ValueError."""
with patch.dict(os.environ, {"TEST_ENV": "option1,invalid,option2"}):
env_func = env_list_with_choices("TEST_ENV", [],
["option1", "option2"])
with pytest.raises(ValueError,
match="Invalid value 'invalid' in TEST_ENV"):
env_func()
def test_callable_choices_resolved_correctly(self):
"""Test that callable choices are resolved correctly."""
def get_choices():
return ["dynamic1", "dynamic2"]
with patch.dict(os.environ, {"TEST_ENV": "dynamic1,dynamic2"}):
env_func = env_list_with_choices("TEST_ENV", [], get_choices)
assert env_func() == ["dynamic1", "dynamic2"]
def test_callable_choices_with_invalid_value(self):
"""Test that callable choices raise error for invalid values."""
def get_choices():
return ["dynamic1", "dynamic2"]
with patch.dict(os.environ, {"TEST_ENV": "dynamic1,invalid"}):
env_func = env_list_with_choices("TEST_ENV", [], get_choices)
with pytest.raises(ValueError,
match="Invalid value 'invalid' in TEST_ENV"):
env_func()
def test_duplicate_values_preserved(self):
"""Test that duplicate values in the list are preserved."""
with patch.dict(os.environ, {"TEST_ENV": "option1,option1,option2"}):
env_func = env_list_with_choices("TEST_ENV", [],
["option1", "option2"])
assert env_func() == ["option1", "option1", "option2"]

View File

@ -13,6 +13,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ToolCall) ToolCall)
from vllm.entrypoints.openai.tool_parsers.qwen3coder_tool_parser import ( from vllm.entrypoints.openai.tool_parsers.qwen3coder_tool_parser import (
Qwen3CoderToolParser) Qwen3CoderToolParser)
from vllm.entrypoints.openai.tool_parsers.qwen3xml_tool_parser import (
Qwen3XMLToolParser)
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
@ -29,6 +31,21 @@ def qwen3_tool_parser(qwen3_tokenizer):
return Qwen3CoderToolParser(qwen3_tokenizer) return Qwen3CoderToolParser(qwen3_tokenizer)
@pytest.fixture
def qwen3_xml_tool_parser(qwen3_tokenizer):
return Qwen3XMLToolParser(qwen3_tokenizer)
@pytest.fixture(params=["original", "xml"])
def qwen3_tool_parser_parametrized(qwen3_tool_parser, qwen3_xml_tool_parser,
request):
"""Parameterized fixture that provides both parser types for testing"""
if request.param == "original":
return qwen3_tool_parser
else:
return qwen3_xml_tool_parser
@pytest.fixture @pytest.fixture
def sample_tools(): def sample_tools():
return [ return [
@ -95,7 +112,7 @@ def assert_tool_calls(actual_tool_calls: list[ToolCall],
def stream_delta_message_generator( def stream_delta_message_generator(
qwen3_tool_parser: Qwen3CoderToolParser, qwen3_tool_parser,
qwen3_tokenizer: AnyTokenizer, qwen3_tokenizer: AnyTokenizer,
model_output: str, model_output: str,
request: Optional[ChatCompletionRequest] = None request: Optional[ChatCompletionRequest] = None
@ -144,9 +161,9 @@ def stream_delta_message_generator(
read_offset = new_read_offset read_offset = new_read_offset
def test_extract_tool_calls_no_tools(qwen3_tool_parser): def test_extract_tool_calls_no_tools(qwen3_tool_parser_parametrized):
model_output = "This is a test response without any tool calls" model_output = "This is a test response without any tool calls"
extracted_tool_calls = qwen3_tool_parser.extract_tool_calls( extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type] model_output, request=None) # type: ignore[arg-type]
assert not extracted_tool_calls.tools_called assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == [] assert extracted_tool_calls.tool_calls == []
@ -294,12 +311,13 @@ circle
], "Let me calculate that area for you."), ], "Let me calculate that area for you."),
], ],
) )
def test_extract_tool_calls(qwen3_tool_parser, sample_tools, model_output, def test_extract_tool_calls(qwen3_tool_parser_parametrized, sample_tools,
expected_tool_calls, expected_content): model_output, expected_tool_calls,
expected_content):
request = ChatCompletionRequest(model=MODEL, request = ChatCompletionRequest(model=MODEL,
messages=[], messages=[],
tools=sample_tools) tools=sample_tools)
extracted_tool_calls = qwen3_tool_parser.extract_tool_calls( extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
model_output, request=request) model_output, request=request)
assert extracted_tool_calls.tools_called assert extracted_tool_calls.tools_called
@ -308,7 +326,8 @@ def test_extract_tool_calls(qwen3_tool_parser, sample_tools, model_output,
assert extracted_tool_calls.content == expected_content assert extracted_tool_calls.content == expected_content
def test_extract_tool_calls_fallback_no_tags(qwen3_tool_parser, sample_tools): def test_extract_tool_calls_fallback_no_tags(qwen3_tool_parser_parametrized,
sample_tools):
"""Test fallback parsing when XML tags are missing""" """Test fallback parsing when XML tags are missing"""
model_output = '''<function=get_current_weather> model_output = '''<function=get_current_weather>
<parameter=city> <parameter=city>
@ -322,7 +341,7 @@ TX
request = ChatCompletionRequest(model=MODEL, request = ChatCompletionRequest(model=MODEL,
messages=[], messages=[],
tools=sample_tools) tools=sample_tools)
extracted_tool_calls = qwen3_tool_parser.extract_tool_calls( extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
model_output, request=request) model_output, request=request)
assert extracted_tool_calls.tools_called assert extracted_tool_calls.tools_called
@ -331,7 +350,7 @@ TX
"get_current_weather") "get_current_weather")
def test_extract_tool_calls_type_conversion(qwen3_tool_parser): def test_extract_tool_calls_type_conversion(qwen3_tool_parser_parametrized):
"""Test parameter type conversion based on tool schema""" """Test parameter type conversion based on tool schema"""
tools = [ tools = [
ChatCompletionToolsParam(type="function", ChatCompletionToolsParam(type="function",
@ -381,7 +400,7 @@ hello world
</tool_call>''' </tool_call>'''
request = ChatCompletionRequest(model=MODEL, messages=[], tools=tools) request = ChatCompletionRequest(model=MODEL, messages=[], tools=tools)
extracted_tool_calls = qwen3_tool_parser.extract_tool_calls( extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
model_output, request=request) model_output, request=request)
args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments) args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments)
@ -536,9 +555,10 @@ circle
], "Let me calculate that area for you."), ], "Let me calculate that area for you."),
], ],
) )
def test_extract_tool_calls_streaming(qwen3_tool_parser, qwen3_tokenizer, def test_extract_tool_calls_streaming(qwen3_tool_parser_parametrized,
sample_tools, model_output, qwen3_tokenizer, sample_tools,
expected_tool_calls, expected_content): model_output, expected_tool_calls,
expected_content):
"""Test incremental streaming behavior including typed parameters""" """Test incremental streaming behavior including typed parameters"""
request = ChatCompletionRequest(model=MODEL, request = ChatCompletionRequest(model=MODEL,
messages=[], messages=[],
@ -548,7 +568,8 @@ def test_extract_tool_calls_streaming(qwen3_tool_parser, qwen3_tokenizer,
tool_states = {} # Track state per tool index tool_states = {} # Track state per tool index
for delta_message in stream_delta_message_generator( for delta_message in stream_delta_message_generator(
qwen3_tool_parser, qwen3_tokenizer, model_output, request): qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output,
request):
# role should never be streamed from tool parser # role should never be streamed from tool parser
assert not delta_message.role assert not delta_message.role
@ -609,7 +630,7 @@ def test_extract_tool_calls_streaming(qwen3_tool_parser, qwen3_tokenizer,
def test_extract_tool_calls_missing_closing_parameter_tag( def test_extract_tool_calls_missing_closing_parameter_tag(
qwen3_tool_parser, sample_tools): qwen3_tool_parser_parametrized, sample_tools):
"""Test handling of missing closing </parameter> tag""" """Test handling of missing closing </parameter> tag"""
# Using get_current_weather from sample_tools but with malformed XML # Using get_current_weather from sample_tools but with malformed XML
model_output = '''Let me check the weather for you: model_output = '''Let me check the weather for you:
@ -629,7 +650,7 @@ fahrenheit
request = ChatCompletionRequest(model=MODEL, request = ChatCompletionRequest(model=MODEL,
messages=[], messages=[],
tools=sample_tools) tools=sample_tools)
extracted_tool_calls = qwen3_tool_parser.extract_tool_calls( extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
model_output, request=request) model_output, request=request)
# The parser should handle the malformed XML gracefully # The parser should handle the malformed XML gracefully
@ -652,7 +673,7 @@ fahrenheit
def test_extract_tool_calls_streaming_missing_closing_tag( def test_extract_tool_calls_streaming_missing_closing_tag(
qwen3_tool_parser, qwen3_tokenizer, sample_tools): qwen3_tool_parser_parametrized, qwen3_tokenizer, sample_tools):
"""Test streaming with missing closing </parameter> tag""" """Test streaming with missing closing </parameter> tag"""
# Using get_current_weather from sample_tools but with malformed XML # Using get_current_weather from sample_tools but with malformed XML
model_output = '''Let me check the weather for you: model_output = '''Let me check the weather for you:
@ -677,7 +698,8 @@ fahrenheit
tool_states = {} tool_states = {}
for delta_message in stream_delta_message_generator( for delta_message in stream_delta_message_generator(
qwen3_tool_parser, qwen3_tokenizer, model_output, request): qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output,
request):
if delta_message.content: if delta_message.content:
other_content += delta_message.content other_content += delta_message.content
@ -727,9 +749,8 @@ fahrenheit
assert args["unit"] == "fahrenheit" assert args["unit"] == "fahrenheit"
def test_extract_tool_calls_streaming_incremental(qwen3_tool_parser, def test_extract_tool_calls_streaming_incremental(
qwen3_tokenizer, qwen3_tool_parser_parametrized, qwen3_tokenizer, sample_tools):
sample_tools):
"""Test that streaming is truly incremental""" """Test that streaming is truly incremental"""
model_output = '''I'll check the weather.<tool_call> model_output = '''I'll check the weather.<tool_call>
<function=get_current_weather> <function=get_current_weather>
@ -748,7 +769,8 @@ TX
chunks = [] chunks = []
for delta_message in stream_delta_message_generator( for delta_message in stream_delta_message_generator(
qwen3_tool_parser, qwen3_tokenizer, model_output, request): qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output,
request):
chunks.append(delta_message) chunks.append(delta_message)
# Should have multiple chunks # Should have multiple chunks
@ -784,3 +806,49 @@ TX
parsed_args = json.loads(full_args) parsed_args = json.loads(full_args)
assert parsed_args["city"] == "Dallas" assert parsed_args["city"] == "Dallas"
assert parsed_args["state"] == "TX" assert parsed_args["state"] == "TX"
def test_extract_tool_calls_complex_type_with_single_quote(
qwen3_tool_parser_parametrized):
"""Test parameter type conversion based on tool schema"""
tools = [
ChatCompletionToolsParam(type="function",
function={
"name": "test_types",
"parameters": {
"type": "object",
"properties": {
"int_param": {
"type": "integer"
},
"float_param": {
"type": "float"
},
"bool_param": {
"type": "boolean"
},
"str_param": {
"type": "string"
},
"obj_param": {
"type": "object"
}
}
}
})
]
model_output = '''<tool_call>
<function=test_types>
<parameter=obj_param>
{'key': 'value'}
</parameter>
</function>
</tool_call>'''
request = ChatCompletionRequest(model=MODEL, messages=[], tools=tools)
extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
model_output, request=request)
args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments)
assert args["obj_param"] == {"key": "value"}

View File

@ -6,6 +6,7 @@ Run `pytest tests/kernels/moe/test_moe_pallas.py`.
""" """
import pytest import pytest
import torch import torch
import torch_xla
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
@ -77,7 +78,7 @@ def test_pallas_moe(
expert_map=e_map, expert_map=e_map,
renormalize=False, renormalize=False,
) )
xm.mark_step() torch_xla.sync(wait=False)
# Compare outputs # Compare outputs
torch.testing.assert_close( torch.testing.assert_close(

View File

@ -47,7 +47,10 @@ backend_configs = {
# FA3 on Hopper # FA3 on Hopper
"FA3": "FA3":
BackendConfig(name="FA3", BackendConfig(name="FA3",
env_vars={"VLLM_FLASH_ATTN_VERSION": "3"}, env_vars={
"VLLM_FLASH_ATTN_VERSION": "3",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={ comp_config={
"cudagraph_mode": "FULL", "cudagraph_mode": "FULL",
}, },
@ -67,6 +70,7 @@ backend_configs = {
BackendConfig(name="FlashAttentionMLA", BackendConfig(name="FlashAttentionMLA",
env_vars={ env_vars={
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA", "VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
}, },
comp_config={ comp_config={
"cudagraph_mode": "FULL_DECODE_ONLY", "cudagraph_mode": "FULL_DECODE_ONLY",
@ -75,7 +79,10 @@ backend_configs = {
# FA2 # FA2
"FA2": "FA2":
BackendConfig(name="FA2", BackendConfig(name="FA2",
env_vars={"VLLM_FLASH_ATTN_VERSION": "2"}, env_vars={
"VLLM_FLASH_ATTN_VERSION": "2",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={ comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE", "cudagraph_mode": "FULL_AND_PIECEWISE",
}), }),

View File

@ -85,7 +85,10 @@ run_tests_for_model() {
echo "Starting prefill instance $i on GPU $GPU_ID, port $PORT" echo "Starting prefill instance $i on GPU $GPU_ID, port $PORT"
# Build the command with or without model-specific args # Build the command with or without model-specific args
BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT vllm serve $model_name \ BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID \
UCX_NET_DEVICES=all \
VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \
vllm serve $model_name \
--port $PORT \ --port $PORT \
--enforce-eager \ --enforce-eager \
--gpu-memory-utilization 0.2 \ --gpu-memory-utilization 0.2 \
@ -117,7 +120,10 @@ run_tests_for_model() {
echo "Starting decode instance $i on GPU $GPU_ID, port $PORT" echo "Starting decode instance $i on GPU $GPU_ID, port $PORT"
# Build the command with or without model-specific args # Build the command with or without model-specific args
BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT vllm serve $model_name \ BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID \
UCX_NET_DEVICES=all \
VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \
vllm serve $model_name \
--port $PORT \ --port $PORT \
--enforce-eager \ --enforce-eager \
--gpu-memory-utilization 0.2 \ --gpu-memory-utilization 0.2 \

View File

@ -0,0 +1,59 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa: E501
SharedStorageConnectorMetadata)
from vllm.distributed.kv_transfer.kv_transfer_state import (
ensure_kv_transfer_initialized, get_kv_transfer_group)
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorModelRunnerMixin)
# Importing utils registers TestSharedStorageConnector with the factory
from .utils import create_vllm_config
def _make_empty_scheduler_output():
return SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={},
total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
kv_connector_metadata=SharedStorageConnectorMetadata(),
)
def test_kv_connector_mixin_clears_metadata():
vllm_config = create_vllm_config()
vllm_config.kv_transfer_config.kv_connector = "TestSharedStorageConnector"
vllm_config.kv_transfer_config.kv_role = "kv_both"
vllm_config.kv_transfer_config.kv_connector_extra_config["name"] = ("unit")
# Initialize the global connector instance
ensure_kv_transfer_initialized(vllm_config)
try:
# Minimal scheduler output with empty metadata; mixin should still
# bind/clear metadata even if no loads happen
scheduler_output = _make_empty_scheduler_output()
# Invoke the no-forward path which uses the mixin context manager
KVConnectorModelRunnerMixin.kv_connector_no_forward(
scheduler_output, vllm_config)
# Verify clear_connector_metadata was called on the connector
connector = get_kv_transfer_group()
assert connector._connector_metadata is None
# Test connector wrapper records method calls
assert connector.call_record.get("bind_connector_metadata", 0) == 1
assert connector.call_record.get("clear_connector_metadata", 0) == 1
finally:
# Ensure we clean up the global connector between tests
KVConnectorModelRunnerMixin.ensure_kv_transfer_shutdown()

View File

@ -27,6 +27,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata, KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
NixlConnectorWorker, NixlKVConnectorStats) NixlConnectorWorker, NixlKVConnectorStats)
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.platforms.interface import Platform
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
@ -56,7 +57,10 @@ class FakeNixlWrapper:
def get_reg_descs(self, caches_data, memory_type: str) -> list: def get_reg_descs(self, caches_data, memory_type: str) -> list:
return [str(uuid.uuid4()) for _ in caches_data] return [str(uuid.uuid4()) for _ in caches_data]
def register_memory(self, descs) -> None: def register_memory(self, descs, backends) -> None:
pass
def deregister_memory(self, descs) -> None:
pass pass
def get_xfer_descs(self, blocks_data, memory_type: str) -> list: def get_xfer_descs(self, blocks_data, memory_type: str) -> list:
@ -85,6 +89,12 @@ class FakeNixlWrapper:
def release_xfer_handle(self, handle: int) -> None: def release_xfer_handle(self, handle: int) -> None:
pass pass
def release_dlist_handle(self, handle: int) -> None:
pass
def remove_remote_agent(self, agent: str) -> None:
pass
def send_notif(self, agent_name: str, notif_msg: bytes) -> None: def send_notif(self, agent_name: str, notif_msg: bytes) -> None:
pass pass
@ -855,3 +865,95 @@ def test_register_kv_caches(dist_init):
assert block_len == expected_block_len, \ assert block_len == expected_block_len, \
f"Block entry {i}: Expected block len {expected_block_len}, " \ f"Block entry {i}: Expected block len {expected_block_len}, " \
f"got {block_len}" f"got {block_len}"
class FakePlatform(Platform):
device_type: str = "oot"
@classmethod
def get_nixl_supported_devices(cls) -> dict[str, tuple[str, ...]]:
"""
Returns a mapping from device_type to a tuple of supported
kv_buffer_device for nixl.
"""
return {'oot': ('oot', )}
@classmethod
def get_nixl_memory_type(cls) -> Optional[str]:
"""
Returns the nixl memory type for the current platform.
"""
return 'VRAM'
@pytest.mark.parametrize("kv_buffer_device, nixl_memory_type", [
("oot", "VRAM"),
])
def test_kv_buffer_to_nixl_memory_types(dist_init, kv_buffer_device,
nixl_memory_type):
"""
Test that register_kv_caches() passes the correct memory types from the
config to the nixl_wrapper.
"""
vllm_config = create_vllm_config()
# Override the default memory types in the config
vllm_config.kv_transfer_config.kv_buffer_device = kv_buffer_device
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
_NIXL_SUPPORTED_DEVICE)
_NIXL_SUPPORTED_DEVICE.update(FakePlatform.get_nixl_supported_devices())
with patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"), \
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event"), \
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread"), \
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform", FakePlatform), \
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector._NIXL_SUPPORTED_DEVICE", _NIXL_SUPPORTED_DEVICE): # noqa: E501
# Create connector and replace its worker with a fake one for isolation
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
# Verify get_reg_descs was called with the correct memory_type
assert connector.connector_worker.kv_buffer_device == kv_buffer_device
assert connector.connector_worker.nixl_memory_type == nixl_memory_type
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper)
def test_shutdown_cleans_up_resources(dist_init):
"""Test that shutdown() properly cleans up all resources."""
vllm_config = create_vllm_config()
worker = NixlConnectorWorker(vllm_config,
vllm_config.kv_transfer_config.engine_id)
nixl_wrapper = worker.nixl_wrapper
with patch.object(worker, '_handshake_initiation_executor') as mock_exec, \
patch.object(worker, '_nixl_handshake_listener_t') as mock_listener, \
patch.object(nixl_wrapper, 'release_xfer_handle') as mock_rel_xfer, \
patch.object(nixl_wrapper, 'release_dlist_handle') as mock_rel_dlist, \
patch.object(nixl_wrapper, 'remove_remote_agent') as mock_rem_agent, \
patch.object(nixl_wrapper, 'deregister_memory') as mock_dereg:
worker._recving_transfers = {"req1": [(123, time.perf_counter())]}
worker.src_xfer_side_handle = 456
worker.dst_xfer_side_handles = {"engine1": 789}
worker._remote_agents = {"engine1": {0: "agent1"}}
worker._registered_descs = ["desc1", "desc2"]
worker.shutdown()
# Test idempotency
worker.shutdown()
worker.shutdown()
mock_exec.shutdown.assert_called_with(wait=False)
mock_listener.join.assert_called_once_with(timeout=0)
mock_rel_xfer.assert_called_once_with(123)
assert mock_rel_dlist.call_count == 2
mock_rel_dlist.assert_any_call(456) # src handle
mock_rel_dlist.assert_any_call(789) # dst handle
mock_rem_agent.assert_called_once_with("agent1")
assert mock_dereg.call_count == 2
mock_dereg.assert_any_call("desc1")
mock_dereg.assert_any_call("desc2")

View File

@ -0,0 +1,62 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
import pytest
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
CPU_BLOCK_SIZES = [16, 48]
@pytest.mark.parametrize("cpu_block_size", CPU_BLOCK_SIZES)
def test_cpu_offloading(cpu_block_size: int) -> None:
"""
Tests OffloadingConnector with CPUOffloadingSpec.
"""
# configure OffloadingConnector (spec_name=CPUOffloadingSpec by default)
kv_transfer_config = KVTransferConfig(
kv_connector="OffloadingConnector",
kv_role="kv_both",
kv_connector_extra_config={
"num_cpu_blocks": 100,
"block_size": cpu_block_size
},
)
llm = LLM(
model="meta-llama/Llama-3.2-1B-Instruct",
gpu_memory_utilization=0.5,
kv_transfer_config=kv_transfer_config,
)
prompts = ["Hi " * 100]
sampling_params = SamplingParams(temperature=0, max_tokens=20)
# run generation - this should trigger saving KV cache
start_time = time.time()
llm.generate(prompts, sampling_params, use_tqdm=False)
cold_time = time.time() - start_time
# run generation again - should hit the GPU prefix cache
start_time = time.time()
llm.generate(prompts, sampling_params, use_tqdm=False)
gpu_hit_time = time.time() - start_time
# reset prefix cache to avoid GPU hit.
llm.reset_prefix_cache()
# sleep for a sec to make sure CPU finished storing
time.sleep(1)
# run generation again - this should trigger loading from CPU
start_time = time.time()
llm.generate(prompts, sampling_params, use_tqdm=False)
cpu_hit_time = time.time() - start_time
print("Generation times:")
print(f" Cold: {cold_time * 1000:.2f}ms")
print(f" GPU hit: {gpu_hit_time * 1000:.2f}ms")
print(f" CPU hit: {cpu_hit_time * 1000:.2f}ms")

View File

@ -13,7 +13,6 @@ from vllm import SamplingParams
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs import PromptType from vllm.inputs import PromptType
from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.core_client import DPAsyncMPClient from vllm.v1.engine.core_client import DPAsyncMPClient
@ -29,10 +28,6 @@ engine_args = AsyncEngineArgs(
data_parallel_size=DP_SIZE, data_parallel_size=DP_SIZE,
) )
if not current_platform.supports_v1(engine_args.create_model_config()):
pytest.skip(reason="Requires V1-supporting platform.",
allow_module_level=True)
async def generate( async def generate(
engine: AsyncLLM, engine: AsyncLLM,

View File

@ -4,6 +4,7 @@ import math
import pytest import pytest
import torch import torch
import torch_xla
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
@ -63,7 +64,7 @@ def test_topp_result_sums_past_p():
probs.masked_fill_(logits_masked.isinf(), 0) probs.masked_fill_(logits_masked.isinf(), 0)
masked_prob_sum = probs.sum(dim=-1) masked_prob_sum = probs.sum(dim=-1)
xm.mark_step() torch_xla.sync()
# Perform assertion on CPU. # Perform assertion on CPU.
assert torch.all(torch.ge(masked_prob_sum.cpu() + TOLERANCE, p.cpu())) assert torch.all(torch.ge(masked_prob_sum.cpu() + TOLERANCE, p.cpu()))
@ -82,7 +83,7 @@ def test_topp_basic():
k=torch.tensor([3, 3]), k=torch.tensor([3, 3]),
p=torch.tensor([0.79, 0.79])) p=torch.tensor([0.79, 0.79]))
xm.mark_step() torch_xla.sync()
# Expect the smallest elements to be dropped. # Expect the smallest elements to be dropped.
expected_result = logits.clone().cpu() expected_result = logits.clone().cpu()
@ -104,7 +105,7 @@ def test_topp_select_all():
k=torch.tensor([3, 3]), k=torch.tensor([3, 3]),
p=torch.tensor([1.0, 1.0])) p=torch.tensor([1.0, 1.0]))
xm.mark_step() torch_xla.sync()
assert torch.allclose(logits.cpu(), result.cpu()) assert torch.allclose(logits.cpu(), result.cpu())
@ -122,7 +123,7 @@ def test_topp_with_ties():
k=torch.tensor([4]), k=torch.tensor([4]),
p=torch.tensor([0.2])) p=torch.tensor([0.2]))
xm.mark_step() torch_xla.sync()
# All tie values are included in the top-p set. Tie breaking is left # All tie values are included in the top-p set. Tie breaking is left
# to be done during final sampling (all tie tokens have equal # to be done during final sampling (all tie tokens have equal
@ -146,7 +147,7 @@ def test_both_topk_topp():
k=torch.tensor([1, 3]), k=torch.tensor([1, 3]),
p=torch.tensor([0.79, 0.79])) p=torch.tensor([0.79, 0.79]))
xm.mark_step() torch_xla.sync()
# Since for the first batch k=1, expect only the largest element gets # Since for the first batch k=1, expect only the largest element gets
# selected. # selected.

View File

@ -1,35 +0,0 @@
#!/bin/bash
CI=${1:-0}
PYTHON_VERSION=${2:-local}
if [ "$CI" -eq 1 ]; then
set -e
fi
if [ $PYTHON_VERSION == "local" ]; then
PYTHON_VERSION=$(python -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")')
fi
run_mypy() {
echo "Running mypy on $1"
if [ "$CI" -eq 1 ] && [ -z "$1" ]; then
mypy --python-version "${PYTHON_VERSION}" "$@"
return
fi
mypy --follow-imports skip --python-version "${PYTHON_VERSION}" "$@"
}
run_mypy # Note that this is less strict than CI
run_mypy tests
run_mypy vllm/attention
run_mypy vllm/compilation
run_mypy vllm/distributed
run_mypy vllm/engine
run_mypy vllm/executor
run_mypy vllm/inputs
run_mypy vllm/lora
run_mypy --exclude 'vllm/model_executor/layers/fla/ops' vllm/model_executor
run_mypy vllm/plugins
run_mypy vllm/worker
run_mypy vllm/v1

View File

@ -1,20 +1,10 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import sys import sys
import regex as re import regex as re
try:
import pathspec
except ImportError:
print(
"ERROR: The 'pathspec' library is required. "
"Install it with 'pip install pathspec'.",
file=sys.stderr)
sys.exit(2)
# List of files (relative to repo root) that are allowed to import pickle or # List of files (relative to repo root) that are allowed to import pickle or
# cloudpickle # cloudpickle
# #
@ -25,7 +15,7 @@ except ImportError:
# Before adding new uses of pickle/cloudpickle, please consider safer # Before adding new uses of pickle/cloudpickle, please consider safer
# alternatives like msgpack or pydantic that are already in use in vLLM. Only # alternatives like msgpack or pydantic that are already in use in vLLM. Only
# add to this list if absolutely necessary and after careful security review. # add to this list if absolutely necessary and after careful security review.
ALLOWED_FILES = set([ ALLOWED_FILES = {
# pickle # pickle
'vllm/v1/serial_utils.py', 'vllm/v1/serial_utils.py',
'vllm/v1/executor/multiproc_executor.py', 'vllm/v1/executor/multiproc_executor.py',
@ -36,11 +26,9 @@ ALLOWED_FILES = set([
'tests/tokenization/test_cached_tokenizer.py', 'tests/tokenization/test_cached_tokenizer.py',
'vllm/distributed/utils.py', 'vllm/distributed/utils.py',
'vllm/distributed/parallel_state.py', 'vllm/distributed/parallel_state.py',
'vllm/engine/multiprocessing/client.py',
'vllm/distributed/device_communicators/all_reduce_utils.py', 'vllm/distributed/device_communicators/all_reduce_utils.py',
'vllm/distributed/device_communicators/shm_broadcast.py', 'vllm/distributed/device_communicators/shm_broadcast.py',
'vllm/distributed/device_communicators/shm_object_storage.py', 'vllm/distributed/device_communicators/shm_object_storage.py',
'vllm/engine/multiprocessing/engine.py',
'benchmarks/kernels/graph_machete_bench.py', 'benchmarks/kernels/graph_machete_bench.py',
'benchmarks/kernels/benchmark_lora.py', 'benchmarks/kernels/benchmark_lora.py',
'benchmarks/kernels/benchmark_machete.py', 'benchmarks/kernels/benchmark_machete.py',
@ -55,65 +43,30 @@ ALLOWED_FILES = set([
'tests/utils.py', 'tests/utils.py',
# pickle and cloudpickle # pickle and cloudpickle
'vllm/utils/__init__.py', 'vllm/utils/__init__.py',
'vllm/v1/serial_utils.py', }
'vllm/v1/executor/multiproc_executor.py',
'vllm/transformers_utils/config.py',
'vllm/model_executor/models/registry.py',
'vllm/engine/multiprocessing/client.py',
'vllm/engine/multiprocessing/engine.py',
])
PICKLE_RE = re.compile(r"^\s*(import\s+(pickle|cloudpickle)(\s|$|\sas)" PICKLE_RE = re.compile(r"^\s*(import\s+(pickle|cloudpickle)(\s|$|\sas)"
r"|from\s+(pickle|cloudpickle)\s+import\b)") r"|from\s+(pickle|cloudpickle)\s+import\b)")
def is_python_file(path): def scan_file(path: str) -> int:
return path.endswith('.py')
def scan_file(path):
with open(path, encoding='utf-8') as f: with open(path, encoding='utf-8') as f:
for line in f: for i, line in enumerate(f, 1):
if PICKLE_RE.match(line): if PICKLE_RE.match(line):
return True print(f"{path}:{i}: "
return False "\033[91merror:\033[0m " # red color
"Found pickle/cloudpickle import")
return 1
def load_gitignore(repo_root): return 0
gitignore_path = os.path.join(repo_root, '.gitignore')
patterns = []
if os.path.exists(gitignore_path):
with open(gitignore_path, encoding='utf-8') as f:
patterns = f.read().splitlines()
# Always ignore .git directory
patterns.append('.git/')
return pathspec.PathSpec.from_lines('gitwildmatch', patterns)
def main(): def main():
repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) returncode = 0
spec = load_gitignore(repo_root) for filename in sys.argv[1:]:
bad_files = [] if filename in ALLOWED_FILES:
for dirpath, _, filenames in os.walk(repo_root): continue
for filename in filenames: returncode |= scan_file(filename)
if not is_python_file(filename): return returncode
continue
abs_path = os.path.join(dirpath, filename)
rel_path = os.path.relpath(abs_path, repo_root)
# Skip ignored files
if spec.match_file(rel_path):
continue
if scan_file(abs_path) and rel_path not in ALLOWED_FILES:
bad_files.append(rel_path)
if bad_files:
print("\nERROR: The following files import 'pickle' or 'cloudpickle' "
"but are not in the allowed list:")
for f in bad_files:
print(f" {f}")
print("\nIf this is intentional, update the allowed list in "
"tools/check_pickle_imports.py.")
sys.exit(1)
sys.exit(0)
def test_regex(): def test_regex():
@ -149,4 +102,4 @@ if __name__ == '__main__':
if '--test-regex' in sys.argv: if '--test-regex' in sys.argv:
test_regex() test_regex()
else: else:
main() sys.exit(main())

140
tools/pre_commit/mypy.py Executable file
View File

@ -0,0 +1,140 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Run mypy on changed files.
This script is designed to be used as a pre-commit hook. It runs mypy
on files that have been changed. It groups files into different mypy calls
based on their directory to avoid import following issues.
Usage:
python tools/pre_commit/mypy.py <ci> <python_version> <changed_files...>
Args:
ci: "1" if running in CI, "0" otherwise. In CI, follow_imports is set to
"silent" for the main group of files.
python_version: Python version to use (e.g., "3.10") or "local" to use
the local Python version.
changed_files: List of changed files to check.
"""
import subprocess
import sys
from typing import Optional
import regex as re
FILES = [
"vllm/*.py",
"vllm/assets",
"vllm/entrypoints",
"vllm/inputs",
"vllm/logging_utils",
"vllm/multimodal",
"vllm/platforms",
"vllm/transformers_utils",
"vllm/triton_utils",
"vllm/usage",
]
# After fixing errors resulting from changing follow_imports
# from "skip" to "silent", move the following directories to FILES
SEPARATE_GROUPS = [
"tests",
"vllm/attention",
"vllm/compilation",
"vllm/distributed",
"vllm/engine",
"vllm/executor",
"vllm/inputs",
"vllm/lora",
"vllm/model_executor",
"vllm/plugins",
"vllm/worker",
"vllm/v1",
]
# TODO(woosuk): Include the code from Megatron and HuggingFace.
EXCLUDE = [
"vllm/model_executor/parallel_utils",
"vllm/model_executor/models",
"vllm/model_executor/layers/fla/ops",
# Ignore triton kernels in ops.
"vllm/attention/ops",
]
def group_files(changed_files: list[str]) -> dict[str, list[str]]:
"""
Group changed files into different mypy calls.
Args:
changed_files: List of changed files.
Returns:
A dictionary mapping file group names to lists of changed files.
"""
exclude_pattern = re.compile(f"^{'|'.join(EXCLUDE)}.*")
files_pattern = re.compile(f"^({'|'.join(FILES)}).*")
file_groups = {"": []}
file_groups.update({k: [] for k in SEPARATE_GROUPS})
for changed_file in changed_files:
# Skip files which should be ignored completely
if exclude_pattern.match(changed_file):
continue
# Group files by mypy call
if files_pattern.match(changed_file):
file_groups[""].append(changed_file)
continue
else:
for directory in SEPARATE_GROUPS:
if re.match(f"^{directory}.*", changed_file):
file_groups[directory].append(changed_file)
break
return file_groups
def mypy(targets: list[str], python_version: Optional[str],
follow_imports: Optional[str], file_group: str) -> int:
"""
Run mypy on the given targets.
Args:
targets: List of files or directories to check.
python_version: Python version to use (e.g., "3.10") or None to use
the default mypy version.
follow_imports: Value for the --follow-imports option or None to use
the default mypy behavior.
file_group: The file group name for logging purposes.
Returns:
The return code from mypy.
"""
args = ["mypy"]
if python_version is not None:
args += ["--python-version", python_version]
if follow_imports is not None:
args += ["--follow-imports", follow_imports]
print(f"$ {' '.join(args)} {file_group}")
return subprocess.run(args + targets, check=False).returncode
def main():
ci = sys.argv[1] == "1"
python_version = sys.argv[2]
file_groups = group_files(sys.argv[3:])
if python_version == "local":
python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
returncode = 0
for file_group, changed_files in file_groups.items():
follow_imports = None if ci and file_group == "" else "skip"
if changed_files:
returncode |= mypy(changed_files, python_version, follow_imports,
file_group)
return returncode
if __name__ == "__main__":
sys.exit(main())

View File

@ -10,7 +10,6 @@ from typing import (Any, Dict, Generic, List, Optional, Protocol, Set, Tuple,
import torch import torch
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
from vllm.multimodal import MultiModalPlaceholderMap
class AttentionType: class AttentionType:
@ -116,15 +115,6 @@ class AttentionMetadata:
# in block 0, and 1st slot in block 1, respectively. # in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor slot_mapping: torch.Tensor
# The index maps that relate multi-modal embeddings to the corresponding
# placeholders.
#
# N.B. These aren't really related to attention and don't belong on this
# type -- this is just a temporary solution to make them available to
# `model_executable`.
multi_modal_placeholder_index_maps: Optional[Dict[
str, MultiModalPlaceholderMap.IndexMap]]
# Enable/disable KV scales calculation. This is so that we can disable the # Enable/disable KV scales calculation. This is so that we can disable the
# calculation until after prefill and cuda graph capture. # calculation until after prefill and cuda graph capture.
enable_kv_scales_calculation: bool enable_kv_scales_calculation: bool

View File

@ -1,10 +1,9 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from itertools import accumulate from itertools import accumulate
from typing import Dict, List, Optional, Tuple, Type from typing import List, Optional, Tuple, Type
import torch import torch
@ -12,7 +11,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionMetadata,
AttentionMetadataBuilder) AttentionMetadataBuilder)
from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.backends.utils import CommonAttentionState
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.utils import async_tensor_h2d from vllm.utils import async_tensor_h2d
# Placeholder attention backend for models like Mamba and pooling models that # Placeholder attention backend for models like Mamba and pooling models that
@ -141,8 +139,6 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
num_prefill_tokens=self.num_prefill_tokens, num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0, num_decode_tokens=0,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
enable_kv_scales_calculation=self.enable_kv_scales_calculation, enable_kv_scales_calculation=self.enable_kv_scales_calculation,
seq_lens=seq_lens, seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor, seq_lens_tensor=seq_lens_tensor,
@ -178,7 +174,6 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
num_prefill_tokens=0, num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens, num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True, enable_kv_scales_calculation=True,
seq_lens=None, seq_lens=None,
seq_lens_tensor=seq_lens_tensor, seq_lens_tensor=seq_lens_tensor,
@ -210,9 +205,6 @@ class PlaceholderAttentionMetadataBuilder(
self.prefill_seq_lens: List[int] = [] self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = [] self.context_lens: List[int] = []
self.curr_seq_lens: List[int] = [] self.curr_seq_lens: List[int] = []
self.multimodal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
self.num_prefills = 0 self.num_prefills = 0
self.num_prefill_tokens = 0 self.num_prefill_tokens = 0
self.num_decode_tokens = 0 self.num_decode_tokens = 0
@ -232,12 +224,6 @@ class PlaceholderAttentionMetadataBuilder(
self.context_lens.append(context_len) self.context_lens.append(context_len)
if is_prompt: if is_prompt:
mm_maps = inter_data.multi_modal_placeholder_maps
if mm_maps:
for modality, placeholders in mm_maps.items():
self.multimodal_placeholder_maps[modality].extend(
placeholders)
self.num_prefills += 1 self.num_prefills += 1
self.num_prefill_tokens += token_len self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len) self.prefill_seq_lens.append(seq_len)
@ -295,12 +281,6 @@ class PlaceholderAttentionMetadataBuilder(
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
device, self.runner.pin_memory) device, self.runner.pin_memory)
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
self.multimodal_placeholder_maps.items()
}
# Placeholders # Placeholders
slot_mapping_tensor = torch.empty(0) slot_mapping_tensor = torch.empty(0)
block_tables = torch.empty(0) block_tables = torch.empty(0)
@ -308,7 +288,6 @@ class PlaceholderAttentionMetadataBuilder(
return PlaceholderAttentionMetadata( return PlaceholderAttentionMetadata(
num_prefills=self.num_prefills, num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor, slot_mapping=slot_mapping_tensor,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=True, enable_kv_scales_calculation=True,
num_prefill_tokens=self.num_prefill_tokens, num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention backend utils""" """Attention backend utils"""
from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from itertools import accumulate from itertools import accumulate
@ -15,16 +14,10 @@ from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.abstract import AttentionType
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm.utils import async_tensor_h2d, make_tensor_with_pad
logger = init_logger(__name__) logger = init_logger(__name__)
# Error string(s) for encoder/decoder
# unsupported attention scenarios
STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported "
"with encoder/decoder models.")
PAD_SLOT_ID = -1 PAD_SLOT_ID = -1
# Switch to numpy implementation of compute_slot_mapping # Switch to numpy implementation of compute_slot_mapping
@ -135,9 +128,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
self.context_lens: List[int] = [] self.context_lens: List[int] = []
self.block_tables: List[List[int]] = [] self.block_tables: List[List[int]] = []
self.curr_seq_lens: List[int] = [] self.curr_seq_lens: List[int] = []
self.multimodal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
self.num_prefills = 0 self.num_prefills = 0
self.num_prefill_tokens = 0 self.num_prefill_tokens = 0
self.num_decode_tokens = 0 self.num_decode_tokens = 0
@ -154,12 +144,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
inter_data.curr_sliding_window_blocks): inter_data.curr_sliding_window_blocks):
self.context_lens.append(context_len) self.context_lens.append(context_len)
if is_prompt: if is_prompt:
mm_maps = inter_data.multi_modal_placeholder_maps
if mm_maps:
for modality, placeholders in mm_maps.items():
self.multimodal_placeholder_maps[modality].extend(
placeholders)
self.num_prefills += 1 self.num_prefills += 1
self.num_prefill_tokens += token_len self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len) self.prefill_seq_lens.append(seq_len)
@ -254,16 +238,10 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
self.runner.pin_memory) self.runner.pin_memory)
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
device, self.runner.pin_memory) device, self.runner.pin_memory)
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
self.multimodal_placeholder_maps.items()
}
return self._metadata_cls( # type: ignore return self._metadata_cls( # type: ignore
num_prefills=self.num_prefills, num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor, slot_mapping=slot_mapping_tensor,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=True, enable_kv_scales_calculation=True,
num_prefill_tokens=self.num_prefill_tokens, num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,
@ -320,7 +298,6 @@ class CommonAttentionState(AttentionState):
num_prefill_tokens=0, num_prefill_tokens=0,
num_decode_tokens=batch_size, num_decode_tokens=batch_size,
slot_mapping=self._graph_slot_mapping[:batch_size], slot_mapping=self._graph_slot_mapping[:batch_size],
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True, enable_kv_scales_calculation=True,
seq_lens=None, seq_lens=None,
seq_lens_tensor=self._graph_seq_lens[:batch_size], seq_lens_tensor=self._graph_seq_lens[:batch_size],

View File

@ -134,6 +134,5 @@ def cp_lse_ag_out_rs(cp_attn_out: torch.Tensor,
cp_attn_lse = cp_attn_lse.contiguous() cp_attn_lse = cp_attn_lse.contiguous()
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses) lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
out, _ = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx) out, _ = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
assert out.is_contiguous()
out = cp_group.reduce_scatter(out, dim=1) out = cp_group.reduce_scatter(out, dim=1)
return out return out

View File

@ -531,18 +531,22 @@ async def benchmark(
extra_body=extra_body, extra_body=extra_body,
) )
test_output = await wait_for_endpoint( if ready_check_timeout_sec > 0:
request_func, test_output = await wait_for_endpoint(
test_input, request_func,
session, test_input,
timeout_seconds=ready_check_timeout_sec, session,
) timeout_seconds=ready_check_timeout_sec,
if not test_output.success: )
raise ValueError( if not test_output.success:
"Initial test run failed - Please make sure benchmark arguments " raise ValueError(
f"are correctly specified. Error: {test_output.error}") "Initial test run failed - Please make sure benchmark "
"arguments are correctly specified. "
f"Error: {test_output.error}")
else:
print("Initial test run completed. Starting main benchmark run...")
else: else:
print("Initial test run completed. Starting main benchmark run...") print("Skipping endpoint ready check.")
if lora_modules: if lora_modules:
# For each input request, choose a LoRA module at random. # For each input request, choose a LoRA module at random.
@ -1151,7 +1155,8 @@ def add_cli_args(parser: argparse.ArgumentParser):
type=int, type=int,
default=600, default=600,
help="Maximum time to wait for the endpoint to become ready " help="Maximum time to wait for the endpoint to become ready "
"in seconds (default: 600 seconds / 10 minutes).", "in seconds (default: 600 seconds / 10 minutes). If set to 0, "
"the ready check will be skipped."
) )

View File

@ -17,7 +17,7 @@ from vllm.platforms import current_platform
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32 from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
from .inductor_pass import enable_fake_mode from .inductor_pass import enable_fake_mode
from .vllm_inductor_pass import VllmInductorPass from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__) logger = init_logger(__name__)
@ -152,7 +152,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
register_replacement(pattern, replacement, inputs, fwd_only, pm_pass) register_replacement(pattern, replacement, inputs, fwd_only, pm_pass)
class ActivationQuantFusionPass(VllmInductorPass): class ActivationQuantFusionPass(VllmPatternMatcherPass):
""" """
This pass fuses a pre-defined set of custom ops into fused ops. This pass fuses a pre-defined set of custom ops into fused ops.
It uses the torch pattern matcher to find the patterns and replace them. It uses the torch pattern matcher to find the patterns and replace them.
@ -176,16 +176,12 @@ class ActivationQuantFusionPass(VllmInductorPass):
pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern() pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern()
pattern_silu_mul_nvfp4.register(self.patterns) pattern_silu_mul_nvfp4.register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph): def __call__(self, graph: torch.fx.Graph):
self.begin() self.matched_count = self.patterns.apply(graph)
self.dump_graph(graph, "before_act_quant_fusion") logger.debug("Replaced %s patterns", self.matched_count)
count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns in ActivationQuantFusionPass",
count)
self.dump_graph(graph, "after_act_quant_fusion")
self.end_and_log()
def uuid(self): def uuid(self):
return VllmInductorPass.hash_source(self, ActivationQuantPattern, return VllmInductorPass.hash_source(self, ActivationQuantPattern,

View File

@ -20,7 +20,7 @@ from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from .inductor_pass import enable_fake_mode from .inductor_pass import enable_fake_mode
from .vllm_inductor_pass import VllmInductorPass from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
@ -348,7 +348,7 @@ class AllGatherCutlassScaledMMPattern(BasePattern):
pm.fwd_only, pm_pass) pm.fwd_only, pm_pass)
class AsyncTPPass(VllmInductorPass): class AsyncTPPass(VllmPatternMatcherPass):
@enable_fake_mode @enable_fake_mode
def __init__(self, config: VllmConfig): def __init__(self, config: VllmConfig):
@ -378,18 +378,17 @@ class AsyncTPPass(VllmInductorPass):
AllGatherCutlassScaledMMPattern( AllGatherCutlassScaledMMPattern(
self.model_dtype, self.device).register(self.patterns) self.model_dtype, self.device).register(self.patterns)
self.dump_patterns(config, self.patterns)
def is_applicable_for_shape(self, shape: Optional[int]) -> bool: def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
# only do replace for specific shapes # only do replace for specific shapes
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
return shape is not None and shape % tp_size == 0 return shape is not None and shape % tp_size == 0
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph): def __call__(self, graph: fx.Graph):
self.begin() self.matched_count = self.patterns.apply(graph)
self.dump_graph(graph, "before_async_tp_pass") logger.debug("Replaced %s patterns", self.matched_count)
count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns with async TP pass.", count)
self.dump_graph(graph, "after_async_tp_pass")
self.end_and_log()
if flashinfer_comm is not None: if flashinfer_comm is not None:
@ -1068,7 +1067,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
pm.fwd_only, pm_pass) pm.fwd_only, pm_pass)
class AllReduceFusionPass(VllmInductorPass): class AllReduceFusionPass(VllmPatternMatcherPass):
def __init__(self, config: VllmConfig): def __init__(self, config: VllmConfig):
super().__init__(config) super().__init__(config)
@ -1124,6 +1123,7 @@ class AllReduceFusionPass(VllmInductorPass):
fuse_rms_quant=config.compilation_config.pass_config.enable_fusion) fuse_rms_quant=config.compilation_config.pass_config.enable_fusion)
self.register_patterns() self.register_patterns()
self.dump_patterns(config, self.patterns)
@enable_fake_mode @enable_fake_mode
def register_patterns(self): def register_patterns(self):
@ -1172,15 +1172,14 @@ class AllReduceFusionPass(VllmInductorPass):
self.disabled = False self.disabled = False
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph): def __call__(self, graph: fx.Graph):
if self.disabled: if self.disabled:
logger.debug("AllReduceFusionPass disabled")
return return
self.begin()
self.dump_graph(graph, "before_all_reduce_fusion_pass") self.matched_count = self.patterns.apply(graph)
count = self.patterns.apply(graph) logger.debug("Replaced %s patterns", self.matched_count)
logger.debug("Replaced %s patterns", count)
self.dump_graph(graph, "after_all_reduce_fusion_pass")
self.end_and_log()
def __del__(self): def __del__(self):
if getattr(self, "disabled", True): if getattr(self, "disabled", True):

View File

@ -26,6 +26,7 @@ class FixFunctionalizationPass(VllmInductorPass):
To add new nodes to defunctionalize, add to the if-elif chain in __call__. To add new nodes to defunctionalize, add to the if-elif chain in __call__.
""" """
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph): def __call__(self, graph: torch.fx.Graph):
# XPU does not support auto-functionalization yet. # XPU does not support auto-functionalization yet.
# Will enable this when switch to vllm-xpu-kernels. # Will enable this when switch to vllm-xpu-kernels.
@ -34,9 +35,6 @@ class FixFunctionalizationPass(VllmInductorPass):
"pass currently.") "pass currently.")
return return
self.begin()
self.dump_graph(graph, "before_fix_functionalization")
self.nodes_to_remove: list[torch.fx.Node] = [] self.nodes_to_remove: list[torch.fx.Node] = []
count = 0 count = 0
for node in graph.nodes: for node in graph.nodes:
@ -111,7 +109,7 @@ class FixFunctionalizationPass(VllmInductorPass):
count += 1 count += 1
self.dump_graph(graph, "before_fix_functionalization_cleanup") self.dump_graph(graph, "before_cleanup")
# Remove the nodes all at once # Remove the nodes all at once
count_removed = len(self.nodes_to_remove) count_removed = len(self.nodes_to_remove)
@ -120,8 +118,7 @@ class FixFunctionalizationPass(VllmInductorPass):
logger.debug("De-functionalized %s nodes, removed %s nodes", count, logger.debug("De-functionalized %s nodes, removed %s nodes", count,
count_removed) count_removed)
self.dump_graph(graph, "after_fix_functionalization") self.nodes_to_remove.clear()
self.end_and_log()
def _remove(self, node_or_nodes: Union[torch.fx.Node, def _remove(self, node_or_nodes: Union[torch.fx.Node,
Iterable[torch.fx.Node]]): Iterable[torch.fx.Node]]):

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, NamedTuple, Optional from typing import Any, NamedTuple
import torch import torch
import torch._inductor.pattern_matcher as pm import torch._inductor.pattern_matcher as pm
@ -16,10 +16,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale) kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .fx_utils import find_getitem_maybe
from .inductor_pass import enable_fake_mode from .inductor_pass import enable_fake_mode
from .multi_output_match import MultiOutputMatch from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__) logger = init_logger(__name__)
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
@ -50,8 +48,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
} }
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[ QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default
kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
class FusedRMSQuantKey(NamedTuple): class FusedRMSQuantKey(NamedTuple):
@ -80,68 +77,6 @@ FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
} }
class QuantMultiOutputMatch(MultiOutputMatch):
def __init__(self, match: pm.Match, quant_op, fused_op):
super().__init__(match)
assert isinstance(quant_op, OpOverload)
assert isinstance(fused_op, OpOverload)
self.QUANT_OP = quant_op # in-place quant op
self.FUSED_OP = fused_op # in-place fused quant op
def insert_fused_node(self, fused_return_mapping: dict[int, tuple[fx.Node,
int]],
**kwargs):
"""
This utility function inserts an auto-functionalized node for FUSED_OP.
It also correctly sets its meta value and rebinds the users of the
unfused nodes to use the fused node instead.
:param fused_return_mapping: A dictionary, mapping from getitem indices
of the fused node result to a tuple of the old node and a getitem index.
:param kwargs: kwargs that get directly forwarded to the auto_fn node
Example:
If we want to replace this graph:
_, x1, x2 = auto_fn(op1)
_, y1, y2 = auto_fn(op2)
with
_, x1, y2, x2 = auto_fn(FUSED_OP)
we would call:
insert_fused_node({1: (op1_node, 1), 2: (op2_node, 2), 3: (op1_node, 2)}
Note that the 0th element is None for auto-functionalized in-place ops.
Hence, others appear 1-indexed.
"""
fused_node = self.insert_auto_fn(self.FUSED_OP, kwargs)
indices = fused_return_mapping.keys()
getitem_nodes = self.insert_getitems(fused_node, indices)
# Prepare the meta value, use a list so it's mutable
meta_val = [None] * (max(indices) + 1)
# Iterate through elements of the tuple produced by fused_node
for idx, getitem_node in zip(indices, getitem_nodes):
old_node, old_idx = fused_return_mapping[idx]
# If the old value was never used, the old_getitem might not exist
old_getitem = find_getitem_maybe(old_node, old_idx)
if old_getitem is not None:
# Rebind the users of match getitem nodes to use the new nodes.
# The old nodes will be removed by DCE at the end of the pass.
old_getitem.replace_all_uses_with(getitem_node)
getitem_node.meta["val"] = old_getitem.meta["val"]
# Extract the appropriate meta value
# It is present even if the getitem node does not exist
meta_val[idx] = old_node.meta["val"][old_idx]
# Fix the meta value on the new fused node
fused_node.meta["val"] = tuple(meta_val)
class RMSNormQuantPattern: class RMSNormQuantPattern:
def __init__(self, epsilon: float, key: FusedRMSQuantKey): def __init__(self, epsilon: float, key: FusedRMSQuantKey):
@ -224,8 +159,7 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
symmetric=symmetric)) symmetric=symmetric))
super().__init__(epsilon, key) super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass, def register(self, pm_pass: PatternMatcherPass):
record_match: Callable[[MultiOutputMatch], bool]):
def pattern(result: torch.Tensor, input: torch.Tensor, def pattern(result: torch.Tensor, input: torch.Tensor,
residual: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
@ -271,36 +205,7 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
inputs, inputs,
pm.fwd_only, pm.fwd_only,
pm_pass, pm_pass,
extra_check=lambda m: record_match( )
self.Match(m, self.QUANT_OP, self.FUSED_OP)))
class Match(QuantMultiOutputMatch):
def process(self):
# Find the nodes in the match that we need to rebind
rms_node = self.find_auto_fn(RMS_ADD_OP)
quant_node = self.find_auto_fn(self.QUANT_OP)
assert len(rms_node.users) == 2
assert len(quant_node.users) == 1
# First, insert a new auto_functionalized node for the fused op,
# as well as getitem nodes to extract the result and residual.
# The auto_fn node returns a tuple of (None, result, residual).
#
# The resulting graph looks like this:
# at = auto_functionalized(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ...) # noqa
# result_node_new = at[1]
# residual_node_new = at[2]
with self.inserting_after_match():
# Missing epsilon, scalars cannot be inputs to the pattern
kwargs = self.match.kwargs.copy()
# 0 is always None
fused_return_mapping = {1: (quant_node, 1), 2: (rms_node, 2)}
self.insert_fused_node(fused_return_mapping,
**kwargs,
epsilon=rms_node.kwargs["epsilon"])
class RMSNormDynamicQuantPattern(RMSNormQuantPattern): class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
@ -317,8 +222,7 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
symmetric=symmetric)) symmetric=symmetric))
super().__init__(epsilon, key) super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass, def register(self, pm_pass: PatternMatcherPass):
record_match: Callable[[MultiOutputMatch], bool]):
def pattern(result: torch.Tensor, result_rms: torch.Tensor, def pattern(result: torch.Tensor, result_rms: torch.Tensor,
input: torch.Tensor, weight: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
@ -366,39 +270,7 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
inputs, inputs,
pm.fwd_only, pm.fwd_only,
pm_pass, pm_pass,
extra_check=lambda m: record_match( )
self.Match(m, self.QUANT_OP, self.FUSED_OP)))
class Match(QuantMultiOutputMatch):
def process(self):
# Find the nodes in the match that we need to rebind
rms_node = self.find_auto_fn(RMS_OP)
quant_node = self.find_auto_fn(self.QUANT_OP)
assert len(rms_node.users) == 1
assert len(quant_node.users) == 2
# First, insert a new auto_functionalized node for the fused op,
# as well as getitem nodes to extract the result and scale.
# The auto_fn node returns a tuple of (None, result, scale).
#
# The resulting graph looks like this:
# at = auto_functionalized(torch.ops._C.rms_norm_dynamic_per_token_quant.default, ...) # noqa
# result_node_new = at[1]
# scale_node_new = at[2]
with self.inserting_after_match():
# Missing epsilon, scalars cannot be inputs to the pattern
kwargs = self.match.kwargs.copy()
del kwargs["result_rms"] # not used in the fused op
fused_return_mapping = {1: (quant_node, 1), 2: (quant_node, 2)}
self.insert_fused_node(
fused_return_mapping,
epsilon=rms_node.kwargs["epsilon"],
scale_ub=None, # not used but required
residual=None, # not used but required
**kwargs)
class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern): class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
@ -415,8 +287,7 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
symmetric=symmetric)) symmetric=symmetric))
super().__init__(epsilon, key) super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass, def register(self, pm_pass: PatternMatcherPass):
record_match: Callable[[MultiOutputMatch], bool]):
def pattern(result: torch.Tensor, input: torch.Tensor, def pattern(result: torch.Tensor, input: torch.Tensor,
residual: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
@ -464,137 +335,49 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
inputs, inputs,
pm.fwd_only, pm.fwd_only,
pm_pass, pm_pass,
extra_check=lambda m: record_match( )
self.Match(m, self.QUANT_OP, self.FUSED_OP)))
class Match(QuantMultiOutputMatch):
def process(self):
# Find the nodes in the match that we need to rebind
rms_node = self.find_auto_fn(RMS_ADD_OP)
quant_node = self.find_auto_fn(self.QUANT_OP)
assert len(rms_node.users) == 2
assert len(quant_node.users) == 2
# First, insert a new auto_functionalized node for the fused op,
# as well as getitem nodes to extract result, scale, and residual.
# The auto_fn node returns a tuple (None, result, scale, residual).
#
# The resulting graph looks like this:
# at = auto_functionalized(torch.ops._C.rms_norm_dynamic_per_token_quant.default, ...) # noqa
# result_node_new = at[1]
# scale_node_new = at[2]
# residual_node_new = at[3]
with self.inserting_after_match():
# Missing epsilon, scalars cannot be inputs to the pattern
kwargs = self.match.kwargs.copy()
fused_return_mapping = {
1: (quant_node, 1), # result
2: (quant_node, 2), # scale
3: (rms_node, 2), # residual
}
self.insert_fused_node(
fused_return_mapping,
epsilon=rms_node.kwargs["epsilon"],
scale_ub=None, # not used but required
**kwargs)
class FusionPass(VllmInductorPass): class RMSNormQuantFusionPass(VllmPatternMatcherPass):
""" """
This pass fuses a pre-defined set of custom ops into fused ops. This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
It uses the torch pattern matcher to find the patterns and replace them. It also supports fused_add_rms_norm.
It also manually processes multi-output matches, as those are broken in
the torch pattern matcher.
Because patterns can only be registered once, the pass is a singleton.
This will be addressed in a future version of PyTorch:
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
""" """
_instance: 'Optional[FusionPass]' = None
@classmethod
def instance(cls, config: VllmConfig):
"""
Get the singleton instance of the FusionPass.
If the instance exists, the config is updated but
initialization is not repeated.
"""
if cls._instance is None:
cls._instance = FusionPass(config)
else:
cls._instance.pass_config = config.compilation_config.pass_config
return cls._instance
@enable_fake_mode @enable_fake_mode
def __init__(self, config: VllmConfig): def __init__(self, config: VllmConfig):
assert self.__class__._instance is None, \
"FusionPass singleton instance already exists"
super().__init__(config) super().__init__(config)
self.matches: list[MultiOutputMatch] = []
self.patterns: PatternMatcherPass = PatternMatcherPass( self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="fusion_pass") pass_name="rmsnorm_quant_fusion_pass")
for epsilon in [1e-5, 1e-6]: for epsilon in [1e-5, 1e-6]:
# Fuse rms_norm + static fp8 quant # Fuse rms_norm + static fp8 quant
RMSNormStaticQuantPattern(epsilon, RMSNormStaticQuantPattern(epsilon,
FP8_DTYPE).register(self.patterns) FP8_DTYPE).register(self.patterns)
# Matches for patterns below have 2 or more outputs, # Fuse fused_add_rms_norm + static fp8 quant
# so we need to process them manually (see process_matches)
# Fuse rms_norm + static fp8 quant
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register( FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns, self.record_match) self.patterns)
# Fuse rms_norm + dynamic per-token fp8 quant # Fuse rms_norm + dynamic per-token fp8 quant
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register( RMSNormDynamicQuantPattern(epsilon,
self.patterns, self.record_match) FP8_DTYPE).register(self.patterns)
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant # Fuse fused_add_rms_norm + dynamic per-token fp8 quant
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register( FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns, self.record_match) self.patterns)
# WARNING: This is a hack to clear the pattern matcher cache self.dump_patterns(config, self.patterns)
# and allow multiple values of epsilon.
torch._inductor.pattern_matcher._seen_patterns.clear()
def record_match(self, match: MultiOutputMatch) -> bool:
# Hijack the extra_check to record the match and
# save it for post-processing.
self.matches.append(match)
# Return False to prevent automatic replacement.
return False
def process_matches(self, graph: fx.Graph):
"""
Manually process multi-output matches and replace them with fused nodes.
See MultiOutputMatch for more details.
"""
for match in self.matches:
match.process()
# Finally, remove matched nodes
graph.eliminate_dead_code()
assert all(node not in graph.nodes for match in self.matches
for node in match.match.nodes)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph): def __call__(self, graph: fx.Graph):
self.begin() self.matched_count = self.patterns.apply(graph)
self.dump_graph(graph, "before_fusion") logger.debug("Replaced %s patterns", self.matched_count)
count = self.patterns.apply(graph) def uuid(self) -> Any:
logger.debug("Replaced %s patterns", count) return self.hash_source(self, RMSNormQuantPattern,
self.dump_graph(graph, "after_pattern_match") RMSNormStaticQuantPattern,
RMSNormDynamicQuantPattern,
# Manually process multi-output matches (and run DCE) FusedAddRMSNormStaticQuantPattern,
self.process_matches(graph) FusedAddRMSNormDynamicQuantPattern)
logger.debug("Post-processed %s matches", len(self.matches))
self.dump_graph(graph, "after_fusion")
self.matches.clear()
self.end_and_log()

View File

@ -18,7 +18,7 @@ from vllm.utils import round_up
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32 from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
from .inductor_pass import enable_fake_mode from .inductor_pass import enable_fake_mode
from .vllm_inductor_pass import VllmInductorPass from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__) logger = init_logger(__name__)
@ -245,7 +245,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
pm_pass) pm_pass)
class AttnFusionPass(VllmInductorPass): class AttnFusionPass(VllmPatternMatcherPass):
""" """
This pass fuses post-attention quantization onto attention if supported. This pass fuses post-attention quantization onto attention if supported.
@ -282,20 +282,12 @@ class AttnFusionPass(VllmInductorPass):
"were found in CompilationConfig.static_forward_context " "were found in CompilationConfig.static_forward_context "
"so no fusion patterns were registered.") "so no fusion patterns were registered.")
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.graph.Graph) -> None: def __call__(self, graph: torch.fx.graph.Graph) -> None:
self.begin() self.matched_count = self.patterns.apply(graph)
self.dump_graph(graph, "before_attn_fusion") logger.debug("Fused quant onto %s attention nodes", self.matched_count)
count = self.patterns.apply(graph)
# TODO: Move this to pass_manager.py after the fx graph broken issue
# has been resolved.
# see https://github.com/vllm-project/vllm/issues/23091
graph.eliminate_dead_code()
logger.debug("Fused quantization onto %s attention nodes", count)
self.dump_graph(graph, "after_attn_fusion")
self.end_and_log()
def uuid(self): def uuid(self):
return VllmInductorPass.hash_source(self, AttentionQuantPattern, return VllmInductorPass.hash_source(self, AttentionQuantPattern,

View File

@ -1,109 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import abc
import operator
from abc import abstractmethod
from collections.abc import Iterable
from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor import pattern_matcher as pm
from torch._ops import OpOverload
from torch.fx import Node
from vllm.compilation.fx_utils import find_auto_fn
class MultiOutputMatch(abc.ABC):
"""
This class provides utilities to process multi-output matches and
manually insert replacements.
This is necessary because the automatic replacement for multi-output
matches is broken: https://github.com/pytorch/pytorch/issues/137280
"""
def __init__(self, match: pm.Match):
self.match = match
@abstractmethod
def process(self):
"""
Process a multi-output match and manually insert the replacement.
This method should:
1. Insert the replacement nodes after the last node in the match.
2. Rebind the users of nodes in the match to use the new nodes.
3. Set meta["val"] for de-functionalization.
The result of an auto-functionalized node is a tuple of tensors.
The first element is the return value of the function, usually None.
The remaining elements are the mutated args of the function.
All auto-functionalized nodes must contain a proper meta["val"],
as it is used by de-functionalization. meta["val"] has to contain the
value of the node (tuple of tensors) that would be returned by the
functionalized node during tracing.
Existing nodes in the graph all have this property set, but we have
to set it manually for new nodes we insert.
Example:
# op schema: foo(a: Tensor!, b: Tensor, c: Tensor!) -> None
at = auto_functionalized(torch.ops._C.foo.default, a, b, c)
# at.meta["val"] = (None, a, c)
"""
raise NotImplementedError
@property
def nodes(self) -> list[fx.Node]:
return self.match.nodes
@property
def graph(self) -> fx.Graph:
return self.match.graph
def find_auto_fn(self, op) -> fx.Node:
"""
Find the first auto_functionalized node with the given op in the match.
"""
return find_auto_fn(self.nodes, op)
def inserting_after_match(self):
"""
Insert nodes after the last node in the match.
This is done to avoid use-before-definition errors after inserting
replacement nodes.
"""
# match.nodes is not guaranteed to be sorted.
# Find the last node in the match.
for last_node_in_match in reversed(self.graph.nodes):
if last_node_in_match in self.match.nodes:
break
else:
raise ValueError("No nodes in graph")
return self.graph.inserting_after(last_node_in_match)
def insert_getitems(self, tuple_node: fx.Node,
indices: Iterable[int]) -> tuple[fx.Node, ...]:
"""
Insert operator.getitem nodes to extract elements from a tuple node.
:param tuple_node: The tuple node to extract elements from.
:param indices: The indices of the elements to extract.
:return: Tuple of the new getitem nodes, corresponding to the indices.
"""
with self.graph.inserting_after(tuple_node):
return tuple(
self.graph.call_function(operator.getitem, (tuple_node, idx))
for idx in indices)
def insert_auto_fn(self, op: OpOverload, kwargs) -> Node:
"""
Insert an auto_functionalized node with the given op and kwargs.
"""
return self.graph.call_function(auto_functionalized, (op, ),
kwargs=kwargs)

View File

@ -64,9 +64,8 @@ class NoOpEliminationPass(VllmInductorPass):
out: "f16[s0, 4096]" = at[1] out: "f16[s0, 4096]" = at[1]
""" """
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph): def __call__(self, graph: torch.fx.Graph):
self.begin()
self.dump_graph(graph, "before_noop_elimination")
count = 0 count = 0
# Remove no-op reshapes/views: # Remove no-op reshapes/views:
for node in graph.nodes: for node in graph.nodes:
@ -121,8 +120,6 @@ class NoOpEliminationPass(VllmInductorPass):
count += 1 count += 1
logger.debug("Removed %s no-op reshapes and slices", count) logger.debug("Removed %s no-op reshapes and slices", count)
self.dump_graph(graph, "after_noop_elimination")
self.end_and_log()
# ---------------------- Reshape helpers ---------------------- # ---------------------- Reshape helpers ----------------------
def reshape_dims_equivalent(self, dim: Union[int, torch.fx.Node], def reshape_dims_equivalent(self, dim: Union[int, torch.fx.Node],

View File

@ -1,15 +1,21 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from torch import fx as fx from torch import fx as fx
from vllm import envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import set_env_var
from .post_cleanup import PostCleanupPass
from .vllm_inductor_pass import VllmInductorPass
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
from .activation_quant_fusion import ActivationQuantFusionPass from .activation_quant_fusion import ActivationQuantFusionPass
from .fusion import FusionPass from .fusion import RMSNormQuantFusionPass
from .fusion_attn import AttnFusionPass from .fusion_attn import AttnFusionPass
if current_platform.is_cuda(): if current_platform.is_cuda():
@ -19,11 +25,28 @@ from .fix_functionalization import FixFunctionalizationPass
from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context
from .noop_elimination import NoOpEliminationPass from .noop_elimination import NoOpEliminationPass
from .sequence_parallelism import SequenceParallelismPass from .sequence_parallelism import SequenceParallelismPass
from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__) logger = init_logger(__name__)
def with_pattern_match_debug(fn):
"""
Function decorator that turns on inductor pattern match debug
for the duration of the call.
Used to avoid logging builtin Inductor pattern matching.
"""
@functools.wraps(fn)
def wrapper(*args, **kwargs):
if (debug_val := envs.VLLM_PATTERN_MATCH_DEBUG) is not None:
# optionally check rank here
with set_env_var("TORCHINDUCTOR_PATTERN_MATCH_DEBUG", debug_val):
return fn(*args, **kwargs)
return fn(*args, **kwargs)
return wrapper
class PostGradPassManager(CustomGraphPass): class PostGradPassManager(CustomGraphPass):
""" """
The pass manager for post-grad passes. The pass manager for post-grad passes.
@ -40,16 +63,26 @@ class PostGradPassManager(CustomGraphPass):
""" """
def __init__(self): def __init__(self):
self.passes: list[VllmInductorPass] = [] self.passes: list[InductorPass] = []
@with_pattern_match_debug
def __call__(self, graph: fx.Graph): def __call__(self, graph: fx.Graph):
VllmInductorPass.dump_prefix = 0 # reset dump index
shape = get_pass_context().runtime_shape shape = get_pass_context().runtime_shape
for pass_ in self.passes: for pass_ in self.passes:
if pass_.is_applicable_for_shape(shape): if pass_.is_applicable_for_shape(shape):
pass_(graph) pass_(graph)
VllmInductorPass.dump_prefix += 1
# post-cleanup goes before fix_functionalization
# because it requires a functional graph
self.post_cleanup(graph)
VllmInductorPass.dump_prefix += 1
# always run fix_functionalization last # always run fix_functionalization last
self.fix_functionalization(graph) self.fix_functionalization(graph)
VllmInductorPass.dump_prefix = None # Cleanup index
def configure(self, config: VllmConfig): def configure(self, config: VllmConfig):
self.pass_config = config.compilation_config.pass_config self.pass_config = config.compilation_config.pass_config
@ -61,14 +94,18 @@ class PostGradPassManager(CustomGraphPass):
if self.pass_config.enable_async_tp: if self.pass_config.enable_async_tp:
self.passes += [AsyncTPPass(config)] self.passes += [AsyncTPPass(config)]
if self.pass_config.enable_fi_allreduce_fusion:
self.passes += [AllReduceFusionPass(config)]
if self.pass_config.enable_fusion: if self.pass_config.enable_fusion:
self.passes += [FusionPass.instance(config)] self.passes += [RMSNormQuantFusionPass(config)]
self.passes += [ActivationQuantFusionPass(config)] self.passes += [ActivationQuantFusionPass(config)]
if self.pass_config.enable_attn_fusion: if self.pass_config.enable_attn_fusion:
self.passes += [AttnFusionPass(config)] self.passes += [AttnFusionPass(config)]
if self.pass_config.enable_fi_allreduce_fusion:
self.passes += [AllReduceFusionPass(config)] # needs a functional graph
self.post_cleanup = PostCleanupPass(config)
self.fix_functionalization = FixFunctionalizationPass(config) self.fix_functionalization = FixFunctionalizationPass(config)
def add(self, pass_: InductorPass): def add(self, pass_: InductorPass):

View File

@ -0,0 +1,20 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from torch import fx
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
class PostCleanupPass(VllmInductorPass):
"""
This pass performs cleanup after custom passes.
It topologically sorts the graph and removes unused nodes.
This is needed because the pattern matcher does not guarantee producing
a topologically sorted graph, and there may be unused nodes left around.
"""
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
from torch._inductor.pattern_matcher import stable_topological_sort
stable_topological_sort(graph)
graph.eliminate_dead_code()

View File

@ -15,7 +15,7 @@ from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .inductor_pass import enable_fake_mode from .inductor_pass import enable_fake_mode
from .vllm_inductor_pass import VllmInductorPass from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__) logger = init_logger(__name__)
@ -417,7 +417,7 @@ class LastAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
pm.fwd_only, pm_pass) pm.fwd_only, pm_pass)
class SequenceParallelismPass(VllmInductorPass): class SequenceParallelismPass(VllmPatternMatcherPass):
""" """
This pass enables sequence parallelism for models. This pass enables sequence parallelism for models.
It identifies patterns where an AllReduce operation is followed by It identifies patterns where an AllReduce operation is followed by
@ -466,19 +466,13 @@ class SequenceParallelismPass(VllmInductorPass):
LastAllReduceRMSNormPattern(epsilon, self.model_dtype, LastAllReduceRMSNormPattern(epsilon, self.model_dtype,
self.device).register(self.patterns) self.device).register(self.patterns)
self.dump_patterns(config, self.patterns)
# WARNING: This is a hack to clear the pattern matcher cache
# and allow multiple values of epsilon.
torch._inductor.pattern_matcher._seen_patterns.clear()
def is_applicable_for_shape(self, shape: Optional[int]) -> bool: def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
return shape is not None and shape % tp_size == 0 return shape is not None and shape % tp_size == 0
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph): def __call__(self, graph: fx.Graph):
self.begin() self.matched_count = self.patterns.apply(graph)
self.dump_graph(graph, "before_sequence_parallelism_pass") logger.debug("Replaced %s patterns", self.matched_count)
count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns with sequence parallelism", count)
self.dump_graph(graph, "after_sequence_parallelism_pass")
self.end_and_log()

View File

@ -1,10 +1,16 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import operator
import time import time
from pathlib import Path
from typing import ClassVar, Optional
import regex as re
import torch import torch
from torch._dynamo.utils import lazy_format_graph_code from torch._dynamo.utils import lazy_format_graph_code
from torch._inductor.pattern_matcher import (PatternMatcherPass,
PatternPrettyPrinter)
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
@ -19,6 +25,8 @@ class VllmInductorPass(InductorPass):
An inductor pass with access to vLLM PassConfig. An inductor pass with access to vLLM PassConfig.
It provides timing, logging, and dumping utilities. It provides timing, logging, and dumping utilities.
""" """
dump_prefix: ClassVar[Optional[int]] = None
"""Keep track of pass index for debug dump ordering."""
def __init__(self, config: VllmConfig): def __init__(self, config: VllmConfig):
self.pass_config = config.compilation_config.pass_config self.pass_config = config.compilation_config.pass_config
@ -28,8 +36,24 @@ class VllmInductorPass(InductorPass):
else None else None
self.pass_name = self.__class__.__name__ self.pass_name = self.__class__.__name__
@staticmethod
def time_and_log(call_fn):
@functools.wraps(call_fn)
def wrapped(self: VllmInductorPass, graph: torch.fx.Graph):
self.begin()
self.dump_graph(graph, "before")
call_fn(self, graph)
self.dump_graph(graph, "after")
self.end_and_log()
return wrapped
def dump_graph(self, graph: torch.fx.Graph, stage: str): def dump_graph(self, graph: torch.fx.Graph, stage: str):
lazy_format_graph_code(stage, graph.owning_module) i = VllmInductorPass.dump_prefix
i_str = "" if i is None else f".{i}"
lazy_format_graph_code(f"post_grad{i_str}.{self.pass_name}.{stage}",
graph.owning_module)
def begin(self): def begin(self):
self._start_time = time.perf_counter_ns() self._start_time = time.perf_counter_ns()
@ -40,6 +64,88 @@ class VllmInductorPass(InductorPass):
logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms) logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms)
class VllmPatternMatcherPass(VllmInductorPass):
"""
A VllmInductorPass that uses the Inductor pattern matcher.
Its main use is providing the dump_patterns utility that dumps the
Inductor pattern matcher patterns into a file, which greatly aids debugging.
TODO(luka) move more utilities to this pass.
"""
matched_count: int = 0
"""The number of matched patterns in the pass."""
_OP_OVERLOAD_PATTERN: ClassVar[re.Pattern] = re.compile(
r"<OpOverload\(op='([^']*)', overload='([^']*)'\)>")
def _replace_op_overloads(self, string: str) -> str:
"""Replace <OpOverload(..., ...)> with nicer formulations"""
return self._OP_OVERLOAD_PATTERN.sub(
lambda m: f"torch.ops.{m.group(1)}.{m.group(2)}",
string,
)
def dump_patterns(self, config: VllmConfig, pm_pass: PatternMatcherPass):
"""
If debug dumping is enabled, dump the Inductor pattern-matcher patterns
into the debug_dump_path folder next to the dumped fx graphs.
This method does its best to print something that looks like Python code
for easier debugging and potentially navigation. If any errors appear in
the output, please add to this method.
TODO(luka): use pattern object to manually produce pattern graph
"""
debug_dump_path = config.compilation_config.debug_dump_path
if not debug_dump_path:
return
rank = config.parallel_config.rank
debug_dump_path = Path(debug_dump_path) / f"rank_{rank}"
debug_dump_path.mkdir(parents=True, exist_ok=True)
from vllm.utils import unique_filepath
file_path = unique_filepath(
lambda i: debug_dump_path / f"patterns.{self.pass_name}.{i}.py")
with file_path.open("w") as f:
print(
f'# This file was produced by VllmPatternMatcherPass.'
f'dump_patterns for {self.pass_name}.\n'
f'# It does its best to produce valid-Python-looking code but'
f' please add to dump_patterns if there are any errors.\n\n'
f'from torch._higher_order_ops.auto_functionalize import '
f'auto_functionalized as auto_functionalized\n'
f'from torch._inductor.pattern_matcher import *',
file=f)
for node, patterns in pm_pass.patterns.items():
# fix the operator.getitem repr
if node[1] == operator.getitem:
node_repr = f"({repr(node[0])}, operator.getitem)"
else:
node_repr = repr(node)
node_repr = self._replace_op_overloads(node_repr)
print(f"\n\n# Patterns for op: {node_repr}", file=f)
for i, pattern in enumerate(patterns):
# reserve auto_functionalized ahead of time
pp = PatternPrettyPrinter()
pp.namespace.create_name("auto_functionalized", None)
# Assemble pattern
out_node = pp.pretty_print(pattern.pattern)
pattern_repr = "\n".join([f"def pattern_{i}():"] + [
f"{pp.memoized_objs_names[key]} = "
f"{pp.memoized_objs_pp[key]}"
for key in pp.memoized_objs_names
] + [f"return {out_node}"]).replace("\n", "\n ")
pattern_repr = self._replace_op_overloads(pattern_repr)
print(f"{pattern_repr}\n", file=f)
class PrinterInductorPass(VllmInductorPass): class PrinterInductorPass(VllmInductorPass):
def __init__(self, name: str, config: VllmConfig): def __init__(self, name: str, config: VllmConfig):

View File

@ -503,7 +503,7 @@ class VllmConfig:
if self.compilation_config.pass_config.enable_sequence_parallelism: if self.compilation_config.pass_config.enable_sequence_parallelism:
self.compilation_config.custom_ops.append("+rms_norm") self.compilation_config.custom_ops.append("+rms_norm")
if current_platform.is_cuda_alike() or current_platform.is_xpu(): if current_platform.support_static_graph_mode():
# if cudagraph_mode is not explicitly set by users, set default # if cudagraph_mode is not explicitly set by users, set default
# value # value
if self.compilation_config.cudagraph_mode is None: if self.compilation_config.cudagraph_mode is None:
@ -905,10 +905,9 @@ def set_current_vllm_config(vllm_config: VllmConfig,
except Exception: except Exception:
raise raise
else: else:
logger.debug("enabled custom ops: %s", if check_compile:
vllm_config.compilation_config.enabled_custom_ops) vllm_config.compilation_config.custom_op_log_check()
logger.debug("disabled custom ops: %s",
vllm_config.compilation_config.disabled_custom_ops)
if check_compile and \ if check_compile and \
vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \ vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \
and compilation_counter.num_models_seen == num_models_seen: and compilation_counter.num_models_seen == num_models_seen:

View File

@ -487,6 +487,12 @@ class CompilationConfig:
"supported with torch>=2.9.0.dev. Set " "supported with torch>=2.9.0.dev. Set "
"use_inductor_graph_partition=False instead.") "use_inductor_graph_partition=False instead.")
for op in self.custom_ops:
if op[0] not in {'+', '-'} and op not in {'all', 'none'}:
raise ValueError(f"Invalid syntax '{op}' for custom op, "
"must be 'all', 'none', '+op' or '-op' "
"(where 'op' is the registered op name)")
def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]: def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
if self.level == CompilationLevel.NO_COMPILATION: if self.level == CompilationLevel.NO_COMPILATION:
raise ValueError("No compilation level is set.") raise ValueError("No compilation level is set.")
@ -532,8 +538,8 @@ class CompilationConfig:
for x in self.compile_sizes: for x in self.compile_sizes:
if isinstance(x, str): if isinstance(x, str):
assert x == "cudagraph_capture_sizes", \ assert x == "cudagraph_capture_sizes", \
"Unrecognized size type in compile_sizes, " \ "Unrecognized size type in compile_sizes, " \
f"expect 'cudagraph_capture_sizes', got {x}" f"expect 'cudagraph_capture_sizes', got {x}"
computed_compile_sizes.extend(self.cudagraph_capture_sizes) computed_compile_sizes.extend(self.cudagraph_capture_sizes)
else: else:
assert isinstance(x, int) assert isinstance(x, int)
@ -628,3 +634,41 @@ class CompilationConfig:
return use_fx_graph_piecewise_compilation or \ return use_fx_graph_piecewise_compilation or \
use_inductor_piecewise_compilation use_inductor_piecewise_compilation
def custom_op_log_check(self):
"""
This method logs the enabled/disabled custom ops and checks that the
passed custom_ops field only contains relevant ops.
It is called at the end of set_current_vllm_config,
after the custom ops have been instantiated.
"""
if len(self.enabled_custom_ops) + len(self.disabled_custom_ops) == 0:
logger.debug("No custom ops found in model.")
return
logger.debug("enabled custom ops: %s", self.enabled_custom_ops)
logger.debug("disabled custom ops: %s", self.disabled_custom_ops)
all_ops_in_model = (self.enabled_custom_ops | self.disabled_custom_ops)
for op in self.custom_ops:
if op in {"all", "none"}:
continue
assert op[0] in {'+', '-'}, "Invalid custom op syntax " \
"(should be checked during init)"
# check if op name exists in model
op_name = op[1:]
if op_name not in all_ops_in_model:
from vllm.model_executor.custom_op import CustomOp
# Does op exist at all or is it just not present in this model?
# Note: Only imported op classes appear in the registry.
missing_str = "doesn't exist (or wasn't imported/registered)" \
if op_name not in CustomOp.op_registry \
else "not present in model"
enable_str = "enabling" if op[0] == '+' else "disabling"
logger.warning_once("Op '%s' %s, %s with '%s' has no effect",
op_name, missing_str, enable_str, op)

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib import hashlib
import os
from dataclasses import field from dataclasses import field
from typing import TYPE_CHECKING, Any, Literal, Optional, Union from typing import TYPE_CHECKING, Any, Literal, Optional, Union
@ -351,6 +352,10 @@ class ParallelConfig:
self.world_size = self.pipeline_parallel_size * \ self.world_size = self.pipeline_parallel_size * \
self.tensor_parallel_size self.tensor_parallel_size
if self.distributed_executor_backend == "external_launcher":
logger.info("Using external launcher for distributed inference.")
self.world_size *= self.data_parallel_size
if self.data_parallel_size_local > self.data_parallel_size: if self.data_parallel_size_local > self.data_parallel_size:
raise ValueError( raise ValueError(
f"data_parallel_size_local ({self.data_parallel_size_local}) " f"data_parallel_size_local ({self.data_parallel_size_local}) "
@ -358,6 +363,13 @@ class ParallelConfig:
if self.data_parallel_size > 1 or self.data_parallel_size_local == 0: if self.data_parallel_size > 1 or self.data_parallel_size_local == 0:
# Data parallel was specified in the engine args. # Data parallel was specified in the engine args.
if self.distributed_executor_backend == "external_launcher":
# For external launcher,
# we need to set the data parallel rank automatically
self.data_parallel_rank = int(os.environ["RANK"]) \
// (self.world_size // self.data_parallel_size)
logger.info("Set data_parallel_rank to %d automatically.",
self.data_parallel_rank)
if not self._data_parallel_master_port_list: if not self._data_parallel_master_port_list:
self._data_parallel_master_port_list = get_open_ports_list(5) self._data_parallel_master_port_list = get_open_ports_list(5)
self.data_parallel_master_port = \ self.data_parallel_master_port = \
@ -380,7 +392,6 @@ class ParallelConfig:
"be set when data_parallel_size > 1") "be set when data_parallel_size > 1")
if self.distributed_executor_backend == "external_launcher": if self.distributed_executor_backend == "external_launcher":
import os
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
logger.info("Disabling V1 multiprocessing for external launcher.") logger.info("Disabling V1 multiprocessing for external launcher.")

View File

@ -527,7 +527,7 @@ class SpeculativeConfig:
"speculative decoding is > 1, but got " "speculative decoding is > 1, but got "
f"{self.disable_by_batch_size=}") f"{self.disable_by_batch_size=}")
eagle3_target_supported = ["llama", "qwen"] eagle3_target_supported = ["llama", "qwen", "gpt_oss"]
if self.method == "eagle3" and self.target_model_config and not any( if self.method == "eagle3" and self.target_model_config and not any(
supported_model in supported_model in
self.target_model_config.hf_text_config.model_type self.target_model_config.hf_text_config.model_type

View File

@ -25,6 +25,12 @@ class XpuCommunicator(DeviceCommunicatorBase):
super().__init__(cpu_group, device, device_group, unique_name) super().__init__(cpu_group, device, device_group, unique_name)
if self.use_all2all: if self.use_all2all:
all2all_backend = envs.VLLM_ALL2ALL_BACKEND all2all_backend = envs.VLLM_ALL2ALL_BACKEND
if all2all_backend != "naive":
logger.warning(
"`%s` all2all manager is not supported on XPU."
"Falling back to `naive` all2all manager for XPU.",
all2all_backend)
all2all_backend = "naive"
if all2all_backend == "naive": if all2all_backend == "naive":
from .all2all import NaiveAll2AllManager from .all2all import NaiveAll2AllManager
self.all2all_manager = NaiveAll2AllManager(self.cpu_group) self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
@ -67,3 +73,16 @@ class XpuCommunicator(DeviceCommunicatorBase):
def broadcast(self, input_: torch.Tensor, src: int = 0) -> None: def broadcast(self, input_: torch.Tensor, src: int = 0) -> None:
dist.broadcast(input_, src=src, group=self.device_group) dist.broadcast(input_, src=src, group=self.device_group)
def dispatch(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
assert self.all2all_manager is not None
hidden_states, router_logits = self.all2all_manager.dispatch(
hidden_states, router_logits)
return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine(hidden_states)
return hidden_states

View File

@ -58,6 +58,12 @@ except ImportError:
logger.warning("NIXL is not available") logger.warning("NIXL is not available")
NixlWrapper = None NixlWrapper = None
try:
from nixl._api import nixl_agent_config
except ImportError:
nixl_agent_config = None
logger.warning("NIXL agent config is not available")
# Supported platforms and types of kv transfer buffer. # Supported platforms and types of kv transfer buffer.
# {device: tuple of supported kv buffer types} # {device: tuple of supported kv buffer types}
_NIXL_SUPPORTED_DEVICE = { _NIXL_SUPPORTED_DEVICE = {
@ -65,6 +71,8 @@ _NIXL_SUPPORTED_DEVICE = {
"tpu": ("cpu", ), "tpu": ("cpu", ),
"xpu": ("cpu", ), "xpu": ("cpu", ),
} }
# support for oot platform by providing mapping in current_platform
_NIXL_SUPPORTED_DEVICE.update(current_platform.get_nixl_supported_devices())
class NixlAgentMetadata( class NixlAgentMetadata(
@ -242,6 +250,10 @@ class NixlConnector(KVConnectorBase_V1):
self.connector_worker.copy_blocks: self.connector_worker.copy_blocks:
self.connector_worker.save_kv_to_host(self._connector_metadata) self.connector_worker.save_kv_to_host(self._connector_metadata)
def shutdown(self):
if self.connector_worker is not None:
self.connector_worker.shutdown()
class NixlConnectorScheduler: class NixlConnectorScheduler:
"""Implementation of Scheduler side methods""" """Implementation of Scheduler side methods"""
@ -448,8 +460,15 @@ class NixlConnectorWorker:
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size self.block_size = vllm_config.cache_config.block_size
self.nixl_backends = \
vllm_config.kv_transfer_config.get_from_extra_config(
"backends", ["UCX"])
# Agent. # Agent.
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None) non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"]
config = nixl_agent_config(backends=self.nixl_backends) if len(
non_ucx_backends) > 0 and nixl_agent_config is not None else None
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), config)
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}. # Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict) self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict)
@ -486,11 +505,15 @@ class NixlConnectorWorker:
# used when device memory can not be registered under nixl # used when device memory can not be registered under nixl
self.host_xfer_buffers: dict[str, torch.Tensor] = {} self.host_xfer_buffers: dict[str, torch.Tensor] = {}
self.use_host_buffer = self.kv_buffer_device == "cpu" self.use_host_buffer = self.kv_buffer_device == "cpu"
if self.kv_buffer_device == "cuda": # support for oot platform which can't register nixl memory
self.nixl_memory_type = "VRAM" # type based on kv_buffer_device
elif self.kv_buffer_device == "cpu": self.nixl_memory_type = current_platform.get_nixl_memory_type()
self.nixl_memory_type = "DRAM" if self.nixl_memory_type is None:
else: if self.kv_buffer_device == "cuda":
self.nixl_memory_type = "VRAM"
elif self.kv_buffer_device == "cpu":
self.nixl_memory_type = "DRAM"
if self.nixl_memory_type is None:
raise RuntimeError( raise RuntimeError(
f"{self.device_type} with {self.kv_buffer_device} kv_buffer " f"{self.device_type} with {self.kv_buffer_device} kv_buffer "
"is not supported.") "is not supported.")
@ -567,13 +590,6 @@ class NixlConnectorWorker:
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int) self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
self.xfer_stats = NixlKVConnectorStats() self.xfer_stats = NixlKVConnectorStats()
def __del__(self):
"""Cleanup background threads on destruction."""
if executor := getattr(self, "_handshake_initiation_executor", None):
executor.shutdown(wait=False)
if listener_t := getattr(self, "_nixl_handshake_listener_t", None):
listener_t.join(timeout=0)
@staticmethod @staticmethod
def _nixl_handshake_listener(metadata: NixlAgentMetadata, def _nixl_handshake_listener(metadata: NixlAgentMetadata,
ready_event: threading.Event, base_port: int, ready_event: threading.Event, base_port: int,
@ -766,7 +782,7 @@ class NixlConnectorWorker:
descs = self.nixl_wrapper.get_reg_descs(caches_data, descs = self.nixl_wrapper.get_reg_descs(caches_data,
self.nixl_memory_type) self.nixl_memory_type)
logger.debug("Registering descs: %s", caches_data) logger.debug("Registering descs: %s", caches_data)
self.nixl_wrapper.register_memory(descs) self.nixl_wrapper.register_memory(descs, backends=self.nixl_backends)
logger.debug("Done registering descs") logger.debug("Done registering descs")
self._registered_descs.append(descs) self._registered_descs.append(descs)
@ -1327,6 +1343,30 @@ class NixlConnectorWorker:
return self.xfer_stats.clone_and_reset() return self.xfer_stats.clone_and_reset()
return None return None
def shutdown(self):
"""Shutdown the connector worker."""
self._handshake_initiation_executor.shutdown(wait=False)
if self._nixl_handshake_listener_t is not None:
self._nixl_handshake_listener_t.join(timeout=0)
self._nixl_handshake_listener_t = None
for handles in self._recving_transfers.values():
for handle, _ in handles:
self.nixl_wrapper.release_xfer_handle(handle)
self._recving_transfers.clear()
if self.src_xfer_side_handle:
self.nixl_wrapper.release_dlist_handle(self.src_xfer_side_handle)
self.src_xfer_side_handle = 0
for dst_xfer_side_handle in self.dst_xfer_side_handles.values():
self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle)
self.dst_xfer_side_handles.clear()
for remote_agents in self._remote_agents.values():
for agent_name in remote_agents.values():
self.nixl_wrapper.remove_remote_agent(agent_name)
self._remote_agents.clear()
for desc in self._registered_descs:
self.nixl_wrapper.deregister_memory(desc)
self._registered_descs.clear()
@contextlib.contextmanager @contextlib.contextmanager
def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]:

View File

@ -178,6 +178,9 @@ class P2pNcclConnector(KVConnectorBase_V1):
# Load the KV for each request each layer # Load the KV for each request each layer
for request in metadata.requests: for request in metadata.requests:
request_id = request.request_id
ip, port = self.parse_request_id(request_id, False)
remote_address = ip + ":" + str(port + self._rank)
for layer_name in forward_context.no_compile_layers: for layer_name in forward_context.no_compile_layers:
layer = forward_context.no_compile_layers[layer_name] layer = forward_context.no_compile_layers[layer_name]
@ -191,7 +194,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
layer = kv_cache[forward_context.virtual_engine] layer = kv_cache[forward_context.virtual_engine]
kv_cache = self.p2p_nccl_engine.recv_tensor( kv_cache = self.p2p_nccl_engine.recv_tensor(
request.request_id + "#" + layer_name) request.request_id + "#" + layer_name, remote_address)
if kv_cache is None: if kv_cache is None:
logger.warning("🚧kv_cache is None, %s", request.request_id) logger.warning("🚧kv_cache is None, %s", request.request_id)

View File

@ -134,7 +134,6 @@ class P2pNcclEngine:
# PUT or PUT_ASYNC # PUT or PUT_ASYNC
# tensor_id: torch.Tensor # tensor_id: torch.Tensor
self.send_queue: deque[SendQueueItem] = deque() self.send_queue: deque[SendQueueItem] = deque()
self.send_request_id_to_tensor_ids: dict[str, set[str]] = {}
if self.send_type == "PUT_ASYNC": if self.send_type == "PUT_ASYNC":
self._send_thread = threading.Thread(target=self.send_async, self._send_thread = threading.Thread(target=self.send_async,
daemon=True) daemon=True)
@ -143,6 +142,7 @@ class P2pNcclEngine:
# tensor_id: torch.Tensor/(addr, dtype, shape) # tensor_id: torch.Tensor/(addr, dtype, shape)
self.recv_store: dict[str, Any] = {} self.recv_store: dict[str, Any] = {}
self.recv_request_id_to_tensor_ids: dict[str, set[str]] = {} self.recv_request_id_to_tensor_ids: dict[str, set[str]] = {}
self.send_request_id_to_tensor_ids: dict[str, set[str]] = {}
self.socks: dict[str, Any] = {} # remote_address: client socket self.socks: dict[str, Any] = {} # remote_address: client socket
self.comms: dict[str, Any] = {} # remote_address: (ncclComm_t, rank) self.comms: dict[str, Any] = {} # remote_address: (ncclComm_t, rank)
@ -223,18 +223,26 @@ class P2pNcclEngine:
# GET # GET
with self.send_store_cv: with self.send_store_cv:
tensor_size = tensor.element_size() * tensor.numel() tensor_size = tensor.element_size() * tensor.numel()
if tensor_size > self.buffer_size_threshold:
logger.warning(
"❗[GET]tensor_id:%s, tensor_size:%d, is greater than"
"buffer size threshold :%d, skip send to %s, rank:%d",
tensor_id, tensor_size, self.buffer_size_threshold,
remote_address, self.rank)
return False
while (self.buffer_size + tensor_size while (self.buffer_size + tensor_size
> self.buffer_size_threshold): > self.buffer_size_threshold):
oldest_tenser_id = next(iter(self.send_store)) assert len(self.send_store) > 0
oldest_tenser = self.send_store.pop(oldest_tenser_id) oldest_tensor_id = next(iter(self.send_store))
oldest_tenser_size = oldest_tenser.element_size( oldest_tensor = self.send_store.pop(oldest_tensor_id)
) * oldest_tenser.numel() oldest_tensor_size = oldest_tensor.element_size(
self.buffer_size -= oldest_tenser_size ) * oldest_tensor.numel()
logger.info( self.buffer_size -= oldest_tensor_size
logger.debug(
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d," "⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
" buffer_size:%d, oldest_tenser_size:%d, rank:%d", " buffer_size:%d, oldest_tensor_size:%d, rank:%d",
remote_address, tensor_id, tensor_size, self.buffer_size, remote_address, tensor_id, tensor_size, self.buffer_size,
oldest_tenser_size, self.rank) oldest_tensor_size, self.rank)
self.send_store[tensor_id] = tensor self.send_store[tensor_id] = tensor
self.buffer_size += tensor_size self.buffer_size += tensor_size

View File

@ -1032,7 +1032,9 @@ def init_distributed_environment(world_size: int = -1,
distributed_init_method, backend) distributed_init_method, backend)
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
config = get_current_vllm_config() config = get_current_vllm_config()
if config is not None and config.parallel_config.data_parallel_size > 1: if config is not None and config.parallel_config.data_parallel_size > 1 \
and config.parallel_config.distributed_executor_backend \
!= "external_launcher":
parallel_config = config.parallel_config parallel_config = config.parallel_config
# adjust to take into account data parallelism # adjust to take into account data parallelism
# offset the rank by the data parallel rank # offset the rank by the data parallel rank

View File

@ -1147,20 +1147,15 @@ class EngineArgs:
else: else:
envs.set_vllm_use_v1(use_v1) envs.set_vllm_use_v1(use_v1)
# Set default arguments for V0 or V1 Engine. # Set default arguments for V1 Engine.
if use_v1: self._set_default_args(usage_context, model_config)
self._set_default_args_v1(usage_context, model_config) # Disable chunked prefill for POWER (ppc64le)/ARM/s390x CPUs in V1
# Disable chunked prefill for POWER (ppc64le)/ARM/s390x CPUs in V1 if current_platform.is_cpu() and current_platform.get_cpu_architecture(
if current_platform.is_cpu( ) in (CpuArchEnum.POWERPC, CpuArchEnum.S390X, CpuArchEnum.ARM):
) and current_platform.get_cpu_architecture() in ( logger.info("Chunked prefill is not supported for ARM and POWER "
CpuArchEnum.POWERPC, CpuArchEnum.S390X, CpuArchEnum.ARM): "and S390X CPUs; "
logger.info( "disabling it for V1 backend.")
"Chunked prefill is not supported for ARM and POWER " self.enable_chunked_prefill = False
"and S390X CPUs; "
"disabling it for V1 backend.")
self.enable_chunked_prefill = False
else:
self._set_default_args_v0(model_config)
assert self.enable_chunked_prefill is not None assert self.enable_chunked_prefill is not None
sliding_window: Optional[int] = None sliding_window: Optional[int] = None
@ -1494,6 +1489,7 @@ class EngineArgs:
"FLEX_ATTENTION", "FLEX_ATTENTION",
"TREE_ATTN", "TREE_ATTN",
"XFORMERS_VLLM_V1", "XFORMERS_VLLM_V1",
"ROCM_ATTN_VLLM_V1",
] ]
if (envs.is_set("VLLM_ATTENTION_BACKEND") if (envs.is_set("VLLM_ATTENTION_BACKEND")
and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS): and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):
@ -1501,12 +1497,6 @@ class EngineArgs:
_raise_or_fallback(feature_name=name, recommend_to_remove=True) _raise_or_fallback(feature_name=name, recommend_to_remove=True)
return False return False
# Platforms must decide if they can support v1 for this model
if not current_platform.supports_v1(model_config=model_config):
_raise_or_fallback(
feature_name=f"device type={current_platform.device_type}",
recommend_to_remove=False)
return False
############################################################# #############################################################
# Experimental Features - allow users to opt in. # Experimental Features - allow users to opt in.
@ -1523,12 +1513,6 @@ class EngineArgs:
recommend_to_remove=False) recommend_to_remove=False)
return False return False
# The platform may be supported on V1, but off by default for now.
if not current_platform.default_v1( # noqa: SIM103
model_config=model_config) and _warn_or_fallback(
current_platform.device_name):
return False
if (current_platform.is_cpu() if (current_platform.is_cpu()
and model_config.get_sliding_window() is not None): and model_config.get_sliding_window() is not None):
_raise_or_fallback(feature_name="sliding window (CPU backend)", _raise_or_fallback(feature_name="sliding window (CPU backend)",
@ -1539,64 +1523,8 @@ class EngineArgs:
return True return True
def _set_default_args_v0(self, model_config: ModelConfig) -> None: def _set_default_args(self, usage_context: UsageContext,
"""Set Default Arguments for V0 Engine.""" model_config: ModelConfig) -> None:
max_model_len = model_config.max_model_len
use_long_context = max_model_len > 32768
if self.enable_chunked_prefill is None:
# Chunked prefill not supported for Multimodal or MLA in V0.
if model_config.is_multimodal_model or model_config.use_mla:
self.enable_chunked_prefill = False
# Enable chunked prefill by default for long context (> 32K)
# models to avoid OOM errors in initial memory profiling phase.
elif use_long_context:
is_gpu = current_platform.is_cuda()
use_sliding_window = (model_config.get_sliding_window()
is not None)
use_spec_decode = self.speculative_config is not None
if (is_gpu and not use_sliding_window and not use_spec_decode
and not self.enable_lora):
self.enable_chunked_prefill = True
logger.warning(
"Chunked prefill is enabled by default for models "
"with max_model_len > 32K. Chunked prefill might "
"not work with some features or models. If you "
"encounter any issues, please disable by launching "
"with --enable-chunked-prefill=False.")
if self.enable_chunked_prefill is None:
self.enable_chunked_prefill = False
if not self.enable_chunked_prefill and use_long_context:
logger.warning(
"The model has a long context length (%s). This may cause"
"OOM during the initial memory profiling phase, or result "
"in low performance due to small KV cache size. Consider "
"setting --max-model-len to a smaller value.", max_model_len)
# Disable prefix caching for multimodal models for VLLM_V0.
if self.enable_prefix_caching and model_config.is_multimodal_model:
logger.warning(
"--enable-prefix-caching is not supported for multimodal "
"models in V0 and has been disabled.")
self.enable_prefix_caching = False
if self.enable_prompt_embeds:
logger.warning(
"--enable-prompt-embeds and --enable-prefix-caching "
"are not supported together in V0. Prefix caching has "
"been disabled.")
self.enable_prefix_caching = False
# Set max_num_seqs to 256 for VLLM_V0.
if self.max_num_seqs is None:
self.max_num_seqs = 256
def _set_default_args_v1(self, usage_context: UsageContext,
model_config: ModelConfig) -> None:
"""Set Default Arguments for V1 Engine.""" """Set Default Arguments for V1 Engine."""
# V1 always uses chunked prefills and prefix caching # V1 always uses chunked prefills and prefix caching
@ -1795,21 +1723,6 @@ def _raise_or_fallback(feature_name: str, recommend_to_remove: bool):
logger.warning(msg) logger.warning(msg)
def _warn_or_fallback(feature_name: str) -> bool:
if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
logger.warning(
"Detected VLLM_USE_V1=1 with %s. Usage should "
"be considered experimental. Please report any "
"issues on Github.", feature_name)
should_exit = False
else:
logger.info(
"%s is experimental on VLLM_USE_V1=1. "
"Falling back to V0 Engine.", feature_name)
should_exit = True
return should_exit
def human_readable_int(value): def human_readable_int(value):
"""Parse human-readable integers like '1k', '2M', etc. """Parse human-readable integers like '1k', '2M', etc.
Including decimal values with decimal multipliers. Including decimal values with decimal multipliers.

View File

@ -8,6 +8,7 @@ from abc import ABC, abstractmethod
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Optional, Union from typing import TYPE_CHECKING, Optional, Union
from openai.types.responses.tool import Mcp
from openai_harmony import Author, Message, Role, StreamState, TextContent from openai_harmony import Author, Message, Role, StreamState, TextContent
from vllm.entrypoints.harmony_utils import ( from vllm.entrypoints.harmony_utils import (
@ -21,6 +22,24 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# This is currently needed as the tool type doesn't 1:1 match the
# tool namespace, which is what is used to look up the
# connection to the tool server
_TOOL_NAME_TO_TYPE_MAP = {
"browser": "web_search_preview",
"python": "code_interpreter",
"container": "container",
}
def _map_tool_name_to_tool_type(tool_name: str) -> str:
if tool_name not in _TOOL_NAME_TO_TYPE_MAP:
available_tools = ', '.join(_TOOL_NAME_TO_TYPE_MAP.keys())
raise ValueError(
f"Built-in tool name '{tool_name}' not defined in mapping. "
f"Available tools: {available_tools}")
return _TOOL_NAME_TO_TYPE_MAP[tool_name]
class TurnTokens: class TurnTokens:
"""Tracks token counts for a single conversation turn.""" """Tracks token counts for a single conversation turn."""
@ -59,8 +78,8 @@ class ConversationContext(ABC):
@abstractmethod @abstractmethod
async def init_tool_sessions(self, tool_server: Optional[ToolServer], async def init_tool_sessions(self, tool_server: Optional[ToolServer],
exit_stack: AsyncExitStack, exit_stack: AsyncExitStack, request_id: str,
request_id: str) -> None: mcp_tools: dict[str, Mcp]) -> None:
pass pass
@abstractmethod @abstractmethod
@ -96,8 +115,8 @@ class SimpleContext(ConversationContext):
raise NotImplementedError("Should not be called.") raise NotImplementedError("Should not be called.")
async def init_tool_sessions(self, tool_server: Optional[ToolServer], async def init_tool_sessions(self, tool_server: Optional[ToolServer],
exit_stack: AsyncExitStack, exit_stack: AsyncExitStack, request_id: str,
request_id: str) -> None: mcp_tools: dict[str, Mcp]) -> None:
pass pass
async def cleanup_session(self) -> None: async def cleanup_session(self) -> None:
@ -318,13 +337,17 @@ class HarmonyContext(ConversationContext):
] ]
async def init_tool_sessions(self, tool_server: Optional[ToolServer], async def init_tool_sessions(self, tool_server: Optional[ToolServer],
exit_stack: AsyncExitStack, exit_stack: AsyncExitStack, request_id: str,
request_id: str) -> None: mcp_tools: dict[str, Mcp]):
if tool_server: if tool_server:
for tool_name in self.available_tools: for tool_name in self.available_tools:
if tool_name not in self._tool_sessions: if tool_name not in self._tool_sessions:
tool_type = _map_tool_name_to_tool_type(tool_name)
headers = mcp_tools[
tool_type].headers if tool_type in mcp_tools else None
tool_session = await exit_stack.enter_async_context( tool_session = await exit_stack.enter_async_context(
tool_server.new_session(tool_name, request_id)) tool_server.new_session(tool_name, request_id,
headers))
self._tool_sessions[tool_name] = tool_session self._tool_sessions[tool_name] = tool_session
exit_stack.push_async_exit(self.cleanup_session) exit_stack.push_async_exit(self.cleanup_session)

View File

@ -126,8 +126,10 @@ def get_developer_message(
function_tools: list[Union[Tool, ChatCompletionToolsParam]] = [] function_tools: list[Union[Tool, ChatCompletionToolsParam]] = []
for tool in tools: for tool in tools:
if tool.type in ("web_search_preview", "code_interpreter", if tool.type in ("web_search_preview", "code_interpreter",
"container"): "container", "mcp"):
# These are built-in tools that are added to the system message. # These are built-in tools that are added to the system message.
# Adding in MCP for now until we support MCP tools executed
# server side
pass pass
elif tool.type == "function": elif tool.type == "function":

View File

@ -1468,7 +1468,7 @@ class LLM:
def _validate_and_add_requests( def _validate_and_add_requests(
self, self,
prompts: Union[PromptType, Sequence[PromptType]], prompts: Union[PromptType, Sequence[PromptType], DataPrompt],
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
Sequence[PoolingParams]], Sequence[PoolingParams]],
*, *,
@ -1478,7 +1478,7 @@ class LLM:
) -> None: ) -> None:
if isinstance(prompts, (str, dict)): if isinstance(prompts, (str, dict)):
# Convert a single prompt to a list. # Convert a single prompt to a list.
prompts = [prompts] prompts = [prompts] # type: ignore[list-item]
num_requests = len(prompts) num_requests = len(prompts)
if isinstance(params, Sequence) and len(params) != num_requests: if isinstance(params, Sequence) and len(params) != num_requests:

View File

@ -460,8 +460,12 @@ class OpenAIServingResponses(OpenAIServing):
async with AsyncExitStack() as exit_stack: async with AsyncExitStack() as exit_stack:
try: try:
mcp_tools = {
tool.server_label: tool
for tool in request.tools if tool.type == "mcp"
}
await context.init_tool_sessions(self.tool_server, exit_stack, await context.init_tool_sessions(self.tool_server, exit_stack,
request.request_id) request.request_id, mcp_tools)
async for _ in result_generator: async for _ in result_generator:
pass pass
except asyncio.CancelledError: except asyncio.CancelledError:
@ -748,11 +752,16 @@ class OpenAIServingResponses(OpenAIServing):
# New conversation. # New conversation.
reasoning_effort = (request.reasoning.effort reasoning_effort = (request.reasoning.effort
if request.reasoning else None) if request.reasoning else None)
# Temporary: OpenAI types doesn't have container tool
# so we used MCP to cover that, up for change
tool_types = [tool.type for tool in request.tools] tool_types = [tool.type for tool in request.tools]
if envs.VLLM_GPT_OSS_USE_CONTAINER_TOOL:
tool_types.append("container") # Allow the MCP Tool type to enable built in tools if the
# server_label is allowlisted in
# envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS
if envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS:
for tool in request.tools:
if (tool.type == "mcp" and tool.server_label
in envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS):
tool_types.append(tool.server_label)
enable_browser = ("web_search_preview" in tool_types enable_browser = ("web_search_preview" in tool_types
and self.tool_server is not None and self.tool_server is not None
and self.tool_server.has_tool("browser")) and self.tool_server.has_tool("browser"))
@ -1653,8 +1662,12 @@ class OpenAIServingResponses(OpenAIServing):
async with AsyncExitStack() as exit_stack: async with AsyncExitStack() as exit_stack:
processer = None processer = None
if self.use_harmony: if self.use_harmony:
mcp_tools = {
tool.server_label: tool
for tool in request.tools if tool.type == "mcp"
}
await context.init_tool_sessions(self.tool_server, exit_stack, await context.init_tool_sessions(self.tool_server, exit_stack,
request.request_id) request.request_id, mcp_tools)
processer = self._process_harmony_streaming_events processer = self._process_harmony_streaming_events
else: else:
processer = self._process_simple_streaming_events processer = self._process_simple_streaming_events

View File

@ -20,6 +20,7 @@ from .openai_tool_parser import OpenAIToolParser
from .phi4mini_tool_parser import Phi4MiniJsonToolParser from .phi4mini_tool_parser import Phi4MiniJsonToolParser
from .pythonic_tool_parser import PythonicToolParser from .pythonic_tool_parser import PythonicToolParser
from .qwen3coder_tool_parser import Qwen3CoderToolParser from .qwen3coder_tool_parser import Qwen3CoderToolParser
from .qwen3xml_tool_parser import Qwen3XMLToolParser
from .seed_oss_tool_parser import SeedOssToolParser from .seed_oss_tool_parser import SeedOssToolParser
from .step3_tool_parser import Step3ToolParser from .step3_tool_parser import Step3ToolParser
from .xlam_tool_parser import xLAMToolParser from .xlam_tool_parser import xLAMToolParser
@ -45,6 +46,7 @@ __all__ = [
"HunyuanA13BToolParser", "HunyuanA13BToolParser",
"Glm4MoeModelToolParser", "Glm4MoeModelToolParser",
"Qwen3CoderToolParser", "Qwen3CoderToolParser",
"Qwen3XMLToolParser",
"SeedOssToolParser", "SeedOssToolParser",
"Step3ToolParser", "Step3ToolParser",
"OpenAIToolParser", "OpenAIToolParser",

View File

@ -368,16 +368,32 @@ class Hermes2ProToolParser(ToolParser):
# case -- we now have the first info about arguments available from # case -- we now have the first info about arguments available from
# autocompleting the JSON # autocompleting the JSON
elif cur_arguments and not prev_arguments: elif cur_arguments and not prev_arguments:
# extract the content after {"name": ..., "arguments":
# directly from tool_call_portion as cur_arguments_json,
# since cur_arguments may differ from the original text
# due to partial JSON parsing
# for example, tool_call_portion =
# {"name": "search", "arguments": {"search_request": {"
# but cur_arguments =
# {"search_request": {}}
function_name = current_tool_call.get("name")
match = re.search(
r'\{"name":\s*"' +
re.escape(function_name) + r'"\s*,\s*"arguments":\s*(.*)',
tool_call_portion.strip(), re.DOTALL)
if match:
cur_arguments_json = match.group(1)
else:
cur_arguments_json = json.dumps(cur_arguments,
ensure_ascii=False)
cur_arguments_json = json.dumps(cur_arguments,
ensure_ascii=False)
logger.debug("finding %s in %s", delta_text, logger.debug("finding %s in %s", delta_text,
cur_arguments_json) cur_arguments_json)
# get the location where previous args differ from current # get the location where previous args differ from current.
if (delta_text not in cur_arguments_json[:-2]): if (delta_text not in cur_arguments_json):
return None return None
args_delta_start_loc = cur_arguments_json[:-2]. \ args_delta_start_loc = cur_arguments_json. \
rindex(delta_text) + \ rindex(delta_text) + \
len(delta_text) len(delta_text)
@ -397,8 +413,20 @@ class Hermes2ProToolParser(ToolParser):
# last case -- we have an update to existing arguments. # last case -- we have an update to existing arguments.
elif cur_arguments and prev_arguments: elif cur_arguments and prev_arguments:
if isinstance(delta_text, str) and len(delta_text.rstrip( # judge whether the tool_call_portion is a complete JSON
)) >= 1 and delta_text.rstrip()[-1] == '}': try:
json.loads(tool_call_portion)
is_complete_json = True
except Exception:
is_complete_json = False
# if the delta_text ends with a '}' and tool_call_portion is a
# complete JSON, then the last '}' does not belong to the
# arguments, so we should trim it off
if isinstance(delta_text, str) \
and len(delta_text.rstrip()) >= 1 \
and delta_text.rstrip()[-1] == '}' \
and is_complete_json:
delta_text = delta_text.rstrip()[:-1] delta_text = delta_text.rstrip()[:-1]
logger.debug("got diff %s", delta_text) logger.debug("got diff %s", delta_text)

File diff suppressed because it is too large Load Diff

View File

@ -280,7 +280,7 @@ class CompletionRenderer(BaseRenderer):
if truncate_prompt_tokens < 0: if truncate_prompt_tokens < 0:
truncate_prompt_tokens = self.model_config.max_model_len truncate_prompt_tokens = self.model_config.max_model_len
if max_length is not None and truncate_prompt_tokens > max_length: if max_length is not None and truncate_prompt_tokens > max_length: # type: ignore[operator]
raise ValueError( raise ValueError(
f"truncate_prompt_tokens ({truncate_prompt_tokens}) " f"truncate_prompt_tokens ({truncate_prompt_tokens}) "
f"cannot be greater than max_length ({max_length}). " f"cannot be greater than max_length ({max_length}). "

View File

@ -18,7 +18,6 @@ if TYPE_CHECKING:
async def list_server_and_tools(server_url: str): async def list_server_and_tools(server_url: str):
from mcp import ClientSession from mcp import ClientSession
from mcp.client.sse import sse_client from mcp.client.sse import sse_client
async with sse_client(url=server_url) as streams, ClientSession( async with sse_client(url=server_url) as streams, ClientSession(
*streams) as session: *streams) as session:
initialize_response = await session.initialize() initialize_response = await session.initialize()
@ -86,8 +85,12 @@ class ToolServer(ABC):
pass pass
@abstractmethod @abstractmethod
def new_session(self, tool_name: str, def new_session(
session_id: str) -> AbstractAsyncContextManager[Any]: self,
tool_name: str,
session_id: str,
headers: Optional[dict[str, str]] = None
) -> AbstractAsyncContextManager[Any]:
""" """
Create a session for the tool. Create a session for the tool.
""" """
@ -144,16 +147,21 @@ class MCPToolServer(ToolServer):
return self.harmony_tool_descriptions.get(tool_name) return self.harmony_tool_descriptions.get(tool_name)
@asynccontextmanager @asynccontextmanager
async def new_session(self, tool_name: str, session_id: str): async def new_session(self,
tool_name: str,
session_id: str,
headers: Optional[dict[str, str]] = None):
from mcp import ClientSession from mcp import ClientSession
from mcp.client.sse import sse_client from mcp.client.sse import sse_client
url = self.urls.get(tool_name) url = self.urls.get(tool_name)
headers = {"x-session-id": session_id} request_headers = {"x-session-id": session_id}
if headers is not None:
request_headers.update(headers)
if not url: if not url:
raise KeyError(f"Tool '{tool_name}' is not supported") raise KeyError(f"Tool '{tool_name}' is not supported")
async with sse_client(url=url, async with sse_client(
headers=headers) as streams, ClientSession( url=url, headers=request_headers) as streams, ClientSession(
*streams) as session: *streams) as session:
await session.initialize() await session.initialize()
yield session yield session
@ -189,7 +197,10 @@ class DemoToolServer(ToolServer):
raise ValueError(f"Unknown tool {tool_name}") raise ValueError(f"Unknown tool {tool_name}")
@asynccontextmanager @asynccontextmanager
async def new_session(self, tool_name: str, session_id: str): async def new_session(self,
tool_name: str,
session_id: str,
headers: Optional[dict[str, str]] = None):
if tool_name not in self.tools: if tool_name not in self.tools:
raise KeyError(f"Tool '{tool_name}' is not supported") raise KeyError(f"Tool '{tool_name}' is not supported")
yield self.tools[tool_name] yield self.tools[tool_name]

View File

@ -119,12 +119,14 @@ if TYPE_CHECKING:
VLLM_SERVER_DEV_MODE: bool = False VLLM_SERVER_DEV_MODE: bool = False
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128 VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128
VLLM_MLA_DISABLE: bool = False VLLM_MLA_DISABLE: bool = False
VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH: int = 16
VLLM_RAY_PER_WORKER_GPUS: float = 1.0 VLLM_RAY_PER_WORKER_GPUS: float = 1.0
VLLM_RAY_BUNDLE_INDICES: str = "" VLLM_RAY_BUNDLE_INDICES: str = ""
VLLM_CUDART_SO_PATH: Optional[str] = None VLLM_CUDART_SO_PATH: Optional[str] = None
VLLM_DP_RANK: int = 0 VLLM_DP_RANK: int = 0
VLLM_DP_RANK_LOCAL: int = -1 VLLM_DP_RANK_LOCAL: int = -1
VLLM_DP_SIZE: int = 1 VLLM_DP_SIZE: int = 1
VLLM_USE_STANDALONE_COMPILE: bool = False
VLLM_DP_MASTER_IP: str = "" VLLM_DP_MASTER_IP: str = ""
VLLM_DP_MASTER_PORT: int = 0 VLLM_DP_MASTER_PORT: int = 0
VLLM_MOE_DP_CHUNK_SIZE: int = 256 VLLM_MOE_DP_CHUNK_SIZE: int = 256
@ -183,11 +185,12 @@ if TYPE_CHECKING:
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
VLLM_GPT_OSS_USE_CONTAINER_TOOL: bool = False
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True
VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME: str = "VLLM_OBJECT_STORAGE_SHM_BUFFER" VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME: str = "VLLM_OBJECT_STORAGE_SHM_BUFFER"
GPT_OSS_SYSTEM_TOOL_MCP_LABELS: list[str] = []
VLLM_PATTERN_MATCH_DEBUG: Optional[str] = None
def get_default_cache_root(): def get_default_cache_root():
@ -258,6 +261,58 @@ def env_with_choices(
return _get_validated_env return _get_validated_env
def env_list_with_choices(
env_name: str,
default: list[str],
choices: Union[list[str], Callable[[], list[str]]],
case_sensitive: bool = True) -> Callable[[], list[str]]:
"""
Create a lambda that validates environment variable
containing comma-separated values against allowed choices
Args:
env_name: Name of the environment variable
default: Default list of values if not set
choices: List of valid string options or callable that returns list
case_sensitive: Whether validation should be case sensitive
Returns:
Lambda function for environment_variables
dict that returns list of strings
"""
def _get_validated_env_list() -> list[str]:
value = os.getenv(env_name)
if value is None:
return default
# Split comma-separated values and strip whitespace
values = [v.strip() for v in value.split(",") if v.strip()]
if not values:
return default
# Resolve choices if it's a callable (for lazy loading)
actual_choices = choices() if callable(choices) else choices
# Validate each value
for val in values:
if not case_sensitive:
check_value = val.lower()
check_choices = [choice.lower() for choice in actual_choices]
else:
check_value = val
check_choices = actual_choices
if check_value not in check_choices:
raise ValueError(f"Invalid value '{val}' in {env_name}. "
f"Valid options: {actual_choices}.")
return values
return _get_validated_env_list
def get_vllm_port() -> Optional[int]: def get_vllm_port() -> Optional[int]:
"""Get the port from VLLM_PORT environment variable. """Get the port from VLLM_PORT environment variable.
@ -436,9 +491,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Feature flag to enable/disable Inductor standalone compile. # Feature flag to enable/disable Inductor standalone compile.
# In torch <= 2.7 we ignore this flag; in torch >= 2.8 this is # In torch <= 2.7 we ignore this flag; in torch >= 2.8 this is
# enabled by default. # disabled by default.
"VLLM_USE_STANDALONE_COMPILE": "VLLM_USE_STANDALONE_COMPILE":
lambda: os.environ.get("VLLM_USE_STANDALONE_COMPILE", "1") == "1", lambda: os.environ.get("VLLM_USE_STANDALONE_COMPILE", "0") == "1",
# Debug pattern matching inside custom passes.
# Should be set to the fx.Node name (e.g. 'getitem_34' or 'scaled_mm_3').
"VLLM_PATTERN_MATCH_DEBUG":
lambda: os.environ.get("VLLM_PATTERN_MATCH_DEBUG", None),
# local rank of the process in the distributed setting, used to determine # local rank of the process in the distributed setting, used to determine
# the GPU device id # the GPU device id
@ -946,6 +1006,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_MLA_DISABLE": "VLLM_MLA_DISABLE":
lambda: bool(int(os.getenv("VLLM_MLA_DISABLE", "0"))), lambda: bool(int(os.getenv("VLLM_MLA_DISABLE", "0"))),
# If set, vLLM will pick up the provided Flash Attention MLA
# max number splits for cuda graph decode
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH":
lambda: int(os.getenv("VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH",
"16")),
# Number of GPUs per worker in Ray, if it is set to be a fraction, # Number of GPUs per worker in Ray, if it is set to be a fraction,
# it allows ray to schedule multiple actors on a single GPU, # it allows ray to schedule multiple actors on a single GPU,
# so that users can colocate other actors on the same GPUs as vLLM. # so that users can colocate other actors on the same GPUs as vLLM.
@ -1306,10 +1372,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_TUNED_CONFIG_FOLDER": "VLLM_TUNED_CONFIG_FOLDER":
lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None), lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None),
# Allows vllm use container tool
"VLLM_GPT_OSS_USE_CONTAINER_TOOL":
lambda: bool(int(os.getenv("VLLM_GPT_OSS_USE_CONTAINER_TOOL", "0"))),
# Allows harmony instructions to be injected on system messages # Allows harmony instructions to be injected on system messages
"VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS": "VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS":
lambda: bool( lambda: bool(
@ -1329,6 +1391,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME": "VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME":
lambda: os.getenv("VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME", lambda: os.getenv("VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME",
"VLLM_OBJECT_STORAGE_SHM_BUFFER"), "VLLM_OBJECT_STORAGE_SHM_BUFFER"),
# Valid values are container,code_interpreter,web_search_preview
# ex GPT_OSS_SYSTEM_TOOL_MCP_LABELS=container,code_interpreter
"GPT_OSS_SYSTEM_TOOL_MCP_LABELS":
env_list_with_choices("GPT_OSS_SYSTEM_TOOL_MCP_LABELS", [],
["container",
"code_interpreter",
"web_search_preview"]),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
@ -1379,6 +1449,7 @@ def compute_hash() -> str:
environment_variables_to_hash = [ environment_variables_to_hash = [
"VLLM_PP_LAYER_PARTITION", "VLLM_PP_LAYER_PARTITION",
"VLLM_MLA_DISABLE", "VLLM_MLA_DISABLE",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH",
"VLLM_USE_TRITON_FLASH_ATTN", "VLLM_USE_TRITON_FLASH_ATTN",
"VLLM_USE_TRITON_AWQ", "VLLM_USE_TRITON_AWQ",
"VLLM_DP_RANK", "VLLM_DP_RANK",

View File

@ -121,18 +121,18 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
lora_bias = self.slice_bias(lora_bias) lora_bias = self.slice_bias(lora_bias)
self.lora_a_stacked[0][index, self.lora_a_stacked[0][index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_( 0, :lora_a.shape[0], :lora_a.shape[1]].copy_(
lora_a.T, non_blocking=True) lora_a, non_blocking=True)
self.lora_b_stacked[0][index, self.lora_b_stacked[0][index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_( 0, :lora_b.shape[0], :lora_b.shape[1]].copy_(
lora_b.T, non_blocking=True) lora_b, non_blocking=True)
if lora_bias is not None: if lora_bias is not None:
self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
self.lora_bias_stacked) self.lora_bias_stacked)
assert len(self.lora_bias_stacked) assert len(self.lora_bias_stacked)
self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_( self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_(
lora_bias.T, non_blocking=True) lora_bias, non_blocking=True)
def apply(self, def apply(self,
x: torch.Tensor, x: torch.Tensor,

View File

@ -99,13 +99,13 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
if self.is_merged_col_linear: if self.is_merged_col_linear:
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
shard_size = self.output_size // 2 shard_size = self.output_size // 2
offset = lora_b.shape[-1] // 2 offset = lora_b.shape[0] // 2
left_weight = lora_b[:, tp_rank * shard_size:(tp_rank + 1) * left_weight = lora_b[tp_rank * shard_size:(tp_rank + 1) *
shard_size] shard_size, :]
right_weight = lora_b[:, offset + tp_rank * shard_size:offset + right_weight = lora_b[offset + tp_rank * shard_size:offset +
(tp_rank + 1) * shard_size] (tp_rank + 1) * shard_size, :]
lora_b = torch.cat([left_weight, right_weight], dim=1) lora_b = torch.cat([left_weight, right_weight], dim=0)
# Applicable to cases where the base_layer is # Applicable to cases where the base_layer is
# ColumnParallelLinear. # ColumnParallelLinear.
else: else:
@ -113,7 +113,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
shard_size = self.output_size shard_size = self.output_size
start_idx = tensor_model_parallel_rank * shard_size start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size end_idx = (tensor_model_parallel_rank + 1) * shard_size
lora_b = lora_b[:, start_idx:end_idx] lora_b = lora_b[start_idx:end_idx, :]
return lora_b return lora_b
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
@ -251,9 +251,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
for i, (shard_id, shard_size) in enumerate( for i, (shard_id, shard_size) in enumerate(
zip(self.output_ids, self.output_slices)): zip(self.output_ids, self.output_slices)):
if (lora_b_i := lora_b[i]) is not None: if (lora_b_i := lora_b[i]) is not None:
sliced_lora_b[i] = lora_b_i[:, sliced_lora_b[i] = lora_b_i[shard_size * shard_id:shard_size *
shard_size * shard_id:shard_size * (shard_id + 1), :]
(shard_id + 1)]
return sliced_lora_b return sliced_lora_b
def slice_bias( def slice_bias(
@ -285,12 +284,12 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
for i in range(self.n_slices): for i in range(self.n_slices):
if (lora_a_i := lora_a[i]) is not None: if (lora_a_i := lora_a[i]) is not None:
self.lora_a_stacked[i][ self.lora_a_stacked[i][
index, 0, :lora_a_i.shape[1], :lora_a_i.shape[0]].copy_( index, 0, :lora_a_i.shape[0], :lora_a_i.shape[1]].copy_(
lora_a_i.T, non_blocking=True) lora_a_i, non_blocking=True)
if (lora_b_i := lora_b[i]) is not None: if (lora_b_i := lora_b[i]) is not None:
self.lora_b_stacked[i][ self.lora_b_stacked[i][
index, 0, :lora_b_i.shape[1], :lora_b_i.shape[0]].copy_( index, 0, :lora_b_i.shape[0], :lora_b_i.shape[1]].copy_(
lora_b_i.T, non_blocking=True) lora_b_i, non_blocking=True)
if lora_bias is not None: if lora_bias is not None:
self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
@ -299,7 +298,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
if (lora_bias_i := lora_bias[i]) is not None: if (lora_bias_i := lora_bias[i]) is not None:
self.lora_bias_stacked[i][index, self.lora_bias_stacked[i][index,
0, :lora_bias_i.shape[0]].copy_( 0, :lora_bias_i.shape[0]].copy_(
lora_bias_i.T, lora_bias_i,
non_blocking=True) non_blocking=True)
@classmethod @classmethod
@ -345,18 +344,18 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
self.q_shard_id = tp_rank self.q_shard_id = tp_rank
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
lora_b_q = lora_b[:, self.q_proj_shard_size * lora_b_q = lora_b[self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size * self.q_shard_id:self.q_proj_shard_size *
(self.q_shard_id + 1)] (self.q_shard_id + 1), :]
k_offset = self.q_proj_total_size k_offset = self.q_proj_total_size
lora_b_k = lora_b[:, k_offset + lora_b_k = lora_b[k_offset +
self.kv_proj_shard_size * self.kv_shard_id:k_offset + self.kv_proj_shard_size * self.kv_shard_id:k_offset +
self.kv_proj_shard_size * (self.kv_shard_id + 1)] self.kv_proj_shard_size * (self.kv_shard_id + 1), :]
v_offset = k_offset + self.kv_proj_total_size v_offset = k_offset + self.kv_proj_total_size
lora_b_v = lora_b[:, v_offset + lora_b_v = lora_b[v_offset +
self.kv_proj_shard_size * self.kv_shard_id:v_offset + self.kv_proj_shard_size * self.kv_shard_id:v_offset +
self.kv_proj_shard_size * (self.kv_shard_id + 1)] self.kv_proj_shard_size * (self.kv_shard_id + 1), :]
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1) lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=0)
return lora_b return lora_b
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
@ -465,7 +464,7 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
shard_size = self.lora_a_stacked[0].shape[2] shard_size = self.lora_a_stacked[0].shape[2]
start_idx = tp_rank * shard_size start_idx = tp_rank * shard_size
lora_a = lora_a[:, start_idx:start_idx + shard_size] lora_a = lora_a[start_idx:start_idx + shard_size, :]
return lora_a return lora_a
def apply(self, def apply(self,
@ -508,10 +507,10 @@ class MergedColumnParallelLinearWithShardedLoRA(
output_shard_size = self.lora_a_stacked[0].shape[2] output_shard_size = self.lora_a_stacked[0].shape[2]
output_start_idx = self.tp_rank * output_shard_size output_start_idx = self.tp_rank * output_shard_size
lora_a = [ lora_a = [
lora_a[0][:, output_start_idx:output_start_idx + lora_a[0][output_start_idx:output_start_idx +
output_shard_size] if lora_a[0] is not None else None, output_shard_size, :] if lora_a[0] is not None else None,
lora_a[1][:, output_start_idx:output_start_idx + lora_a[1][output_start_idx:output_start_idx +
output_shard_size] if lora_a[1] is not None else None, output_shard_size, :] if lora_a[1] is not None else None,
] ]
return lora_a return lora_a
@ -551,7 +550,7 @@ class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
shard_size = self.lora_a_stacked[0].shape[2] shard_size = self.lora_a_stacked[0].shape[2]
start_idx = tp_rank * shard_size start_idx = tp_rank * shard_size
lora_a = lora_a[:, start_idx:start_idx + shard_size] lora_a = lora_a[start_idx:start_idx + shard_size, :]
return lora_a return lora_a
def apply(self, def apply(self,
@ -589,12 +588,12 @@ class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA):
shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)] shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
start_idx = [self.tp_rank * shard_size[i] for i in range(3)] start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
lora_a = [ lora_a = [
lora_a[0][:, start_idx[0]:start_idx[0] + lora_a[0][start_idx[0]:start_idx[0] +
shard_size[0]] if lora_a[0] is not None else None, shard_size[0], :] if lora_a[0] is not None else None,
lora_a[1][:, start_idx[1]:start_idx[1] + lora_a[1][start_idx[1]:start_idx[1] +
shard_size[1]] if lora_a[1] is not None else None, shard_size[1], :] if lora_a[1] is not None else None,
lora_a[2][:, start_idx[2]:start_idx[2] + lora_a[2][start_idx[2]:start_idx[2] +
shard_size[2]] if lora_a[2] is not None else None, shard_size[2], :] if lora_a[2] is not None else None,
] ]
return lora_a return lora_a

View File

@ -140,11 +140,11 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
): ):
self.reset_lora(index) self.reset_lora(index)
self.lora_a_stacked[index, self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_( 0, :lora_a.shape[0], :lora_a.shape[1]].copy_(
lora_a.T, non_blocking=True) lora_a, non_blocking=True)
self.lora_b_stacked[index, self.lora_b_stacked[index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_( 0, :lora_b.shape[0], :lora_b.shape[1]].copy_(
lora_b.T, non_blocking=True) lora_b, non_blocking=True)
if embeddings_tensor is not None: if embeddings_tensor is not None:
self.embeddings_tensors[ self.embeddings_tensors[
index, index,

View File

@ -39,7 +39,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
shard_size = self.input_size shard_size = self.input_size
start_idx = self.tp_rank * shard_size start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size end_idx = (self.tp_rank + 1) * shard_size
lora_a = lora_a[start_idx:end_idx, :] lora_a = lora_a[:,start_idx:end_idx]
return lora_a return lora_a
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
@ -122,7 +122,7 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
shard_size = self.lora_b_stacked[0].shape[2] shard_size = self.lora_b_stacked[0].shape[2]
start_idx = self.tp_rank * shard_size start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size end_idx = (self.tp_rank + 1) * shard_size
lora_b = lora_b[:, start_idx:end_idx] lora_b = lora_b[ start_idx:end_idx,:]
return lora_b return lora_b
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:

View File

@ -95,11 +95,13 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
): ):
self.reset_lora(index) self.reset_lora(index)
self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_( # NOTE self.lora_a_stacked is row-major, and lora_a is col-major,
lora_a, non_blocking=True) # so we need transpose here
self.lora_a_stacked[index, :lora_a.shape[1], :lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True)
self.lora_b_stacked[index, self.lora_b_stacked[index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_( 0, :lora_b.shape[0], :lora_b.shape[1]].copy_(
lora_b.T, non_blocking=True) lora_b, non_blocking=True)
if embeddings_tensor is not None: if embeddings_tensor is not None:
self.embeddings_tensors[ self.embeddings_tensors[
index, index,

View File

@ -86,11 +86,11 @@ class LoRALayerWeights:
embeddings_tensor_dim: Optional[int] = None, embeddings_tensor_dim: Optional[int] = None,
bias_enabled: Optional[bool] = False) -> "LoRALayerWeights": bias_enabled: Optional[bool] = False) -> "LoRALayerWeights":
pin_memory = str(device) == "cpu" and is_pin_memory_available() pin_memory = str(device) == "cpu" and is_pin_memory_available()
lora_a = torch.zeros([input_dim, rank], lora_a = torch.zeros([rank, input_dim],
dtype=dtype, dtype=dtype,
device=device, device=device,
pin_memory=pin_memory) pin_memory=pin_memory)
lora_b = torch.zeros([rank, output_dim], lora_b = torch.zeros([output_dim, rank],
dtype=dtype, dtype=dtype,
device=device, device=device,
pin_memory=pin_memory) pin_memory=pin_memory)

View File

@ -152,30 +152,29 @@ class LoRAModel:
module_name, peft_helper, lora_embeddings_tensor) module_name, peft_helper, lora_embeddings_tensor)
if is_bias: if is_bias:
loras[module_name].bias = tensor.to(device=device, loras[module_name].bias = tensor.to(device=device, dtype=dtype)
dtype=dtype).t() bias = tensor.to(device=device, dtype=dtype)
bias = tensor.to(device=device, dtype=dtype).t()
if pin_memory: if pin_memory:
bias = bias.pin_memory() bias = bias.pin_memory()
loras[module_name].bias = bias loras[module_name].bias = bias
elif is_lora_a: elif is_lora_a:
loras[module_name].lora_a = tensor.to(device=device, loras[module_name].lora_a = tensor.to(device=device,
dtype=dtype).t() dtype=dtype)
if pin_memory: if pin_memory:
loras[module_name].lora_a = loras[ loras[module_name].lora_a = loras[
module_name].lora_a.pin_memory() module_name].lora_a.pin_memory()
else: else:
loras[module_name].lora_b = tensor.to(device=device, loras[module_name].lora_b = tensor.to(device=device,
dtype=dtype).t() dtype=dtype)
assert embedding_padding_modules is not None assert embedding_padding_modules is not None
if any(name in module_name if any(name in module_name
for name in embedding_padding_modules for name in embedding_padding_modules
) and target_embedding_padding is not None: ) and target_embedding_padding is not None:
lora_b = loras[module_name].lora_b lora_b = loras[module_name].lora_b
assert target_embedding_padding >= lora_b.shape[1] assert target_embedding_padding >= lora_b.shape[0]
addition = target_embedding_padding - lora_b.shape[1] addition = target_embedding_padding - lora_b.shape[0]
loras[module_name].lora_b = torch.nn.functional.pad( loras[module_name].lora_b = torch.nn.functional.pad(
lora_b, (0, addition)) lora_b, (0, 0, 0, addition))
if pin_memory: if pin_memory:
loras[module_name].lora_b = loras[ loras[module_name].lora_b = loras[
module_name].lora_b.pin_memory() module_name].lora_b.pin_memory()
@ -585,7 +584,6 @@ class LoRAModelManager:
"cpu", "cpu",
bias_enabled=bias_enabled, bias_enabled=bias_enabled,
) )
lora.optimize()
else: else:
parts = module_name.split(".") parts = module_name.split(".")
replacements = self.packed_modules_mapping[parts[-1]] replacements = self.packed_modules_mapping[parts[-1]]
@ -600,7 +598,6 @@ class LoRAModelManager:
"cpu", "cpu",
bias_enabled=bias_enabled, bias_enabled=bias_enabled,
) )
lora.optimize()
subloras.append(lora) subloras.append(lora)
lora = PackedLoRALayerWeights.pack(subloras) lora = PackedLoRALayerWeights.pack(subloras)
model.loras[module_name] = lora model.loras[module_name] = lora

Some files were not shown because too many files have changed in this diff Show More