mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-08 01:27:03 +08:00
Merge branch 'main' into woosuk/model-runner-v2
This commit is contained in:
commit
17c2c106b1
@ -62,7 +62,7 @@ echo "--- Installing Python dependencies ---"
|
||||
python3 -m pip install --progress-bar off git+https://github.com/thuml/depyf.git \
|
||||
&& python3 -m pip install --progress-bar off pytest pytest-asyncio tpu-info \
|
||||
&& python3 -m pip install --progress-bar off "lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d" \
|
||||
&& python3 -m pip install --progress-bar off hf-transfer
|
||||
&& python3 -m pip install --progress-bar off hf-transfer tblib==3.1.0
|
||||
echo "--- Python dependencies installed ---"
|
||||
export VLLM_USE_V1=1
|
||||
export VLLM_XLA_CHECK_RECOMPILATION=1
|
||||
|
||||
@ -62,7 +62,7 @@ echo "--- Installing Python dependencies ---"
|
||||
python3 -m pip install --progress-bar off git+https://github.com/thuml/depyf.git \
|
||||
&& python3 -m pip install --progress-bar off pytest pytest-asyncio tpu-info \
|
||||
&& python3 -m pip install --progress-bar off "lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d" \
|
||||
&& python3 -m pip install --progress-bar off hf-transfer
|
||||
&& python3 -m pip install --progress-bar off hf-transfer tblib==3.1.0
|
||||
echo "--- Python dependencies installed ---"
|
||||
export VLLM_USE_V1=1
|
||||
export VLLM_XLA_CHECK_RECOMPILATION=1
|
||||
|
||||
@ -165,10 +165,18 @@ steps:
|
||||
- tests/v1/test_hybrid_lb_dp.py
|
||||
- tests/v1/engine/test_engine_core_client.py
|
||||
commands:
|
||||
# test with tp=2 and external_dp=2
|
||||
# test with torchrun tp=2 and external_dp=2
|
||||
- torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
|
||||
# test with tp=2 and pp=2
|
||||
# test with torchrun tp=2 and pp=2
|
||||
- PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
|
||||
# test with torchrun tp=4 and dp=1
|
||||
- TP_SIZE=4 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py
|
||||
# test with torchrun tp=2, pp=2 and dp=1
|
||||
- PP_SIZE=2 TP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py
|
||||
# test with torchrun tp=1 and dp=4 with ep
|
||||
- DP_SIZE=4 ENABLE_EP=1 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py
|
||||
# test with torchrun tp=2 and dp=2 with ep
|
||||
- TP_SIZE=2 DP_SIZE=2 ENABLE_EP=1 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py
|
||||
# test with internal dp
|
||||
- python3 ../examples/offline_inference/data_parallel.py --enforce-eager
|
||||
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
||||
|
||||
1
.github/CODEOWNERS
vendored
1
.github/CODEOWNERS
vendored
@ -72,6 +72,7 @@ mkdocs.yaml @hmellor
|
||||
# Linting
|
||||
.markdownlint.yaml @hmellor
|
||||
.pre-commit-config.yaml @hmellor
|
||||
/tools/pre_commit @hmellor
|
||||
|
||||
# CPU
|
||||
/vllm/v1/worker/cpu* @bigPYJ1151
|
||||
|
||||
4
.github/ISSUE_TEMPLATE/750-RFC.yml
vendored
4
.github/ISSUE_TEMPLATE/750-RFC.yml
vendored
@ -43,10 +43,6 @@ body:
|
||||
Any other things you would like to mention.
|
||||
validations:
|
||||
required: false
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
Thanks for contributing 🎉! The vLLM core team hosts a biweekly RFC review session at 9:30AM Pacific Time, while most RFCs can be discussed online, you can optionally sign up for a slot to discuss your RFC online [here](https://docs.google.com/document/d/1CiLVBZeIVfR7_PNAKVSusxpceywkoOOB78qoWqHvSZc/edit).
|
||||
- type: checkboxes
|
||||
id: askllm
|
||||
attributes:
|
||||
|
||||
@ -60,38 +60,32 @@ repos:
|
||||
files: ^requirements/test\.(in|txt)$
|
||||
- id: mypy-local
|
||||
name: Run mypy for local Python installation
|
||||
entry: tools/mypy.sh 0 "local"
|
||||
language: python
|
||||
types: [python]
|
||||
additional_dependencies: &mypy_deps [mypy==1.11.1, types-cachetools, types-setuptools, types-PyYAML, types-requests, pydantic]
|
||||
entry: python tools/pre_commit/mypy.py 0 "local"
|
||||
stages: [pre-commit] # Don't run in CI
|
||||
<<: &mypy_common
|
||||
language: python
|
||||
types_or: [python, pyi]
|
||||
require_serial: true
|
||||
additional_dependencies: [mypy==1.11.1, regex, types-cachetools, types-setuptools, types-PyYAML, types-requests, types-torch, pydantic]
|
||||
- id: mypy-3.9 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
|
||||
name: Run mypy for Python 3.9
|
||||
entry: tools/mypy.sh 1 "3.9"
|
||||
language: python
|
||||
types: [python]
|
||||
additional_dependencies: *mypy_deps
|
||||
entry: python tools/pre_commit/mypy.py 1 "3.9"
|
||||
<<: *mypy_common
|
||||
stages: [manual] # Only run in CI
|
||||
- id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
|
||||
name: Run mypy for Python 3.10
|
||||
entry: tools/mypy.sh 1 "3.10"
|
||||
language: python
|
||||
types: [python]
|
||||
additional_dependencies: *mypy_deps
|
||||
entry: python tools/pre_commit/mypy.py 1 "3.10"
|
||||
<<: *mypy_common
|
||||
stages: [manual] # Only run in CI
|
||||
- id: mypy-3.11 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
|
||||
name: Run mypy for Python 3.11
|
||||
entry: tools/mypy.sh 1 "3.11"
|
||||
language: python
|
||||
types: [python]
|
||||
additional_dependencies: *mypy_deps
|
||||
entry: python tools/pre_commit/mypy.py 1 "3.11"
|
||||
<<: *mypy_common
|
||||
stages: [manual] # Only run in CI
|
||||
- id: mypy-3.12 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
|
||||
name: Run mypy for Python 3.12
|
||||
entry: tools/mypy.sh 1 "3.12"
|
||||
language: python
|
||||
types: [python]
|
||||
additional_dependencies: *mypy_deps
|
||||
entry: python tools/pre_commit/mypy.py 1 "3.12"
|
||||
<<: *mypy_common
|
||||
stages: [manual] # Only run in CI
|
||||
- id: shellcheck
|
||||
name: Lint shell scripts
|
||||
@ -155,11 +149,10 @@ repos:
|
||||
additional_dependencies: [regex]
|
||||
- id: check-pickle-imports
|
||||
name: Prevent new pickle/cloudpickle imports
|
||||
entry: python tools/check_pickle_imports.py
|
||||
entry: python tools/pre_commit/check_pickle_imports.py
|
||||
language: python
|
||||
types: [python]
|
||||
pass_filenames: false
|
||||
additional_dependencies: [pathspec, regex]
|
||||
additional_dependencies: [regex]
|
||||
- id: validate-config
|
||||
name: Validate configuration has default values and that each field has a docstring
|
||||
entry: python tools/validate_config.py
|
||||
|
||||
@ -680,7 +680,7 @@ vllm bench serve \
|
||||
--save-result \
|
||||
--result-dir ~/vllm_benchmark_results \
|
||||
--save-detailed \
|
||||
--endpoint /v1/chat/completion
|
||||
--endpoint /v1/chat/completions
|
||||
```
|
||||
|
||||
##### Videos (ShareGPT4Video)
|
||||
@ -707,7 +707,7 @@ vllm bench serve \
|
||||
--save-result \
|
||||
--result-dir ~/vllm_benchmark_results \
|
||||
--save-detailed \
|
||||
--endpoint /v1/chat/completion
|
||||
--endpoint /v1/chat/completions
|
||||
```
|
||||
|
||||
##### Synthetic Random Images (random-mm)
|
||||
|
||||
@ -23,7 +23,7 @@ Now supports 5 types of connectors:
|
||||
|
||||
- **SharedStorageConnector**: refer to <gh-file:examples/offline_inference/disaggregated-prefill-v1/run.sh> for the example usage of SharedStorageConnector disaggregated prefilling.
|
||||
- **LMCacheConnectorV1**: refer to <gh-file:examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh> for the example usage of LMCacheConnectorV1 disaggregated prefilling which uses NIXL as the underlying KV transmission.
|
||||
- **NixlConnector**: refer to <gh-file:tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh> for the example usage of NixlConnector disaggregated prefilling which support fully async send/recv.
|
||||
- **NixlConnector**: refer to <gh-file:tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh> for the example usage of NixlConnector disaggregated prefilling which support fully async send/recv. For detailed usage guide, see [NixlConnector Usage Guide](nixl_connector_usage.md).
|
||||
- **P2pNcclConnector**: refer to <gh-file:examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_example_p2p_nccl_xpyd.sh> for the example usage of P2pNcclConnector disaggregated prefilling.
|
||||
- **MultiConnector**: take advantage of the kv_connector_extra_config: dict[str, Any] already present in KVTransferConfig to stash all the connectors we want in an ordered list of kwargs.such as:
|
||||
|
||||
@ -31,6 +31,18 @@ Now supports 5 types of connectors:
|
||||
--kv-transfer-config '{"kv_connector":"MultiConnector","kv_role":"kv_both","kv_connector_extra_config":{"connectors":[{"kv_connector":"NixlConnector","kv_role":"kv_both"},{"kv_connector":"SharedStorageConnector","kv_role":"kv_both","kv_connector_extra_config":{"shared_storage_path":"local_storage"}}]}}'
|
||||
```
|
||||
|
||||
For NixlConnector, you may also specify one or multiple NIXL_Backend. Such as:
|
||||
|
||||
```bash
|
||||
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both", "kv_buffer_device":"cuda", "kv_connector_extra_config":{"backend":["UCX", "GDS"]}'
|
||||
```
|
||||
|
||||
- **OffloadingConnector**: enable offloading of KV data to CPU memory, customizing the CPU block size (in tokens) and number of blocks to allocate (per worker):
|
||||
|
||||
```bash
|
||||
--kv-transfer-config '{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"block_size": 64, "num_cpu_blocks": 1000}}'
|
||||
```
|
||||
|
||||
## Benchmarks
|
||||
|
||||
Please refer to <gh-file:benchmarks/disagg_benchmarks> for disaggregated prefilling benchmarks.
|
||||
|
||||
159
docs/features/nixl_connector_usage.md
Normal file
159
docs/features/nixl_connector_usage.md
Normal file
@ -0,0 +1,159 @@
|
||||
# NixlConnector Usage Guide
|
||||
|
||||
NixlConnector is a high-performance KV cache transfer connector for vLLM's disaggregated prefilling feature. It provides fully asynchronous send/receive operations using the NIXL library for efficient cross-process KV cache transfer.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### Installation
|
||||
|
||||
Install the NIXL library: `uv pip install nixl`, as a quick start.
|
||||
|
||||
- Refer to [NIXL official repository](https://github.com/ai-dynamo/nixl) for more installation instructions
|
||||
- The specified required NIXL version can be found in [requirements/kv_connectors.txt](../../requirements/kv_connectors.txt) and other relevant config files
|
||||
|
||||
### Transport Configuration
|
||||
|
||||
NixlConnector uses NIXL library for underlying communication, which supports multiple transport backends. UCX (Unified Communication X) is the primary default transport library used by NIXL. Configure transport environment variables:
|
||||
|
||||
```bash
|
||||
# Example UCX configuration, adjust according to your enviroment
|
||||
export UCX_TLS=all # or specify specific transports like "rc,ud,sm,^cuda_ipc" ..etc
|
||||
export UCX_NET_DEVICES=all # or specify network devices like "mlx5_0:1,mlx5_1:1"
|
||||
```
|
||||
|
||||
!!! tip
|
||||
When using UCX as the transport backend, NCCL environment variables (like `NCCL_IB_HCA`, `NCCL_SOCKET_IFNAME`) are not applicable to NixlConnector, so configure UCX-specific environment variables instead of NCCL variables.
|
||||
|
||||
## Basic Usage (on the same host)
|
||||
|
||||
### Producer (Prefiller) Configuration
|
||||
|
||||
Start a prefiller instance that produces KV caches
|
||||
|
||||
```bash
|
||||
# 1st GPU as prefiller
|
||||
CUDA_VISIBLE_DEVICES=0 \
|
||||
UCX_NET_DEVICES=all \
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=5600 \
|
||||
vllm serve Qwen/Qwen3-0.6B \
|
||||
--port 8100 \
|
||||
--enforce-eager \
|
||||
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
|
||||
```
|
||||
|
||||
### Consumer (Decoder) Configuration
|
||||
|
||||
Start a decoder instance that consumes KV caches:
|
||||
|
||||
```bash
|
||||
# 2nd GPU as decoder
|
||||
CUDA_VISIBLE_DEVICES=1 \
|
||||
UCX_NET_DEVICES=all \
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=5601 \
|
||||
vllm serve Qwen/Qwen3-0.6B \
|
||||
--port 8200 \
|
||||
--enforce-eager \
|
||||
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
|
||||
```
|
||||
|
||||
### Proxy Server
|
||||
|
||||
Use a proxy server to route requests between prefiller and decoder:
|
||||
|
||||
```bash
|
||||
python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py \
|
||||
--port 8192 \
|
||||
--prefiller-hosts localhost \
|
||||
--prefiller-ports 8100 \
|
||||
--decoder-hosts localhost \
|
||||
--decoder-ports 8200
|
||||
```
|
||||
|
||||
## Environment Variables
|
||||
|
||||
- `VLLM_NIXL_SIDE_CHANNEL_PORT`: Port for NIXL handshake communication
|
||||
- Default: 5600
|
||||
- **Required for both prefiller and decoder instances**
|
||||
- Each vLLM worker needs a unique port on its host; using the same port number across different hosts is fine
|
||||
- For TP/DP deployments, each worker's port on a node is computed as: base_port + dp_rank * tp_size + tp_rank (e.g., with `--tensor-parallel-size=4` and base_port=5600, tp_rank 0..3 use ports 5600, 5601, 5602, 5603 on that node).
|
||||
- Used for the initial NIXL handshake between the prefiller and the decoder
|
||||
|
||||
- `VLLM_NIXL_SIDE_CHANNEL_HOST`: Host for side channel communication
|
||||
- Default: "localhost"
|
||||
- Set when prefiller and decoder are on different machines
|
||||
- Connection info is passed via KVTransferParams from prefiller to decoder for handshake
|
||||
|
||||
- `VLLM_NIXL_ABORT_REQUEST_TIMEOUT`: Timeout (in seconds) for automatically releasing the prefiller’s KV cache for a particular request. (Optional)
|
||||
- Default: 120
|
||||
- If a request is aborted and the decoder has not yet read the KV-cache blocks through the nixl channel, the prefill instance will release its KV-cache blocks after this timeout to avoid holding them indefinitely.
|
||||
|
||||
## Multi-Instance Setup
|
||||
|
||||
### Multiple Prefiller Instances on Different Machines
|
||||
|
||||
```bash
|
||||
# Prefiller 1 on Machine A (example IP: ${IP1})
|
||||
VLLM_NIXL_SIDE_CHANNEL_HOST=${IP1} \
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=5600 \
|
||||
UCX_NET_DEVICES=all \
|
||||
vllm serve Qwen/Qwen3-0.6B --port 8000 \
|
||||
--tensor-parallel-size 8 \
|
||||
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_producer"}'
|
||||
|
||||
# Prefiller 2 on Machine B (example IP: ${IP2})
|
||||
VLLM_NIXL_SIDE_CHANNEL_HOST=${IP2} \
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=5600 \
|
||||
UCX_NET_DEVICES=all \
|
||||
vllm serve Qwen/Qwen3-0.6B --port 8000 \
|
||||
--tensor-parallel-size 8 \
|
||||
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_producer"}'
|
||||
```
|
||||
|
||||
### Multiple Decoder Instances on Different Machines
|
||||
|
||||
```bash
|
||||
# Decoder 1 on Machine C (example IP: ${IP3})
|
||||
VLLM_NIXL_SIDE_CHANNEL_HOST=${IP3} \
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=5600 \
|
||||
UCX_NET_DEVICES=all \
|
||||
vllm serve Qwen/Qwen3-0.6B --port 8000 \
|
||||
--tensor-parallel-size 8 \
|
||||
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_consumer"}'
|
||||
|
||||
# Decoder 2 on Machine D (example IP: ${IP4})
|
||||
VLLM_NIXL_SIDE_CHANNEL_HOST=${IP4} \
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=5600 \
|
||||
UCX_NET_DEVICES=all \
|
||||
vllm serve Qwen/Qwen3-0.6B --port 8000 \
|
||||
--tensor-parallel-size 8 \
|
||||
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_consumer"}'
|
||||
```
|
||||
|
||||
### Proxy for Multiple Instances
|
||||
|
||||
```bash
|
||||
python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py \
|
||||
--port 8192 \
|
||||
--prefiller-hosts ${IP1} ${IP2} \
|
||||
--prefiller-ports 8000 8000 \
|
||||
--decoder-hosts ${IP3} ${IP4} \
|
||||
--decoder-ports 8000 8000
|
||||
```
|
||||
|
||||
### KV Role Options
|
||||
|
||||
- **kv_producer**: For prefiller instances that generate KV caches
|
||||
- **kv_consumer**: For decoder instances that consume KV caches from prefiller
|
||||
- **kv_both**: Enables symmetric functionality where the connector can act as both producer and consumer. This provides flexibility for experimental setups and scenarios where the role distinction is not predetermined.
|
||||
|
||||
!!! tip
|
||||
NixlConnector currently does not distinguish `kv_role`; the actual prefiller/decoder roles are determined by the upper-level proxy (e.g., `toy_proxy_server.py` using `--prefiller-hosts` and `--decoder-hosts`).
|
||||
Therefore, `kv_role` in `--kv-transfer-config` is effectively a placeholder and does not affect NixlConnector's behavior.
|
||||
|
||||
## Example Scripts/Code
|
||||
|
||||
Refer to these example scripts in the vLLM repository:
|
||||
|
||||
- [run_accuracy_test.sh](../../tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh)
|
||||
- [toy_proxy_server.py](../../tests/v1/kv_connector/nixl_integration/toy_proxy_server.py)
|
||||
- [test_accuracy.py](../../tests/v1/kv_connector/nixl_integration/test_accuracy.py)
|
||||
@ -319,6 +319,15 @@ Supported models:
|
||||
|
||||
Flags: `--tool-call-parser glm45`
|
||||
|
||||
### Qwen3-Coder Models (`qwen3_xml`)
|
||||
|
||||
Supported models:
|
||||
|
||||
* `Qwen/Qwen3-480B-A35B-Instruct`
|
||||
* `Qwen/Qwen3-Coder-30B-A3B-Instruct`
|
||||
|
||||
Flags: `--tool-call-parser qwen3_xml`
|
||||
|
||||
### Models with Pythonic Tool Calls (`pythonic`)
|
||||
|
||||
A growing number of models output a python list to represent tool calls instead of using JSON. This has the advantage of inherently supporting parallel tool calls and removing ambiguity around the JSON schema required for tool calls. The `pythonic` tool parser can support such models.
|
||||
|
||||
@ -352,6 +352,7 @@ th {
|
||||
| `DeepseekV2ForCausalLM` | DeepSeek-V2 | `deepseek-ai/DeepSeek-V2`, `deepseek-ai/DeepSeek-V2-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3`, `deepseek-ai/DeepSeek-R1`, `deepseek-ai/DeepSeek-V3.1`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Dots1ForCausalLM` | dots.llm1 | `rednote-hilab/dots.llm1.base`, `rednote-hilab/dots.llm1.inst`, etc. | | ✅︎ | ✅︎ |
|
||||
| `DotsOCRForCausalLM` | dots_ocr | `rednote-hilab/dots.ocr` | | ✅︎ | ✅︎ |
|
||||
| `Ernie4_5ForCausalLM` | Ernie4.5 | `baidu/ERNIE-4.5-0.3B-PT`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Ernie4_5_MoeForCausalLM` | Ernie4.5MoE | `baidu/ERNIE-4.5-21B-A3B-PT`, `baidu/ERNIE-4.5-300B-A47B-PT`, etc. |✅︎| ✅︎ | ✅︎ |
|
||||
| `ExaoneForCausalLM` | EXAONE-3 | `LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
|
||||
@ -193,7 +193,7 @@ For production deployments requiring strict SLA guarantees for time-to-first-tok
|
||||
|
||||
1. **Install gdrcopy/ucx/nixl**: For maximum performance, run the [install_gdrcopy.sh](gh-file:tools/install_gdrcopy.sh) script to install `gdrcopy` (e.g., `install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "12.8" "x64"`). You can find available OS versions [here](https://developer.download.nvidia.com/compute/redist/gdrcopy/CUDA%2012.8/). If `gdrcopy` is not installed, things will still work with a plain `pip install nixl`, just with lower performance. `nixl` and `ucx` are installed as dependencies via pip.
|
||||
|
||||
2. **Configure Both Instances**: Add this flag to both prefill and decode instances `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}`
|
||||
2. **Configure Both Instances**: Add this flag to both prefill and decode instances `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}`. Noted, you may also specify one or multiple NIXL_Backend. Such as: `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both", "kv_connector_extra_config":{"backend":["UCX", "GDS"]}'`
|
||||
|
||||
3. **Client Orchestration**: Use the client-side script below to coordinate prefill/decode operations. We are actively working on routing solutions.
|
||||
|
||||
|
||||
@ -101,6 +101,13 @@ def parse_args():
|
||||
"--quantization",
|
||||
type=str,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-expert-parallel",
|
||||
dest="enable_expert_parallel",
|
||||
action="store_false",
|
||||
help="Disable expert parallel (default: enabled).",
|
||||
)
|
||||
parser.set_defaults(enable_expert_parallel=True)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -113,6 +120,7 @@ def main(
|
||||
dp_master_port,
|
||||
GPUs_per_dp_rank,
|
||||
enforce_eager,
|
||||
enable_expert_parallel,
|
||||
trust_remote_code,
|
||||
max_num_seqs,
|
||||
max_model_len,
|
||||
@ -168,7 +176,7 @@ def main(
|
||||
model=model,
|
||||
tensor_parallel_size=GPUs_per_dp_rank,
|
||||
enforce_eager=enforce_eager,
|
||||
enable_expert_parallel=True,
|
||||
enable_expert_parallel=enable_expert_parallel,
|
||||
trust_remote_code=trust_remote_code,
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_model_len=max_model_len,
|
||||
@ -229,6 +237,7 @@ if __name__ == "__main__":
|
||||
dp_master_port,
|
||||
tp_size,
|
||||
args.enforce_eager,
|
||||
args.enable_expert_parallel,
|
||||
args.trust_remote_code,
|
||||
args.max_num_seqs,
|
||||
args.max_model_len,
|
||||
|
||||
81
examples/offline_inference/torchrun_dp_example.py
Normal file
81
examples/offline_inference/torchrun_dp_example.py
Normal file
@ -0,0 +1,81 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
experimental support for data-parallel inference with torchrun
|
||||
Note the data load balancing and distribution is done out of the vllm engine,
|
||||
no internal lb supported in external_launcher mode.
|
||||
"""
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# Create prompts, the same across all ranks
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
] * 50
|
||||
|
||||
# Create sampling parameters, the same across all ranks
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
# Use `distributed_executor_backend="external_launcher"` so that
|
||||
# this llm engine/instance only creates one worker.
|
||||
# it is important to set an explicit seed to make sure that
|
||||
# all ranks have the same random seed, so that sampling can be
|
||||
# deterministic across ranks.
|
||||
llm = LLM(
|
||||
model="microsoft/Phi-mini-MoE-instruct",
|
||||
tensor_parallel_size=1,
|
||||
data_parallel_size=2,
|
||||
pipeline_parallel_size=1,
|
||||
enable_expert_parallel=False,
|
||||
distributed_executor_backend="external_launcher",
|
||||
max_model_len=4096,
|
||||
gpu_memory_utilization=0.6,
|
||||
seed=1,
|
||||
)
|
||||
|
||||
dp_rank = llm.llm_engine.vllm_config.parallel_config.data_parallel_rank
|
||||
dp_size = llm.llm_engine.vllm_config.parallel_config.data_parallel_size
|
||||
|
||||
prompts = [
|
||||
f"{idx}.{prompt}" for idx, prompt in enumerate(prompts) if idx % dp_size == dp_rank
|
||||
]
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
|
||||
# all ranks will have the same outputs
|
||||
print("-" * 50)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}\n")
|
||||
print("-" * 50)
|
||||
"""
|
||||
Further tips:
|
||||
|
||||
1. to communicate control messages across all ranks, use the cpu group,
|
||||
a PyTorch ProcessGroup with GLOO backend.
|
||||
|
||||
```python
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
cpu_group = get_world_group().cpu_group
|
||||
torch_rank = dist.get_rank(group=cpu_group)
|
||||
if torch_rank == 0:
|
||||
# do something for rank 0, e.g. saving the results to disk.
|
||||
```
|
||||
|
||||
2. to communicate data across all ranks, use the model's device group,
|
||||
a PyTorch ProcessGroup with NCCL backend.
|
||||
```python
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
device_group = get_world_group().device_group
|
||||
```
|
||||
|
||||
3. to access the model directly in every rank, use the following code:
|
||||
```python
|
||||
llm.llm_engine.model_executor.driver_worker.worker.model_runner.model
|
||||
```
|
||||
"""
|
||||
@ -126,6 +126,23 @@ def run_chameleon(questions: list[str], modality: str) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
# Dots-OCR
|
||||
def run_dots_ocr(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
|
||||
prompts = [f"<|img|><|imgpad|><|endofimg|>{question}" for question in questions]
|
||||
engine_args = EngineArgs(
|
||||
model="rednote-hilab/dots.ocr",
|
||||
limit_mm_per_prompt={modality: 1},
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
def run_command_a_vision(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
|
||||
@ -1676,6 +1693,7 @@ model_example_map = {
|
||||
"aya_vision": run_aya_vision,
|
||||
"blip-2": run_blip2,
|
||||
"chameleon": run_chameleon,
|
||||
"dots_ocr": run_dots_ocr,
|
||||
"command_a_vision": run_command_a_vision,
|
||||
"deepseek_vl_v2": run_deepseek_vl2,
|
||||
"ernie45_vl": run_ernie45_vl,
|
||||
|
||||
@ -110,27 +110,6 @@ ignore_missing_imports = true
|
||||
check_untyped_defs = true
|
||||
follow_imports = "silent"
|
||||
|
||||
# After fixing type errors resulting from follow_imports: "skip" -> "silent",
|
||||
# move the directory here and remove it from tools/mypy.sh
|
||||
files = [
|
||||
"vllm/*.py",
|
||||
"vllm/assets",
|
||||
"vllm/entrypoints",
|
||||
"vllm/inputs",
|
||||
"vllm/logging_utils",
|
||||
"vllm/multimodal",
|
||||
"vllm/platforms",
|
||||
"vllm/transformers_utils",
|
||||
"vllm/triton_utils",
|
||||
"vllm/usage",
|
||||
]
|
||||
# TODO(woosuk): Include the code from Megatron and HuggingFace.
|
||||
exclude = [
|
||||
"vllm/model_executor/parallel_utils/|vllm/model_executor/models/",
|
||||
# Ignore triton kernels in ops.
|
||||
'vllm/attention/ops/.*\.py$'
|
||||
]
|
||||
|
||||
[tool.isort]
|
||||
skip_glob = [
|
||||
".buildkite/*",
|
||||
|
||||
@ -14,14 +14,4 @@ nixl==0.3.0
|
||||
tpu_info==0.4.0
|
||||
|
||||
# Install torch_xla
|
||||
--pre
|
||||
--extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
--find-links https://storage.googleapis.com/libtpu-wheels/index.html
|
||||
--find-links https://storage.googleapis.com/libtpu-releases/index.html
|
||||
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
|
||||
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||
torch==2.9.0.dev20250730
|
||||
torchvision==0.24.0.dev20250730
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250730-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250730-cp312-cp312-linux_x86_64.whl ; python_version == "3.12"
|
||||
|
||||
torch_xla[tpu, pallas]==2.8.0
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import weakref
|
||||
from collections.abc import Sequence
|
||||
from copy import deepcopy
|
||||
from typing import Callable, Union
|
||||
@ -10,7 +11,26 @@ from torch._ops import OpOverload
|
||||
|
||||
from vllm.compilation.fx_utils import find_op_nodes
|
||||
from vllm.compilation.inductor_pass import InductorPass
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.compilation.pass_manager import with_pattern_match_debug
|
||||
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
|
||||
|
||||
class LazyInitPass(InductorPass):
|
||||
"""
|
||||
If there's a pass that we want to initialize lazily in a test,
|
||||
we can wrap it in LazyInitPass, which will initialize the pass when invoked
|
||||
and then immediately invoke it.
|
||||
"""
|
||||
|
||||
def __init__(self, pass_cls: type[VllmInductorPass],
|
||||
vllm_config: VllmConfig):
|
||||
self.pass_cls = pass_cls
|
||||
self.vllm_config = weakref.proxy(vllm_config) # avoid cycle
|
||||
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
self.pass_ = self.pass_cls(self.vllm_config)
|
||||
self.pass_(graph)
|
||||
|
||||
|
||||
class TestBackend:
|
||||
@ -40,10 +60,16 @@ class TestBackend:
|
||||
example_inputs,
|
||||
config_patches=self.inductor_config)
|
||||
|
||||
@with_pattern_match_debug
|
||||
def post_pass(self, graph: fx.Graph):
|
||||
self.graph_pre_pass = deepcopy(graph)
|
||||
|
||||
VllmInductorPass.dump_prefix = 0
|
||||
for pass_ in self.custom_passes:
|
||||
pass_(graph)
|
||||
VllmInductorPass.dump_prefix += 1
|
||||
|
||||
VllmInductorPass.dump_prefix = None
|
||||
|
||||
self.graph_post_pass = deepcopy(graph)
|
||||
# assign by reference, will reflect the final state of the graph
|
||||
|
||||
@ -46,7 +46,10 @@ backend_configs = {
|
||||
# FA3 on Hopper
|
||||
"FA3":
|
||||
BackendConfig(name="FA3",
|
||||
env_vars={"VLLM_FLASH_ATTN_VERSION": "3"},
|
||||
env_vars={
|
||||
"VLLM_FLASH_ATTN_VERSION": "3",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL",
|
||||
},
|
||||
@ -66,6 +69,7 @@ backend_configs = {
|
||||
BackendConfig(name="FlashAttentionMLA",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||
@ -89,7 +93,10 @@ backend_configs = {
|
||||
# FA2
|
||||
"FA2":
|
||||
BackendConfig(name="FA2",
|
||||
env_vars={"VLLM_FLASH_ATTN_VERSION": "2"},
|
||||
env_vars={
|
||||
"VLLM_FLASH_ATTN_VERSION": "2",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL",
|
||||
}),
|
||||
|
||||
@ -294,6 +294,8 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
|
||||
compiled_model = torch.compile(model, backend=backend)
|
||||
compiled_model(hidden_states)
|
||||
|
||||
assert async_tp_pass.matched_count == 1
|
||||
|
||||
# In pre-nodes, all gather or reduce scatter should exist,
|
||||
# fused_matmul_reduce_scatter or fused_all_gather_matmul should not
|
||||
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
|
||||
|
||||
@ -4,7 +4,7 @@ import pytest
|
||||
|
||||
import vllm
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import CompilationConfig, VllmConfig
|
||||
from vllm.utils import _is_torch_equal_or_newer
|
||||
|
||||
|
||||
@ -26,6 +26,14 @@ def test_use_cudagraphs_dynamic(monkeypatch):
|
||||
assert not vllm_config.compilation_config.use_cudagraph
|
||||
|
||||
|
||||
def test_custom_op():
|
||||
# proper syntax
|
||||
_ = CompilationConfig(custom_ops=["+quant_fp8", "-silu_and_mul"])
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid syntax '"):
|
||||
_ = CompilationConfig(custom_ops=["quant_fp8"])
|
||||
|
||||
|
||||
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
|
||||
@pytest.mark.forked
|
||||
# NB: We don't test VLLM_DISABLE_COMPILE_CACHE=0 because that depends
|
||||
|
||||
@ -8,9 +8,10 @@ import vllm.envs as envs
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
|
||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||
from vllm.compilation.fusion import FUSED_OPS, FusionPass
|
||||
from vllm.compilation.fusion import FUSED_OPS, RMSNormQuantFusionPass
|
||||
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.config import CompilationConfig, PassConfig, VllmConfig
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey, kFp8DynamicTokenSym, kFp8StaticTensorSym)
|
||||
@ -58,11 +59,12 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
|
||||
vllm_config.compilation_config = CompilationConfig(
|
||||
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True))
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
fusion_pass = FusionPass.instance(vllm_config)
|
||||
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
|
||||
|
||||
passes = [noop_pass, fusion_pass, act_quant_fusion_pass
|
||||
] if do_fusion else [noop_pass]
|
||||
passes = [noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass
|
||||
] if do_fusion else [noop_pass, cleanup_pass]
|
||||
func_pass = FixFunctionalizationPass(vllm_config)
|
||||
backend_func = TestBackend(*passes, func_pass)
|
||||
backend_no_func = TestBackend(*passes)
|
||||
|
||||
@ -4,11 +4,11 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm.plugins
|
||||
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
|
||||
FusionPass)
|
||||
RMSNormQuantFusionPass)
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
|
||||
VllmConfig)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@ -79,15 +79,15 @@ class TestModel(torch.nn.Module):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("hidden_size", [64, 3392, 4096])
|
||||
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049])
|
||||
@pytest.mark.parametrize("hidden_size", [64])
|
||||
@pytest.mark.parametrize("num_tokens", [257])
|
||||
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
|
||||
@pytest.mark.parametrize("static", [True, False])
|
||||
# cuda_force_torch used to test torch code path on platforms that
|
||||
# cutlass_fp8_supported() == True.
|
||||
@pytest.mark.parametrize("cuda_force_torch",
|
||||
[True, False] if cutlass_fp8_supported() else [True])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
|
||||
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
|
||||
reason="Only test on CUDA and ROCm")
|
||||
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
|
||||
cuda_force_torch):
|
||||
@ -104,9 +104,10 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
|
||||
with vllm.config.set_current_vllm_config(vllm_config):
|
||||
# Reshape pass is needed for the fusion pass to work
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
fusion_pass = FusionPass.instance(vllm_config)
|
||||
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
|
||||
backend = TestBackend(noop_pass, fusion_pass)
|
||||
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
|
||||
model = TestModel(hidden_size, eps, static, cuda_force_torch)
|
||||
|
||||
# First dimension dynamic
|
||||
@ -128,6 +129,8 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
|
||||
|
||||
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
|
||||
|
||||
assert fusion_pass.matched_count == 2
|
||||
|
||||
# In pre-nodes, fp8 quant should be there and fused kernels should not
|
||||
backend.check_before_ops(model.ops_in_model_before())
|
||||
|
||||
|
||||
@ -9,6 +9,7 @@ import vllm.envs as envs
|
||||
from vllm.compilation.collective_fusion import AllReduceFusionPass
|
||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig,
|
||||
ModelConfig, PassConfig, VllmConfig)
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
@ -215,8 +216,10 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
|
||||
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
func_pass = FixFunctionalizationPass(vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
|
||||
backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass)
|
||||
backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass,
|
||||
cleanup_pass)
|
||||
|
||||
token_num = batch_size * seq_len
|
||||
model = test_model_cls(hidden_size, token_num)
|
||||
@ -227,6 +230,7 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
|
||||
compiled_model = torch.compile(model, backend=backend)
|
||||
compiled_model(hidden_states, residual)
|
||||
|
||||
assert all_reduce_fusion_pass.matched_count == 1
|
||||
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
|
||||
backend.check_after_ops(model.ops_in_model_after())
|
||||
del all_reduce_fusion_pass
|
||||
|
||||
@ -6,18 +6,19 @@ from typing import Optional
|
||||
import pytest
|
||||
import torch._dynamo
|
||||
|
||||
from tests.compile.backend import TestBackend
|
||||
from tests.compile.backend import LazyInitPass, TestBackend
|
||||
from tests.models.utils import check_outputs_equal
|
||||
from tests.v1.attention.utils import (BatchSpec, _Backend,
|
||||
create_common_attn_metadata)
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
from vllm.attention import Attention
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention.selector import global_force_attn_backend_context_manager
|
||||
from vllm.compilation.fusion import QUANT_OPS
|
||||
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
|
||||
from vllm.compilation.fx_utils import find_op_nodes
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
|
||||
ModelConfig, PassConfig, SchedulerConfig, VllmConfig,
|
||||
set_current_vllm_config)
|
||||
@ -104,7 +105,7 @@ def test_attention_fusion_v0(example_prompts, monkeypatch, model: str,
|
||||
|
||||
# AttnFusionPass needs attention layers to be registered in config upon init
|
||||
# so we initialize it during compilation.
|
||||
attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw)
|
||||
attn_pass = LazyInitPass(AttnFusionPass, vllm_config)
|
||||
backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass)
|
||||
llm2 = LLM(model,
|
||||
enforce_eager=True,
|
||||
@ -197,7 +198,8 @@ class AttentionQuantPatternModel(torch.nn.Module):
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def build_attn_metadata(self, batch_size: int, use_hnd: bool):
|
||||
def build_attn_metadata(self, batch_size: int, use_hnd: bool) \
|
||||
-> AttentionMetadata:
|
||||
"""Initialize attention metadata."""
|
||||
|
||||
# Create common attn metadata
|
||||
@ -447,9 +449,10 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
||||
|
||||
# Create test backend with fusion passes enabled
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw
|
||||
)
|
||||
test_backend = TestBackend(noop_pass, attn_pass)
|
||||
attn_pass = LazyInitPass(AttnFusionPass, vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
|
||||
test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass)
|
||||
|
||||
# Compile model with fusion enabled
|
||||
model_compiled = torch.compile(model_fused,
|
||||
@ -485,6 +488,9 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
||||
test_backend.check_before_ops([QUANT_OPS[quant_key]],
|
||||
fully_replaced=True)
|
||||
|
||||
# access the underlying `AttnFusionPass` on the `LazyInitPass`
|
||||
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)
|
||||
|
||||
# Check attention ops in the graph before and after fusion
|
||||
attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass))
|
||||
attn_nodes_post = list(find_op_nodes(ATTN_OP,
|
||||
|
||||
@ -6,10 +6,12 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||
from vllm.compilation.fusion import FusionPass
|
||||
from vllm.compilation.fusion import RMSNormQuantFusionPass
|
||||
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
|
||||
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
||||
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
|
||||
PassConfig, VllmConfig)
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
@ -104,7 +106,7 @@ class TestQuantModel(torch.nn.Module):
|
||||
# Initialize weights
|
||||
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
||||
|
||||
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=False)
|
||||
self.fp8_linear = Fp8LinearOp(act_quant_static=True)
|
||||
|
||||
self.scale = torch.rand(1, dtype=torch.float32)
|
||||
# Create a weight that is compatible with torch._scaled_mm,
|
||||
@ -137,8 +139,7 @@ class TestQuantModel(torch.nn.Module):
|
||||
# layer normalization
|
||||
norm_output, residual_output = self.norm(all_reduce, residual)
|
||||
|
||||
# for static input quantization
|
||||
# self.fp8_linear is initialized with use_per_token_if_dynamic=False
|
||||
# scaled_mm with static input quantization
|
||||
fp8_linear_result = self.fp8_linear.apply(norm_output,
|
||||
self.w,
|
||||
self.wscale,
|
||||
@ -253,16 +254,20 @@ def sequence_parallelism_pass_on_test_model(
|
||||
dtype=dtype,
|
||||
seed=42)
|
||||
|
||||
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
|
||||
func_pass = FixFunctionalizationPass(vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
|
||||
passes_for_backend = [noop_pass, sequence_parallelism_pass]
|
||||
passes_for_backend: list[VllmInductorPass] = \
|
||||
[noop_pass, sequence_parallelism_pass]
|
||||
|
||||
if enable_fusion:
|
||||
fusion_pass = FusionPass.instance(vllm_config)
|
||||
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
||||
passes_for_backend.append(fusion_pass)
|
||||
|
||||
passes_for_backend.append(cleanup_pass)
|
||||
|
||||
backend_no_func = TestBackend(*passes_for_backend)
|
||||
backend_func = TestBackend(*passes_for_backend, func_pass)
|
||||
|
||||
@ -279,6 +284,8 @@ def sequence_parallelism_pass_on_test_model(
|
||||
compiled_model_func = torch.compile(model, backend=backend_func)
|
||||
compiled_model_func(hidden_states, residual)
|
||||
|
||||
assert sequence_parallelism_pass.matched_count == 1
|
||||
|
||||
# In pre-nodes, all reduce should be there,
|
||||
# reduce scatter and all gather should not
|
||||
backend_no_func.check_before_ops(model.ops_in_model_before())
|
||||
|
||||
@ -15,6 +15,7 @@ from vllm.compilation.activation_quant_fusion import (
|
||||
# yapf: enable
|
||||
from vllm.compilation.fusion import QUANT_OPS
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.config import CompilationConfig, PassConfig, VllmConfig
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
@ -69,6 +70,10 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size: int, x: torch.Tensor, **kwargs):
|
||||
super().__init__()
|
||||
from vllm.compilation.activation_quant_fusion import (
|
||||
silu_and_mul_nvfp4_quant_supported)
|
||||
assert silu_and_mul_nvfp4_quant_supported
|
||||
|
||||
self.silu_and_mul = SiluAndMul()
|
||||
|
||||
# create nvfp4 weight
|
||||
@ -127,7 +132,11 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class,
|
||||
pass_config=PassConfig(enable_fusion=True, enable_noop=True))
|
||||
fusion_pass = ActivationQuantFusionPass(config)
|
||||
|
||||
backend = TestBackend(NoOpEliminationPass(config), fusion_pass)
|
||||
passes = [
|
||||
NoOpEliminationPass(config), fusion_pass,
|
||||
PostCleanupPass(config)
|
||||
]
|
||||
backend = TestBackend(*passes)
|
||||
model = model_class(hidden_size=hidden_size,
|
||||
cuda_force_torch=cuda_force_torch,
|
||||
x=x)
|
||||
@ -151,6 +160,8 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class,
|
||||
atol=atol,
|
||||
rtol=rtol)
|
||||
|
||||
assert fusion_pass.matched_count == 1
|
||||
|
||||
# In pre-nodes, quant op should be present and fused kernels should not
|
||||
backend.check_before_ops(model.ops_in_model_before())
|
||||
|
||||
|
||||
81
tests/distributed/test_torchrun_example_moe.py
Normal file
81
tests/distributed/test_torchrun_example_moe.py
Normal file
@ -0,0 +1,81 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# unit test for `examples/offline_inference/torchrun_example.py`
|
||||
import os
|
||||
import random
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.distributed.parallel_state import get_tp_group, get_world_group
|
||||
|
||||
dist.init_process_group(backend="gloo")
|
||||
|
||||
# Create prompts
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
] * 10
|
||||
dp_size = int(os.getenv("DP_SIZE", "1"))
|
||||
dp_rank = int(os.getenv("DP_RANK", "0"))
|
||||
|
||||
if dp_size > 1:
|
||||
# distribute the prompts across the data parallel ranks
|
||||
prompts = [
|
||||
prompt for idx, prompt in enumerate(prompts)
|
||||
if idx % dp_size == dp_rank
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
# set different `gpu_memory_utilization` and `swap_space` for different ranks,
|
||||
# to test if all ranks agree on the same kv cache configuration.
|
||||
llm = LLM(model="microsoft/Phi-mini-MoE-instruct",
|
||||
tensor_parallel_size=int(os.getenv("TP_SIZE", "1")),
|
||||
pipeline_parallel_size=int(os.getenv("PP_SIZE", "1")),
|
||||
enable_expert_parallel=int(os.getenv("ENABLE_EP", "0")) == 1,
|
||||
distributed_executor_backend="external_launcher",
|
||||
gpu_memory_utilization=random.uniform(0.7, 0.9),
|
||||
swap_space=random.randint(1, 4),
|
||||
seed=0)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
group = get_world_group() if dp_size == 1 else get_tp_group()
|
||||
cpu_group = group.cpu_group
|
||||
group_rank = dist.get_rank(group=cpu_group)
|
||||
|
||||
|
||||
def test_consistent_across_ranks(obj):
|
||||
if group_rank == 0:
|
||||
dist.broadcast_object_list([obj], src=group.ranks[0], group=cpu_group)
|
||||
else:
|
||||
container = [None]
|
||||
dist.broadcast_object_list(container,
|
||||
src=group.ranks[0],
|
||||
group=cpu_group)
|
||||
assert container[0] == obj
|
||||
|
||||
|
||||
test_consistent_across_ranks(
|
||||
llm.llm_engine.vllm_config.cache_config.num_cpu_blocks)
|
||||
test_consistent_across_ranks(
|
||||
llm.llm_engine.vllm_config.cache_config.num_gpu_blocks)
|
||||
|
||||
# make sure we can access the model parameters from the calling process
|
||||
# of the `LLM` instance.
|
||||
params = list(llm.llm_engine.model_executor.driver_worker.worker.model_runner.
|
||||
model.parameters())
|
||||
test_consistent_across_ranks(len(params))
|
||||
|
||||
# all ranks should have the same outputs
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
test_consistent_across_ranks(prompt)
|
||||
test_consistent_across_ranks(generated_text)
|
||||
print(f"Rank {group_rank}, Prompt: {prompt!r}, "
|
||||
f"Generated text: {generated_text!r}")
|
||||
106
tests/entrypoints/openai/test_response_api_mcp_tools.py
Normal file
106
tests/entrypoints/openai/test_response_api_mcp_tools.py
Normal file
@ -0,0 +1,106 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from openai import OpenAI
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "openai/gpt-oss-20b"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def monkeypatch_module():
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
mpatch = MonkeyPatch()
|
||||
yield mpatch
|
||||
mpatch.undo()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mcp_disabled_server(monkeypatch_module: pytest.MonkeyPatch):
|
||||
args = ["--enforce-eager", "--tool-server", "demo"]
|
||||
|
||||
with monkeypatch_module.context() as m:
|
||||
m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1")
|
||||
m.setenv("PYTHON_EXECUTION_BACKEND", "dangerously_use_uv")
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def mcp_enabled_server(monkeypatch_module: pytest.MonkeyPatch):
|
||||
args = ["--enforce-eager", "--tool-server", "demo"]
|
||||
|
||||
with monkeypatch_module.context() as m:
|
||||
m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1")
|
||||
m.setenv("PYTHON_EXECUTION_BACKEND", "dangerously_use_uv")
|
||||
m.setenv("GPT_OSS_SYSTEM_TOOL_MCP_LABELS",
|
||||
"code_interpreter,container")
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def mcp_disabled_client(mcp_disabled_server):
|
||||
async with mcp_disabled_server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def mcp_enabled_client(mcp_enabled_server):
|
||||
async with mcp_enabled_server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.skip(reason="Code interpreter tool is not available in CI yet.")
|
||||
async def test_mcp_tool_env_flag_enabled(mcp_enabled_client: OpenAI,
|
||||
model_name: str):
|
||||
response = await mcp_enabled_client.responses.create(
|
||||
model=model_name,
|
||||
# TODO: Ideally should be able to set max tool calls
|
||||
# to prevent multi-turn, but it is not currently supported
|
||||
# would speed up the test
|
||||
input=("What's the first 4 digits after the decimal point of "
|
||||
"cube root of `19910212 * 20250910`? "
|
||||
"Show only the digits. The python interpreter is not stateful "
|
||||
"and you must print to see the output."),
|
||||
tools=[{
|
||||
"type": "mcp",
|
||||
"server_label": "code_interpreter",
|
||||
# URL unused for DemoToolServer
|
||||
"server_url": "http://localhost:8888"
|
||||
}],
|
||||
)
|
||||
assert response is not None
|
||||
assert response.status == "completed"
|
||||
assert response.usage.output_tokens_details.tool_output_tokens > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.skip(reason="Code interpreter tool is not available in CI yet.")
|
||||
async def test_mcp_tool_env_flag_disabled(mcp_disabled_client: OpenAI,
|
||||
model_name: str):
|
||||
response = await mcp_disabled_client.responses.create(
|
||||
model=model_name,
|
||||
# TODO: Ideally should be able to set max tool calls
|
||||
# to prevent multi-turn, but it is not currently supported
|
||||
# would speed up the test
|
||||
input=("What's the first 4 digits after the decimal point of "
|
||||
"cube root of `19910212 * 20250910`? "
|
||||
"Show only the digits. The python interpreter is not stateful "
|
||||
"and you must print to see the output."),
|
||||
tools=[{
|
||||
"type": "mcp",
|
||||
"server_label": "code_interpreter",
|
||||
# URL unused for DemoToolServer
|
||||
"server_url": "http://localhost:8888"
|
||||
}],
|
||||
)
|
||||
assert response is not None
|
||||
assert response.status == "completed"
|
||||
assert response.usage.output_tokens_details.tool_output_tokens == 0
|
||||
@ -454,7 +454,13 @@ async def test_web_search(client: OpenAI, model_name: str):
|
||||
async def test_code_interpreter(client: OpenAI, model_name: str):
|
||||
response = await client.responses.create(
|
||||
model=model_name,
|
||||
input="Multiply 64548*15151 using builtin python interpreter.",
|
||||
# TODO: Ideally should be able to set max tool calls
|
||||
# to prevent multi-turn, but it is not currently supported
|
||||
# would speed up the test
|
||||
input=("What's the first 4 digits after the decimal point of "
|
||||
"cube root of `19910212 * 20250910`? "
|
||||
"Show only the digits. The python interpreter is not stateful "
|
||||
"and you must print to see the output."),
|
||||
tools=[{
|
||||
"type": "code_interpreter",
|
||||
"container": {
|
||||
@ -464,6 +470,7 @@ async def test_code_interpreter(client: OpenAI, model_name: str):
|
||||
)
|
||||
assert response is not None
|
||||
assert response.status == "completed"
|
||||
assert response.usage.output_tokens_details.tool_output_tokens > 0
|
||||
|
||||
|
||||
def get_weather(latitude, longitude):
|
||||
|
||||
@ -5,6 +5,11 @@ import json
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import (
|
||||
Hermes2ProToolParser)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
from ....utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
@ -37,7 +42,7 @@ TOOLS = [{
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"]
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["location"],
|
||||
@ -45,8 +50,39 @@ TOOLS = [{
|
||||
},
|
||||
}]
|
||||
|
||||
PRODUCT_TOOLS = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_product_info",
|
||||
"description": "Get detailed information of a product based on its "
|
||||
"product ID.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"inserted": {
|
||||
"type": "boolean",
|
||||
"description": "inserted.",
|
||||
},
|
||||
"product_id": {
|
||||
"type": "integer",
|
||||
"description": "The product ID of the product.",
|
||||
},
|
||||
},
|
||||
"required": ["product_id", "inserted"],
|
||||
},
|
||||
},
|
||||
}]
|
||||
|
||||
MESSAGES = [{"role": "user", "content": "What's the weather like in Boston?"}]
|
||||
|
||||
PRODUCT_MESSAGES = [{
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"Hi! Do you have any detailed information about the product id "
|
||||
"7355608 and inserted true?",
|
||||
}]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_streaming_tool_call():
|
||||
@ -113,8 +149,8 @@ async def test_streaming_tool_call():
|
||||
if tool_chunk.function.name:
|
||||
tool_call_chunks[index]["name"] += tool_chunk.function.name
|
||||
if tool_chunk.function.arguments:
|
||||
tool_call_chunks[index][
|
||||
"arguments"] += tool_chunk.function.arguments
|
||||
tool_call_chunks[index]["arguments"] += (
|
||||
tool_chunk.function.arguments)
|
||||
|
||||
assert len(tool_call_chunks) == 1
|
||||
reconstructed_tool_call = tool_call_chunks[0]
|
||||
@ -127,3 +163,295 @@ async def test_streaming_tool_call():
|
||||
print("\n[Streaming Test Passed]")
|
||||
print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}")
|
||||
print(f"Reconstructed Arguments: {arguments}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_streaming_product_tool_call():
|
||||
"""Test tool call integer and boolean parameters in non-streaming mode."""
|
||||
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server:
|
||||
client = server.get_async_client()
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model=LORA_MODEL,
|
||||
messages=PRODUCT_MESSAGES,
|
||||
tools=PRODUCT_TOOLS,
|
||||
tool_choice="auto",
|
||||
temperature=0.66,
|
||||
)
|
||||
|
||||
assert response.choices
|
||||
choice = response.choices[0]
|
||||
message = choice.message
|
||||
|
||||
assert choice.finish_reason == "tool_calls"
|
||||
assert message.tool_calls is not None
|
||||
|
||||
tool_call = message.tool_calls[0]
|
||||
assert tool_call.type == "function"
|
||||
assert tool_call.function.name == "get_product_info"
|
||||
|
||||
arguments = json.loads(tool_call.function.arguments)
|
||||
assert "product_id" in arguments
|
||||
assert "inserted" in arguments
|
||||
|
||||
product_id = arguments.get("product_id")
|
||||
inserted = arguments.get("inserted")
|
||||
|
||||
assert isinstance(product_id, int)
|
||||
assert product_id == 7355608
|
||||
assert isinstance(inserted, bool)
|
||||
assert inserted is True
|
||||
|
||||
print("\n[Non-Streaming Product Test Passed]")
|
||||
print(f"Tool Call: {tool_call.function.name}")
|
||||
print(f"Arguments: {arguments}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_product_tool_call():
|
||||
"""Test tool call integer and boolean parameters in streaming mode."""
|
||||
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server:
|
||||
client = server.get_async_client()
|
||||
|
||||
stream = await client.chat.completions.create(
|
||||
model=LORA_MODEL,
|
||||
messages=PRODUCT_MESSAGES,
|
||||
tools=PRODUCT_TOOLS,
|
||||
tool_choice="auto",
|
||||
temperature=0.66,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
tool_call_chunks = {}
|
||||
async for chunk in stream:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
|
||||
delta = chunk.choices[0].delta
|
||||
if not delta or not delta.tool_calls:
|
||||
continue
|
||||
|
||||
for tool_chunk in delta.tool_calls:
|
||||
index = tool_chunk.index
|
||||
if index not in tool_call_chunks:
|
||||
tool_call_chunks[index] = {"name": "", "arguments": ""}
|
||||
|
||||
if tool_chunk.function.name:
|
||||
tool_call_chunks[index]["name"] += tool_chunk.function.name
|
||||
if tool_chunk.function.arguments:
|
||||
tool_call_chunks[index]["arguments"] += (
|
||||
tool_chunk.function.arguments)
|
||||
|
||||
assert len(tool_call_chunks) == 1
|
||||
reconstructed_tool_call = tool_call_chunks[0]
|
||||
|
||||
assert reconstructed_tool_call["name"] == "get_product_info"
|
||||
|
||||
arguments = json.loads(reconstructed_tool_call["arguments"])
|
||||
assert "product_id" in arguments
|
||||
assert "inserted" in arguments
|
||||
|
||||
# Handle type coercion for streaming test as well
|
||||
product_id = arguments.get("product_id")
|
||||
inserted = arguments.get("inserted")
|
||||
|
||||
assert isinstance(product_id, int)
|
||||
assert product_id == 7355608
|
||||
assert isinstance(inserted, bool)
|
||||
assert inserted is True
|
||||
|
||||
print("\n[Streaming Product Test Passed]")
|
||||
print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}")
|
||||
print(f"Reconstructed Arguments: {arguments}")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def qwen_tokenizer() -> AnyTokenizer:
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
return get_tokenizer("Qwen/Qwen3-32B")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hermes_parser(qwen_tokenizer: AnyTokenizer) -> Hermes2ProToolParser:
|
||||
return Hermes2ProToolParser(qwen_tokenizer)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def any_chat_request() -> ChatCompletionRequest:
|
||||
return ChatCompletionRequest(
|
||||
seed=42,
|
||||
model="Qwen/Qwen3-32B",
|
||||
messages=[],
|
||||
)
|
||||
|
||||
|
||||
def test_hermes_parser_streaming_just_forward_text(
|
||||
qwen_tokenizer: AnyTokenizer,
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = (
|
||||
"""This is some prior text that has nothing to do with tool calling."""
|
||||
)
|
||||
tokens = qwen_tokenizer.encode(text)
|
||||
previous_text = ""
|
||||
delta_messages = []
|
||||
for token in tokens:
|
||||
delta_text = qwen_tokenizer.decode([token])
|
||||
current_text = previous_text + delta_text
|
||||
delta = hermes_parser.extract_tool_calls_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=any_chat_request,
|
||||
)
|
||||
previous_text = current_text
|
||||
delta_messages.append(delta)
|
||||
|
||||
for delta in delta_messages:
|
||||
assert delta is not None
|
||||
assert not delta.tool_calls
|
||||
|
||||
print(delta_messages)
|
||||
assert "".join([delta.content for delta in delta_messages]) == text
|
||||
|
||||
|
||||
def test_hermes_parser_streaming_failure_case_bug_19056(
|
||||
qwen_tokenizer: AnyTokenizer,
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = """<tool_call>
|
||||
{"name": "final_answer", "arguments": {"trigger": true}}
|
||||
</tool_call>"""
|
||||
tokens = qwen_tokenizer.encode(text)
|
||||
previous_text = ""
|
||||
delta_messages = []
|
||||
for token in tokens:
|
||||
text = qwen_tokenizer.decode([token])
|
||||
current_text = previous_text + text
|
||||
delta = hermes_parser.extract_tool_calls_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=text,
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=any_chat_request,
|
||||
)
|
||||
previous_text = current_text
|
||||
if delta is not None:
|
||||
delta_messages.append(delta)
|
||||
|
||||
assert delta_messages[0].tool_calls[0].function.name == "final_answer"
|
||||
tool_call_args = "".join(delta.tool_calls[0].function.arguments or ""
|
||||
for delta in delta_messages)
|
||||
assert tool_call_args == '{"trigger": true}'
|
||||
|
||||
|
||||
def test_hermes_parser_streaming(
|
||||
qwen_tokenizer: AnyTokenizer,
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = '<tool_call>\
|
||||
{"name": "get_current_temperature",\
|
||||
"arguments": {"location":\
|
||||
"San Francisco, California, United States", "unit": "celsius"}}\
|
||||
</tool_call>'
|
||||
|
||||
tokens = qwen_tokenizer.encode(text)
|
||||
previous_text = ""
|
||||
delta_messages = []
|
||||
for token in tokens:
|
||||
text = qwen_tokenizer.decode([token])
|
||||
current_text = previous_text + text
|
||||
delta = hermes_parser.extract_tool_calls_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=text,
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=any_chat_request,
|
||||
)
|
||||
previous_text = current_text
|
||||
if delta is not None:
|
||||
delta_messages.append(delta)
|
||||
print(delta_messages)
|
||||
assert (delta_messages[0].tool_calls[0].function.name ==
|
||||
"get_current_temperature")
|
||||
tool_call_args = "".join(delta.tool_calls[0].function.arguments or ""
|
||||
for delta in delta_messages)
|
||||
assert tool_call_args == (
|
||||
'{"location":"San Francisco, California, United States", '
|
||||
'"unit": "celsius"}')
|
||||
|
||||
|
||||
def test_hermes_parser_non_streaming_no_tool_call(
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = """This is not a tool call."""
|
||||
tool_call = hermes_parser.extract_tool_calls(
|
||||
model_output=text,
|
||||
request=any_chat_request,
|
||||
)
|
||||
|
||||
assert tool_call is not None
|
||||
assert not tool_call.tools_called
|
||||
|
||||
|
||||
def test_hermes_parser_non_streaming_tool_call_between_tags(
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = """<tool_call>
|
||||
{"name": "final_answer", "arguments": {"trigger": true}}
|
||||
</tool_call>"""
|
||||
tool_call = hermes_parser.extract_tool_calls(
|
||||
model_output=text,
|
||||
request=any_chat_request,
|
||||
)
|
||||
|
||||
assert tool_call is not None
|
||||
assert tool_call.tools_called
|
||||
assert tool_call.tool_calls[0].function.name == "final_answer"
|
||||
assert tool_call.tool_calls[0].function.arguments == '{"trigger": true}'
|
||||
|
||||
|
||||
def test_hermes_parser_non_streaming_tool_call_until_eos(
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = """<tool_call>
|
||||
{"name": "final_answer", "arguments": {"trigger": true}}"""
|
||||
tool_call = hermes_parser.extract_tool_calls(
|
||||
model_output=text,
|
||||
request=any_chat_request,
|
||||
)
|
||||
|
||||
assert tool_call is not None
|
||||
assert tool_call.tools_called
|
||||
assert tool_call.tool_calls[0].function.name == "final_answer"
|
||||
assert tool_call.tool_calls[0].function.arguments == '{"trigger": true}'
|
||||
|
||||
|
||||
def test_hermes_parser_non_streaming_tool_call_invalid_json(
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
# Missing closing brace to trigger exception
|
||||
text = """<tool_call>
|
||||
{"name": "final_answer", "arguments": {"trigger": true}"""
|
||||
tool_call = hermes_parser.extract_tool_calls(
|
||||
model_output=text,
|
||||
request=any_chat_request,
|
||||
)
|
||||
|
||||
assert tool_call is not None
|
||||
assert not tool_call.tools_called
|
||||
|
||||
@ -19,7 +19,7 @@ pytest -s -v tests/gsm8k/test_gsm8k_correctness.py \
|
||||
vllm serve Qwen/Qwen2.5-1.5B-Instruct --port 8000
|
||||
|
||||
# Run evaluation
|
||||
python tests/gsm8k/gsm8k_eval.py --port 8000
|
||||
python tests/evals/gsm8k/gsm8k_eval.py --port 8000
|
||||
```
|
||||
|
||||
## Configuration Format
|
||||
|
||||
@ -67,7 +67,6 @@ def generate_params():
|
||||
return params
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Skipped for now. Should be revisited.")
|
||||
@pytest.mark.parametrize("device, name, use_mla, block_size",
|
||||
generate_params())
|
||||
def test_env(
|
||||
@ -189,7 +188,7 @@ def test_env(
|
||||
# FlashMLA only supports block_size == 64
|
||||
pytest.skip("FlashMLA only supports block_size 64")
|
||||
else:
|
||||
from vllm.attention.backends.flashmla import (
|
||||
from vllm.v1.attention.backends.mla.flashmla import ( # noqa: E501
|
||||
is_flashmla_supported)
|
||||
is_supported, _ = is_flashmla_supported()
|
||||
if not is_supported:
|
||||
|
||||
@ -959,7 +959,6 @@ def make_test_metadata(
|
||||
return attn_backend_obj.make_metadata(
|
||||
num_prefills=num_prefills,
|
||||
slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
@ -1009,7 +1008,6 @@ def make_test_metadata(
|
||||
return attn_backend_obj.make_metadata(
|
||||
num_prefills=num_prefills,
|
||||
slot_mapping=kv_mmap.slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
|
||||
@ -164,8 +164,8 @@ def populate_loras(
|
||||
weight=layer_weights,
|
||||
generate_embeddings_tensor=generate_embeddings_tensor,
|
||||
)
|
||||
sublora.lora_b = sublora.lora_b[:, (sublora_len *
|
||||
i):(sublora_len * (i + 1))]
|
||||
sublora.lora_b = sublora.lora_b[(sublora_len *
|
||||
i):(sublora_len * (i + 1)), :]
|
||||
sublora.optimize()
|
||||
subloras.append(sublora)
|
||||
|
||||
@ -304,9 +304,9 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
|
||||
result = embedding(input_)
|
||||
after_a = F.embedding(
|
||||
input_,
|
||||
lora.lora_a,
|
||||
lora.lora_a.T,
|
||||
)
|
||||
result += (after_a @ lora.lora_b)
|
||||
result += (after_a @ lora.lora_b.T)
|
||||
expected_results.append(result)
|
||||
expected_result = torch.cat(expected_results)
|
||||
|
||||
@ -445,9 +445,9 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
|
||||
result = expanded_embedding(input_)
|
||||
after_a = F.embedding(
|
||||
original_input_,
|
||||
lora.lora_a,
|
||||
lora.lora_a.T,
|
||||
)
|
||||
result += (after_a @ lora.lora_b)
|
||||
result += (after_a @ lora.lora_b.T)
|
||||
expected_results.append(result)
|
||||
expected_result = torch.cat(expected_results)
|
||||
|
||||
@ -575,7 +575,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
|
||||
lm_head=linear,
|
||||
embedding_bias=None)
|
||||
result[:, vocab_size + embeddings_tensor_len:] = float("-inf")
|
||||
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
|
||||
result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
|
||||
expected_results.append(result)
|
||||
expected_result = torch.cat(expected_results)
|
||||
logits_processor.org_vocab_size = vocab_size
|
||||
@ -692,9 +692,10 @@ def test_linear_replicated(
|
||||
|
||||
expected_results: list[torch.Tensor] = []
|
||||
for input_, lora_id in zip(inputs, prompt_mapping):
|
||||
|
||||
lora = lora_dict[lora_id]
|
||||
result = linear(input_)[0]
|
||||
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
|
||||
result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
|
||||
expected_results.append(result)
|
||||
expected_result = torch.cat(expected_results)
|
||||
|
||||
@ -817,7 +818,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
|
||||
for input_, lora_id in zip(inputs, prompt_mapping):
|
||||
lora = lora_dict[lora_id]
|
||||
result = linear(input_)[0]
|
||||
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
|
||||
result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
|
||||
expected_results.append(result)
|
||||
expected_result = torch.cat(expected_results)
|
||||
|
||||
@ -965,9 +966,10 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
|
||||
result = linear(input_)[0]
|
||||
subloras = sublora_dict[lora_id]
|
||||
for i, sublora in enumerate(subloras):
|
||||
result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] *
|
||||
(i + 1)] += (input_ @ sublora.lora_a @ sublora.lora_b *
|
||||
sublora.scaling)
|
||||
result[:, sublora.lora_b.shape[0] * i:sublora.lora_b.shape[0] *
|
||||
(i + 1)] += (
|
||||
input_ @ sublora.lora_a.T @ sublora.lora_b.T *
|
||||
sublora.scaling)
|
||||
expected_results.append(result)
|
||||
expected_result = torch.cat(expected_results)
|
||||
|
||||
|
||||
@ -63,9 +63,9 @@ def test_from_lora_tensors(sql_lora_files, device):
|
||||
assert lora.lora_b is not None
|
||||
assert lora.lora_a.device == torch.device(device)
|
||||
assert lora.lora_b.device == torch.device(device)
|
||||
assert (lora.lora_a.shape[1] == lora.lora_b.shape[0]
|
||||
assert (lora.lora_a.shape[0] == lora.lora_b.shape[1]
|
||||
), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
|
||||
assert lora.lora_a.shape[1] == 8
|
||||
assert lora.lora_a.shape[0] == 8
|
||||
embeddings_module = next(
|
||||
(k for k in EMBEDDING_MODULES if k in module_name), None)
|
||||
if embeddings_module:
|
||||
@ -86,8 +86,8 @@ def create_lora(lora_id: int, model: nn.Module, sub_modules: list[str],
|
||||
name,
|
||||
8,
|
||||
16,
|
||||
torch.rand([w.shape[1], 8], device=device),
|
||||
torch.rand([8, w.shape[0]], device=device),
|
||||
torch.rand([8, w.shape[1]], device=device),
|
||||
torch.rand([w.shape[0], 8], device=device),
|
||||
)
|
||||
return LoRAModel(lora_id, 8, loras)
|
||||
|
||||
@ -109,8 +109,8 @@ def create_packed_lora(
|
||||
replaced_module_name,
|
||||
8,
|
||||
16,
|
||||
torch.rand([w.shape[1], 8], device=device),
|
||||
torch.rand([8, w.shape[0] // len(replaced_module_names)],
|
||||
torch.rand([8, w.shape[1]], device=device),
|
||||
torch.rand([w.shape[0] // len(replaced_module_names), 8],
|
||||
device=device),
|
||||
)
|
||||
return LoRAModel(lora_id, 8, loras)
|
||||
|
||||
@ -36,10 +36,10 @@ class DummyLoRAManager:
|
||||
module_name,
|
||||
rank=rank,
|
||||
lora_alpha=1,
|
||||
lora_a=torch.rand([weight.shape[1], rank],
|
||||
lora_a=torch.rand([rank, weight.shape[1]],
|
||||
dtype=weight.dtype,
|
||||
device=self._device),
|
||||
lora_b=torch.rand([rank, weight.shape[0]],
|
||||
lora_b=torch.rand([weight.shape[0], rank],
|
||||
dtype=weight.dtype,
|
||||
device=self._device),
|
||||
)
|
||||
@ -67,8 +67,8 @@ class DummyLoRAManager:
|
||||
module_name,
|
||||
rank=rank,
|
||||
lora_alpha=1,
|
||||
lora_a=torch.rand([input_dim, rank], device="cuda"),
|
||||
lora_b=torch.rand([rank, output_dim], device="cuda"),
|
||||
lora_a=torch.rand([rank, input_dim], device="cuda"),
|
||||
lora_b=torch.rand([output_dim, input_dim], device="cuda"),
|
||||
embeddings_tensor=embeddings_tensor,
|
||||
)
|
||||
self.set_module_lora(module_name, lora)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -34,15 +35,15 @@ class Relu3(ReLUSquaredActivation):
|
||||
[
|
||||
# Default values based on compile level
|
||||
# - All by default (no Inductor compilation)
|
||||
("", 0, False, [True] * 4, True),
|
||||
("", 1, True, [True] * 4, True),
|
||||
("", 2, False, [True] * 4, True),
|
||||
(None, 0, False, [True] * 4, True),
|
||||
(None, 1, True, [True] * 4, True),
|
||||
(None, 2, False, [True] * 4, True),
|
||||
# - None by default (with Inductor)
|
||||
("", 3, True, [False] * 4, False),
|
||||
("", 4, True, [False] * 4, False),
|
||||
(None, 3, True, [False] * 4, False),
|
||||
(None, 4, True, [False] * 4, False),
|
||||
# - All by default (without Inductor)
|
||||
("", 3, False, [True] * 4, True),
|
||||
("", 4, False, [True] * 4, True),
|
||||
(None, 3, False, [True] * 4, True),
|
||||
(None, 4, False, [True] * 4, True),
|
||||
# Explicitly enabling/disabling
|
||||
#
|
||||
# Default: all
|
||||
@ -54,7 +55,7 @@ class Relu3(ReLUSquaredActivation):
|
||||
# All but SiluAndMul
|
||||
("all,-silu_and_mul", 2, True, [1, 0, 1, 1], True),
|
||||
# All but ReLU3 (even if ReLU2 is on)
|
||||
("-relu3,relu2", 3, False, [1, 1, 1, 0], True),
|
||||
("-relu3,+relu2", 3, False, [1, 1, 1, 0], True),
|
||||
# RMSNorm and SiluAndMul
|
||||
("none,-relu3,+rms_norm,+silu_and_mul", 4, False, [1, 1, 0, 0], False),
|
||||
# All but RMSNorm
|
||||
@ -67,12 +68,13 @@ class Relu3(ReLUSquaredActivation):
|
||||
# All but RMSNorm
|
||||
("all,-rms_norm", 4, True, [0, 1, 1, 1], True),
|
||||
])
|
||||
def test_enabled_ops(env: str, torch_level: int, use_inductor: bool,
|
||||
def test_enabled_ops(env: Optional[str], torch_level: int, use_inductor: bool,
|
||||
ops_enabled: list[int], default_on: bool):
|
||||
custom_ops = env.split(',') if env else []
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(use_inductor=bool(use_inductor),
|
||||
level=torch_level,
|
||||
custom_ops=env.split(",")))
|
||||
custom_ops=custom_ops))
|
||||
with set_current_vllm_config(vllm_config):
|
||||
assert CustomOp.default_on() == default_on
|
||||
|
||||
|
||||
@ -20,7 +20,9 @@ pytestmark = pytest.mark.hybrid_model
|
||||
SSM_MODELS = [
|
||||
"state-spaces/mamba-130m-hf",
|
||||
"tiiuae/falcon-mamba-tiny-dev",
|
||||
"yujiepan/mamba2-codestral-v0.1-tiny-random",
|
||||
# mamba2-codestral in transformers is broken pending:
|
||||
# https://github.com/huggingface/transformers/pull/40861
|
||||
#"yujiepan/mamba2-codestral-v0.1-tiny-random",
|
||||
]
|
||||
|
||||
HYBRID_MODELS = [
|
||||
@ -31,18 +33,7 @@ HYBRID_MODELS = [
|
||||
"ibm-granite/granite-4.0-tiny-preview",
|
||||
"tiiuae/Falcon-H1-0.5B-Base",
|
||||
"LiquidAI/LFM2-1.2B",
|
||||
]
|
||||
|
||||
V1_SUPPORTED_MODELS = [
|
||||
"state-spaces/mamba-130m-hf",
|
||||
"ai21labs/Jamba-tiny-dev",
|
||||
"pfnet/plamo-2-1b",
|
||||
"yujiepan/mamba2-codestral-v0.1-tiny-random",
|
||||
"Zyphra/Zamba2-1.2B-instruct",
|
||||
"hmellor/tiny-random-BambaForCausalLM",
|
||||
"ibm-granite/granite-4.0-tiny-preview",
|
||||
"tiiuae/Falcon-H1-0.5B-Base",
|
||||
"LiquidAI/LFM2-1.2B",
|
||||
"tiny-random/qwen3-next-moe",
|
||||
]
|
||||
|
||||
FULL_CUDA_GRAPH_MODELS = [
|
||||
@ -51,10 +42,6 @@ FULL_CUDA_GRAPH_MODELS = [
|
||||
"Zyphra/Zamba2-1.2B-instruct",
|
||||
]
|
||||
|
||||
V0_UNSUPPORTED_MODELS = [
|
||||
"LiquidAI/LFM2-1.2B",
|
||||
]
|
||||
|
||||
FP32_STATE_MODELS = [
|
||||
"state-spaces/mamba-130m-hf",
|
||||
"Zyphra/Zamba2-1.2B-instruct",
|
||||
@ -88,20 +75,16 @@ def test_models(
|
||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
if model in V1_SUPPORTED_MODELS:
|
||||
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
else:
|
||||
vllm_v1_outputs = None
|
||||
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
if model in V1_SUPPORTED_MODELS:
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_v1_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm-v1",
|
||||
)
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
|
||||
@ -299,14 +282,14 @@ def test_full_cuda_graph(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
||||
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_v1_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm-v1",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
@ -340,12 +323,12 @@ def test_fp32_cache_state(
|
||||
with vllm_runner(model,
|
||||
max_num_seqs=MAX_NUM_SEQS,
|
||||
**{cache_dtype_param: "float32"}) as vllm_model:
|
||||
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
||||
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_v1_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm-v1",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
@ -209,7 +209,6 @@ def batch_make_video_embeddings(
|
||||
return visual(pixel_values_on_device,
|
||||
grid_thw=video_grid_thw_on_device).cpu()
|
||||
|
||||
# V1 Test: this calls a V0 internal.
|
||||
video_embeds = torch.concat(llm.apply_model(get_image_embeds))
|
||||
|
||||
# split into original batches
|
||||
|
||||
@ -312,14 +312,12 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"),
|
||||
"PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"),
|
||||
"Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"),
|
||||
"Phi4FlashForCausalLM": _HfExamplesInfo("microsoft/Phi-4-mini-flash-reasoning", # noqa: E501
|
||||
trust_remote_code=True,
|
||||
v0_only=True,
|
||||
max_model_len=10240),
|
||||
"PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct",
|
||||
trust_remote_code=True),
|
||||
"Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b",
|
||||
trust_remote_code=True),
|
||||
max_transformers_version="4.55.4",
|
||||
transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers", # noqa: E501
|
||||
trust_remote_code=True),
|
||||
"QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat",
|
||||
max_transformers_version="4.53",
|
||||
transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers", # noqa: E501
|
||||
@ -330,7 +328,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
|
||||
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
|
||||
"Qwen3NextForCausalLM": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct",
|
||||
min_transformers_version="4.56.2"),
|
||||
extras={"tiny-random": "tiny-random/qwen3-next-moe"}, # noqa: E501
|
||||
min_transformers_version="4.56.3"),
|
||||
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
|
||||
"SeedOssForCausalLM": _HfExamplesInfo("ByteDance-Seed/Seed-OSS-36B-Instruct", # noqa: E501
|
||||
trust_remote_code=True,
|
||||
@ -448,6 +447,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
max_transformers_version="4.48", # noqa: E501
|
||||
transformers_version_reason="HF model is not compatible.", # noqa: E501
|
||||
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501
|
||||
"DotsOCRForCausalLM": _HfExamplesInfo("rednote-hilab/dots.ocr",
|
||||
trust_remote_code=True),
|
||||
"Emu3ForConditionalGeneration": _HfExamplesInfo("BAAI/Emu3-Chat-hf"),
|
||||
"Ernie4_5_VLMoeForConditionalGeneration": _HfExamplesInfo("baidu/ERNIE-4.5-VL-28B-A3B-PT", # noqa: E501
|
||||
trust_remote_code=True),
|
||||
@ -560,10 +561,12 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
"Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ"), # noqa: E501
|
||||
"Qwen3VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen3-VL-4B-Instruct", # noqa: E501
|
||||
max_model_len=4096,
|
||||
min_transformers_version="4.57"), # noqa: E501
|
||||
min_transformers_version="4.57",
|
||||
is_available_online=False),
|
||||
"Qwen3VLMoeForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen3-VL-30B-A3B-Instruct", # noqa: E501
|
||||
max_model_len=4096,
|
||||
min_transformers_version="4.57"),
|
||||
max_model_len=4096,
|
||||
min_transformers_version="4.57",
|
||||
is_available_online=False),
|
||||
"RForConditionalGeneration": _HfExamplesInfo("YannQi/R-4B",
|
||||
trust_remote_code=True),
|
||||
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B",
|
||||
@ -640,7 +643,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
||||
trust_remote_code=True,
|
||||
speculative_model="XiaomiMiMo/MiMo-7B-RL"),
|
||||
"Qwen3NextMTP": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct",
|
||||
min_transformers_version="4.56.2"),
|
||||
min_transformers_version="4.56.3"),
|
||||
}
|
||||
|
||||
_TRANSFORMERS_BACKEND_MODELS = {
|
||||
|
||||
@ -1,10 +1,20 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from vllm.model_executor.models.vision import resolve_visual_encoder_outputs
|
||||
from tests.utils import multi_gpu_test
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import (init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
from vllm.model_executor.models.vision import (
|
||||
get_load_balance_assignment, resolve_visual_encoder_outputs,
|
||||
run_dp_sharded_mrope_vision_model, run_dp_sharded_vision_model)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import get_open_port, update_environment_variables
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -33,3 +43,415 @@ def test_resolve_visual_encoder_outputs(feature_sample_layers,
|
||||
post_layer_norm=None,
|
||||
max_possible_layers=max_possible_layers)
|
||||
assert torch.equal(torch.tensor(expected_features), output_tensor)
|
||||
|
||||
|
||||
class SimpleLinearModel(torch.nn.Module):
|
||||
"""A simple linear vision model for testing."""
|
||||
|
||||
def __init__(self, input_dim: int = 3 * 224 * 224, output_dim: int = 32):
|
||||
super().__init__()
|
||||
self.flatten = torch.nn.Flatten()
|
||||
self.linear = torch.nn.Linear(input_dim, output_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# Flatten the input and apply linear transformation
|
||||
x = self.flatten(x)
|
||||
return self.linear(x)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size",
|
||||
[
|
||||
1, # Single image
|
||||
4, # Small batch
|
||||
5, # Odd batch size (for testing padding)
|
||||
],
|
||||
)
|
||||
def test_run_dp_sharded_vision_model(batch_size: int):
|
||||
world_size = 2
|
||||
# Launch processes
|
||||
mp.spawn(
|
||||
run_dp_sharded_vision_model_vs_direct,
|
||||
args=(
|
||||
world_size,
|
||||
batch_size,
|
||||
get_open_port(),
|
||||
),
|
||||
nprocs=world_size,
|
||||
)
|
||||
|
||||
|
||||
def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int,
|
||||
batch_size: int, master_port: int):
|
||||
"""
|
||||
Test that run_dp_sharded_vision_model produces the same results as
|
||||
calling the model directly.
|
||||
"""
|
||||
|
||||
# Set random seed for reproducibility
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
device = f"{current_platform.device_name}:{local_rank}"
|
||||
current_platform.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
|
||||
update_environment_variables({
|
||||
'RANK': str(local_rank),
|
||||
'LOCAL_RANK': str(local_rank),
|
||||
'WORLD_SIZE': str(world_size),
|
||||
'MASTER_ADDR': 'localhost',
|
||||
'MASTER_PORT': str(master_port),
|
||||
})
|
||||
|
||||
# initialize distributed
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
# Create a test input tensor
|
||||
image_input = torch.randn(batch_size, 3, 224, 224)
|
||||
|
||||
# Create a simple linear model
|
||||
vision_model = SimpleLinearModel()
|
||||
|
||||
# Run the model directly on the full input
|
||||
with torch.inference_mode():
|
||||
direct_output = vision_model(image_input)
|
||||
|
||||
# Run the model through the sharded function
|
||||
with torch.inference_mode():
|
||||
sharded_output = run_dp_sharded_vision_model(image_input, vision_model)
|
||||
|
||||
# Check that the world size is set up correctly
|
||||
assert get_tensor_model_parallel_world_size() == world_size
|
||||
|
||||
# Check that the outputs have the same shape
|
||||
assert direct_output.shape == sharded_output.shape
|
||||
|
||||
# Check that the outputs are close (they should be identical)
|
||||
assert torch.allclose(direct_output, sharded_output, rtol=1e-5, atol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sizes,num_gpus,expected_shuffle_indices,expected_gpu_sample_counts,"
|
||||
"expected_grouped_sizes_per_gpu,test_description",
|
||||
[
|
||||
# Empty input
|
||||
([], 2, [], [0, 0], [0, 0], "empty input"),
|
||||
|
||||
# Fewer samples than GPUs
|
||||
([100, 200], 4, [1, 0], [1, 1, 0, 0], [200, 100, 0, 0
|
||||
], "fewer samples than GPUs"),
|
||||
|
||||
# Single GPU
|
||||
([100, 200, 300], 1, [2, 1, 0], [3], [600], "single GPU"),
|
||||
|
||||
# Balanced assignment
|
||||
([100, 100, 100, 100
|
||||
], 2, [0, 2, 1, 3], [2, 2], [200, 200], "balanced assignment"),
|
||||
|
||||
# Unbalanced sizes - this one is trickier since the algorithm is greedy
|
||||
([1000, 100, 200, 50], 2, [0, 2, 1, 3
|
||||
], [1, 3], [1000, 350], "unbalanced sizes"),
|
||||
],
|
||||
)
|
||||
def test_get_load_balance_assignment_cases(sizes, num_gpus,
|
||||
expected_shuffle_indices,
|
||||
expected_gpu_sample_counts,
|
||||
expected_grouped_sizes_per_gpu,
|
||||
test_description):
|
||||
"""Test get_load_balance_assignment with various input cases."""
|
||||
result = get_load_balance_assignment(sizes, num_gpus=num_gpus)
|
||||
(shuffle_indices, gpu_sample_counts, grouped_sizes_per_gpu) = result
|
||||
|
||||
# Common assertions for all cases
|
||||
assert len(shuffle_indices) == len(sizes)
|
||||
assert len(gpu_sample_counts) == num_gpus
|
||||
assert len(grouped_sizes_per_gpu) == num_gpus
|
||||
assert sum(gpu_sample_counts) == len(sizes)
|
||||
|
||||
assert shuffle_indices == expected_shuffle_indices
|
||||
|
||||
assert gpu_sample_counts == expected_gpu_sample_counts
|
||||
assert grouped_sizes_per_gpu == expected_grouped_sizes_per_gpu
|
||||
|
||||
|
||||
class SimpleMRopeVisionModel(torch.nn.Module):
|
||||
"""A simple vision model for testing mrope functionality."""
|
||||
|
||||
def __init__(self, spatial_merge_size: int = 2, out_hidden_size: int = 64):
|
||||
super().__init__()
|
||||
self.spatial_merge_size = spatial_merge_size
|
||||
self.out_hidden_size = out_hidden_size
|
||||
self.linear = torch.nn.Linear(768, out_hidden_size)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor,
|
||||
grid_thw_list: list[list[int]]):
|
||||
"""Simple forward pass that simulates spatial merging."""
|
||||
# Apply linear transformation
|
||||
embeddings = self.linear(pixel_values)
|
||||
|
||||
# Simulate spatial merging by reducing the number of patches
|
||||
merge_factor = self.spatial_merge_size * self.spatial_merge_size
|
||||
|
||||
# Group patches and merge spatially
|
||||
merged_embeddings = []
|
||||
start_idx = 0
|
||||
|
||||
for grid_thw in grid_thw_list:
|
||||
num_patches = math.prod(grid_thw)
|
||||
end_idx = start_idx + num_patches
|
||||
|
||||
# Get patches for this image
|
||||
image_patches = embeddings[start_idx:end_idx]
|
||||
|
||||
# Simulate spatial merging by averaging groups of patches
|
||||
merged_patches = num_patches // merge_factor
|
||||
if merged_patches > 0:
|
||||
# Reshape and average to simulate merging
|
||||
reshaped = image_patches[:merged_patches * merge_factor].view(
|
||||
merged_patches, merge_factor, -1)
|
||||
merged = reshaped.mean(dim=1)
|
||||
merged_embeddings.append(merged)
|
||||
|
||||
start_idx = end_idx
|
||||
|
||||
if merged_embeddings:
|
||||
return torch.cat(merged_embeddings, dim=0)
|
||||
else:
|
||||
return torch.empty((0, self.out_hidden_size),
|
||||
device=pixel_values.device,
|
||||
dtype=pixel_values.dtype)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size",
|
||||
[
|
||||
1, # Single image
|
||||
3, # Small batch
|
||||
5, # Odd batch size (for testing padding)
|
||||
],
|
||||
)
|
||||
def test_run_dp_sharded_mrope_vision_model(batch_size: int):
|
||||
world_size = 2
|
||||
# Launch processes
|
||||
mp.spawn(
|
||||
run_dp_sharded_mrope_vision_model_vs_direct,
|
||||
args=(
|
||||
world_size,
|
||||
batch_size,
|
||||
get_open_port(),
|
||||
),
|
||||
nprocs=world_size,
|
||||
)
|
||||
|
||||
|
||||
def run_dp_sharded_mrope_vision_model_vs_direct(local_rank: int,
|
||||
world_size: int,
|
||||
batch_size: int,
|
||||
master_port: int):
|
||||
"""
|
||||
Test that run_dp_sharded_mrope_vision_model produces the same results as
|
||||
calling the model directly.
|
||||
"""
|
||||
# Set random seed for reproducibility
|
||||
current_platform.seed_everything(0)
|
||||
device = f"{current_platform.device_name}:{local_rank}"
|
||||
current_platform.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
|
||||
update_environment_variables({
|
||||
'RANK': str(local_rank),
|
||||
'LOCAL_RANK': str(local_rank),
|
||||
'WORLD_SIZE': str(world_size),
|
||||
'MASTER_ADDR': 'localhost',
|
||||
'MASTER_PORT': str(master_port),
|
||||
})
|
||||
|
||||
# initialize distributed
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
# Create test data
|
||||
grid_thw_list = []
|
||||
pixel_values_list = []
|
||||
|
||||
for i in range(batch_size):
|
||||
# Varying image sizes for better testing
|
||||
t, h, w = 1, 4 + i, 4 + i
|
||||
grid_thw_list.append([t, h, w])
|
||||
|
||||
num_patches = t * h * w
|
||||
# Create random pixel values for this image
|
||||
image_pixels = torch.randn(num_patches, 768)
|
||||
pixel_values_list.append(image_pixels)
|
||||
|
||||
# Concatenate all pixel values
|
||||
pixel_values = torch.cat(pixel_values_list, dim=0)
|
||||
|
||||
# Create a simple mrope vision model
|
||||
vision_model = SimpleMRopeVisionModel()
|
||||
|
||||
# Run the model directly on the full input (only on rank 0)
|
||||
if local_rank == 0:
|
||||
with torch.inference_mode():
|
||||
direct_output = vision_model(pixel_values, grid_thw_list)
|
||||
|
||||
# Run the model through the sharded function
|
||||
with torch.inference_mode():
|
||||
sharded_output = run_dp_sharded_mrope_vision_model(vision_model,
|
||||
pixel_values,
|
||||
grid_thw_list,
|
||||
rope_type="rope_3d")
|
||||
sharded_output = torch.cat(sharded_output, dim=0)
|
||||
|
||||
# Check that the world size is set up correctly
|
||||
assert get_tensor_model_parallel_world_size() == world_size
|
||||
|
||||
# Compare outputs (only on rank 0)
|
||||
if local_rank == 0:
|
||||
# Check that the outputs have the same shape
|
||||
assert direct_output.shape == sharded_output.shape
|
||||
# Check that the outputs are close (they should be identical)
|
||||
assert torch.allclose(direct_output,
|
||||
sharded_output,
|
||||
rtol=1e-5,
|
||||
atol=1e-5)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
def test_run_dp_sharded_mrope_vision_model_empty_input():
|
||||
world_size = 2
|
||||
mp.spawn(
|
||||
run_dp_sharded_mrope_vision_model_empty_input_worker,
|
||||
args=(world_size, get_open_port()),
|
||||
nprocs=world_size,
|
||||
)
|
||||
|
||||
|
||||
def run_dp_sharded_mrope_vision_model_empty_input_worker(
|
||||
local_rank: int, world_size: int, master_port: int):
|
||||
"""Test run_dp_sharded_mrope_vision_model with empty input."""
|
||||
# Set up distributed environment
|
||||
device = f"{current_platform.device_name}:{local_rank}"
|
||||
current_platform.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
|
||||
update_environment_variables({
|
||||
'RANK': str(local_rank),
|
||||
'LOCAL_RANK': str(local_rank),
|
||||
'WORLD_SIZE': str(world_size),
|
||||
'MASTER_ADDR': 'localhost',
|
||||
'MASTER_PORT': str(master_port),
|
||||
})
|
||||
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
# Create empty inputs
|
||||
pixel_values = torch.empty((0, 768))
|
||||
grid_thw_list: list[list[int]] = []
|
||||
|
||||
vision_model = SimpleMRopeVisionModel()
|
||||
|
||||
# Should handle empty input gracefully
|
||||
with torch.inference_mode():
|
||||
output = run_dp_sharded_mrope_vision_model(vision_model,
|
||||
pixel_values,
|
||||
grid_thw_list,
|
||||
rope_type="rope_3d")
|
||||
|
||||
assert len(output) == 0
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=4)
|
||||
def test_run_dp_sharded_mrope_vision_model_uneven_load():
|
||||
world_size = 4
|
||||
mp.spawn(
|
||||
run_dp_sharded_mrope_vision_model_uneven_load_worker,
|
||||
args=(world_size, get_open_port()),
|
||||
nprocs=world_size,
|
||||
)
|
||||
|
||||
|
||||
def run_dp_sharded_mrope_vision_model_uneven_load_worker(
|
||||
local_rank: int, world_size: int, master_port: int):
|
||||
"""Test run_dp_sharded_mrope_vision_model with uneven load distribution."""
|
||||
# Set up distributed environment
|
||||
current_platform.seed_everything(123)
|
||||
device = f"{current_platform.device_name}:{local_rank}"
|
||||
current_platform.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
|
||||
update_environment_variables({
|
||||
'RANK': str(local_rank),
|
||||
'LOCAL_RANK': str(local_rank),
|
||||
'WORLD_SIZE': str(world_size),
|
||||
'MASTER_ADDR': 'localhost',
|
||||
'MASTER_PORT': str(master_port),
|
||||
})
|
||||
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
# Create images with very different sizes
|
||||
grid_thw_list = [
|
||||
[1, 2, 2], # Small: 4 patches
|
||||
[1, 8, 8], # Large: 64 patches
|
||||
[1, 3, 3], # Medium: 9 patches
|
||||
]
|
||||
|
||||
pixel_values_list = []
|
||||
for grid_thw in grid_thw_list:
|
||||
num_patches = math.prod(grid_thw)
|
||||
image_pixels = torch.randn(num_patches, 768)
|
||||
pixel_values_list.append(image_pixels)
|
||||
|
||||
pixel_values = torch.cat(pixel_values_list, dim=0)
|
||||
vision_model = SimpleMRopeVisionModel()
|
||||
|
||||
# Should handle uneven distribution without errors
|
||||
with torch.inference_mode():
|
||||
output_tuple = run_dp_sharded_mrope_vision_model(vision_model,
|
||||
pixel_values,
|
||||
grid_thw_list,
|
||||
rope_type="rope_3d")
|
||||
|
||||
# Verify output shape is reasonable
|
||||
merge_factor = vision_model.spatial_merge_size**2
|
||||
expected_output_patches = list(
|
||||
math.prod(grid_thw) // merge_factor for grid_thw in grid_thw_list)
|
||||
|
||||
for i, output in enumerate(output_tuple):
|
||||
assert output.shape[0] == expected_output_patches[i]
|
||||
assert output.shape[1] == vision_model.out_hidden_size
|
||||
|
||||
|
||||
@pytest.mark.parametrize("spatial_merge_size", [2, 4])
|
||||
def test_simple_mrope_vision_model_spatial_merge(spatial_merge_size: int):
|
||||
"""Test SimpleMRopeVisionModel with different spatial merge sizes."""
|
||||
device = current_platform.device_type
|
||||
|
||||
grid_thw_list = [[1, 4, 4], [1, 6, 6]] # Two images
|
||||
pixel_values_list = []
|
||||
|
||||
for grid_thw in grid_thw_list:
|
||||
num_patches = math.prod(grid_thw)
|
||||
image_pixels = torch.randn(num_patches, 768, device=device)
|
||||
pixel_values_list.append(image_pixels)
|
||||
|
||||
pixel_values = torch.cat(pixel_values_list, dim=0)
|
||||
vision_model = SimpleMRopeVisionModel(
|
||||
spatial_merge_size=spatial_merge_size).to(device)
|
||||
|
||||
with torch.inference_mode():
|
||||
output = vision_model(pixel_values, grid_thw_list)
|
||||
|
||||
# Verify output dimensions based on spatial merging
|
||||
total_patches = sum(math.prod(grid_thw) for grid_thw in grid_thw_list)
|
||||
merge_factor = spatial_merge_size**2
|
||||
expected_output_patches = total_patches // merge_factor
|
||||
|
||||
assert output.shape[0] == expected_output_patches
|
||||
assert output.shape[1] == vision_model.out_hidden_size
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import base64
|
||||
import math
|
||||
import mimetypes
|
||||
import os
|
||||
from tempfile import NamedTemporaryFile, TemporaryDirectory
|
||||
@ -10,22 +9,11 @@ from typing import TYPE_CHECKING, NamedTuple
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from PIL import Image, ImageChops
|
||||
|
||||
from tests.utils import multi_gpu_test
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import (init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
from vllm.multimodal.image import convert_image_mode
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
from vllm.multimodal.utils import (MediaConnector, argsort_mm_positions,
|
||||
get_load_balance_assignment,
|
||||
run_dp_sharded_mrope_vision_model,
|
||||
run_dp_sharded_vision_model)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import get_open_port, update_environment_variables
|
||||
from vllm.multimodal.utils import MediaConnector, argsort_mm_positions
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.multimodal.inputs import MultiModalPlaceholderDict
|
||||
@ -404,415 +392,3 @@ def test_argsort_mm_positions():
|
||||
modality_idxs = argsort_mm_positions(mm_positions)
|
||||
|
||||
assert modality_idxs == expected_modality_idxs
|
||||
|
||||
|
||||
class SimpleLinearModel(torch.nn.Module):
|
||||
"""A simple linear vision model for testing."""
|
||||
|
||||
def __init__(self, input_dim: int = 3 * 224 * 224, output_dim: int = 32):
|
||||
super().__init__()
|
||||
self.flatten = torch.nn.Flatten()
|
||||
self.linear = torch.nn.Linear(input_dim, output_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# Flatten the input and apply linear transformation
|
||||
x = self.flatten(x)
|
||||
return self.linear(x)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size",
|
||||
[
|
||||
1, # Single image
|
||||
4, # Small batch
|
||||
5, # Odd batch size (for testing padding)
|
||||
],
|
||||
)
|
||||
def test_run_dp_sharded_vision_model(batch_size: int):
|
||||
world_size = 2
|
||||
# Launch processes
|
||||
mp.spawn(
|
||||
run_dp_sharded_vision_model_vs_direct,
|
||||
args=(
|
||||
world_size,
|
||||
batch_size,
|
||||
get_open_port(),
|
||||
),
|
||||
nprocs=world_size,
|
||||
)
|
||||
|
||||
|
||||
def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int,
|
||||
batch_size: int, master_port: int):
|
||||
"""
|
||||
Test that run_dp_sharded_vision_model produces the same results as
|
||||
calling the model directly.
|
||||
"""
|
||||
|
||||
# Set random seed for reproducibility
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
device = f"{current_platform.device_name}:{local_rank}"
|
||||
current_platform.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
|
||||
update_environment_variables({
|
||||
'RANK': str(local_rank),
|
||||
'LOCAL_RANK': str(local_rank),
|
||||
'WORLD_SIZE': str(world_size),
|
||||
'MASTER_ADDR': 'localhost',
|
||||
'MASTER_PORT': str(master_port),
|
||||
})
|
||||
|
||||
# initialize distributed
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
# Create a test input tensor
|
||||
image_input = torch.randn(batch_size, 3, 224, 224)
|
||||
|
||||
# Create a simple linear model
|
||||
vision_model = SimpleLinearModel()
|
||||
|
||||
# Run the model directly on the full input
|
||||
with torch.inference_mode():
|
||||
direct_output = vision_model(image_input)
|
||||
|
||||
# Run the model through the sharded function
|
||||
with torch.inference_mode():
|
||||
sharded_output = run_dp_sharded_vision_model(image_input, vision_model)
|
||||
|
||||
# Check that the world size is set up correctly
|
||||
assert get_tensor_model_parallel_world_size() == world_size
|
||||
|
||||
# Check that the outputs have the same shape
|
||||
assert direct_output.shape == sharded_output.shape
|
||||
|
||||
# Check that the outputs are close (they should be identical)
|
||||
assert torch.allclose(direct_output, sharded_output, rtol=1e-5, atol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sizes,num_gpus,expected_shuffle_indices,expected_gpu_sample_counts,"
|
||||
"expected_grouped_sizes_per_gpu,test_description",
|
||||
[
|
||||
# Empty input
|
||||
([], 2, [], [0, 0], [0, 0], "empty input"),
|
||||
|
||||
# Fewer samples than GPUs
|
||||
([100, 200], 4, [1, 0], [1, 1, 0, 0], [200, 100, 0, 0
|
||||
], "fewer samples than GPUs"),
|
||||
|
||||
# Single GPU
|
||||
([100, 200, 300], 1, [2, 1, 0], [3], [600], "single GPU"),
|
||||
|
||||
# Balanced assignment
|
||||
([100, 100, 100, 100
|
||||
], 2, [0, 2, 1, 3], [2, 2], [200, 200], "balanced assignment"),
|
||||
|
||||
# Unbalanced sizes - this one is trickier since the algorithm is greedy
|
||||
([1000, 100, 200, 50], 2, [0, 2, 1, 3
|
||||
], [1, 3], [1000, 350], "unbalanced sizes"),
|
||||
],
|
||||
)
|
||||
def test_get_load_balance_assignment_cases(sizes, num_gpus,
|
||||
expected_shuffle_indices,
|
||||
expected_gpu_sample_counts,
|
||||
expected_grouped_sizes_per_gpu,
|
||||
test_description):
|
||||
"""Test get_load_balance_assignment with various input cases."""
|
||||
result = get_load_balance_assignment(sizes, num_gpus=num_gpus)
|
||||
(shuffle_indices, gpu_sample_counts, grouped_sizes_per_gpu) = result
|
||||
|
||||
# Common assertions for all cases
|
||||
assert len(shuffle_indices) == len(sizes)
|
||||
assert len(gpu_sample_counts) == num_gpus
|
||||
assert len(grouped_sizes_per_gpu) == num_gpus
|
||||
assert sum(gpu_sample_counts) == len(sizes)
|
||||
|
||||
assert shuffle_indices == expected_shuffle_indices
|
||||
|
||||
assert gpu_sample_counts == expected_gpu_sample_counts
|
||||
assert grouped_sizes_per_gpu == expected_grouped_sizes_per_gpu
|
||||
|
||||
|
||||
class SimpleMRopeVisionModel(torch.nn.Module):
|
||||
"""A simple vision model for testing mrope functionality."""
|
||||
|
||||
def __init__(self, spatial_merge_size: int = 2, out_hidden_size: int = 64):
|
||||
super().__init__()
|
||||
self.spatial_merge_size = spatial_merge_size
|
||||
self.out_hidden_size = out_hidden_size
|
||||
self.linear = torch.nn.Linear(768, out_hidden_size)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor,
|
||||
grid_thw_list: list[list[int]]):
|
||||
"""Simple forward pass that simulates spatial merging."""
|
||||
# Apply linear transformation
|
||||
embeddings = self.linear(pixel_values)
|
||||
|
||||
# Simulate spatial merging by reducing the number of patches
|
||||
merge_factor = self.spatial_merge_size * self.spatial_merge_size
|
||||
|
||||
# Group patches and merge spatially
|
||||
merged_embeddings = []
|
||||
start_idx = 0
|
||||
|
||||
for grid_thw in grid_thw_list:
|
||||
num_patches = math.prod(grid_thw)
|
||||
end_idx = start_idx + num_patches
|
||||
|
||||
# Get patches for this image
|
||||
image_patches = embeddings[start_idx:end_idx]
|
||||
|
||||
# Simulate spatial merging by averaging groups of patches
|
||||
merged_patches = num_patches // merge_factor
|
||||
if merged_patches > 0:
|
||||
# Reshape and average to simulate merging
|
||||
reshaped = image_patches[:merged_patches * merge_factor].view(
|
||||
merged_patches, merge_factor, -1)
|
||||
merged = reshaped.mean(dim=1)
|
||||
merged_embeddings.append(merged)
|
||||
|
||||
start_idx = end_idx
|
||||
|
||||
if merged_embeddings:
|
||||
return torch.cat(merged_embeddings, dim=0)
|
||||
else:
|
||||
return torch.empty((0, self.out_hidden_size),
|
||||
device=pixel_values.device,
|
||||
dtype=pixel_values.dtype)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size",
|
||||
[
|
||||
1, # Single image
|
||||
3, # Small batch
|
||||
5, # Odd batch size (for testing padding)
|
||||
],
|
||||
)
|
||||
def test_run_dp_sharded_mrope_vision_model(batch_size: int):
|
||||
world_size = 2
|
||||
# Launch processes
|
||||
mp.spawn(
|
||||
run_dp_sharded_mrope_vision_model_vs_direct,
|
||||
args=(
|
||||
world_size,
|
||||
batch_size,
|
||||
get_open_port(),
|
||||
),
|
||||
nprocs=world_size,
|
||||
)
|
||||
|
||||
|
||||
def run_dp_sharded_mrope_vision_model_vs_direct(local_rank: int,
|
||||
world_size: int,
|
||||
batch_size: int,
|
||||
master_port: int):
|
||||
"""
|
||||
Test that run_dp_sharded_mrope_vision_model produces the same results as
|
||||
calling the model directly.
|
||||
"""
|
||||
# Set random seed for reproducibility
|
||||
current_platform.seed_everything(0)
|
||||
device = f"{current_platform.device_name}:{local_rank}"
|
||||
current_platform.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
|
||||
update_environment_variables({
|
||||
'RANK': str(local_rank),
|
||||
'LOCAL_RANK': str(local_rank),
|
||||
'WORLD_SIZE': str(world_size),
|
||||
'MASTER_ADDR': 'localhost',
|
||||
'MASTER_PORT': str(master_port),
|
||||
})
|
||||
|
||||
# initialize distributed
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
# Create test data
|
||||
grid_thw_list = []
|
||||
pixel_values_list = []
|
||||
|
||||
for i in range(batch_size):
|
||||
# Varying image sizes for better testing
|
||||
t, h, w = 1, 4 + i, 4 + i
|
||||
grid_thw_list.append([t, h, w])
|
||||
|
||||
num_patches = t * h * w
|
||||
# Create random pixel values for this image
|
||||
image_pixels = torch.randn(num_patches, 768)
|
||||
pixel_values_list.append(image_pixels)
|
||||
|
||||
# Concatenate all pixel values
|
||||
pixel_values = torch.cat(pixel_values_list, dim=0)
|
||||
|
||||
# Create a simple mrope vision model
|
||||
vision_model = SimpleMRopeVisionModel()
|
||||
|
||||
# Run the model directly on the full input (only on rank 0)
|
||||
if local_rank == 0:
|
||||
with torch.inference_mode():
|
||||
direct_output = vision_model(pixel_values, grid_thw_list)
|
||||
|
||||
# Run the model through the sharded function
|
||||
with torch.inference_mode():
|
||||
sharded_output = run_dp_sharded_mrope_vision_model(vision_model,
|
||||
pixel_values,
|
||||
grid_thw_list,
|
||||
rope_type="rope_3d")
|
||||
sharded_output = torch.cat(sharded_output, dim=0)
|
||||
|
||||
# Check that the world size is set up correctly
|
||||
assert get_tensor_model_parallel_world_size() == world_size
|
||||
|
||||
# Compare outputs (only on rank 0)
|
||||
if local_rank == 0:
|
||||
# Check that the outputs have the same shape
|
||||
assert direct_output.shape == sharded_output.shape
|
||||
# Check that the outputs are close (they should be identical)
|
||||
assert torch.allclose(direct_output,
|
||||
sharded_output,
|
||||
rtol=1e-5,
|
||||
atol=1e-5)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
def test_run_dp_sharded_mrope_vision_model_empty_input():
|
||||
world_size = 2
|
||||
mp.spawn(
|
||||
run_dp_sharded_mrope_vision_model_empty_input_worker,
|
||||
args=(world_size, get_open_port()),
|
||||
nprocs=world_size,
|
||||
)
|
||||
|
||||
|
||||
def run_dp_sharded_mrope_vision_model_empty_input_worker(
|
||||
local_rank: int, world_size: int, master_port: int):
|
||||
"""Test run_dp_sharded_mrope_vision_model with empty input."""
|
||||
# Set up distributed environment
|
||||
device = f"{current_platform.device_name}:{local_rank}"
|
||||
current_platform.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
|
||||
update_environment_variables({
|
||||
'RANK': str(local_rank),
|
||||
'LOCAL_RANK': str(local_rank),
|
||||
'WORLD_SIZE': str(world_size),
|
||||
'MASTER_ADDR': 'localhost',
|
||||
'MASTER_PORT': str(master_port),
|
||||
})
|
||||
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
# Create empty inputs
|
||||
pixel_values = torch.empty((0, 768))
|
||||
grid_thw_list: list[list[int]] = []
|
||||
|
||||
vision_model = SimpleMRopeVisionModel()
|
||||
|
||||
# Should handle empty input gracefully
|
||||
with torch.inference_mode():
|
||||
output = run_dp_sharded_mrope_vision_model(vision_model,
|
||||
pixel_values,
|
||||
grid_thw_list,
|
||||
rope_type="rope_3d")
|
||||
|
||||
assert len(output) == 0
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=4)
|
||||
def test_run_dp_sharded_mrope_vision_model_uneven_load():
|
||||
world_size = 4
|
||||
mp.spawn(
|
||||
run_dp_sharded_mrope_vision_model_uneven_load_worker,
|
||||
args=(world_size, get_open_port()),
|
||||
nprocs=world_size,
|
||||
)
|
||||
|
||||
|
||||
def run_dp_sharded_mrope_vision_model_uneven_load_worker(
|
||||
local_rank: int, world_size: int, master_port: int):
|
||||
"""Test run_dp_sharded_mrope_vision_model with uneven load distribution."""
|
||||
# Set up distributed environment
|
||||
current_platform.seed_everything(123)
|
||||
device = f"{current_platform.device_name}:{local_rank}"
|
||||
current_platform.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
|
||||
update_environment_variables({
|
||||
'RANK': str(local_rank),
|
||||
'LOCAL_RANK': str(local_rank),
|
||||
'WORLD_SIZE': str(world_size),
|
||||
'MASTER_ADDR': 'localhost',
|
||||
'MASTER_PORT': str(master_port),
|
||||
})
|
||||
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
# Create images with very different sizes
|
||||
grid_thw_list = [
|
||||
[1, 2, 2], # Small: 4 patches
|
||||
[1, 8, 8], # Large: 64 patches
|
||||
[1, 3, 3], # Medium: 9 patches
|
||||
]
|
||||
|
||||
pixel_values_list = []
|
||||
for grid_thw in grid_thw_list:
|
||||
num_patches = math.prod(grid_thw)
|
||||
image_pixels = torch.randn(num_patches, 768)
|
||||
pixel_values_list.append(image_pixels)
|
||||
|
||||
pixel_values = torch.cat(pixel_values_list, dim=0)
|
||||
vision_model = SimpleMRopeVisionModel()
|
||||
|
||||
# Should handle uneven distribution without errors
|
||||
with torch.inference_mode():
|
||||
output_tuple = run_dp_sharded_mrope_vision_model(vision_model,
|
||||
pixel_values,
|
||||
grid_thw_list,
|
||||
rope_type="rope_3d")
|
||||
|
||||
# Verify output shape is reasonable
|
||||
merge_factor = vision_model.spatial_merge_size**2
|
||||
expected_output_patches = list(
|
||||
math.prod(grid_thw) // merge_factor for grid_thw in grid_thw_list)
|
||||
|
||||
for i, output in enumerate(output_tuple):
|
||||
assert output.shape[0] == expected_output_patches[i]
|
||||
assert output.shape[1] == vision_model.out_hidden_size
|
||||
|
||||
|
||||
@pytest.mark.parametrize("spatial_merge_size", [2, 4])
|
||||
def test_simple_mrope_vision_model_spatial_merge(spatial_merge_size: int):
|
||||
"""Test SimpleMRopeVisionModel with different spatial merge sizes."""
|
||||
device = current_platform.device_type
|
||||
|
||||
grid_thw_list = [[1, 4, 4], [1, 6, 6]] # Two images
|
||||
pixel_values_list = []
|
||||
|
||||
for grid_thw in grid_thw_list:
|
||||
num_patches = math.prod(grid_thw)
|
||||
image_pixels = torch.randn(num_patches, 768, device=device)
|
||||
pixel_values_list.append(image_pixels)
|
||||
|
||||
pixel_values = torch.cat(pixel_values_list, dim=0)
|
||||
vision_model = SimpleMRopeVisionModel(
|
||||
spatial_merge_size=spatial_merge_size).to(device)
|
||||
|
||||
with torch.inference_mode():
|
||||
output = vision_model(pixel_values, grid_thw_list)
|
||||
|
||||
# Verify output dimensions based on spatial merging
|
||||
total_patches = sum(math.prod(grid_thw) for grid_thw in grid_thw_list)
|
||||
merge_factor = spatial_merge_size**2
|
||||
expected_output_patches = total_patches // merge_factor
|
||||
|
||||
assert output.shape[0] == expected_output_patches
|
||||
assert output.shape[1] == vision_model.out_hidden_size
|
||||
|
||||
216
tests/test_envs.py
Normal file
216
tests/test_envs.py
Normal file
@ -0,0 +1,216 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.envs import env_list_with_choices, env_with_choices
|
||||
|
||||
|
||||
class TestEnvWithChoices:
|
||||
"""Test cases for env_with_choices function."""
|
||||
|
||||
def test_default_value_returned_when_env_not_set(self):
|
||||
"""Test default is returned when env var is not set."""
|
||||
env_func = env_with_choices("NONEXISTENT_ENV", "default",
|
||||
["option1", "option2"])
|
||||
assert env_func() == "default"
|
||||
|
||||
def test_none_default_returned_when_env_not_set(self):
|
||||
"""Test that None is returned when env not set and default is None."""
|
||||
env_func = env_with_choices("NONEXISTENT_ENV", None,
|
||||
["option1", "option2"])
|
||||
assert env_func() is None
|
||||
|
||||
def test_valid_value_returned_case_sensitive(self):
|
||||
"""Test that valid value is returned in case sensitive mode."""
|
||||
with patch.dict(os.environ, {"TEST_ENV": "option1"}):
|
||||
env_func = env_with_choices("TEST_ENV",
|
||||
"default", ["option1", "option2"],
|
||||
case_sensitive=True)
|
||||
assert env_func() == "option1"
|
||||
|
||||
def test_valid_lowercase_value_returned_case_insensitive(self):
|
||||
"""Test that lowercase value is accepted in case insensitive mode."""
|
||||
with patch.dict(os.environ, {"TEST_ENV": "option1"}):
|
||||
env_func = env_with_choices("TEST_ENV",
|
||||
"default", ["OPTION1", "OPTION2"],
|
||||
case_sensitive=False)
|
||||
assert env_func() == "option1"
|
||||
|
||||
def test_valid_uppercase_value_returned_case_insensitive(self):
|
||||
"""Test that uppercase value is accepted in case insensitive mode."""
|
||||
with patch.dict(os.environ, {"TEST_ENV": "OPTION1"}):
|
||||
env_func = env_with_choices("TEST_ENV",
|
||||
"default", ["option1", "option2"],
|
||||
case_sensitive=False)
|
||||
assert env_func() == "OPTION1"
|
||||
|
||||
def test_invalid_value_raises_error_case_sensitive(self):
|
||||
"""Test that invalid value raises ValueError in case sensitive mode."""
|
||||
with patch.dict(os.environ, {"TEST_ENV": "invalid"}):
|
||||
env_func = env_with_choices("TEST_ENV",
|
||||
"default", ["option1", "option2"],
|
||||
case_sensitive=True)
|
||||
with pytest.raises(ValueError,
|
||||
match="Invalid value 'invalid' for TEST_ENV"):
|
||||
env_func()
|
||||
|
||||
def test_case_mismatch_raises_error_case_sensitive(self):
|
||||
"""Test that case mismatch raises ValueError in case sensitive mode."""
|
||||
with patch.dict(os.environ, {"TEST_ENV": "OPTION1"}):
|
||||
env_func = env_with_choices("TEST_ENV",
|
||||
"default", ["option1", "option2"],
|
||||
case_sensitive=True)
|
||||
with pytest.raises(ValueError,
|
||||
match="Invalid value 'OPTION1' for TEST_ENV"):
|
||||
env_func()
|
||||
|
||||
def test_invalid_value_raises_error_case_insensitive(self):
|
||||
"""Test that invalid value raises ValueError when case insensitive."""
|
||||
with patch.dict(os.environ, {"TEST_ENV": "invalid"}):
|
||||
env_func = env_with_choices("TEST_ENV",
|
||||
"default", ["option1", "option2"],
|
||||
case_sensitive=False)
|
||||
with pytest.raises(ValueError,
|
||||
match="Invalid value 'invalid' for TEST_ENV"):
|
||||
env_func()
|
||||
|
||||
def test_callable_choices_resolved_correctly(self):
|
||||
"""Test that callable choices are resolved correctly."""
|
||||
|
||||
def get_choices():
|
||||
return ["dynamic1", "dynamic2"]
|
||||
|
||||
with patch.dict(os.environ, {"TEST_ENV": "dynamic1"}):
|
||||
env_func = env_with_choices("TEST_ENV", "default", get_choices)
|
||||
assert env_func() == "dynamic1"
|
||||
|
||||
def test_callable_choices_with_invalid_value(self):
|
||||
"""Test that callable choices raise error for invalid values."""
|
||||
|
||||
def get_choices():
|
||||
return ["dynamic1", "dynamic2"]
|
||||
|
||||
with patch.dict(os.environ, {"TEST_ENV": "invalid"}):
|
||||
env_func = env_with_choices("TEST_ENV", "default", get_choices)
|
||||
with pytest.raises(ValueError,
|
||||
match="Invalid value 'invalid' for TEST_ENV"):
|
||||
env_func()
|
||||
|
||||
|
||||
class TestEnvListWithChoices:
|
||||
"""Test cases for env_list_with_choices function."""
|
||||
|
||||
def test_default_list_returned_when_env_not_set(self):
|
||||
"""Test that default list is returned when env var is not set."""
|
||||
env_func = env_list_with_choices("NONEXISTENT_ENV",
|
||||
["default1", "default2"],
|
||||
["option1", "option2"])
|
||||
assert env_func() == ["default1", "default2"]
|
||||
|
||||
def test_empty_default_list_returned_when_env_not_set(self):
|
||||
"""Test that empty default list is returned when env not set."""
|
||||
env_func = env_list_with_choices("NONEXISTENT_ENV", [],
|
||||
["option1", "option2"])
|
||||
assert env_func() == []
|
||||
|
||||
def test_single_valid_value_parsed_correctly(self):
|
||||
"""Test that single valid value is parsed correctly."""
|
||||
with patch.dict(os.environ, {"TEST_ENV": "option1"}):
|
||||
env_func = env_list_with_choices("TEST_ENV", [],
|
||||
["option1", "option2"])
|
||||
assert env_func() == ["option1"]
|
||||
|
||||
def test_multiple_valid_values_parsed_correctly(self):
|
||||
"""Test that multiple valid values are parsed correctly."""
|
||||
with patch.dict(os.environ, {"TEST_ENV": "option1,option2"}):
|
||||
env_func = env_list_with_choices("TEST_ENV", [],
|
||||
["option1", "option2"])
|
||||
assert env_func() == ["option1", "option2"]
|
||||
|
||||
def test_values_with_whitespace_trimmed(self):
|
||||
"""Test that values with whitespace are trimmed correctly."""
|
||||
with patch.dict(os.environ, {"TEST_ENV": " option1 , option2 "}):
|
||||
env_func = env_list_with_choices("TEST_ENV", [],
|
||||
["option1", "option2"])
|
||||
assert env_func() == ["option1", "option2"]
|
||||
|
||||
def test_empty_values_filtered_out(self):
|
||||
"""Test that empty values are filtered out."""
|
||||
with patch.dict(os.environ, {"TEST_ENV": "option1,,option2,"}):
|
||||
env_func = env_list_with_choices("TEST_ENV", [],
|
||||
["option1", "option2"])
|
||||
assert env_func() == ["option1", "option2"]
|
||||
|
||||
def test_empty_string_returns_default(self):
|
||||
"""Test that empty string returns default."""
|
||||
with patch.dict(os.environ, {"TEST_ENV": ""}):
|
||||
env_func = env_list_with_choices("TEST_ENV", ["default"],
|
||||
["option1", "option2"])
|
||||
assert env_func() == ["default"]
|
||||
|
||||
def test_only_commas_returns_default(self):
|
||||
"""Test that string with only commas returns default."""
|
||||
with patch.dict(os.environ, {"TEST_ENV": ",,,"}):
|
||||
env_func = env_list_with_choices("TEST_ENV", ["default"],
|
||||
["option1", "option2"])
|
||||
assert env_func() == ["default"]
|
||||
|
||||
def test_case_sensitive_validation(self):
|
||||
"""Test case sensitive validation."""
|
||||
with patch.dict(os.environ, {"TEST_ENV": "option1,OPTION2"}):
|
||||
env_func = env_list_with_choices("TEST_ENV", [],
|
||||
["option1", "option2"],
|
||||
case_sensitive=True)
|
||||
with pytest.raises(ValueError,
|
||||
match="Invalid value 'OPTION2' in TEST_ENV"):
|
||||
env_func()
|
||||
|
||||
def test_case_insensitive_validation(self):
|
||||
"""Test case insensitive validation."""
|
||||
with patch.dict(os.environ, {"TEST_ENV": "OPTION1,option2"}):
|
||||
env_func = env_list_with_choices("TEST_ENV", [],
|
||||
["option1", "option2"],
|
||||
case_sensitive=False)
|
||||
assert env_func() == ["OPTION1", "option2"]
|
||||
|
||||
def test_invalid_value_in_list_raises_error(self):
|
||||
"""Test that invalid value in list raises ValueError."""
|
||||
with patch.dict(os.environ, {"TEST_ENV": "option1,invalid,option2"}):
|
||||
env_func = env_list_with_choices("TEST_ENV", [],
|
||||
["option1", "option2"])
|
||||
with pytest.raises(ValueError,
|
||||
match="Invalid value 'invalid' in TEST_ENV"):
|
||||
env_func()
|
||||
|
||||
def test_callable_choices_resolved_correctly(self):
|
||||
"""Test that callable choices are resolved correctly."""
|
||||
|
||||
def get_choices():
|
||||
return ["dynamic1", "dynamic2"]
|
||||
|
||||
with patch.dict(os.environ, {"TEST_ENV": "dynamic1,dynamic2"}):
|
||||
env_func = env_list_with_choices("TEST_ENV", [], get_choices)
|
||||
assert env_func() == ["dynamic1", "dynamic2"]
|
||||
|
||||
def test_callable_choices_with_invalid_value(self):
|
||||
"""Test that callable choices raise error for invalid values."""
|
||||
|
||||
def get_choices():
|
||||
return ["dynamic1", "dynamic2"]
|
||||
|
||||
with patch.dict(os.environ, {"TEST_ENV": "dynamic1,invalid"}):
|
||||
env_func = env_list_with_choices("TEST_ENV", [], get_choices)
|
||||
with pytest.raises(ValueError,
|
||||
match="Invalid value 'invalid' in TEST_ENV"):
|
||||
env_func()
|
||||
|
||||
def test_duplicate_values_preserved(self):
|
||||
"""Test that duplicate values in the list are preserved."""
|
||||
with patch.dict(os.environ, {"TEST_ENV": "option1,option1,option2"}):
|
||||
env_func = env_list_with_choices("TEST_ENV", [],
|
||||
["option1", "option2"])
|
||||
assert env_func() == ["option1", "option1", "option2"]
|
||||
@ -13,6 +13,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.qwen3coder_tool_parser import (
|
||||
Qwen3CoderToolParser)
|
||||
from vllm.entrypoints.openai.tool_parsers.qwen3xml_tool_parser import (
|
||||
Qwen3XMLToolParser)
|
||||
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
||||
|
||||
@ -29,6 +31,21 @@ def qwen3_tool_parser(qwen3_tokenizer):
|
||||
return Qwen3CoderToolParser(qwen3_tokenizer)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def qwen3_xml_tool_parser(qwen3_tokenizer):
|
||||
return Qwen3XMLToolParser(qwen3_tokenizer)
|
||||
|
||||
|
||||
@pytest.fixture(params=["original", "xml"])
|
||||
def qwen3_tool_parser_parametrized(qwen3_tool_parser, qwen3_xml_tool_parser,
|
||||
request):
|
||||
"""Parameterized fixture that provides both parser types for testing"""
|
||||
if request.param == "original":
|
||||
return qwen3_tool_parser
|
||||
else:
|
||||
return qwen3_xml_tool_parser
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tools():
|
||||
return [
|
||||
@ -95,7 +112,7 @@ def assert_tool_calls(actual_tool_calls: list[ToolCall],
|
||||
|
||||
|
||||
def stream_delta_message_generator(
|
||||
qwen3_tool_parser: Qwen3CoderToolParser,
|
||||
qwen3_tool_parser,
|
||||
qwen3_tokenizer: AnyTokenizer,
|
||||
model_output: str,
|
||||
request: Optional[ChatCompletionRequest] = None
|
||||
@ -144,9 +161,9 @@ def stream_delta_message_generator(
|
||||
read_offset = new_read_offset
|
||||
|
||||
|
||||
def test_extract_tool_calls_no_tools(qwen3_tool_parser):
|
||||
def test_extract_tool_calls_no_tools(qwen3_tool_parser_parametrized):
|
||||
model_output = "This is a test response without any tool calls"
|
||||
extracted_tool_calls = qwen3_tool_parser.extract_tool_calls(
|
||||
extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
assert not extracted_tool_calls.tools_called
|
||||
assert extracted_tool_calls.tool_calls == []
|
||||
@ -294,12 +311,13 @@ circle
|
||||
], "Let me calculate that area for you."),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls(qwen3_tool_parser, sample_tools, model_output,
|
||||
expected_tool_calls, expected_content):
|
||||
def test_extract_tool_calls(qwen3_tool_parser_parametrized, sample_tools,
|
||||
model_output, expected_tool_calls,
|
||||
expected_content):
|
||||
request = ChatCompletionRequest(model=MODEL,
|
||||
messages=[],
|
||||
tools=sample_tools)
|
||||
extracted_tool_calls = qwen3_tool_parser.extract_tool_calls(
|
||||
extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
|
||||
model_output, request=request)
|
||||
assert extracted_tool_calls.tools_called
|
||||
|
||||
@ -308,7 +326,8 @@ def test_extract_tool_calls(qwen3_tool_parser, sample_tools, model_output,
|
||||
assert extracted_tool_calls.content == expected_content
|
||||
|
||||
|
||||
def test_extract_tool_calls_fallback_no_tags(qwen3_tool_parser, sample_tools):
|
||||
def test_extract_tool_calls_fallback_no_tags(qwen3_tool_parser_parametrized,
|
||||
sample_tools):
|
||||
"""Test fallback parsing when XML tags are missing"""
|
||||
model_output = '''<function=get_current_weather>
|
||||
<parameter=city>
|
||||
@ -322,7 +341,7 @@ TX
|
||||
request = ChatCompletionRequest(model=MODEL,
|
||||
messages=[],
|
||||
tools=sample_tools)
|
||||
extracted_tool_calls = qwen3_tool_parser.extract_tool_calls(
|
||||
extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
|
||||
model_output, request=request)
|
||||
|
||||
assert extracted_tool_calls.tools_called
|
||||
@ -331,7 +350,7 @@ TX
|
||||
"get_current_weather")
|
||||
|
||||
|
||||
def test_extract_tool_calls_type_conversion(qwen3_tool_parser):
|
||||
def test_extract_tool_calls_type_conversion(qwen3_tool_parser_parametrized):
|
||||
"""Test parameter type conversion based on tool schema"""
|
||||
tools = [
|
||||
ChatCompletionToolsParam(type="function",
|
||||
@ -381,7 +400,7 @@ hello world
|
||||
</tool_call>'''
|
||||
|
||||
request = ChatCompletionRequest(model=MODEL, messages=[], tools=tools)
|
||||
extracted_tool_calls = qwen3_tool_parser.extract_tool_calls(
|
||||
extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
|
||||
model_output, request=request)
|
||||
|
||||
args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments)
|
||||
@ -536,9 +555,10 @@ circle
|
||||
], "Let me calculate that area for you."),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls_streaming(qwen3_tool_parser, qwen3_tokenizer,
|
||||
sample_tools, model_output,
|
||||
expected_tool_calls, expected_content):
|
||||
def test_extract_tool_calls_streaming(qwen3_tool_parser_parametrized,
|
||||
qwen3_tokenizer, sample_tools,
|
||||
model_output, expected_tool_calls,
|
||||
expected_content):
|
||||
"""Test incremental streaming behavior including typed parameters"""
|
||||
request = ChatCompletionRequest(model=MODEL,
|
||||
messages=[],
|
||||
@ -548,7 +568,8 @@ def test_extract_tool_calls_streaming(qwen3_tool_parser, qwen3_tokenizer,
|
||||
tool_states = {} # Track state per tool index
|
||||
|
||||
for delta_message in stream_delta_message_generator(
|
||||
qwen3_tool_parser, qwen3_tokenizer, model_output, request):
|
||||
qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output,
|
||||
request):
|
||||
# role should never be streamed from tool parser
|
||||
assert not delta_message.role
|
||||
|
||||
@ -609,7 +630,7 @@ def test_extract_tool_calls_streaming(qwen3_tool_parser, qwen3_tokenizer,
|
||||
|
||||
|
||||
def test_extract_tool_calls_missing_closing_parameter_tag(
|
||||
qwen3_tool_parser, sample_tools):
|
||||
qwen3_tool_parser_parametrized, sample_tools):
|
||||
"""Test handling of missing closing </parameter> tag"""
|
||||
# Using get_current_weather from sample_tools but with malformed XML
|
||||
model_output = '''Let me check the weather for you:
|
||||
@ -629,7 +650,7 @@ fahrenheit
|
||||
request = ChatCompletionRequest(model=MODEL,
|
||||
messages=[],
|
||||
tools=sample_tools)
|
||||
extracted_tool_calls = qwen3_tool_parser.extract_tool_calls(
|
||||
extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
|
||||
model_output, request=request)
|
||||
|
||||
# The parser should handle the malformed XML gracefully
|
||||
@ -652,7 +673,7 @@ fahrenheit
|
||||
|
||||
|
||||
def test_extract_tool_calls_streaming_missing_closing_tag(
|
||||
qwen3_tool_parser, qwen3_tokenizer, sample_tools):
|
||||
qwen3_tool_parser_parametrized, qwen3_tokenizer, sample_tools):
|
||||
"""Test streaming with missing closing </parameter> tag"""
|
||||
# Using get_current_weather from sample_tools but with malformed XML
|
||||
model_output = '''Let me check the weather for you:
|
||||
@ -677,7 +698,8 @@ fahrenheit
|
||||
tool_states = {}
|
||||
|
||||
for delta_message in stream_delta_message_generator(
|
||||
qwen3_tool_parser, qwen3_tokenizer, model_output, request):
|
||||
qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output,
|
||||
request):
|
||||
|
||||
if delta_message.content:
|
||||
other_content += delta_message.content
|
||||
@ -727,9 +749,8 @@ fahrenheit
|
||||
assert args["unit"] == "fahrenheit"
|
||||
|
||||
|
||||
def test_extract_tool_calls_streaming_incremental(qwen3_tool_parser,
|
||||
qwen3_tokenizer,
|
||||
sample_tools):
|
||||
def test_extract_tool_calls_streaming_incremental(
|
||||
qwen3_tool_parser_parametrized, qwen3_tokenizer, sample_tools):
|
||||
"""Test that streaming is truly incremental"""
|
||||
model_output = '''I'll check the weather.<tool_call>
|
||||
<function=get_current_weather>
|
||||
@ -748,7 +769,8 @@ TX
|
||||
|
||||
chunks = []
|
||||
for delta_message in stream_delta_message_generator(
|
||||
qwen3_tool_parser, qwen3_tokenizer, model_output, request):
|
||||
qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output,
|
||||
request):
|
||||
chunks.append(delta_message)
|
||||
|
||||
# Should have multiple chunks
|
||||
@ -784,3 +806,49 @@ TX
|
||||
parsed_args = json.loads(full_args)
|
||||
assert parsed_args["city"] == "Dallas"
|
||||
assert parsed_args["state"] == "TX"
|
||||
|
||||
|
||||
def test_extract_tool_calls_complex_type_with_single_quote(
|
||||
qwen3_tool_parser_parametrized):
|
||||
"""Test parameter type conversion based on tool schema"""
|
||||
tools = [
|
||||
ChatCompletionToolsParam(type="function",
|
||||
function={
|
||||
"name": "test_types",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"int_param": {
|
||||
"type": "integer"
|
||||
},
|
||||
"float_param": {
|
||||
"type": "float"
|
||||
},
|
||||
"bool_param": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"str_param": {
|
||||
"type": "string"
|
||||
},
|
||||
"obj_param": {
|
||||
"type": "object"
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
]
|
||||
|
||||
model_output = '''<tool_call>
|
||||
<function=test_types>
|
||||
<parameter=obj_param>
|
||||
{'key': 'value'}
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>'''
|
||||
|
||||
request = ChatCompletionRequest(model=MODEL, messages=[], tools=tools)
|
||||
extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
|
||||
model_output, request=request)
|
||||
|
||||
args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments)
|
||||
assert args["obj_param"] == {"key": "value"}
|
||||
|
||||
@ -6,6 +6,7 @@ Run `pytest tests/kernels/moe/test_moe_pallas.py`.
|
||||
"""
|
||||
import pytest
|
||||
import torch
|
||||
import torch_xla
|
||||
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
@ -77,7 +78,7 @@ def test_pallas_moe(
|
||||
expert_map=e_map,
|
||||
renormalize=False,
|
||||
)
|
||||
xm.mark_step()
|
||||
torch_xla.sync(wait=False)
|
||||
|
||||
# Compare outputs
|
||||
torch.testing.assert_close(
|
||||
|
||||
@ -47,7 +47,10 @@ backend_configs = {
|
||||
# FA3 on Hopper
|
||||
"FA3":
|
||||
BackendConfig(name="FA3",
|
||||
env_vars={"VLLM_FLASH_ATTN_VERSION": "3"},
|
||||
env_vars={
|
||||
"VLLM_FLASH_ATTN_VERSION": "3",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL",
|
||||
},
|
||||
@ -67,6 +70,7 @@ backend_configs = {
|
||||
BackendConfig(name="FlashAttentionMLA",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||
@ -75,7 +79,10 @@ backend_configs = {
|
||||
# FA2
|
||||
"FA2":
|
||||
BackendConfig(name="FA2",
|
||||
env_vars={"VLLM_FLASH_ATTN_VERSION": "2"},
|
||||
env_vars={
|
||||
"VLLM_FLASH_ATTN_VERSION": "2",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
}),
|
||||
|
||||
@ -85,7 +85,10 @@ run_tests_for_model() {
|
||||
echo "Starting prefill instance $i on GPU $GPU_ID, port $PORT"
|
||||
|
||||
# Build the command with or without model-specific args
|
||||
BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT vllm serve $model_name \
|
||||
BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID \
|
||||
UCX_NET_DEVICES=all \
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \
|
||||
vllm serve $model_name \
|
||||
--port $PORT \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.2 \
|
||||
@ -117,7 +120,10 @@ run_tests_for_model() {
|
||||
echo "Starting decode instance $i on GPU $GPU_ID, port $PORT"
|
||||
|
||||
# Build the command with or without model-specific args
|
||||
BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT vllm serve $model_name \
|
||||
BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID \
|
||||
UCX_NET_DEVICES=all \
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \
|
||||
vllm serve $model_name \
|
||||
--port $PORT \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.2 \
|
||||
|
||||
59
tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py
Normal file
59
tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py
Normal file
@ -0,0 +1,59 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa: E501
|
||||
SharedStorageConnectorMetadata)
|
||||
from vllm.distributed.kv_transfer.kv_transfer_state import (
|
||||
ensure_kv_transfer_initialized, get_kv_transfer_group)
|
||||
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
||||
KVConnectorModelRunnerMixin)
|
||||
|
||||
# Importing utils registers TestSharedStorageConnector with the factory
|
||||
from .utils import create_vllm_config
|
||||
|
||||
|
||||
def _make_empty_scheduler_output():
|
||||
return SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
||||
num_scheduled_tokens={},
|
||||
total_num_scheduled_tokens=0,
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None,
|
||||
kv_connector_metadata=SharedStorageConnectorMetadata(),
|
||||
)
|
||||
|
||||
|
||||
def test_kv_connector_mixin_clears_metadata():
|
||||
vllm_config = create_vllm_config()
|
||||
vllm_config.kv_transfer_config.kv_connector = "TestSharedStorageConnector"
|
||||
vllm_config.kv_transfer_config.kv_role = "kv_both"
|
||||
vllm_config.kv_transfer_config.kv_connector_extra_config["name"] = ("unit")
|
||||
|
||||
# Initialize the global connector instance
|
||||
ensure_kv_transfer_initialized(vllm_config)
|
||||
|
||||
try:
|
||||
# Minimal scheduler output with empty metadata; mixin should still
|
||||
# bind/clear metadata even if no loads happen
|
||||
scheduler_output = _make_empty_scheduler_output()
|
||||
|
||||
# Invoke the no-forward path which uses the mixin context manager
|
||||
KVConnectorModelRunnerMixin.kv_connector_no_forward(
|
||||
scheduler_output, vllm_config)
|
||||
|
||||
# Verify clear_connector_metadata was called on the connector
|
||||
connector = get_kv_transfer_group()
|
||||
assert connector._connector_metadata is None
|
||||
# Test connector wrapper records method calls
|
||||
assert connector.call_record.get("bind_connector_metadata", 0) == 1
|
||||
assert connector.call_record.get("clear_connector_metadata", 0) == 1
|
||||
finally:
|
||||
# Ensure we clean up the global connector between tests
|
||||
KVConnectorModelRunnerMixin.ensure_kv_transfer_shutdown()
|
||||
@ -27,6 +27,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
||||
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
|
||||
NixlConnectorWorker, NixlKVConnectorStats)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.platforms.interface import Platform
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
||||
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
||||
@ -56,7 +57,10 @@ class FakeNixlWrapper:
|
||||
def get_reg_descs(self, caches_data, memory_type: str) -> list:
|
||||
return [str(uuid.uuid4()) for _ in caches_data]
|
||||
|
||||
def register_memory(self, descs) -> None:
|
||||
def register_memory(self, descs, backends) -> None:
|
||||
pass
|
||||
|
||||
def deregister_memory(self, descs) -> None:
|
||||
pass
|
||||
|
||||
def get_xfer_descs(self, blocks_data, memory_type: str) -> list:
|
||||
@ -85,6 +89,12 @@ class FakeNixlWrapper:
|
||||
def release_xfer_handle(self, handle: int) -> None:
|
||||
pass
|
||||
|
||||
def release_dlist_handle(self, handle: int) -> None:
|
||||
pass
|
||||
|
||||
def remove_remote_agent(self, agent: str) -> None:
|
||||
pass
|
||||
|
||||
def send_notif(self, agent_name: str, notif_msg: bytes) -> None:
|
||||
pass
|
||||
|
||||
@ -855,3 +865,95 @@ def test_register_kv_caches(dist_init):
|
||||
assert block_len == expected_block_len, \
|
||||
f"Block entry {i}: Expected block len {expected_block_len}, " \
|
||||
f"got {block_len}"
|
||||
|
||||
|
||||
class FakePlatform(Platform):
|
||||
device_type: str = "oot"
|
||||
|
||||
@classmethod
|
||||
def get_nixl_supported_devices(cls) -> dict[str, tuple[str, ...]]:
|
||||
"""
|
||||
Returns a mapping from device_type to a tuple of supported
|
||||
kv_buffer_device for nixl.
|
||||
"""
|
||||
return {'oot': ('oot', )}
|
||||
|
||||
@classmethod
|
||||
def get_nixl_memory_type(cls) -> Optional[str]:
|
||||
"""
|
||||
Returns the nixl memory type for the current platform.
|
||||
"""
|
||||
return 'VRAM'
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_buffer_device, nixl_memory_type", [
|
||||
("oot", "VRAM"),
|
||||
])
|
||||
def test_kv_buffer_to_nixl_memory_types(dist_init, kv_buffer_device,
|
||||
nixl_memory_type):
|
||||
"""
|
||||
Test that register_kv_caches() passes the correct memory types from the
|
||||
config to the nixl_wrapper.
|
||||
"""
|
||||
vllm_config = create_vllm_config()
|
||||
# Override the default memory types in the config
|
||||
vllm_config.kv_transfer_config.kv_buffer_device = kv_buffer_device
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
||||
_NIXL_SUPPORTED_DEVICE)
|
||||
_NIXL_SUPPORTED_DEVICE.update(FakePlatform.get_nixl_supported_devices())
|
||||
|
||||
with patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"), \
|
||||
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event"), \
|
||||
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread"), \
|
||||
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform", FakePlatform), \
|
||||
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector._NIXL_SUPPORTED_DEVICE", _NIXL_SUPPORTED_DEVICE): # noqa: E501
|
||||
|
||||
# Create connector and replace its worker with a fake one for isolation
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
|
||||
# Verify get_reg_descs was called with the correct memory_type
|
||||
assert connector.connector_worker.kv_buffer_device == kv_buffer_device
|
||||
assert connector.connector_worker.nixl_memory_type == nixl_memory_type
|
||||
|
||||
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FakeNixlWrapper)
|
||||
def test_shutdown_cleans_up_resources(dist_init):
|
||||
"""Test that shutdown() properly cleans up all resources."""
|
||||
vllm_config = create_vllm_config()
|
||||
|
||||
worker = NixlConnectorWorker(vllm_config,
|
||||
vllm_config.kv_transfer_config.engine_id)
|
||||
nixl_wrapper = worker.nixl_wrapper
|
||||
|
||||
with patch.object(worker, '_handshake_initiation_executor') as mock_exec, \
|
||||
patch.object(worker, '_nixl_handshake_listener_t') as mock_listener, \
|
||||
patch.object(nixl_wrapper, 'release_xfer_handle') as mock_rel_xfer, \
|
||||
patch.object(nixl_wrapper, 'release_dlist_handle') as mock_rel_dlist, \
|
||||
patch.object(nixl_wrapper, 'remove_remote_agent') as mock_rem_agent, \
|
||||
patch.object(nixl_wrapper, 'deregister_memory') as mock_dereg:
|
||||
|
||||
worker._recving_transfers = {"req1": [(123, time.perf_counter())]}
|
||||
worker.src_xfer_side_handle = 456
|
||||
worker.dst_xfer_side_handles = {"engine1": 789}
|
||||
worker._remote_agents = {"engine1": {0: "agent1"}}
|
||||
worker._registered_descs = ["desc1", "desc2"]
|
||||
|
||||
worker.shutdown()
|
||||
|
||||
# Test idempotency
|
||||
worker.shutdown()
|
||||
worker.shutdown()
|
||||
|
||||
mock_exec.shutdown.assert_called_with(wait=False)
|
||||
mock_listener.join.assert_called_once_with(timeout=0)
|
||||
|
||||
mock_rel_xfer.assert_called_once_with(123)
|
||||
assert mock_rel_dlist.call_count == 2
|
||||
mock_rel_dlist.assert_any_call(456) # src handle
|
||||
mock_rel_dlist.assert_any_call(789) # dst handle
|
||||
mock_rem_agent.assert_called_once_with("agent1")
|
||||
assert mock_dereg.call_count == 2
|
||||
mock_dereg.assert_any_call("desc1")
|
||||
mock_dereg.assert_any_call("desc2")
|
||||
|
||||
62
tests/v1/kv_offload/test_cpu_offloading.py
Normal file
62
tests/v1/kv_offload/test_cpu_offloading.py
Normal file
@ -0,0 +1,62 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import KVTransferConfig
|
||||
|
||||
CPU_BLOCK_SIZES = [16, 48]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("cpu_block_size", CPU_BLOCK_SIZES)
|
||||
def test_cpu_offloading(cpu_block_size: int) -> None:
|
||||
"""
|
||||
Tests OffloadingConnector with CPUOffloadingSpec.
|
||||
"""
|
||||
|
||||
# configure OffloadingConnector (spec_name=CPUOffloadingSpec by default)
|
||||
kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="OffloadingConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={
|
||||
"num_cpu_blocks": 100,
|
||||
"block_size": cpu_block_size
|
||||
},
|
||||
)
|
||||
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-3.2-1B-Instruct",
|
||||
gpu_memory_utilization=0.5,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
)
|
||||
|
||||
prompts = ["Hi " * 100]
|
||||
sampling_params = SamplingParams(temperature=0, max_tokens=20)
|
||||
|
||||
# run generation - this should trigger saving KV cache
|
||||
start_time = time.time()
|
||||
llm.generate(prompts, sampling_params, use_tqdm=False)
|
||||
cold_time = time.time() - start_time
|
||||
|
||||
# run generation again - should hit the GPU prefix cache
|
||||
start_time = time.time()
|
||||
llm.generate(prompts, sampling_params, use_tqdm=False)
|
||||
gpu_hit_time = time.time() - start_time
|
||||
|
||||
# reset prefix cache to avoid GPU hit.
|
||||
llm.reset_prefix_cache()
|
||||
|
||||
# sleep for a sec to make sure CPU finished storing
|
||||
time.sleep(1)
|
||||
|
||||
# run generation again - this should trigger loading from CPU
|
||||
start_time = time.time()
|
||||
llm.generate(prompts, sampling_params, use_tqdm=False)
|
||||
cpu_hit_time = time.time() - start_time
|
||||
|
||||
print("Generation times:")
|
||||
print(f" Cold: {cold_time * 1000:.2f}ms")
|
||||
print(f" GPU hit: {gpu_hit_time * 1000:.2f}ms")
|
||||
print(f" CPU hit: {cpu_hit_time * 1000:.2f}ms")
|
||||
@ -13,7 +13,6 @@ from vllm import SamplingParams
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
from vllm.v1.engine.core_client import DPAsyncMPClient
|
||||
@ -29,10 +28,6 @@ engine_args = AsyncEngineArgs(
|
||||
data_parallel_size=DP_SIZE,
|
||||
)
|
||||
|
||||
if not current_platform.supports_v1(engine_args.create_model_config()):
|
||||
pytest.skip(reason="Requires V1-supporting platform.",
|
||||
allow_module_level=True)
|
||||
|
||||
|
||||
async def generate(
|
||||
engine: AsyncLLM,
|
||||
|
||||
@ -4,6 +4,7 @@ import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch_xla
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
@ -63,7 +64,7 @@ def test_topp_result_sums_past_p():
|
||||
probs.masked_fill_(logits_masked.isinf(), 0)
|
||||
masked_prob_sum = probs.sum(dim=-1)
|
||||
|
||||
xm.mark_step()
|
||||
torch_xla.sync()
|
||||
|
||||
# Perform assertion on CPU.
|
||||
assert torch.all(torch.ge(masked_prob_sum.cpu() + TOLERANCE, p.cpu()))
|
||||
@ -82,7 +83,7 @@ def test_topp_basic():
|
||||
k=torch.tensor([3, 3]),
|
||||
p=torch.tensor([0.79, 0.79]))
|
||||
|
||||
xm.mark_step()
|
||||
torch_xla.sync()
|
||||
|
||||
# Expect the smallest elements to be dropped.
|
||||
expected_result = logits.clone().cpu()
|
||||
@ -104,7 +105,7 @@ def test_topp_select_all():
|
||||
k=torch.tensor([3, 3]),
|
||||
p=torch.tensor([1.0, 1.0]))
|
||||
|
||||
xm.mark_step()
|
||||
torch_xla.sync()
|
||||
|
||||
assert torch.allclose(logits.cpu(), result.cpu())
|
||||
|
||||
@ -122,7 +123,7 @@ def test_topp_with_ties():
|
||||
k=torch.tensor([4]),
|
||||
p=torch.tensor([0.2]))
|
||||
|
||||
xm.mark_step()
|
||||
torch_xla.sync()
|
||||
|
||||
# All tie values are included in the top-p set. Tie breaking is left
|
||||
# to be done during final sampling (all tie tokens have equal
|
||||
@ -146,7 +147,7 @@ def test_both_topk_topp():
|
||||
k=torch.tensor([1, 3]),
|
||||
p=torch.tensor([0.79, 0.79]))
|
||||
|
||||
xm.mark_step()
|
||||
torch_xla.sync()
|
||||
|
||||
# Since for the first batch k=1, expect only the largest element gets
|
||||
# selected.
|
||||
|
||||
@ -1,35 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
CI=${1:-0}
|
||||
PYTHON_VERSION=${2:-local}
|
||||
|
||||
if [ "$CI" -eq 1 ]; then
|
||||
set -e
|
||||
fi
|
||||
|
||||
if [ $PYTHON_VERSION == "local" ]; then
|
||||
PYTHON_VERSION=$(python -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")')
|
||||
fi
|
||||
|
||||
run_mypy() {
|
||||
echo "Running mypy on $1"
|
||||
if [ "$CI" -eq 1 ] && [ -z "$1" ]; then
|
||||
mypy --python-version "${PYTHON_VERSION}" "$@"
|
||||
return
|
||||
fi
|
||||
mypy --follow-imports skip --python-version "${PYTHON_VERSION}" "$@"
|
||||
}
|
||||
|
||||
run_mypy # Note that this is less strict than CI
|
||||
run_mypy tests
|
||||
run_mypy vllm/attention
|
||||
run_mypy vllm/compilation
|
||||
run_mypy vllm/distributed
|
||||
run_mypy vllm/engine
|
||||
run_mypy vllm/executor
|
||||
run_mypy vllm/inputs
|
||||
run_mypy vllm/lora
|
||||
run_mypy --exclude 'vllm/model_executor/layers/fla/ops' vllm/model_executor
|
||||
run_mypy vllm/plugins
|
||||
run_mypy vllm/worker
|
||||
run_mypy vllm/v1
|
||||
@ -1,20 +1,10 @@
|
||||
#!/usr/bin/env python3
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
import sys
|
||||
|
||||
import regex as re
|
||||
|
||||
try:
|
||||
import pathspec
|
||||
except ImportError:
|
||||
print(
|
||||
"ERROR: The 'pathspec' library is required. "
|
||||
"Install it with 'pip install pathspec'.",
|
||||
file=sys.stderr)
|
||||
sys.exit(2)
|
||||
|
||||
# List of files (relative to repo root) that are allowed to import pickle or
|
||||
# cloudpickle
|
||||
#
|
||||
@ -25,7 +15,7 @@ except ImportError:
|
||||
# Before adding new uses of pickle/cloudpickle, please consider safer
|
||||
# alternatives like msgpack or pydantic that are already in use in vLLM. Only
|
||||
# add to this list if absolutely necessary and after careful security review.
|
||||
ALLOWED_FILES = set([
|
||||
ALLOWED_FILES = {
|
||||
# pickle
|
||||
'vllm/v1/serial_utils.py',
|
||||
'vllm/v1/executor/multiproc_executor.py',
|
||||
@ -36,11 +26,9 @@ ALLOWED_FILES = set([
|
||||
'tests/tokenization/test_cached_tokenizer.py',
|
||||
'vllm/distributed/utils.py',
|
||||
'vllm/distributed/parallel_state.py',
|
||||
'vllm/engine/multiprocessing/client.py',
|
||||
'vllm/distributed/device_communicators/all_reduce_utils.py',
|
||||
'vllm/distributed/device_communicators/shm_broadcast.py',
|
||||
'vllm/distributed/device_communicators/shm_object_storage.py',
|
||||
'vllm/engine/multiprocessing/engine.py',
|
||||
'benchmarks/kernels/graph_machete_bench.py',
|
||||
'benchmarks/kernels/benchmark_lora.py',
|
||||
'benchmarks/kernels/benchmark_machete.py',
|
||||
@ -55,65 +43,30 @@ ALLOWED_FILES = set([
|
||||
'tests/utils.py',
|
||||
# pickle and cloudpickle
|
||||
'vllm/utils/__init__.py',
|
||||
'vllm/v1/serial_utils.py',
|
||||
'vllm/v1/executor/multiproc_executor.py',
|
||||
'vllm/transformers_utils/config.py',
|
||||
'vllm/model_executor/models/registry.py',
|
||||
'vllm/engine/multiprocessing/client.py',
|
||||
'vllm/engine/multiprocessing/engine.py',
|
||||
])
|
||||
}
|
||||
|
||||
PICKLE_RE = re.compile(r"^\s*(import\s+(pickle|cloudpickle)(\s|$|\sas)"
|
||||
r"|from\s+(pickle|cloudpickle)\s+import\b)")
|
||||
|
||||
|
||||
def is_python_file(path):
|
||||
return path.endswith('.py')
|
||||
|
||||
|
||||
def scan_file(path):
|
||||
def scan_file(path: str) -> int:
|
||||
with open(path, encoding='utf-8') as f:
|
||||
for line in f:
|
||||
for i, line in enumerate(f, 1):
|
||||
if PICKLE_RE.match(line):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def load_gitignore(repo_root):
|
||||
gitignore_path = os.path.join(repo_root, '.gitignore')
|
||||
patterns = []
|
||||
if os.path.exists(gitignore_path):
|
||||
with open(gitignore_path, encoding='utf-8') as f:
|
||||
patterns = f.read().splitlines()
|
||||
# Always ignore .git directory
|
||||
patterns.append('.git/')
|
||||
return pathspec.PathSpec.from_lines('gitwildmatch', patterns)
|
||||
print(f"{path}:{i}: "
|
||||
"\033[91merror:\033[0m " # red color
|
||||
"Found pickle/cloudpickle import")
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
def main():
|
||||
repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
spec = load_gitignore(repo_root)
|
||||
bad_files = []
|
||||
for dirpath, _, filenames in os.walk(repo_root):
|
||||
for filename in filenames:
|
||||
if not is_python_file(filename):
|
||||
continue
|
||||
abs_path = os.path.join(dirpath, filename)
|
||||
rel_path = os.path.relpath(abs_path, repo_root)
|
||||
# Skip ignored files
|
||||
if spec.match_file(rel_path):
|
||||
continue
|
||||
if scan_file(abs_path) and rel_path not in ALLOWED_FILES:
|
||||
bad_files.append(rel_path)
|
||||
if bad_files:
|
||||
print("\nERROR: The following files import 'pickle' or 'cloudpickle' "
|
||||
"but are not in the allowed list:")
|
||||
for f in bad_files:
|
||||
print(f" {f}")
|
||||
print("\nIf this is intentional, update the allowed list in "
|
||||
"tools/check_pickle_imports.py.")
|
||||
sys.exit(1)
|
||||
sys.exit(0)
|
||||
returncode = 0
|
||||
for filename in sys.argv[1:]:
|
||||
if filename in ALLOWED_FILES:
|
||||
continue
|
||||
returncode |= scan_file(filename)
|
||||
return returncode
|
||||
|
||||
|
||||
def test_regex():
|
||||
@ -149,4 +102,4 @@ if __name__ == '__main__':
|
||||
if '--test-regex' in sys.argv:
|
||||
test_regex()
|
||||
else:
|
||||
main()
|
||||
sys.exit(main())
|
||||
140
tools/pre_commit/mypy.py
Executable file
140
tools/pre_commit/mypy.py
Executable file
@ -0,0 +1,140 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Run mypy on changed files.
|
||||
|
||||
This script is designed to be used as a pre-commit hook. It runs mypy
|
||||
on files that have been changed. It groups files into different mypy calls
|
||||
based on their directory to avoid import following issues.
|
||||
|
||||
Usage:
|
||||
python tools/pre_commit/mypy.py <ci> <python_version> <changed_files...>
|
||||
|
||||
Args:
|
||||
ci: "1" if running in CI, "0" otherwise. In CI, follow_imports is set to
|
||||
"silent" for the main group of files.
|
||||
python_version: Python version to use (e.g., "3.10") or "local" to use
|
||||
the local Python version.
|
||||
changed_files: List of changed files to check.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
import regex as re
|
||||
|
||||
FILES = [
|
||||
"vllm/*.py",
|
||||
"vllm/assets",
|
||||
"vllm/entrypoints",
|
||||
"vllm/inputs",
|
||||
"vllm/logging_utils",
|
||||
"vllm/multimodal",
|
||||
"vllm/platforms",
|
||||
"vllm/transformers_utils",
|
||||
"vllm/triton_utils",
|
||||
"vllm/usage",
|
||||
]
|
||||
|
||||
# After fixing errors resulting from changing follow_imports
|
||||
# from "skip" to "silent", move the following directories to FILES
|
||||
SEPARATE_GROUPS = [
|
||||
"tests",
|
||||
"vllm/attention",
|
||||
"vllm/compilation",
|
||||
"vllm/distributed",
|
||||
"vllm/engine",
|
||||
"vllm/executor",
|
||||
"vllm/inputs",
|
||||
"vllm/lora",
|
||||
"vllm/model_executor",
|
||||
"vllm/plugins",
|
||||
"vllm/worker",
|
||||
"vllm/v1",
|
||||
]
|
||||
|
||||
# TODO(woosuk): Include the code from Megatron and HuggingFace.
|
||||
EXCLUDE = [
|
||||
"vllm/model_executor/parallel_utils",
|
||||
"vllm/model_executor/models",
|
||||
"vllm/model_executor/layers/fla/ops",
|
||||
# Ignore triton kernels in ops.
|
||||
"vllm/attention/ops",
|
||||
]
|
||||
|
||||
|
||||
def group_files(changed_files: list[str]) -> dict[str, list[str]]:
|
||||
"""
|
||||
Group changed files into different mypy calls.
|
||||
|
||||
Args:
|
||||
changed_files: List of changed files.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping file group names to lists of changed files.
|
||||
"""
|
||||
exclude_pattern = re.compile(f"^{'|'.join(EXCLUDE)}.*")
|
||||
files_pattern = re.compile(f"^({'|'.join(FILES)}).*")
|
||||
file_groups = {"": []}
|
||||
file_groups.update({k: [] for k in SEPARATE_GROUPS})
|
||||
for changed_file in changed_files:
|
||||
# Skip files which should be ignored completely
|
||||
if exclude_pattern.match(changed_file):
|
||||
continue
|
||||
# Group files by mypy call
|
||||
if files_pattern.match(changed_file):
|
||||
file_groups[""].append(changed_file)
|
||||
continue
|
||||
else:
|
||||
for directory in SEPARATE_GROUPS:
|
||||
if re.match(f"^{directory}.*", changed_file):
|
||||
file_groups[directory].append(changed_file)
|
||||
break
|
||||
return file_groups
|
||||
|
||||
|
||||
def mypy(targets: list[str], python_version: Optional[str],
|
||||
follow_imports: Optional[str], file_group: str) -> int:
|
||||
"""
|
||||
Run mypy on the given targets.
|
||||
|
||||
Args:
|
||||
targets: List of files or directories to check.
|
||||
python_version: Python version to use (e.g., "3.10") or None to use
|
||||
the default mypy version.
|
||||
follow_imports: Value for the --follow-imports option or None to use
|
||||
the default mypy behavior.
|
||||
file_group: The file group name for logging purposes.
|
||||
|
||||
Returns:
|
||||
The return code from mypy.
|
||||
"""
|
||||
args = ["mypy"]
|
||||
if python_version is not None:
|
||||
args += ["--python-version", python_version]
|
||||
if follow_imports is not None:
|
||||
args += ["--follow-imports", follow_imports]
|
||||
print(f"$ {' '.join(args)} {file_group}")
|
||||
return subprocess.run(args + targets, check=False).returncode
|
||||
|
||||
|
||||
def main():
|
||||
ci = sys.argv[1] == "1"
|
||||
python_version = sys.argv[2]
|
||||
file_groups = group_files(sys.argv[3:])
|
||||
|
||||
if python_version == "local":
|
||||
python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
|
||||
|
||||
returncode = 0
|
||||
for file_group, changed_files in file_groups.items():
|
||||
follow_imports = None if ci and file_group == "" else "skip"
|
||||
if changed_files:
|
||||
returncode |= mypy(changed_files, python_version, follow_imports,
|
||||
file_group)
|
||||
return returncode
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@ -10,7 +10,6 @@ from typing import (Any, Dict, Generic, List, Optional, Protocol, Set, Tuple,
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
|
||||
from vllm.multimodal import MultiModalPlaceholderMap
|
||||
|
||||
|
||||
class AttentionType:
|
||||
@ -116,15 +115,6 @@ class AttentionMetadata:
|
||||
# in block 0, and 1st slot in block 1, respectively.
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
# The index maps that relate multi-modal embeddings to the corresponding
|
||||
# placeholders.
|
||||
#
|
||||
# N.B. These aren't really related to attention and don't belong on this
|
||||
# type -- this is just a temporary solution to make them available to
|
||||
# `model_executable`.
|
||||
multi_modal_placeholder_index_maps: Optional[Dict[
|
||||
str, MultiModalPlaceholderMap.IndexMap]]
|
||||
|
||||
# Enable/disable KV scales calculation. This is so that we can disable the
|
||||
# calculation until after prefill and cuda graph capture.
|
||||
enable_kv_scales_calculation: bool
|
||||
|
||||
@ -1,10 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from itertools import accumulate
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
@ -12,7 +11,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.multimodal import MultiModalPlaceholderMap
|
||||
from vllm.utils import async_tensor_h2d
|
||||
|
||||
# Placeholder attention backend for models like Mamba and pooling models that
|
||||
@ -141,8 +139,6 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=self.
|
||||
multi_modal_placeholder_index_maps,
|
||||
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
@ -178,7 +174,6 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
@ -210,9 +205,6 @@ class PlaceholderAttentionMetadataBuilder(
|
||||
self.prefill_seq_lens: List[int] = []
|
||||
self.context_lens: List[int] = []
|
||||
self.curr_seq_lens: List[int] = []
|
||||
self.multimodal_placeholder_maps: Dict[
|
||||
str,
|
||||
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
|
||||
self.num_prefills = 0
|
||||
self.num_prefill_tokens = 0
|
||||
self.num_decode_tokens = 0
|
||||
@ -232,12 +224,6 @@ class PlaceholderAttentionMetadataBuilder(
|
||||
self.context_lens.append(context_len)
|
||||
|
||||
if is_prompt:
|
||||
mm_maps = inter_data.multi_modal_placeholder_maps
|
||||
if mm_maps:
|
||||
for modality, placeholders in mm_maps.items():
|
||||
self.multimodal_placeholder_maps[modality].extend(
|
||||
placeholders)
|
||||
|
||||
self.num_prefills += 1
|
||||
self.num_prefill_tokens += token_len
|
||||
self.prefill_seq_lens.append(seq_len)
|
||||
@ -295,12 +281,6 @@ class PlaceholderAttentionMetadataBuilder(
|
||||
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
|
||||
device, self.runner.pin_memory)
|
||||
|
||||
placeholder_index_maps = {
|
||||
modality: placeholder_map.index_map()
|
||||
for modality, placeholder_map in
|
||||
self.multimodal_placeholder_maps.items()
|
||||
}
|
||||
|
||||
# Placeholders
|
||||
slot_mapping_tensor = torch.empty(0)
|
||||
block_tables = torch.empty(0)
|
||||
@ -308,7 +288,6 @@ class PlaceholderAttentionMetadataBuilder(
|
||||
return PlaceholderAttentionMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
enable_kv_scales_calculation=True,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention backend utils"""
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from itertools import accumulate
|
||||
@ -15,16 +14,10 @@ from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MultiModalPlaceholderMap
|
||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Error string(s) for encoder/decoder
|
||||
# unsupported attention scenarios
|
||||
STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported "
|
||||
"with encoder/decoder models.")
|
||||
|
||||
PAD_SLOT_ID = -1
|
||||
|
||||
# Switch to numpy implementation of compute_slot_mapping
|
||||
@ -135,9 +128,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
|
||||
self.context_lens: List[int] = []
|
||||
self.block_tables: List[List[int]] = []
|
||||
self.curr_seq_lens: List[int] = []
|
||||
self.multimodal_placeholder_maps: Dict[
|
||||
str,
|
||||
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
|
||||
self.num_prefills = 0
|
||||
self.num_prefill_tokens = 0
|
||||
self.num_decode_tokens = 0
|
||||
@ -154,12 +144,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
|
||||
inter_data.curr_sliding_window_blocks):
|
||||
self.context_lens.append(context_len)
|
||||
if is_prompt:
|
||||
mm_maps = inter_data.multi_modal_placeholder_maps
|
||||
if mm_maps:
|
||||
for modality, placeholders in mm_maps.items():
|
||||
self.multimodal_placeholder_maps[modality].extend(
|
||||
placeholders)
|
||||
|
||||
self.num_prefills += 1
|
||||
self.num_prefill_tokens += token_len
|
||||
self.prefill_seq_lens.append(seq_len)
|
||||
@ -254,16 +238,10 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
|
||||
self.runner.pin_memory)
|
||||
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
|
||||
device, self.runner.pin_memory)
|
||||
placeholder_index_maps = {
|
||||
modality: placeholder_map.index_map()
|
||||
for modality, placeholder_map in
|
||||
self.multimodal_placeholder_maps.items()
|
||||
}
|
||||
|
||||
return self._metadata_cls( # type: ignore
|
||||
num_prefills=self.num_prefills,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
enable_kv_scales_calculation=True,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
@ -320,7 +298,6 @@ class CommonAttentionState(AttentionState):
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=batch_size,
|
||||
slot_mapping=self._graph_slot_mapping[:batch_size],
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=self._graph_seq_lens[:batch_size],
|
||||
|
||||
@ -134,6 +134,5 @@ def cp_lse_ag_out_rs(cp_attn_out: torch.Tensor,
|
||||
cp_attn_lse = cp_attn_lse.contiguous()
|
||||
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
|
||||
out, _ = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
|
||||
assert out.is_contiguous()
|
||||
out = cp_group.reduce_scatter(out, dim=1)
|
||||
return out
|
||||
|
||||
@ -531,18 +531,22 @@ async def benchmark(
|
||||
extra_body=extra_body,
|
||||
)
|
||||
|
||||
test_output = await wait_for_endpoint(
|
||||
request_func,
|
||||
test_input,
|
||||
session,
|
||||
timeout_seconds=ready_check_timeout_sec,
|
||||
)
|
||||
if not test_output.success:
|
||||
raise ValueError(
|
||||
"Initial test run failed - Please make sure benchmark arguments "
|
||||
f"are correctly specified. Error: {test_output.error}")
|
||||
if ready_check_timeout_sec > 0:
|
||||
test_output = await wait_for_endpoint(
|
||||
request_func,
|
||||
test_input,
|
||||
session,
|
||||
timeout_seconds=ready_check_timeout_sec,
|
||||
)
|
||||
if not test_output.success:
|
||||
raise ValueError(
|
||||
"Initial test run failed - Please make sure benchmark "
|
||||
"arguments are correctly specified. "
|
||||
f"Error: {test_output.error}")
|
||||
else:
|
||||
print("Initial test run completed. Starting main benchmark run...")
|
||||
else:
|
||||
print("Initial test run completed. Starting main benchmark run...")
|
||||
print("Skipping endpoint ready check.")
|
||||
|
||||
if lora_modules:
|
||||
# For each input request, choose a LoRA module at random.
|
||||
@ -1151,7 +1155,8 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
||||
type=int,
|
||||
default=600,
|
||||
help="Maximum time to wait for the endpoint to become ready "
|
||||
"in seconds (default: 600 seconds / 10 minutes).",
|
||||
"in seconds (default: 600 seconds / 10 minutes). If set to 0, "
|
||||
"the ready check will be skipped."
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -17,7 +17,7 @@ from vllm.platforms import current_platform
|
||||
|
||||
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -152,7 +152,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
|
||||
register_replacement(pattern, replacement, inputs, fwd_only, pm_pass)
|
||||
|
||||
|
||||
class ActivationQuantFusionPass(VllmInductorPass):
|
||||
class ActivationQuantFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass fuses a pre-defined set of custom ops into fused ops.
|
||||
It uses the torch pattern matcher to find the patterns and replace them.
|
||||
@ -176,16 +176,12 @@ class ActivationQuantFusionPass(VllmInductorPass):
|
||||
pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern()
|
||||
pattern_silu_mul_nvfp4.register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.begin()
|
||||
self.dump_graph(graph, "before_act_quant_fusion")
|
||||
|
||||
count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns in ActivationQuantFusionPass",
|
||||
count)
|
||||
|
||||
self.dump_graph(graph, "after_act_quant_fusion")
|
||||
self.end_and_log()
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
def uuid(self):
|
||||
return VllmInductorPass.hash_source(self, ActivationQuantPattern,
|
||||
|
||||
@ -20,7 +20,7 @@ from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
@ -348,7 +348,7 @@ class AllGatherCutlassScaledMMPattern(BasePattern):
|
||||
pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AsyncTPPass(VllmInductorPass):
|
||||
class AsyncTPPass(VllmPatternMatcherPass):
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig):
|
||||
@ -378,18 +378,17 @@ class AsyncTPPass(VllmInductorPass):
|
||||
AllGatherCutlassScaledMMPattern(
|
||||
self.model_dtype, self.device).register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
|
||||
# only do replace for specific shapes
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
return shape is not None and shape % tp_size == 0
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph):
|
||||
self.begin()
|
||||
self.dump_graph(graph, "before_async_tp_pass")
|
||||
count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns with async TP pass.", count)
|
||||
self.dump_graph(graph, "after_async_tp_pass")
|
||||
self.end_and_log()
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
|
||||
if flashinfer_comm is not None:
|
||||
@ -1068,7 +1067,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||
pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AllReduceFusionPass(VllmInductorPass):
|
||||
class AllReduceFusionPass(VllmPatternMatcherPass):
|
||||
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
@ -1124,6 +1123,7 @@ class AllReduceFusionPass(VllmInductorPass):
|
||||
fuse_rms_quant=config.compilation_config.pass_config.enable_fusion)
|
||||
|
||||
self.register_patterns()
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@enable_fake_mode
|
||||
def register_patterns(self):
|
||||
@ -1172,15 +1172,14 @@ class AllReduceFusionPass(VllmInductorPass):
|
||||
|
||||
self.disabled = False
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph):
|
||||
if self.disabled:
|
||||
logger.debug("AllReduceFusionPass disabled")
|
||||
return
|
||||
self.begin()
|
||||
self.dump_graph(graph, "before_all_reduce_fusion_pass")
|
||||
count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", count)
|
||||
self.dump_graph(graph, "after_all_reduce_fusion_pass")
|
||||
self.end_and_log()
|
||||
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
def __del__(self):
|
||||
if getattr(self, "disabled", True):
|
||||
|
||||
@ -26,6 +26,7 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
To add new nodes to defunctionalize, add to the if-elif chain in __call__.
|
||||
"""
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
# XPU does not support auto-functionalization yet.
|
||||
# Will enable this when switch to vllm-xpu-kernels.
|
||||
@ -34,9 +35,6 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
"pass currently.")
|
||||
return
|
||||
|
||||
self.begin()
|
||||
self.dump_graph(graph, "before_fix_functionalization")
|
||||
|
||||
self.nodes_to_remove: list[torch.fx.Node] = []
|
||||
count = 0
|
||||
for node in graph.nodes:
|
||||
@ -111,7 +109,7 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
|
||||
count += 1
|
||||
|
||||
self.dump_graph(graph, "before_fix_functionalization_cleanup")
|
||||
self.dump_graph(graph, "before_cleanup")
|
||||
|
||||
# Remove the nodes all at once
|
||||
count_removed = len(self.nodes_to_remove)
|
||||
@ -120,8 +118,7 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
|
||||
logger.debug("De-functionalized %s nodes, removed %s nodes", count,
|
||||
count_removed)
|
||||
self.dump_graph(graph, "after_fix_functionalization")
|
||||
self.end_and_log()
|
||||
self.nodes_to_remove.clear()
|
||||
|
||||
def _remove(self, node_or_nodes: Union[torch.fx.Node,
|
||||
Iterable[torch.fx.Node]]):
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Callable, NamedTuple, Optional
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
@ -16,10 +16,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .fx_utils import find_getitem_maybe
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from .multi_output_match import MultiOutputMatch
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
@ -50,8 +48,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
|
||||
torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
|
||||
}
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
QUANT_OPS[
|
||||
kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
|
||||
QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default
|
||||
|
||||
|
||||
class FusedRMSQuantKey(NamedTuple):
|
||||
@ -80,68 +77,6 @@ FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
|
||||
}
|
||||
|
||||
|
||||
class QuantMultiOutputMatch(MultiOutputMatch):
|
||||
|
||||
def __init__(self, match: pm.Match, quant_op, fused_op):
|
||||
super().__init__(match)
|
||||
assert isinstance(quant_op, OpOverload)
|
||||
assert isinstance(fused_op, OpOverload)
|
||||
self.QUANT_OP = quant_op # in-place quant op
|
||||
self.FUSED_OP = fused_op # in-place fused quant op
|
||||
|
||||
def insert_fused_node(self, fused_return_mapping: dict[int, tuple[fx.Node,
|
||||
int]],
|
||||
**kwargs):
|
||||
"""
|
||||
This utility function inserts an auto-functionalized node for FUSED_OP.
|
||||
It also correctly sets its meta value and rebinds the users of the
|
||||
unfused nodes to use the fused node instead.
|
||||
|
||||
:param fused_return_mapping: A dictionary, mapping from getitem indices
|
||||
of the fused node result to a tuple of the old node and a getitem index.
|
||||
:param kwargs: kwargs that get directly forwarded to the auto_fn node
|
||||
|
||||
Example:
|
||||
If we want to replace this graph:
|
||||
_, x1, x2 = auto_fn(op1)
|
||||
_, y1, y2 = auto_fn(op2)
|
||||
|
||||
with
|
||||
_, x1, y2, x2 = auto_fn(FUSED_OP)
|
||||
|
||||
we would call:
|
||||
insert_fused_node({1: (op1_node, 1), 2: (op2_node, 2), 3: (op1_node, 2)}
|
||||
|
||||
Note that the 0th element is None for auto-functionalized in-place ops.
|
||||
Hence, others appear 1-indexed.
|
||||
"""
|
||||
fused_node = self.insert_auto_fn(self.FUSED_OP, kwargs)
|
||||
indices = fused_return_mapping.keys()
|
||||
getitem_nodes = self.insert_getitems(fused_node, indices)
|
||||
|
||||
# Prepare the meta value, use a list so it's mutable
|
||||
meta_val = [None] * (max(indices) + 1)
|
||||
|
||||
# Iterate through elements of the tuple produced by fused_node
|
||||
for idx, getitem_node in zip(indices, getitem_nodes):
|
||||
old_node, old_idx = fused_return_mapping[idx]
|
||||
|
||||
# If the old value was never used, the old_getitem might not exist
|
||||
old_getitem = find_getitem_maybe(old_node, old_idx)
|
||||
if old_getitem is not None:
|
||||
# Rebind the users of match getitem nodes to use the new nodes.
|
||||
# The old nodes will be removed by DCE at the end of the pass.
|
||||
old_getitem.replace_all_uses_with(getitem_node)
|
||||
getitem_node.meta["val"] = old_getitem.meta["val"]
|
||||
|
||||
# Extract the appropriate meta value
|
||||
# It is present even if the getitem node does not exist
|
||||
meta_val[idx] = old_node.meta["val"][old_idx]
|
||||
|
||||
# Fix the meta value on the new fused node
|
||||
fused_node.meta["val"] = tuple(meta_val)
|
||||
|
||||
|
||||
class RMSNormQuantPattern:
|
||||
|
||||
def __init__(self, epsilon: float, key: FusedRMSQuantKey):
|
||||
@ -224,8 +159,7 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
symmetric=symmetric))
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass,
|
||||
record_match: Callable[[MultiOutputMatch], bool]):
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(result: torch.Tensor, input: torch.Tensor,
|
||||
residual: torch.Tensor, weight: torch.Tensor,
|
||||
@ -271,36 +205,7 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
inputs,
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
extra_check=lambda m: record_match(
|
||||
self.Match(m, self.QUANT_OP, self.FUSED_OP)))
|
||||
|
||||
class Match(QuantMultiOutputMatch):
|
||||
|
||||
def process(self):
|
||||
# Find the nodes in the match that we need to rebind
|
||||
rms_node = self.find_auto_fn(RMS_ADD_OP)
|
||||
quant_node = self.find_auto_fn(self.QUANT_OP)
|
||||
|
||||
assert len(rms_node.users) == 2
|
||||
assert len(quant_node.users) == 1
|
||||
|
||||
# First, insert a new auto_functionalized node for the fused op,
|
||||
# as well as getitem nodes to extract the result and residual.
|
||||
# The auto_fn node returns a tuple of (None, result, residual).
|
||||
#
|
||||
# The resulting graph looks like this:
|
||||
# at = auto_functionalized(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ...) # noqa
|
||||
# result_node_new = at[1]
|
||||
# residual_node_new = at[2]
|
||||
with self.inserting_after_match():
|
||||
# Missing epsilon, scalars cannot be inputs to the pattern
|
||||
kwargs = self.match.kwargs.copy()
|
||||
|
||||
# 0 is always None
|
||||
fused_return_mapping = {1: (quant_node, 1), 2: (rms_node, 2)}
|
||||
self.insert_fused_node(fused_return_mapping,
|
||||
**kwargs,
|
||||
epsilon=rms_node.kwargs["epsilon"])
|
||||
)
|
||||
|
||||
|
||||
class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
@ -317,8 +222,7 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
symmetric=symmetric))
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass,
|
||||
record_match: Callable[[MultiOutputMatch], bool]):
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(result: torch.Tensor, result_rms: torch.Tensor,
|
||||
input: torch.Tensor, weight: torch.Tensor,
|
||||
@ -366,39 +270,7 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
inputs,
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
extra_check=lambda m: record_match(
|
||||
self.Match(m, self.QUANT_OP, self.FUSED_OP)))
|
||||
|
||||
class Match(QuantMultiOutputMatch):
|
||||
|
||||
def process(self):
|
||||
# Find the nodes in the match that we need to rebind
|
||||
rms_node = self.find_auto_fn(RMS_OP)
|
||||
quant_node = self.find_auto_fn(self.QUANT_OP)
|
||||
|
||||
assert len(rms_node.users) == 1
|
||||
assert len(quant_node.users) == 2
|
||||
|
||||
# First, insert a new auto_functionalized node for the fused op,
|
||||
# as well as getitem nodes to extract the result and scale.
|
||||
# The auto_fn node returns a tuple of (None, result, scale).
|
||||
#
|
||||
# The resulting graph looks like this:
|
||||
# at = auto_functionalized(torch.ops._C.rms_norm_dynamic_per_token_quant.default, ...) # noqa
|
||||
# result_node_new = at[1]
|
||||
# scale_node_new = at[2]
|
||||
with self.inserting_after_match():
|
||||
# Missing epsilon, scalars cannot be inputs to the pattern
|
||||
kwargs = self.match.kwargs.copy()
|
||||
del kwargs["result_rms"] # not used in the fused op
|
||||
|
||||
fused_return_mapping = {1: (quant_node, 1), 2: (quant_node, 2)}
|
||||
self.insert_fused_node(
|
||||
fused_return_mapping,
|
||||
epsilon=rms_node.kwargs["epsilon"],
|
||||
scale_ub=None, # not used but required
|
||||
residual=None, # not used but required
|
||||
**kwargs)
|
||||
)
|
||||
|
||||
|
||||
class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
@ -415,8 +287,7 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
symmetric=symmetric))
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass,
|
||||
record_match: Callable[[MultiOutputMatch], bool]):
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(result: torch.Tensor, input: torch.Tensor,
|
||||
residual: torch.Tensor, weight: torch.Tensor,
|
||||
@ -464,137 +335,49 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
inputs,
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
extra_check=lambda m: record_match(
|
||||
self.Match(m, self.QUANT_OP, self.FUSED_OP)))
|
||||
|
||||
class Match(QuantMultiOutputMatch):
|
||||
|
||||
def process(self):
|
||||
# Find the nodes in the match that we need to rebind
|
||||
rms_node = self.find_auto_fn(RMS_ADD_OP)
|
||||
quant_node = self.find_auto_fn(self.QUANT_OP)
|
||||
|
||||
assert len(rms_node.users) == 2
|
||||
assert len(quant_node.users) == 2
|
||||
|
||||
# First, insert a new auto_functionalized node for the fused op,
|
||||
# as well as getitem nodes to extract result, scale, and residual.
|
||||
# The auto_fn node returns a tuple (None, result, scale, residual).
|
||||
#
|
||||
# The resulting graph looks like this:
|
||||
# at = auto_functionalized(torch.ops._C.rms_norm_dynamic_per_token_quant.default, ...) # noqa
|
||||
# result_node_new = at[1]
|
||||
# scale_node_new = at[2]
|
||||
# residual_node_new = at[3]
|
||||
with self.inserting_after_match():
|
||||
# Missing epsilon, scalars cannot be inputs to the pattern
|
||||
kwargs = self.match.kwargs.copy()
|
||||
|
||||
fused_return_mapping = {
|
||||
1: (quant_node, 1), # result
|
||||
2: (quant_node, 2), # scale
|
||||
3: (rms_node, 2), # residual
|
||||
}
|
||||
self.insert_fused_node(
|
||||
fused_return_mapping,
|
||||
epsilon=rms_node.kwargs["epsilon"],
|
||||
scale_ub=None, # not used but required
|
||||
**kwargs)
|
||||
)
|
||||
|
||||
|
||||
class FusionPass(VllmInductorPass):
|
||||
class RMSNormQuantFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass fuses a pre-defined set of custom ops into fused ops.
|
||||
It uses the torch pattern matcher to find the patterns and replace them.
|
||||
It also manually processes multi-output matches, as those are broken in
|
||||
the torch pattern matcher.
|
||||
|
||||
Because patterns can only be registered once, the pass is a singleton.
|
||||
This will be addressed in a future version of PyTorch:
|
||||
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
|
||||
This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
|
||||
It also supports fused_add_rms_norm.
|
||||
"""
|
||||
|
||||
_instance: 'Optional[FusionPass]' = None
|
||||
|
||||
@classmethod
|
||||
def instance(cls, config: VllmConfig):
|
||||
"""
|
||||
Get the singleton instance of the FusionPass.
|
||||
If the instance exists, the config is updated but
|
||||
initialization is not repeated.
|
||||
"""
|
||||
if cls._instance is None:
|
||||
cls._instance = FusionPass(config)
|
||||
else:
|
||||
cls._instance.pass_config = config.compilation_config.pass_config
|
||||
return cls._instance
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig):
|
||||
assert self.__class__._instance is None, \
|
||||
"FusionPass singleton instance already exists"
|
||||
super().__init__(config)
|
||||
|
||||
self.matches: list[MultiOutputMatch] = []
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="fusion_pass")
|
||||
pass_name="rmsnorm_quant_fusion_pass")
|
||||
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
# Fuse rms_norm + static fp8 quant
|
||||
RMSNormStaticQuantPattern(epsilon,
|
||||
FP8_DTYPE).register(self.patterns)
|
||||
|
||||
# Matches for patterns below have 2 or more outputs,
|
||||
# so we need to process them manually (see process_matches)
|
||||
|
||||
# Fuse rms_norm + static fp8 quant
|
||||
# Fuse fused_add_rms_norm + static fp8 quant
|
||||
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
|
||||
self.patterns, self.record_match)
|
||||
self.patterns)
|
||||
|
||||
# Fuse rms_norm + dynamic per-token fp8 quant
|
||||
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
|
||||
self.patterns, self.record_match)
|
||||
RMSNormDynamicQuantPattern(epsilon,
|
||||
FP8_DTYPE).register(self.patterns)
|
||||
|
||||
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
|
||||
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
|
||||
self.patterns, self.record_match)
|
||||
self.patterns)
|
||||
|
||||
# WARNING: This is a hack to clear the pattern matcher cache
|
||||
# and allow multiple values of epsilon.
|
||||
torch._inductor.pattern_matcher._seen_patterns.clear()
|
||||
|
||||
def record_match(self, match: MultiOutputMatch) -> bool:
|
||||
# Hijack the extra_check to record the match and
|
||||
# save it for post-processing.
|
||||
self.matches.append(match)
|
||||
|
||||
# Return False to prevent automatic replacement.
|
||||
return False
|
||||
|
||||
def process_matches(self, graph: fx.Graph):
|
||||
"""
|
||||
Manually process multi-output matches and replace them with fused nodes.
|
||||
See MultiOutputMatch for more details.
|
||||
"""
|
||||
for match in self.matches:
|
||||
match.process()
|
||||
|
||||
# Finally, remove matched nodes
|
||||
graph.eliminate_dead_code()
|
||||
assert all(node not in graph.nodes for match in self.matches
|
||||
for node in match.match.nodes)
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph):
|
||||
self.begin()
|
||||
self.dump_graph(graph, "before_fusion")
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", count)
|
||||
self.dump_graph(graph, "after_pattern_match")
|
||||
|
||||
# Manually process multi-output matches (and run DCE)
|
||||
self.process_matches(graph)
|
||||
logger.debug("Post-processed %s matches", len(self.matches))
|
||||
self.dump_graph(graph, "after_fusion")
|
||||
self.matches.clear()
|
||||
self.end_and_log()
|
||||
def uuid(self) -> Any:
|
||||
return self.hash_source(self, RMSNormQuantPattern,
|
||||
RMSNormStaticQuantPattern,
|
||||
RMSNormDynamicQuantPattern,
|
||||
FusedAddRMSNormStaticQuantPattern,
|
||||
FusedAddRMSNormDynamicQuantPattern)
|
||||
|
||||
@ -18,7 +18,7 @@ from vllm.utils import round_up
|
||||
|
||||
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -245,7 +245,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
||||
pm_pass)
|
||||
|
||||
|
||||
class AttnFusionPass(VllmInductorPass):
|
||||
class AttnFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass fuses post-attention quantization onto attention if supported.
|
||||
|
||||
@ -282,20 +282,12 @@ class AttnFusionPass(VllmInductorPass):
|
||||
"were found in CompilationConfig.static_forward_context "
|
||||
"so no fusion patterns were registered.")
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: torch.fx.graph.Graph) -> None:
|
||||
self.begin()
|
||||
self.dump_graph(graph, "before_attn_fusion")
|
||||
|
||||
count = self.patterns.apply(graph)
|
||||
|
||||
# TODO: Move this to pass_manager.py after the fx graph broken issue
|
||||
# has been resolved.
|
||||
# see https://github.com/vllm-project/vllm/issues/23091
|
||||
graph.eliminate_dead_code()
|
||||
|
||||
logger.debug("Fused quantization onto %s attention nodes", count)
|
||||
self.dump_graph(graph, "after_attn_fusion")
|
||||
self.end_and_log()
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Fused quant onto %s attention nodes", self.matched_count)
|
||||
|
||||
def uuid(self):
|
||||
return VllmInductorPass.hash_source(self, AttentionQuantPattern,
|
||||
|
||||
@ -1,109 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import abc
|
||||
import operator
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Iterable
|
||||
|
||||
from torch import fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor import pattern_matcher as pm
|
||||
from torch._ops import OpOverload
|
||||
from torch.fx import Node
|
||||
|
||||
from vllm.compilation.fx_utils import find_auto_fn
|
||||
|
||||
|
||||
class MultiOutputMatch(abc.ABC):
|
||||
"""
|
||||
This class provides utilities to process multi-output matches and
|
||||
manually insert replacements.
|
||||
|
||||
This is necessary because the automatic replacement for multi-output
|
||||
matches is broken: https://github.com/pytorch/pytorch/issues/137280
|
||||
"""
|
||||
|
||||
def __init__(self, match: pm.Match):
|
||||
self.match = match
|
||||
|
||||
@abstractmethod
|
||||
def process(self):
|
||||
"""
|
||||
Process a multi-output match and manually insert the replacement.
|
||||
|
||||
This method should:
|
||||
1. Insert the replacement nodes after the last node in the match.
|
||||
2. Rebind the users of nodes in the match to use the new nodes.
|
||||
3. Set meta["val"] for de-functionalization.
|
||||
|
||||
The result of an auto-functionalized node is a tuple of tensors.
|
||||
The first element is the return value of the function, usually None.
|
||||
The remaining elements are the mutated args of the function.
|
||||
|
||||
All auto-functionalized nodes must contain a proper meta["val"],
|
||||
as it is used by de-functionalization. meta["val"] has to contain the
|
||||
value of the node (tuple of tensors) that would be returned by the
|
||||
functionalized node during tracing.
|
||||
|
||||
Existing nodes in the graph all have this property set, but we have
|
||||
to set it manually for new nodes we insert.
|
||||
|
||||
Example:
|
||||
# op schema: foo(a: Tensor!, b: Tensor, c: Tensor!) -> None
|
||||
at = auto_functionalized(torch.ops._C.foo.default, a, b, c)
|
||||
# at.meta["val"] = (None, a, c)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def nodes(self) -> list[fx.Node]:
|
||||
return self.match.nodes
|
||||
|
||||
@property
|
||||
def graph(self) -> fx.Graph:
|
||||
return self.match.graph
|
||||
|
||||
def find_auto_fn(self, op) -> fx.Node:
|
||||
"""
|
||||
Find the first auto_functionalized node with the given op in the match.
|
||||
"""
|
||||
return find_auto_fn(self.nodes, op)
|
||||
|
||||
def inserting_after_match(self):
|
||||
"""
|
||||
Insert nodes after the last node in the match.
|
||||
This is done to avoid use-before-definition errors after inserting
|
||||
replacement nodes.
|
||||
"""
|
||||
|
||||
# match.nodes is not guaranteed to be sorted.
|
||||
# Find the last node in the match.
|
||||
for last_node_in_match in reversed(self.graph.nodes):
|
||||
if last_node_in_match in self.match.nodes:
|
||||
break
|
||||
else:
|
||||
raise ValueError("No nodes in graph")
|
||||
|
||||
return self.graph.inserting_after(last_node_in_match)
|
||||
|
||||
def insert_getitems(self, tuple_node: fx.Node,
|
||||
indices: Iterable[int]) -> tuple[fx.Node, ...]:
|
||||
"""
|
||||
Insert operator.getitem nodes to extract elements from a tuple node.
|
||||
|
||||
:param tuple_node: The tuple node to extract elements from.
|
||||
:param indices: The indices of the elements to extract.
|
||||
:return: Tuple of the new getitem nodes, corresponding to the indices.
|
||||
"""
|
||||
with self.graph.inserting_after(tuple_node):
|
||||
return tuple(
|
||||
self.graph.call_function(operator.getitem, (tuple_node, idx))
|
||||
for idx in indices)
|
||||
|
||||
def insert_auto_fn(self, op: OpOverload, kwargs) -> Node:
|
||||
"""
|
||||
Insert an auto_functionalized node with the given op and kwargs.
|
||||
"""
|
||||
return self.graph.call_function(auto_functionalized, (op, ),
|
||||
kwargs=kwargs)
|
||||
@ -64,9 +64,8 @@ class NoOpEliminationPass(VllmInductorPass):
|
||||
out: "f16[s0, 4096]" = at[1]
|
||||
"""
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.begin()
|
||||
self.dump_graph(graph, "before_noop_elimination")
|
||||
count = 0
|
||||
# Remove no-op reshapes/views:
|
||||
for node in graph.nodes:
|
||||
@ -121,8 +120,6 @@ class NoOpEliminationPass(VllmInductorPass):
|
||||
count += 1
|
||||
|
||||
logger.debug("Removed %s no-op reshapes and slices", count)
|
||||
self.dump_graph(graph, "after_noop_elimination")
|
||||
self.end_and_log()
|
||||
|
||||
# ---------------------- Reshape helpers ----------------------
|
||||
def reshape_dims_equivalent(self, dim: Union[int, torch.fx.Node],
|
||||
|
||||
@ -1,15 +1,21 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
|
||||
from torch import fx as fx
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import set_env_var
|
||||
|
||||
from .post_cleanup import PostCleanupPass
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from .activation_quant_fusion import ActivationQuantFusionPass
|
||||
from .fusion import FusionPass
|
||||
from .fusion import RMSNormQuantFusionPass
|
||||
from .fusion_attn import AttnFusionPass
|
||||
|
||||
if current_platform.is_cuda():
|
||||
@ -19,11 +25,28 @@ from .fix_functionalization import FixFunctionalizationPass
|
||||
from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context
|
||||
from .noop_elimination import NoOpEliminationPass
|
||||
from .sequence_parallelism import SequenceParallelismPass
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def with_pattern_match_debug(fn):
|
||||
"""
|
||||
Function decorator that turns on inductor pattern match debug
|
||||
for the duration of the call.
|
||||
Used to avoid logging builtin Inductor pattern matching.
|
||||
"""
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
if (debug_val := envs.VLLM_PATTERN_MATCH_DEBUG) is not None:
|
||||
# optionally check rank here
|
||||
with set_env_var("TORCHINDUCTOR_PATTERN_MATCH_DEBUG", debug_val):
|
||||
return fn(*args, **kwargs)
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class PostGradPassManager(CustomGraphPass):
|
||||
"""
|
||||
The pass manager for post-grad passes.
|
||||
@ -40,16 +63,26 @@ class PostGradPassManager(CustomGraphPass):
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.passes: list[VllmInductorPass] = []
|
||||
self.passes: list[InductorPass] = []
|
||||
|
||||
@with_pattern_match_debug
|
||||
def __call__(self, graph: fx.Graph):
|
||||
VllmInductorPass.dump_prefix = 0 # reset dump index
|
||||
|
||||
shape = get_pass_context().runtime_shape
|
||||
for pass_ in self.passes:
|
||||
if pass_.is_applicable_for_shape(shape):
|
||||
pass_(graph)
|
||||
VllmInductorPass.dump_prefix += 1
|
||||
|
||||
# post-cleanup goes before fix_functionalization
|
||||
# because it requires a functional graph
|
||||
self.post_cleanup(graph)
|
||||
VllmInductorPass.dump_prefix += 1
|
||||
|
||||
# always run fix_functionalization last
|
||||
self.fix_functionalization(graph)
|
||||
VllmInductorPass.dump_prefix = None # Cleanup index
|
||||
|
||||
def configure(self, config: VllmConfig):
|
||||
self.pass_config = config.compilation_config.pass_config
|
||||
@ -61,14 +94,18 @@ class PostGradPassManager(CustomGraphPass):
|
||||
if self.pass_config.enable_async_tp:
|
||||
self.passes += [AsyncTPPass(config)]
|
||||
|
||||
if self.pass_config.enable_fi_allreduce_fusion:
|
||||
self.passes += [AllReduceFusionPass(config)]
|
||||
|
||||
if self.pass_config.enable_fusion:
|
||||
self.passes += [FusionPass.instance(config)]
|
||||
self.passes += [RMSNormQuantFusionPass(config)]
|
||||
self.passes += [ActivationQuantFusionPass(config)]
|
||||
|
||||
if self.pass_config.enable_attn_fusion:
|
||||
self.passes += [AttnFusionPass(config)]
|
||||
if self.pass_config.enable_fi_allreduce_fusion:
|
||||
self.passes += [AllReduceFusionPass(config)]
|
||||
|
||||
# needs a functional graph
|
||||
self.post_cleanup = PostCleanupPass(config)
|
||||
self.fix_functionalization = FixFunctionalizationPass(config)
|
||||
|
||||
def add(self, pass_: InductorPass):
|
||||
|
||||
20
vllm/compilation/post_cleanup.py
Normal file
20
vllm/compilation/post_cleanup.py
Normal file
@ -0,0 +1,20 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from torch import fx
|
||||
|
||||
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
|
||||
class PostCleanupPass(VllmInductorPass):
|
||||
"""
|
||||
This pass performs cleanup after custom passes.
|
||||
It topologically sorts the graph and removes unused nodes.
|
||||
This is needed because the pattern matcher does not guarantee producing
|
||||
a topologically sorted graph, and there may be unused nodes left around.
|
||||
"""
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
from torch._inductor.pattern_matcher import stable_topological_sort
|
||||
stable_topological_sort(graph)
|
||||
graph.eliminate_dead_code()
|
||||
@ -15,7 +15,7 @@ from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -417,7 +417,7 @@ class LastAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class SequenceParallelismPass(VllmInductorPass):
|
||||
class SequenceParallelismPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass enables sequence parallelism for models.
|
||||
It identifies patterns where an AllReduce operation is followed by
|
||||
@ -466,19 +466,13 @@ class SequenceParallelismPass(VllmInductorPass):
|
||||
|
||||
LastAllReduceRMSNormPattern(epsilon, self.model_dtype,
|
||||
self.device).register(self.patterns)
|
||||
|
||||
# WARNING: This is a hack to clear the pattern matcher cache
|
||||
# and allow multiple values of epsilon.
|
||||
torch._inductor.pattern_matcher._seen_patterns.clear()
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
return shape is not None and shape % tp_size == 0
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph):
|
||||
self.begin()
|
||||
self.dump_graph(graph, "before_sequence_parallelism_pass")
|
||||
count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns with sequence parallelism", count)
|
||||
self.dump_graph(graph, "after_sequence_parallelism_pass")
|
||||
self.end_and_log()
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
@ -1,10 +1,16 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import functools
|
||||
import operator
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
from torch._dynamo.utils import lazy_format_graph_code
|
||||
from torch._inductor.pattern_matcher import (PatternMatcherPass,
|
||||
PatternPrettyPrinter)
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
@ -19,6 +25,8 @@ class VllmInductorPass(InductorPass):
|
||||
An inductor pass with access to vLLM PassConfig.
|
||||
It provides timing, logging, and dumping utilities.
|
||||
"""
|
||||
dump_prefix: ClassVar[Optional[int]] = None
|
||||
"""Keep track of pass index for debug dump ordering."""
|
||||
|
||||
def __init__(self, config: VllmConfig):
|
||||
self.pass_config = config.compilation_config.pass_config
|
||||
@ -28,8 +36,24 @@ class VllmInductorPass(InductorPass):
|
||||
else None
|
||||
self.pass_name = self.__class__.__name__
|
||||
|
||||
@staticmethod
|
||||
def time_and_log(call_fn):
|
||||
|
||||
@functools.wraps(call_fn)
|
||||
def wrapped(self: VllmInductorPass, graph: torch.fx.Graph):
|
||||
self.begin()
|
||||
self.dump_graph(graph, "before")
|
||||
call_fn(self, graph)
|
||||
self.dump_graph(graph, "after")
|
||||
self.end_and_log()
|
||||
|
||||
return wrapped
|
||||
|
||||
def dump_graph(self, graph: torch.fx.Graph, stage: str):
|
||||
lazy_format_graph_code(stage, graph.owning_module)
|
||||
i = VllmInductorPass.dump_prefix
|
||||
i_str = "" if i is None else f".{i}"
|
||||
lazy_format_graph_code(f"post_grad{i_str}.{self.pass_name}.{stage}",
|
||||
graph.owning_module)
|
||||
|
||||
def begin(self):
|
||||
self._start_time = time.perf_counter_ns()
|
||||
@ -40,6 +64,88 @@ class VllmInductorPass(InductorPass):
|
||||
logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms)
|
||||
|
||||
|
||||
class VllmPatternMatcherPass(VllmInductorPass):
|
||||
"""
|
||||
A VllmInductorPass that uses the Inductor pattern matcher.
|
||||
Its main use is providing the dump_patterns utility that dumps the
|
||||
Inductor pattern matcher patterns into a file, which greatly aids debugging.
|
||||
|
||||
TODO(luka) move more utilities to this pass.
|
||||
"""
|
||||
matched_count: int = 0
|
||||
"""The number of matched patterns in the pass."""
|
||||
|
||||
_OP_OVERLOAD_PATTERN: ClassVar[re.Pattern] = re.compile(
|
||||
r"<OpOverload\(op='([^']*)', overload='([^']*)'\)>")
|
||||
|
||||
def _replace_op_overloads(self, string: str) -> str:
|
||||
"""Replace <OpOverload(..., ...)> with nicer formulations"""
|
||||
return self._OP_OVERLOAD_PATTERN.sub(
|
||||
lambda m: f"torch.ops.{m.group(1)}.{m.group(2)}",
|
||||
string,
|
||||
)
|
||||
|
||||
def dump_patterns(self, config: VllmConfig, pm_pass: PatternMatcherPass):
|
||||
"""
|
||||
If debug dumping is enabled, dump the Inductor pattern-matcher patterns
|
||||
into the debug_dump_path folder next to the dumped fx graphs.
|
||||
|
||||
This method does its best to print something that looks like Python code
|
||||
for easier debugging and potentially navigation. If any errors appear in
|
||||
the output, please add to this method.
|
||||
|
||||
TODO(luka): use pattern object to manually produce pattern graph
|
||||
"""
|
||||
debug_dump_path = config.compilation_config.debug_dump_path
|
||||
if not debug_dump_path:
|
||||
return
|
||||
|
||||
rank = config.parallel_config.rank
|
||||
debug_dump_path = Path(debug_dump_path) / f"rank_{rank}"
|
||||
debug_dump_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
from vllm.utils import unique_filepath
|
||||
file_path = unique_filepath(
|
||||
lambda i: debug_dump_path / f"patterns.{self.pass_name}.{i}.py")
|
||||
|
||||
with file_path.open("w") as f:
|
||||
print(
|
||||
f'# This file was produced by VllmPatternMatcherPass.'
|
||||
f'dump_patterns for {self.pass_name}.\n'
|
||||
f'# It does its best to produce valid-Python-looking code but'
|
||||
f' please add to dump_patterns if there are any errors.\n\n'
|
||||
f'from torch._higher_order_ops.auto_functionalize import '
|
||||
f'auto_functionalized as auto_functionalized\n'
|
||||
f'from torch._inductor.pattern_matcher import *',
|
||||
file=f)
|
||||
|
||||
for node, patterns in pm_pass.patterns.items():
|
||||
# fix the operator.getitem repr
|
||||
if node[1] == operator.getitem:
|
||||
node_repr = f"({repr(node[0])}, operator.getitem)"
|
||||
else:
|
||||
node_repr = repr(node)
|
||||
|
||||
node_repr = self._replace_op_overloads(node_repr)
|
||||
|
||||
print(f"\n\n# Patterns for op: {node_repr}", file=f)
|
||||
for i, pattern in enumerate(patterns):
|
||||
# reserve auto_functionalized ahead of time
|
||||
pp = PatternPrettyPrinter()
|
||||
pp.namespace.create_name("auto_functionalized", None)
|
||||
|
||||
# Assemble pattern
|
||||
out_node = pp.pretty_print(pattern.pattern)
|
||||
pattern_repr = "\n".join([f"def pattern_{i}():"] + [
|
||||
f"{pp.memoized_objs_names[key]} = "
|
||||
f"{pp.memoized_objs_pp[key]}"
|
||||
for key in pp.memoized_objs_names
|
||||
] + [f"return {out_node}"]).replace("\n", "\n ")
|
||||
|
||||
pattern_repr = self._replace_op_overloads(pattern_repr)
|
||||
print(f"{pattern_repr}\n", file=f)
|
||||
|
||||
|
||||
class PrinterInductorPass(VllmInductorPass):
|
||||
|
||||
def __init__(self, name: str, config: VllmConfig):
|
||||
|
||||
@ -503,7 +503,7 @@ class VllmConfig:
|
||||
if self.compilation_config.pass_config.enable_sequence_parallelism:
|
||||
self.compilation_config.custom_ops.append("+rms_norm")
|
||||
|
||||
if current_platform.is_cuda_alike() or current_platform.is_xpu():
|
||||
if current_platform.support_static_graph_mode():
|
||||
# if cudagraph_mode is not explicitly set by users, set default
|
||||
# value
|
||||
if self.compilation_config.cudagraph_mode is None:
|
||||
@ -905,10 +905,9 @@ def set_current_vllm_config(vllm_config: VllmConfig,
|
||||
except Exception:
|
||||
raise
|
||||
else:
|
||||
logger.debug("enabled custom ops: %s",
|
||||
vllm_config.compilation_config.enabled_custom_ops)
|
||||
logger.debug("disabled custom ops: %s",
|
||||
vllm_config.compilation_config.disabled_custom_ops)
|
||||
if check_compile:
|
||||
vllm_config.compilation_config.custom_op_log_check()
|
||||
|
||||
if check_compile and \
|
||||
vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \
|
||||
and compilation_counter.num_models_seen == num_models_seen:
|
||||
|
||||
@ -487,6 +487,12 @@ class CompilationConfig:
|
||||
"supported with torch>=2.9.0.dev. Set "
|
||||
"use_inductor_graph_partition=False instead.")
|
||||
|
||||
for op in self.custom_ops:
|
||||
if op[0] not in {'+', '-'} and op not in {'all', 'none'}:
|
||||
raise ValueError(f"Invalid syntax '{op}' for custom op, "
|
||||
"must be 'all', 'none', '+op' or '-op' "
|
||||
"(where 'op' is the registered op name)")
|
||||
|
||||
def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
|
||||
if self.level == CompilationLevel.NO_COMPILATION:
|
||||
raise ValueError("No compilation level is set.")
|
||||
@ -532,8 +538,8 @@ class CompilationConfig:
|
||||
for x in self.compile_sizes:
|
||||
if isinstance(x, str):
|
||||
assert x == "cudagraph_capture_sizes", \
|
||||
"Unrecognized size type in compile_sizes, " \
|
||||
f"expect 'cudagraph_capture_sizes', got {x}"
|
||||
"Unrecognized size type in compile_sizes, " \
|
||||
f"expect 'cudagraph_capture_sizes', got {x}"
|
||||
computed_compile_sizes.extend(self.cudagraph_capture_sizes)
|
||||
else:
|
||||
assert isinstance(x, int)
|
||||
@ -628,3 +634,41 @@ class CompilationConfig:
|
||||
|
||||
return use_fx_graph_piecewise_compilation or \
|
||||
use_inductor_piecewise_compilation
|
||||
|
||||
def custom_op_log_check(self):
|
||||
"""
|
||||
This method logs the enabled/disabled custom ops and checks that the
|
||||
passed custom_ops field only contains relevant ops.
|
||||
It is called at the end of set_current_vllm_config,
|
||||
after the custom ops have been instantiated.
|
||||
"""
|
||||
|
||||
if len(self.enabled_custom_ops) + len(self.disabled_custom_ops) == 0:
|
||||
logger.debug("No custom ops found in model.")
|
||||
return
|
||||
|
||||
logger.debug("enabled custom ops: %s", self.enabled_custom_ops)
|
||||
logger.debug("disabled custom ops: %s", self.disabled_custom_ops)
|
||||
|
||||
all_ops_in_model = (self.enabled_custom_ops | self.disabled_custom_ops)
|
||||
for op in self.custom_ops:
|
||||
if op in {"all", "none"}:
|
||||
continue
|
||||
|
||||
assert op[0] in {'+', '-'}, "Invalid custom op syntax " \
|
||||
"(should be checked during init)"
|
||||
|
||||
# check if op name exists in model
|
||||
op_name = op[1:]
|
||||
if op_name not in all_ops_in_model:
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
# Does op exist at all or is it just not present in this model?
|
||||
# Note: Only imported op classes appear in the registry.
|
||||
missing_str = "doesn't exist (or wasn't imported/registered)" \
|
||||
if op_name not in CustomOp.op_registry \
|
||||
else "not present in model"
|
||||
|
||||
enable_str = "enabling" if op[0] == '+' else "disabling"
|
||||
logger.warning_once("Op '%s' %s, %s with '%s' has no effect",
|
||||
op_name, missing_str, enable_str, op)
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
from dataclasses import field
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
|
||||
|
||||
@ -351,6 +352,10 @@ class ParallelConfig:
|
||||
self.world_size = self.pipeline_parallel_size * \
|
||||
self.tensor_parallel_size
|
||||
|
||||
if self.distributed_executor_backend == "external_launcher":
|
||||
logger.info("Using external launcher for distributed inference.")
|
||||
self.world_size *= self.data_parallel_size
|
||||
|
||||
if self.data_parallel_size_local > self.data_parallel_size:
|
||||
raise ValueError(
|
||||
f"data_parallel_size_local ({self.data_parallel_size_local}) "
|
||||
@ -358,6 +363,13 @@ class ParallelConfig:
|
||||
|
||||
if self.data_parallel_size > 1 or self.data_parallel_size_local == 0:
|
||||
# Data parallel was specified in the engine args.
|
||||
if self.distributed_executor_backend == "external_launcher":
|
||||
# For external launcher,
|
||||
# we need to set the data parallel rank automatically
|
||||
self.data_parallel_rank = int(os.environ["RANK"]) \
|
||||
// (self.world_size // self.data_parallel_size)
|
||||
logger.info("Set data_parallel_rank to %d automatically.",
|
||||
self.data_parallel_rank)
|
||||
if not self._data_parallel_master_port_list:
|
||||
self._data_parallel_master_port_list = get_open_ports_list(5)
|
||||
self.data_parallel_master_port = \
|
||||
@ -380,7 +392,6 @@ class ParallelConfig:
|
||||
"be set when data_parallel_size > 1")
|
||||
|
||||
if self.distributed_executor_backend == "external_launcher":
|
||||
import os
|
||||
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
|
||||
logger.info("Disabling V1 multiprocessing for external launcher.")
|
||||
|
||||
|
||||
@ -527,7 +527,7 @@ class SpeculativeConfig:
|
||||
"speculative decoding is > 1, but got "
|
||||
f"{self.disable_by_batch_size=}")
|
||||
|
||||
eagle3_target_supported = ["llama", "qwen"]
|
||||
eagle3_target_supported = ["llama", "qwen", "gpt_oss"]
|
||||
if self.method == "eagle3" and self.target_model_config and not any(
|
||||
supported_model in
|
||||
self.target_model_config.hf_text_config.model_type
|
||||
|
||||
@ -25,6 +25,12 @@ class XpuCommunicator(DeviceCommunicatorBase):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
if self.use_all2all:
|
||||
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
|
||||
if all2all_backend != "naive":
|
||||
logger.warning(
|
||||
"`%s` all2all manager is not supported on XPU."
|
||||
"Falling back to `naive` all2all manager for XPU.",
|
||||
all2all_backend)
|
||||
all2all_backend = "naive"
|
||||
if all2all_backend == "naive":
|
||||
from .all2all import NaiveAll2AllManager
|
||||
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
|
||||
@ -67,3 +73,16 @@ class XpuCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
def broadcast(self, input_: torch.Tensor, src: int = 0) -> None:
|
||||
dist.broadcast(input_, src=src, group=self.device_group)
|
||||
|
||||
def dispatch(
|
||||
self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.all2all_manager is not None
|
||||
hidden_states, router_logits = self.all2all_manager.dispatch(
|
||||
hidden_states, router_logits)
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
assert self.all2all_manager is not None
|
||||
hidden_states = self.all2all_manager.combine(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
@ -58,6 +58,12 @@ except ImportError:
|
||||
logger.warning("NIXL is not available")
|
||||
NixlWrapper = None
|
||||
|
||||
try:
|
||||
from nixl._api import nixl_agent_config
|
||||
except ImportError:
|
||||
nixl_agent_config = None
|
||||
logger.warning("NIXL agent config is not available")
|
||||
|
||||
# Supported platforms and types of kv transfer buffer.
|
||||
# {device: tuple of supported kv buffer types}
|
||||
_NIXL_SUPPORTED_DEVICE = {
|
||||
@ -65,6 +71,8 @@ _NIXL_SUPPORTED_DEVICE = {
|
||||
"tpu": ("cpu", ),
|
||||
"xpu": ("cpu", ),
|
||||
}
|
||||
# support for oot platform by providing mapping in current_platform
|
||||
_NIXL_SUPPORTED_DEVICE.update(current_platform.get_nixl_supported_devices())
|
||||
|
||||
|
||||
class NixlAgentMetadata(
|
||||
@ -242,6 +250,10 @@ class NixlConnector(KVConnectorBase_V1):
|
||||
self.connector_worker.copy_blocks:
|
||||
self.connector_worker.save_kv_to_host(self._connector_metadata)
|
||||
|
||||
def shutdown(self):
|
||||
if self.connector_worker is not None:
|
||||
self.connector_worker.shutdown()
|
||||
|
||||
|
||||
class NixlConnectorScheduler:
|
||||
"""Implementation of Scheduler side methods"""
|
||||
@ -448,8 +460,15 @@ class NixlConnectorWorker:
|
||||
self.vllm_config = vllm_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
|
||||
self.nixl_backends = \
|
||||
vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"backends", ["UCX"])
|
||||
# Agent.
|
||||
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
|
||||
non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"]
|
||||
config = nixl_agent_config(backends=self.nixl_backends) if len(
|
||||
non_ucx_backends) > 0 and nixl_agent_config is not None else None
|
||||
|
||||
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), config)
|
||||
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
|
||||
self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict)
|
||||
|
||||
@ -486,11 +505,15 @@ class NixlConnectorWorker:
|
||||
# used when device memory can not be registered under nixl
|
||||
self.host_xfer_buffers: dict[str, torch.Tensor] = {}
|
||||
self.use_host_buffer = self.kv_buffer_device == "cpu"
|
||||
if self.kv_buffer_device == "cuda":
|
||||
self.nixl_memory_type = "VRAM"
|
||||
elif self.kv_buffer_device == "cpu":
|
||||
self.nixl_memory_type = "DRAM"
|
||||
else:
|
||||
# support for oot platform which can't register nixl memory
|
||||
# type based on kv_buffer_device
|
||||
self.nixl_memory_type = current_platform.get_nixl_memory_type()
|
||||
if self.nixl_memory_type is None:
|
||||
if self.kv_buffer_device == "cuda":
|
||||
self.nixl_memory_type = "VRAM"
|
||||
elif self.kv_buffer_device == "cpu":
|
||||
self.nixl_memory_type = "DRAM"
|
||||
if self.nixl_memory_type is None:
|
||||
raise RuntimeError(
|
||||
f"{self.device_type} with {self.kv_buffer_device} kv_buffer "
|
||||
"is not supported.")
|
||||
@ -567,13 +590,6 @@ class NixlConnectorWorker:
|
||||
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
|
||||
self.xfer_stats = NixlKVConnectorStats()
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup background threads on destruction."""
|
||||
if executor := getattr(self, "_handshake_initiation_executor", None):
|
||||
executor.shutdown(wait=False)
|
||||
if listener_t := getattr(self, "_nixl_handshake_listener_t", None):
|
||||
listener_t.join(timeout=0)
|
||||
|
||||
@staticmethod
|
||||
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
|
||||
ready_event: threading.Event, base_port: int,
|
||||
@ -766,7 +782,7 @@ class NixlConnectorWorker:
|
||||
descs = self.nixl_wrapper.get_reg_descs(caches_data,
|
||||
self.nixl_memory_type)
|
||||
logger.debug("Registering descs: %s", caches_data)
|
||||
self.nixl_wrapper.register_memory(descs)
|
||||
self.nixl_wrapper.register_memory(descs, backends=self.nixl_backends)
|
||||
logger.debug("Done registering descs")
|
||||
self._registered_descs.append(descs)
|
||||
|
||||
@ -1327,6 +1343,30 @@ class NixlConnectorWorker:
|
||||
return self.xfer_stats.clone_and_reset()
|
||||
return None
|
||||
|
||||
def shutdown(self):
|
||||
"""Shutdown the connector worker."""
|
||||
self._handshake_initiation_executor.shutdown(wait=False)
|
||||
if self._nixl_handshake_listener_t is not None:
|
||||
self._nixl_handshake_listener_t.join(timeout=0)
|
||||
self._nixl_handshake_listener_t = None
|
||||
for handles in self._recving_transfers.values():
|
||||
for handle, _ in handles:
|
||||
self.nixl_wrapper.release_xfer_handle(handle)
|
||||
self._recving_transfers.clear()
|
||||
if self.src_xfer_side_handle:
|
||||
self.nixl_wrapper.release_dlist_handle(self.src_xfer_side_handle)
|
||||
self.src_xfer_side_handle = 0
|
||||
for dst_xfer_side_handle in self.dst_xfer_side_handles.values():
|
||||
self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle)
|
||||
self.dst_xfer_side_handles.clear()
|
||||
for remote_agents in self._remote_agents.values():
|
||||
for agent_name in remote_agents.values():
|
||||
self.nixl_wrapper.remove_remote_agent(agent_name)
|
||||
self._remote_agents.clear()
|
||||
for desc in self._registered_descs:
|
||||
self.nixl_wrapper.deregister_memory(desc)
|
||||
self._registered_descs.clear()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]:
|
||||
|
||||
@ -178,6 +178,9 @@ class P2pNcclConnector(KVConnectorBase_V1):
|
||||
|
||||
# Load the KV for each request each layer
|
||||
for request in metadata.requests:
|
||||
request_id = request.request_id
|
||||
ip, port = self.parse_request_id(request_id, False)
|
||||
remote_address = ip + ":" + str(port + self._rank)
|
||||
for layer_name in forward_context.no_compile_layers:
|
||||
layer = forward_context.no_compile_layers[layer_name]
|
||||
|
||||
@ -191,7 +194,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
|
||||
layer = kv_cache[forward_context.virtual_engine]
|
||||
|
||||
kv_cache = self.p2p_nccl_engine.recv_tensor(
|
||||
request.request_id + "#" + layer_name)
|
||||
request.request_id + "#" + layer_name, remote_address)
|
||||
|
||||
if kv_cache is None:
|
||||
logger.warning("🚧kv_cache is None, %s", request.request_id)
|
||||
|
||||
@ -134,7 +134,6 @@ class P2pNcclEngine:
|
||||
# PUT or PUT_ASYNC
|
||||
# tensor_id: torch.Tensor
|
||||
self.send_queue: deque[SendQueueItem] = deque()
|
||||
self.send_request_id_to_tensor_ids: dict[str, set[str]] = {}
|
||||
if self.send_type == "PUT_ASYNC":
|
||||
self._send_thread = threading.Thread(target=self.send_async,
|
||||
daemon=True)
|
||||
@ -143,6 +142,7 @@ class P2pNcclEngine:
|
||||
# tensor_id: torch.Tensor/(addr, dtype, shape)
|
||||
self.recv_store: dict[str, Any] = {}
|
||||
self.recv_request_id_to_tensor_ids: dict[str, set[str]] = {}
|
||||
self.send_request_id_to_tensor_ids: dict[str, set[str]] = {}
|
||||
self.socks: dict[str, Any] = {} # remote_address: client socket
|
||||
self.comms: dict[str, Any] = {} # remote_address: (ncclComm_t, rank)
|
||||
|
||||
@ -223,18 +223,26 @@ class P2pNcclEngine:
|
||||
# GET
|
||||
with self.send_store_cv:
|
||||
tensor_size = tensor.element_size() * tensor.numel()
|
||||
if tensor_size > self.buffer_size_threshold:
|
||||
logger.warning(
|
||||
"❗[GET]tensor_id:%s, tensor_size:%d, is greater than"
|
||||
"buffer size threshold :%d, skip send to %s, rank:%d",
|
||||
tensor_id, tensor_size, self.buffer_size_threshold,
|
||||
remote_address, self.rank)
|
||||
return False
|
||||
while (self.buffer_size + tensor_size
|
||||
> self.buffer_size_threshold):
|
||||
oldest_tenser_id = next(iter(self.send_store))
|
||||
oldest_tenser = self.send_store.pop(oldest_tenser_id)
|
||||
oldest_tenser_size = oldest_tenser.element_size(
|
||||
) * oldest_tenser.numel()
|
||||
self.buffer_size -= oldest_tenser_size
|
||||
logger.info(
|
||||
assert len(self.send_store) > 0
|
||||
oldest_tensor_id = next(iter(self.send_store))
|
||||
oldest_tensor = self.send_store.pop(oldest_tensor_id)
|
||||
oldest_tensor_size = oldest_tensor.element_size(
|
||||
) * oldest_tensor.numel()
|
||||
self.buffer_size -= oldest_tensor_size
|
||||
logger.debug(
|
||||
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
|
||||
" buffer_size:%d, oldest_tenser_size:%d, rank:%d",
|
||||
" buffer_size:%d, oldest_tensor_size:%d, rank:%d",
|
||||
remote_address, tensor_id, tensor_size, self.buffer_size,
|
||||
oldest_tenser_size, self.rank)
|
||||
oldest_tensor_size, self.rank)
|
||||
|
||||
self.send_store[tensor_id] = tensor
|
||||
self.buffer_size += tensor_size
|
||||
|
||||
@ -1032,7 +1032,9 @@ def init_distributed_environment(world_size: int = -1,
|
||||
distributed_init_method, backend)
|
||||
from vllm.config import get_current_vllm_config
|
||||
config = get_current_vllm_config()
|
||||
if config is not None and config.parallel_config.data_parallel_size > 1:
|
||||
if config is not None and config.parallel_config.data_parallel_size > 1 \
|
||||
and config.parallel_config.distributed_executor_backend \
|
||||
!= "external_launcher":
|
||||
parallel_config = config.parallel_config
|
||||
# adjust to take into account data parallelism
|
||||
# offset the rank by the data parallel rank
|
||||
|
||||
@ -1147,20 +1147,15 @@ class EngineArgs:
|
||||
else:
|
||||
envs.set_vllm_use_v1(use_v1)
|
||||
|
||||
# Set default arguments for V0 or V1 Engine.
|
||||
if use_v1:
|
||||
self._set_default_args_v1(usage_context, model_config)
|
||||
# Disable chunked prefill for POWER (ppc64le)/ARM/s390x CPUs in V1
|
||||
if current_platform.is_cpu(
|
||||
) and current_platform.get_cpu_architecture() in (
|
||||
CpuArchEnum.POWERPC, CpuArchEnum.S390X, CpuArchEnum.ARM):
|
||||
logger.info(
|
||||
"Chunked prefill is not supported for ARM and POWER "
|
||||
"and S390X CPUs; "
|
||||
"disabling it for V1 backend.")
|
||||
self.enable_chunked_prefill = False
|
||||
else:
|
||||
self._set_default_args_v0(model_config)
|
||||
# Set default arguments for V1 Engine.
|
||||
self._set_default_args(usage_context, model_config)
|
||||
# Disable chunked prefill for POWER (ppc64le)/ARM/s390x CPUs in V1
|
||||
if current_platform.is_cpu() and current_platform.get_cpu_architecture(
|
||||
) in (CpuArchEnum.POWERPC, CpuArchEnum.S390X, CpuArchEnum.ARM):
|
||||
logger.info("Chunked prefill is not supported for ARM and POWER "
|
||||
"and S390X CPUs; "
|
||||
"disabling it for V1 backend.")
|
||||
self.enable_chunked_prefill = False
|
||||
assert self.enable_chunked_prefill is not None
|
||||
|
||||
sliding_window: Optional[int] = None
|
||||
@ -1494,6 +1489,7 @@ class EngineArgs:
|
||||
"FLEX_ATTENTION",
|
||||
"TREE_ATTN",
|
||||
"XFORMERS_VLLM_V1",
|
||||
"ROCM_ATTN_VLLM_V1",
|
||||
]
|
||||
if (envs.is_set("VLLM_ATTENTION_BACKEND")
|
||||
and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):
|
||||
@ -1501,12 +1497,6 @@ class EngineArgs:
|
||||
_raise_or_fallback(feature_name=name, recommend_to_remove=True)
|
||||
return False
|
||||
|
||||
# Platforms must decide if they can support v1 for this model
|
||||
if not current_platform.supports_v1(model_config=model_config):
|
||||
_raise_or_fallback(
|
||||
feature_name=f"device type={current_platform.device_type}",
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
#############################################################
|
||||
# Experimental Features - allow users to opt in.
|
||||
|
||||
@ -1523,12 +1513,6 @@ class EngineArgs:
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
# The platform may be supported on V1, but off by default for now.
|
||||
if not current_platform.default_v1( # noqa: SIM103
|
||||
model_config=model_config) and _warn_or_fallback(
|
||||
current_platform.device_name):
|
||||
return False
|
||||
|
||||
if (current_platform.is_cpu()
|
||||
and model_config.get_sliding_window() is not None):
|
||||
_raise_or_fallback(feature_name="sliding window (CPU backend)",
|
||||
@ -1539,64 +1523,8 @@ class EngineArgs:
|
||||
|
||||
return True
|
||||
|
||||
def _set_default_args_v0(self, model_config: ModelConfig) -> None:
|
||||
"""Set Default Arguments for V0 Engine."""
|
||||
|
||||
max_model_len = model_config.max_model_len
|
||||
use_long_context = max_model_len > 32768
|
||||
if self.enable_chunked_prefill is None:
|
||||
# Chunked prefill not supported for Multimodal or MLA in V0.
|
||||
if model_config.is_multimodal_model or model_config.use_mla:
|
||||
self.enable_chunked_prefill = False
|
||||
|
||||
# Enable chunked prefill by default for long context (> 32K)
|
||||
# models to avoid OOM errors in initial memory profiling phase.
|
||||
elif use_long_context:
|
||||
is_gpu = current_platform.is_cuda()
|
||||
use_sliding_window = (model_config.get_sliding_window()
|
||||
is not None)
|
||||
use_spec_decode = self.speculative_config is not None
|
||||
|
||||
if (is_gpu and not use_sliding_window and not use_spec_decode
|
||||
and not self.enable_lora):
|
||||
self.enable_chunked_prefill = True
|
||||
logger.warning(
|
||||
"Chunked prefill is enabled by default for models "
|
||||
"with max_model_len > 32K. Chunked prefill might "
|
||||
"not work with some features or models. If you "
|
||||
"encounter any issues, please disable by launching "
|
||||
"with --enable-chunked-prefill=False.")
|
||||
|
||||
if self.enable_chunked_prefill is None:
|
||||
self.enable_chunked_prefill = False
|
||||
|
||||
if not self.enable_chunked_prefill and use_long_context:
|
||||
logger.warning(
|
||||
"The model has a long context length (%s). This may cause"
|
||||
"OOM during the initial memory profiling phase, or result "
|
||||
"in low performance due to small KV cache size. Consider "
|
||||
"setting --max-model-len to a smaller value.", max_model_len)
|
||||
|
||||
# Disable prefix caching for multimodal models for VLLM_V0.
|
||||
if self.enable_prefix_caching and model_config.is_multimodal_model:
|
||||
logger.warning(
|
||||
"--enable-prefix-caching is not supported for multimodal "
|
||||
"models in V0 and has been disabled.")
|
||||
self.enable_prefix_caching = False
|
||||
|
||||
if self.enable_prompt_embeds:
|
||||
logger.warning(
|
||||
"--enable-prompt-embeds and --enable-prefix-caching "
|
||||
"are not supported together in V0. Prefix caching has "
|
||||
"been disabled.")
|
||||
self.enable_prefix_caching = False
|
||||
|
||||
# Set max_num_seqs to 256 for VLLM_V0.
|
||||
if self.max_num_seqs is None:
|
||||
self.max_num_seqs = 256
|
||||
|
||||
def _set_default_args_v1(self, usage_context: UsageContext,
|
||||
model_config: ModelConfig) -> None:
|
||||
def _set_default_args(self, usage_context: UsageContext,
|
||||
model_config: ModelConfig) -> None:
|
||||
"""Set Default Arguments for V1 Engine."""
|
||||
|
||||
# V1 always uses chunked prefills and prefix caching
|
||||
@ -1795,21 +1723,6 @@ def _raise_or_fallback(feature_name: str, recommend_to_remove: bool):
|
||||
logger.warning(msg)
|
||||
|
||||
|
||||
def _warn_or_fallback(feature_name: str) -> bool:
|
||||
if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
|
||||
logger.warning(
|
||||
"Detected VLLM_USE_V1=1 with %s. Usage should "
|
||||
"be considered experimental. Please report any "
|
||||
"issues on Github.", feature_name)
|
||||
should_exit = False
|
||||
else:
|
||||
logger.info(
|
||||
"%s is experimental on VLLM_USE_V1=1. "
|
||||
"Falling back to V0 Engine.", feature_name)
|
||||
should_exit = True
|
||||
return should_exit
|
||||
|
||||
|
||||
def human_readable_int(value):
|
||||
"""Parse human-readable integers like '1k', '2M', etc.
|
||||
Including decimal values with decimal multipliers.
|
||||
|
||||
@ -8,6 +8,7 @@ from abc import ABC, abstractmethod
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
from openai.types.responses.tool import Mcp
|
||||
from openai_harmony import Author, Message, Role, StreamState, TextContent
|
||||
|
||||
from vllm.entrypoints.harmony_utils import (
|
||||
@ -21,6 +22,24 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# This is currently needed as the tool type doesn't 1:1 match the
|
||||
# tool namespace, which is what is used to look up the
|
||||
# connection to the tool server
|
||||
_TOOL_NAME_TO_TYPE_MAP = {
|
||||
"browser": "web_search_preview",
|
||||
"python": "code_interpreter",
|
||||
"container": "container",
|
||||
}
|
||||
|
||||
|
||||
def _map_tool_name_to_tool_type(tool_name: str) -> str:
|
||||
if tool_name not in _TOOL_NAME_TO_TYPE_MAP:
|
||||
available_tools = ', '.join(_TOOL_NAME_TO_TYPE_MAP.keys())
|
||||
raise ValueError(
|
||||
f"Built-in tool name '{tool_name}' not defined in mapping. "
|
||||
f"Available tools: {available_tools}")
|
||||
return _TOOL_NAME_TO_TYPE_MAP[tool_name]
|
||||
|
||||
|
||||
class TurnTokens:
|
||||
"""Tracks token counts for a single conversation turn."""
|
||||
@ -59,8 +78,8 @@ class ConversationContext(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
||||
exit_stack: AsyncExitStack,
|
||||
request_id: str) -> None:
|
||||
exit_stack: AsyncExitStack, request_id: str,
|
||||
mcp_tools: dict[str, Mcp]) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -96,8 +115,8 @@ class SimpleContext(ConversationContext):
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
||||
exit_stack: AsyncExitStack,
|
||||
request_id: str) -> None:
|
||||
exit_stack: AsyncExitStack, request_id: str,
|
||||
mcp_tools: dict[str, Mcp]) -> None:
|
||||
pass
|
||||
|
||||
async def cleanup_session(self) -> None:
|
||||
@ -318,13 +337,17 @@ class HarmonyContext(ConversationContext):
|
||||
]
|
||||
|
||||
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
||||
exit_stack: AsyncExitStack,
|
||||
request_id: str) -> None:
|
||||
exit_stack: AsyncExitStack, request_id: str,
|
||||
mcp_tools: dict[str, Mcp]):
|
||||
if tool_server:
|
||||
for tool_name in self.available_tools:
|
||||
if tool_name not in self._tool_sessions:
|
||||
tool_type = _map_tool_name_to_tool_type(tool_name)
|
||||
headers = mcp_tools[
|
||||
tool_type].headers if tool_type in mcp_tools else None
|
||||
tool_session = await exit_stack.enter_async_context(
|
||||
tool_server.new_session(tool_name, request_id))
|
||||
tool_server.new_session(tool_name, request_id,
|
||||
headers))
|
||||
self._tool_sessions[tool_name] = tool_session
|
||||
exit_stack.push_async_exit(self.cleanup_session)
|
||||
|
||||
|
||||
@ -126,8 +126,10 @@ def get_developer_message(
|
||||
function_tools: list[Union[Tool, ChatCompletionToolsParam]] = []
|
||||
for tool in tools:
|
||||
if tool.type in ("web_search_preview", "code_interpreter",
|
||||
"container"):
|
||||
"container", "mcp"):
|
||||
# These are built-in tools that are added to the system message.
|
||||
# Adding in MCP for now until we support MCP tools executed
|
||||
# server side
|
||||
pass
|
||||
|
||||
elif tool.type == "function":
|
||||
|
||||
@ -1468,7 +1468,7 @@ class LLM:
|
||||
|
||||
def _validate_and_add_requests(
|
||||
self,
|
||||
prompts: Union[PromptType, Sequence[PromptType]],
|
||||
prompts: Union[PromptType, Sequence[PromptType], DataPrompt],
|
||||
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
|
||||
Sequence[PoolingParams]],
|
||||
*,
|
||||
@ -1478,7 +1478,7 @@ class LLM:
|
||||
) -> None:
|
||||
if isinstance(prompts, (str, dict)):
|
||||
# Convert a single prompt to a list.
|
||||
prompts = [prompts]
|
||||
prompts = [prompts] # type: ignore[list-item]
|
||||
|
||||
num_requests = len(prompts)
|
||||
if isinstance(params, Sequence) and len(params) != num_requests:
|
||||
|
||||
@ -460,8 +460,12 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
|
||||
async with AsyncExitStack() as exit_stack:
|
||||
try:
|
||||
mcp_tools = {
|
||||
tool.server_label: tool
|
||||
for tool in request.tools if tool.type == "mcp"
|
||||
}
|
||||
await context.init_tool_sessions(self.tool_server, exit_stack,
|
||||
request.request_id)
|
||||
request.request_id, mcp_tools)
|
||||
async for _ in result_generator:
|
||||
pass
|
||||
except asyncio.CancelledError:
|
||||
@ -748,11 +752,16 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
# New conversation.
|
||||
reasoning_effort = (request.reasoning.effort
|
||||
if request.reasoning else None)
|
||||
# Temporary: OpenAI types doesn't have container tool
|
||||
# so we used MCP to cover that, up for change
|
||||
tool_types = [tool.type for tool in request.tools]
|
||||
if envs.VLLM_GPT_OSS_USE_CONTAINER_TOOL:
|
||||
tool_types.append("container")
|
||||
|
||||
# Allow the MCP Tool type to enable built in tools if the
|
||||
# server_label is allowlisted in
|
||||
# envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS
|
||||
if envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS:
|
||||
for tool in request.tools:
|
||||
if (tool.type == "mcp" and tool.server_label
|
||||
in envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS):
|
||||
tool_types.append(tool.server_label)
|
||||
enable_browser = ("web_search_preview" in tool_types
|
||||
and self.tool_server is not None
|
||||
and self.tool_server.has_tool("browser"))
|
||||
@ -1653,8 +1662,12 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
async with AsyncExitStack() as exit_stack:
|
||||
processer = None
|
||||
if self.use_harmony:
|
||||
mcp_tools = {
|
||||
tool.server_label: tool
|
||||
for tool in request.tools if tool.type == "mcp"
|
||||
}
|
||||
await context.init_tool_sessions(self.tool_server, exit_stack,
|
||||
request.request_id)
|
||||
request.request_id, mcp_tools)
|
||||
processer = self._process_harmony_streaming_events
|
||||
else:
|
||||
processer = self._process_simple_streaming_events
|
||||
|
||||
@ -20,6 +20,7 @@ from .openai_tool_parser import OpenAIToolParser
|
||||
from .phi4mini_tool_parser import Phi4MiniJsonToolParser
|
||||
from .pythonic_tool_parser import PythonicToolParser
|
||||
from .qwen3coder_tool_parser import Qwen3CoderToolParser
|
||||
from .qwen3xml_tool_parser import Qwen3XMLToolParser
|
||||
from .seed_oss_tool_parser import SeedOssToolParser
|
||||
from .step3_tool_parser import Step3ToolParser
|
||||
from .xlam_tool_parser import xLAMToolParser
|
||||
@ -45,6 +46,7 @@ __all__ = [
|
||||
"HunyuanA13BToolParser",
|
||||
"Glm4MoeModelToolParser",
|
||||
"Qwen3CoderToolParser",
|
||||
"Qwen3XMLToolParser",
|
||||
"SeedOssToolParser",
|
||||
"Step3ToolParser",
|
||||
"OpenAIToolParser",
|
||||
|
||||
@ -368,16 +368,32 @@ class Hermes2ProToolParser(ToolParser):
|
||||
# case -- we now have the first info about arguments available from
|
||||
# autocompleting the JSON
|
||||
elif cur_arguments and not prev_arguments:
|
||||
# extract the content after {"name": ..., "arguments":
|
||||
# directly from tool_call_portion as cur_arguments_json,
|
||||
# since cur_arguments may differ from the original text
|
||||
# due to partial JSON parsing
|
||||
# for example, tool_call_portion =
|
||||
# {"name": "search", "arguments": {"search_request": {"
|
||||
# but cur_arguments =
|
||||
# {"search_request": {}}
|
||||
function_name = current_tool_call.get("name")
|
||||
match = re.search(
|
||||
r'\{"name":\s*"' +
|
||||
re.escape(function_name) + r'"\s*,\s*"arguments":\s*(.*)',
|
||||
tool_call_portion.strip(), re.DOTALL)
|
||||
if match:
|
||||
cur_arguments_json = match.group(1)
|
||||
else:
|
||||
cur_arguments_json = json.dumps(cur_arguments,
|
||||
ensure_ascii=False)
|
||||
|
||||
cur_arguments_json = json.dumps(cur_arguments,
|
||||
ensure_ascii=False)
|
||||
logger.debug("finding %s in %s", delta_text,
|
||||
cur_arguments_json)
|
||||
|
||||
# get the location where previous args differ from current
|
||||
if (delta_text not in cur_arguments_json[:-2]):
|
||||
# get the location where previous args differ from current.
|
||||
if (delta_text not in cur_arguments_json):
|
||||
return None
|
||||
args_delta_start_loc = cur_arguments_json[:-2]. \
|
||||
args_delta_start_loc = cur_arguments_json. \
|
||||
rindex(delta_text) + \
|
||||
len(delta_text)
|
||||
|
||||
@ -397,8 +413,20 @@ class Hermes2ProToolParser(ToolParser):
|
||||
|
||||
# last case -- we have an update to existing arguments.
|
||||
elif cur_arguments and prev_arguments:
|
||||
if isinstance(delta_text, str) and len(delta_text.rstrip(
|
||||
)) >= 1 and delta_text.rstrip()[-1] == '}':
|
||||
# judge whether the tool_call_portion is a complete JSON
|
||||
try:
|
||||
json.loads(tool_call_portion)
|
||||
is_complete_json = True
|
||||
except Exception:
|
||||
is_complete_json = False
|
||||
|
||||
# if the delta_text ends with a '}' and tool_call_portion is a
|
||||
# complete JSON, then the last '}' does not belong to the
|
||||
# arguments, so we should trim it off
|
||||
if isinstance(delta_text, str) \
|
||||
and len(delta_text.rstrip()) >= 1 \
|
||||
and delta_text.rstrip()[-1] == '}' \
|
||||
and is_complete_json:
|
||||
delta_text = delta_text.rstrip()[:-1]
|
||||
|
||||
logger.debug("got diff %s", delta_text)
|
||||
|
||||
1137
vllm/entrypoints/openai/tool_parsers/qwen3xml_tool_parser.py
Normal file
1137
vllm/entrypoints/openai/tool_parsers/qwen3xml_tool_parser.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -280,7 +280,7 @@ class CompletionRenderer(BaseRenderer):
|
||||
if truncate_prompt_tokens < 0:
|
||||
truncate_prompt_tokens = self.model_config.max_model_len
|
||||
|
||||
if max_length is not None and truncate_prompt_tokens > max_length:
|
||||
if max_length is not None and truncate_prompt_tokens > max_length: # type: ignore[operator]
|
||||
raise ValueError(
|
||||
f"truncate_prompt_tokens ({truncate_prompt_tokens}) "
|
||||
f"cannot be greater than max_length ({max_length}). "
|
||||
|
||||
@ -18,7 +18,6 @@ if TYPE_CHECKING:
|
||||
async def list_server_and_tools(server_url: str):
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
async with sse_client(url=server_url) as streams, ClientSession(
|
||||
*streams) as session:
|
||||
initialize_response = await session.initialize()
|
||||
@ -86,8 +85,12 @@ class ToolServer(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def new_session(self, tool_name: str,
|
||||
session_id: str) -> AbstractAsyncContextManager[Any]:
|
||||
def new_session(
|
||||
self,
|
||||
tool_name: str,
|
||||
session_id: str,
|
||||
headers: Optional[dict[str, str]] = None
|
||||
) -> AbstractAsyncContextManager[Any]:
|
||||
"""
|
||||
Create a session for the tool.
|
||||
"""
|
||||
@ -144,16 +147,21 @@ class MCPToolServer(ToolServer):
|
||||
return self.harmony_tool_descriptions.get(tool_name)
|
||||
|
||||
@asynccontextmanager
|
||||
async def new_session(self, tool_name: str, session_id: str):
|
||||
async def new_session(self,
|
||||
tool_name: str,
|
||||
session_id: str,
|
||||
headers: Optional[dict[str, str]] = None):
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
url = self.urls.get(tool_name)
|
||||
headers = {"x-session-id": session_id}
|
||||
request_headers = {"x-session-id": session_id}
|
||||
if headers is not None:
|
||||
request_headers.update(headers)
|
||||
if not url:
|
||||
raise KeyError(f"Tool '{tool_name}' is not supported")
|
||||
async with sse_client(url=url,
|
||||
headers=headers) as streams, ClientSession(
|
||||
*streams) as session:
|
||||
async with sse_client(
|
||||
url=url, headers=request_headers) as streams, ClientSession(
|
||||
*streams) as session:
|
||||
await session.initialize()
|
||||
yield session
|
||||
|
||||
@ -189,7 +197,10 @@ class DemoToolServer(ToolServer):
|
||||
raise ValueError(f"Unknown tool {tool_name}")
|
||||
|
||||
@asynccontextmanager
|
||||
async def new_session(self, tool_name: str, session_id: str):
|
||||
async def new_session(self,
|
||||
tool_name: str,
|
||||
session_id: str,
|
||||
headers: Optional[dict[str, str]] = None):
|
||||
if tool_name not in self.tools:
|
||||
raise KeyError(f"Tool '{tool_name}' is not supported")
|
||||
yield self.tools[tool_name]
|
||||
|
||||
85
vllm/envs.py
85
vllm/envs.py
@ -119,12 +119,14 @@ if TYPE_CHECKING:
|
||||
VLLM_SERVER_DEV_MODE: bool = False
|
||||
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128
|
||||
VLLM_MLA_DISABLE: bool = False
|
||||
VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH: int = 16
|
||||
VLLM_RAY_PER_WORKER_GPUS: float = 1.0
|
||||
VLLM_RAY_BUNDLE_INDICES: str = ""
|
||||
VLLM_CUDART_SO_PATH: Optional[str] = None
|
||||
VLLM_DP_RANK: int = 0
|
||||
VLLM_DP_RANK_LOCAL: int = -1
|
||||
VLLM_DP_SIZE: int = 1
|
||||
VLLM_USE_STANDALONE_COMPILE: bool = False
|
||||
VLLM_DP_MASTER_IP: str = ""
|
||||
VLLM_DP_MASTER_PORT: int = 0
|
||||
VLLM_MOE_DP_CHUNK_SIZE: int = 256
|
||||
@ -183,11 +185,12 @@ if TYPE_CHECKING:
|
||||
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
|
||||
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
|
||||
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
|
||||
VLLM_GPT_OSS_USE_CONTAINER_TOOL: bool = False
|
||||
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
|
||||
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
|
||||
VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True
|
||||
VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME: str = "VLLM_OBJECT_STORAGE_SHM_BUFFER"
|
||||
GPT_OSS_SYSTEM_TOOL_MCP_LABELS: list[str] = []
|
||||
VLLM_PATTERN_MATCH_DEBUG: Optional[str] = None
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -258,6 +261,58 @@ def env_with_choices(
|
||||
return _get_validated_env
|
||||
|
||||
|
||||
def env_list_with_choices(
|
||||
env_name: str,
|
||||
default: list[str],
|
||||
choices: Union[list[str], Callable[[], list[str]]],
|
||||
case_sensitive: bool = True) -> Callable[[], list[str]]:
|
||||
"""
|
||||
Create a lambda that validates environment variable
|
||||
containing comma-separated values against allowed choices
|
||||
|
||||
Args:
|
||||
env_name: Name of the environment variable
|
||||
default: Default list of values if not set
|
||||
choices: List of valid string options or callable that returns list
|
||||
case_sensitive: Whether validation should be case sensitive
|
||||
|
||||
Returns:
|
||||
Lambda function for environment_variables
|
||||
dict that returns list of strings
|
||||
"""
|
||||
|
||||
def _get_validated_env_list() -> list[str]:
|
||||
value = os.getenv(env_name)
|
||||
if value is None:
|
||||
return default
|
||||
|
||||
# Split comma-separated values and strip whitespace
|
||||
values = [v.strip() for v in value.split(",") if v.strip()]
|
||||
|
||||
if not values:
|
||||
return default
|
||||
|
||||
# Resolve choices if it's a callable (for lazy loading)
|
||||
actual_choices = choices() if callable(choices) else choices
|
||||
|
||||
# Validate each value
|
||||
for val in values:
|
||||
if not case_sensitive:
|
||||
check_value = val.lower()
|
||||
check_choices = [choice.lower() for choice in actual_choices]
|
||||
else:
|
||||
check_value = val
|
||||
check_choices = actual_choices
|
||||
|
||||
if check_value not in check_choices:
|
||||
raise ValueError(f"Invalid value '{val}' in {env_name}. "
|
||||
f"Valid options: {actual_choices}.")
|
||||
|
||||
return values
|
||||
|
||||
return _get_validated_env_list
|
||||
|
||||
|
||||
def get_vllm_port() -> Optional[int]:
|
||||
"""Get the port from VLLM_PORT environment variable.
|
||||
|
||||
@ -436,9 +491,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
|
||||
# Feature flag to enable/disable Inductor standalone compile.
|
||||
# In torch <= 2.7 we ignore this flag; in torch >= 2.8 this is
|
||||
# enabled by default.
|
||||
# disabled by default.
|
||||
"VLLM_USE_STANDALONE_COMPILE":
|
||||
lambda: os.environ.get("VLLM_USE_STANDALONE_COMPILE", "1") == "1",
|
||||
lambda: os.environ.get("VLLM_USE_STANDALONE_COMPILE", "0") == "1",
|
||||
|
||||
# Debug pattern matching inside custom passes.
|
||||
# Should be set to the fx.Node name (e.g. 'getitem_34' or 'scaled_mm_3').
|
||||
"VLLM_PATTERN_MATCH_DEBUG":
|
||||
lambda: os.environ.get("VLLM_PATTERN_MATCH_DEBUG", None),
|
||||
|
||||
# local rank of the process in the distributed setting, used to determine
|
||||
# the GPU device id
|
||||
@ -946,6 +1006,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_MLA_DISABLE":
|
||||
lambda: bool(int(os.getenv("VLLM_MLA_DISABLE", "0"))),
|
||||
|
||||
# If set, vLLM will pick up the provided Flash Attention MLA
|
||||
# max number splits for cuda graph decode
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH":
|
||||
lambda: int(os.getenv("VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH",
|
||||
"16")),
|
||||
|
||||
# Number of GPUs per worker in Ray, if it is set to be a fraction,
|
||||
# it allows ray to schedule multiple actors on a single GPU,
|
||||
# so that users can colocate other actors on the same GPUs as vLLM.
|
||||
@ -1306,10 +1372,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_TUNED_CONFIG_FOLDER":
|
||||
lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None),
|
||||
|
||||
# Allows vllm use container tool
|
||||
"VLLM_GPT_OSS_USE_CONTAINER_TOOL":
|
||||
lambda: bool(int(os.getenv("VLLM_GPT_OSS_USE_CONTAINER_TOOL", "0"))),
|
||||
|
||||
# Allows harmony instructions to be injected on system messages
|
||||
"VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS":
|
||||
lambda: bool(
|
||||
@ -1329,6 +1391,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME":
|
||||
lambda: os.getenv("VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME",
|
||||
"VLLM_OBJECT_STORAGE_SHM_BUFFER"),
|
||||
|
||||
# Valid values are container,code_interpreter,web_search_preview
|
||||
# ex GPT_OSS_SYSTEM_TOOL_MCP_LABELS=container,code_interpreter
|
||||
"GPT_OSS_SYSTEM_TOOL_MCP_LABELS":
|
||||
env_list_with_choices("GPT_OSS_SYSTEM_TOOL_MCP_LABELS", [],
|
||||
["container",
|
||||
"code_interpreter",
|
||||
"web_search_preview"]),
|
||||
}
|
||||
|
||||
# --8<-- [end:env-vars-definition]
|
||||
@ -1379,6 +1449,7 @@ def compute_hash() -> str:
|
||||
environment_variables_to_hash = [
|
||||
"VLLM_PP_LAYER_PARTITION",
|
||||
"VLLM_MLA_DISABLE",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH",
|
||||
"VLLM_USE_TRITON_FLASH_ATTN",
|
||||
"VLLM_USE_TRITON_AWQ",
|
||||
"VLLM_DP_RANK",
|
||||
|
||||
@ -121,18 +121,18 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
|
||||
lora_bias = self.slice_bias(lora_bias)
|
||||
|
||||
self.lora_a_stacked[0][index,
|
||||
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
|
||||
lora_a.T, non_blocking=True)
|
||||
0, :lora_a.shape[0], :lora_a.shape[1]].copy_(
|
||||
lora_a, non_blocking=True)
|
||||
self.lora_b_stacked[0][index,
|
||||
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
|
||||
lora_b.T, non_blocking=True)
|
||||
0, :lora_b.shape[0], :lora_b.shape[1]].copy_(
|
||||
lora_b, non_blocking=True)
|
||||
if lora_bias is not None:
|
||||
|
||||
self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
|
||||
self.lora_bias_stacked)
|
||||
assert len(self.lora_bias_stacked)
|
||||
self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_(
|
||||
lora_bias.T, non_blocking=True)
|
||||
lora_bias, non_blocking=True)
|
||||
|
||||
def apply(self,
|
||||
x: torch.Tensor,
|
||||
|
||||
@ -99,13 +99,13 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
||||
if self.is_merged_col_linear:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = self.output_size // 2
|
||||
offset = lora_b.shape[-1] // 2
|
||||
offset = lora_b.shape[0] // 2
|
||||
|
||||
left_weight = lora_b[:, tp_rank * shard_size:(tp_rank + 1) *
|
||||
shard_size]
|
||||
right_weight = lora_b[:, offset + tp_rank * shard_size:offset +
|
||||
(tp_rank + 1) * shard_size]
|
||||
lora_b = torch.cat([left_weight, right_weight], dim=1)
|
||||
left_weight = lora_b[tp_rank * shard_size:(tp_rank + 1) *
|
||||
shard_size, :]
|
||||
right_weight = lora_b[offset + tp_rank * shard_size:offset +
|
||||
(tp_rank + 1) * shard_size, :]
|
||||
lora_b = torch.cat([left_weight, right_weight], dim=0)
|
||||
# Applicable to cases where the base_layer is
|
||||
# ColumnParallelLinear.
|
||||
else:
|
||||
@ -113,7 +113,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
||||
shard_size = self.output_size
|
||||
start_idx = tensor_model_parallel_rank * shard_size
|
||||
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
||||
lora_b = lora_b[:, start_idx:end_idx]
|
||||
lora_b = lora_b[start_idx:end_idx, :]
|
||||
return lora_b
|
||||
|
||||
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
|
||||
@ -251,9 +251,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
for i, (shard_id, shard_size) in enumerate(
|
||||
zip(self.output_ids, self.output_slices)):
|
||||
if (lora_b_i := lora_b[i]) is not None:
|
||||
sliced_lora_b[i] = lora_b_i[:,
|
||||
shard_size * shard_id:shard_size *
|
||||
(shard_id + 1)]
|
||||
sliced_lora_b[i] = lora_b_i[shard_size * shard_id:shard_size *
|
||||
(shard_id + 1), :]
|
||||
return sliced_lora_b
|
||||
|
||||
def slice_bias(
|
||||
@ -285,12 +284,12 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
for i in range(self.n_slices):
|
||||
if (lora_a_i := lora_a[i]) is not None:
|
||||
self.lora_a_stacked[i][
|
||||
index, 0, :lora_a_i.shape[1], :lora_a_i.shape[0]].copy_(
|
||||
lora_a_i.T, non_blocking=True)
|
||||
index, 0, :lora_a_i.shape[0], :lora_a_i.shape[1]].copy_(
|
||||
lora_a_i, non_blocking=True)
|
||||
if (lora_b_i := lora_b[i]) is not None:
|
||||
self.lora_b_stacked[i][
|
||||
index, 0, :lora_b_i.shape[1], :lora_b_i.shape[0]].copy_(
|
||||
lora_b_i.T, non_blocking=True)
|
||||
index, 0, :lora_b_i.shape[0], :lora_b_i.shape[1]].copy_(
|
||||
lora_b_i, non_blocking=True)
|
||||
|
||||
if lora_bias is not None:
|
||||
self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
|
||||
@ -299,7 +298,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
if (lora_bias_i := lora_bias[i]) is not None:
|
||||
self.lora_bias_stacked[i][index,
|
||||
0, :lora_bias_i.shape[0]].copy_(
|
||||
lora_bias_i.T,
|
||||
lora_bias_i,
|
||||
non_blocking=True)
|
||||
|
||||
@classmethod
|
||||
@ -345,18 +344,18 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
self.q_shard_id = tp_rank
|
||||
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
|
||||
lora_b_q = lora_b[:, self.q_proj_shard_size *
|
||||
lora_b_q = lora_b[self.q_proj_shard_size *
|
||||
self.q_shard_id:self.q_proj_shard_size *
|
||||
(self.q_shard_id + 1)]
|
||||
(self.q_shard_id + 1), :]
|
||||
k_offset = self.q_proj_total_size
|
||||
lora_b_k = lora_b[:, k_offset +
|
||||
lora_b_k = lora_b[k_offset +
|
||||
self.kv_proj_shard_size * self.kv_shard_id:k_offset +
|
||||
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
|
||||
self.kv_proj_shard_size * (self.kv_shard_id + 1), :]
|
||||
v_offset = k_offset + self.kv_proj_total_size
|
||||
lora_b_v = lora_b[:, v_offset +
|
||||
lora_b_v = lora_b[v_offset +
|
||||
self.kv_proj_shard_size * self.kv_shard_id:v_offset +
|
||||
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
|
||||
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
|
||||
self.kv_proj_shard_size * (self.kv_shard_id + 1), :]
|
||||
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=0)
|
||||
return lora_b
|
||||
|
||||
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
|
||||
@ -465,7 +464,7 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = self.lora_a_stacked[0].shape[2]
|
||||
start_idx = tp_rank * shard_size
|
||||
lora_a = lora_a[:, start_idx:start_idx + shard_size]
|
||||
lora_a = lora_a[start_idx:start_idx + shard_size, :]
|
||||
return lora_a
|
||||
|
||||
def apply(self,
|
||||
@ -508,10 +507,10 @@ class MergedColumnParallelLinearWithShardedLoRA(
|
||||
output_shard_size = self.lora_a_stacked[0].shape[2]
|
||||
output_start_idx = self.tp_rank * output_shard_size
|
||||
lora_a = [
|
||||
lora_a[0][:, output_start_idx:output_start_idx +
|
||||
output_shard_size] if lora_a[0] is not None else None,
|
||||
lora_a[1][:, output_start_idx:output_start_idx +
|
||||
output_shard_size] if lora_a[1] is not None else None,
|
||||
lora_a[0][output_start_idx:output_start_idx +
|
||||
output_shard_size, :] if lora_a[0] is not None else None,
|
||||
lora_a[1][output_start_idx:output_start_idx +
|
||||
output_shard_size, :] if lora_a[1] is not None else None,
|
||||
]
|
||||
return lora_a
|
||||
|
||||
@ -551,7 +550,7 @@ class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = self.lora_a_stacked[0].shape[2]
|
||||
start_idx = tp_rank * shard_size
|
||||
lora_a = lora_a[:, start_idx:start_idx + shard_size]
|
||||
lora_a = lora_a[start_idx:start_idx + shard_size, :]
|
||||
return lora_a
|
||||
|
||||
def apply(self,
|
||||
@ -589,12 +588,12 @@ class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA):
|
||||
shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
|
||||
start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
|
||||
lora_a = [
|
||||
lora_a[0][:, start_idx[0]:start_idx[0] +
|
||||
shard_size[0]] if lora_a[0] is not None else None,
|
||||
lora_a[1][:, start_idx[1]:start_idx[1] +
|
||||
shard_size[1]] if lora_a[1] is not None else None,
|
||||
lora_a[2][:, start_idx[2]:start_idx[2] +
|
||||
shard_size[2]] if lora_a[2] is not None else None,
|
||||
lora_a[0][start_idx[0]:start_idx[0] +
|
||||
shard_size[0], :] if lora_a[0] is not None else None,
|
||||
lora_a[1][start_idx[1]:start_idx[1] +
|
||||
shard_size[1], :] if lora_a[1] is not None else None,
|
||||
lora_a[2][start_idx[2]:start_idx[2] +
|
||||
shard_size[2], :] if lora_a[2] is not None else None,
|
||||
]
|
||||
return lora_a
|
||||
|
||||
|
||||
@ -140,11 +140,11 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
||||
):
|
||||
self.reset_lora(index)
|
||||
self.lora_a_stacked[index,
|
||||
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
|
||||
lora_a.T, non_blocking=True)
|
||||
0, :lora_a.shape[0], :lora_a.shape[1]].copy_(
|
||||
lora_a, non_blocking=True)
|
||||
self.lora_b_stacked[index,
|
||||
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
|
||||
lora_b.T, non_blocking=True)
|
||||
0, :lora_b.shape[0], :lora_b.shape[1]].copy_(
|
||||
lora_b, non_blocking=True)
|
||||
if embeddings_tensor is not None:
|
||||
self.embeddings_tensors[
|
||||
index,
|
||||
|
||||
@ -39,7 +39,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
||||
shard_size = self.input_size
|
||||
start_idx = self.tp_rank * shard_size
|
||||
end_idx = (self.tp_rank + 1) * shard_size
|
||||
lora_a = lora_a[start_idx:end_idx, :]
|
||||
lora_a = lora_a[:,start_idx:end_idx]
|
||||
return lora_a
|
||||
|
||||
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
|
||||
@ -122,7 +122,7 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
|
||||
shard_size = self.lora_b_stacked[0].shape[2]
|
||||
start_idx = self.tp_rank * shard_size
|
||||
end_idx = (self.tp_rank + 1) * shard_size
|
||||
lora_b = lora_b[:, start_idx:end_idx]
|
||||
lora_b = lora_b[ start_idx:end_idx,:]
|
||||
return lora_b
|
||||
|
||||
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@ -95,11 +95,13 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
):
|
||||
self.reset_lora(index)
|
||||
self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
|
||||
lora_a, non_blocking=True)
|
||||
# NOTE self.lora_a_stacked is row-major, and lora_a is col-major,
|
||||
# so we need transpose here
|
||||
self.lora_a_stacked[index, :lora_a.shape[1], :lora_a.shape[0]].copy_(
|
||||
lora_a.T, non_blocking=True)
|
||||
self.lora_b_stacked[index,
|
||||
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
|
||||
lora_b.T, non_blocking=True)
|
||||
0, :lora_b.shape[0], :lora_b.shape[1]].copy_(
|
||||
lora_b, non_blocking=True)
|
||||
if embeddings_tensor is not None:
|
||||
self.embeddings_tensors[
|
||||
index,
|
||||
|
||||
@ -86,11 +86,11 @@ class LoRALayerWeights:
|
||||
embeddings_tensor_dim: Optional[int] = None,
|
||||
bias_enabled: Optional[bool] = False) -> "LoRALayerWeights":
|
||||
pin_memory = str(device) == "cpu" and is_pin_memory_available()
|
||||
lora_a = torch.zeros([input_dim, rank],
|
||||
lora_a = torch.zeros([rank, input_dim],
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
pin_memory=pin_memory)
|
||||
lora_b = torch.zeros([rank, output_dim],
|
||||
lora_b = torch.zeros([output_dim, rank],
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
pin_memory=pin_memory)
|
||||
|
||||
@ -152,30 +152,29 @@ class LoRAModel:
|
||||
module_name, peft_helper, lora_embeddings_tensor)
|
||||
|
||||
if is_bias:
|
||||
loras[module_name].bias = tensor.to(device=device,
|
||||
dtype=dtype).t()
|
||||
bias = tensor.to(device=device, dtype=dtype).t()
|
||||
loras[module_name].bias = tensor.to(device=device, dtype=dtype)
|
||||
bias = tensor.to(device=device, dtype=dtype)
|
||||
if pin_memory:
|
||||
bias = bias.pin_memory()
|
||||
loras[module_name].bias = bias
|
||||
elif is_lora_a:
|
||||
loras[module_name].lora_a = tensor.to(device=device,
|
||||
dtype=dtype).t()
|
||||
dtype=dtype)
|
||||
if pin_memory:
|
||||
loras[module_name].lora_a = loras[
|
||||
module_name].lora_a.pin_memory()
|
||||
else:
|
||||
loras[module_name].lora_b = tensor.to(device=device,
|
||||
dtype=dtype).t()
|
||||
dtype=dtype)
|
||||
assert embedding_padding_modules is not None
|
||||
if any(name in module_name
|
||||
for name in embedding_padding_modules
|
||||
) and target_embedding_padding is not None:
|
||||
lora_b = loras[module_name].lora_b
|
||||
assert target_embedding_padding >= lora_b.shape[1]
|
||||
addition = target_embedding_padding - lora_b.shape[1]
|
||||
assert target_embedding_padding >= lora_b.shape[0]
|
||||
addition = target_embedding_padding - lora_b.shape[0]
|
||||
loras[module_name].lora_b = torch.nn.functional.pad(
|
||||
lora_b, (0, addition))
|
||||
lora_b, (0, 0, 0, addition))
|
||||
if pin_memory:
|
||||
loras[module_name].lora_b = loras[
|
||||
module_name].lora_b.pin_memory()
|
||||
@ -585,7 +584,6 @@ class LoRAModelManager:
|
||||
"cpu",
|
||||
bias_enabled=bias_enabled,
|
||||
)
|
||||
lora.optimize()
|
||||
else:
|
||||
parts = module_name.split(".")
|
||||
replacements = self.packed_modules_mapping[parts[-1]]
|
||||
@ -600,7 +598,6 @@ class LoRAModelManager:
|
||||
"cpu",
|
||||
bias_enabled=bias_enabled,
|
||||
)
|
||||
lora.optimize()
|
||||
subloras.append(lora)
|
||||
lora = PackedLoRALayerWeights.pack(subloras)
|
||||
model.loras[module_name] = lora
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user