Merge branch 'main' into woosuk/model-runner-v2

This commit is contained in:
Woosuk Kwon 2025-09-23 09:22:58 -07:00
commit 17c2c106b1
192 changed files with 7964 additions and 4624 deletions

View File

@ -62,7 +62,7 @@ echo "--- Installing Python dependencies ---"
python3 -m pip install --progress-bar off git+https://github.com/thuml/depyf.git \
&& python3 -m pip install --progress-bar off pytest pytest-asyncio tpu-info \
&& python3 -m pip install --progress-bar off "lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d" \
&& python3 -m pip install --progress-bar off hf-transfer
&& python3 -m pip install --progress-bar off hf-transfer tblib==3.1.0
echo "--- Python dependencies installed ---"
export VLLM_USE_V1=1
export VLLM_XLA_CHECK_RECOMPILATION=1

View File

@ -62,7 +62,7 @@ echo "--- Installing Python dependencies ---"
python3 -m pip install --progress-bar off git+https://github.com/thuml/depyf.git \
&& python3 -m pip install --progress-bar off pytest pytest-asyncio tpu-info \
&& python3 -m pip install --progress-bar off "lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d" \
&& python3 -m pip install --progress-bar off hf-transfer
&& python3 -m pip install --progress-bar off hf-transfer tblib==3.1.0
echo "--- Python dependencies installed ---"
export VLLM_USE_V1=1
export VLLM_XLA_CHECK_RECOMPILATION=1

View File

@ -165,10 +165,18 @@ steps:
- tests/v1/test_hybrid_lb_dp.py
- tests/v1/engine/test_engine_core_client.py
commands:
# test with tp=2 and external_dp=2
# test with torchrun tp=2 and external_dp=2
- torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
# test with tp=2 and pp=2
# test with torchrun tp=2 and pp=2
- PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
# test with torchrun tp=4 and dp=1
- TP_SIZE=4 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py
# test with torchrun tp=2, pp=2 and dp=1
- PP_SIZE=2 TP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py
# test with torchrun tp=1 and dp=4 with ep
- DP_SIZE=4 ENABLE_EP=1 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py
# test with torchrun tp=2 and dp=2 with ep
- TP_SIZE=2 DP_SIZE=2 ENABLE_EP=1 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py
# test with internal dp
- python3 ../examples/offline_inference/data_parallel.py --enforce-eager
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py

1
.github/CODEOWNERS vendored
View File

@ -72,6 +72,7 @@ mkdocs.yaml @hmellor
# Linting
.markdownlint.yaml @hmellor
.pre-commit-config.yaml @hmellor
/tools/pre_commit @hmellor
# CPU
/vllm/v1/worker/cpu* @bigPYJ1151

View File

@ -43,10 +43,6 @@ body:
Any other things you would like to mention.
validations:
required: false
- type: markdown
attributes:
value: >
Thanks for contributing 🎉! The vLLM core team hosts a biweekly RFC review session at 9:30AM Pacific Time, while most RFCs can be discussed online, you can optionally sign up for a slot to discuss your RFC online [here](https://docs.google.com/document/d/1CiLVBZeIVfR7_PNAKVSusxpceywkoOOB78qoWqHvSZc/edit).
- type: checkboxes
id: askllm
attributes:

View File

@ -60,38 +60,32 @@ repos:
files: ^requirements/test\.(in|txt)$
- id: mypy-local
name: Run mypy for local Python installation
entry: tools/mypy.sh 0 "local"
language: python
types: [python]
additional_dependencies: &mypy_deps [mypy==1.11.1, types-cachetools, types-setuptools, types-PyYAML, types-requests, pydantic]
entry: python tools/pre_commit/mypy.py 0 "local"
stages: [pre-commit] # Don't run in CI
<<: &mypy_common
language: python
types_or: [python, pyi]
require_serial: true
additional_dependencies: [mypy==1.11.1, regex, types-cachetools, types-setuptools, types-PyYAML, types-requests, types-torch, pydantic]
- id: mypy-3.9 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
name: Run mypy for Python 3.9
entry: tools/mypy.sh 1 "3.9"
language: python
types: [python]
additional_dependencies: *mypy_deps
entry: python tools/pre_commit/mypy.py 1 "3.9"
<<: *mypy_common
stages: [manual] # Only run in CI
- id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
name: Run mypy for Python 3.10
entry: tools/mypy.sh 1 "3.10"
language: python
types: [python]
additional_dependencies: *mypy_deps
entry: python tools/pre_commit/mypy.py 1 "3.10"
<<: *mypy_common
stages: [manual] # Only run in CI
- id: mypy-3.11 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
name: Run mypy for Python 3.11
entry: tools/mypy.sh 1 "3.11"
language: python
types: [python]
additional_dependencies: *mypy_deps
entry: python tools/pre_commit/mypy.py 1 "3.11"
<<: *mypy_common
stages: [manual] # Only run in CI
- id: mypy-3.12 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
name: Run mypy for Python 3.12
entry: tools/mypy.sh 1 "3.12"
language: python
types: [python]
additional_dependencies: *mypy_deps
entry: python tools/pre_commit/mypy.py 1 "3.12"
<<: *mypy_common
stages: [manual] # Only run in CI
- id: shellcheck
name: Lint shell scripts
@ -155,11 +149,10 @@ repos:
additional_dependencies: [regex]
- id: check-pickle-imports
name: Prevent new pickle/cloudpickle imports
entry: python tools/check_pickle_imports.py
entry: python tools/pre_commit/check_pickle_imports.py
language: python
types: [python]
pass_filenames: false
additional_dependencies: [pathspec, regex]
additional_dependencies: [regex]
- id: validate-config
name: Validate configuration has default values and that each field has a docstring
entry: python tools/validate_config.py

View File

@ -680,7 +680,7 @@ vllm bench serve \
--save-result \
--result-dir ~/vllm_benchmark_results \
--save-detailed \
--endpoint /v1/chat/completion
--endpoint /v1/chat/completions
```
##### Videos (ShareGPT4Video)
@ -707,7 +707,7 @@ vllm bench serve \
--save-result \
--result-dir ~/vllm_benchmark_results \
--save-detailed \
--endpoint /v1/chat/completion
--endpoint /v1/chat/completions
```
##### Synthetic Random Images (random-mm)

View File

@ -23,7 +23,7 @@ Now supports 5 types of connectors:
- **SharedStorageConnector**: refer to <gh-file:examples/offline_inference/disaggregated-prefill-v1/run.sh> for the example usage of SharedStorageConnector disaggregated prefilling.
- **LMCacheConnectorV1**: refer to <gh-file:examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh> for the example usage of LMCacheConnectorV1 disaggregated prefilling which uses NIXL as the underlying KV transmission.
- **NixlConnector**: refer to <gh-file:tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh> for the example usage of NixlConnector disaggregated prefilling which support fully async send/recv.
- **NixlConnector**: refer to <gh-file:tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh> for the example usage of NixlConnector disaggregated prefilling which support fully async send/recv. For detailed usage guide, see [NixlConnector Usage Guide](nixl_connector_usage.md).
- **P2pNcclConnector**: refer to <gh-file:examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_example_p2p_nccl_xpyd.sh> for the example usage of P2pNcclConnector disaggregated prefilling.
- **MultiConnector**: take advantage of the kv_connector_extra_config: dict[str, Any] already present in KVTransferConfig to stash all the connectors we want in an ordered list of kwargs.such as:
@ -31,6 +31,18 @@ Now supports 5 types of connectors:
--kv-transfer-config '{"kv_connector":"MultiConnector","kv_role":"kv_both","kv_connector_extra_config":{"connectors":[{"kv_connector":"NixlConnector","kv_role":"kv_both"},{"kv_connector":"SharedStorageConnector","kv_role":"kv_both","kv_connector_extra_config":{"shared_storage_path":"local_storage"}}]}}'
```
For NixlConnector, you may also specify one or multiple NIXL_Backend. Such as:
```bash
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both", "kv_buffer_device":"cuda", "kv_connector_extra_config":{"backend":["UCX", "GDS"]}'
```
- **OffloadingConnector**: enable offloading of KV data to CPU memory, customizing the CPU block size (in tokens) and number of blocks to allocate (per worker):
```bash
--kv-transfer-config '{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"block_size": 64, "num_cpu_blocks": 1000}}'
```
## Benchmarks
Please refer to <gh-file:benchmarks/disagg_benchmarks> for disaggregated prefilling benchmarks.

View File

@ -0,0 +1,159 @@
# NixlConnector Usage Guide
NixlConnector is a high-performance KV cache transfer connector for vLLM's disaggregated prefilling feature. It provides fully asynchronous send/receive operations using the NIXL library for efficient cross-process KV cache transfer.
## Prerequisites
### Installation
Install the NIXL library: `uv pip install nixl`, as a quick start.
- Refer to [NIXL official repository](https://github.com/ai-dynamo/nixl) for more installation instructions
- The specified required NIXL version can be found in [requirements/kv_connectors.txt](../../requirements/kv_connectors.txt) and other relevant config files
### Transport Configuration
NixlConnector uses NIXL library for underlying communication, which supports multiple transport backends. UCX (Unified Communication X) is the primary default transport library used by NIXL. Configure transport environment variables:
```bash
# Example UCX configuration, adjust according to your enviroment
export UCX_TLS=all # or specify specific transports like "rc,ud,sm,^cuda_ipc" ..etc
export UCX_NET_DEVICES=all # or specify network devices like "mlx5_0:1,mlx5_1:1"
```
!!! tip
When using UCX as the transport backend, NCCL environment variables (like `NCCL_IB_HCA`, `NCCL_SOCKET_IFNAME`) are not applicable to NixlConnector, so configure UCX-specific environment variables instead of NCCL variables.
## Basic Usage (on the same host)
### Producer (Prefiller) Configuration
Start a prefiller instance that produces KV caches
```bash
# 1st GPU as prefiller
CUDA_VISIBLE_DEVICES=0 \
UCX_NET_DEVICES=all \
VLLM_NIXL_SIDE_CHANNEL_PORT=5600 \
vllm serve Qwen/Qwen3-0.6B \
--port 8100 \
--enforce-eager \
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
```
### Consumer (Decoder) Configuration
Start a decoder instance that consumes KV caches:
```bash
# 2nd GPU as decoder
CUDA_VISIBLE_DEVICES=1 \
UCX_NET_DEVICES=all \
VLLM_NIXL_SIDE_CHANNEL_PORT=5601 \
vllm serve Qwen/Qwen3-0.6B \
--port 8200 \
--enforce-eager \
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
```
### Proxy Server
Use a proxy server to route requests between prefiller and decoder:
```bash
python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py \
--port 8192 \
--prefiller-hosts localhost \
--prefiller-ports 8100 \
--decoder-hosts localhost \
--decoder-ports 8200
```
## Environment Variables
- `VLLM_NIXL_SIDE_CHANNEL_PORT`: Port for NIXL handshake communication
- Default: 5600
- **Required for both prefiller and decoder instances**
- Each vLLM worker needs a unique port on its host; using the same port number across different hosts is fine
- For TP/DP deployments, each worker's port on a node is computed as: base_port + dp_rank * tp_size + tp_rank (e.g., with `--tensor-parallel-size=4` and base_port=5600, tp_rank 0..3 use ports 5600, 5601, 5602, 5603 on that node).
- Used for the initial NIXL handshake between the prefiller and the decoder
- `VLLM_NIXL_SIDE_CHANNEL_HOST`: Host for side channel communication
- Default: "localhost"
- Set when prefiller and decoder are on different machines
- Connection info is passed via KVTransferParams from prefiller to decoder for handshake
- `VLLM_NIXL_ABORT_REQUEST_TIMEOUT`: Timeout (in seconds) for automatically releasing the prefillers KV cache for a particular request. (Optional)
- Default: 120
- If a request is aborted and the decoder has not yet read the KV-cache blocks through the nixl channel, the prefill instance will release its KV-cache blocks after this timeout to avoid holding them indefinitely.
## Multi-Instance Setup
### Multiple Prefiller Instances on Different Machines
```bash
# Prefiller 1 on Machine A (example IP: ${IP1})
VLLM_NIXL_SIDE_CHANNEL_HOST=${IP1} \
VLLM_NIXL_SIDE_CHANNEL_PORT=5600 \
UCX_NET_DEVICES=all \
vllm serve Qwen/Qwen3-0.6B --port 8000 \
--tensor-parallel-size 8 \
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_producer"}'
# Prefiller 2 on Machine B (example IP: ${IP2})
VLLM_NIXL_SIDE_CHANNEL_HOST=${IP2} \
VLLM_NIXL_SIDE_CHANNEL_PORT=5600 \
UCX_NET_DEVICES=all \
vllm serve Qwen/Qwen3-0.6B --port 8000 \
--tensor-parallel-size 8 \
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_producer"}'
```
### Multiple Decoder Instances on Different Machines
```bash
# Decoder 1 on Machine C (example IP: ${IP3})
VLLM_NIXL_SIDE_CHANNEL_HOST=${IP3} \
VLLM_NIXL_SIDE_CHANNEL_PORT=5600 \
UCX_NET_DEVICES=all \
vllm serve Qwen/Qwen3-0.6B --port 8000 \
--tensor-parallel-size 8 \
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_consumer"}'
# Decoder 2 on Machine D (example IP: ${IP4})
VLLM_NIXL_SIDE_CHANNEL_HOST=${IP4} \
VLLM_NIXL_SIDE_CHANNEL_PORT=5600 \
UCX_NET_DEVICES=all \
vllm serve Qwen/Qwen3-0.6B --port 8000 \
--tensor-parallel-size 8 \
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_consumer"}'
```
### Proxy for Multiple Instances
```bash
python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py \
--port 8192 \
--prefiller-hosts ${IP1} ${IP2} \
--prefiller-ports 8000 8000 \
--decoder-hosts ${IP3} ${IP4} \
--decoder-ports 8000 8000
```
### KV Role Options
- **kv_producer**: For prefiller instances that generate KV caches
- **kv_consumer**: For decoder instances that consume KV caches from prefiller
- **kv_both**: Enables symmetric functionality where the connector can act as both producer and consumer. This provides flexibility for experimental setups and scenarios where the role distinction is not predetermined.
!!! tip
NixlConnector currently does not distinguish `kv_role`; the actual prefiller/decoder roles are determined by the upper-level proxy (e.g., `toy_proxy_server.py` using `--prefiller-hosts` and `--decoder-hosts`).
Therefore, `kv_role` in `--kv-transfer-config` is effectively a placeholder and does not affect NixlConnector's behavior.
## Example Scripts/Code
Refer to these example scripts in the vLLM repository:
- [run_accuracy_test.sh](../../tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh)
- [toy_proxy_server.py](../../tests/v1/kv_connector/nixl_integration/toy_proxy_server.py)
- [test_accuracy.py](../../tests/v1/kv_connector/nixl_integration/test_accuracy.py)

View File

@ -319,6 +319,15 @@ Supported models:
Flags: `--tool-call-parser glm45`
### Qwen3-Coder Models (`qwen3_xml`)
Supported models:
* `Qwen/Qwen3-480B-A35B-Instruct`
* `Qwen/Qwen3-Coder-30B-A3B-Instruct`
Flags: `--tool-call-parser qwen3_xml`
### Models with Pythonic Tool Calls (`pythonic`)
A growing number of models output a python list to represent tool calls instead of using JSON. This has the advantage of inherently supporting parallel tool calls and removing ambiguity around the JSON schema required for tool calls. The `pythonic` tool parser can support such models.

View File

@ -352,6 +352,7 @@ th {
| `DeepseekV2ForCausalLM` | DeepSeek-V2 | `deepseek-ai/DeepSeek-V2`, `deepseek-ai/DeepSeek-V2-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3`, `deepseek-ai/DeepSeek-R1`, `deepseek-ai/DeepSeek-V3.1`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Dots1ForCausalLM` | dots.llm1 | `rednote-hilab/dots.llm1.base`, `rednote-hilab/dots.llm1.inst`, etc. | | ✅︎ | ✅︎ |
| `DotsOCRForCausalLM` | dots_ocr | `rednote-hilab/dots.ocr` | | ✅︎ | ✅︎ |
| `Ernie4_5ForCausalLM` | Ernie4.5 | `baidu/ERNIE-4.5-0.3B-PT`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Ernie4_5_MoeForCausalLM` | Ernie4.5MoE | `baidu/ERNIE-4.5-21B-A3B-PT`, `baidu/ERNIE-4.5-300B-A47B-PT`, etc. |✅︎| ✅︎ | ✅︎ |
| `ExaoneForCausalLM` | EXAONE-3 | `LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |

View File

@ -193,7 +193,7 @@ For production deployments requiring strict SLA guarantees for time-to-first-tok
1. **Install gdrcopy/ucx/nixl**: For maximum performance, run the [install_gdrcopy.sh](gh-file:tools/install_gdrcopy.sh) script to install `gdrcopy` (e.g., `install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "12.8" "x64"`). You can find available OS versions [here](https://developer.download.nvidia.com/compute/redist/gdrcopy/CUDA%2012.8/). If `gdrcopy` is not installed, things will still work with a plain `pip install nixl`, just with lower performance. `nixl` and `ucx` are installed as dependencies via pip.
2. **Configure Both Instances**: Add this flag to both prefill and decode instances `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}`
2. **Configure Both Instances**: Add this flag to both prefill and decode instances `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}`. Noted, you may also specify one or multiple NIXL_Backend. Such as: `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both", "kv_connector_extra_config":{"backend":["UCX", "GDS"]}'`
3. **Client Orchestration**: Use the client-side script below to coordinate prefill/decode operations. We are actively working on routing solutions.

View File

@ -101,6 +101,13 @@ def parse_args():
"--quantization",
type=str,
)
parser.add_argument(
"--disable-expert-parallel",
dest="enable_expert_parallel",
action="store_false",
help="Disable expert parallel (default: enabled).",
)
parser.set_defaults(enable_expert_parallel=True)
return parser.parse_args()
@ -113,6 +120,7 @@ def main(
dp_master_port,
GPUs_per_dp_rank,
enforce_eager,
enable_expert_parallel,
trust_remote_code,
max_num_seqs,
max_model_len,
@ -168,7 +176,7 @@ def main(
model=model,
tensor_parallel_size=GPUs_per_dp_rank,
enforce_eager=enforce_eager,
enable_expert_parallel=True,
enable_expert_parallel=enable_expert_parallel,
trust_remote_code=trust_remote_code,
max_num_seqs=max_num_seqs,
max_model_len=max_model_len,
@ -229,6 +237,7 @@ if __name__ == "__main__":
dp_master_port,
tp_size,
args.enforce_eager,
args.enable_expert_parallel,
args.trust_remote_code,
args.max_num_seqs,
args.max_model_len,

View File

@ -0,0 +1,81 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
experimental support for data-parallel inference with torchrun
Note the data load balancing and distribution is done out of the vllm engine,
no internal lb supported in external_launcher mode.
"""
from vllm import LLM, SamplingParams
# Create prompts, the same across all ranks
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
] * 50
# Create sampling parameters, the same across all ranks
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Use `distributed_executor_backend="external_launcher"` so that
# this llm engine/instance only creates one worker.
# it is important to set an explicit seed to make sure that
# all ranks have the same random seed, so that sampling can be
# deterministic across ranks.
llm = LLM(
model="microsoft/Phi-mini-MoE-instruct",
tensor_parallel_size=1,
data_parallel_size=2,
pipeline_parallel_size=1,
enable_expert_parallel=False,
distributed_executor_backend="external_launcher",
max_model_len=4096,
gpu_memory_utilization=0.6,
seed=1,
)
dp_rank = llm.llm_engine.vllm_config.parallel_config.data_parallel_rank
dp_size = llm.llm_engine.vllm_config.parallel_config.data_parallel_size
prompts = [
f"{idx}.{prompt}" for idx, prompt in enumerate(prompts) if idx % dp_size == dp_rank
]
outputs = llm.generate(prompts, sampling_params)
# all ranks will have the same outputs
print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}\n")
print("-" * 50)
"""
Further tips:
1. to communicate control messages across all ranks, use the cpu group,
a PyTorch ProcessGroup with GLOO backend.
```python
from vllm.distributed.parallel_state import get_world_group
cpu_group = get_world_group().cpu_group
torch_rank = dist.get_rank(group=cpu_group)
if torch_rank == 0:
# do something for rank 0, e.g. saving the results to disk.
```
2. to communicate data across all ranks, use the model's device group,
a PyTorch ProcessGroup with NCCL backend.
```python
from vllm.distributed.parallel_state import get_world_group
device_group = get_world_group().device_group
```
3. to access the model directly in every rank, use the following code:
```python
llm.llm_engine.model_executor.driver_worker.worker.model_runner.model
```
"""

View File

@ -126,6 +126,23 @@ def run_chameleon(questions: list[str], modality: str) -> ModelRequestData:
)
# Dots-OCR
def run_dots_ocr(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
prompts = [f"<|img|><|imgpad|><|endofimg|>{question}" for question in questions]
engine_args = EngineArgs(
model="rednote-hilab/dots.ocr",
limit_mm_per_prompt={modality: 1},
trust_remote_code=True,
)
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
def run_command_a_vision(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
@ -1676,6 +1693,7 @@ model_example_map = {
"aya_vision": run_aya_vision,
"blip-2": run_blip2,
"chameleon": run_chameleon,
"dots_ocr": run_dots_ocr,
"command_a_vision": run_command_a_vision,
"deepseek_vl_v2": run_deepseek_vl2,
"ernie45_vl": run_ernie45_vl,

View File

@ -110,27 +110,6 @@ ignore_missing_imports = true
check_untyped_defs = true
follow_imports = "silent"
# After fixing type errors resulting from follow_imports: "skip" -> "silent",
# move the directory here and remove it from tools/mypy.sh
files = [
"vllm/*.py",
"vllm/assets",
"vllm/entrypoints",
"vllm/inputs",
"vllm/logging_utils",
"vllm/multimodal",
"vllm/platforms",
"vllm/transformers_utils",
"vllm/triton_utils",
"vllm/usage",
]
# TODO(woosuk): Include the code from Megatron and HuggingFace.
exclude = [
"vllm/model_executor/parallel_utils/|vllm/model_executor/models/",
# Ignore triton kernels in ops.
'vllm/attention/ops/.*\.py$'
]
[tool.isort]
skip_glob = [
".buildkite/*",

View File

@ -14,14 +14,4 @@ nixl==0.3.0
tpu_info==0.4.0
# Install torch_xla
--pre
--extra-index-url https://download.pytorch.org/whl/nightly/cpu
--find-links https://storage.googleapis.com/libtpu-wheels/index.html
--find-links https://storage.googleapis.com/libtpu-releases/index.html
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
torch==2.9.0.dev20250730
torchvision==0.24.0.dev20250730
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250730-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250730-cp312-cp312-linux_x86_64.whl ; python_version == "3.12"
torch_xla[tpu, pallas]==2.8.0

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import weakref
from collections.abc import Sequence
from copy import deepcopy
from typing import Callable, Union
@ -10,7 +11,26 @@ from torch._ops import OpOverload
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.inductor_pass import InductorPass
from vllm.config import get_current_vllm_config
from vllm.compilation.pass_manager import with_pattern_match_debug
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import VllmConfig, get_current_vllm_config
class LazyInitPass(InductorPass):
"""
If there's a pass that we want to initialize lazily in a test,
we can wrap it in LazyInitPass, which will initialize the pass when invoked
and then immediately invoke it.
"""
def __init__(self, pass_cls: type[VllmInductorPass],
vllm_config: VllmConfig):
self.pass_cls = pass_cls
self.vllm_config = weakref.proxy(vllm_config) # avoid cycle
def __call__(self, graph: fx.Graph) -> None:
self.pass_ = self.pass_cls(self.vllm_config)
self.pass_(graph)
class TestBackend:
@ -40,10 +60,16 @@ class TestBackend:
example_inputs,
config_patches=self.inductor_config)
@with_pattern_match_debug
def post_pass(self, graph: fx.Graph):
self.graph_pre_pass = deepcopy(graph)
VllmInductorPass.dump_prefix = 0
for pass_ in self.custom_passes:
pass_(graph)
VllmInductorPass.dump_prefix += 1
VllmInductorPass.dump_prefix = None
self.graph_post_pass = deepcopy(graph)
# assign by reference, will reflect the final state of the graph

View File

@ -46,7 +46,10 @@ backend_configs = {
# FA3 on Hopper
"FA3":
BackendConfig(name="FA3",
env_vars={"VLLM_FLASH_ATTN_VERSION": "3"},
env_vars={
"VLLM_FLASH_ATTN_VERSION": "3",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL",
},
@ -66,6 +69,7 @@ backend_configs = {
BackendConfig(name="FlashAttentionMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL_DECODE_ONLY",
@ -89,7 +93,10 @@ backend_configs = {
# FA2
"FA2":
BackendConfig(name="FA2",
env_vars={"VLLM_FLASH_ATTN_VERSION": "2"},
env_vars={
"VLLM_FLASH_ATTN_VERSION": "2",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL",
}),

View File

@ -294,6 +294,8 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
compiled_model = torch.compile(model, backend=backend)
compiled_model(hidden_states)
assert async_tp_pass.matched_count == 1
# In pre-nodes, all gather or reduce scatter should exist,
# fused_matmul_reduce_scatter or fused_all_gather_matmul should not
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)

View File

@ -4,7 +4,7 @@ import pytest
import vllm
from vllm.compilation.counter import compilation_counter
from vllm.config import VllmConfig
from vllm.config import CompilationConfig, VllmConfig
from vllm.utils import _is_torch_equal_or_newer
@ -26,6 +26,14 @@ def test_use_cudagraphs_dynamic(monkeypatch):
assert not vllm_config.compilation_config.use_cudagraph
def test_custom_op():
# proper syntax
_ = CompilationConfig(custom_ops=["+quant_fp8", "-silu_and_mul"])
with pytest.raises(ValueError, match="Invalid syntax '"):
_ = CompilationConfig(custom_ops=["quant_fp8"])
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
# NB: We don't test VLLM_DISABLE_COMPILE_CACHE=0 because that depends

View File

@ -8,9 +8,10 @@ import vllm.envs as envs
from vllm import LLM, SamplingParams
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.fusion import FUSED_OPS, FusionPass
from vllm.compilation.fusion import FUSED_OPS, RMSNormQuantFusionPass
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import CompilationConfig, PassConfig, VllmConfig
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, kFp8DynamicTokenSym, kFp8StaticTensorSym)
@ -58,11 +59,12 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
vllm_config.compilation_config = CompilationConfig(
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True))
noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = FusionPass.instance(vllm_config)
fusion_pass = RMSNormQuantFusionPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
passes = [noop_pass, fusion_pass, act_quant_fusion_pass
] if do_fusion else [noop_pass]
passes = [noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass
] if do_fusion else [noop_pass, cleanup_pass]
func_pass = FixFunctionalizationPass(vllm_config)
backend_func = TestBackend(*passes, func_pass)
backend_no_func = TestBackend(*passes)

View File

@ -4,11 +4,11 @@
import pytest
import torch
import vllm.envs as envs
import vllm.plugins
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
FusionPass)
RMSNormQuantFusionPass)
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
VllmConfig)
from vllm.model_executor.layers.layernorm import RMSNorm
@ -79,15 +79,15 @@ class TestModel(torch.nn.Module):
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [64, 3392, 4096])
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049])
@pytest.mark.parametrize("hidden_size", [64])
@pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize("static", [True, False])
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
@pytest.mark.parametrize("cuda_force_torch",
[True, False] if cutlass_fp8_supported() else [True])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Only test on CUDA and ROCm")
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
cuda_force_torch):
@ -104,9 +104,10 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
with vllm.config.set_current_vllm_config(vllm_config):
# Reshape pass is needed for the fusion pass to work
noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = FusionPass.instance(vllm_config)
fusion_pass = RMSNormQuantFusionPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
backend = TestBackend(noop_pass, fusion_pass)
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
model = TestModel(hidden_size, eps, static, cuda_force_torch)
# First dimension dynamic
@ -128,6 +129,8 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
assert fusion_pass.matched_count == 2
# In pre-nodes, fp8 quant should be there and fused kernels should not
backend.check_before_ops(model.ops_in_model_before())

View File

@ -9,6 +9,7 @@ import vllm.envs as envs
from vllm.compilation.collective_fusion import AllReduceFusionPass
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig,
ModelConfig, PassConfig, VllmConfig)
from vllm.distributed import tensor_model_parallel_all_reduce
@ -215,8 +216,10 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
noop_pass = NoOpEliminationPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass)
backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass,
cleanup_pass)
token_num = batch_size * seq_len
model = test_model_cls(hidden_size, token_num)
@ -227,6 +230,7 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
compiled_model = torch.compile(model, backend=backend)
compiled_model(hidden_states, residual)
assert all_reduce_fusion_pass.matched_count == 1
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
backend.check_after_ops(model.ops_in_model_after())
del all_reduce_fusion_pass

View File

@ -6,18 +6,19 @@ from typing import Optional
import pytest
import torch._dynamo
from tests.compile.backend import TestBackend
from tests.compile.backend import LazyInitPass, TestBackend
from tests.models.utils import check_outputs_equal
from tests.v1.attention.utils import (BatchSpec, _Backend,
create_common_attn_metadata)
from vllm import LLM, SamplingParams
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.attention import Attention
from vllm.attention import Attention, AttentionMetadata
from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.compilation.fusion import QUANT_OPS
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
ModelConfig, PassConfig, SchedulerConfig, VllmConfig,
set_current_vllm_config)
@ -104,7 +105,7 @@ def test_attention_fusion_v0(example_prompts, monkeypatch, model: str,
# AttnFusionPass needs attention layers to be registered in config upon init
# so we initialize it during compilation.
attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw)
attn_pass = LazyInitPass(AttnFusionPass, vllm_config)
backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass)
llm2 = LLM(model,
enforce_eager=True,
@ -197,7 +198,8 @@ class AttentionQuantPatternModel(torch.nn.Module):
device=self.device,
)
def build_attn_metadata(self, batch_size: int, use_hnd: bool):
def build_attn_metadata(self, batch_size: int, use_hnd: bool) \
-> AttentionMetadata:
"""Initialize attention metadata."""
# Create common attn metadata
@ -447,9 +449,10 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
# Create test backend with fusion passes enabled
noop_pass = NoOpEliminationPass(vllm_config)
attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw
)
test_backend = TestBackend(noop_pass, attn_pass)
attn_pass = LazyInitPass(AttnFusionPass, vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass)
# Compile model with fusion enabled
model_compiled = torch.compile(model_fused,
@ -485,6 +488,9 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
test_backend.check_before_ops([QUANT_OPS[quant_key]],
fully_replaced=True)
# access the underlying `AttnFusionPass` on the `LazyInitPass`
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)
# Check attention ops in the graph before and after fusion
attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass))
attn_nodes_post = list(find_op_nodes(ATTN_OP,

View File

@ -6,10 +6,12 @@ import torch
import vllm.envs as envs
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.fusion import FusionPass
from vllm.compilation.fusion import RMSNormQuantFusionPass
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
PassConfig, VllmConfig)
from vllm.distributed import tensor_model_parallel_all_reduce
@ -104,7 +106,7 @@ class TestQuantModel(torch.nn.Module):
# Initialize weights
torch.nn.init.normal_(self.gate_proj, std=0.02)
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=False)
self.fp8_linear = Fp8LinearOp(act_quant_static=True)
self.scale = torch.rand(1, dtype=torch.float32)
# Create a weight that is compatible with torch._scaled_mm,
@ -137,8 +139,7 @@ class TestQuantModel(torch.nn.Module):
# layer normalization
norm_output, residual_output = self.norm(all_reduce, residual)
# for static input quantization
# self.fp8_linear is initialized with use_per_token_if_dynamic=False
# scaled_mm with static input quantization
fp8_linear_result = self.fp8_linear.apply(norm_output,
self.w,
self.wscale,
@ -253,16 +254,20 @@ def sequence_parallelism_pass_on_test_model(
dtype=dtype,
seed=42)
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
noop_pass = NoOpEliminationPass(vllm_config)
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
passes_for_backend = [noop_pass, sequence_parallelism_pass]
passes_for_backend: list[VllmInductorPass] = \
[noop_pass, sequence_parallelism_pass]
if enable_fusion:
fusion_pass = FusionPass.instance(vllm_config)
fusion_pass = RMSNormQuantFusionPass(vllm_config)
passes_for_backend.append(fusion_pass)
passes_for_backend.append(cleanup_pass)
backend_no_func = TestBackend(*passes_for_backend)
backend_func = TestBackend(*passes_for_backend, func_pass)
@ -279,6 +284,8 @@ def sequence_parallelism_pass_on_test_model(
compiled_model_func = torch.compile(model, backend=backend_func)
compiled_model_func(hidden_states, residual)
assert sequence_parallelism_pass.matched_count == 1
# In pre-nodes, all reduce should be there,
# reduce scatter and all gather should not
backend_no_func.check_before_ops(model.ops_in_model_before())

View File

@ -15,6 +15,7 @@ from vllm.compilation.activation_quant_fusion import (
# yapf: enable
from vllm.compilation.fusion import QUANT_OPS
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import CompilationConfig, PassConfig, VllmConfig
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.quantization.utils.quant_utils import (
@ -69,6 +70,10 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
def __init__(self, hidden_size: int, x: torch.Tensor, **kwargs):
super().__init__()
from vllm.compilation.activation_quant_fusion import (
silu_and_mul_nvfp4_quant_supported)
assert silu_and_mul_nvfp4_quant_supported
self.silu_and_mul = SiluAndMul()
# create nvfp4 weight
@ -127,7 +132,11 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class,
pass_config=PassConfig(enable_fusion=True, enable_noop=True))
fusion_pass = ActivationQuantFusionPass(config)
backend = TestBackend(NoOpEliminationPass(config), fusion_pass)
passes = [
NoOpEliminationPass(config), fusion_pass,
PostCleanupPass(config)
]
backend = TestBackend(*passes)
model = model_class(hidden_size=hidden_size,
cuda_force_torch=cuda_force_torch,
x=x)
@ -151,6 +160,8 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class,
atol=atol,
rtol=rtol)
assert fusion_pass.matched_count == 1
# In pre-nodes, quant op should be present and fused kernels should not
backend.check_before_ops(model.ops_in_model_before())

View File

@ -0,0 +1,81 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# unit test for `examples/offline_inference/torchrun_example.py`
import os
import random
import torch.distributed as dist
from vllm import LLM, SamplingParams
from vllm.distributed.parallel_state import get_tp_group, get_world_group
dist.init_process_group(backend="gloo")
# Create prompts
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
] * 10
dp_size = int(os.getenv("DP_SIZE", "1"))
dp_rank = int(os.getenv("DP_RANK", "0"))
if dp_size > 1:
# distribute the prompts across the data parallel ranks
prompts = [
prompt for idx, prompt in enumerate(prompts)
if idx % dp_size == dp_rank
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# set different `gpu_memory_utilization` and `swap_space` for different ranks,
# to test if all ranks agree on the same kv cache configuration.
llm = LLM(model="microsoft/Phi-mini-MoE-instruct",
tensor_parallel_size=int(os.getenv("TP_SIZE", "1")),
pipeline_parallel_size=int(os.getenv("PP_SIZE", "1")),
enable_expert_parallel=int(os.getenv("ENABLE_EP", "0")) == 1,
distributed_executor_backend="external_launcher",
gpu_memory_utilization=random.uniform(0.7, 0.9),
swap_space=random.randint(1, 4),
seed=0)
outputs = llm.generate(prompts, sampling_params)
group = get_world_group() if dp_size == 1 else get_tp_group()
cpu_group = group.cpu_group
group_rank = dist.get_rank(group=cpu_group)
def test_consistent_across_ranks(obj):
if group_rank == 0:
dist.broadcast_object_list([obj], src=group.ranks[0], group=cpu_group)
else:
container = [None]
dist.broadcast_object_list(container,
src=group.ranks[0],
group=cpu_group)
assert container[0] == obj
test_consistent_across_ranks(
llm.llm_engine.vllm_config.cache_config.num_cpu_blocks)
test_consistent_across_ranks(
llm.llm_engine.vllm_config.cache_config.num_gpu_blocks)
# make sure we can access the model parameters from the calling process
# of the `LLM` instance.
params = list(llm.llm_engine.model_executor.driver_worker.worker.model_runner.
model.parameters())
test_consistent_across_ranks(len(params))
# all ranks should have the same outputs
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
test_consistent_across_ranks(prompt)
test_consistent_across_ranks(generated_text)
print(f"Rank {group_rank}, Prompt: {prompt!r}, "
f"Generated text: {generated_text!r}")

View File

@ -0,0 +1,106 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import pytest_asyncio
from openai import OpenAI
from ...utils import RemoteOpenAIServer
MODEL_NAME = "openai/gpt-oss-20b"
@pytest.fixture(scope="module")
def monkeypatch_module():
from _pytest.monkeypatch import MonkeyPatch
mpatch = MonkeyPatch()
yield mpatch
mpatch.undo()
@pytest.fixture(scope="module")
def mcp_disabled_server(monkeypatch_module: pytest.MonkeyPatch):
args = ["--enforce-eager", "--tool-server", "demo"]
with monkeypatch_module.context() as m:
m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1")
m.setenv("PYTHON_EXECUTION_BACKEND", "dangerously_use_uv")
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest.fixture(scope="function")
def mcp_enabled_server(monkeypatch_module: pytest.MonkeyPatch):
args = ["--enforce-eager", "--tool-server", "demo"]
with monkeypatch_module.context() as m:
m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1")
m.setenv("PYTHON_EXECUTION_BACKEND", "dangerously_use_uv")
m.setenv("GPT_OSS_SYSTEM_TOOL_MCP_LABELS",
"code_interpreter,container")
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def mcp_disabled_client(mcp_disabled_server):
async with mcp_disabled_server.get_async_client() as async_client:
yield async_client
@pytest_asyncio.fixture
async def mcp_enabled_client(mcp_enabled_server):
async with mcp_enabled_server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.skip(reason="Code interpreter tool is not available in CI yet.")
async def test_mcp_tool_env_flag_enabled(mcp_enabled_client: OpenAI,
model_name: str):
response = await mcp_enabled_client.responses.create(
model=model_name,
# TODO: Ideally should be able to set max tool calls
# to prevent multi-turn, but it is not currently supported
# would speed up the test
input=("What's the first 4 digits after the decimal point of "
"cube root of `19910212 * 20250910`? "
"Show only the digits. The python interpreter is not stateful "
"and you must print to see the output."),
tools=[{
"type": "mcp",
"server_label": "code_interpreter",
# URL unused for DemoToolServer
"server_url": "http://localhost:8888"
}],
)
assert response is not None
assert response.status == "completed"
assert response.usage.output_tokens_details.tool_output_tokens > 0
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.skip(reason="Code interpreter tool is not available in CI yet.")
async def test_mcp_tool_env_flag_disabled(mcp_disabled_client: OpenAI,
model_name: str):
response = await mcp_disabled_client.responses.create(
model=model_name,
# TODO: Ideally should be able to set max tool calls
# to prevent multi-turn, but it is not currently supported
# would speed up the test
input=("What's the first 4 digits after the decimal point of "
"cube root of `19910212 * 20250910`? "
"Show only the digits. The python interpreter is not stateful "
"and you must print to see the output."),
tools=[{
"type": "mcp",
"server_label": "code_interpreter",
# URL unused for DemoToolServer
"server_url": "http://localhost:8888"
}],
)
assert response is not None
assert response.status == "completed"
assert response.usage.output_tokens_details.tool_output_tokens == 0

View File

@ -454,7 +454,13 @@ async def test_web_search(client: OpenAI, model_name: str):
async def test_code_interpreter(client: OpenAI, model_name: str):
response = await client.responses.create(
model=model_name,
input="Multiply 64548*15151 using builtin python interpreter.",
# TODO: Ideally should be able to set max tool calls
# to prevent multi-turn, but it is not currently supported
# would speed up the test
input=("What's the first 4 digits after the decimal point of "
"cube root of `19910212 * 20250910`? "
"Show only the digits. The python interpreter is not stateful "
"and you must print to see the output."),
tools=[{
"type": "code_interpreter",
"container": {
@ -464,6 +470,7 @@ async def test_code_interpreter(client: OpenAI, model_name: str):
)
assert response is not None
assert response.status == "completed"
assert response.usage.output_tokens_details.tool_output_tokens > 0
def get_weather(latitude, longitude):

View File

@ -5,6 +5,11 @@ import json
import pytest
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import (
Hermes2ProToolParser)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from ....utils import RemoteOpenAIServer
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
@ -37,7 +42,7 @@ TOOLS = [{
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"]
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
@ -45,8 +50,39 @@ TOOLS = [{
},
}]
PRODUCT_TOOLS = [{
"type": "function",
"function": {
"name": "get_product_info",
"description": "Get detailed information of a product based on its "
"product ID.",
"parameters": {
"type": "object",
"properties": {
"inserted": {
"type": "boolean",
"description": "inserted.",
},
"product_id": {
"type": "integer",
"description": "The product ID of the product.",
},
},
"required": ["product_id", "inserted"],
},
},
}]
MESSAGES = [{"role": "user", "content": "What's the weather like in Boston?"}]
PRODUCT_MESSAGES = [{
"role":
"user",
"content":
"Hi! Do you have any detailed information about the product id "
"7355608 and inserted true?",
}]
@pytest.mark.asyncio
async def test_non_streaming_tool_call():
@ -113,8 +149,8 @@ async def test_streaming_tool_call():
if tool_chunk.function.name:
tool_call_chunks[index]["name"] += tool_chunk.function.name
if tool_chunk.function.arguments:
tool_call_chunks[index][
"arguments"] += tool_chunk.function.arguments
tool_call_chunks[index]["arguments"] += (
tool_chunk.function.arguments)
assert len(tool_call_chunks) == 1
reconstructed_tool_call = tool_call_chunks[0]
@ -127,3 +163,295 @@ async def test_streaming_tool_call():
print("\n[Streaming Test Passed]")
print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}")
print(f"Reconstructed Arguments: {arguments}")
@pytest.mark.asyncio
async def test_non_streaming_product_tool_call():
"""Test tool call integer and boolean parameters in non-streaming mode."""
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server:
client = server.get_async_client()
response = await client.chat.completions.create(
model=LORA_MODEL,
messages=PRODUCT_MESSAGES,
tools=PRODUCT_TOOLS,
tool_choice="auto",
temperature=0.66,
)
assert response.choices
choice = response.choices[0]
message = choice.message
assert choice.finish_reason == "tool_calls"
assert message.tool_calls is not None
tool_call = message.tool_calls[0]
assert tool_call.type == "function"
assert tool_call.function.name == "get_product_info"
arguments = json.loads(tool_call.function.arguments)
assert "product_id" in arguments
assert "inserted" in arguments
product_id = arguments.get("product_id")
inserted = arguments.get("inserted")
assert isinstance(product_id, int)
assert product_id == 7355608
assert isinstance(inserted, bool)
assert inserted is True
print("\n[Non-Streaming Product Test Passed]")
print(f"Tool Call: {tool_call.function.name}")
print(f"Arguments: {arguments}")
@pytest.mark.asyncio
async def test_streaming_product_tool_call():
"""Test tool call integer and boolean parameters in streaming mode."""
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server:
client = server.get_async_client()
stream = await client.chat.completions.create(
model=LORA_MODEL,
messages=PRODUCT_MESSAGES,
tools=PRODUCT_TOOLS,
tool_choice="auto",
temperature=0.66,
stream=True,
)
tool_call_chunks = {}
async for chunk in stream:
if not chunk.choices:
continue
delta = chunk.choices[0].delta
if not delta or not delta.tool_calls:
continue
for tool_chunk in delta.tool_calls:
index = tool_chunk.index
if index not in tool_call_chunks:
tool_call_chunks[index] = {"name": "", "arguments": ""}
if tool_chunk.function.name:
tool_call_chunks[index]["name"] += tool_chunk.function.name
if tool_chunk.function.arguments:
tool_call_chunks[index]["arguments"] += (
tool_chunk.function.arguments)
assert len(tool_call_chunks) == 1
reconstructed_tool_call = tool_call_chunks[0]
assert reconstructed_tool_call["name"] == "get_product_info"
arguments = json.loads(reconstructed_tool_call["arguments"])
assert "product_id" in arguments
assert "inserted" in arguments
# Handle type coercion for streaming test as well
product_id = arguments.get("product_id")
inserted = arguments.get("inserted")
assert isinstance(product_id, int)
assert product_id == 7355608
assert isinstance(inserted, bool)
assert inserted is True
print("\n[Streaming Product Test Passed]")
print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}")
print(f"Reconstructed Arguments: {arguments}")
@pytest.fixture
def qwen_tokenizer() -> AnyTokenizer:
from vllm.transformers_utils.tokenizer import get_tokenizer
return get_tokenizer("Qwen/Qwen3-32B")
@pytest.fixture
def hermes_parser(qwen_tokenizer: AnyTokenizer) -> Hermes2ProToolParser:
return Hermes2ProToolParser(qwen_tokenizer)
@pytest.fixture
def any_chat_request() -> ChatCompletionRequest:
return ChatCompletionRequest(
seed=42,
model="Qwen/Qwen3-32B",
messages=[],
)
def test_hermes_parser_streaming_just_forward_text(
qwen_tokenizer: AnyTokenizer,
hermes_parser: Hermes2ProToolParser,
any_chat_request: ChatCompletionRequest,
) -> None:
text = (
"""This is some prior text that has nothing to do with tool calling."""
)
tokens = qwen_tokenizer.encode(text)
previous_text = ""
delta_messages = []
for token in tokens:
delta_text = qwen_tokenizer.decode([token])
current_text = previous_text + delta_text
delta = hermes_parser.extract_tool_calls_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=any_chat_request,
)
previous_text = current_text
delta_messages.append(delta)
for delta in delta_messages:
assert delta is not None
assert not delta.tool_calls
print(delta_messages)
assert "".join([delta.content for delta in delta_messages]) == text
def test_hermes_parser_streaming_failure_case_bug_19056(
qwen_tokenizer: AnyTokenizer,
hermes_parser: Hermes2ProToolParser,
any_chat_request: ChatCompletionRequest,
) -> None:
text = """<tool_call>
{"name": "final_answer", "arguments": {"trigger": true}}
</tool_call>"""
tokens = qwen_tokenizer.encode(text)
previous_text = ""
delta_messages = []
for token in tokens:
text = qwen_tokenizer.decode([token])
current_text = previous_text + text
delta = hermes_parser.extract_tool_calls_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=text,
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=any_chat_request,
)
previous_text = current_text
if delta is not None:
delta_messages.append(delta)
assert delta_messages[0].tool_calls[0].function.name == "final_answer"
tool_call_args = "".join(delta.tool_calls[0].function.arguments or ""
for delta in delta_messages)
assert tool_call_args == '{"trigger": true}'
def test_hermes_parser_streaming(
qwen_tokenizer: AnyTokenizer,
hermes_parser: Hermes2ProToolParser,
any_chat_request: ChatCompletionRequest,
) -> None:
text = '<tool_call>\
{"name": "get_current_temperature",\
"arguments": {"location":\
"San Francisco, California, United States", "unit": "celsius"}}\
</tool_call>'
tokens = qwen_tokenizer.encode(text)
previous_text = ""
delta_messages = []
for token in tokens:
text = qwen_tokenizer.decode([token])
current_text = previous_text + text
delta = hermes_parser.extract_tool_calls_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=text,
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=any_chat_request,
)
previous_text = current_text
if delta is not None:
delta_messages.append(delta)
print(delta_messages)
assert (delta_messages[0].tool_calls[0].function.name ==
"get_current_temperature")
tool_call_args = "".join(delta.tool_calls[0].function.arguments or ""
for delta in delta_messages)
assert tool_call_args == (
'{"location":"San Francisco, California, United States", '
'"unit": "celsius"}')
def test_hermes_parser_non_streaming_no_tool_call(
hermes_parser: Hermes2ProToolParser,
any_chat_request: ChatCompletionRequest,
) -> None:
text = """This is not a tool call."""
tool_call = hermes_parser.extract_tool_calls(
model_output=text,
request=any_chat_request,
)
assert tool_call is not None
assert not tool_call.tools_called
def test_hermes_parser_non_streaming_tool_call_between_tags(
hermes_parser: Hermes2ProToolParser,
any_chat_request: ChatCompletionRequest,
) -> None:
text = """<tool_call>
{"name": "final_answer", "arguments": {"trigger": true}}
</tool_call>"""
tool_call = hermes_parser.extract_tool_calls(
model_output=text,
request=any_chat_request,
)
assert tool_call is not None
assert tool_call.tools_called
assert tool_call.tool_calls[0].function.name == "final_answer"
assert tool_call.tool_calls[0].function.arguments == '{"trigger": true}'
def test_hermes_parser_non_streaming_tool_call_until_eos(
hermes_parser: Hermes2ProToolParser,
any_chat_request: ChatCompletionRequest,
) -> None:
text = """<tool_call>
{"name": "final_answer", "arguments": {"trigger": true}}"""
tool_call = hermes_parser.extract_tool_calls(
model_output=text,
request=any_chat_request,
)
assert tool_call is not None
assert tool_call.tools_called
assert tool_call.tool_calls[0].function.name == "final_answer"
assert tool_call.tool_calls[0].function.arguments == '{"trigger": true}'
def test_hermes_parser_non_streaming_tool_call_invalid_json(
hermes_parser: Hermes2ProToolParser,
any_chat_request: ChatCompletionRequest,
) -> None:
# Missing closing brace to trigger exception
text = """<tool_call>
{"name": "final_answer", "arguments": {"trigger": true}"""
tool_call = hermes_parser.extract_tool_calls(
model_output=text,
request=any_chat_request,
)
assert tool_call is not None
assert not tool_call.tools_called

View File

@ -19,7 +19,7 @@ pytest -s -v tests/gsm8k/test_gsm8k_correctness.py \
vllm serve Qwen/Qwen2.5-1.5B-Instruct --port 8000
# Run evaluation
python tests/gsm8k/gsm8k_eval.py --port 8000
python tests/evals/gsm8k/gsm8k_eval.py --port 8000
```
## Configuration Format

View File

@ -67,7 +67,6 @@ def generate_params():
return params
@pytest.mark.skip(reason="Skipped for now. Should be revisited.")
@pytest.mark.parametrize("device, name, use_mla, block_size",
generate_params())
def test_env(
@ -189,7 +188,7 @@ def test_env(
# FlashMLA only supports block_size == 64
pytest.skip("FlashMLA only supports block_size 64")
else:
from vllm.attention.backends.flashmla import (
from vllm.v1.attention.backends.mla.flashmla import ( # noqa: E501
is_flashmla_supported)
is_supported, _ = is_flashmla_supported()
if not is_supported:

View File

@ -959,7 +959,6 @@ def make_test_metadata(
return attn_backend_obj.make_metadata(
num_prefills=num_prefills,
slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
@ -1009,7 +1008,6 @@ def make_test_metadata(
return attn_backend_obj.make_metadata(
num_prefills=num_prefills,
slot_mapping=kv_mmap.slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,

View File

@ -164,8 +164,8 @@ def populate_loras(
weight=layer_weights,
generate_embeddings_tensor=generate_embeddings_tensor,
)
sublora.lora_b = sublora.lora_b[:, (sublora_len *
i):(sublora_len * (i + 1))]
sublora.lora_b = sublora.lora_b[(sublora_len *
i):(sublora_len * (i + 1)), :]
sublora.optimize()
subloras.append(sublora)
@ -304,9 +304,9 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
result = embedding(input_)
after_a = F.embedding(
input_,
lora.lora_a,
lora.lora_a.T,
)
result += (after_a @ lora.lora_b)
result += (after_a @ lora.lora_b.T)
expected_results.append(result)
expected_result = torch.cat(expected_results)
@ -445,9 +445,9 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
result = expanded_embedding(input_)
after_a = F.embedding(
original_input_,
lora.lora_a,
lora.lora_a.T,
)
result += (after_a @ lora.lora_b)
result += (after_a @ lora.lora_b.T)
expected_results.append(result)
expected_result = torch.cat(expected_results)
@ -575,7 +575,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
lm_head=linear,
embedding_bias=None)
result[:, vocab_size + embeddings_tensor_len:] = float("-inf")
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
expected_results.append(result)
expected_result = torch.cat(expected_results)
logits_processor.org_vocab_size = vocab_size
@ -692,9 +692,10 @@ def test_linear_replicated(
expected_results: list[torch.Tensor] = []
for input_, lora_id in zip(inputs, prompt_mapping):
lora = lora_dict[lora_id]
result = linear(input_)[0]
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
expected_results.append(result)
expected_result = torch.cat(expected_results)
@ -817,7 +818,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
for input_, lora_id in zip(inputs, prompt_mapping):
lora = lora_dict[lora_id]
result = linear(input_)[0]
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
expected_results.append(result)
expected_result = torch.cat(expected_results)
@ -965,9 +966,10 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
result = linear(input_)[0]
subloras = sublora_dict[lora_id]
for i, sublora in enumerate(subloras):
result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] *
(i + 1)] += (input_ @ sublora.lora_a @ sublora.lora_b *
sublora.scaling)
result[:, sublora.lora_b.shape[0] * i:sublora.lora_b.shape[0] *
(i + 1)] += (
input_ @ sublora.lora_a.T @ sublora.lora_b.T *
sublora.scaling)
expected_results.append(result)
expected_result = torch.cat(expected_results)

View File

@ -63,9 +63,9 @@ def test_from_lora_tensors(sql_lora_files, device):
assert lora.lora_b is not None
assert lora.lora_a.device == torch.device(device)
assert lora.lora_b.device == torch.device(device)
assert (lora.lora_a.shape[1] == lora.lora_b.shape[0]
assert (lora.lora_a.shape[0] == lora.lora_b.shape[1]
), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
assert lora.lora_a.shape[1] == 8
assert lora.lora_a.shape[0] == 8
embeddings_module = next(
(k for k in EMBEDDING_MODULES if k in module_name), None)
if embeddings_module:
@ -86,8 +86,8 @@ def create_lora(lora_id: int, model: nn.Module, sub_modules: list[str],
name,
8,
16,
torch.rand([w.shape[1], 8], device=device),
torch.rand([8, w.shape[0]], device=device),
torch.rand([8, w.shape[1]], device=device),
torch.rand([w.shape[0], 8], device=device),
)
return LoRAModel(lora_id, 8, loras)
@ -109,8 +109,8 @@ def create_packed_lora(
replaced_module_name,
8,
16,
torch.rand([w.shape[1], 8], device=device),
torch.rand([8, w.shape[0] // len(replaced_module_names)],
torch.rand([8, w.shape[1]], device=device),
torch.rand([w.shape[0] // len(replaced_module_names), 8],
device=device),
)
return LoRAModel(lora_id, 8, loras)

View File

@ -36,10 +36,10 @@ class DummyLoRAManager:
module_name,
rank=rank,
lora_alpha=1,
lora_a=torch.rand([weight.shape[1], rank],
lora_a=torch.rand([rank, weight.shape[1]],
dtype=weight.dtype,
device=self._device),
lora_b=torch.rand([rank, weight.shape[0]],
lora_b=torch.rand([weight.shape[0], rank],
dtype=weight.dtype,
device=self._device),
)
@ -67,8 +67,8 @@ class DummyLoRAManager:
module_name,
rank=rank,
lora_alpha=1,
lora_a=torch.rand([input_dim, rank], device="cuda"),
lora_b=torch.rand([rank, output_dim], device="cuda"),
lora_a=torch.rand([rank, input_dim], device="cuda"),
lora_b=torch.rand([output_dim, input_dim], device="cuda"),
embeddings_tensor=embeddings_tensor,
)
self.set_module_lora(module_name, lora)

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import pytest
import torch
@ -34,15 +35,15 @@ class Relu3(ReLUSquaredActivation):
[
# Default values based on compile level
# - All by default (no Inductor compilation)
("", 0, False, [True] * 4, True),
("", 1, True, [True] * 4, True),
("", 2, False, [True] * 4, True),
(None, 0, False, [True] * 4, True),
(None, 1, True, [True] * 4, True),
(None, 2, False, [True] * 4, True),
# - None by default (with Inductor)
("", 3, True, [False] * 4, False),
("", 4, True, [False] * 4, False),
(None, 3, True, [False] * 4, False),
(None, 4, True, [False] * 4, False),
# - All by default (without Inductor)
("", 3, False, [True] * 4, True),
("", 4, False, [True] * 4, True),
(None, 3, False, [True] * 4, True),
(None, 4, False, [True] * 4, True),
# Explicitly enabling/disabling
#
# Default: all
@ -54,7 +55,7 @@ class Relu3(ReLUSquaredActivation):
# All but SiluAndMul
("all,-silu_and_mul", 2, True, [1, 0, 1, 1], True),
# All but ReLU3 (even if ReLU2 is on)
("-relu3,relu2", 3, False, [1, 1, 1, 0], True),
("-relu3,+relu2", 3, False, [1, 1, 1, 0], True),
# RMSNorm and SiluAndMul
("none,-relu3,+rms_norm,+silu_and_mul", 4, False, [1, 1, 0, 0], False),
# All but RMSNorm
@ -67,12 +68,13 @@ class Relu3(ReLUSquaredActivation):
# All but RMSNorm
("all,-rms_norm", 4, True, [0, 1, 1, 1], True),
])
def test_enabled_ops(env: str, torch_level: int, use_inductor: bool,
def test_enabled_ops(env: Optional[str], torch_level: int, use_inductor: bool,
ops_enabled: list[int], default_on: bool):
custom_ops = env.split(',') if env else []
vllm_config = VllmConfig(
compilation_config=CompilationConfig(use_inductor=bool(use_inductor),
level=torch_level,
custom_ops=env.split(",")))
custom_ops=custom_ops))
with set_current_vllm_config(vllm_config):
assert CustomOp.default_on() == default_on

View File

@ -20,7 +20,9 @@ pytestmark = pytest.mark.hybrid_model
SSM_MODELS = [
"state-spaces/mamba-130m-hf",
"tiiuae/falcon-mamba-tiny-dev",
"yujiepan/mamba2-codestral-v0.1-tiny-random",
# mamba2-codestral in transformers is broken pending:
# https://github.com/huggingface/transformers/pull/40861
#"yujiepan/mamba2-codestral-v0.1-tiny-random",
]
HYBRID_MODELS = [
@ -31,18 +33,7 @@ HYBRID_MODELS = [
"ibm-granite/granite-4.0-tiny-preview",
"tiiuae/Falcon-H1-0.5B-Base",
"LiquidAI/LFM2-1.2B",
]
V1_SUPPORTED_MODELS = [
"state-spaces/mamba-130m-hf",
"ai21labs/Jamba-tiny-dev",
"pfnet/plamo-2-1b",
"yujiepan/mamba2-codestral-v0.1-tiny-random",
"Zyphra/Zamba2-1.2B-instruct",
"hmellor/tiny-random-BambaForCausalLM",
"ibm-granite/granite-4.0-tiny-preview",
"tiiuae/Falcon-H1-0.5B-Base",
"LiquidAI/LFM2-1.2B",
"tiny-random/qwen3-next-moe",
]
FULL_CUDA_GRAPH_MODELS = [
@ -51,10 +42,6 @@ FULL_CUDA_GRAPH_MODELS = [
"Zyphra/Zamba2-1.2B-instruct",
]
V0_UNSUPPORTED_MODELS = [
"LiquidAI/LFM2-1.2B",
]
FP32_STATE_MODELS = [
"state-spaces/mamba-130m-hf",
"Zyphra/Zamba2-1.2B-instruct",
@ -88,20 +75,16 @@ def test_models(
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)
if model in V1_SUPPORTED_MODELS:
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
else:
vllm_v1_outputs = None
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
if model in V1_SUPPORTED_MODELS:
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_v1_outputs,
name_0="hf",
name_1="vllm-v1",
)
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
@ -299,14 +282,14 @@ def test_full_cuda_graph(
example_prompts, max_tokens, num_logprobs)
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_v1_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm-v1",
name_1="vllm",
)
@ -340,12 +323,12 @@ def test_fp32_cache_state(
with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
**{cache_dtype_param: "float32"}) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_v1_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm-v1",
name_1="vllm",
)

View File

@ -209,7 +209,6 @@ def batch_make_video_embeddings(
return visual(pixel_values_on_device,
grid_thw=video_grid_thw_on_device).cpu()
# V1 Test: this calls a V0 internal.
video_embeds = torch.concat(llm.apply_model(get_image_embeds))
# split into original batches

View File

@ -312,14 +312,12 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"),
"PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"),
"Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"),
"Phi4FlashForCausalLM": _HfExamplesInfo("microsoft/Phi-4-mini-flash-reasoning", # noqa: E501
trust_remote_code=True,
v0_only=True,
max_model_len=10240),
"PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct",
trust_remote_code=True),
"Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b",
trust_remote_code=True),
max_transformers_version="4.55.4",
transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers", # noqa: E501
trust_remote_code=True),
"QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat",
max_transformers_version="4.53",
transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers", # noqa: E501
@ -330,7 +328,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
"Qwen3NextForCausalLM": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct",
min_transformers_version="4.56.2"),
extras={"tiny-random": "tiny-random/qwen3-next-moe"}, # noqa: E501
min_transformers_version="4.56.3"),
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
"SeedOssForCausalLM": _HfExamplesInfo("ByteDance-Seed/Seed-OSS-36B-Instruct", # noqa: E501
trust_remote_code=True,
@ -448,6 +447,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
max_transformers_version="4.48", # noqa: E501
transformers_version_reason="HF model is not compatible.", # noqa: E501
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501
"DotsOCRForCausalLM": _HfExamplesInfo("rednote-hilab/dots.ocr",
trust_remote_code=True),
"Emu3ForConditionalGeneration": _HfExamplesInfo("BAAI/Emu3-Chat-hf"),
"Ernie4_5_VLMoeForConditionalGeneration": _HfExamplesInfo("baidu/ERNIE-4.5-VL-28B-A3B-PT", # noqa: E501
trust_remote_code=True),
@ -560,10 +561,12 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ"), # noqa: E501
"Qwen3VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen3-VL-4B-Instruct", # noqa: E501
max_model_len=4096,
min_transformers_version="4.57"), # noqa: E501
min_transformers_version="4.57",
is_available_online=False),
"Qwen3VLMoeForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen3-VL-30B-A3B-Instruct", # noqa: E501
max_model_len=4096,
min_transformers_version="4.57"),
max_model_len=4096,
min_transformers_version="4.57",
is_available_online=False),
"RForConditionalGeneration": _HfExamplesInfo("YannQi/R-4B",
trust_remote_code=True),
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B",
@ -640,7 +643,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
trust_remote_code=True,
speculative_model="XiaomiMiMo/MiMo-7B-RL"),
"Qwen3NextMTP": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct",
min_transformers_version="4.56.2"),
min_transformers_version="4.56.3"),
}
_TRANSFORMERS_BACKEND_MODELS = {

View File

@ -1,10 +1,20 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
import pytest
import torch
import torch.multiprocessing as mp
from vllm.model_executor.models.vision import resolve_visual_encoder_outputs
from tests.utils import multi_gpu_test
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import (init_distributed_environment,
initialize_model_parallel)
from vllm.model_executor.models.vision import (
get_load_balance_assignment, resolve_visual_encoder_outputs,
run_dp_sharded_mrope_vision_model, run_dp_sharded_vision_model)
from vllm.platforms import current_platform
from vllm.utils import get_open_port, update_environment_variables
@pytest.mark.parametrize(
@ -33,3 +43,415 @@ def test_resolve_visual_encoder_outputs(feature_sample_layers,
post_layer_norm=None,
max_possible_layers=max_possible_layers)
assert torch.equal(torch.tensor(expected_features), output_tensor)
class SimpleLinearModel(torch.nn.Module):
"""A simple linear vision model for testing."""
def __init__(self, input_dim: int = 3 * 224 * 224, output_dim: int = 32):
super().__init__()
self.flatten = torch.nn.Flatten()
self.linear = torch.nn.Linear(input_dim, output_dim)
def forward(self, x: torch.Tensor):
# Flatten the input and apply linear transformation
x = self.flatten(x)
return self.linear(x)
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"batch_size",
[
1, # Single image
4, # Small batch
5, # Odd batch size (for testing padding)
],
)
def test_run_dp_sharded_vision_model(batch_size: int):
world_size = 2
# Launch processes
mp.spawn(
run_dp_sharded_vision_model_vs_direct,
args=(
world_size,
batch_size,
get_open_port(),
),
nprocs=world_size,
)
def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int,
batch_size: int, master_port: int):
"""
Test that run_dp_sharded_vision_model produces the same results as
calling the model directly.
"""
# Set random seed for reproducibility
current_platform.seed_everything(0)
device = f"{current_platform.device_name}:{local_rank}"
current_platform.set_device(device)
torch.set_default_device(device)
update_environment_variables({
'RANK': str(local_rank),
'LOCAL_RANK': str(local_rank),
'WORLD_SIZE': str(world_size),
'MASTER_ADDR': 'localhost',
'MASTER_PORT': str(master_port),
})
# initialize distributed
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
# Create a test input tensor
image_input = torch.randn(batch_size, 3, 224, 224)
# Create a simple linear model
vision_model = SimpleLinearModel()
# Run the model directly on the full input
with torch.inference_mode():
direct_output = vision_model(image_input)
# Run the model through the sharded function
with torch.inference_mode():
sharded_output = run_dp_sharded_vision_model(image_input, vision_model)
# Check that the world size is set up correctly
assert get_tensor_model_parallel_world_size() == world_size
# Check that the outputs have the same shape
assert direct_output.shape == sharded_output.shape
# Check that the outputs are close (they should be identical)
assert torch.allclose(direct_output, sharded_output, rtol=1e-5, atol=1e-5)
@pytest.mark.parametrize(
"sizes,num_gpus,expected_shuffle_indices,expected_gpu_sample_counts,"
"expected_grouped_sizes_per_gpu,test_description",
[
# Empty input
([], 2, [], [0, 0], [0, 0], "empty input"),
# Fewer samples than GPUs
([100, 200], 4, [1, 0], [1, 1, 0, 0], [200, 100, 0, 0
], "fewer samples than GPUs"),
# Single GPU
([100, 200, 300], 1, [2, 1, 0], [3], [600], "single GPU"),
# Balanced assignment
([100, 100, 100, 100
], 2, [0, 2, 1, 3], [2, 2], [200, 200], "balanced assignment"),
# Unbalanced sizes - this one is trickier since the algorithm is greedy
([1000, 100, 200, 50], 2, [0, 2, 1, 3
], [1, 3], [1000, 350], "unbalanced sizes"),
],
)
def test_get_load_balance_assignment_cases(sizes, num_gpus,
expected_shuffle_indices,
expected_gpu_sample_counts,
expected_grouped_sizes_per_gpu,
test_description):
"""Test get_load_balance_assignment with various input cases."""
result = get_load_balance_assignment(sizes, num_gpus=num_gpus)
(shuffle_indices, gpu_sample_counts, grouped_sizes_per_gpu) = result
# Common assertions for all cases
assert len(shuffle_indices) == len(sizes)
assert len(gpu_sample_counts) == num_gpus
assert len(grouped_sizes_per_gpu) == num_gpus
assert sum(gpu_sample_counts) == len(sizes)
assert shuffle_indices == expected_shuffle_indices
assert gpu_sample_counts == expected_gpu_sample_counts
assert grouped_sizes_per_gpu == expected_grouped_sizes_per_gpu
class SimpleMRopeVisionModel(torch.nn.Module):
"""A simple vision model for testing mrope functionality."""
def __init__(self, spatial_merge_size: int = 2, out_hidden_size: int = 64):
super().__init__()
self.spatial_merge_size = spatial_merge_size
self.out_hidden_size = out_hidden_size
self.linear = torch.nn.Linear(768, out_hidden_size)
def forward(self, pixel_values: torch.Tensor,
grid_thw_list: list[list[int]]):
"""Simple forward pass that simulates spatial merging."""
# Apply linear transformation
embeddings = self.linear(pixel_values)
# Simulate spatial merging by reducing the number of patches
merge_factor = self.spatial_merge_size * self.spatial_merge_size
# Group patches and merge spatially
merged_embeddings = []
start_idx = 0
for grid_thw in grid_thw_list:
num_patches = math.prod(grid_thw)
end_idx = start_idx + num_patches
# Get patches for this image
image_patches = embeddings[start_idx:end_idx]
# Simulate spatial merging by averaging groups of patches
merged_patches = num_patches // merge_factor
if merged_patches > 0:
# Reshape and average to simulate merging
reshaped = image_patches[:merged_patches * merge_factor].view(
merged_patches, merge_factor, -1)
merged = reshaped.mean(dim=1)
merged_embeddings.append(merged)
start_idx = end_idx
if merged_embeddings:
return torch.cat(merged_embeddings, dim=0)
else:
return torch.empty((0, self.out_hidden_size),
device=pixel_values.device,
dtype=pixel_values.dtype)
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"batch_size",
[
1, # Single image
3, # Small batch
5, # Odd batch size (for testing padding)
],
)
def test_run_dp_sharded_mrope_vision_model(batch_size: int):
world_size = 2
# Launch processes
mp.spawn(
run_dp_sharded_mrope_vision_model_vs_direct,
args=(
world_size,
batch_size,
get_open_port(),
),
nprocs=world_size,
)
def run_dp_sharded_mrope_vision_model_vs_direct(local_rank: int,
world_size: int,
batch_size: int,
master_port: int):
"""
Test that run_dp_sharded_mrope_vision_model produces the same results as
calling the model directly.
"""
# Set random seed for reproducibility
current_platform.seed_everything(0)
device = f"{current_platform.device_name}:{local_rank}"
current_platform.set_device(device)
torch.set_default_device(device)
update_environment_variables({
'RANK': str(local_rank),
'LOCAL_RANK': str(local_rank),
'WORLD_SIZE': str(world_size),
'MASTER_ADDR': 'localhost',
'MASTER_PORT': str(master_port),
})
# initialize distributed
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
# Create test data
grid_thw_list = []
pixel_values_list = []
for i in range(batch_size):
# Varying image sizes for better testing
t, h, w = 1, 4 + i, 4 + i
grid_thw_list.append([t, h, w])
num_patches = t * h * w
# Create random pixel values for this image
image_pixels = torch.randn(num_patches, 768)
pixel_values_list.append(image_pixels)
# Concatenate all pixel values
pixel_values = torch.cat(pixel_values_list, dim=0)
# Create a simple mrope vision model
vision_model = SimpleMRopeVisionModel()
# Run the model directly on the full input (only on rank 0)
if local_rank == 0:
with torch.inference_mode():
direct_output = vision_model(pixel_values, grid_thw_list)
# Run the model through the sharded function
with torch.inference_mode():
sharded_output = run_dp_sharded_mrope_vision_model(vision_model,
pixel_values,
grid_thw_list,
rope_type="rope_3d")
sharded_output = torch.cat(sharded_output, dim=0)
# Check that the world size is set up correctly
assert get_tensor_model_parallel_world_size() == world_size
# Compare outputs (only on rank 0)
if local_rank == 0:
# Check that the outputs have the same shape
assert direct_output.shape == sharded_output.shape
# Check that the outputs are close (they should be identical)
assert torch.allclose(direct_output,
sharded_output,
rtol=1e-5,
atol=1e-5)
@multi_gpu_test(num_gpus=2)
def test_run_dp_sharded_mrope_vision_model_empty_input():
world_size = 2
mp.spawn(
run_dp_sharded_mrope_vision_model_empty_input_worker,
args=(world_size, get_open_port()),
nprocs=world_size,
)
def run_dp_sharded_mrope_vision_model_empty_input_worker(
local_rank: int, world_size: int, master_port: int):
"""Test run_dp_sharded_mrope_vision_model with empty input."""
# Set up distributed environment
device = f"{current_platform.device_name}:{local_rank}"
current_platform.set_device(device)
torch.set_default_device(device)
update_environment_variables({
'RANK': str(local_rank),
'LOCAL_RANK': str(local_rank),
'WORLD_SIZE': str(world_size),
'MASTER_ADDR': 'localhost',
'MASTER_PORT': str(master_port),
})
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
# Create empty inputs
pixel_values = torch.empty((0, 768))
grid_thw_list: list[list[int]] = []
vision_model = SimpleMRopeVisionModel()
# Should handle empty input gracefully
with torch.inference_mode():
output = run_dp_sharded_mrope_vision_model(vision_model,
pixel_values,
grid_thw_list,
rope_type="rope_3d")
assert len(output) == 0
@multi_gpu_test(num_gpus=4)
def test_run_dp_sharded_mrope_vision_model_uneven_load():
world_size = 4
mp.spawn(
run_dp_sharded_mrope_vision_model_uneven_load_worker,
args=(world_size, get_open_port()),
nprocs=world_size,
)
def run_dp_sharded_mrope_vision_model_uneven_load_worker(
local_rank: int, world_size: int, master_port: int):
"""Test run_dp_sharded_mrope_vision_model with uneven load distribution."""
# Set up distributed environment
current_platform.seed_everything(123)
device = f"{current_platform.device_name}:{local_rank}"
current_platform.set_device(device)
torch.set_default_device(device)
update_environment_variables({
'RANK': str(local_rank),
'LOCAL_RANK': str(local_rank),
'WORLD_SIZE': str(world_size),
'MASTER_ADDR': 'localhost',
'MASTER_PORT': str(master_port),
})
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
# Create images with very different sizes
grid_thw_list = [
[1, 2, 2], # Small: 4 patches
[1, 8, 8], # Large: 64 patches
[1, 3, 3], # Medium: 9 patches
]
pixel_values_list = []
for grid_thw in grid_thw_list:
num_patches = math.prod(grid_thw)
image_pixels = torch.randn(num_patches, 768)
pixel_values_list.append(image_pixels)
pixel_values = torch.cat(pixel_values_list, dim=0)
vision_model = SimpleMRopeVisionModel()
# Should handle uneven distribution without errors
with torch.inference_mode():
output_tuple = run_dp_sharded_mrope_vision_model(vision_model,
pixel_values,
grid_thw_list,
rope_type="rope_3d")
# Verify output shape is reasonable
merge_factor = vision_model.spatial_merge_size**2
expected_output_patches = list(
math.prod(grid_thw) // merge_factor for grid_thw in grid_thw_list)
for i, output in enumerate(output_tuple):
assert output.shape[0] == expected_output_patches[i]
assert output.shape[1] == vision_model.out_hidden_size
@pytest.mark.parametrize("spatial_merge_size", [2, 4])
def test_simple_mrope_vision_model_spatial_merge(spatial_merge_size: int):
"""Test SimpleMRopeVisionModel with different spatial merge sizes."""
device = current_platform.device_type
grid_thw_list = [[1, 4, 4], [1, 6, 6]] # Two images
pixel_values_list = []
for grid_thw in grid_thw_list:
num_patches = math.prod(grid_thw)
image_pixels = torch.randn(num_patches, 768, device=device)
pixel_values_list.append(image_pixels)
pixel_values = torch.cat(pixel_values_list, dim=0)
vision_model = SimpleMRopeVisionModel(
spatial_merge_size=spatial_merge_size).to(device)
with torch.inference_mode():
output = vision_model(pixel_values, grid_thw_list)
# Verify output dimensions based on spatial merging
total_patches = sum(math.prod(grid_thw) for grid_thw in grid_thw_list)
merge_factor = spatial_merge_size**2
expected_output_patches = total_patches // merge_factor
assert output.shape[0] == expected_output_patches
assert output.shape[1] == vision_model.out_hidden_size

View File

@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
import math
import mimetypes
import os
from tempfile import NamedTemporaryFile, TemporaryDirectory
@ -10,22 +9,11 @@ from typing import TYPE_CHECKING, NamedTuple
import numpy as np
import pytest
import torch
import torch.multiprocessing as mp
from PIL import Image, ImageChops
from tests.utils import multi_gpu_test
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import (init_distributed_environment,
initialize_model_parallel)
from vllm.multimodal.image import convert_image_mode
from vllm.multimodal.inputs import PlaceholderRange
from vllm.multimodal.utils import (MediaConnector, argsort_mm_positions,
get_load_balance_assignment,
run_dp_sharded_mrope_vision_model,
run_dp_sharded_vision_model)
from vllm.platforms import current_platform
from vllm.utils import get_open_port, update_environment_variables
from vllm.multimodal.utils import MediaConnector, argsort_mm_positions
if TYPE_CHECKING:
from vllm.multimodal.inputs import MultiModalPlaceholderDict
@ -404,415 +392,3 @@ def test_argsort_mm_positions():
modality_idxs = argsort_mm_positions(mm_positions)
assert modality_idxs == expected_modality_idxs
class SimpleLinearModel(torch.nn.Module):
"""A simple linear vision model for testing."""
def __init__(self, input_dim: int = 3 * 224 * 224, output_dim: int = 32):
super().__init__()
self.flatten = torch.nn.Flatten()
self.linear = torch.nn.Linear(input_dim, output_dim)
def forward(self, x: torch.Tensor):
# Flatten the input and apply linear transformation
x = self.flatten(x)
return self.linear(x)
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"batch_size",
[
1, # Single image
4, # Small batch
5, # Odd batch size (for testing padding)
],
)
def test_run_dp_sharded_vision_model(batch_size: int):
world_size = 2
# Launch processes
mp.spawn(
run_dp_sharded_vision_model_vs_direct,
args=(
world_size,
batch_size,
get_open_port(),
),
nprocs=world_size,
)
def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int,
batch_size: int, master_port: int):
"""
Test that run_dp_sharded_vision_model produces the same results as
calling the model directly.
"""
# Set random seed for reproducibility
current_platform.seed_everything(0)
device = f"{current_platform.device_name}:{local_rank}"
current_platform.set_device(device)
torch.set_default_device(device)
update_environment_variables({
'RANK': str(local_rank),
'LOCAL_RANK': str(local_rank),
'WORLD_SIZE': str(world_size),
'MASTER_ADDR': 'localhost',
'MASTER_PORT': str(master_port),
})
# initialize distributed
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
# Create a test input tensor
image_input = torch.randn(batch_size, 3, 224, 224)
# Create a simple linear model
vision_model = SimpleLinearModel()
# Run the model directly on the full input
with torch.inference_mode():
direct_output = vision_model(image_input)
# Run the model through the sharded function
with torch.inference_mode():
sharded_output = run_dp_sharded_vision_model(image_input, vision_model)
# Check that the world size is set up correctly
assert get_tensor_model_parallel_world_size() == world_size
# Check that the outputs have the same shape
assert direct_output.shape == sharded_output.shape
# Check that the outputs are close (they should be identical)
assert torch.allclose(direct_output, sharded_output, rtol=1e-5, atol=1e-5)
@pytest.mark.parametrize(
"sizes,num_gpus,expected_shuffle_indices,expected_gpu_sample_counts,"
"expected_grouped_sizes_per_gpu,test_description",
[
# Empty input
([], 2, [], [0, 0], [0, 0], "empty input"),
# Fewer samples than GPUs
([100, 200], 4, [1, 0], [1, 1, 0, 0], [200, 100, 0, 0
], "fewer samples than GPUs"),
# Single GPU
([100, 200, 300], 1, [2, 1, 0], [3], [600], "single GPU"),
# Balanced assignment
([100, 100, 100, 100
], 2, [0, 2, 1, 3], [2, 2], [200, 200], "balanced assignment"),
# Unbalanced sizes - this one is trickier since the algorithm is greedy
([1000, 100, 200, 50], 2, [0, 2, 1, 3
], [1, 3], [1000, 350], "unbalanced sizes"),
],
)
def test_get_load_balance_assignment_cases(sizes, num_gpus,
expected_shuffle_indices,
expected_gpu_sample_counts,
expected_grouped_sizes_per_gpu,
test_description):
"""Test get_load_balance_assignment with various input cases."""
result = get_load_balance_assignment(sizes, num_gpus=num_gpus)
(shuffle_indices, gpu_sample_counts, grouped_sizes_per_gpu) = result
# Common assertions for all cases
assert len(shuffle_indices) == len(sizes)
assert len(gpu_sample_counts) == num_gpus
assert len(grouped_sizes_per_gpu) == num_gpus
assert sum(gpu_sample_counts) == len(sizes)
assert shuffle_indices == expected_shuffle_indices
assert gpu_sample_counts == expected_gpu_sample_counts
assert grouped_sizes_per_gpu == expected_grouped_sizes_per_gpu
class SimpleMRopeVisionModel(torch.nn.Module):
"""A simple vision model for testing mrope functionality."""
def __init__(self, spatial_merge_size: int = 2, out_hidden_size: int = 64):
super().__init__()
self.spatial_merge_size = spatial_merge_size
self.out_hidden_size = out_hidden_size
self.linear = torch.nn.Linear(768, out_hidden_size)
def forward(self, pixel_values: torch.Tensor,
grid_thw_list: list[list[int]]):
"""Simple forward pass that simulates spatial merging."""
# Apply linear transformation
embeddings = self.linear(pixel_values)
# Simulate spatial merging by reducing the number of patches
merge_factor = self.spatial_merge_size * self.spatial_merge_size
# Group patches and merge spatially
merged_embeddings = []
start_idx = 0
for grid_thw in grid_thw_list:
num_patches = math.prod(grid_thw)
end_idx = start_idx + num_patches
# Get patches for this image
image_patches = embeddings[start_idx:end_idx]
# Simulate spatial merging by averaging groups of patches
merged_patches = num_patches // merge_factor
if merged_patches > 0:
# Reshape and average to simulate merging
reshaped = image_patches[:merged_patches * merge_factor].view(
merged_patches, merge_factor, -1)
merged = reshaped.mean(dim=1)
merged_embeddings.append(merged)
start_idx = end_idx
if merged_embeddings:
return torch.cat(merged_embeddings, dim=0)
else:
return torch.empty((0, self.out_hidden_size),
device=pixel_values.device,
dtype=pixel_values.dtype)
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"batch_size",
[
1, # Single image
3, # Small batch
5, # Odd batch size (for testing padding)
],
)
def test_run_dp_sharded_mrope_vision_model(batch_size: int):
world_size = 2
# Launch processes
mp.spawn(
run_dp_sharded_mrope_vision_model_vs_direct,
args=(
world_size,
batch_size,
get_open_port(),
),
nprocs=world_size,
)
def run_dp_sharded_mrope_vision_model_vs_direct(local_rank: int,
world_size: int,
batch_size: int,
master_port: int):
"""
Test that run_dp_sharded_mrope_vision_model produces the same results as
calling the model directly.
"""
# Set random seed for reproducibility
current_platform.seed_everything(0)
device = f"{current_platform.device_name}:{local_rank}"
current_platform.set_device(device)
torch.set_default_device(device)
update_environment_variables({
'RANK': str(local_rank),
'LOCAL_RANK': str(local_rank),
'WORLD_SIZE': str(world_size),
'MASTER_ADDR': 'localhost',
'MASTER_PORT': str(master_port),
})
# initialize distributed
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
# Create test data
grid_thw_list = []
pixel_values_list = []
for i in range(batch_size):
# Varying image sizes for better testing
t, h, w = 1, 4 + i, 4 + i
grid_thw_list.append([t, h, w])
num_patches = t * h * w
# Create random pixel values for this image
image_pixels = torch.randn(num_patches, 768)
pixel_values_list.append(image_pixels)
# Concatenate all pixel values
pixel_values = torch.cat(pixel_values_list, dim=0)
# Create a simple mrope vision model
vision_model = SimpleMRopeVisionModel()
# Run the model directly on the full input (only on rank 0)
if local_rank == 0:
with torch.inference_mode():
direct_output = vision_model(pixel_values, grid_thw_list)
# Run the model through the sharded function
with torch.inference_mode():
sharded_output = run_dp_sharded_mrope_vision_model(vision_model,
pixel_values,
grid_thw_list,
rope_type="rope_3d")
sharded_output = torch.cat(sharded_output, dim=0)
# Check that the world size is set up correctly
assert get_tensor_model_parallel_world_size() == world_size
# Compare outputs (only on rank 0)
if local_rank == 0:
# Check that the outputs have the same shape
assert direct_output.shape == sharded_output.shape
# Check that the outputs are close (they should be identical)
assert torch.allclose(direct_output,
sharded_output,
rtol=1e-5,
atol=1e-5)
@multi_gpu_test(num_gpus=2)
def test_run_dp_sharded_mrope_vision_model_empty_input():
world_size = 2
mp.spawn(
run_dp_sharded_mrope_vision_model_empty_input_worker,
args=(world_size, get_open_port()),
nprocs=world_size,
)
def run_dp_sharded_mrope_vision_model_empty_input_worker(
local_rank: int, world_size: int, master_port: int):
"""Test run_dp_sharded_mrope_vision_model with empty input."""
# Set up distributed environment
device = f"{current_platform.device_name}:{local_rank}"
current_platform.set_device(device)
torch.set_default_device(device)
update_environment_variables({
'RANK': str(local_rank),
'LOCAL_RANK': str(local_rank),
'WORLD_SIZE': str(world_size),
'MASTER_ADDR': 'localhost',
'MASTER_PORT': str(master_port),
})
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
# Create empty inputs
pixel_values = torch.empty((0, 768))
grid_thw_list: list[list[int]] = []
vision_model = SimpleMRopeVisionModel()
# Should handle empty input gracefully
with torch.inference_mode():
output = run_dp_sharded_mrope_vision_model(vision_model,
pixel_values,
grid_thw_list,
rope_type="rope_3d")
assert len(output) == 0
@multi_gpu_test(num_gpus=4)
def test_run_dp_sharded_mrope_vision_model_uneven_load():
world_size = 4
mp.spawn(
run_dp_sharded_mrope_vision_model_uneven_load_worker,
args=(world_size, get_open_port()),
nprocs=world_size,
)
def run_dp_sharded_mrope_vision_model_uneven_load_worker(
local_rank: int, world_size: int, master_port: int):
"""Test run_dp_sharded_mrope_vision_model with uneven load distribution."""
# Set up distributed environment
current_platform.seed_everything(123)
device = f"{current_platform.device_name}:{local_rank}"
current_platform.set_device(device)
torch.set_default_device(device)
update_environment_variables({
'RANK': str(local_rank),
'LOCAL_RANK': str(local_rank),
'WORLD_SIZE': str(world_size),
'MASTER_ADDR': 'localhost',
'MASTER_PORT': str(master_port),
})
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
# Create images with very different sizes
grid_thw_list = [
[1, 2, 2], # Small: 4 patches
[1, 8, 8], # Large: 64 patches
[1, 3, 3], # Medium: 9 patches
]
pixel_values_list = []
for grid_thw in grid_thw_list:
num_patches = math.prod(grid_thw)
image_pixels = torch.randn(num_patches, 768)
pixel_values_list.append(image_pixels)
pixel_values = torch.cat(pixel_values_list, dim=0)
vision_model = SimpleMRopeVisionModel()
# Should handle uneven distribution without errors
with torch.inference_mode():
output_tuple = run_dp_sharded_mrope_vision_model(vision_model,
pixel_values,
grid_thw_list,
rope_type="rope_3d")
# Verify output shape is reasonable
merge_factor = vision_model.spatial_merge_size**2
expected_output_patches = list(
math.prod(grid_thw) // merge_factor for grid_thw in grid_thw_list)
for i, output in enumerate(output_tuple):
assert output.shape[0] == expected_output_patches[i]
assert output.shape[1] == vision_model.out_hidden_size
@pytest.mark.parametrize("spatial_merge_size", [2, 4])
def test_simple_mrope_vision_model_spatial_merge(spatial_merge_size: int):
"""Test SimpleMRopeVisionModel with different spatial merge sizes."""
device = current_platform.device_type
grid_thw_list = [[1, 4, 4], [1, 6, 6]] # Two images
pixel_values_list = []
for grid_thw in grid_thw_list:
num_patches = math.prod(grid_thw)
image_pixels = torch.randn(num_patches, 768, device=device)
pixel_values_list.append(image_pixels)
pixel_values = torch.cat(pixel_values_list, dim=0)
vision_model = SimpleMRopeVisionModel(
spatial_merge_size=spatial_merge_size).to(device)
with torch.inference_mode():
output = vision_model(pixel_values, grid_thw_list)
# Verify output dimensions based on spatial merging
total_patches = sum(math.prod(grid_thw) for grid_thw in grid_thw_list)
merge_factor = spatial_merge_size**2
expected_output_patches = total_patches // merge_factor
assert output.shape[0] == expected_output_patches
assert output.shape[1] == vision_model.out_hidden_size

216
tests/test_envs.py Normal file
View File

@ -0,0 +1,216 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from unittest.mock import patch
import pytest
from vllm.envs import env_list_with_choices, env_with_choices
class TestEnvWithChoices:
"""Test cases for env_with_choices function."""
def test_default_value_returned_when_env_not_set(self):
"""Test default is returned when env var is not set."""
env_func = env_with_choices("NONEXISTENT_ENV", "default",
["option1", "option2"])
assert env_func() == "default"
def test_none_default_returned_when_env_not_set(self):
"""Test that None is returned when env not set and default is None."""
env_func = env_with_choices("NONEXISTENT_ENV", None,
["option1", "option2"])
assert env_func() is None
def test_valid_value_returned_case_sensitive(self):
"""Test that valid value is returned in case sensitive mode."""
with patch.dict(os.environ, {"TEST_ENV": "option1"}):
env_func = env_with_choices("TEST_ENV",
"default", ["option1", "option2"],
case_sensitive=True)
assert env_func() == "option1"
def test_valid_lowercase_value_returned_case_insensitive(self):
"""Test that lowercase value is accepted in case insensitive mode."""
with patch.dict(os.environ, {"TEST_ENV": "option1"}):
env_func = env_with_choices("TEST_ENV",
"default", ["OPTION1", "OPTION2"],
case_sensitive=False)
assert env_func() == "option1"
def test_valid_uppercase_value_returned_case_insensitive(self):
"""Test that uppercase value is accepted in case insensitive mode."""
with patch.dict(os.environ, {"TEST_ENV": "OPTION1"}):
env_func = env_with_choices("TEST_ENV",
"default", ["option1", "option2"],
case_sensitive=False)
assert env_func() == "OPTION1"
def test_invalid_value_raises_error_case_sensitive(self):
"""Test that invalid value raises ValueError in case sensitive mode."""
with patch.dict(os.environ, {"TEST_ENV": "invalid"}):
env_func = env_with_choices("TEST_ENV",
"default", ["option1", "option2"],
case_sensitive=True)
with pytest.raises(ValueError,
match="Invalid value 'invalid' for TEST_ENV"):
env_func()
def test_case_mismatch_raises_error_case_sensitive(self):
"""Test that case mismatch raises ValueError in case sensitive mode."""
with patch.dict(os.environ, {"TEST_ENV": "OPTION1"}):
env_func = env_with_choices("TEST_ENV",
"default", ["option1", "option2"],
case_sensitive=True)
with pytest.raises(ValueError,
match="Invalid value 'OPTION1' for TEST_ENV"):
env_func()
def test_invalid_value_raises_error_case_insensitive(self):
"""Test that invalid value raises ValueError when case insensitive."""
with patch.dict(os.environ, {"TEST_ENV": "invalid"}):
env_func = env_with_choices("TEST_ENV",
"default", ["option1", "option2"],
case_sensitive=False)
with pytest.raises(ValueError,
match="Invalid value 'invalid' for TEST_ENV"):
env_func()
def test_callable_choices_resolved_correctly(self):
"""Test that callable choices are resolved correctly."""
def get_choices():
return ["dynamic1", "dynamic2"]
with patch.dict(os.environ, {"TEST_ENV": "dynamic1"}):
env_func = env_with_choices("TEST_ENV", "default", get_choices)
assert env_func() == "dynamic1"
def test_callable_choices_with_invalid_value(self):
"""Test that callable choices raise error for invalid values."""
def get_choices():
return ["dynamic1", "dynamic2"]
with patch.dict(os.environ, {"TEST_ENV": "invalid"}):
env_func = env_with_choices("TEST_ENV", "default", get_choices)
with pytest.raises(ValueError,
match="Invalid value 'invalid' for TEST_ENV"):
env_func()
class TestEnvListWithChoices:
"""Test cases for env_list_with_choices function."""
def test_default_list_returned_when_env_not_set(self):
"""Test that default list is returned when env var is not set."""
env_func = env_list_with_choices("NONEXISTENT_ENV",
["default1", "default2"],
["option1", "option2"])
assert env_func() == ["default1", "default2"]
def test_empty_default_list_returned_when_env_not_set(self):
"""Test that empty default list is returned when env not set."""
env_func = env_list_with_choices("NONEXISTENT_ENV", [],
["option1", "option2"])
assert env_func() == []
def test_single_valid_value_parsed_correctly(self):
"""Test that single valid value is parsed correctly."""
with patch.dict(os.environ, {"TEST_ENV": "option1"}):
env_func = env_list_with_choices("TEST_ENV", [],
["option1", "option2"])
assert env_func() == ["option1"]
def test_multiple_valid_values_parsed_correctly(self):
"""Test that multiple valid values are parsed correctly."""
with patch.dict(os.environ, {"TEST_ENV": "option1,option2"}):
env_func = env_list_with_choices("TEST_ENV", [],
["option1", "option2"])
assert env_func() == ["option1", "option2"]
def test_values_with_whitespace_trimmed(self):
"""Test that values with whitespace are trimmed correctly."""
with patch.dict(os.environ, {"TEST_ENV": " option1 , option2 "}):
env_func = env_list_with_choices("TEST_ENV", [],
["option1", "option2"])
assert env_func() == ["option1", "option2"]
def test_empty_values_filtered_out(self):
"""Test that empty values are filtered out."""
with patch.dict(os.environ, {"TEST_ENV": "option1,,option2,"}):
env_func = env_list_with_choices("TEST_ENV", [],
["option1", "option2"])
assert env_func() == ["option1", "option2"]
def test_empty_string_returns_default(self):
"""Test that empty string returns default."""
with patch.dict(os.environ, {"TEST_ENV": ""}):
env_func = env_list_with_choices("TEST_ENV", ["default"],
["option1", "option2"])
assert env_func() == ["default"]
def test_only_commas_returns_default(self):
"""Test that string with only commas returns default."""
with patch.dict(os.environ, {"TEST_ENV": ",,,"}):
env_func = env_list_with_choices("TEST_ENV", ["default"],
["option1", "option2"])
assert env_func() == ["default"]
def test_case_sensitive_validation(self):
"""Test case sensitive validation."""
with patch.dict(os.environ, {"TEST_ENV": "option1,OPTION2"}):
env_func = env_list_with_choices("TEST_ENV", [],
["option1", "option2"],
case_sensitive=True)
with pytest.raises(ValueError,
match="Invalid value 'OPTION2' in TEST_ENV"):
env_func()
def test_case_insensitive_validation(self):
"""Test case insensitive validation."""
with patch.dict(os.environ, {"TEST_ENV": "OPTION1,option2"}):
env_func = env_list_with_choices("TEST_ENV", [],
["option1", "option2"],
case_sensitive=False)
assert env_func() == ["OPTION1", "option2"]
def test_invalid_value_in_list_raises_error(self):
"""Test that invalid value in list raises ValueError."""
with patch.dict(os.environ, {"TEST_ENV": "option1,invalid,option2"}):
env_func = env_list_with_choices("TEST_ENV", [],
["option1", "option2"])
with pytest.raises(ValueError,
match="Invalid value 'invalid' in TEST_ENV"):
env_func()
def test_callable_choices_resolved_correctly(self):
"""Test that callable choices are resolved correctly."""
def get_choices():
return ["dynamic1", "dynamic2"]
with patch.dict(os.environ, {"TEST_ENV": "dynamic1,dynamic2"}):
env_func = env_list_with_choices("TEST_ENV", [], get_choices)
assert env_func() == ["dynamic1", "dynamic2"]
def test_callable_choices_with_invalid_value(self):
"""Test that callable choices raise error for invalid values."""
def get_choices():
return ["dynamic1", "dynamic2"]
with patch.dict(os.environ, {"TEST_ENV": "dynamic1,invalid"}):
env_func = env_list_with_choices("TEST_ENV", [], get_choices)
with pytest.raises(ValueError,
match="Invalid value 'invalid' in TEST_ENV"):
env_func()
def test_duplicate_values_preserved(self):
"""Test that duplicate values in the list are preserved."""
with patch.dict(os.environ, {"TEST_ENV": "option1,option1,option2"}):
env_func = env_list_with_choices("TEST_ENV", [],
["option1", "option2"])
assert env_func() == ["option1", "option1", "option2"]

View File

@ -13,6 +13,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ToolCall)
from vllm.entrypoints.openai.tool_parsers.qwen3coder_tool_parser import (
Qwen3CoderToolParser)
from vllm.entrypoints.openai.tool_parsers.qwen3xml_tool_parser import (
Qwen3XMLToolParser)
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
@ -29,6 +31,21 @@ def qwen3_tool_parser(qwen3_tokenizer):
return Qwen3CoderToolParser(qwen3_tokenizer)
@pytest.fixture
def qwen3_xml_tool_parser(qwen3_tokenizer):
return Qwen3XMLToolParser(qwen3_tokenizer)
@pytest.fixture(params=["original", "xml"])
def qwen3_tool_parser_parametrized(qwen3_tool_parser, qwen3_xml_tool_parser,
request):
"""Parameterized fixture that provides both parser types for testing"""
if request.param == "original":
return qwen3_tool_parser
else:
return qwen3_xml_tool_parser
@pytest.fixture
def sample_tools():
return [
@ -95,7 +112,7 @@ def assert_tool_calls(actual_tool_calls: list[ToolCall],
def stream_delta_message_generator(
qwen3_tool_parser: Qwen3CoderToolParser,
qwen3_tool_parser,
qwen3_tokenizer: AnyTokenizer,
model_output: str,
request: Optional[ChatCompletionRequest] = None
@ -144,9 +161,9 @@ def stream_delta_message_generator(
read_offset = new_read_offset
def test_extract_tool_calls_no_tools(qwen3_tool_parser):
def test_extract_tool_calls_no_tools(qwen3_tool_parser_parametrized):
model_output = "This is a test response without any tool calls"
extracted_tool_calls = qwen3_tool_parser.extract_tool_calls(
extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
@ -294,12 +311,13 @@ circle
], "Let me calculate that area for you."),
],
)
def test_extract_tool_calls(qwen3_tool_parser, sample_tools, model_output,
expected_tool_calls, expected_content):
def test_extract_tool_calls(qwen3_tool_parser_parametrized, sample_tools,
model_output, expected_tool_calls,
expected_content):
request = ChatCompletionRequest(model=MODEL,
messages=[],
tools=sample_tools)
extracted_tool_calls = qwen3_tool_parser.extract_tool_calls(
extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
model_output, request=request)
assert extracted_tool_calls.tools_called
@ -308,7 +326,8 @@ def test_extract_tool_calls(qwen3_tool_parser, sample_tools, model_output,
assert extracted_tool_calls.content == expected_content
def test_extract_tool_calls_fallback_no_tags(qwen3_tool_parser, sample_tools):
def test_extract_tool_calls_fallback_no_tags(qwen3_tool_parser_parametrized,
sample_tools):
"""Test fallback parsing when XML tags are missing"""
model_output = '''<function=get_current_weather>
<parameter=city>
@ -322,7 +341,7 @@ TX
request = ChatCompletionRequest(model=MODEL,
messages=[],
tools=sample_tools)
extracted_tool_calls = qwen3_tool_parser.extract_tool_calls(
extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
model_output, request=request)
assert extracted_tool_calls.tools_called
@ -331,7 +350,7 @@ TX
"get_current_weather")
def test_extract_tool_calls_type_conversion(qwen3_tool_parser):
def test_extract_tool_calls_type_conversion(qwen3_tool_parser_parametrized):
"""Test parameter type conversion based on tool schema"""
tools = [
ChatCompletionToolsParam(type="function",
@ -381,7 +400,7 @@ hello world
</tool_call>'''
request = ChatCompletionRequest(model=MODEL, messages=[], tools=tools)
extracted_tool_calls = qwen3_tool_parser.extract_tool_calls(
extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
model_output, request=request)
args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments)
@ -536,9 +555,10 @@ circle
], "Let me calculate that area for you."),
],
)
def test_extract_tool_calls_streaming(qwen3_tool_parser, qwen3_tokenizer,
sample_tools, model_output,
expected_tool_calls, expected_content):
def test_extract_tool_calls_streaming(qwen3_tool_parser_parametrized,
qwen3_tokenizer, sample_tools,
model_output, expected_tool_calls,
expected_content):
"""Test incremental streaming behavior including typed parameters"""
request = ChatCompletionRequest(model=MODEL,
messages=[],
@ -548,7 +568,8 @@ def test_extract_tool_calls_streaming(qwen3_tool_parser, qwen3_tokenizer,
tool_states = {} # Track state per tool index
for delta_message in stream_delta_message_generator(
qwen3_tool_parser, qwen3_tokenizer, model_output, request):
qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output,
request):
# role should never be streamed from tool parser
assert not delta_message.role
@ -609,7 +630,7 @@ def test_extract_tool_calls_streaming(qwen3_tool_parser, qwen3_tokenizer,
def test_extract_tool_calls_missing_closing_parameter_tag(
qwen3_tool_parser, sample_tools):
qwen3_tool_parser_parametrized, sample_tools):
"""Test handling of missing closing </parameter> tag"""
# Using get_current_weather from sample_tools but with malformed XML
model_output = '''Let me check the weather for you:
@ -629,7 +650,7 @@ fahrenheit
request = ChatCompletionRequest(model=MODEL,
messages=[],
tools=sample_tools)
extracted_tool_calls = qwen3_tool_parser.extract_tool_calls(
extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
model_output, request=request)
# The parser should handle the malformed XML gracefully
@ -652,7 +673,7 @@ fahrenheit
def test_extract_tool_calls_streaming_missing_closing_tag(
qwen3_tool_parser, qwen3_tokenizer, sample_tools):
qwen3_tool_parser_parametrized, qwen3_tokenizer, sample_tools):
"""Test streaming with missing closing </parameter> tag"""
# Using get_current_weather from sample_tools but with malformed XML
model_output = '''Let me check the weather for you:
@ -677,7 +698,8 @@ fahrenheit
tool_states = {}
for delta_message in stream_delta_message_generator(
qwen3_tool_parser, qwen3_tokenizer, model_output, request):
qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output,
request):
if delta_message.content:
other_content += delta_message.content
@ -727,9 +749,8 @@ fahrenheit
assert args["unit"] == "fahrenheit"
def test_extract_tool_calls_streaming_incremental(qwen3_tool_parser,
qwen3_tokenizer,
sample_tools):
def test_extract_tool_calls_streaming_incremental(
qwen3_tool_parser_parametrized, qwen3_tokenizer, sample_tools):
"""Test that streaming is truly incremental"""
model_output = '''I'll check the weather.<tool_call>
<function=get_current_weather>
@ -748,7 +769,8 @@ TX
chunks = []
for delta_message in stream_delta_message_generator(
qwen3_tool_parser, qwen3_tokenizer, model_output, request):
qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output,
request):
chunks.append(delta_message)
# Should have multiple chunks
@ -784,3 +806,49 @@ TX
parsed_args = json.loads(full_args)
assert parsed_args["city"] == "Dallas"
assert parsed_args["state"] == "TX"
def test_extract_tool_calls_complex_type_with_single_quote(
qwen3_tool_parser_parametrized):
"""Test parameter type conversion based on tool schema"""
tools = [
ChatCompletionToolsParam(type="function",
function={
"name": "test_types",
"parameters": {
"type": "object",
"properties": {
"int_param": {
"type": "integer"
},
"float_param": {
"type": "float"
},
"bool_param": {
"type": "boolean"
},
"str_param": {
"type": "string"
},
"obj_param": {
"type": "object"
}
}
}
})
]
model_output = '''<tool_call>
<function=test_types>
<parameter=obj_param>
{'key': 'value'}
</parameter>
</function>
</tool_call>'''
request = ChatCompletionRequest(model=MODEL, messages=[], tools=tools)
extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
model_output, request=request)
args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments)
assert args["obj_param"] == {"key": "value"}

View File

@ -6,6 +6,7 @@ Run `pytest tests/kernels/moe/test_moe_pallas.py`.
"""
import pytest
import torch
import torch_xla
# yapf conflicts with isort for this block
# yapf: disable
@ -77,7 +78,7 @@ def test_pallas_moe(
expert_map=e_map,
renormalize=False,
)
xm.mark_step()
torch_xla.sync(wait=False)
# Compare outputs
torch.testing.assert_close(

View File

@ -47,7 +47,10 @@ backend_configs = {
# FA3 on Hopper
"FA3":
BackendConfig(name="FA3",
env_vars={"VLLM_FLASH_ATTN_VERSION": "3"},
env_vars={
"VLLM_FLASH_ATTN_VERSION": "3",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL",
},
@ -67,6 +70,7 @@ backend_configs = {
BackendConfig(name="FlashAttentionMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL_DECODE_ONLY",
@ -75,7 +79,10 @@ backend_configs = {
# FA2
"FA2":
BackendConfig(name="FA2",
env_vars={"VLLM_FLASH_ATTN_VERSION": "2"},
env_vars={
"VLLM_FLASH_ATTN_VERSION": "2",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
}),

View File

@ -85,7 +85,10 @@ run_tests_for_model() {
echo "Starting prefill instance $i on GPU $GPU_ID, port $PORT"
# Build the command with or without model-specific args
BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT vllm serve $model_name \
BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID \
UCX_NET_DEVICES=all \
VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \
vllm serve $model_name \
--port $PORT \
--enforce-eager \
--gpu-memory-utilization 0.2 \
@ -117,7 +120,10 @@ run_tests_for_model() {
echo "Starting decode instance $i on GPU $GPU_ID, port $PORT"
# Build the command with or without model-specific args
BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT vllm serve $model_name \
BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID \
UCX_NET_DEVICES=all \
VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \
vllm serve $model_name \
--port $PORT \
--enforce-eager \
--gpu-memory-utilization 0.2 \

View File

@ -0,0 +1,59 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa: E501
SharedStorageConnectorMetadata)
from vllm.distributed.kv_transfer.kv_transfer_state import (
ensure_kv_transfer_initialized, get_kv_transfer_group)
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorModelRunnerMixin)
# Importing utils registers TestSharedStorageConnector with the factory
from .utils import create_vllm_config
def _make_empty_scheduler_output():
return SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={},
total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
kv_connector_metadata=SharedStorageConnectorMetadata(),
)
def test_kv_connector_mixin_clears_metadata():
vllm_config = create_vllm_config()
vllm_config.kv_transfer_config.kv_connector = "TestSharedStorageConnector"
vllm_config.kv_transfer_config.kv_role = "kv_both"
vllm_config.kv_transfer_config.kv_connector_extra_config["name"] = ("unit")
# Initialize the global connector instance
ensure_kv_transfer_initialized(vllm_config)
try:
# Minimal scheduler output with empty metadata; mixin should still
# bind/clear metadata even if no loads happen
scheduler_output = _make_empty_scheduler_output()
# Invoke the no-forward path which uses the mixin context manager
KVConnectorModelRunnerMixin.kv_connector_no_forward(
scheduler_output, vllm_config)
# Verify clear_connector_metadata was called on the connector
connector = get_kv_transfer_group()
assert connector._connector_metadata is None
# Test connector wrapper records method calls
assert connector.call_record.get("bind_connector_metadata", 0) == 1
assert connector.call_record.get("clear_connector_metadata", 0) == 1
finally:
# Ensure we clean up the global connector between tests
KVConnectorModelRunnerMixin.ensure_kv_transfer_shutdown()

View File

@ -27,6 +27,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
NixlConnectorWorker, NixlKVConnectorStats)
from vllm.forward_context import ForwardContext
from vllm.platforms.interface import Platform
from vllm.sampling_params import SamplingParams
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
@ -56,7 +57,10 @@ class FakeNixlWrapper:
def get_reg_descs(self, caches_data, memory_type: str) -> list:
return [str(uuid.uuid4()) for _ in caches_data]
def register_memory(self, descs) -> None:
def register_memory(self, descs, backends) -> None:
pass
def deregister_memory(self, descs) -> None:
pass
def get_xfer_descs(self, blocks_data, memory_type: str) -> list:
@ -85,6 +89,12 @@ class FakeNixlWrapper:
def release_xfer_handle(self, handle: int) -> None:
pass
def release_dlist_handle(self, handle: int) -> None:
pass
def remove_remote_agent(self, agent: str) -> None:
pass
def send_notif(self, agent_name: str, notif_msg: bytes) -> None:
pass
@ -855,3 +865,95 @@ def test_register_kv_caches(dist_init):
assert block_len == expected_block_len, \
f"Block entry {i}: Expected block len {expected_block_len}, " \
f"got {block_len}"
class FakePlatform(Platform):
device_type: str = "oot"
@classmethod
def get_nixl_supported_devices(cls) -> dict[str, tuple[str, ...]]:
"""
Returns a mapping from device_type to a tuple of supported
kv_buffer_device for nixl.
"""
return {'oot': ('oot', )}
@classmethod
def get_nixl_memory_type(cls) -> Optional[str]:
"""
Returns the nixl memory type for the current platform.
"""
return 'VRAM'
@pytest.mark.parametrize("kv_buffer_device, nixl_memory_type", [
("oot", "VRAM"),
])
def test_kv_buffer_to_nixl_memory_types(dist_init, kv_buffer_device,
nixl_memory_type):
"""
Test that register_kv_caches() passes the correct memory types from the
config to the nixl_wrapper.
"""
vllm_config = create_vllm_config()
# Override the default memory types in the config
vllm_config.kv_transfer_config.kv_buffer_device = kv_buffer_device
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
_NIXL_SUPPORTED_DEVICE)
_NIXL_SUPPORTED_DEVICE.update(FakePlatform.get_nixl_supported_devices())
with patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"), \
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event"), \
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread"), \
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform", FakePlatform), \
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector._NIXL_SUPPORTED_DEVICE", _NIXL_SUPPORTED_DEVICE): # noqa: E501
# Create connector and replace its worker with a fake one for isolation
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
# Verify get_reg_descs was called with the correct memory_type
assert connector.connector_worker.kv_buffer_device == kv_buffer_device
assert connector.connector_worker.nixl_memory_type == nixl_memory_type
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper)
def test_shutdown_cleans_up_resources(dist_init):
"""Test that shutdown() properly cleans up all resources."""
vllm_config = create_vllm_config()
worker = NixlConnectorWorker(vllm_config,
vllm_config.kv_transfer_config.engine_id)
nixl_wrapper = worker.nixl_wrapper
with patch.object(worker, '_handshake_initiation_executor') as mock_exec, \
patch.object(worker, '_nixl_handshake_listener_t') as mock_listener, \
patch.object(nixl_wrapper, 'release_xfer_handle') as mock_rel_xfer, \
patch.object(nixl_wrapper, 'release_dlist_handle') as mock_rel_dlist, \
patch.object(nixl_wrapper, 'remove_remote_agent') as mock_rem_agent, \
patch.object(nixl_wrapper, 'deregister_memory') as mock_dereg:
worker._recving_transfers = {"req1": [(123, time.perf_counter())]}
worker.src_xfer_side_handle = 456
worker.dst_xfer_side_handles = {"engine1": 789}
worker._remote_agents = {"engine1": {0: "agent1"}}
worker._registered_descs = ["desc1", "desc2"]
worker.shutdown()
# Test idempotency
worker.shutdown()
worker.shutdown()
mock_exec.shutdown.assert_called_with(wait=False)
mock_listener.join.assert_called_once_with(timeout=0)
mock_rel_xfer.assert_called_once_with(123)
assert mock_rel_dlist.call_count == 2
mock_rel_dlist.assert_any_call(456) # src handle
mock_rel_dlist.assert_any_call(789) # dst handle
mock_rem_agent.assert_called_once_with("agent1")
assert mock_dereg.call_count == 2
mock_dereg.assert_any_call("desc1")
mock_dereg.assert_any_call("desc2")

View File

@ -0,0 +1,62 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
import pytest
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
CPU_BLOCK_SIZES = [16, 48]
@pytest.mark.parametrize("cpu_block_size", CPU_BLOCK_SIZES)
def test_cpu_offloading(cpu_block_size: int) -> None:
"""
Tests OffloadingConnector with CPUOffloadingSpec.
"""
# configure OffloadingConnector (spec_name=CPUOffloadingSpec by default)
kv_transfer_config = KVTransferConfig(
kv_connector="OffloadingConnector",
kv_role="kv_both",
kv_connector_extra_config={
"num_cpu_blocks": 100,
"block_size": cpu_block_size
},
)
llm = LLM(
model="meta-llama/Llama-3.2-1B-Instruct",
gpu_memory_utilization=0.5,
kv_transfer_config=kv_transfer_config,
)
prompts = ["Hi " * 100]
sampling_params = SamplingParams(temperature=0, max_tokens=20)
# run generation - this should trigger saving KV cache
start_time = time.time()
llm.generate(prompts, sampling_params, use_tqdm=False)
cold_time = time.time() - start_time
# run generation again - should hit the GPU prefix cache
start_time = time.time()
llm.generate(prompts, sampling_params, use_tqdm=False)
gpu_hit_time = time.time() - start_time
# reset prefix cache to avoid GPU hit.
llm.reset_prefix_cache()
# sleep for a sec to make sure CPU finished storing
time.sleep(1)
# run generation again - this should trigger loading from CPU
start_time = time.time()
llm.generate(prompts, sampling_params, use_tqdm=False)
cpu_hit_time = time.time() - start_time
print("Generation times:")
print(f" Cold: {cold_time * 1000:.2f}ms")
print(f" GPU hit: {gpu_hit_time * 1000:.2f}ms")
print(f" CPU hit: {cpu_hit_time * 1000:.2f}ms")

View File

@ -13,7 +13,6 @@ from vllm import SamplingParams
from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs import PromptType
from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.core_client import DPAsyncMPClient
@ -29,10 +28,6 @@ engine_args = AsyncEngineArgs(
data_parallel_size=DP_SIZE,
)
if not current_platform.supports_v1(engine_args.create_model_config()):
pytest.skip(reason="Requires V1-supporting platform.",
allow_module_level=True)
async def generate(
engine: AsyncLLM,

View File

@ -4,6 +4,7 @@ import math
import pytest
import torch
import torch_xla
from vllm.platforms import current_platform
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
@ -63,7 +64,7 @@ def test_topp_result_sums_past_p():
probs.masked_fill_(logits_masked.isinf(), 0)
masked_prob_sum = probs.sum(dim=-1)
xm.mark_step()
torch_xla.sync()
# Perform assertion on CPU.
assert torch.all(torch.ge(masked_prob_sum.cpu() + TOLERANCE, p.cpu()))
@ -82,7 +83,7 @@ def test_topp_basic():
k=torch.tensor([3, 3]),
p=torch.tensor([0.79, 0.79]))
xm.mark_step()
torch_xla.sync()
# Expect the smallest elements to be dropped.
expected_result = logits.clone().cpu()
@ -104,7 +105,7 @@ def test_topp_select_all():
k=torch.tensor([3, 3]),
p=torch.tensor([1.0, 1.0]))
xm.mark_step()
torch_xla.sync()
assert torch.allclose(logits.cpu(), result.cpu())
@ -122,7 +123,7 @@ def test_topp_with_ties():
k=torch.tensor([4]),
p=torch.tensor([0.2]))
xm.mark_step()
torch_xla.sync()
# All tie values are included in the top-p set. Tie breaking is left
# to be done during final sampling (all tie tokens have equal
@ -146,7 +147,7 @@ def test_both_topk_topp():
k=torch.tensor([1, 3]),
p=torch.tensor([0.79, 0.79]))
xm.mark_step()
torch_xla.sync()
# Since for the first batch k=1, expect only the largest element gets
# selected.

View File

@ -1,35 +0,0 @@
#!/bin/bash
CI=${1:-0}
PYTHON_VERSION=${2:-local}
if [ "$CI" -eq 1 ]; then
set -e
fi
if [ $PYTHON_VERSION == "local" ]; then
PYTHON_VERSION=$(python -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")')
fi
run_mypy() {
echo "Running mypy on $1"
if [ "$CI" -eq 1 ] && [ -z "$1" ]; then
mypy --python-version "${PYTHON_VERSION}" "$@"
return
fi
mypy --follow-imports skip --python-version "${PYTHON_VERSION}" "$@"
}
run_mypy # Note that this is less strict than CI
run_mypy tests
run_mypy vllm/attention
run_mypy vllm/compilation
run_mypy vllm/distributed
run_mypy vllm/engine
run_mypy vllm/executor
run_mypy vllm/inputs
run_mypy vllm/lora
run_mypy --exclude 'vllm/model_executor/layers/fla/ops' vllm/model_executor
run_mypy vllm/plugins
run_mypy vllm/worker
run_mypy vllm/v1

View File

@ -1,20 +1,10 @@
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import sys
import regex as re
try:
import pathspec
except ImportError:
print(
"ERROR: The 'pathspec' library is required. "
"Install it with 'pip install pathspec'.",
file=sys.stderr)
sys.exit(2)
# List of files (relative to repo root) that are allowed to import pickle or
# cloudpickle
#
@ -25,7 +15,7 @@ except ImportError:
# Before adding new uses of pickle/cloudpickle, please consider safer
# alternatives like msgpack or pydantic that are already in use in vLLM. Only
# add to this list if absolutely necessary and after careful security review.
ALLOWED_FILES = set([
ALLOWED_FILES = {
# pickle
'vllm/v1/serial_utils.py',
'vllm/v1/executor/multiproc_executor.py',
@ -36,11 +26,9 @@ ALLOWED_FILES = set([
'tests/tokenization/test_cached_tokenizer.py',
'vllm/distributed/utils.py',
'vllm/distributed/parallel_state.py',
'vllm/engine/multiprocessing/client.py',
'vllm/distributed/device_communicators/all_reduce_utils.py',
'vllm/distributed/device_communicators/shm_broadcast.py',
'vllm/distributed/device_communicators/shm_object_storage.py',
'vllm/engine/multiprocessing/engine.py',
'benchmarks/kernels/graph_machete_bench.py',
'benchmarks/kernels/benchmark_lora.py',
'benchmarks/kernels/benchmark_machete.py',
@ -55,65 +43,30 @@ ALLOWED_FILES = set([
'tests/utils.py',
# pickle and cloudpickle
'vllm/utils/__init__.py',
'vllm/v1/serial_utils.py',
'vllm/v1/executor/multiproc_executor.py',
'vllm/transformers_utils/config.py',
'vllm/model_executor/models/registry.py',
'vllm/engine/multiprocessing/client.py',
'vllm/engine/multiprocessing/engine.py',
])
}
PICKLE_RE = re.compile(r"^\s*(import\s+(pickle|cloudpickle)(\s|$|\sas)"
r"|from\s+(pickle|cloudpickle)\s+import\b)")
def is_python_file(path):
return path.endswith('.py')
def scan_file(path):
def scan_file(path: str) -> int:
with open(path, encoding='utf-8') as f:
for line in f:
for i, line in enumerate(f, 1):
if PICKLE_RE.match(line):
return True
return False
def load_gitignore(repo_root):
gitignore_path = os.path.join(repo_root, '.gitignore')
patterns = []
if os.path.exists(gitignore_path):
with open(gitignore_path, encoding='utf-8') as f:
patterns = f.read().splitlines()
# Always ignore .git directory
patterns.append('.git/')
return pathspec.PathSpec.from_lines('gitwildmatch', patterns)
print(f"{path}:{i}: "
"\033[91merror:\033[0m " # red color
"Found pickle/cloudpickle import")
return 1
return 0
def main():
repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
spec = load_gitignore(repo_root)
bad_files = []
for dirpath, _, filenames in os.walk(repo_root):
for filename in filenames:
if not is_python_file(filename):
continue
abs_path = os.path.join(dirpath, filename)
rel_path = os.path.relpath(abs_path, repo_root)
# Skip ignored files
if spec.match_file(rel_path):
continue
if scan_file(abs_path) and rel_path not in ALLOWED_FILES:
bad_files.append(rel_path)
if bad_files:
print("\nERROR: The following files import 'pickle' or 'cloudpickle' "
"but are not in the allowed list:")
for f in bad_files:
print(f" {f}")
print("\nIf this is intentional, update the allowed list in "
"tools/check_pickle_imports.py.")
sys.exit(1)
sys.exit(0)
returncode = 0
for filename in sys.argv[1:]:
if filename in ALLOWED_FILES:
continue
returncode |= scan_file(filename)
return returncode
def test_regex():
@ -149,4 +102,4 @@ if __name__ == '__main__':
if '--test-regex' in sys.argv:
test_regex()
else:
main()
sys.exit(main())

140
tools/pre_commit/mypy.py Executable file
View File

@ -0,0 +1,140 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Run mypy on changed files.
This script is designed to be used as a pre-commit hook. It runs mypy
on files that have been changed. It groups files into different mypy calls
based on their directory to avoid import following issues.
Usage:
python tools/pre_commit/mypy.py <ci> <python_version> <changed_files...>
Args:
ci: "1" if running in CI, "0" otherwise. In CI, follow_imports is set to
"silent" for the main group of files.
python_version: Python version to use (e.g., "3.10") or "local" to use
the local Python version.
changed_files: List of changed files to check.
"""
import subprocess
import sys
from typing import Optional
import regex as re
FILES = [
"vllm/*.py",
"vllm/assets",
"vllm/entrypoints",
"vllm/inputs",
"vllm/logging_utils",
"vllm/multimodal",
"vllm/platforms",
"vllm/transformers_utils",
"vllm/triton_utils",
"vllm/usage",
]
# After fixing errors resulting from changing follow_imports
# from "skip" to "silent", move the following directories to FILES
SEPARATE_GROUPS = [
"tests",
"vllm/attention",
"vllm/compilation",
"vllm/distributed",
"vllm/engine",
"vllm/executor",
"vllm/inputs",
"vllm/lora",
"vllm/model_executor",
"vllm/plugins",
"vllm/worker",
"vllm/v1",
]
# TODO(woosuk): Include the code from Megatron and HuggingFace.
EXCLUDE = [
"vllm/model_executor/parallel_utils",
"vllm/model_executor/models",
"vllm/model_executor/layers/fla/ops",
# Ignore triton kernels in ops.
"vllm/attention/ops",
]
def group_files(changed_files: list[str]) -> dict[str, list[str]]:
"""
Group changed files into different mypy calls.
Args:
changed_files: List of changed files.
Returns:
A dictionary mapping file group names to lists of changed files.
"""
exclude_pattern = re.compile(f"^{'|'.join(EXCLUDE)}.*")
files_pattern = re.compile(f"^({'|'.join(FILES)}).*")
file_groups = {"": []}
file_groups.update({k: [] for k in SEPARATE_GROUPS})
for changed_file in changed_files:
# Skip files which should be ignored completely
if exclude_pattern.match(changed_file):
continue
# Group files by mypy call
if files_pattern.match(changed_file):
file_groups[""].append(changed_file)
continue
else:
for directory in SEPARATE_GROUPS:
if re.match(f"^{directory}.*", changed_file):
file_groups[directory].append(changed_file)
break
return file_groups
def mypy(targets: list[str], python_version: Optional[str],
follow_imports: Optional[str], file_group: str) -> int:
"""
Run mypy on the given targets.
Args:
targets: List of files or directories to check.
python_version: Python version to use (e.g., "3.10") or None to use
the default mypy version.
follow_imports: Value for the --follow-imports option or None to use
the default mypy behavior.
file_group: The file group name for logging purposes.
Returns:
The return code from mypy.
"""
args = ["mypy"]
if python_version is not None:
args += ["--python-version", python_version]
if follow_imports is not None:
args += ["--follow-imports", follow_imports]
print(f"$ {' '.join(args)} {file_group}")
return subprocess.run(args + targets, check=False).returncode
def main():
ci = sys.argv[1] == "1"
python_version = sys.argv[2]
file_groups = group_files(sys.argv[3:])
if python_version == "local":
python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
returncode = 0
for file_group, changed_files in file_groups.items():
follow_imports = None if ci and file_group == "" else "skip"
if changed_files:
returncode |= mypy(changed_files, python_version, follow_imports,
file_group)
return returncode
if __name__ == "__main__":
sys.exit(main())

View File

@ -10,7 +10,6 @@ from typing import (Any, Dict, Generic, List, Optional, Protocol, Set, Tuple,
import torch
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
from vllm.multimodal import MultiModalPlaceholderMap
class AttentionType:
@ -116,15 +115,6 @@ class AttentionMetadata:
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor
# The index maps that relate multi-modal embeddings to the corresponding
# placeholders.
#
# N.B. These aren't really related to attention and don't belong on this
# type -- this is just a temporary solution to make them available to
# `model_executable`.
multi_modal_placeholder_index_maps: Optional[Dict[
str, MultiModalPlaceholderMap.IndexMap]]
# Enable/disable KV scales calculation. This is so that we can disable the
# calculation until after prefill and cuda graph capture.
enable_kv_scales_calculation: bool

View File

@ -1,10 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from dataclasses import dataclass
from itertools import accumulate
from typing import Dict, List, Optional, Tuple, Type
from typing import List, Optional, Tuple, Type
import torch
@ -12,7 +11,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
AttentionMetadataBuilder)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.utils import async_tensor_h2d
# Placeholder attention backend for models like Mamba and pooling models that
@ -141,8 +139,6 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
@ -178,7 +174,6 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens=None,
seq_lens_tensor=seq_lens_tensor,
@ -210,9 +205,6 @@ class PlaceholderAttentionMetadataBuilder(
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.curr_seq_lens: List[int] = []
self.multimodal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0
@ -232,12 +224,6 @@ class PlaceholderAttentionMetadataBuilder(
self.context_lens.append(context_len)
if is_prompt:
mm_maps = inter_data.multi_modal_placeholder_maps
if mm_maps:
for modality, placeholders in mm_maps.items():
self.multimodal_placeholder_maps[modality].extend(
placeholders)
self.num_prefills += 1
self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len)
@ -295,12 +281,6 @@ class PlaceholderAttentionMetadataBuilder(
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
device, self.runner.pin_memory)
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
self.multimodal_placeholder_maps.items()
}
# Placeholders
slot_mapping_tensor = torch.empty(0)
block_tables = torch.empty(0)
@ -308,7 +288,6 @@ class PlaceholderAttentionMetadataBuilder(
return PlaceholderAttentionMetadata(
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=True,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention backend utils"""
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from itertools import accumulate
@ -15,16 +14,10 @@ from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
from vllm.attention.backends.abstract import AttentionType
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
logger = init_logger(__name__)
# Error string(s) for encoder/decoder
# unsupported attention scenarios
STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported "
"with encoder/decoder models.")
PAD_SLOT_ID = -1
# Switch to numpy implementation of compute_slot_mapping
@ -135,9 +128,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
self.context_lens: List[int] = []
self.block_tables: List[List[int]] = []
self.curr_seq_lens: List[int] = []
self.multimodal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0
@ -154,12 +144,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
inter_data.curr_sliding_window_blocks):
self.context_lens.append(context_len)
if is_prompt:
mm_maps = inter_data.multi_modal_placeholder_maps
if mm_maps:
for modality, placeholders in mm_maps.items():
self.multimodal_placeholder_maps[modality].extend(
placeholders)
self.num_prefills += 1
self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len)
@ -254,16 +238,10 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
self.runner.pin_memory)
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
device, self.runner.pin_memory)
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
self.multimodal_placeholder_maps.items()
}
return self._metadata_cls( # type: ignore
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=True,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
@ -320,7 +298,6 @@ class CommonAttentionState(AttentionState):
num_prefill_tokens=0,
num_decode_tokens=batch_size,
slot_mapping=self._graph_slot_mapping[:batch_size],
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens=None,
seq_lens_tensor=self._graph_seq_lens[:batch_size],

View File

@ -134,6 +134,5 @@ def cp_lse_ag_out_rs(cp_attn_out: torch.Tensor,
cp_attn_lse = cp_attn_lse.contiguous()
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
out, _ = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
assert out.is_contiguous()
out = cp_group.reduce_scatter(out, dim=1)
return out

View File

@ -531,18 +531,22 @@ async def benchmark(
extra_body=extra_body,
)
test_output = await wait_for_endpoint(
request_func,
test_input,
session,
timeout_seconds=ready_check_timeout_sec,
)
if not test_output.success:
raise ValueError(
"Initial test run failed - Please make sure benchmark arguments "
f"are correctly specified. Error: {test_output.error}")
if ready_check_timeout_sec > 0:
test_output = await wait_for_endpoint(
request_func,
test_input,
session,
timeout_seconds=ready_check_timeout_sec,
)
if not test_output.success:
raise ValueError(
"Initial test run failed - Please make sure benchmark "
"arguments are correctly specified. "
f"Error: {test_output.error}")
else:
print("Initial test run completed. Starting main benchmark run...")
else:
print("Initial test run completed. Starting main benchmark run...")
print("Skipping endpoint ready check.")
if lora_modules:
# For each input request, choose a LoRA module at random.
@ -1151,7 +1155,8 @@ def add_cli_args(parser: argparse.ArgumentParser):
type=int,
default=600,
help="Maximum time to wait for the endpoint to become ready "
"in seconds (default: 600 seconds / 10 minutes).",
"in seconds (default: 600 seconds / 10 minutes). If set to 0, "
"the ready check will be skipped."
)

View File

@ -17,7 +17,7 @@ from vllm.platforms import current_platform
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
from .inductor_pass import enable_fake_mode
from .vllm_inductor_pass import VllmInductorPass
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)
@ -152,7 +152,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
register_replacement(pattern, replacement, inputs, fwd_only, pm_pass)
class ActivationQuantFusionPass(VllmInductorPass):
class ActivationQuantFusionPass(VllmPatternMatcherPass):
"""
This pass fuses a pre-defined set of custom ops into fused ops.
It uses the torch pattern matcher to find the patterns and replace them.
@ -176,16 +176,12 @@ class ActivationQuantFusionPass(VllmInductorPass):
pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern()
pattern_silu_mul_nvfp4.register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph):
self.begin()
self.dump_graph(graph, "before_act_quant_fusion")
count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns in ActivationQuantFusionPass",
count)
self.dump_graph(graph, "after_act_quant_fusion")
self.end_and_log()
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self):
return VllmInductorPass.hash_source(self, ActivationQuantPattern,

View File

@ -20,7 +20,7 @@ from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from .inductor_pass import enable_fake_mode
from .vllm_inductor_pass import VllmInductorPass
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
FP8_DTYPE = current_platform.fp8_dtype()
@ -348,7 +348,7 @@ class AllGatherCutlassScaledMMPattern(BasePattern):
pm.fwd_only, pm_pass)
class AsyncTPPass(VllmInductorPass):
class AsyncTPPass(VllmPatternMatcherPass):
@enable_fake_mode
def __init__(self, config: VllmConfig):
@ -378,18 +378,17 @@ class AsyncTPPass(VllmInductorPass):
AllGatherCutlassScaledMMPattern(
self.model_dtype, self.device).register(self.patterns)
self.dump_patterns(config, self.patterns)
def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
# only do replace for specific shapes
tp_size = get_tensor_model_parallel_world_size()
return shape is not None and shape % tp_size == 0
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph):
self.begin()
self.dump_graph(graph, "before_async_tp_pass")
count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns with async TP pass.", count)
self.dump_graph(graph, "after_async_tp_pass")
self.end_and_log()
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
if flashinfer_comm is not None:
@ -1068,7 +1067,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
pm.fwd_only, pm_pass)
class AllReduceFusionPass(VllmInductorPass):
class AllReduceFusionPass(VllmPatternMatcherPass):
def __init__(self, config: VllmConfig):
super().__init__(config)
@ -1124,6 +1123,7 @@ class AllReduceFusionPass(VllmInductorPass):
fuse_rms_quant=config.compilation_config.pass_config.enable_fusion)
self.register_patterns()
self.dump_patterns(config, self.patterns)
@enable_fake_mode
def register_patterns(self):
@ -1172,15 +1172,14 @@ class AllReduceFusionPass(VllmInductorPass):
self.disabled = False
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph):
if self.disabled:
logger.debug("AllReduceFusionPass disabled")
return
self.begin()
self.dump_graph(graph, "before_all_reduce_fusion_pass")
count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", count)
self.dump_graph(graph, "after_all_reduce_fusion_pass")
self.end_and_log()
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def __del__(self):
if getattr(self, "disabled", True):

View File

@ -26,6 +26,7 @@ class FixFunctionalizationPass(VllmInductorPass):
To add new nodes to defunctionalize, add to the if-elif chain in __call__.
"""
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph):
# XPU does not support auto-functionalization yet.
# Will enable this when switch to vllm-xpu-kernels.
@ -34,9 +35,6 @@ class FixFunctionalizationPass(VllmInductorPass):
"pass currently.")
return
self.begin()
self.dump_graph(graph, "before_fix_functionalization")
self.nodes_to_remove: list[torch.fx.Node] = []
count = 0
for node in graph.nodes:
@ -111,7 +109,7 @@ class FixFunctionalizationPass(VllmInductorPass):
count += 1
self.dump_graph(graph, "before_fix_functionalization_cleanup")
self.dump_graph(graph, "before_cleanup")
# Remove the nodes all at once
count_removed = len(self.nodes_to_remove)
@ -120,8 +118,7 @@ class FixFunctionalizationPass(VllmInductorPass):
logger.debug("De-functionalized %s nodes, removed %s nodes", count,
count_removed)
self.dump_graph(graph, "after_fix_functionalization")
self.end_and_log()
self.nodes_to_remove.clear()
def _remove(self, node_or_nodes: Union[torch.fx.Node,
Iterable[torch.fx.Node]]):

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, NamedTuple, Optional
from typing import Any, NamedTuple
import torch
import torch._inductor.pattern_matcher as pm
@ -16,10 +16,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale)
from vllm.platforms import current_platform
from .fx_utils import find_getitem_maybe
from .inductor_pass import enable_fake_mode
from .multi_output_match import MultiOutputMatch
from .vllm_inductor_pass import VllmInductorPass
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)
FP8_DTYPE = current_platform.fp8_dtype()
@ -50,8 +48,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
}
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[
kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default
class FusedRMSQuantKey(NamedTuple):
@ -80,68 +77,6 @@ FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
}
class QuantMultiOutputMatch(MultiOutputMatch):
def __init__(self, match: pm.Match, quant_op, fused_op):
super().__init__(match)
assert isinstance(quant_op, OpOverload)
assert isinstance(fused_op, OpOverload)
self.QUANT_OP = quant_op # in-place quant op
self.FUSED_OP = fused_op # in-place fused quant op
def insert_fused_node(self, fused_return_mapping: dict[int, tuple[fx.Node,
int]],
**kwargs):
"""
This utility function inserts an auto-functionalized node for FUSED_OP.
It also correctly sets its meta value and rebinds the users of the
unfused nodes to use the fused node instead.
:param fused_return_mapping: A dictionary, mapping from getitem indices
of the fused node result to a tuple of the old node and a getitem index.
:param kwargs: kwargs that get directly forwarded to the auto_fn node
Example:
If we want to replace this graph:
_, x1, x2 = auto_fn(op1)
_, y1, y2 = auto_fn(op2)
with
_, x1, y2, x2 = auto_fn(FUSED_OP)
we would call:
insert_fused_node({1: (op1_node, 1), 2: (op2_node, 2), 3: (op1_node, 2)}
Note that the 0th element is None for auto-functionalized in-place ops.
Hence, others appear 1-indexed.
"""
fused_node = self.insert_auto_fn(self.FUSED_OP, kwargs)
indices = fused_return_mapping.keys()
getitem_nodes = self.insert_getitems(fused_node, indices)
# Prepare the meta value, use a list so it's mutable
meta_val = [None] * (max(indices) + 1)
# Iterate through elements of the tuple produced by fused_node
for idx, getitem_node in zip(indices, getitem_nodes):
old_node, old_idx = fused_return_mapping[idx]
# If the old value was never used, the old_getitem might not exist
old_getitem = find_getitem_maybe(old_node, old_idx)
if old_getitem is not None:
# Rebind the users of match getitem nodes to use the new nodes.
# The old nodes will be removed by DCE at the end of the pass.
old_getitem.replace_all_uses_with(getitem_node)
getitem_node.meta["val"] = old_getitem.meta["val"]
# Extract the appropriate meta value
# It is present even if the getitem node does not exist
meta_val[idx] = old_node.meta["val"][old_idx]
# Fix the meta value on the new fused node
fused_node.meta["val"] = tuple(meta_val)
class RMSNormQuantPattern:
def __init__(self, epsilon: float, key: FusedRMSQuantKey):
@ -224,8 +159,7 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
symmetric=symmetric))
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass,
record_match: Callable[[MultiOutputMatch], bool]):
def register(self, pm_pass: PatternMatcherPass):
def pattern(result: torch.Tensor, input: torch.Tensor,
residual: torch.Tensor, weight: torch.Tensor,
@ -271,36 +205,7 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
inputs,
pm.fwd_only,
pm_pass,
extra_check=lambda m: record_match(
self.Match(m, self.QUANT_OP, self.FUSED_OP)))
class Match(QuantMultiOutputMatch):
def process(self):
# Find the nodes in the match that we need to rebind
rms_node = self.find_auto_fn(RMS_ADD_OP)
quant_node = self.find_auto_fn(self.QUANT_OP)
assert len(rms_node.users) == 2
assert len(quant_node.users) == 1
# First, insert a new auto_functionalized node for the fused op,
# as well as getitem nodes to extract the result and residual.
# The auto_fn node returns a tuple of (None, result, residual).
#
# The resulting graph looks like this:
# at = auto_functionalized(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ...) # noqa
# result_node_new = at[1]
# residual_node_new = at[2]
with self.inserting_after_match():
# Missing epsilon, scalars cannot be inputs to the pattern
kwargs = self.match.kwargs.copy()
# 0 is always None
fused_return_mapping = {1: (quant_node, 1), 2: (rms_node, 2)}
self.insert_fused_node(fused_return_mapping,
**kwargs,
epsilon=rms_node.kwargs["epsilon"])
)
class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
@ -317,8 +222,7 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
symmetric=symmetric))
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass,
record_match: Callable[[MultiOutputMatch], bool]):
def register(self, pm_pass: PatternMatcherPass):
def pattern(result: torch.Tensor, result_rms: torch.Tensor,
input: torch.Tensor, weight: torch.Tensor,
@ -366,39 +270,7 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
inputs,
pm.fwd_only,
pm_pass,
extra_check=lambda m: record_match(
self.Match(m, self.QUANT_OP, self.FUSED_OP)))
class Match(QuantMultiOutputMatch):
def process(self):
# Find the nodes in the match that we need to rebind
rms_node = self.find_auto_fn(RMS_OP)
quant_node = self.find_auto_fn(self.QUANT_OP)
assert len(rms_node.users) == 1
assert len(quant_node.users) == 2
# First, insert a new auto_functionalized node for the fused op,
# as well as getitem nodes to extract the result and scale.
# The auto_fn node returns a tuple of (None, result, scale).
#
# The resulting graph looks like this:
# at = auto_functionalized(torch.ops._C.rms_norm_dynamic_per_token_quant.default, ...) # noqa
# result_node_new = at[1]
# scale_node_new = at[2]
with self.inserting_after_match():
# Missing epsilon, scalars cannot be inputs to the pattern
kwargs = self.match.kwargs.copy()
del kwargs["result_rms"] # not used in the fused op
fused_return_mapping = {1: (quant_node, 1), 2: (quant_node, 2)}
self.insert_fused_node(
fused_return_mapping,
epsilon=rms_node.kwargs["epsilon"],
scale_ub=None, # not used but required
residual=None, # not used but required
**kwargs)
)
class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
@ -415,8 +287,7 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
symmetric=symmetric))
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass,
record_match: Callable[[MultiOutputMatch], bool]):
def register(self, pm_pass: PatternMatcherPass):
def pattern(result: torch.Tensor, input: torch.Tensor,
residual: torch.Tensor, weight: torch.Tensor,
@ -464,137 +335,49 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
inputs,
pm.fwd_only,
pm_pass,
extra_check=lambda m: record_match(
self.Match(m, self.QUANT_OP, self.FUSED_OP)))
class Match(QuantMultiOutputMatch):
def process(self):
# Find the nodes in the match that we need to rebind
rms_node = self.find_auto_fn(RMS_ADD_OP)
quant_node = self.find_auto_fn(self.QUANT_OP)
assert len(rms_node.users) == 2
assert len(quant_node.users) == 2
# First, insert a new auto_functionalized node for the fused op,
# as well as getitem nodes to extract result, scale, and residual.
# The auto_fn node returns a tuple (None, result, scale, residual).
#
# The resulting graph looks like this:
# at = auto_functionalized(torch.ops._C.rms_norm_dynamic_per_token_quant.default, ...) # noqa
# result_node_new = at[1]
# scale_node_new = at[2]
# residual_node_new = at[3]
with self.inserting_after_match():
# Missing epsilon, scalars cannot be inputs to the pattern
kwargs = self.match.kwargs.copy()
fused_return_mapping = {
1: (quant_node, 1), # result
2: (quant_node, 2), # scale
3: (rms_node, 2), # residual
}
self.insert_fused_node(
fused_return_mapping,
epsilon=rms_node.kwargs["epsilon"],
scale_ub=None, # not used but required
**kwargs)
)
class FusionPass(VllmInductorPass):
class RMSNormQuantFusionPass(VllmPatternMatcherPass):
"""
This pass fuses a pre-defined set of custom ops into fused ops.
It uses the torch pattern matcher to find the patterns and replace them.
It also manually processes multi-output matches, as those are broken in
the torch pattern matcher.
Because patterns can only be registered once, the pass is a singleton.
This will be addressed in a future version of PyTorch:
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
It also supports fused_add_rms_norm.
"""
_instance: 'Optional[FusionPass]' = None
@classmethod
def instance(cls, config: VllmConfig):
"""
Get the singleton instance of the FusionPass.
If the instance exists, the config is updated but
initialization is not repeated.
"""
if cls._instance is None:
cls._instance = FusionPass(config)
else:
cls._instance.pass_config = config.compilation_config.pass_config
return cls._instance
@enable_fake_mode
def __init__(self, config: VllmConfig):
assert self.__class__._instance is None, \
"FusionPass singleton instance already exists"
super().__init__(config)
self.matches: list[MultiOutputMatch] = []
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="fusion_pass")
pass_name="rmsnorm_quant_fusion_pass")
for epsilon in [1e-5, 1e-6]:
# Fuse rms_norm + static fp8 quant
RMSNormStaticQuantPattern(epsilon,
FP8_DTYPE).register(self.patterns)
# Matches for patterns below have 2 or more outputs,
# so we need to process them manually (see process_matches)
# Fuse rms_norm + static fp8 quant
# Fuse fused_add_rms_norm + static fp8 quant
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns, self.record_match)
self.patterns)
# Fuse rms_norm + dynamic per-token fp8 quant
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns, self.record_match)
RMSNormDynamicQuantPattern(epsilon,
FP8_DTYPE).register(self.patterns)
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns, self.record_match)
self.patterns)
# WARNING: This is a hack to clear the pattern matcher cache
# and allow multiple values of epsilon.
torch._inductor.pattern_matcher._seen_patterns.clear()
def record_match(self, match: MultiOutputMatch) -> bool:
# Hijack the extra_check to record the match and
# save it for post-processing.
self.matches.append(match)
# Return False to prevent automatic replacement.
return False
def process_matches(self, graph: fx.Graph):
"""
Manually process multi-output matches and replace them with fused nodes.
See MultiOutputMatch for more details.
"""
for match in self.matches:
match.process()
# Finally, remove matched nodes
graph.eliminate_dead_code()
assert all(node not in graph.nodes for match in self.matches
for node in match.match.nodes)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph):
self.begin()
self.dump_graph(graph, "before_fusion")
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", count)
self.dump_graph(graph, "after_pattern_match")
# Manually process multi-output matches (and run DCE)
self.process_matches(graph)
logger.debug("Post-processed %s matches", len(self.matches))
self.dump_graph(graph, "after_fusion")
self.matches.clear()
self.end_and_log()
def uuid(self) -> Any:
return self.hash_source(self, RMSNormQuantPattern,
RMSNormStaticQuantPattern,
RMSNormDynamicQuantPattern,
FusedAddRMSNormStaticQuantPattern,
FusedAddRMSNormDynamicQuantPattern)

View File

@ -18,7 +18,7 @@ from vllm.utils import round_up
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
from .inductor_pass import enable_fake_mode
from .vllm_inductor_pass import VllmInductorPass
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)
@ -245,7 +245,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
pm_pass)
class AttnFusionPass(VllmInductorPass):
class AttnFusionPass(VllmPatternMatcherPass):
"""
This pass fuses post-attention quantization onto attention if supported.
@ -282,20 +282,12 @@ class AttnFusionPass(VllmInductorPass):
"were found in CompilationConfig.static_forward_context "
"so no fusion patterns were registered.")
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.graph.Graph) -> None:
self.begin()
self.dump_graph(graph, "before_attn_fusion")
count = self.patterns.apply(graph)
# TODO: Move this to pass_manager.py after the fx graph broken issue
# has been resolved.
# see https://github.com/vllm-project/vllm/issues/23091
graph.eliminate_dead_code()
logger.debug("Fused quantization onto %s attention nodes", count)
self.dump_graph(graph, "after_attn_fusion")
self.end_and_log()
self.matched_count = self.patterns.apply(graph)
logger.debug("Fused quant onto %s attention nodes", self.matched_count)
def uuid(self):
return VllmInductorPass.hash_source(self, AttentionQuantPattern,

View File

@ -1,109 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import abc
import operator
from abc import abstractmethod
from collections.abc import Iterable
from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor import pattern_matcher as pm
from torch._ops import OpOverload
from torch.fx import Node
from vllm.compilation.fx_utils import find_auto_fn
class MultiOutputMatch(abc.ABC):
"""
This class provides utilities to process multi-output matches and
manually insert replacements.
This is necessary because the automatic replacement for multi-output
matches is broken: https://github.com/pytorch/pytorch/issues/137280
"""
def __init__(self, match: pm.Match):
self.match = match
@abstractmethod
def process(self):
"""
Process a multi-output match and manually insert the replacement.
This method should:
1. Insert the replacement nodes after the last node in the match.
2. Rebind the users of nodes in the match to use the new nodes.
3. Set meta["val"] for de-functionalization.
The result of an auto-functionalized node is a tuple of tensors.
The first element is the return value of the function, usually None.
The remaining elements are the mutated args of the function.
All auto-functionalized nodes must contain a proper meta["val"],
as it is used by de-functionalization. meta["val"] has to contain the
value of the node (tuple of tensors) that would be returned by the
functionalized node during tracing.
Existing nodes in the graph all have this property set, but we have
to set it manually for new nodes we insert.
Example:
# op schema: foo(a: Tensor!, b: Tensor, c: Tensor!) -> None
at = auto_functionalized(torch.ops._C.foo.default, a, b, c)
# at.meta["val"] = (None, a, c)
"""
raise NotImplementedError
@property
def nodes(self) -> list[fx.Node]:
return self.match.nodes
@property
def graph(self) -> fx.Graph:
return self.match.graph
def find_auto_fn(self, op) -> fx.Node:
"""
Find the first auto_functionalized node with the given op in the match.
"""
return find_auto_fn(self.nodes, op)
def inserting_after_match(self):
"""
Insert nodes after the last node in the match.
This is done to avoid use-before-definition errors after inserting
replacement nodes.
"""
# match.nodes is not guaranteed to be sorted.
# Find the last node in the match.
for last_node_in_match in reversed(self.graph.nodes):
if last_node_in_match in self.match.nodes:
break
else:
raise ValueError("No nodes in graph")
return self.graph.inserting_after(last_node_in_match)
def insert_getitems(self, tuple_node: fx.Node,
indices: Iterable[int]) -> tuple[fx.Node, ...]:
"""
Insert operator.getitem nodes to extract elements from a tuple node.
:param tuple_node: The tuple node to extract elements from.
:param indices: The indices of the elements to extract.
:return: Tuple of the new getitem nodes, corresponding to the indices.
"""
with self.graph.inserting_after(tuple_node):
return tuple(
self.graph.call_function(operator.getitem, (tuple_node, idx))
for idx in indices)
def insert_auto_fn(self, op: OpOverload, kwargs) -> Node:
"""
Insert an auto_functionalized node with the given op and kwargs.
"""
return self.graph.call_function(auto_functionalized, (op, ),
kwargs=kwargs)

View File

@ -64,9 +64,8 @@ class NoOpEliminationPass(VllmInductorPass):
out: "f16[s0, 4096]" = at[1]
"""
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph):
self.begin()
self.dump_graph(graph, "before_noop_elimination")
count = 0
# Remove no-op reshapes/views:
for node in graph.nodes:
@ -121,8 +120,6 @@ class NoOpEliminationPass(VllmInductorPass):
count += 1
logger.debug("Removed %s no-op reshapes and slices", count)
self.dump_graph(graph, "after_noop_elimination")
self.end_and_log()
# ---------------------- Reshape helpers ----------------------
def reshape_dims_equivalent(self, dim: Union[int, torch.fx.Node],

View File

@ -1,15 +1,21 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from torch import fx as fx
from vllm import envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import set_env_var
from .post_cleanup import PostCleanupPass
from .vllm_inductor_pass import VllmInductorPass
if current_platform.is_cuda_alike():
from .activation_quant_fusion import ActivationQuantFusionPass
from .fusion import FusionPass
from .fusion import RMSNormQuantFusionPass
from .fusion_attn import AttnFusionPass
if current_platform.is_cuda():
@ -19,11 +25,28 @@ from .fix_functionalization import FixFunctionalizationPass
from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context
from .noop_elimination import NoOpEliminationPass
from .sequence_parallelism import SequenceParallelismPass
from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
def with_pattern_match_debug(fn):
"""
Function decorator that turns on inductor pattern match debug
for the duration of the call.
Used to avoid logging builtin Inductor pattern matching.
"""
@functools.wraps(fn)
def wrapper(*args, **kwargs):
if (debug_val := envs.VLLM_PATTERN_MATCH_DEBUG) is not None:
# optionally check rank here
with set_env_var("TORCHINDUCTOR_PATTERN_MATCH_DEBUG", debug_val):
return fn(*args, **kwargs)
return fn(*args, **kwargs)
return wrapper
class PostGradPassManager(CustomGraphPass):
"""
The pass manager for post-grad passes.
@ -40,16 +63,26 @@ class PostGradPassManager(CustomGraphPass):
"""
def __init__(self):
self.passes: list[VllmInductorPass] = []
self.passes: list[InductorPass] = []
@with_pattern_match_debug
def __call__(self, graph: fx.Graph):
VllmInductorPass.dump_prefix = 0 # reset dump index
shape = get_pass_context().runtime_shape
for pass_ in self.passes:
if pass_.is_applicable_for_shape(shape):
pass_(graph)
VllmInductorPass.dump_prefix += 1
# post-cleanup goes before fix_functionalization
# because it requires a functional graph
self.post_cleanup(graph)
VllmInductorPass.dump_prefix += 1
# always run fix_functionalization last
self.fix_functionalization(graph)
VllmInductorPass.dump_prefix = None # Cleanup index
def configure(self, config: VllmConfig):
self.pass_config = config.compilation_config.pass_config
@ -61,14 +94,18 @@ class PostGradPassManager(CustomGraphPass):
if self.pass_config.enable_async_tp:
self.passes += [AsyncTPPass(config)]
if self.pass_config.enable_fi_allreduce_fusion:
self.passes += [AllReduceFusionPass(config)]
if self.pass_config.enable_fusion:
self.passes += [FusionPass.instance(config)]
self.passes += [RMSNormQuantFusionPass(config)]
self.passes += [ActivationQuantFusionPass(config)]
if self.pass_config.enable_attn_fusion:
self.passes += [AttnFusionPass(config)]
if self.pass_config.enable_fi_allreduce_fusion:
self.passes += [AllReduceFusionPass(config)]
# needs a functional graph
self.post_cleanup = PostCleanupPass(config)
self.fix_functionalization = FixFunctionalizationPass(config)
def add(self, pass_: InductorPass):

View File

@ -0,0 +1,20 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from torch import fx
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
class PostCleanupPass(VllmInductorPass):
"""
This pass performs cleanup after custom passes.
It topologically sorts the graph and removes unused nodes.
This is needed because the pattern matcher does not guarantee producing
a topologically sorted graph, and there may be unused nodes left around.
"""
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
from torch._inductor.pattern_matcher import stable_topological_sort
stable_topological_sort(graph)
graph.eliminate_dead_code()

View File

@ -15,7 +15,7 @@ from vllm.logger import init_logger
from vllm.platforms import current_platform
from .inductor_pass import enable_fake_mode
from .vllm_inductor_pass import VllmInductorPass
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)
@ -417,7 +417,7 @@ class LastAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
pm.fwd_only, pm_pass)
class SequenceParallelismPass(VllmInductorPass):
class SequenceParallelismPass(VllmPatternMatcherPass):
"""
This pass enables sequence parallelism for models.
It identifies patterns where an AllReduce operation is followed by
@ -466,19 +466,13 @@ class SequenceParallelismPass(VllmInductorPass):
LastAllReduceRMSNormPattern(epsilon, self.model_dtype,
self.device).register(self.patterns)
# WARNING: This is a hack to clear the pattern matcher cache
# and allow multiple values of epsilon.
torch._inductor.pattern_matcher._seen_patterns.clear()
self.dump_patterns(config, self.patterns)
def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
tp_size = get_tensor_model_parallel_world_size()
return shape is not None and shape % tp_size == 0
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph):
self.begin()
self.dump_graph(graph, "before_sequence_parallelism_pass")
count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns with sequence parallelism", count)
self.dump_graph(graph, "after_sequence_parallelism_pass")
self.end_and_log()
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)

View File

@ -1,10 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import operator
import time
from pathlib import Path
from typing import ClassVar, Optional
import regex as re
import torch
from torch._dynamo.utils import lazy_format_graph_code
from torch._inductor.pattern_matcher import (PatternMatcherPass,
PatternPrettyPrinter)
from vllm.config import VllmConfig
from vllm.logger import init_logger
@ -19,6 +25,8 @@ class VllmInductorPass(InductorPass):
An inductor pass with access to vLLM PassConfig.
It provides timing, logging, and dumping utilities.
"""
dump_prefix: ClassVar[Optional[int]] = None
"""Keep track of pass index for debug dump ordering."""
def __init__(self, config: VllmConfig):
self.pass_config = config.compilation_config.pass_config
@ -28,8 +36,24 @@ class VllmInductorPass(InductorPass):
else None
self.pass_name = self.__class__.__name__
@staticmethod
def time_and_log(call_fn):
@functools.wraps(call_fn)
def wrapped(self: VllmInductorPass, graph: torch.fx.Graph):
self.begin()
self.dump_graph(graph, "before")
call_fn(self, graph)
self.dump_graph(graph, "after")
self.end_and_log()
return wrapped
def dump_graph(self, graph: torch.fx.Graph, stage: str):
lazy_format_graph_code(stage, graph.owning_module)
i = VllmInductorPass.dump_prefix
i_str = "" if i is None else f".{i}"
lazy_format_graph_code(f"post_grad{i_str}.{self.pass_name}.{stage}",
graph.owning_module)
def begin(self):
self._start_time = time.perf_counter_ns()
@ -40,6 +64,88 @@ class VllmInductorPass(InductorPass):
logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms)
class VllmPatternMatcherPass(VllmInductorPass):
"""
A VllmInductorPass that uses the Inductor pattern matcher.
Its main use is providing the dump_patterns utility that dumps the
Inductor pattern matcher patterns into a file, which greatly aids debugging.
TODO(luka) move more utilities to this pass.
"""
matched_count: int = 0
"""The number of matched patterns in the pass."""
_OP_OVERLOAD_PATTERN: ClassVar[re.Pattern] = re.compile(
r"<OpOverload\(op='([^']*)', overload='([^']*)'\)>")
def _replace_op_overloads(self, string: str) -> str:
"""Replace <OpOverload(..., ...)> with nicer formulations"""
return self._OP_OVERLOAD_PATTERN.sub(
lambda m: f"torch.ops.{m.group(1)}.{m.group(2)}",
string,
)
def dump_patterns(self, config: VllmConfig, pm_pass: PatternMatcherPass):
"""
If debug dumping is enabled, dump the Inductor pattern-matcher patterns
into the debug_dump_path folder next to the dumped fx graphs.
This method does its best to print something that looks like Python code
for easier debugging and potentially navigation. If any errors appear in
the output, please add to this method.
TODO(luka): use pattern object to manually produce pattern graph
"""
debug_dump_path = config.compilation_config.debug_dump_path
if not debug_dump_path:
return
rank = config.parallel_config.rank
debug_dump_path = Path(debug_dump_path) / f"rank_{rank}"
debug_dump_path.mkdir(parents=True, exist_ok=True)
from vllm.utils import unique_filepath
file_path = unique_filepath(
lambda i: debug_dump_path / f"patterns.{self.pass_name}.{i}.py")
with file_path.open("w") as f:
print(
f'# This file was produced by VllmPatternMatcherPass.'
f'dump_patterns for {self.pass_name}.\n'
f'# It does its best to produce valid-Python-looking code but'
f' please add to dump_patterns if there are any errors.\n\n'
f'from torch._higher_order_ops.auto_functionalize import '
f'auto_functionalized as auto_functionalized\n'
f'from torch._inductor.pattern_matcher import *',
file=f)
for node, patterns in pm_pass.patterns.items():
# fix the operator.getitem repr
if node[1] == operator.getitem:
node_repr = f"({repr(node[0])}, operator.getitem)"
else:
node_repr = repr(node)
node_repr = self._replace_op_overloads(node_repr)
print(f"\n\n# Patterns for op: {node_repr}", file=f)
for i, pattern in enumerate(patterns):
# reserve auto_functionalized ahead of time
pp = PatternPrettyPrinter()
pp.namespace.create_name("auto_functionalized", None)
# Assemble pattern
out_node = pp.pretty_print(pattern.pattern)
pattern_repr = "\n".join([f"def pattern_{i}():"] + [
f"{pp.memoized_objs_names[key]} = "
f"{pp.memoized_objs_pp[key]}"
for key in pp.memoized_objs_names
] + [f"return {out_node}"]).replace("\n", "\n ")
pattern_repr = self._replace_op_overloads(pattern_repr)
print(f"{pattern_repr}\n", file=f)
class PrinterInductorPass(VllmInductorPass):
def __init__(self, name: str, config: VllmConfig):

View File

@ -503,7 +503,7 @@ class VllmConfig:
if self.compilation_config.pass_config.enable_sequence_parallelism:
self.compilation_config.custom_ops.append("+rms_norm")
if current_platform.is_cuda_alike() or current_platform.is_xpu():
if current_platform.support_static_graph_mode():
# if cudagraph_mode is not explicitly set by users, set default
# value
if self.compilation_config.cudagraph_mode is None:
@ -905,10 +905,9 @@ def set_current_vllm_config(vllm_config: VllmConfig,
except Exception:
raise
else:
logger.debug("enabled custom ops: %s",
vllm_config.compilation_config.enabled_custom_ops)
logger.debug("disabled custom ops: %s",
vllm_config.compilation_config.disabled_custom_ops)
if check_compile:
vllm_config.compilation_config.custom_op_log_check()
if check_compile and \
vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \
and compilation_counter.num_models_seen == num_models_seen:

View File

@ -487,6 +487,12 @@ class CompilationConfig:
"supported with torch>=2.9.0.dev. Set "
"use_inductor_graph_partition=False instead.")
for op in self.custom_ops:
if op[0] not in {'+', '-'} and op not in {'all', 'none'}:
raise ValueError(f"Invalid syntax '{op}' for custom op, "
"must be 'all', 'none', '+op' or '-op' "
"(where 'op' is the registered op name)")
def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
if self.level == CompilationLevel.NO_COMPILATION:
raise ValueError("No compilation level is set.")
@ -532,8 +538,8 @@ class CompilationConfig:
for x in self.compile_sizes:
if isinstance(x, str):
assert x == "cudagraph_capture_sizes", \
"Unrecognized size type in compile_sizes, " \
f"expect 'cudagraph_capture_sizes', got {x}"
"Unrecognized size type in compile_sizes, " \
f"expect 'cudagraph_capture_sizes', got {x}"
computed_compile_sizes.extend(self.cudagraph_capture_sizes)
else:
assert isinstance(x, int)
@ -628,3 +634,41 @@ class CompilationConfig:
return use_fx_graph_piecewise_compilation or \
use_inductor_piecewise_compilation
def custom_op_log_check(self):
"""
This method logs the enabled/disabled custom ops and checks that the
passed custom_ops field only contains relevant ops.
It is called at the end of set_current_vllm_config,
after the custom ops have been instantiated.
"""
if len(self.enabled_custom_ops) + len(self.disabled_custom_ops) == 0:
logger.debug("No custom ops found in model.")
return
logger.debug("enabled custom ops: %s", self.enabled_custom_ops)
logger.debug("disabled custom ops: %s", self.disabled_custom_ops)
all_ops_in_model = (self.enabled_custom_ops | self.disabled_custom_ops)
for op in self.custom_ops:
if op in {"all", "none"}:
continue
assert op[0] in {'+', '-'}, "Invalid custom op syntax " \
"(should be checked during init)"
# check if op name exists in model
op_name = op[1:]
if op_name not in all_ops_in_model:
from vllm.model_executor.custom_op import CustomOp
# Does op exist at all or is it just not present in this model?
# Note: Only imported op classes appear in the registry.
missing_str = "doesn't exist (or wasn't imported/registered)" \
if op_name not in CustomOp.op_registry \
else "not present in model"
enable_str = "enabling" if op[0] == '+' else "disabling"
logger.warning_once("Op '%s' %s, %s with '%s' has no effect",
op_name, missing_str, enable_str, op)

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
import os
from dataclasses import field
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
@ -351,6 +352,10 @@ class ParallelConfig:
self.world_size = self.pipeline_parallel_size * \
self.tensor_parallel_size
if self.distributed_executor_backend == "external_launcher":
logger.info("Using external launcher for distributed inference.")
self.world_size *= self.data_parallel_size
if self.data_parallel_size_local > self.data_parallel_size:
raise ValueError(
f"data_parallel_size_local ({self.data_parallel_size_local}) "
@ -358,6 +363,13 @@ class ParallelConfig:
if self.data_parallel_size > 1 or self.data_parallel_size_local == 0:
# Data parallel was specified in the engine args.
if self.distributed_executor_backend == "external_launcher":
# For external launcher,
# we need to set the data parallel rank automatically
self.data_parallel_rank = int(os.environ["RANK"]) \
// (self.world_size // self.data_parallel_size)
logger.info("Set data_parallel_rank to %d automatically.",
self.data_parallel_rank)
if not self._data_parallel_master_port_list:
self._data_parallel_master_port_list = get_open_ports_list(5)
self.data_parallel_master_port = \
@ -380,7 +392,6 @@ class ParallelConfig:
"be set when data_parallel_size > 1")
if self.distributed_executor_backend == "external_launcher":
import os
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
logger.info("Disabling V1 multiprocessing for external launcher.")

View File

@ -527,7 +527,7 @@ class SpeculativeConfig:
"speculative decoding is > 1, but got "
f"{self.disable_by_batch_size=}")
eagle3_target_supported = ["llama", "qwen"]
eagle3_target_supported = ["llama", "qwen", "gpt_oss"]
if self.method == "eagle3" and self.target_model_config and not any(
supported_model in
self.target_model_config.hf_text_config.model_type

View File

@ -25,6 +25,12 @@ class XpuCommunicator(DeviceCommunicatorBase):
super().__init__(cpu_group, device, device_group, unique_name)
if self.use_all2all:
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
if all2all_backend != "naive":
logger.warning(
"`%s` all2all manager is not supported on XPU."
"Falling back to `naive` all2all manager for XPU.",
all2all_backend)
all2all_backend = "naive"
if all2all_backend == "naive":
from .all2all import NaiveAll2AllManager
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
@ -67,3 +73,16 @@ class XpuCommunicator(DeviceCommunicatorBase):
def broadcast(self, input_: torch.Tensor, src: int = 0) -> None:
dist.broadcast(input_, src=src, group=self.device_group)
def dispatch(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
assert self.all2all_manager is not None
hidden_states, router_logits = self.all2all_manager.dispatch(
hidden_states, router_logits)
return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine(hidden_states)
return hidden_states

View File

@ -58,6 +58,12 @@ except ImportError:
logger.warning("NIXL is not available")
NixlWrapper = None
try:
from nixl._api import nixl_agent_config
except ImportError:
nixl_agent_config = None
logger.warning("NIXL agent config is not available")
# Supported platforms and types of kv transfer buffer.
# {device: tuple of supported kv buffer types}
_NIXL_SUPPORTED_DEVICE = {
@ -65,6 +71,8 @@ _NIXL_SUPPORTED_DEVICE = {
"tpu": ("cpu", ),
"xpu": ("cpu", ),
}
# support for oot platform by providing mapping in current_platform
_NIXL_SUPPORTED_DEVICE.update(current_platform.get_nixl_supported_devices())
class NixlAgentMetadata(
@ -242,6 +250,10 @@ class NixlConnector(KVConnectorBase_V1):
self.connector_worker.copy_blocks:
self.connector_worker.save_kv_to_host(self._connector_metadata)
def shutdown(self):
if self.connector_worker is not None:
self.connector_worker.shutdown()
class NixlConnectorScheduler:
"""Implementation of Scheduler side methods"""
@ -448,8 +460,15 @@ class NixlConnectorWorker:
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
self.nixl_backends = \
vllm_config.kv_transfer_config.get_from_extra_config(
"backends", ["UCX"])
# Agent.
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"]
config = nixl_agent_config(backends=self.nixl_backends) if len(
non_ucx_backends) > 0 and nixl_agent_config is not None else None
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), config)
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict)
@ -486,11 +505,15 @@ class NixlConnectorWorker:
# used when device memory can not be registered under nixl
self.host_xfer_buffers: dict[str, torch.Tensor] = {}
self.use_host_buffer = self.kv_buffer_device == "cpu"
if self.kv_buffer_device == "cuda":
self.nixl_memory_type = "VRAM"
elif self.kv_buffer_device == "cpu":
self.nixl_memory_type = "DRAM"
else:
# support for oot platform which can't register nixl memory
# type based on kv_buffer_device
self.nixl_memory_type = current_platform.get_nixl_memory_type()
if self.nixl_memory_type is None:
if self.kv_buffer_device == "cuda":
self.nixl_memory_type = "VRAM"
elif self.kv_buffer_device == "cpu":
self.nixl_memory_type = "DRAM"
if self.nixl_memory_type is None:
raise RuntimeError(
f"{self.device_type} with {self.kv_buffer_device} kv_buffer "
"is not supported.")
@ -567,13 +590,6 @@ class NixlConnectorWorker:
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
self.xfer_stats = NixlKVConnectorStats()
def __del__(self):
"""Cleanup background threads on destruction."""
if executor := getattr(self, "_handshake_initiation_executor", None):
executor.shutdown(wait=False)
if listener_t := getattr(self, "_nixl_handshake_listener_t", None):
listener_t.join(timeout=0)
@staticmethod
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
ready_event: threading.Event, base_port: int,
@ -766,7 +782,7 @@ class NixlConnectorWorker:
descs = self.nixl_wrapper.get_reg_descs(caches_data,
self.nixl_memory_type)
logger.debug("Registering descs: %s", caches_data)
self.nixl_wrapper.register_memory(descs)
self.nixl_wrapper.register_memory(descs, backends=self.nixl_backends)
logger.debug("Done registering descs")
self._registered_descs.append(descs)
@ -1327,6 +1343,30 @@ class NixlConnectorWorker:
return self.xfer_stats.clone_and_reset()
return None
def shutdown(self):
"""Shutdown the connector worker."""
self._handshake_initiation_executor.shutdown(wait=False)
if self._nixl_handshake_listener_t is not None:
self._nixl_handshake_listener_t.join(timeout=0)
self._nixl_handshake_listener_t = None
for handles in self._recving_transfers.values():
for handle, _ in handles:
self.nixl_wrapper.release_xfer_handle(handle)
self._recving_transfers.clear()
if self.src_xfer_side_handle:
self.nixl_wrapper.release_dlist_handle(self.src_xfer_side_handle)
self.src_xfer_side_handle = 0
for dst_xfer_side_handle in self.dst_xfer_side_handles.values():
self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle)
self.dst_xfer_side_handles.clear()
for remote_agents in self._remote_agents.values():
for agent_name in remote_agents.values():
self.nixl_wrapper.remove_remote_agent(agent_name)
self._remote_agents.clear()
for desc in self._registered_descs:
self.nixl_wrapper.deregister_memory(desc)
self._registered_descs.clear()
@contextlib.contextmanager
def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]:

View File

@ -178,6 +178,9 @@ class P2pNcclConnector(KVConnectorBase_V1):
# Load the KV for each request each layer
for request in metadata.requests:
request_id = request.request_id
ip, port = self.parse_request_id(request_id, False)
remote_address = ip + ":" + str(port + self._rank)
for layer_name in forward_context.no_compile_layers:
layer = forward_context.no_compile_layers[layer_name]
@ -191,7 +194,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
layer = kv_cache[forward_context.virtual_engine]
kv_cache = self.p2p_nccl_engine.recv_tensor(
request.request_id + "#" + layer_name)
request.request_id + "#" + layer_name, remote_address)
if kv_cache is None:
logger.warning("🚧kv_cache is None, %s", request.request_id)

View File

@ -134,7 +134,6 @@ class P2pNcclEngine:
# PUT or PUT_ASYNC
# tensor_id: torch.Tensor
self.send_queue: deque[SendQueueItem] = deque()
self.send_request_id_to_tensor_ids: dict[str, set[str]] = {}
if self.send_type == "PUT_ASYNC":
self._send_thread = threading.Thread(target=self.send_async,
daemon=True)
@ -143,6 +142,7 @@ class P2pNcclEngine:
# tensor_id: torch.Tensor/(addr, dtype, shape)
self.recv_store: dict[str, Any] = {}
self.recv_request_id_to_tensor_ids: dict[str, set[str]] = {}
self.send_request_id_to_tensor_ids: dict[str, set[str]] = {}
self.socks: dict[str, Any] = {} # remote_address: client socket
self.comms: dict[str, Any] = {} # remote_address: (ncclComm_t, rank)
@ -223,18 +223,26 @@ class P2pNcclEngine:
# GET
with self.send_store_cv:
tensor_size = tensor.element_size() * tensor.numel()
if tensor_size > self.buffer_size_threshold:
logger.warning(
"❗[GET]tensor_id:%s, tensor_size:%d, is greater than"
"buffer size threshold :%d, skip send to %s, rank:%d",
tensor_id, tensor_size, self.buffer_size_threshold,
remote_address, self.rank)
return False
while (self.buffer_size + tensor_size
> self.buffer_size_threshold):
oldest_tenser_id = next(iter(self.send_store))
oldest_tenser = self.send_store.pop(oldest_tenser_id)
oldest_tenser_size = oldest_tenser.element_size(
) * oldest_tenser.numel()
self.buffer_size -= oldest_tenser_size
logger.info(
assert len(self.send_store) > 0
oldest_tensor_id = next(iter(self.send_store))
oldest_tensor = self.send_store.pop(oldest_tensor_id)
oldest_tensor_size = oldest_tensor.element_size(
) * oldest_tensor.numel()
self.buffer_size -= oldest_tensor_size
logger.debug(
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
" buffer_size:%d, oldest_tenser_size:%d, rank:%d",
" buffer_size:%d, oldest_tensor_size:%d, rank:%d",
remote_address, tensor_id, tensor_size, self.buffer_size,
oldest_tenser_size, self.rank)
oldest_tensor_size, self.rank)
self.send_store[tensor_id] = tensor
self.buffer_size += tensor_size

View File

@ -1032,7 +1032,9 @@ def init_distributed_environment(world_size: int = -1,
distributed_init_method, backend)
from vllm.config import get_current_vllm_config
config = get_current_vllm_config()
if config is not None and config.parallel_config.data_parallel_size > 1:
if config is not None and config.parallel_config.data_parallel_size > 1 \
and config.parallel_config.distributed_executor_backend \
!= "external_launcher":
parallel_config = config.parallel_config
# adjust to take into account data parallelism
# offset the rank by the data parallel rank

View File

@ -1147,20 +1147,15 @@ class EngineArgs:
else:
envs.set_vllm_use_v1(use_v1)
# Set default arguments for V0 or V1 Engine.
if use_v1:
self._set_default_args_v1(usage_context, model_config)
# Disable chunked prefill for POWER (ppc64le)/ARM/s390x CPUs in V1
if current_platform.is_cpu(
) and current_platform.get_cpu_architecture() in (
CpuArchEnum.POWERPC, CpuArchEnum.S390X, CpuArchEnum.ARM):
logger.info(
"Chunked prefill is not supported for ARM and POWER "
"and S390X CPUs; "
"disabling it for V1 backend.")
self.enable_chunked_prefill = False
else:
self._set_default_args_v0(model_config)
# Set default arguments for V1 Engine.
self._set_default_args(usage_context, model_config)
# Disable chunked prefill for POWER (ppc64le)/ARM/s390x CPUs in V1
if current_platform.is_cpu() and current_platform.get_cpu_architecture(
) in (CpuArchEnum.POWERPC, CpuArchEnum.S390X, CpuArchEnum.ARM):
logger.info("Chunked prefill is not supported for ARM and POWER "
"and S390X CPUs; "
"disabling it for V1 backend.")
self.enable_chunked_prefill = False
assert self.enable_chunked_prefill is not None
sliding_window: Optional[int] = None
@ -1494,6 +1489,7 @@ class EngineArgs:
"FLEX_ATTENTION",
"TREE_ATTN",
"XFORMERS_VLLM_V1",
"ROCM_ATTN_VLLM_V1",
]
if (envs.is_set("VLLM_ATTENTION_BACKEND")
and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):
@ -1501,12 +1497,6 @@ class EngineArgs:
_raise_or_fallback(feature_name=name, recommend_to_remove=True)
return False
# Platforms must decide if they can support v1 for this model
if not current_platform.supports_v1(model_config=model_config):
_raise_or_fallback(
feature_name=f"device type={current_platform.device_type}",
recommend_to_remove=False)
return False
#############################################################
# Experimental Features - allow users to opt in.
@ -1523,12 +1513,6 @@ class EngineArgs:
recommend_to_remove=False)
return False
# The platform may be supported on V1, but off by default for now.
if not current_platform.default_v1( # noqa: SIM103
model_config=model_config) and _warn_or_fallback(
current_platform.device_name):
return False
if (current_platform.is_cpu()
and model_config.get_sliding_window() is not None):
_raise_or_fallback(feature_name="sliding window (CPU backend)",
@ -1539,64 +1523,8 @@ class EngineArgs:
return True
def _set_default_args_v0(self, model_config: ModelConfig) -> None:
"""Set Default Arguments for V0 Engine."""
max_model_len = model_config.max_model_len
use_long_context = max_model_len > 32768
if self.enable_chunked_prefill is None:
# Chunked prefill not supported for Multimodal or MLA in V0.
if model_config.is_multimodal_model or model_config.use_mla:
self.enable_chunked_prefill = False
# Enable chunked prefill by default for long context (> 32K)
# models to avoid OOM errors in initial memory profiling phase.
elif use_long_context:
is_gpu = current_platform.is_cuda()
use_sliding_window = (model_config.get_sliding_window()
is not None)
use_spec_decode = self.speculative_config is not None
if (is_gpu and not use_sliding_window and not use_spec_decode
and not self.enable_lora):
self.enable_chunked_prefill = True
logger.warning(
"Chunked prefill is enabled by default for models "
"with max_model_len > 32K. Chunked prefill might "
"not work with some features or models. If you "
"encounter any issues, please disable by launching "
"with --enable-chunked-prefill=False.")
if self.enable_chunked_prefill is None:
self.enable_chunked_prefill = False
if not self.enable_chunked_prefill and use_long_context:
logger.warning(
"The model has a long context length (%s). This may cause"
"OOM during the initial memory profiling phase, or result "
"in low performance due to small KV cache size. Consider "
"setting --max-model-len to a smaller value.", max_model_len)
# Disable prefix caching for multimodal models for VLLM_V0.
if self.enable_prefix_caching and model_config.is_multimodal_model:
logger.warning(
"--enable-prefix-caching is not supported for multimodal "
"models in V0 and has been disabled.")
self.enable_prefix_caching = False
if self.enable_prompt_embeds:
logger.warning(
"--enable-prompt-embeds and --enable-prefix-caching "
"are not supported together in V0. Prefix caching has "
"been disabled.")
self.enable_prefix_caching = False
# Set max_num_seqs to 256 for VLLM_V0.
if self.max_num_seqs is None:
self.max_num_seqs = 256
def _set_default_args_v1(self, usage_context: UsageContext,
model_config: ModelConfig) -> None:
def _set_default_args(self, usage_context: UsageContext,
model_config: ModelConfig) -> None:
"""Set Default Arguments for V1 Engine."""
# V1 always uses chunked prefills and prefix caching
@ -1795,21 +1723,6 @@ def _raise_or_fallback(feature_name: str, recommend_to_remove: bool):
logger.warning(msg)
def _warn_or_fallback(feature_name: str) -> bool:
if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
logger.warning(
"Detected VLLM_USE_V1=1 with %s. Usage should "
"be considered experimental. Please report any "
"issues on Github.", feature_name)
should_exit = False
else:
logger.info(
"%s is experimental on VLLM_USE_V1=1. "
"Falling back to V0 Engine.", feature_name)
should_exit = True
return should_exit
def human_readable_int(value):
"""Parse human-readable integers like '1k', '2M', etc.
Including decimal values with decimal multipliers.

View File

@ -8,6 +8,7 @@ from abc import ABC, abstractmethod
from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Optional, Union
from openai.types.responses.tool import Mcp
from openai_harmony import Author, Message, Role, StreamState, TextContent
from vllm.entrypoints.harmony_utils import (
@ -21,6 +22,24 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# This is currently needed as the tool type doesn't 1:1 match the
# tool namespace, which is what is used to look up the
# connection to the tool server
_TOOL_NAME_TO_TYPE_MAP = {
"browser": "web_search_preview",
"python": "code_interpreter",
"container": "container",
}
def _map_tool_name_to_tool_type(tool_name: str) -> str:
if tool_name not in _TOOL_NAME_TO_TYPE_MAP:
available_tools = ', '.join(_TOOL_NAME_TO_TYPE_MAP.keys())
raise ValueError(
f"Built-in tool name '{tool_name}' not defined in mapping. "
f"Available tools: {available_tools}")
return _TOOL_NAME_TO_TYPE_MAP[tool_name]
class TurnTokens:
"""Tracks token counts for a single conversation turn."""
@ -59,8 +78,8 @@ class ConversationContext(ABC):
@abstractmethod
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
exit_stack: AsyncExitStack,
request_id: str) -> None:
exit_stack: AsyncExitStack, request_id: str,
mcp_tools: dict[str, Mcp]) -> None:
pass
@abstractmethod
@ -96,8 +115,8 @@ class SimpleContext(ConversationContext):
raise NotImplementedError("Should not be called.")
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
exit_stack: AsyncExitStack,
request_id: str) -> None:
exit_stack: AsyncExitStack, request_id: str,
mcp_tools: dict[str, Mcp]) -> None:
pass
async def cleanup_session(self) -> None:
@ -318,13 +337,17 @@ class HarmonyContext(ConversationContext):
]
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
exit_stack: AsyncExitStack,
request_id: str) -> None:
exit_stack: AsyncExitStack, request_id: str,
mcp_tools: dict[str, Mcp]):
if tool_server:
for tool_name in self.available_tools:
if tool_name not in self._tool_sessions:
tool_type = _map_tool_name_to_tool_type(tool_name)
headers = mcp_tools[
tool_type].headers if tool_type in mcp_tools else None
tool_session = await exit_stack.enter_async_context(
tool_server.new_session(tool_name, request_id))
tool_server.new_session(tool_name, request_id,
headers))
self._tool_sessions[tool_name] = tool_session
exit_stack.push_async_exit(self.cleanup_session)

View File

@ -126,8 +126,10 @@ def get_developer_message(
function_tools: list[Union[Tool, ChatCompletionToolsParam]] = []
for tool in tools:
if tool.type in ("web_search_preview", "code_interpreter",
"container"):
"container", "mcp"):
# These are built-in tools that are added to the system message.
# Adding in MCP for now until we support MCP tools executed
# server side
pass
elif tool.type == "function":

View File

@ -1468,7 +1468,7 @@ class LLM:
def _validate_and_add_requests(
self,
prompts: Union[PromptType, Sequence[PromptType]],
prompts: Union[PromptType, Sequence[PromptType], DataPrompt],
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
Sequence[PoolingParams]],
*,
@ -1478,7 +1478,7 @@ class LLM:
) -> None:
if isinstance(prompts, (str, dict)):
# Convert a single prompt to a list.
prompts = [prompts]
prompts = [prompts] # type: ignore[list-item]
num_requests = len(prompts)
if isinstance(params, Sequence) and len(params) != num_requests:

View File

@ -460,8 +460,12 @@ class OpenAIServingResponses(OpenAIServing):
async with AsyncExitStack() as exit_stack:
try:
mcp_tools = {
tool.server_label: tool
for tool in request.tools if tool.type == "mcp"
}
await context.init_tool_sessions(self.tool_server, exit_stack,
request.request_id)
request.request_id, mcp_tools)
async for _ in result_generator:
pass
except asyncio.CancelledError:
@ -748,11 +752,16 @@ class OpenAIServingResponses(OpenAIServing):
# New conversation.
reasoning_effort = (request.reasoning.effort
if request.reasoning else None)
# Temporary: OpenAI types doesn't have container tool
# so we used MCP to cover that, up for change
tool_types = [tool.type for tool in request.tools]
if envs.VLLM_GPT_OSS_USE_CONTAINER_TOOL:
tool_types.append("container")
# Allow the MCP Tool type to enable built in tools if the
# server_label is allowlisted in
# envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS
if envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS:
for tool in request.tools:
if (tool.type == "mcp" and tool.server_label
in envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS):
tool_types.append(tool.server_label)
enable_browser = ("web_search_preview" in tool_types
and self.tool_server is not None
and self.tool_server.has_tool("browser"))
@ -1653,8 +1662,12 @@ class OpenAIServingResponses(OpenAIServing):
async with AsyncExitStack() as exit_stack:
processer = None
if self.use_harmony:
mcp_tools = {
tool.server_label: tool
for tool in request.tools if tool.type == "mcp"
}
await context.init_tool_sessions(self.tool_server, exit_stack,
request.request_id)
request.request_id, mcp_tools)
processer = self._process_harmony_streaming_events
else:
processer = self._process_simple_streaming_events

View File

@ -20,6 +20,7 @@ from .openai_tool_parser import OpenAIToolParser
from .phi4mini_tool_parser import Phi4MiniJsonToolParser
from .pythonic_tool_parser import PythonicToolParser
from .qwen3coder_tool_parser import Qwen3CoderToolParser
from .qwen3xml_tool_parser import Qwen3XMLToolParser
from .seed_oss_tool_parser import SeedOssToolParser
from .step3_tool_parser import Step3ToolParser
from .xlam_tool_parser import xLAMToolParser
@ -45,6 +46,7 @@ __all__ = [
"HunyuanA13BToolParser",
"Glm4MoeModelToolParser",
"Qwen3CoderToolParser",
"Qwen3XMLToolParser",
"SeedOssToolParser",
"Step3ToolParser",
"OpenAIToolParser",

View File

@ -368,16 +368,32 @@ class Hermes2ProToolParser(ToolParser):
# case -- we now have the first info about arguments available from
# autocompleting the JSON
elif cur_arguments and not prev_arguments:
# extract the content after {"name": ..., "arguments":
# directly from tool_call_portion as cur_arguments_json,
# since cur_arguments may differ from the original text
# due to partial JSON parsing
# for example, tool_call_portion =
# {"name": "search", "arguments": {"search_request": {"
# but cur_arguments =
# {"search_request": {}}
function_name = current_tool_call.get("name")
match = re.search(
r'\{"name":\s*"' +
re.escape(function_name) + r'"\s*,\s*"arguments":\s*(.*)',
tool_call_portion.strip(), re.DOTALL)
if match:
cur_arguments_json = match.group(1)
else:
cur_arguments_json = json.dumps(cur_arguments,
ensure_ascii=False)
cur_arguments_json = json.dumps(cur_arguments,
ensure_ascii=False)
logger.debug("finding %s in %s", delta_text,
cur_arguments_json)
# get the location where previous args differ from current
if (delta_text not in cur_arguments_json[:-2]):
# get the location where previous args differ from current.
if (delta_text not in cur_arguments_json):
return None
args_delta_start_loc = cur_arguments_json[:-2]. \
args_delta_start_loc = cur_arguments_json. \
rindex(delta_text) + \
len(delta_text)
@ -397,8 +413,20 @@ class Hermes2ProToolParser(ToolParser):
# last case -- we have an update to existing arguments.
elif cur_arguments and prev_arguments:
if isinstance(delta_text, str) and len(delta_text.rstrip(
)) >= 1 and delta_text.rstrip()[-1] == '}':
# judge whether the tool_call_portion is a complete JSON
try:
json.loads(tool_call_portion)
is_complete_json = True
except Exception:
is_complete_json = False
# if the delta_text ends with a '}' and tool_call_portion is a
# complete JSON, then the last '}' does not belong to the
# arguments, so we should trim it off
if isinstance(delta_text, str) \
and len(delta_text.rstrip()) >= 1 \
and delta_text.rstrip()[-1] == '}' \
and is_complete_json:
delta_text = delta_text.rstrip()[:-1]
logger.debug("got diff %s", delta_text)

File diff suppressed because it is too large Load Diff

View File

@ -280,7 +280,7 @@ class CompletionRenderer(BaseRenderer):
if truncate_prompt_tokens < 0:
truncate_prompt_tokens = self.model_config.max_model_len
if max_length is not None and truncate_prompt_tokens > max_length:
if max_length is not None and truncate_prompt_tokens > max_length: # type: ignore[operator]
raise ValueError(
f"truncate_prompt_tokens ({truncate_prompt_tokens}) "
f"cannot be greater than max_length ({max_length}). "

View File

@ -18,7 +18,6 @@ if TYPE_CHECKING:
async def list_server_and_tools(server_url: str):
from mcp import ClientSession
from mcp.client.sse import sse_client
async with sse_client(url=server_url) as streams, ClientSession(
*streams) as session:
initialize_response = await session.initialize()
@ -86,8 +85,12 @@ class ToolServer(ABC):
pass
@abstractmethod
def new_session(self, tool_name: str,
session_id: str) -> AbstractAsyncContextManager[Any]:
def new_session(
self,
tool_name: str,
session_id: str,
headers: Optional[dict[str, str]] = None
) -> AbstractAsyncContextManager[Any]:
"""
Create a session for the tool.
"""
@ -144,16 +147,21 @@ class MCPToolServer(ToolServer):
return self.harmony_tool_descriptions.get(tool_name)
@asynccontextmanager
async def new_session(self, tool_name: str, session_id: str):
async def new_session(self,
tool_name: str,
session_id: str,
headers: Optional[dict[str, str]] = None):
from mcp import ClientSession
from mcp.client.sse import sse_client
url = self.urls.get(tool_name)
headers = {"x-session-id": session_id}
request_headers = {"x-session-id": session_id}
if headers is not None:
request_headers.update(headers)
if not url:
raise KeyError(f"Tool '{tool_name}' is not supported")
async with sse_client(url=url,
headers=headers) as streams, ClientSession(
*streams) as session:
async with sse_client(
url=url, headers=request_headers) as streams, ClientSession(
*streams) as session:
await session.initialize()
yield session
@ -189,7 +197,10 @@ class DemoToolServer(ToolServer):
raise ValueError(f"Unknown tool {tool_name}")
@asynccontextmanager
async def new_session(self, tool_name: str, session_id: str):
async def new_session(self,
tool_name: str,
session_id: str,
headers: Optional[dict[str, str]] = None):
if tool_name not in self.tools:
raise KeyError(f"Tool '{tool_name}' is not supported")
yield self.tools[tool_name]

View File

@ -119,12 +119,14 @@ if TYPE_CHECKING:
VLLM_SERVER_DEV_MODE: bool = False
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128
VLLM_MLA_DISABLE: bool = False
VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH: int = 16
VLLM_RAY_PER_WORKER_GPUS: float = 1.0
VLLM_RAY_BUNDLE_INDICES: str = ""
VLLM_CUDART_SO_PATH: Optional[str] = None
VLLM_DP_RANK: int = 0
VLLM_DP_RANK_LOCAL: int = -1
VLLM_DP_SIZE: int = 1
VLLM_USE_STANDALONE_COMPILE: bool = False
VLLM_DP_MASTER_IP: str = ""
VLLM_DP_MASTER_PORT: int = 0
VLLM_MOE_DP_CHUNK_SIZE: int = 256
@ -183,11 +185,12 @@ if TYPE_CHECKING:
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
VLLM_GPT_OSS_USE_CONTAINER_TOOL: bool = False
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True
VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME: str = "VLLM_OBJECT_STORAGE_SHM_BUFFER"
GPT_OSS_SYSTEM_TOOL_MCP_LABELS: list[str] = []
VLLM_PATTERN_MATCH_DEBUG: Optional[str] = None
def get_default_cache_root():
@ -258,6 +261,58 @@ def env_with_choices(
return _get_validated_env
def env_list_with_choices(
env_name: str,
default: list[str],
choices: Union[list[str], Callable[[], list[str]]],
case_sensitive: bool = True) -> Callable[[], list[str]]:
"""
Create a lambda that validates environment variable
containing comma-separated values against allowed choices
Args:
env_name: Name of the environment variable
default: Default list of values if not set
choices: List of valid string options or callable that returns list
case_sensitive: Whether validation should be case sensitive
Returns:
Lambda function for environment_variables
dict that returns list of strings
"""
def _get_validated_env_list() -> list[str]:
value = os.getenv(env_name)
if value is None:
return default
# Split comma-separated values and strip whitespace
values = [v.strip() for v in value.split(",") if v.strip()]
if not values:
return default
# Resolve choices if it's a callable (for lazy loading)
actual_choices = choices() if callable(choices) else choices
# Validate each value
for val in values:
if not case_sensitive:
check_value = val.lower()
check_choices = [choice.lower() for choice in actual_choices]
else:
check_value = val
check_choices = actual_choices
if check_value not in check_choices:
raise ValueError(f"Invalid value '{val}' in {env_name}. "
f"Valid options: {actual_choices}.")
return values
return _get_validated_env_list
def get_vllm_port() -> Optional[int]:
"""Get the port from VLLM_PORT environment variable.
@ -436,9 +491,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Feature flag to enable/disable Inductor standalone compile.
# In torch <= 2.7 we ignore this flag; in torch >= 2.8 this is
# enabled by default.
# disabled by default.
"VLLM_USE_STANDALONE_COMPILE":
lambda: os.environ.get("VLLM_USE_STANDALONE_COMPILE", "1") == "1",
lambda: os.environ.get("VLLM_USE_STANDALONE_COMPILE", "0") == "1",
# Debug pattern matching inside custom passes.
# Should be set to the fx.Node name (e.g. 'getitem_34' or 'scaled_mm_3').
"VLLM_PATTERN_MATCH_DEBUG":
lambda: os.environ.get("VLLM_PATTERN_MATCH_DEBUG", None),
# local rank of the process in the distributed setting, used to determine
# the GPU device id
@ -946,6 +1006,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_MLA_DISABLE":
lambda: bool(int(os.getenv("VLLM_MLA_DISABLE", "0"))),
# If set, vLLM will pick up the provided Flash Attention MLA
# max number splits for cuda graph decode
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH":
lambda: int(os.getenv("VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH",
"16")),
# Number of GPUs per worker in Ray, if it is set to be a fraction,
# it allows ray to schedule multiple actors on a single GPU,
# so that users can colocate other actors on the same GPUs as vLLM.
@ -1306,10 +1372,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_TUNED_CONFIG_FOLDER":
lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None),
# Allows vllm use container tool
"VLLM_GPT_OSS_USE_CONTAINER_TOOL":
lambda: bool(int(os.getenv("VLLM_GPT_OSS_USE_CONTAINER_TOOL", "0"))),
# Allows harmony instructions to be injected on system messages
"VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS":
lambda: bool(
@ -1329,6 +1391,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME":
lambda: os.getenv("VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME",
"VLLM_OBJECT_STORAGE_SHM_BUFFER"),
# Valid values are container,code_interpreter,web_search_preview
# ex GPT_OSS_SYSTEM_TOOL_MCP_LABELS=container,code_interpreter
"GPT_OSS_SYSTEM_TOOL_MCP_LABELS":
env_list_with_choices("GPT_OSS_SYSTEM_TOOL_MCP_LABELS", [],
["container",
"code_interpreter",
"web_search_preview"]),
}
# --8<-- [end:env-vars-definition]
@ -1379,6 +1449,7 @@ def compute_hash() -> str:
environment_variables_to_hash = [
"VLLM_PP_LAYER_PARTITION",
"VLLM_MLA_DISABLE",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH",
"VLLM_USE_TRITON_FLASH_ATTN",
"VLLM_USE_TRITON_AWQ",
"VLLM_DP_RANK",

View File

@ -121,18 +121,18 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
lora_bias = self.slice_bias(lora_bias)
self.lora_a_stacked[0][index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True)
0, :lora_a.shape[0], :lora_a.shape[1]].copy_(
lora_a, non_blocking=True)
self.lora_b_stacked[0][index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True)
0, :lora_b.shape[0], :lora_b.shape[1]].copy_(
lora_b, non_blocking=True)
if lora_bias is not None:
self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
self.lora_bias_stacked)
assert len(self.lora_bias_stacked)
self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_(
lora_bias.T, non_blocking=True)
lora_bias, non_blocking=True)
def apply(self,
x: torch.Tensor,

View File

@ -99,13 +99,13 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
if self.is_merged_col_linear:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.output_size // 2
offset = lora_b.shape[-1] // 2
offset = lora_b.shape[0] // 2
left_weight = lora_b[:, tp_rank * shard_size:(tp_rank + 1) *
shard_size]
right_weight = lora_b[:, offset + tp_rank * shard_size:offset +
(tp_rank + 1) * shard_size]
lora_b = torch.cat([left_weight, right_weight], dim=1)
left_weight = lora_b[tp_rank * shard_size:(tp_rank + 1) *
shard_size, :]
right_weight = lora_b[offset + tp_rank * shard_size:offset +
(tp_rank + 1) * shard_size, :]
lora_b = torch.cat([left_weight, right_weight], dim=0)
# Applicable to cases where the base_layer is
# ColumnParallelLinear.
else:
@ -113,7 +113,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
shard_size = self.output_size
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
lora_b = lora_b[:, start_idx:end_idx]
lora_b = lora_b[start_idx:end_idx, :]
return lora_b
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
@ -251,9 +251,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
for i, (shard_id, shard_size) in enumerate(
zip(self.output_ids, self.output_slices)):
if (lora_b_i := lora_b[i]) is not None:
sliced_lora_b[i] = lora_b_i[:,
shard_size * shard_id:shard_size *
(shard_id + 1)]
sliced_lora_b[i] = lora_b_i[shard_size * shard_id:shard_size *
(shard_id + 1), :]
return sliced_lora_b
def slice_bias(
@ -285,12 +284,12 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
for i in range(self.n_slices):
if (lora_a_i := lora_a[i]) is not None:
self.lora_a_stacked[i][
index, 0, :lora_a_i.shape[1], :lora_a_i.shape[0]].copy_(
lora_a_i.T, non_blocking=True)
index, 0, :lora_a_i.shape[0], :lora_a_i.shape[1]].copy_(
lora_a_i, non_blocking=True)
if (lora_b_i := lora_b[i]) is not None:
self.lora_b_stacked[i][
index, 0, :lora_b_i.shape[1], :lora_b_i.shape[0]].copy_(
lora_b_i.T, non_blocking=True)
index, 0, :lora_b_i.shape[0], :lora_b_i.shape[1]].copy_(
lora_b_i, non_blocking=True)
if lora_bias is not None:
self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
@ -299,7 +298,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
if (lora_bias_i := lora_bias[i]) is not None:
self.lora_bias_stacked[i][index,
0, :lora_bias_i.shape[0]].copy_(
lora_bias_i.T,
lora_bias_i,
non_blocking=True)
@classmethod
@ -345,18 +344,18 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
tp_rank = get_tensor_model_parallel_rank()
self.q_shard_id = tp_rank
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
lora_b_q = lora_b[:, self.q_proj_shard_size *
lora_b_q = lora_b[self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size *
(self.q_shard_id + 1)]
(self.q_shard_id + 1), :]
k_offset = self.q_proj_total_size
lora_b_k = lora_b[:, k_offset +
lora_b_k = lora_b[k_offset +
self.kv_proj_shard_size * self.kv_shard_id:k_offset +
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
self.kv_proj_shard_size * (self.kv_shard_id + 1), :]
v_offset = k_offset + self.kv_proj_total_size
lora_b_v = lora_b[:, v_offset +
lora_b_v = lora_b[v_offset +
self.kv_proj_shard_size * self.kv_shard_id:v_offset +
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
self.kv_proj_shard_size * (self.kv_shard_id + 1), :]
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=0)
return lora_b
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
@ -465,7 +464,7 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.lora_a_stacked[0].shape[2]
start_idx = tp_rank * shard_size
lora_a = lora_a[:, start_idx:start_idx + shard_size]
lora_a = lora_a[start_idx:start_idx + shard_size, :]
return lora_a
def apply(self,
@ -508,10 +507,10 @@ class MergedColumnParallelLinearWithShardedLoRA(
output_shard_size = self.lora_a_stacked[0].shape[2]
output_start_idx = self.tp_rank * output_shard_size
lora_a = [
lora_a[0][:, output_start_idx:output_start_idx +
output_shard_size] if lora_a[0] is not None else None,
lora_a[1][:, output_start_idx:output_start_idx +
output_shard_size] if lora_a[1] is not None else None,
lora_a[0][output_start_idx:output_start_idx +
output_shard_size, :] if lora_a[0] is not None else None,
lora_a[1][output_start_idx:output_start_idx +
output_shard_size, :] if lora_a[1] is not None else None,
]
return lora_a
@ -551,7 +550,7 @@ class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA):
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.lora_a_stacked[0].shape[2]
start_idx = tp_rank * shard_size
lora_a = lora_a[:, start_idx:start_idx + shard_size]
lora_a = lora_a[start_idx:start_idx + shard_size, :]
return lora_a
def apply(self,
@ -589,12 +588,12 @@ class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA):
shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
lora_a = [
lora_a[0][:, start_idx[0]:start_idx[0] +
shard_size[0]] if lora_a[0] is not None else None,
lora_a[1][:, start_idx[1]:start_idx[1] +
shard_size[1]] if lora_a[1] is not None else None,
lora_a[2][:, start_idx[2]:start_idx[2] +
shard_size[2]] if lora_a[2] is not None else None,
lora_a[0][start_idx[0]:start_idx[0] +
shard_size[0], :] if lora_a[0] is not None else None,
lora_a[1][start_idx[1]:start_idx[1] +
shard_size[1], :] if lora_a[1] is not None else None,
lora_a[2][start_idx[2]:start_idx[2] +
shard_size[2], :] if lora_a[2] is not None else None,
]
return lora_a

View File

@ -140,11 +140,11 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
):
self.reset_lora(index)
self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True)
0, :lora_a.shape[0], :lora_a.shape[1]].copy_(
lora_a, non_blocking=True)
self.lora_b_stacked[index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True)
0, :lora_b.shape[0], :lora_b.shape[1]].copy_(
lora_b, non_blocking=True)
if embeddings_tensor is not None:
self.embeddings_tensors[
index,

View File

@ -39,7 +39,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
shard_size = self.input_size
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
lora_a = lora_a[start_idx:end_idx, :]
lora_a = lora_a[:,start_idx:end_idx]
return lora_a
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
@ -122,7 +122,7 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
shard_size = self.lora_b_stacked[0].shape[2]
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
lora_b = lora_b[:, start_idx:end_idx]
lora_b = lora_b[ start_idx:end_idx,:]
return lora_b
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:

View File

@ -95,11 +95,13 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
bias: Optional[torch.Tensor] = None,
):
self.reset_lora(index)
self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
lora_a, non_blocking=True)
# NOTE self.lora_a_stacked is row-major, and lora_a is col-major,
# so we need transpose here
self.lora_a_stacked[index, :lora_a.shape[1], :lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True)
self.lora_b_stacked[index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True)
0, :lora_b.shape[0], :lora_b.shape[1]].copy_(
lora_b, non_blocking=True)
if embeddings_tensor is not None:
self.embeddings_tensors[
index,

View File

@ -86,11 +86,11 @@ class LoRALayerWeights:
embeddings_tensor_dim: Optional[int] = None,
bias_enabled: Optional[bool] = False) -> "LoRALayerWeights":
pin_memory = str(device) == "cpu" and is_pin_memory_available()
lora_a = torch.zeros([input_dim, rank],
lora_a = torch.zeros([rank, input_dim],
dtype=dtype,
device=device,
pin_memory=pin_memory)
lora_b = torch.zeros([rank, output_dim],
lora_b = torch.zeros([output_dim, rank],
dtype=dtype,
device=device,
pin_memory=pin_memory)

View File

@ -152,30 +152,29 @@ class LoRAModel:
module_name, peft_helper, lora_embeddings_tensor)
if is_bias:
loras[module_name].bias = tensor.to(device=device,
dtype=dtype).t()
bias = tensor.to(device=device, dtype=dtype).t()
loras[module_name].bias = tensor.to(device=device, dtype=dtype)
bias = tensor.to(device=device, dtype=dtype)
if pin_memory:
bias = bias.pin_memory()
loras[module_name].bias = bias
elif is_lora_a:
loras[module_name].lora_a = tensor.to(device=device,
dtype=dtype).t()
dtype=dtype)
if pin_memory:
loras[module_name].lora_a = loras[
module_name].lora_a.pin_memory()
else:
loras[module_name].lora_b = tensor.to(device=device,
dtype=dtype).t()
dtype=dtype)
assert embedding_padding_modules is not None
if any(name in module_name
for name in embedding_padding_modules
) and target_embedding_padding is not None:
lora_b = loras[module_name].lora_b
assert target_embedding_padding >= lora_b.shape[1]
addition = target_embedding_padding - lora_b.shape[1]
assert target_embedding_padding >= lora_b.shape[0]
addition = target_embedding_padding - lora_b.shape[0]
loras[module_name].lora_b = torch.nn.functional.pad(
lora_b, (0, addition))
lora_b, (0, 0, 0, addition))
if pin_memory:
loras[module_name].lora_b = loras[
module_name].lora_b.pin_memory()
@ -585,7 +584,6 @@ class LoRAModelManager:
"cpu",
bias_enabled=bias_enabled,
)
lora.optimize()
else:
parts = module_name.split(".")
replacements = self.packed_modules_mapping[parts[-1]]
@ -600,7 +598,6 @@ class LoRAModelManager:
"cpu",
bias_enabled=bias_enabled,
)
lora.optimize()
subloras.append(lora)
lora = PackedLoRALayerWeights.pack(subloras)
model.loras[module_name] = lora

Some files were not shown because too many files have changed in this diff Show More