diff --git a/.buildkite/scripts/hardware_ci/run-tpu-v1-test-part2.sh b/.buildkite/scripts/hardware_ci/run-tpu-v1-test-part2.sh index 1073a4ee30afa..e76528a178205 100755 --- a/.buildkite/scripts/hardware_ci/run-tpu-v1-test-part2.sh +++ b/.buildkite/scripts/hardware_ci/run-tpu-v1-test-part2.sh @@ -62,7 +62,7 @@ echo "--- Installing Python dependencies ---" python3 -m pip install --progress-bar off git+https://github.com/thuml/depyf.git \ && python3 -m pip install --progress-bar off pytest pytest-asyncio tpu-info \ && python3 -m pip install --progress-bar off "lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d" \ - && python3 -m pip install --progress-bar off hf-transfer + && python3 -m pip install --progress-bar off hf-transfer tblib==3.1.0 echo "--- Python dependencies installed ---" export VLLM_USE_V1=1 export VLLM_XLA_CHECK_RECOMPILATION=1 diff --git a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh index 505664f3aecd0..69366cd503219 100755 --- a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh +++ b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh @@ -62,7 +62,7 @@ echo "--- Installing Python dependencies ---" python3 -m pip install --progress-bar off git+https://github.com/thuml/depyf.git \ && python3 -m pip install --progress-bar off pytest pytest-asyncio tpu-info \ && python3 -m pip install --progress-bar off "lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d" \ - && python3 -m pip install --progress-bar off hf-transfer + && python3 -m pip install --progress-bar off hf-transfer tblib==3.1.0 echo "--- Python dependencies installed ---" export VLLM_USE_V1=1 export VLLM_XLA_CHECK_RECOMPILATION=1 diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index fe4796b35786c..c4ea4b675649c 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -165,10 +165,18 @@ steps: - tests/v1/test_hybrid_lb_dp.py - tests/v1/engine/test_engine_core_client.py commands: - # test with tp=2 and external_dp=2 + # test with torchrun tp=2 and external_dp=2 - torchrun --nproc-per-node=4 distributed/test_torchrun_example.py - # test with tp=2 and pp=2 + # test with torchrun tp=2 and pp=2 - PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py + # test with torchrun tp=4 and dp=1 + - TP_SIZE=4 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py + # test with torchrun tp=2, pp=2 and dp=1 + - PP_SIZE=2 TP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py + # test with torchrun tp=1 and dp=4 with ep + - DP_SIZE=4 ENABLE_EP=1 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py + # test with torchrun tp=2 and dp=2 with ep + - TP_SIZE=2 DP_SIZE=2 ENABLE_EP=1 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py # test with internal dp - python3 ../examples/offline_inference/data_parallel.py --enforce-eager - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 37bd0ace98a97..9d749fe8d3238 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -72,6 +72,7 @@ mkdocs.yaml @hmellor # Linting .markdownlint.yaml @hmellor .pre-commit-config.yaml @hmellor +/tools/pre_commit @hmellor # CPU /vllm/v1/worker/cpu* @bigPYJ1151 diff --git a/.github/ISSUE_TEMPLATE/750-RFC.yml b/.github/ISSUE_TEMPLATE/750-RFC.yml index 7ee57c42895ca..c0e009855964a 100644 --- a/.github/ISSUE_TEMPLATE/750-RFC.yml +++ b/.github/ISSUE_TEMPLATE/750-RFC.yml @@ -43,10 +43,6 @@ body: Any other things you would like to mention. validations: required: false -- type: markdown - attributes: - value: > - Thanks for contributing 🎉! The vLLM core team hosts a biweekly RFC review session at 9:30AM Pacific Time, while most RFCs can be discussed online, you can optionally sign up for a slot to discuss your RFC online [here](https://docs.google.com/document/d/1CiLVBZeIVfR7_PNAKVSusxpceywkoOOB78qoWqHvSZc/edit). - type: checkboxes id: askllm attributes: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a4ea888af3f3e..8ca414ee4269b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -60,38 +60,32 @@ repos: files: ^requirements/test\.(in|txt)$ - id: mypy-local name: Run mypy for local Python installation - entry: tools/mypy.sh 0 "local" - language: python - types: [python] - additional_dependencies: &mypy_deps [mypy==1.11.1, types-cachetools, types-setuptools, types-PyYAML, types-requests, pydantic] + entry: python tools/pre_commit/mypy.py 0 "local" stages: [pre-commit] # Don't run in CI + <<: &mypy_common + language: python + types_or: [python, pyi] + require_serial: true + additional_dependencies: [mypy==1.11.1, regex, types-cachetools, types-setuptools, types-PyYAML, types-requests, types-torch, pydantic] - id: mypy-3.9 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward name: Run mypy for Python 3.9 - entry: tools/mypy.sh 1 "3.9" - language: python - types: [python] - additional_dependencies: *mypy_deps + entry: python tools/pre_commit/mypy.py 1 "3.9" + <<: *mypy_common stages: [manual] # Only run in CI - id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward name: Run mypy for Python 3.10 - entry: tools/mypy.sh 1 "3.10" - language: python - types: [python] - additional_dependencies: *mypy_deps + entry: python tools/pre_commit/mypy.py 1 "3.10" + <<: *mypy_common stages: [manual] # Only run in CI - id: mypy-3.11 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward name: Run mypy for Python 3.11 - entry: tools/mypy.sh 1 "3.11" - language: python - types: [python] - additional_dependencies: *mypy_deps + entry: python tools/pre_commit/mypy.py 1 "3.11" + <<: *mypy_common stages: [manual] # Only run in CI - id: mypy-3.12 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward name: Run mypy for Python 3.12 - entry: tools/mypy.sh 1 "3.12" - language: python - types: [python] - additional_dependencies: *mypy_deps + entry: python tools/pre_commit/mypy.py 1 "3.12" + <<: *mypy_common stages: [manual] # Only run in CI - id: shellcheck name: Lint shell scripts @@ -155,11 +149,10 @@ repos: additional_dependencies: [regex] - id: check-pickle-imports name: Prevent new pickle/cloudpickle imports - entry: python tools/check_pickle_imports.py + entry: python tools/pre_commit/check_pickle_imports.py language: python types: [python] - pass_filenames: false - additional_dependencies: [pathspec, regex] + additional_dependencies: [regex] - id: validate-config name: Validate configuration has default values and that each field has a docstring entry: python tools/validate_config.py diff --git a/docs/contributing/benchmarks.md b/docs/contributing/benchmarks.md index 2a03ce1dffd63..a97d1fa6a3a55 100644 --- a/docs/contributing/benchmarks.md +++ b/docs/contributing/benchmarks.md @@ -680,7 +680,7 @@ vllm bench serve \ --save-result \ --result-dir ~/vllm_benchmark_results \ --save-detailed \ - --endpoint /v1/chat/completion + --endpoint /v1/chat/completions ``` ##### Videos (ShareGPT4Video) @@ -707,7 +707,7 @@ vllm bench serve \ --save-result \ --result-dir ~/vllm_benchmark_results \ --save-detailed \ - --endpoint /v1/chat/completion + --endpoint /v1/chat/completions ``` ##### Synthetic Random Images (random-mm) diff --git a/docs/features/disagg_prefill.md b/docs/features/disagg_prefill.md index 996ef00a6b960..2c69304db3393 100644 --- a/docs/features/disagg_prefill.md +++ b/docs/features/disagg_prefill.md @@ -23,7 +23,7 @@ Now supports 5 types of connectors: - **SharedStorageConnector**: refer to for the example usage of SharedStorageConnector disaggregated prefilling. - **LMCacheConnectorV1**: refer to for the example usage of LMCacheConnectorV1 disaggregated prefilling which uses NIXL as the underlying KV transmission. -- **NixlConnector**: refer to for the example usage of NixlConnector disaggregated prefilling which support fully async send/recv. +- **NixlConnector**: refer to for the example usage of NixlConnector disaggregated prefilling which support fully async send/recv. For detailed usage guide, see [NixlConnector Usage Guide](nixl_connector_usage.md). - **P2pNcclConnector**: refer to for the example usage of P2pNcclConnector disaggregated prefilling. - **MultiConnector**: take advantage of the kv_connector_extra_config: dict[str, Any] already present in KVTransferConfig to stash all the connectors we want in an ordered list of kwargs.such as: @@ -31,6 +31,18 @@ Now supports 5 types of connectors: --kv-transfer-config '{"kv_connector":"MultiConnector","kv_role":"kv_both","kv_connector_extra_config":{"connectors":[{"kv_connector":"NixlConnector","kv_role":"kv_both"},{"kv_connector":"SharedStorageConnector","kv_role":"kv_both","kv_connector_extra_config":{"shared_storage_path":"local_storage"}}]}}' ``` +For NixlConnector, you may also specify one or multiple NIXL_Backend. Such as: + + ```bash + --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both", "kv_buffer_device":"cuda", "kv_connector_extra_config":{"backend":["UCX", "GDS"]}' + ``` + +- **OffloadingConnector**: enable offloading of KV data to CPU memory, customizing the CPU block size (in tokens) and number of blocks to allocate (per worker): + + ```bash + --kv-transfer-config '{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"block_size": 64, "num_cpu_blocks": 1000}}' + ``` + ## Benchmarks Please refer to for disaggregated prefilling benchmarks. diff --git a/docs/features/nixl_connector_usage.md b/docs/features/nixl_connector_usage.md new file mode 100644 index 0000000000000..de50f091df428 --- /dev/null +++ b/docs/features/nixl_connector_usage.md @@ -0,0 +1,159 @@ +# NixlConnector Usage Guide + +NixlConnector is a high-performance KV cache transfer connector for vLLM's disaggregated prefilling feature. It provides fully asynchronous send/receive operations using the NIXL library for efficient cross-process KV cache transfer. + +## Prerequisites + +### Installation + +Install the NIXL library: `uv pip install nixl`, as a quick start. + +- Refer to [NIXL official repository](https://github.com/ai-dynamo/nixl) for more installation instructions +- The specified required NIXL version can be found in [requirements/kv_connectors.txt](../../requirements/kv_connectors.txt) and other relevant config files + +### Transport Configuration + +NixlConnector uses NIXL library for underlying communication, which supports multiple transport backends. UCX (Unified Communication X) is the primary default transport library used by NIXL. Configure transport environment variables: + +```bash +# Example UCX configuration, adjust according to your enviroment +export UCX_TLS=all # or specify specific transports like "rc,ud,sm,^cuda_ipc" ..etc +export UCX_NET_DEVICES=all # or specify network devices like "mlx5_0:1,mlx5_1:1" +``` + +!!! tip + When using UCX as the transport backend, NCCL environment variables (like `NCCL_IB_HCA`, `NCCL_SOCKET_IFNAME`) are not applicable to NixlConnector, so configure UCX-specific environment variables instead of NCCL variables. + +## Basic Usage (on the same host) + +### Producer (Prefiller) Configuration + +Start a prefiller instance that produces KV caches + +```bash +# 1st GPU as prefiller +CUDA_VISIBLE_DEVICES=0 \ +UCX_NET_DEVICES=all \ +VLLM_NIXL_SIDE_CHANNEL_PORT=5600 \ +vllm serve Qwen/Qwen3-0.6B \ + --port 8100 \ + --enforce-eager \ + --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' +``` + +### Consumer (Decoder) Configuration + +Start a decoder instance that consumes KV caches: + +```bash +# 2nd GPU as decoder +CUDA_VISIBLE_DEVICES=1 \ +UCX_NET_DEVICES=all \ +VLLM_NIXL_SIDE_CHANNEL_PORT=5601 \ +vllm serve Qwen/Qwen3-0.6B \ + --port 8200 \ + --enforce-eager \ + --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' +``` + +### Proxy Server + +Use a proxy server to route requests between prefiller and decoder: + +```bash +python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py \ + --port 8192 \ + --prefiller-hosts localhost \ + --prefiller-ports 8100 \ + --decoder-hosts localhost \ + --decoder-ports 8200 +``` + +## Environment Variables + +- `VLLM_NIXL_SIDE_CHANNEL_PORT`: Port for NIXL handshake communication + - Default: 5600 + - **Required for both prefiller and decoder instances** + - Each vLLM worker needs a unique port on its host; using the same port number across different hosts is fine + - For TP/DP deployments, each worker's port on a node is computed as: base_port + dp_rank * tp_size + tp_rank (e.g., with `--tensor-parallel-size=4` and base_port=5600, tp_rank 0..3 use ports 5600, 5601, 5602, 5603 on that node). + - Used for the initial NIXL handshake between the prefiller and the decoder + +- `VLLM_NIXL_SIDE_CHANNEL_HOST`: Host for side channel communication + - Default: "localhost" + - Set when prefiller and decoder are on different machines + - Connection info is passed via KVTransferParams from prefiller to decoder for handshake + +- `VLLM_NIXL_ABORT_REQUEST_TIMEOUT`: Timeout (in seconds) for automatically releasing the prefiller’s KV cache for a particular request. (Optional) + - Default: 120 + - If a request is aborted and the decoder has not yet read the KV-cache blocks through the nixl channel, the prefill instance will release its KV-cache blocks after this timeout to avoid holding them indefinitely. + +## Multi-Instance Setup + +### Multiple Prefiller Instances on Different Machines + +```bash +# Prefiller 1 on Machine A (example IP: ${IP1}) +VLLM_NIXL_SIDE_CHANNEL_HOST=${IP1} \ +VLLM_NIXL_SIDE_CHANNEL_PORT=5600 \ +UCX_NET_DEVICES=all \ +vllm serve Qwen/Qwen3-0.6B --port 8000 \ + --tensor-parallel-size 8 \ + --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_producer"}' + +# Prefiller 2 on Machine B (example IP: ${IP2}) +VLLM_NIXL_SIDE_CHANNEL_HOST=${IP2} \ +VLLM_NIXL_SIDE_CHANNEL_PORT=5600 \ +UCX_NET_DEVICES=all \ +vllm serve Qwen/Qwen3-0.6B --port 8000 \ + --tensor-parallel-size 8 \ + --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_producer"}' +``` + +### Multiple Decoder Instances on Different Machines + +```bash +# Decoder 1 on Machine C (example IP: ${IP3}) +VLLM_NIXL_SIDE_CHANNEL_HOST=${IP3} \ +VLLM_NIXL_SIDE_CHANNEL_PORT=5600 \ +UCX_NET_DEVICES=all \ +vllm serve Qwen/Qwen3-0.6B --port 8000 \ + --tensor-parallel-size 8 \ + --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_consumer"}' + +# Decoder 2 on Machine D (example IP: ${IP4}) +VLLM_NIXL_SIDE_CHANNEL_HOST=${IP4} \ +VLLM_NIXL_SIDE_CHANNEL_PORT=5600 \ +UCX_NET_DEVICES=all \ +vllm serve Qwen/Qwen3-0.6B --port 8000 \ + --tensor-parallel-size 8 \ + --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_consumer"}' +``` + +### Proxy for Multiple Instances + +```bash +python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py \ + --port 8192 \ + --prefiller-hosts ${IP1} ${IP2} \ + --prefiller-ports 8000 8000 \ + --decoder-hosts ${IP3} ${IP4} \ + --decoder-ports 8000 8000 +``` + +### KV Role Options + +- **kv_producer**: For prefiller instances that generate KV caches +- **kv_consumer**: For decoder instances that consume KV caches from prefiller +- **kv_both**: Enables symmetric functionality where the connector can act as both producer and consumer. This provides flexibility for experimental setups and scenarios where the role distinction is not predetermined. + +!!! tip + NixlConnector currently does not distinguish `kv_role`; the actual prefiller/decoder roles are determined by the upper-level proxy (e.g., `toy_proxy_server.py` using `--prefiller-hosts` and `--decoder-hosts`). + Therefore, `kv_role` in `--kv-transfer-config` is effectively a placeholder and does not affect NixlConnector's behavior. + +## Example Scripts/Code + +Refer to these example scripts in the vLLM repository: + +- [run_accuracy_test.sh](../../tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh) +- [toy_proxy_server.py](../../tests/v1/kv_connector/nixl_integration/toy_proxy_server.py) +- [test_accuracy.py](../../tests/v1/kv_connector/nixl_integration/test_accuracy.py) diff --git a/docs/features/tool_calling.md b/docs/features/tool_calling.md index 2a48596571d1d..291c313cd57af 100644 --- a/docs/features/tool_calling.md +++ b/docs/features/tool_calling.md @@ -319,6 +319,15 @@ Supported models: Flags: `--tool-call-parser glm45` +### Qwen3-Coder Models (`qwen3_xml`) + +Supported models: + +* `Qwen/Qwen3-480B-A35B-Instruct` +* `Qwen/Qwen3-Coder-30B-A3B-Instruct` + +Flags: `--tool-call-parser qwen3_xml` + ### Models with Pythonic Tool Calls (`pythonic`) A growing number of models output a python list to represent tool calls instead of using JSON. This has the advantage of inherently supporting parallel tool calls and removing ambiguity around the JSON schema required for tool calls. The `pythonic` tool parser can support such models. diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index cbc0a56a645ea..9d288667a318f 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -352,6 +352,7 @@ th { | `DeepseekV2ForCausalLM` | DeepSeek-V2 | `deepseek-ai/DeepSeek-V2`, `deepseek-ai/DeepSeek-V2-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | | `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3`, `deepseek-ai/DeepSeek-R1`, `deepseek-ai/DeepSeek-V3.1`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Dots1ForCausalLM` | dots.llm1 | `rednote-hilab/dots.llm1.base`, `rednote-hilab/dots.llm1.inst`, etc. | | ✅︎ | ✅︎ | +| `DotsOCRForCausalLM` | dots_ocr | `rednote-hilab/dots.ocr` | | ✅︎ | ✅︎ | | `Ernie4_5ForCausalLM` | Ernie4.5 | `baidu/ERNIE-4.5-0.3B-PT`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Ernie4_5_MoeForCausalLM` | Ernie4.5MoE | `baidu/ERNIE-4.5-21B-A3B-PT`, `baidu/ERNIE-4.5-300B-A47B-PT`, etc. |✅︎| ✅︎ | ✅︎ | | `ExaoneForCausalLM` | EXAONE-3 | `LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | diff --git a/docs/serving/expert_parallel_deployment.md b/docs/serving/expert_parallel_deployment.md index 7489fc2609831..f823d33df80ea 100644 --- a/docs/serving/expert_parallel_deployment.md +++ b/docs/serving/expert_parallel_deployment.md @@ -193,7 +193,7 @@ For production deployments requiring strict SLA guarantees for time-to-first-tok 1. **Install gdrcopy/ucx/nixl**: For maximum performance, run the [install_gdrcopy.sh](gh-file:tools/install_gdrcopy.sh) script to install `gdrcopy` (e.g., `install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "12.8" "x64"`). You can find available OS versions [here](https://developer.download.nvidia.com/compute/redist/gdrcopy/CUDA%2012.8/). If `gdrcopy` is not installed, things will still work with a plain `pip install nixl`, just with lower performance. `nixl` and `ucx` are installed as dependencies via pip. -2. **Configure Both Instances**: Add this flag to both prefill and decode instances `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}` +2. **Configure Both Instances**: Add this flag to both prefill and decode instances `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}`. Noted, you may also specify one or multiple NIXL_Backend. Such as: `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both", "kv_connector_extra_config":{"backend":["UCX", "GDS"]}'` 3. **Client Orchestration**: Use the client-side script below to coordinate prefill/decode operations. We are actively working on routing solutions. diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py index 98fe36d0fb796..0076d4d30ee8e 100644 --- a/examples/offline_inference/data_parallel.py +++ b/examples/offline_inference/data_parallel.py @@ -101,6 +101,13 @@ def parse_args(): "--quantization", type=str, ) + parser.add_argument( + "--disable-expert-parallel", + dest="enable_expert_parallel", + action="store_false", + help="Disable expert parallel (default: enabled).", + ) + parser.set_defaults(enable_expert_parallel=True) return parser.parse_args() @@ -113,6 +120,7 @@ def main( dp_master_port, GPUs_per_dp_rank, enforce_eager, + enable_expert_parallel, trust_remote_code, max_num_seqs, max_model_len, @@ -168,7 +176,7 @@ def main( model=model, tensor_parallel_size=GPUs_per_dp_rank, enforce_eager=enforce_eager, - enable_expert_parallel=True, + enable_expert_parallel=enable_expert_parallel, trust_remote_code=trust_remote_code, max_num_seqs=max_num_seqs, max_model_len=max_model_len, @@ -229,6 +237,7 @@ if __name__ == "__main__": dp_master_port, tp_size, args.enforce_eager, + args.enable_expert_parallel, args.trust_remote_code, args.max_num_seqs, args.max_model_len, diff --git a/examples/offline_inference/torchrun_dp_example.py b/examples/offline_inference/torchrun_dp_example.py new file mode 100644 index 0000000000000..8e888a100254e --- /dev/null +++ b/examples/offline_inference/torchrun_dp_example.py @@ -0,0 +1,81 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +experimental support for data-parallel inference with torchrun +Note the data load balancing and distribution is done out of the vllm engine, +no internal lb supported in external_launcher mode. +""" + +from vllm import LLM, SamplingParams + +# Create prompts, the same across all ranks +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] * 50 + +# Create sampling parameters, the same across all ranks +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + +# Use `distributed_executor_backend="external_launcher"` so that +# this llm engine/instance only creates one worker. +# it is important to set an explicit seed to make sure that +# all ranks have the same random seed, so that sampling can be +# deterministic across ranks. +llm = LLM( + model="microsoft/Phi-mini-MoE-instruct", + tensor_parallel_size=1, + data_parallel_size=2, + pipeline_parallel_size=1, + enable_expert_parallel=False, + distributed_executor_backend="external_launcher", + max_model_len=4096, + gpu_memory_utilization=0.6, + seed=1, +) + +dp_rank = llm.llm_engine.vllm_config.parallel_config.data_parallel_rank +dp_size = llm.llm_engine.vllm_config.parallel_config.data_parallel_size + +prompts = [ + f"{idx}.{prompt}" for idx, prompt in enumerate(prompts) if idx % dp_size == dp_rank +] + +outputs = llm.generate(prompts, sampling_params) + + +# all ranks will have the same outputs +print("-" * 50) +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}\n") + print("-" * 50) +""" +Further tips: + +1. to communicate control messages across all ranks, use the cpu group, +a PyTorch ProcessGroup with GLOO backend. + +```python +from vllm.distributed.parallel_state import get_world_group +cpu_group = get_world_group().cpu_group +torch_rank = dist.get_rank(group=cpu_group) +if torch_rank == 0: + # do something for rank 0, e.g. saving the results to disk. +``` + +2. to communicate data across all ranks, use the model's device group, +a PyTorch ProcessGroup with NCCL backend. +```python +from vllm.distributed.parallel_state import get_world_group +device_group = get_world_group().device_group +``` + +3. to access the model directly in every rank, use the following code: +```python +llm.llm_engine.model_executor.driver_worker.worker.model_runner.model +``` +""" diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index de3f3afc17948..f8ddb5a22b31a 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -126,6 +126,23 @@ def run_chameleon(questions: list[str], modality: str) -> ModelRequestData: ) +# Dots-OCR +def run_dots_ocr(questions: list[str], modality: str) -> ModelRequestData: + assert modality == "image" + + prompts = [f"<|img|><|imgpad|><|endofimg|>{question}" for question in questions] + engine_args = EngineArgs( + model="rednote-hilab/dots.ocr", + limit_mm_per_prompt={modality: 1}, + trust_remote_code=True, + ) + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + def run_command_a_vision(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -1676,6 +1693,7 @@ model_example_map = { "aya_vision": run_aya_vision, "blip-2": run_blip2, "chameleon": run_chameleon, + "dots_ocr": run_dots_ocr, "command_a_vision": run_command_a_vision, "deepseek_vl_v2": run_deepseek_vl2, "ernie45_vl": run_ernie45_vl, diff --git a/pyproject.toml b/pyproject.toml index f43ae69e00bdd..88c5c4067f5ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -110,27 +110,6 @@ ignore_missing_imports = true check_untyped_defs = true follow_imports = "silent" -# After fixing type errors resulting from follow_imports: "skip" -> "silent", -# move the directory here and remove it from tools/mypy.sh -files = [ - "vllm/*.py", - "vllm/assets", - "vllm/entrypoints", - "vllm/inputs", - "vllm/logging_utils", - "vllm/multimodal", - "vllm/platforms", - "vllm/transformers_utils", - "vllm/triton_utils", - "vllm/usage", -] -# TODO(woosuk): Include the code from Megatron and HuggingFace. -exclude = [ - "vllm/model_executor/parallel_utils/|vllm/model_executor/models/", - # Ignore triton kernels in ops. - 'vllm/attention/ops/.*\.py$' -] - [tool.isort] skip_glob = [ ".buildkite/*", diff --git a/requirements/tpu.txt b/requirements/tpu.txt index 7ea239b48ea26..4241cbb2b0333 100644 --- a/requirements/tpu.txt +++ b/requirements/tpu.txt @@ -14,14 +14,4 @@ nixl==0.3.0 tpu_info==0.4.0 # Install torch_xla ---pre ---extra-index-url https://download.pytorch.org/whl/nightly/cpu ---find-links https://storage.googleapis.com/libtpu-wheels/index.html ---find-links https://storage.googleapis.com/libtpu-releases/index.html ---find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html ---find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -torch==2.9.0.dev20250730 -torchvision==0.24.0.dev20250730 -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250730-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250730-cp312-cp312-linux_x86_64.whl ; python_version == "3.12" - +torch_xla[tpu, pallas]==2.8.0 \ No newline at end of file diff --git a/tests/compile/backend.py b/tests/compile/backend.py index 2c4287950dcfe..f25c367433f41 100644 --- a/tests/compile/backend.py +++ b/tests/compile/backend.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import weakref from collections.abc import Sequence from copy import deepcopy from typing import Callable, Union @@ -10,7 +11,26 @@ from torch._ops import OpOverload from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.inductor_pass import InductorPass -from vllm.config import get_current_vllm_config +from vllm.compilation.pass_manager import with_pattern_match_debug +from vllm.compilation.vllm_inductor_pass import VllmInductorPass +from vllm.config import VllmConfig, get_current_vllm_config + + +class LazyInitPass(InductorPass): + """ + If there's a pass that we want to initialize lazily in a test, + we can wrap it in LazyInitPass, which will initialize the pass when invoked + and then immediately invoke it. + """ + + def __init__(self, pass_cls: type[VllmInductorPass], + vllm_config: VllmConfig): + self.pass_cls = pass_cls + self.vllm_config = weakref.proxy(vllm_config) # avoid cycle + + def __call__(self, graph: fx.Graph) -> None: + self.pass_ = self.pass_cls(self.vllm_config) + self.pass_(graph) class TestBackend: @@ -40,10 +60,16 @@ class TestBackend: example_inputs, config_patches=self.inductor_config) + @with_pattern_match_debug def post_pass(self, graph: fx.Graph): self.graph_pre_pass = deepcopy(graph) + + VllmInductorPass.dump_prefix = 0 for pass_ in self.custom_passes: pass_(graph) + VllmInductorPass.dump_prefix += 1 + + VllmInductorPass.dump_prefix = None self.graph_post_pass = deepcopy(graph) # assign by reference, will reflect the final state of the graph diff --git a/tests/compile/piecewise/test_full_cudagraph.py b/tests/compile/piecewise/test_full_cudagraph.py index 2454f85342eba..780a0d6b5c0e4 100644 --- a/tests/compile/piecewise/test_full_cudagraph.py +++ b/tests/compile/piecewise/test_full_cudagraph.py @@ -46,7 +46,10 @@ backend_configs = { # FA3 on Hopper "FA3": BackendConfig(name="FA3", - env_vars={"VLLM_FLASH_ATTN_VERSION": "3"}, + env_vars={ + "VLLM_FLASH_ATTN_VERSION": "3", + "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", + }, comp_config={ "cudagraph_mode": "FULL", }, @@ -66,6 +69,7 @@ backend_configs = { BackendConfig(name="FlashAttentionMLA", env_vars={ "VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA", + "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", }, comp_config={ "cudagraph_mode": "FULL_DECODE_ONLY", @@ -89,7 +93,10 @@ backend_configs = { # FA2 "FA2": BackendConfig(name="FA2", - env_vars={"VLLM_FLASH_ATTN_VERSION": "2"}, + env_vars={ + "VLLM_FLASH_ATTN_VERSION": "2", + "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", + }, comp_config={ "cudagraph_mode": "FULL", }), diff --git a/tests/compile/test_async_tp.py b/tests/compile/test_async_tp.py index 9a51e6b3514f4..1dc21365d5577 100644 --- a/tests/compile/test_async_tp.py +++ b/tests/compile/test_async_tp.py @@ -294,6 +294,8 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int, compiled_model = torch.compile(model, backend=backend) compiled_model(hidden_states) + assert async_tp_pass.matched_count == 1 + # In pre-nodes, all gather or reduce scatter should exist, # fused_matmul_reduce_scatter or fused_all_gather_matmul should not backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False) diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index 90e8e0ff95858..7afd6251bbbd5 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -4,7 +4,7 @@ import pytest import vllm from vllm.compilation.counter import compilation_counter -from vllm.config import VllmConfig +from vllm.config import CompilationConfig, VllmConfig from vllm.utils import _is_torch_equal_or_newer @@ -26,6 +26,14 @@ def test_use_cudagraphs_dynamic(monkeypatch): assert not vllm_config.compilation_config.use_cudagraph +def test_custom_op(): + # proper syntax + _ = CompilationConfig(custom_ops=["+quant_fp8", "-silu_and_mul"]) + + with pytest.raises(ValueError, match="Invalid syntax '"): + _ = CompilationConfig(custom_ops=["quant_fp8"]) + + # forked needed to workaround https://github.com/vllm-project/vllm/issues/21073 @pytest.mark.forked # NB: We don't test VLLM_DISABLE_COMPILE_CACHE=0 because that depends diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index 0c7e6fbccf20c..2ee9aa7476beb 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -8,9 +8,10 @@ import vllm.envs as envs from vllm import LLM, SamplingParams from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass from vllm.compilation.fix_functionalization import FixFunctionalizationPass -from vllm.compilation.fusion import FUSED_OPS, FusionPass +from vllm.compilation.fusion import FUSED_OPS, RMSNormQuantFusionPass from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func from vllm.compilation.noop_elimination import NoOpEliminationPass +from vllm.compilation.post_cleanup import PostCleanupPass from vllm.config import CompilationConfig, PassConfig, VllmConfig from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8DynamicTokenSym, kFp8StaticTensorSym) @@ -58,11 +59,12 @@ def test_fix_functionalization(model: str, quant_key: QuantKey, vllm_config.compilation_config = CompilationConfig( pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True)) noop_pass = NoOpEliminationPass(vllm_config) - fusion_pass = FusionPass.instance(vllm_config) + fusion_pass = RMSNormQuantFusionPass(vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config) - passes = [noop_pass, fusion_pass, act_quant_fusion_pass - ] if do_fusion else [noop_pass] + passes = [noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass + ] if do_fusion else [noop_pass, cleanup_pass] func_pass = FixFunctionalizationPass(vllm_config) backend_func = TestBackend(*passes, func_pass) backend_no_func = TestBackend(*passes) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index eedb9bdcd5299..3d8897d3f18b8 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -4,11 +4,11 @@ import pytest import torch -import vllm.envs as envs import vllm.plugins from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey, - FusionPass) + RMSNormQuantFusionPass) from vllm.compilation.noop_elimination import NoOpEliminationPass +from vllm.compilation.post_cleanup import PostCleanupPass from vllm.config import (CompilationConfig, CompilationLevel, PassConfig, VllmConfig) from vllm.model_executor.layers.layernorm import RMSNorm @@ -79,15 +79,15 @@ class TestModel(torch.nn.Module): @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("hidden_size", [64, 3392, 4096]) -@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049]) +@pytest.mark.parametrize("hidden_size", [64]) +@pytest.mark.parametrize("num_tokens", [257]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) @pytest.mark.parametrize("static", [True, False]) # cuda_force_torch used to test torch code path on platforms that # cutlass_fp8_supported() == True. @pytest.mark.parametrize("cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True]) -@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], +@pytest.mark.skipif(not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm") def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, cuda_force_torch): @@ -104,9 +104,10 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, with vllm.config.set_current_vllm_config(vllm_config): # Reshape pass is needed for the fusion pass to work noop_pass = NoOpEliminationPass(vllm_config) - fusion_pass = FusionPass.instance(vllm_config) + fusion_pass = RMSNormQuantFusionPass(vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) - backend = TestBackend(noop_pass, fusion_pass) + backend = TestBackend(noop_pass, fusion_pass, cleanup_pass) model = TestModel(hidden_size, eps, static, cuda_force_torch) # First dimension dynamic @@ -128,6 +129,8 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL) + assert fusion_pass.matched_count == 2 + # In pre-nodes, fp8 quant should be there and fused kernels should not backend.check_before_ops(model.ops_in_model_before()) diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index dd31e0db1f59f..60f32c863208d 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -9,6 +9,7 @@ import vllm.envs as envs from vllm.compilation.collective_fusion import AllReduceFusionPass from vllm.compilation.fix_functionalization import FixFunctionalizationPass from vllm.compilation.noop_elimination import NoOpEliminationPass +from vllm.compilation.post_cleanup import PostCleanupPass from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig, ModelConfig, PassConfig, VllmConfig) from vllm.distributed import tensor_model_parallel_all_reduce @@ -215,8 +216,10 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int, all_reduce_fusion_pass = AllReduceFusionPass(vllm_config) noop_pass = NoOpEliminationPass(vllm_config) func_pass = FixFunctionalizationPass(vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) - backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass) + backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass, + cleanup_pass) token_num = batch_size * seq_len model = test_model_cls(hidden_size, token_num) @@ -227,6 +230,7 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int, compiled_model = torch.compile(model, backend=backend) compiled_model(hidden_states, residual) + assert all_reduce_fusion_pass.matched_count == 1 backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False) backend.check_after_ops(model.ops_in_model_after()) del all_reduce_fusion_pass diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index c3f1c7481d1b3..c4cac95531926 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -6,18 +6,19 @@ from typing import Optional import pytest import torch._dynamo -from tests.compile.backend import TestBackend +from tests.compile.backend import LazyInitPass, TestBackend from tests.models.utils import check_outputs_equal from tests.v1.attention.utils import (BatchSpec, _Backend, create_common_attn_metadata) from vllm import LLM, SamplingParams from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant -from vllm.attention import Attention +from vllm.attention import Attention, AttentionMetadata from vllm.attention.selector import global_force_attn_backend_context_manager from vllm.compilation.fusion import QUANT_OPS from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.noop_elimination import NoOpEliminationPass +from vllm.compilation.post_cleanup import PostCleanupPass from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel, ModelConfig, PassConfig, SchedulerConfig, VllmConfig, set_current_vllm_config) @@ -104,7 +105,7 @@ def test_attention_fusion_v0(example_prompts, monkeypatch, model: str, # AttnFusionPass needs attention layers to be registered in config upon init # so we initialize it during compilation. - attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw) + attn_pass = LazyInitPass(AttnFusionPass, vllm_config) backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass) llm2 = LLM(model, enforce_eager=True, @@ -197,7 +198,8 @@ class AttentionQuantPatternModel(torch.nn.Module): device=self.device, ) - def build_attn_metadata(self, batch_size: int, use_hnd: bool): + def build_attn_metadata(self, batch_size: int, use_hnd: bool) \ + -> AttentionMetadata: """Initialize attention metadata.""" # Create common attn metadata @@ -447,9 +449,10 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, # Create test backend with fusion passes enabled noop_pass = NoOpEliminationPass(vllm_config) - attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw - ) - test_backend = TestBackend(noop_pass, attn_pass) + attn_pass = LazyInitPass(AttnFusionPass, vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) + + test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass) # Compile model with fusion enabled model_compiled = torch.compile(model_fused, @@ -485,6 +488,9 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, test_backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=True) + # access the underlying `AttnFusionPass` on the `LazyInitPass` + assert attn_pass.pass_.matched_count == sum(attn_fusion_supported) + # Check attention ops in the graph before and after fusion attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass)) attn_nodes_post = list(find_op_nodes(ATTN_OP, diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py index fb9f9dde22799..b2734e915bbbf 100644 --- a/tests/compile/test_sequence_parallelism.py +++ b/tests/compile/test_sequence_parallelism.py @@ -6,10 +6,12 @@ import torch import vllm.envs as envs from vllm.compilation.fix_functionalization import FixFunctionalizationPass -from vllm.compilation.fusion import FusionPass +from vllm.compilation.fusion import RMSNormQuantFusionPass from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func from vllm.compilation.noop_elimination import NoOpEliminationPass +from vllm.compilation.post_cleanup import PostCleanupPass from vllm.compilation.sequence_parallelism import SequenceParallelismPass +from vllm.compilation.vllm_inductor_pass import VllmInductorPass from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig, PassConfig, VllmConfig) from vllm.distributed import tensor_model_parallel_all_reduce @@ -104,7 +106,7 @@ class TestQuantModel(torch.nn.Module): # Initialize weights torch.nn.init.normal_(self.gate_proj, std=0.02) - self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=False) + self.fp8_linear = Fp8LinearOp(act_quant_static=True) self.scale = torch.rand(1, dtype=torch.float32) # Create a weight that is compatible with torch._scaled_mm, @@ -137,8 +139,7 @@ class TestQuantModel(torch.nn.Module): # layer normalization norm_output, residual_output = self.norm(all_reduce, residual) - # for static input quantization - # self.fp8_linear is initialized with use_per_token_if_dynamic=False + # scaled_mm with static input quantization fp8_linear_result = self.fp8_linear.apply(norm_output, self.w, self.wscale, @@ -253,16 +254,20 @@ def sequence_parallelism_pass_on_test_model( dtype=dtype, seed=42) - sequence_parallelism_pass = SequenceParallelismPass(vllm_config) noop_pass = NoOpEliminationPass(vllm_config) + sequence_parallelism_pass = SequenceParallelismPass(vllm_config) func_pass = FixFunctionalizationPass(vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) - passes_for_backend = [noop_pass, sequence_parallelism_pass] + passes_for_backend: list[VllmInductorPass] = \ + [noop_pass, sequence_parallelism_pass] if enable_fusion: - fusion_pass = FusionPass.instance(vllm_config) + fusion_pass = RMSNormQuantFusionPass(vllm_config) passes_for_backend.append(fusion_pass) + passes_for_backend.append(cleanup_pass) + backend_no_func = TestBackend(*passes_for_backend) backend_func = TestBackend(*passes_for_backend, func_pass) @@ -279,6 +284,8 @@ def sequence_parallelism_pass_on_test_model( compiled_model_func = torch.compile(model, backend=backend_func) compiled_model_func(hidden_states, residual) + assert sequence_parallelism_pass.matched_count == 1 + # In pre-nodes, all reduce should be there, # reduce scatter and all gather should not backend_no_func.check_before_ops(model.ops_in_model_before()) diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index ae190d25cad62..c445f4dde2cc4 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -15,6 +15,7 @@ from vllm.compilation.activation_quant_fusion import ( # yapf: enable from vllm.compilation.fusion import QUANT_OPS from vllm.compilation.noop_elimination import NoOpEliminationPass +from vllm.compilation.post_cleanup import PostCleanupPass from vllm.config import CompilationConfig, PassConfig, VllmConfig from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.quantization.utils.quant_utils import ( @@ -69,6 +70,10 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module): def __init__(self, hidden_size: int, x: torch.Tensor, **kwargs): super().__init__() + from vllm.compilation.activation_quant_fusion import ( + silu_and_mul_nvfp4_quant_supported) + assert silu_and_mul_nvfp4_quant_supported + self.silu_and_mul = SiluAndMul() # create nvfp4 weight @@ -127,7 +132,11 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class, pass_config=PassConfig(enable_fusion=True, enable_noop=True)) fusion_pass = ActivationQuantFusionPass(config) - backend = TestBackend(NoOpEliminationPass(config), fusion_pass) + passes = [ + NoOpEliminationPass(config), fusion_pass, + PostCleanupPass(config) + ] + backend = TestBackend(*passes) model = model_class(hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x) @@ -151,6 +160,8 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class, atol=atol, rtol=rtol) + assert fusion_pass.matched_count == 1 + # In pre-nodes, quant op should be present and fused kernels should not backend.check_before_ops(model.ops_in_model_before()) diff --git a/tests/distributed/test_torchrun_example_moe.py b/tests/distributed/test_torchrun_example_moe.py new file mode 100644 index 0000000000000..2d6b930fcc07e --- /dev/null +++ b/tests/distributed/test_torchrun_example_moe.py @@ -0,0 +1,81 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# unit test for `examples/offline_inference/torchrun_example.py` +import os +import random + +import torch.distributed as dist + +from vllm import LLM, SamplingParams +from vllm.distributed.parallel_state import get_tp_group, get_world_group + +dist.init_process_group(backend="gloo") + +# Create prompts +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] * 10 +dp_size = int(os.getenv("DP_SIZE", "1")) +dp_rank = int(os.getenv("DP_RANK", "0")) + +if dp_size > 1: + # distribute the prompts across the data parallel ranks + prompts = [ + prompt for idx, prompt in enumerate(prompts) + if idx % dp_size == dp_rank + ] + +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + +# set different `gpu_memory_utilization` and `swap_space` for different ranks, +# to test if all ranks agree on the same kv cache configuration. +llm = LLM(model="microsoft/Phi-mini-MoE-instruct", + tensor_parallel_size=int(os.getenv("TP_SIZE", "1")), + pipeline_parallel_size=int(os.getenv("PP_SIZE", "1")), + enable_expert_parallel=int(os.getenv("ENABLE_EP", "0")) == 1, + distributed_executor_backend="external_launcher", + gpu_memory_utilization=random.uniform(0.7, 0.9), + swap_space=random.randint(1, 4), + seed=0) + +outputs = llm.generate(prompts, sampling_params) + +group = get_world_group() if dp_size == 1 else get_tp_group() +cpu_group = group.cpu_group +group_rank = dist.get_rank(group=cpu_group) + + +def test_consistent_across_ranks(obj): + if group_rank == 0: + dist.broadcast_object_list([obj], src=group.ranks[0], group=cpu_group) + else: + container = [None] + dist.broadcast_object_list(container, + src=group.ranks[0], + group=cpu_group) + assert container[0] == obj + + +test_consistent_across_ranks( + llm.llm_engine.vllm_config.cache_config.num_cpu_blocks) +test_consistent_across_ranks( + llm.llm_engine.vllm_config.cache_config.num_gpu_blocks) + +# make sure we can access the model parameters from the calling process +# of the `LLM` instance. +params = list(llm.llm_engine.model_executor.driver_worker.worker.model_runner. + model.parameters()) +test_consistent_across_ranks(len(params)) + +# all ranks should have the same outputs +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + test_consistent_across_ranks(prompt) + test_consistent_across_ranks(generated_text) + print(f"Rank {group_rank}, Prompt: {prompt!r}, " + f"Generated text: {generated_text!r}") diff --git a/tests/entrypoints/openai/test_response_api_mcp_tools.py b/tests/entrypoints/openai/test_response_api_mcp_tools.py new file mode 100644 index 0000000000000..b0eb84712c199 --- /dev/null +++ b/tests/entrypoints/openai/test_response_api_mcp_tools.py @@ -0,0 +1,106 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import pytest_asyncio +from openai import OpenAI + +from ...utils import RemoteOpenAIServer + +MODEL_NAME = "openai/gpt-oss-20b" + + +@pytest.fixture(scope="module") +def monkeypatch_module(): + from _pytest.monkeypatch import MonkeyPatch + mpatch = MonkeyPatch() + yield mpatch + mpatch.undo() + + +@pytest.fixture(scope="module") +def mcp_disabled_server(monkeypatch_module: pytest.MonkeyPatch): + args = ["--enforce-eager", "--tool-server", "demo"] + + with monkeypatch_module.context() as m: + m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1") + m.setenv("PYTHON_EXECUTION_BACKEND", "dangerously_use_uv") + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest.fixture(scope="function") +def mcp_enabled_server(monkeypatch_module: pytest.MonkeyPatch): + args = ["--enforce-eager", "--tool-server", "demo"] + + with monkeypatch_module.context() as m: + m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1") + m.setenv("PYTHON_EXECUTION_BACKEND", "dangerously_use_uv") + m.setenv("GPT_OSS_SYSTEM_TOOL_MCP_LABELS", + "code_interpreter,container") + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def mcp_disabled_client(mcp_disabled_server): + async with mcp_disabled_server.get_async_client() as async_client: + yield async_client + + +@pytest_asyncio.fixture +async def mcp_enabled_client(mcp_enabled_server): + async with mcp_enabled_server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.skip(reason="Code interpreter tool is not available in CI yet.") +async def test_mcp_tool_env_flag_enabled(mcp_enabled_client: OpenAI, + model_name: str): + response = await mcp_enabled_client.responses.create( + model=model_name, + # TODO: Ideally should be able to set max tool calls + # to prevent multi-turn, but it is not currently supported + # would speed up the test + input=("What's the first 4 digits after the decimal point of " + "cube root of `19910212 * 20250910`? " + "Show only the digits. The python interpreter is not stateful " + "and you must print to see the output."), + tools=[{ + "type": "mcp", + "server_label": "code_interpreter", + # URL unused for DemoToolServer + "server_url": "http://localhost:8888" + }], + ) + assert response is not None + assert response.status == "completed" + assert response.usage.output_tokens_details.tool_output_tokens > 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.skip(reason="Code interpreter tool is not available in CI yet.") +async def test_mcp_tool_env_flag_disabled(mcp_disabled_client: OpenAI, + model_name: str): + response = await mcp_disabled_client.responses.create( + model=model_name, + # TODO: Ideally should be able to set max tool calls + # to prevent multi-turn, but it is not currently supported + # would speed up the test + input=("What's the first 4 digits after the decimal point of " + "cube root of `19910212 * 20250910`? " + "Show only the digits. The python interpreter is not stateful " + "and you must print to see the output."), + tools=[{ + "type": "mcp", + "server_label": "code_interpreter", + # URL unused for DemoToolServer + "server_url": "http://localhost:8888" + }], + ) + assert response is not None + assert response.status == "completed" + assert response.usage.output_tokens_details.tool_output_tokens == 0 diff --git a/tests/entrypoints/openai/test_response_api_with_harmony.py b/tests/entrypoints/openai/test_response_api_with_harmony.py index f3c3148577b85..23d8373d97809 100644 --- a/tests/entrypoints/openai/test_response_api_with_harmony.py +++ b/tests/entrypoints/openai/test_response_api_with_harmony.py @@ -454,7 +454,13 @@ async def test_web_search(client: OpenAI, model_name: str): async def test_code_interpreter(client: OpenAI, model_name: str): response = await client.responses.create( model=model_name, - input="Multiply 64548*15151 using builtin python interpreter.", + # TODO: Ideally should be able to set max tool calls + # to prevent multi-turn, but it is not currently supported + # would speed up the test + input=("What's the first 4 digits after the decimal point of " + "cube root of `19910212 * 20250910`? " + "Show only the digits. The python interpreter is not stateful " + "and you must print to see the output."), tools=[{ "type": "code_interpreter", "container": { @@ -464,6 +470,7 @@ async def test_code_interpreter(client: OpenAI, model_name: str): ) assert response is not None assert response.status == "completed" + assert response.usage.output_tokens_details.tool_output_tokens > 0 def get_weather(latitude, longitude): diff --git a/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py index 4bab849f47c27..1da06be2eba92 100644 --- a/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py @@ -5,6 +5,11 @@ import json import pytest +from vllm.entrypoints.openai.protocol import ChatCompletionRequest +from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import ( + Hermes2ProToolParser) +from vllm.transformers_utils.tokenizer import AnyTokenizer + from ....utils import RemoteOpenAIServer MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" @@ -37,7 +42,7 @@ TOOLS = [{ }, "unit": { "type": "string", - "enum": ["celsius", "fahrenheit"] + "enum": ["celsius", "fahrenheit"], }, }, "required": ["location"], @@ -45,8 +50,39 @@ TOOLS = [{ }, }] +PRODUCT_TOOLS = [{ + "type": "function", + "function": { + "name": "get_product_info", + "description": "Get detailed information of a product based on its " + "product ID.", + "parameters": { + "type": "object", + "properties": { + "inserted": { + "type": "boolean", + "description": "inserted.", + }, + "product_id": { + "type": "integer", + "description": "The product ID of the product.", + }, + }, + "required": ["product_id", "inserted"], + }, + }, +}] + MESSAGES = [{"role": "user", "content": "What's the weather like in Boston?"}] +PRODUCT_MESSAGES = [{ + "role": + "user", + "content": + "Hi! Do you have any detailed information about the product id " + "7355608 and inserted true?", +}] + @pytest.mark.asyncio async def test_non_streaming_tool_call(): @@ -113,8 +149,8 @@ async def test_streaming_tool_call(): if tool_chunk.function.name: tool_call_chunks[index]["name"] += tool_chunk.function.name if tool_chunk.function.arguments: - tool_call_chunks[index][ - "arguments"] += tool_chunk.function.arguments + tool_call_chunks[index]["arguments"] += ( + tool_chunk.function.arguments) assert len(tool_call_chunks) == 1 reconstructed_tool_call = tool_call_chunks[0] @@ -127,3 +163,295 @@ async def test_streaming_tool_call(): print("\n[Streaming Test Passed]") print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}") print(f"Reconstructed Arguments: {arguments}") + + +@pytest.mark.asyncio +async def test_non_streaming_product_tool_call(): + """Test tool call integer and boolean parameters in non-streaming mode.""" + with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server: + client = server.get_async_client() + + response = await client.chat.completions.create( + model=LORA_MODEL, + messages=PRODUCT_MESSAGES, + tools=PRODUCT_TOOLS, + tool_choice="auto", + temperature=0.66, + ) + + assert response.choices + choice = response.choices[0] + message = choice.message + + assert choice.finish_reason == "tool_calls" + assert message.tool_calls is not None + + tool_call = message.tool_calls[0] + assert tool_call.type == "function" + assert tool_call.function.name == "get_product_info" + + arguments = json.loads(tool_call.function.arguments) + assert "product_id" in arguments + assert "inserted" in arguments + + product_id = arguments.get("product_id") + inserted = arguments.get("inserted") + + assert isinstance(product_id, int) + assert product_id == 7355608 + assert isinstance(inserted, bool) + assert inserted is True + + print("\n[Non-Streaming Product Test Passed]") + print(f"Tool Call: {tool_call.function.name}") + print(f"Arguments: {arguments}") + + +@pytest.mark.asyncio +async def test_streaming_product_tool_call(): + """Test tool call integer and boolean parameters in streaming mode.""" + with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server: + client = server.get_async_client() + + stream = await client.chat.completions.create( + model=LORA_MODEL, + messages=PRODUCT_MESSAGES, + tools=PRODUCT_TOOLS, + tool_choice="auto", + temperature=0.66, + stream=True, + ) + + tool_call_chunks = {} + async for chunk in stream: + if not chunk.choices: + continue + + delta = chunk.choices[0].delta + if not delta or not delta.tool_calls: + continue + + for tool_chunk in delta.tool_calls: + index = tool_chunk.index + if index not in tool_call_chunks: + tool_call_chunks[index] = {"name": "", "arguments": ""} + + if tool_chunk.function.name: + tool_call_chunks[index]["name"] += tool_chunk.function.name + if tool_chunk.function.arguments: + tool_call_chunks[index]["arguments"] += ( + tool_chunk.function.arguments) + + assert len(tool_call_chunks) == 1 + reconstructed_tool_call = tool_call_chunks[0] + + assert reconstructed_tool_call["name"] == "get_product_info" + + arguments = json.loads(reconstructed_tool_call["arguments"]) + assert "product_id" in arguments + assert "inserted" in arguments + + # Handle type coercion for streaming test as well + product_id = arguments.get("product_id") + inserted = arguments.get("inserted") + + assert isinstance(product_id, int) + assert product_id == 7355608 + assert isinstance(inserted, bool) + assert inserted is True + + print("\n[Streaming Product Test Passed]") + print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}") + print(f"Reconstructed Arguments: {arguments}") + + +@pytest.fixture +def qwen_tokenizer() -> AnyTokenizer: + from vllm.transformers_utils.tokenizer import get_tokenizer + + return get_tokenizer("Qwen/Qwen3-32B") + + +@pytest.fixture +def hermes_parser(qwen_tokenizer: AnyTokenizer) -> Hermes2ProToolParser: + return Hermes2ProToolParser(qwen_tokenizer) + + +@pytest.fixture +def any_chat_request() -> ChatCompletionRequest: + return ChatCompletionRequest( + seed=42, + model="Qwen/Qwen3-32B", + messages=[], + ) + + +def test_hermes_parser_streaming_just_forward_text( + qwen_tokenizer: AnyTokenizer, + hermes_parser: Hermes2ProToolParser, + any_chat_request: ChatCompletionRequest, +) -> None: + text = ( + """This is some prior text that has nothing to do with tool calling.""" + ) + tokens = qwen_tokenizer.encode(text) + previous_text = "" + delta_messages = [] + for token in tokens: + delta_text = qwen_tokenizer.decode([token]) + current_text = previous_text + delta_text + delta = hermes_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=any_chat_request, + ) + previous_text = current_text + delta_messages.append(delta) + + for delta in delta_messages: + assert delta is not None + assert not delta.tool_calls + + print(delta_messages) + assert "".join([delta.content for delta in delta_messages]) == text + + +def test_hermes_parser_streaming_failure_case_bug_19056( + qwen_tokenizer: AnyTokenizer, + hermes_parser: Hermes2ProToolParser, + any_chat_request: ChatCompletionRequest, +) -> None: + text = """ +{"name": "final_answer", "arguments": {"trigger": true}} +""" + tokens = qwen_tokenizer.encode(text) + previous_text = "" + delta_messages = [] + for token in tokens: + text = qwen_tokenizer.decode([token]) + current_text = previous_text + text + delta = hermes_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=text, + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=any_chat_request, + ) + previous_text = current_text + if delta is not None: + delta_messages.append(delta) + + assert delta_messages[0].tool_calls[0].function.name == "final_answer" + tool_call_args = "".join(delta.tool_calls[0].function.arguments or "" + for delta in delta_messages) + assert tool_call_args == '{"trigger": true}' + + +def test_hermes_parser_streaming( + qwen_tokenizer: AnyTokenizer, + hermes_parser: Hermes2ProToolParser, + any_chat_request: ChatCompletionRequest, +) -> None: + text = '\ +{"name": "get_current_temperature",\ +"arguments": {"location":\ +"San Francisco, California, United States", "unit": "celsius"}}\ +' + + tokens = qwen_tokenizer.encode(text) + previous_text = "" + delta_messages = [] + for token in tokens: + text = qwen_tokenizer.decode([token]) + current_text = previous_text + text + delta = hermes_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=text, + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=any_chat_request, + ) + previous_text = current_text + if delta is not None: + delta_messages.append(delta) + print(delta_messages) + assert (delta_messages[0].tool_calls[0].function.name == + "get_current_temperature") + tool_call_args = "".join(delta.tool_calls[0].function.arguments or "" + for delta in delta_messages) + assert tool_call_args == ( + '{"location":"San Francisco, California, United States", ' + '"unit": "celsius"}') + + +def test_hermes_parser_non_streaming_no_tool_call( + hermes_parser: Hermes2ProToolParser, + any_chat_request: ChatCompletionRequest, +) -> None: + text = """This is not a tool call.""" + tool_call = hermes_parser.extract_tool_calls( + model_output=text, + request=any_chat_request, + ) + + assert tool_call is not None + assert not tool_call.tools_called + + +def test_hermes_parser_non_streaming_tool_call_between_tags( + hermes_parser: Hermes2ProToolParser, + any_chat_request: ChatCompletionRequest, +) -> None: + text = """ +{"name": "final_answer", "arguments": {"trigger": true}} +""" + tool_call = hermes_parser.extract_tool_calls( + model_output=text, + request=any_chat_request, + ) + + assert tool_call is not None + assert tool_call.tools_called + assert tool_call.tool_calls[0].function.name == "final_answer" + assert tool_call.tool_calls[0].function.arguments == '{"trigger": true}' + + +def test_hermes_parser_non_streaming_tool_call_until_eos( + hermes_parser: Hermes2ProToolParser, + any_chat_request: ChatCompletionRequest, +) -> None: + text = """ +{"name": "final_answer", "arguments": {"trigger": true}}""" + tool_call = hermes_parser.extract_tool_calls( + model_output=text, + request=any_chat_request, + ) + + assert tool_call is not None + assert tool_call.tools_called + assert tool_call.tool_calls[0].function.name == "final_answer" + assert tool_call.tool_calls[0].function.arguments == '{"trigger": true}' + + +def test_hermes_parser_non_streaming_tool_call_invalid_json( + hermes_parser: Hermes2ProToolParser, + any_chat_request: ChatCompletionRequest, +) -> None: + # Missing closing brace to trigger exception + text = """ +{"name": "final_answer", "arguments": {"trigger": true}""" + tool_call = hermes_parser.extract_tool_calls( + model_output=text, + request=any_chat_request, + ) + + assert tool_call is not None + assert not tool_call.tools_called diff --git a/tests/evals/gsm8k/README.md b/tests/evals/gsm8k/README.md index 58572c3a6fbc1..29c5199e1e87a 100644 --- a/tests/evals/gsm8k/README.md +++ b/tests/evals/gsm8k/README.md @@ -19,7 +19,7 @@ pytest -s -v tests/gsm8k/test_gsm8k_correctness.py \ vllm serve Qwen/Qwen2.5-1.5B-Instruct --port 8000 # Run evaluation -python tests/gsm8k/gsm8k_eval.py --port 8000 +python tests/evals/gsm8k/gsm8k_eval.py --port 8000 ``` ## Configuration Format diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index 38ab40f88ae0b..a4e200775c09d 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -67,7 +67,6 @@ def generate_params(): return params -@pytest.mark.skip(reason="Skipped for now. Should be revisited.") @pytest.mark.parametrize("device, name, use_mla, block_size", generate_params()) def test_env( @@ -189,7 +188,7 @@ def test_env( # FlashMLA only supports block_size == 64 pytest.skip("FlashMLA only supports block_size 64") else: - from vllm.attention.backends.flashmla import ( + from vllm.v1.attention.backends.mla.flashmla import ( # noqa: E501 is_flashmla_supported) is_supported, _ = is_flashmla_supported() if not is_supported: diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 8d6ce381976b5..39ea07309134b 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -959,7 +959,6 @@ def make_test_metadata( return attn_backend_obj.make_metadata( num_prefills=num_prefills, slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping), - multi_modal_placeholder_index_maps=None, enable_kv_scales_calculation=True, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, @@ -1009,7 +1008,6 @@ def make_test_metadata( return attn_backend_obj.make_metadata( num_prefills=num_prefills, slot_mapping=kv_mmap.slot_mapping, - multi_modal_placeholder_index_maps=None, enable_kv_scales_calculation=True, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 6735b7cd9e436..ced0afc50cb91 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -164,8 +164,8 @@ def populate_loras( weight=layer_weights, generate_embeddings_tensor=generate_embeddings_tensor, ) - sublora.lora_b = sublora.lora_b[:, (sublora_len * - i):(sublora_len * (i + 1))] + sublora.lora_b = sublora.lora_b[(sublora_len * + i):(sublora_len * (i + 1)), :] sublora.optimize() subloras.append(sublora) @@ -304,9 +304,9 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: result = embedding(input_) after_a = F.embedding( input_, - lora.lora_a, + lora.lora_a.T, ) - result += (after_a @ lora.lora_b) + result += (after_a @ lora.lora_b.T) expected_results.append(result) expected_result = torch.cat(expected_results) @@ -445,9 +445,9 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device, result = expanded_embedding(input_) after_a = F.embedding( original_input_, - lora.lora_a, + lora.lora_a.T, ) - result += (after_a @ lora.lora_b) + result += (after_a @ lora.lora_b.T) expected_results.append(result) expected_result = torch.cat(expected_results) @@ -575,7 +575,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, lm_head=linear, embedding_bias=None) result[:, vocab_size + embeddings_tensor_len:] = float("-inf") - result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling + result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling expected_results.append(result) expected_result = torch.cat(expected_results) logits_processor.org_vocab_size = vocab_size @@ -692,9 +692,10 @@ def test_linear_replicated( expected_results: list[torch.Tensor] = [] for input_, lora_id in zip(inputs, prompt_mapping): + lora = lora_dict[lora_id] result = linear(input_)[0] - result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling + result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling expected_results.append(result) expected_result = torch.cat(expected_results) @@ -817,7 +818,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, for input_, lora_id in zip(inputs, prompt_mapping): lora = lora_dict[lora_id] result = linear(input_)[0] - result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling + result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling expected_results.append(result) expected_result = torch.cat(expected_results) @@ -965,9 +966,10 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, result = linear(input_)[0] subloras = sublora_dict[lora_id] for i, sublora in enumerate(subloras): - result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] * - (i + 1)] += (input_ @ sublora.lora_a @ sublora.lora_b * - sublora.scaling) + result[:, sublora.lora_b.shape[0] * i:sublora.lora_b.shape[0] * + (i + 1)] += ( + input_ @ sublora.lora_a.T @ sublora.lora_b.T * + sublora.scaling) expected_results.append(result) expected_result = torch.cat(expected_results) diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index d7684fbf34abb..6f0a852314081 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -63,9 +63,9 @@ def test_from_lora_tensors(sql_lora_files, device): assert lora.lora_b is not None assert lora.lora_a.device == torch.device(device) assert lora.lora_b.device == torch.device(device) - assert (lora.lora_a.shape[1] == lora.lora_b.shape[0] + assert (lora.lora_a.shape[0] == lora.lora_b.shape[1] ), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}" - assert lora.lora_a.shape[1] == 8 + assert lora.lora_a.shape[0] == 8 embeddings_module = next( (k for k in EMBEDDING_MODULES if k in module_name), None) if embeddings_module: @@ -86,8 +86,8 @@ def create_lora(lora_id: int, model: nn.Module, sub_modules: list[str], name, 8, 16, - torch.rand([w.shape[1], 8], device=device), - torch.rand([8, w.shape[0]], device=device), + torch.rand([8, w.shape[1]], device=device), + torch.rand([w.shape[0], 8], device=device), ) return LoRAModel(lora_id, 8, loras) @@ -109,8 +109,8 @@ def create_packed_lora( replaced_module_name, 8, 16, - torch.rand([w.shape[1], 8], device=device), - torch.rand([8, w.shape[0] // len(replaced_module_names)], + torch.rand([8, w.shape[1]], device=device), + torch.rand([w.shape[0] // len(replaced_module_names), 8], device=device), ) return LoRAModel(lora_id, 8, loras) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index ab475904d4938..0432a1a9bba07 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -36,10 +36,10 @@ class DummyLoRAManager: module_name, rank=rank, lora_alpha=1, - lora_a=torch.rand([weight.shape[1], rank], + lora_a=torch.rand([rank, weight.shape[1]], dtype=weight.dtype, device=self._device), - lora_b=torch.rand([rank, weight.shape[0]], + lora_b=torch.rand([weight.shape[0], rank], dtype=weight.dtype, device=self._device), ) @@ -67,8 +67,8 @@ class DummyLoRAManager: module_name, rank=rank, lora_alpha=1, - lora_a=torch.rand([input_dim, rank], device="cuda"), - lora_b=torch.rand([rank, output_dim], device="cuda"), + lora_a=torch.rand([rank, input_dim], device="cuda"), + lora_b=torch.rand([output_dim, input_dim], device="cuda"), embeddings_tensor=embeddings_tensor, ) self.set_module_lora(module_name, lora) diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index 86139d598582d..92ce10a9efc0b 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional import pytest import torch @@ -34,15 +35,15 @@ class Relu3(ReLUSquaredActivation): [ # Default values based on compile level # - All by default (no Inductor compilation) - ("", 0, False, [True] * 4, True), - ("", 1, True, [True] * 4, True), - ("", 2, False, [True] * 4, True), + (None, 0, False, [True] * 4, True), + (None, 1, True, [True] * 4, True), + (None, 2, False, [True] * 4, True), # - None by default (with Inductor) - ("", 3, True, [False] * 4, False), - ("", 4, True, [False] * 4, False), + (None, 3, True, [False] * 4, False), + (None, 4, True, [False] * 4, False), # - All by default (without Inductor) - ("", 3, False, [True] * 4, True), - ("", 4, False, [True] * 4, True), + (None, 3, False, [True] * 4, True), + (None, 4, False, [True] * 4, True), # Explicitly enabling/disabling # # Default: all @@ -54,7 +55,7 @@ class Relu3(ReLUSquaredActivation): # All but SiluAndMul ("all,-silu_and_mul", 2, True, [1, 0, 1, 1], True), # All but ReLU3 (even if ReLU2 is on) - ("-relu3,relu2", 3, False, [1, 1, 1, 0], True), + ("-relu3,+relu2", 3, False, [1, 1, 1, 0], True), # RMSNorm and SiluAndMul ("none,-relu3,+rms_norm,+silu_and_mul", 4, False, [1, 1, 0, 0], False), # All but RMSNorm @@ -67,12 +68,13 @@ class Relu3(ReLUSquaredActivation): # All but RMSNorm ("all,-rms_norm", 4, True, [0, 1, 1, 1], True), ]) -def test_enabled_ops(env: str, torch_level: int, use_inductor: bool, +def test_enabled_ops(env: Optional[str], torch_level: int, use_inductor: bool, ops_enabled: list[int], default_on: bool): + custom_ops = env.split(',') if env else [] vllm_config = VllmConfig( compilation_config=CompilationConfig(use_inductor=bool(use_inductor), level=torch_level, - custom_ops=env.split(","))) + custom_ops=custom_ops)) with set_current_vllm_config(vllm_config): assert CustomOp.default_on() == default_on diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 0b1f90e27db82..e60a86075b8bc 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -20,7 +20,9 @@ pytestmark = pytest.mark.hybrid_model SSM_MODELS = [ "state-spaces/mamba-130m-hf", "tiiuae/falcon-mamba-tiny-dev", - "yujiepan/mamba2-codestral-v0.1-tiny-random", + # mamba2-codestral in transformers is broken pending: + # https://github.com/huggingface/transformers/pull/40861 + #"yujiepan/mamba2-codestral-v0.1-tiny-random", ] HYBRID_MODELS = [ @@ -31,18 +33,7 @@ HYBRID_MODELS = [ "ibm-granite/granite-4.0-tiny-preview", "tiiuae/Falcon-H1-0.5B-Base", "LiquidAI/LFM2-1.2B", -] - -V1_SUPPORTED_MODELS = [ - "state-spaces/mamba-130m-hf", - "ai21labs/Jamba-tiny-dev", - "pfnet/plamo-2-1b", - "yujiepan/mamba2-codestral-v0.1-tiny-random", - "Zyphra/Zamba2-1.2B-instruct", - "hmellor/tiny-random-BambaForCausalLM", - "ibm-granite/granite-4.0-tiny-preview", - "tiiuae/Falcon-H1-0.5B-Base", - "LiquidAI/LFM2-1.2B", + "tiny-random/qwen3-next-moe", ] FULL_CUDA_GRAPH_MODELS = [ @@ -51,10 +42,6 @@ FULL_CUDA_GRAPH_MODELS = [ "Zyphra/Zamba2-1.2B-instruct", ] -V0_UNSUPPORTED_MODELS = [ - "LiquidAI/LFM2-1.2B", -] - FP32_STATE_MODELS = [ "state-spaces/mamba-130m-hf", "Zyphra/Zamba2-1.2B-instruct", @@ -88,20 +75,16 @@ def test_models( hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, num_logprobs) - if model in V1_SUPPORTED_MODELS: - with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: - vllm_v1_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) - else: - vllm_v1_outputs = None + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: + vllm_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) - if model in V1_SUPPORTED_MODELS: - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_v1_outputs, - name_0="hf", - name_1="vllm-v1", - ) + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) @pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) @@ -299,14 +282,14 @@ def test_full_cuda_graph( example_prompts, max_tokens, num_logprobs) with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: - vllm_v1_outputs = vllm_model.generate_greedy_logprobs( + vllm_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) check_logprobs_close( outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_v1_outputs, + outputs_1_lst=vllm_outputs, name_0="hf", - name_1="vllm-v1", + name_1="vllm", ) @@ -340,12 +323,12 @@ def test_fp32_cache_state( with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS, **{cache_dtype_param: "float32"}) as vllm_model: - vllm_v1_outputs = vllm_model.generate_greedy_logprobs( + vllm_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) check_logprobs_close( outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_v1_outputs, + outputs_1_lst=vllm_outputs, name_0="hf", - name_1="vllm-v1", + name_1="vllm", ) diff --git a/tests/models/multimodal/generation/test_qwen2_vl.py b/tests/models/multimodal/generation/test_qwen2_vl.py index 8336ebc0d59cb..c8a3513ac7ad1 100644 --- a/tests/models/multimodal/generation/test_qwen2_vl.py +++ b/tests/models/multimodal/generation/test_qwen2_vl.py @@ -209,7 +209,6 @@ def batch_make_video_embeddings( return visual(pixel_values_on_device, grid_thw=video_grid_thw_on_device).cpu() - # V1 Test: this calls a V0 internal. video_embeds = torch.concat(llm.apply_model(get_image_embeds)) # split into original batches diff --git a/tests/models/registry.py b/tests/models/registry.py index e9cc5170ade74..8b62952ad5908 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -312,14 +312,12 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"), "PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"), "Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"), - "Phi4FlashForCausalLM": _HfExamplesInfo("microsoft/Phi-4-mini-flash-reasoning", # noqa: E501 - trust_remote_code=True, - v0_only=True, - max_model_len=10240), "PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct", trust_remote_code=True), "Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b", - trust_remote_code=True), + max_transformers_version="4.55.4", + transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers", # noqa: E501 + trust_remote_code=True), "QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat", max_transformers_version="4.53", transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers", # noqa: E501 @@ -330,7 +328,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"), "Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"), "Qwen3NextForCausalLM": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct", - min_transformers_version="4.56.2"), + extras={"tiny-random": "tiny-random/qwen3-next-moe"}, # noqa: E501 + min_transformers_version="4.56.3"), "RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"), "SeedOssForCausalLM": _HfExamplesInfo("ByteDance-Seed/Seed-OSS-36B-Instruct", # noqa: E501 trust_remote_code=True, @@ -448,6 +447,8 @@ _MULTIMODAL_EXAMPLE_MODELS = { max_transformers_version="4.48", # noqa: E501 transformers_version_reason="HF model is not compatible.", # noqa: E501 hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501 + "DotsOCRForCausalLM": _HfExamplesInfo("rednote-hilab/dots.ocr", + trust_remote_code=True), "Emu3ForConditionalGeneration": _HfExamplesInfo("BAAI/Emu3-Chat-hf"), "Ernie4_5_VLMoeForConditionalGeneration": _HfExamplesInfo("baidu/ERNIE-4.5-VL-28B-A3B-PT", # noqa: E501 trust_remote_code=True), @@ -560,10 +561,12 @@ _MULTIMODAL_EXAMPLE_MODELS = { "Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ"), # noqa: E501 "Qwen3VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen3-VL-4B-Instruct", # noqa: E501 max_model_len=4096, - min_transformers_version="4.57"), # noqa: E501 + min_transformers_version="4.57", + is_available_online=False), "Qwen3VLMoeForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen3-VL-30B-A3B-Instruct", # noqa: E501 - max_model_len=4096, - min_transformers_version="4.57"), + max_model_len=4096, + min_transformers_version="4.57", + is_available_online=False), "RForConditionalGeneration": _HfExamplesInfo("YannQi/R-4B", trust_remote_code=True), "SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B", @@ -640,7 +643,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { trust_remote_code=True, speculative_model="XiaomiMiMo/MiMo-7B-RL"), "Qwen3NextMTP": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct", - min_transformers_version="4.56.2"), + min_transformers_version="4.56.3"), } _TRANSFORMERS_BACKEND_MODELS = { diff --git a/tests/models/test_vision.py b/tests/models/test_vision.py index 310d3a3719b65..8744bcbd3a2a6 100644 --- a/tests/models/test_vision.py +++ b/tests/models/test_vision.py @@ -1,10 +1,20 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math import pytest import torch +import torch.multiprocessing as mp -from vllm.model_executor.models.vision import resolve_visual_encoder_outputs +from tests.utils import multi_gpu_test +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import (init_distributed_environment, + initialize_model_parallel) +from vllm.model_executor.models.vision import ( + get_load_balance_assignment, resolve_visual_encoder_outputs, + run_dp_sharded_mrope_vision_model, run_dp_sharded_vision_model) +from vllm.platforms import current_platform +from vllm.utils import get_open_port, update_environment_variables @pytest.mark.parametrize( @@ -33,3 +43,415 @@ def test_resolve_visual_encoder_outputs(feature_sample_layers, post_layer_norm=None, max_possible_layers=max_possible_layers) assert torch.equal(torch.tensor(expected_features), output_tensor) + + +class SimpleLinearModel(torch.nn.Module): + """A simple linear vision model for testing.""" + + def __init__(self, input_dim: int = 3 * 224 * 224, output_dim: int = 32): + super().__init__() + self.flatten = torch.nn.Flatten() + self.linear = torch.nn.Linear(input_dim, output_dim) + + def forward(self, x: torch.Tensor): + # Flatten the input and apply linear transformation + x = self.flatten(x) + return self.linear(x) + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize( + "batch_size", + [ + 1, # Single image + 4, # Small batch + 5, # Odd batch size (for testing padding) + ], +) +def test_run_dp_sharded_vision_model(batch_size: int): + world_size = 2 + # Launch processes + mp.spawn( + run_dp_sharded_vision_model_vs_direct, + args=( + world_size, + batch_size, + get_open_port(), + ), + nprocs=world_size, + ) + + +def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int, + batch_size: int, master_port: int): + """ + Test that run_dp_sharded_vision_model produces the same results as + calling the model directly. + """ + + # Set random seed for reproducibility + current_platform.seed_everything(0) + + device = f"{current_platform.device_name}:{local_rank}" + current_platform.set_device(device) + torch.set_default_device(device) + + update_environment_variables({ + 'RANK': str(local_rank), + 'LOCAL_RANK': str(local_rank), + 'WORLD_SIZE': str(world_size), + 'MASTER_ADDR': 'localhost', + 'MASTER_PORT': str(master_port), + }) + + # initialize distributed + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # Create a test input tensor + image_input = torch.randn(batch_size, 3, 224, 224) + + # Create a simple linear model + vision_model = SimpleLinearModel() + + # Run the model directly on the full input + with torch.inference_mode(): + direct_output = vision_model(image_input) + + # Run the model through the sharded function + with torch.inference_mode(): + sharded_output = run_dp_sharded_vision_model(image_input, vision_model) + + # Check that the world size is set up correctly + assert get_tensor_model_parallel_world_size() == world_size + + # Check that the outputs have the same shape + assert direct_output.shape == sharded_output.shape + + # Check that the outputs are close (they should be identical) + assert torch.allclose(direct_output, sharded_output, rtol=1e-5, atol=1e-5) + + +@pytest.mark.parametrize( + "sizes,num_gpus,expected_shuffle_indices,expected_gpu_sample_counts," + "expected_grouped_sizes_per_gpu,test_description", + [ + # Empty input + ([], 2, [], [0, 0], [0, 0], "empty input"), + + # Fewer samples than GPUs + ([100, 200], 4, [1, 0], [1, 1, 0, 0], [200, 100, 0, 0 + ], "fewer samples than GPUs"), + + # Single GPU + ([100, 200, 300], 1, [2, 1, 0], [3], [600], "single GPU"), + + # Balanced assignment + ([100, 100, 100, 100 + ], 2, [0, 2, 1, 3], [2, 2], [200, 200], "balanced assignment"), + + # Unbalanced sizes - this one is trickier since the algorithm is greedy + ([1000, 100, 200, 50], 2, [0, 2, 1, 3 + ], [1, 3], [1000, 350], "unbalanced sizes"), + ], +) +def test_get_load_balance_assignment_cases(sizes, num_gpus, + expected_shuffle_indices, + expected_gpu_sample_counts, + expected_grouped_sizes_per_gpu, + test_description): + """Test get_load_balance_assignment with various input cases.""" + result = get_load_balance_assignment(sizes, num_gpus=num_gpus) + (shuffle_indices, gpu_sample_counts, grouped_sizes_per_gpu) = result + + # Common assertions for all cases + assert len(shuffle_indices) == len(sizes) + assert len(gpu_sample_counts) == num_gpus + assert len(grouped_sizes_per_gpu) == num_gpus + assert sum(gpu_sample_counts) == len(sizes) + + assert shuffle_indices == expected_shuffle_indices + + assert gpu_sample_counts == expected_gpu_sample_counts + assert grouped_sizes_per_gpu == expected_grouped_sizes_per_gpu + + +class SimpleMRopeVisionModel(torch.nn.Module): + """A simple vision model for testing mrope functionality.""" + + def __init__(self, spatial_merge_size: int = 2, out_hidden_size: int = 64): + super().__init__() + self.spatial_merge_size = spatial_merge_size + self.out_hidden_size = out_hidden_size + self.linear = torch.nn.Linear(768, out_hidden_size) + + def forward(self, pixel_values: torch.Tensor, + grid_thw_list: list[list[int]]): + """Simple forward pass that simulates spatial merging.""" + # Apply linear transformation + embeddings = self.linear(pixel_values) + + # Simulate spatial merging by reducing the number of patches + merge_factor = self.spatial_merge_size * self.spatial_merge_size + + # Group patches and merge spatially + merged_embeddings = [] + start_idx = 0 + + for grid_thw in grid_thw_list: + num_patches = math.prod(grid_thw) + end_idx = start_idx + num_patches + + # Get patches for this image + image_patches = embeddings[start_idx:end_idx] + + # Simulate spatial merging by averaging groups of patches + merged_patches = num_patches // merge_factor + if merged_patches > 0: + # Reshape and average to simulate merging + reshaped = image_patches[:merged_patches * merge_factor].view( + merged_patches, merge_factor, -1) + merged = reshaped.mean(dim=1) + merged_embeddings.append(merged) + + start_idx = end_idx + + if merged_embeddings: + return torch.cat(merged_embeddings, dim=0) + else: + return torch.empty((0, self.out_hidden_size), + device=pixel_values.device, + dtype=pixel_values.dtype) + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize( + "batch_size", + [ + 1, # Single image + 3, # Small batch + 5, # Odd batch size (for testing padding) + ], +) +def test_run_dp_sharded_mrope_vision_model(batch_size: int): + world_size = 2 + # Launch processes + mp.spawn( + run_dp_sharded_mrope_vision_model_vs_direct, + args=( + world_size, + batch_size, + get_open_port(), + ), + nprocs=world_size, + ) + + +def run_dp_sharded_mrope_vision_model_vs_direct(local_rank: int, + world_size: int, + batch_size: int, + master_port: int): + """ + Test that run_dp_sharded_mrope_vision_model produces the same results as + calling the model directly. + """ + # Set random seed for reproducibility + current_platform.seed_everything(0) + device = f"{current_platform.device_name}:{local_rank}" + current_platform.set_device(device) + torch.set_default_device(device) + + update_environment_variables({ + 'RANK': str(local_rank), + 'LOCAL_RANK': str(local_rank), + 'WORLD_SIZE': str(world_size), + 'MASTER_ADDR': 'localhost', + 'MASTER_PORT': str(master_port), + }) + + # initialize distributed + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # Create test data + grid_thw_list = [] + pixel_values_list = [] + + for i in range(batch_size): + # Varying image sizes for better testing + t, h, w = 1, 4 + i, 4 + i + grid_thw_list.append([t, h, w]) + + num_patches = t * h * w + # Create random pixel values for this image + image_pixels = torch.randn(num_patches, 768) + pixel_values_list.append(image_pixels) + + # Concatenate all pixel values + pixel_values = torch.cat(pixel_values_list, dim=0) + + # Create a simple mrope vision model + vision_model = SimpleMRopeVisionModel() + + # Run the model directly on the full input (only on rank 0) + if local_rank == 0: + with torch.inference_mode(): + direct_output = vision_model(pixel_values, grid_thw_list) + + # Run the model through the sharded function + with torch.inference_mode(): + sharded_output = run_dp_sharded_mrope_vision_model(vision_model, + pixel_values, + grid_thw_list, + rope_type="rope_3d") + sharded_output = torch.cat(sharded_output, dim=0) + + # Check that the world size is set up correctly + assert get_tensor_model_parallel_world_size() == world_size + + # Compare outputs (only on rank 0) + if local_rank == 0: + # Check that the outputs have the same shape + assert direct_output.shape == sharded_output.shape + # Check that the outputs are close (they should be identical) + assert torch.allclose(direct_output, + sharded_output, + rtol=1e-5, + atol=1e-5) + + +@multi_gpu_test(num_gpus=2) +def test_run_dp_sharded_mrope_vision_model_empty_input(): + world_size = 2 + mp.spawn( + run_dp_sharded_mrope_vision_model_empty_input_worker, + args=(world_size, get_open_port()), + nprocs=world_size, + ) + + +def run_dp_sharded_mrope_vision_model_empty_input_worker( + local_rank: int, world_size: int, master_port: int): + """Test run_dp_sharded_mrope_vision_model with empty input.""" + # Set up distributed environment + device = f"{current_platform.device_name}:{local_rank}" + current_platform.set_device(device) + torch.set_default_device(device) + + update_environment_variables({ + 'RANK': str(local_rank), + 'LOCAL_RANK': str(local_rank), + 'WORLD_SIZE': str(world_size), + 'MASTER_ADDR': 'localhost', + 'MASTER_PORT': str(master_port), + }) + + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # Create empty inputs + pixel_values = torch.empty((0, 768)) + grid_thw_list: list[list[int]] = [] + + vision_model = SimpleMRopeVisionModel() + + # Should handle empty input gracefully + with torch.inference_mode(): + output = run_dp_sharded_mrope_vision_model(vision_model, + pixel_values, + grid_thw_list, + rope_type="rope_3d") + + assert len(output) == 0 + + +@multi_gpu_test(num_gpus=4) +def test_run_dp_sharded_mrope_vision_model_uneven_load(): + world_size = 4 + mp.spawn( + run_dp_sharded_mrope_vision_model_uneven_load_worker, + args=(world_size, get_open_port()), + nprocs=world_size, + ) + + +def run_dp_sharded_mrope_vision_model_uneven_load_worker( + local_rank: int, world_size: int, master_port: int): + """Test run_dp_sharded_mrope_vision_model with uneven load distribution.""" + # Set up distributed environment + current_platform.seed_everything(123) + device = f"{current_platform.device_name}:{local_rank}" + current_platform.set_device(device) + torch.set_default_device(device) + + update_environment_variables({ + 'RANK': str(local_rank), + 'LOCAL_RANK': str(local_rank), + 'WORLD_SIZE': str(world_size), + 'MASTER_ADDR': 'localhost', + 'MASTER_PORT': str(master_port), + }) + + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # Create images with very different sizes + grid_thw_list = [ + [1, 2, 2], # Small: 4 patches + [1, 8, 8], # Large: 64 patches + [1, 3, 3], # Medium: 9 patches + ] + + pixel_values_list = [] + for grid_thw in grid_thw_list: + num_patches = math.prod(grid_thw) + image_pixels = torch.randn(num_patches, 768) + pixel_values_list.append(image_pixels) + + pixel_values = torch.cat(pixel_values_list, dim=0) + vision_model = SimpleMRopeVisionModel() + + # Should handle uneven distribution without errors + with torch.inference_mode(): + output_tuple = run_dp_sharded_mrope_vision_model(vision_model, + pixel_values, + grid_thw_list, + rope_type="rope_3d") + + # Verify output shape is reasonable + merge_factor = vision_model.spatial_merge_size**2 + expected_output_patches = list( + math.prod(grid_thw) // merge_factor for grid_thw in grid_thw_list) + + for i, output in enumerate(output_tuple): + assert output.shape[0] == expected_output_patches[i] + assert output.shape[1] == vision_model.out_hidden_size + + +@pytest.mark.parametrize("spatial_merge_size", [2, 4]) +def test_simple_mrope_vision_model_spatial_merge(spatial_merge_size: int): + """Test SimpleMRopeVisionModel with different spatial merge sizes.""" + device = current_platform.device_type + + grid_thw_list = [[1, 4, 4], [1, 6, 6]] # Two images + pixel_values_list = [] + + for grid_thw in grid_thw_list: + num_patches = math.prod(grid_thw) + image_pixels = torch.randn(num_patches, 768, device=device) + pixel_values_list.append(image_pixels) + + pixel_values = torch.cat(pixel_values_list, dim=0) + vision_model = SimpleMRopeVisionModel( + spatial_merge_size=spatial_merge_size).to(device) + + with torch.inference_mode(): + output = vision_model(pixel_values, grid_thw_list) + + # Verify output dimensions based on spatial merging + total_patches = sum(math.prod(grid_thw) for grid_thw in grid_thw_list) + merge_factor = spatial_merge_size**2 + expected_output_patches = total_patches // merge_factor + + assert output.shape[0] == expected_output_patches + assert output.shape[1] == vision_model.out_hidden_size diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index e1e8282dd66d4..f36d94ca01551 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import base64 -import math import mimetypes import os from tempfile import NamedTemporaryFile, TemporaryDirectory @@ -10,22 +9,11 @@ from typing import TYPE_CHECKING, NamedTuple import numpy as np import pytest -import torch -import torch.multiprocessing as mp from PIL import Image, ImageChops -from tests.utils import multi_gpu_test -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.distributed.parallel_state import (init_distributed_environment, - initialize_model_parallel) from vllm.multimodal.image import convert_image_mode from vllm.multimodal.inputs import PlaceholderRange -from vllm.multimodal.utils import (MediaConnector, argsort_mm_positions, - get_load_balance_assignment, - run_dp_sharded_mrope_vision_model, - run_dp_sharded_vision_model) -from vllm.platforms import current_platform -from vllm.utils import get_open_port, update_environment_variables +from vllm.multimodal.utils import MediaConnector, argsort_mm_positions if TYPE_CHECKING: from vllm.multimodal.inputs import MultiModalPlaceholderDict @@ -404,415 +392,3 @@ def test_argsort_mm_positions(): modality_idxs = argsort_mm_positions(mm_positions) assert modality_idxs == expected_modality_idxs - - -class SimpleLinearModel(torch.nn.Module): - """A simple linear vision model for testing.""" - - def __init__(self, input_dim: int = 3 * 224 * 224, output_dim: int = 32): - super().__init__() - self.flatten = torch.nn.Flatten() - self.linear = torch.nn.Linear(input_dim, output_dim) - - def forward(self, x: torch.Tensor): - # Flatten the input and apply linear transformation - x = self.flatten(x) - return self.linear(x) - - -@multi_gpu_test(num_gpus=2) -@pytest.mark.parametrize( - "batch_size", - [ - 1, # Single image - 4, # Small batch - 5, # Odd batch size (for testing padding) - ], -) -def test_run_dp_sharded_vision_model(batch_size: int): - world_size = 2 - # Launch processes - mp.spawn( - run_dp_sharded_vision_model_vs_direct, - args=( - world_size, - batch_size, - get_open_port(), - ), - nprocs=world_size, - ) - - -def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int, - batch_size: int, master_port: int): - """ - Test that run_dp_sharded_vision_model produces the same results as - calling the model directly. - """ - - # Set random seed for reproducibility - current_platform.seed_everything(0) - - device = f"{current_platform.device_name}:{local_rank}" - current_platform.set_device(device) - torch.set_default_device(device) - - update_environment_variables({ - 'RANK': str(local_rank), - 'LOCAL_RANK': str(local_rank), - 'WORLD_SIZE': str(world_size), - 'MASTER_ADDR': 'localhost', - 'MASTER_PORT': str(master_port), - }) - - # initialize distributed - init_distributed_environment() - initialize_model_parallel(tensor_model_parallel_size=world_size) - - # Create a test input tensor - image_input = torch.randn(batch_size, 3, 224, 224) - - # Create a simple linear model - vision_model = SimpleLinearModel() - - # Run the model directly on the full input - with torch.inference_mode(): - direct_output = vision_model(image_input) - - # Run the model through the sharded function - with torch.inference_mode(): - sharded_output = run_dp_sharded_vision_model(image_input, vision_model) - - # Check that the world size is set up correctly - assert get_tensor_model_parallel_world_size() == world_size - - # Check that the outputs have the same shape - assert direct_output.shape == sharded_output.shape - - # Check that the outputs are close (they should be identical) - assert torch.allclose(direct_output, sharded_output, rtol=1e-5, atol=1e-5) - - -@pytest.mark.parametrize( - "sizes,num_gpus,expected_shuffle_indices,expected_gpu_sample_counts," - "expected_grouped_sizes_per_gpu,test_description", - [ - # Empty input - ([], 2, [], [0, 0], [0, 0], "empty input"), - - # Fewer samples than GPUs - ([100, 200], 4, [1, 0], [1, 1, 0, 0], [200, 100, 0, 0 - ], "fewer samples than GPUs"), - - # Single GPU - ([100, 200, 300], 1, [2, 1, 0], [3], [600], "single GPU"), - - # Balanced assignment - ([100, 100, 100, 100 - ], 2, [0, 2, 1, 3], [2, 2], [200, 200], "balanced assignment"), - - # Unbalanced sizes - this one is trickier since the algorithm is greedy - ([1000, 100, 200, 50], 2, [0, 2, 1, 3 - ], [1, 3], [1000, 350], "unbalanced sizes"), - ], -) -def test_get_load_balance_assignment_cases(sizes, num_gpus, - expected_shuffle_indices, - expected_gpu_sample_counts, - expected_grouped_sizes_per_gpu, - test_description): - """Test get_load_balance_assignment with various input cases.""" - result = get_load_balance_assignment(sizes, num_gpus=num_gpus) - (shuffle_indices, gpu_sample_counts, grouped_sizes_per_gpu) = result - - # Common assertions for all cases - assert len(shuffle_indices) == len(sizes) - assert len(gpu_sample_counts) == num_gpus - assert len(grouped_sizes_per_gpu) == num_gpus - assert sum(gpu_sample_counts) == len(sizes) - - assert shuffle_indices == expected_shuffle_indices - - assert gpu_sample_counts == expected_gpu_sample_counts - assert grouped_sizes_per_gpu == expected_grouped_sizes_per_gpu - - -class SimpleMRopeVisionModel(torch.nn.Module): - """A simple vision model for testing mrope functionality.""" - - def __init__(self, spatial_merge_size: int = 2, out_hidden_size: int = 64): - super().__init__() - self.spatial_merge_size = spatial_merge_size - self.out_hidden_size = out_hidden_size - self.linear = torch.nn.Linear(768, out_hidden_size) - - def forward(self, pixel_values: torch.Tensor, - grid_thw_list: list[list[int]]): - """Simple forward pass that simulates spatial merging.""" - # Apply linear transformation - embeddings = self.linear(pixel_values) - - # Simulate spatial merging by reducing the number of patches - merge_factor = self.spatial_merge_size * self.spatial_merge_size - - # Group patches and merge spatially - merged_embeddings = [] - start_idx = 0 - - for grid_thw in grid_thw_list: - num_patches = math.prod(grid_thw) - end_idx = start_idx + num_patches - - # Get patches for this image - image_patches = embeddings[start_idx:end_idx] - - # Simulate spatial merging by averaging groups of patches - merged_patches = num_patches // merge_factor - if merged_patches > 0: - # Reshape and average to simulate merging - reshaped = image_patches[:merged_patches * merge_factor].view( - merged_patches, merge_factor, -1) - merged = reshaped.mean(dim=1) - merged_embeddings.append(merged) - - start_idx = end_idx - - if merged_embeddings: - return torch.cat(merged_embeddings, dim=0) - else: - return torch.empty((0, self.out_hidden_size), - device=pixel_values.device, - dtype=pixel_values.dtype) - - -@multi_gpu_test(num_gpus=2) -@pytest.mark.parametrize( - "batch_size", - [ - 1, # Single image - 3, # Small batch - 5, # Odd batch size (for testing padding) - ], -) -def test_run_dp_sharded_mrope_vision_model(batch_size: int): - world_size = 2 - # Launch processes - mp.spawn( - run_dp_sharded_mrope_vision_model_vs_direct, - args=( - world_size, - batch_size, - get_open_port(), - ), - nprocs=world_size, - ) - - -def run_dp_sharded_mrope_vision_model_vs_direct(local_rank: int, - world_size: int, - batch_size: int, - master_port: int): - """ - Test that run_dp_sharded_mrope_vision_model produces the same results as - calling the model directly. - """ - # Set random seed for reproducibility - current_platform.seed_everything(0) - device = f"{current_platform.device_name}:{local_rank}" - current_platform.set_device(device) - torch.set_default_device(device) - - update_environment_variables({ - 'RANK': str(local_rank), - 'LOCAL_RANK': str(local_rank), - 'WORLD_SIZE': str(world_size), - 'MASTER_ADDR': 'localhost', - 'MASTER_PORT': str(master_port), - }) - - # initialize distributed - init_distributed_environment() - initialize_model_parallel(tensor_model_parallel_size=world_size) - - # Create test data - grid_thw_list = [] - pixel_values_list = [] - - for i in range(batch_size): - # Varying image sizes for better testing - t, h, w = 1, 4 + i, 4 + i - grid_thw_list.append([t, h, w]) - - num_patches = t * h * w - # Create random pixel values for this image - image_pixels = torch.randn(num_patches, 768) - pixel_values_list.append(image_pixels) - - # Concatenate all pixel values - pixel_values = torch.cat(pixel_values_list, dim=0) - - # Create a simple mrope vision model - vision_model = SimpleMRopeVisionModel() - - # Run the model directly on the full input (only on rank 0) - if local_rank == 0: - with torch.inference_mode(): - direct_output = vision_model(pixel_values, grid_thw_list) - - # Run the model through the sharded function - with torch.inference_mode(): - sharded_output = run_dp_sharded_mrope_vision_model(vision_model, - pixel_values, - grid_thw_list, - rope_type="rope_3d") - sharded_output = torch.cat(sharded_output, dim=0) - - # Check that the world size is set up correctly - assert get_tensor_model_parallel_world_size() == world_size - - # Compare outputs (only on rank 0) - if local_rank == 0: - # Check that the outputs have the same shape - assert direct_output.shape == sharded_output.shape - # Check that the outputs are close (they should be identical) - assert torch.allclose(direct_output, - sharded_output, - rtol=1e-5, - atol=1e-5) - - -@multi_gpu_test(num_gpus=2) -def test_run_dp_sharded_mrope_vision_model_empty_input(): - world_size = 2 - mp.spawn( - run_dp_sharded_mrope_vision_model_empty_input_worker, - args=(world_size, get_open_port()), - nprocs=world_size, - ) - - -def run_dp_sharded_mrope_vision_model_empty_input_worker( - local_rank: int, world_size: int, master_port: int): - """Test run_dp_sharded_mrope_vision_model with empty input.""" - # Set up distributed environment - device = f"{current_platform.device_name}:{local_rank}" - current_platform.set_device(device) - torch.set_default_device(device) - - update_environment_variables({ - 'RANK': str(local_rank), - 'LOCAL_RANK': str(local_rank), - 'WORLD_SIZE': str(world_size), - 'MASTER_ADDR': 'localhost', - 'MASTER_PORT': str(master_port), - }) - - init_distributed_environment() - initialize_model_parallel(tensor_model_parallel_size=world_size) - - # Create empty inputs - pixel_values = torch.empty((0, 768)) - grid_thw_list: list[list[int]] = [] - - vision_model = SimpleMRopeVisionModel() - - # Should handle empty input gracefully - with torch.inference_mode(): - output = run_dp_sharded_mrope_vision_model(vision_model, - pixel_values, - grid_thw_list, - rope_type="rope_3d") - - assert len(output) == 0 - - -@multi_gpu_test(num_gpus=4) -def test_run_dp_sharded_mrope_vision_model_uneven_load(): - world_size = 4 - mp.spawn( - run_dp_sharded_mrope_vision_model_uneven_load_worker, - args=(world_size, get_open_port()), - nprocs=world_size, - ) - - -def run_dp_sharded_mrope_vision_model_uneven_load_worker( - local_rank: int, world_size: int, master_port: int): - """Test run_dp_sharded_mrope_vision_model with uneven load distribution.""" - # Set up distributed environment - current_platform.seed_everything(123) - device = f"{current_platform.device_name}:{local_rank}" - current_platform.set_device(device) - torch.set_default_device(device) - - update_environment_variables({ - 'RANK': str(local_rank), - 'LOCAL_RANK': str(local_rank), - 'WORLD_SIZE': str(world_size), - 'MASTER_ADDR': 'localhost', - 'MASTER_PORT': str(master_port), - }) - - init_distributed_environment() - initialize_model_parallel(tensor_model_parallel_size=world_size) - - # Create images with very different sizes - grid_thw_list = [ - [1, 2, 2], # Small: 4 patches - [1, 8, 8], # Large: 64 patches - [1, 3, 3], # Medium: 9 patches - ] - - pixel_values_list = [] - for grid_thw in grid_thw_list: - num_patches = math.prod(grid_thw) - image_pixels = torch.randn(num_patches, 768) - pixel_values_list.append(image_pixels) - - pixel_values = torch.cat(pixel_values_list, dim=0) - vision_model = SimpleMRopeVisionModel() - - # Should handle uneven distribution without errors - with torch.inference_mode(): - output_tuple = run_dp_sharded_mrope_vision_model(vision_model, - pixel_values, - grid_thw_list, - rope_type="rope_3d") - - # Verify output shape is reasonable - merge_factor = vision_model.spatial_merge_size**2 - expected_output_patches = list( - math.prod(grid_thw) // merge_factor for grid_thw in grid_thw_list) - - for i, output in enumerate(output_tuple): - assert output.shape[0] == expected_output_patches[i] - assert output.shape[1] == vision_model.out_hidden_size - - -@pytest.mark.parametrize("spatial_merge_size", [2, 4]) -def test_simple_mrope_vision_model_spatial_merge(spatial_merge_size: int): - """Test SimpleMRopeVisionModel with different spatial merge sizes.""" - device = current_platform.device_type - - grid_thw_list = [[1, 4, 4], [1, 6, 6]] # Two images - pixel_values_list = [] - - for grid_thw in grid_thw_list: - num_patches = math.prod(grid_thw) - image_pixels = torch.randn(num_patches, 768, device=device) - pixel_values_list.append(image_pixels) - - pixel_values = torch.cat(pixel_values_list, dim=0) - vision_model = SimpleMRopeVisionModel( - spatial_merge_size=spatial_merge_size).to(device) - - with torch.inference_mode(): - output = vision_model(pixel_values, grid_thw_list) - - # Verify output dimensions based on spatial merging - total_patches = sum(math.prod(grid_thw) for grid_thw in grid_thw_list) - merge_factor = spatial_merge_size**2 - expected_output_patches = total_patches // merge_factor - - assert output.shape[0] == expected_output_patches - assert output.shape[1] == vision_model.out_hidden_size diff --git a/tests/test_envs.py b/tests/test_envs.py new file mode 100644 index 0000000000000..f81a6e2e415cd --- /dev/null +++ b/tests/test_envs.py @@ -0,0 +1,216 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +from unittest.mock import patch + +import pytest + +from vllm.envs import env_list_with_choices, env_with_choices + + +class TestEnvWithChoices: + """Test cases for env_with_choices function.""" + + def test_default_value_returned_when_env_not_set(self): + """Test default is returned when env var is not set.""" + env_func = env_with_choices("NONEXISTENT_ENV", "default", + ["option1", "option2"]) + assert env_func() == "default" + + def test_none_default_returned_when_env_not_set(self): + """Test that None is returned when env not set and default is None.""" + env_func = env_with_choices("NONEXISTENT_ENV", None, + ["option1", "option2"]) + assert env_func() is None + + def test_valid_value_returned_case_sensitive(self): + """Test that valid value is returned in case sensitive mode.""" + with patch.dict(os.environ, {"TEST_ENV": "option1"}): + env_func = env_with_choices("TEST_ENV", + "default", ["option1", "option2"], + case_sensitive=True) + assert env_func() == "option1" + + def test_valid_lowercase_value_returned_case_insensitive(self): + """Test that lowercase value is accepted in case insensitive mode.""" + with patch.dict(os.environ, {"TEST_ENV": "option1"}): + env_func = env_with_choices("TEST_ENV", + "default", ["OPTION1", "OPTION2"], + case_sensitive=False) + assert env_func() == "option1" + + def test_valid_uppercase_value_returned_case_insensitive(self): + """Test that uppercase value is accepted in case insensitive mode.""" + with patch.dict(os.environ, {"TEST_ENV": "OPTION1"}): + env_func = env_with_choices("TEST_ENV", + "default", ["option1", "option2"], + case_sensitive=False) + assert env_func() == "OPTION1" + + def test_invalid_value_raises_error_case_sensitive(self): + """Test that invalid value raises ValueError in case sensitive mode.""" + with patch.dict(os.environ, {"TEST_ENV": "invalid"}): + env_func = env_with_choices("TEST_ENV", + "default", ["option1", "option2"], + case_sensitive=True) + with pytest.raises(ValueError, + match="Invalid value 'invalid' for TEST_ENV"): + env_func() + + def test_case_mismatch_raises_error_case_sensitive(self): + """Test that case mismatch raises ValueError in case sensitive mode.""" + with patch.dict(os.environ, {"TEST_ENV": "OPTION1"}): + env_func = env_with_choices("TEST_ENV", + "default", ["option1", "option2"], + case_sensitive=True) + with pytest.raises(ValueError, + match="Invalid value 'OPTION1' for TEST_ENV"): + env_func() + + def test_invalid_value_raises_error_case_insensitive(self): + """Test that invalid value raises ValueError when case insensitive.""" + with patch.dict(os.environ, {"TEST_ENV": "invalid"}): + env_func = env_with_choices("TEST_ENV", + "default", ["option1", "option2"], + case_sensitive=False) + with pytest.raises(ValueError, + match="Invalid value 'invalid' for TEST_ENV"): + env_func() + + def test_callable_choices_resolved_correctly(self): + """Test that callable choices are resolved correctly.""" + + def get_choices(): + return ["dynamic1", "dynamic2"] + + with patch.dict(os.environ, {"TEST_ENV": "dynamic1"}): + env_func = env_with_choices("TEST_ENV", "default", get_choices) + assert env_func() == "dynamic1" + + def test_callable_choices_with_invalid_value(self): + """Test that callable choices raise error for invalid values.""" + + def get_choices(): + return ["dynamic1", "dynamic2"] + + with patch.dict(os.environ, {"TEST_ENV": "invalid"}): + env_func = env_with_choices("TEST_ENV", "default", get_choices) + with pytest.raises(ValueError, + match="Invalid value 'invalid' for TEST_ENV"): + env_func() + + +class TestEnvListWithChoices: + """Test cases for env_list_with_choices function.""" + + def test_default_list_returned_when_env_not_set(self): + """Test that default list is returned when env var is not set.""" + env_func = env_list_with_choices("NONEXISTENT_ENV", + ["default1", "default2"], + ["option1", "option2"]) + assert env_func() == ["default1", "default2"] + + def test_empty_default_list_returned_when_env_not_set(self): + """Test that empty default list is returned when env not set.""" + env_func = env_list_with_choices("NONEXISTENT_ENV", [], + ["option1", "option2"]) + assert env_func() == [] + + def test_single_valid_value_parsed_correctly(self): + """Test that single valid value is parsed correctly.""" + with patch.dict(os.environ, {"TEST_ENV": "option1"}): + env_func = env_list_with_choices("TEST_ENV", [], + ["option1", "option2"]) + assert env_func() == ["option1"] + + def test_multiple_valid_values_parsed_correctly(self): + """Test that multiple valid values are parsed correctly.""" + with patch.dict(os.environ, {"TEST_ENV": "option1,option2"}): + env_func = env_list_with_choices("TEST_ENV", [], + ["option1", "option2"]) + assert env_func() == ["option1", "option2"] + + def test_values_with_whitespace_trimmed(self): + """Test that values with whitespace are trimmed correctly.""" + with patch.dict(os.environ, {"TEST_ENV": " option1 , option2 "}): + env_func = env_list_with_choices("TEST_ENV", [], + ["option1", "option2"]) + assert env_func() == ["option1", "option2"] + + def test_empty_values_filtered_out(self): + """Test that empty values are filtered out.""" + with patch.dict(os.environ, {"TEST_ENV": "option1,,option2,"}): + env_func = env_list_with_choices("TEST_ENV", [], + ["option1", "option2"]) + assert env_func() == ["option1", "option2"] + + def test_empty_string_returns_default(self): + """Test that empty string returns default.""" + with patch.dict(os.environ, {"TEST_ENV": ""}): + env_func = env_list_with_choices("TEST_ENV", ["default"], + ["option1", "option2"]) + assert env_func() == ["default"] + + def test_only_commas_returns_default(self): + """Test that string with only commas returns default.""" + with patch.dict(os.environ, {"TEST_ENV": ",,,"}): + env_func = env_list_with_choices("TEST_ENV", ["default"], + ["option1", "option2"]) + assert env_func() == ["default"] + + def test_case_sensitive_validation(self): + """Test case sensitive validation.""" + with patch.dict(os.environ, {"TEST_ENV": "option1,OPTION2"}): + env_func = env_list_with_choices("TEST_ENV", [], + ["option1", "option2"], + case_sensitive=True) + with pytest.raises(ValueError, + match="Invalid value 'OPTION2' in TEST_ENV"): + env_func() + + def test_case_insensitive_validation(self): + """Test case insensitive validation.""" + with patch.dict(os.environ, {"TEST_ENV": "OPTION1,option2"}): + env_func = env_list_with_choices("TEST_ENV", [], + ["option1", "option2"], + case_sensitive=False) + assert env_func() == ["OPTION1", "option2"] + + def test_invalid_value_in_list_raises_error(self): + """Test that invalid value in list raises ValueError.""" + with patch.dict(os.environ, {"TEST_ENV": "option1,invalid,option2"}): + env_func = env_list_with_choices("TEST_ENV", [], + ["option1", "option2"]) + with pytest.raises(ValueError, + match="Invalid value 'invalid' in TEST_ENV"): + env_func() + + def test_callable_choices_resolved_correctly(self): + """Test that callable choices are resolved correctly.""" + + def get_choices(): + return ["dynamic1", "dynamic2"] + + with patch.dict(os.environ, {"TEST_ENV": "dynamic1,dynamic2"}): + env_func = env_list_with_choices("TEST_ENV", [], get_choices) + assert env_func() == ["dynamic1", "dynamic2"] + + def test_callable_choices_with_invalid_value(self): + """Test that callable choices raise error for invalid values.""" + + def get_choices(): + return ["dynamic1", "dynamic2"] + + with patch.dict(os.environ, {"TEST_ENV": "dynamic1,invalid"}): + env_func = env_list_with_choices("TEST_ENV", [], get_choices) + with pytest.raises(ValueError, + match="Invalid value 'invalid' in TEST_ENV"): + env_func() + + def test_duplicate_values_preserved(self): + """Test that duplicate values in the list are preserved.""" + with patch.dict(os.environ, {"TEST_ENV": "option1,option1,option2"}): + env_func = env_list_with_choices("TEST_ENV", [], + ["option1", "option2"]) + assert env_func() == ["option1", "option1", "option2"] diff --git a/tests/tool_use/test_qwen3coder_tool_parser.py b/tests/tool_use/test_qwen3coder_tool_parser.py index f06fb2b9f2f04..57eaf84d36f23 100644 --- a/tests/tool_use/test_qwen3coder_tool_parser.py +++ b/tests/tool_use/test_qwen3coder_tool_parser.py @@ -13,6 +13,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ToolCall) from vllm.entrypoints.openai.tool_parsers.qwen3coder_tool_parser import ( Qwen3CoderToolParser) +from vllm.entrypoints.openai.tool_parsers.qwen3xml_tool_parser import ( + Qwen3XMLToolParser) from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer @@ -29,6 +31,21 @@ def qwen3_tool_parser(qwen3_tokenizer): return Qwen3CoderToolParser(qwen3_tokenizer) +@pytest.fixture +def qwen3_xml_tool_parser(qwen3_tokenizer): + return Qwen3XMLToolParser(qwen3_tokenizer) + + +@pytest.fixture(params=["original", "xml"]) +def qwen3_tool_parser_parametrized(qwen3_tool_parser, qwen3_xml_tool_parser, + request): + """Parameterized fixture that provides both parser types for testing""" + if request.param == "original": + return qwen3_tool_parser + else: + return qwen3_xml_tool_parser + + @pytest.fixture def sample_tools(): return [ @@ -95,7 +112,7 @@ def assert_tool_calls(actual_tool_calls: list[ToolCall], def stream_delta_message_generator( - qwen3_tool_parser: Qwen3CoderToolParser, + qwen3_tool_parser, qwen3_tokenizer: AnyTokenizer, model_output: str, request: Optional[ChatCompletionRequest] = None @@ -144,9 +161,9 @@ def stream_delta_message_generator( read_offset = new_read_offset -def test_extract_tool_calls_no_tools(qwen3_tool_parser): +def test_extract_tool_calls_no_tools(qwen3_tool_parser_parametrized): model_output = "This is a test response without any tool calls" - extracted_tool_calls = qwen3_tool_parser.extract_tool_calls( + extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls( model_output, request=None) # type: ignore[arg-type] assert not extracted_tool_calls.tools_called assert extracted_tool_calls.tool_calls == [] @@ -294,12 +311,13 @@ circle ], "Let me calculate that area for you."), ], ) -def test_extract_tool_calls(qwen3_tool_parser, sample_tools, model_output, - expected_tool_calls, expected_content): +def test_extract_tool_calls(qwen3_tool_parser_parametrized, sample_tools, + model_output, expected_tool_calls, + expected_content): request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) - extracted_tool_calls = qwen3_tool_parser.extract_tool_calls( + extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls( model_output, request=request) assert extracted_tool_calls.tools_called @@ -308,7 +326,8 @@ def test_extract_tool_calls(qwen3_tool_parser, sample_tools, model_output, assert extracted_tool_calls.content == expected_content -def test_extract_tool_calls_fallback_no_tags(qwen3_tool_parser, sample_tools): +def test_extract_tool_calls_fallback_no_tags(qwen3_tool_parser_parametrized, + sample_tools): """Test fallback parsing when XML tags are missing""" model_output = ''' @@ -322,7 +341,7 @@ TX request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) - extracted_tool_calls = qwen3_tool_parser.extract_tool_calls( + extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls( model_output, request=request) assert extracted_tool_calls.tools_called @@ -331,7 +350,7 @@ TX "get_current_weather") -def test_extract_tool_calls_type_conversion(qwen3_tool_parser): +def test_extract_tool_calls_type_conversion(qwen3_tool_parser_parametrized): """Test parameter type conversion based on tool schema""" tools = [ ChatCompletionToolsParam(type="function", @@ -381,7 +400,7 @@ hello world ''' request = ChatCompletionRequest(model=MODEL, messages=[], tools=tools) - extracted_tool_calls = qwen3_tool_parser.extract_tool_calls( + extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls( model_output, request=request) args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments) @@ -536,9 +555,10 @@ circle ], "Let me calculate that area for you."), ], ) -def test_extract_tool_calls_streaming(qwen3_tool_parser, qwen3_tokenizer, - sample_tools, model_output, - expected_tool_calls, expected_content): +def test_extract_tool_calls_streaming(qwen3_tool_parser_parametrized, + qwen3_tokenizer, sample_tools, + model_output, expected_tool_calls, + expected_content): """Test incremental streaming behavior including typed parameters""" request = ChatCompletionRequest(model=MODEL, messages=[], @@ -548,7 +568,8 @@ def test_extract_tool_calls_streaming(qwen3_tool_parser, qwen3_tokenizer, tool_states = {} # Track state per tool index for delta_message in stream_delta_message_generator( - qwen3_tool_parser, qwen3_tokenizer, model_output, request): + qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output, + request): # role should never be streamed from tool parser assert not delta_message.role @@ -609,7 +630,7 @@ def test_extract_tool_calls_streaming(qwen3_tool_parser, qwen3_tokenizer, def test_extract_tool_calls_missing_closing_parameter_tag( - qwen3_tool_parser, sample_tools): + qwen3_tool_parser_parametrized, sample_tools): """Test handling of missing closing tag""" # Using get_current_weather from sample_tools but with malformed XML model_output = '''Let me check the weather for you: @@ -629,7 +650,7 @@ fahrenheit request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) - extracted_tool_calls = qwen3_tool_parser.extract_tool_calls( + extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls( model_output, request=request) # The parser should handle the malformed XML gracefully @@ -652,7 +673,7 @@ fahrenheit def test_extract_tool_calls_streaming_missing_closing_tag( - qwen3_tool_parser, qwen3_tokenizer, sample_tools): + qwen3_tool_parser_parametrized, qwen3_tokenizer, sample_tools): """Test streaming with missing closing tag""" # Using get_current_weather from sample_tools but with malformed XML model_output = '''Let me check the weather for you: @@ -677,7 +698,8 @@ fahrenheit tool_states = {} for delta_message in stream_delta_message_generator( - qwen3_tool_parser, qwen3_tokenizer, model_output, request): + qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output, + request): if delta_message.content: other_content += delta_message.content @@ -727,9 +749,8 @@ fahrenheit assert args["unit"] == "fahrenheit" -def test_extract_tool_calls_streaming_incremental(qwen3_tool_parser, - qwen3_tokenizer, - sample_tools): +def test_extract_tool_calls_streaming_incremental( + qwen3_tool_parser_parametrized, qwen3_tokenizer, sample_tools): """Test that streaming is truly incremental""" model_output = '''I'll check the weather. @@ -748,7 +769,8 @@ TX chunks = [] for delta_message in stream_delta_message_generator( - qwen3_tool_parser, qwen3_tokenizer, model_output, request): + qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output, + request): chunks.append(delta_message) # Should have multiple chunks @@ -784,3 +806,49 @@ TX parsed_args = json.loads(full_args) assert parsed_args["city"] == "Dallas" assert parsed_args["state"] == "TX" + + +def test_extract_tool_calls_complex_type_with_single_quote( + qwen3_tool_parser_parametrized): + """Test parameter type conversion based on tool schema""" + tools = [ + ChatCompletionToolsParam(type="function", + function={ + "name": "test_types", + "parameters": { + "type": "object", + "properties": { + "int_param": { + "type": "integer" + }, + "float_param": { + "type": "float" + }, + "bool_param": { + "type": "boolean" + }, + "str_param": { + "type": "string" + }, + "obj_param": { + "type": "object" + } + } + } + }) + ] + + model_output = ''' + + +{'key': 'value'} + + +''' + + request = ChatCompletionRequest(model=MODEL, messages=[], tools=tools) + extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls( + model_output, request=request) + + args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments) + assert args["obj_param"] == {"key": "value"} diff --git a/tests/tpu/test_moe_pallas.py b/tests/tpu/test_moe_pallas.py index 407a824d81748..1e5d9d923d004 100644 --- a/tests/tpu/test_moe_pallas.py +++ b/tests/tpu/test_moe_pallas.py @@ -6,6 +6,7 @@ Run `pytest tests/kernels/moe/test_moe_pallas.py`. """ import pytest import torch +import torch_xla # yapf conflicts with isort for this block # yapf: disable @@ -77,7 +78,7 @@ def test_pallas_moe( expert_map=e_map, renormalize=False, ) - xm.mark_step() + torch_xla.sync(wait=False) # Compare outputs torch.testing.assert_close( diff --git a/tests/v1/cudagraph/test_cudagraph_mode.py b/tests/v1/cudagraph/test_cudagraph_mode.py index 25e01806f4956..1ae9185fafbdd 100644 --- a/tests/v1/cudagraph/test_cudagraph_mode.py +++ b/tests/v1/cudagraph/test_cudagraph_mode.py @@ -47,7 +47,10 @@ backend_configs = { # FA3 on Hopper "FA3": BackendConfig(name="FA3", - env_vars={"VLLM_FLASH_ATTN_VERSION": "3"}, + env_vars={ + "VLLM_FLASH_ATTN_VERSION": "3", + "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", + }, comp_config={ "cudagraph_mode": "FULL", }, @@ -67,6 +70,7 @@ backend_configs = { BackendConfig(name="FlashAttentionMLA", env_vars={ "VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA", + "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", }, comp_config={ "cudagraph_mode": "FULL_DECODE_ONLY", @@ -75,7 +79,10 @@ backend_configs = { # FA2 "FA2": BackendConfig(name="FA2", - env_vars={"VLLM_FLASH_ATTN_VERSION": "2"}, + env_vars={ + "VLLM_FLASH_ATTN_VERSION": "2", + "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", + }, comp_config={ "cudagraph_mode": "FULL_AND_PIECEWISE", }), diff --git a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh index 9322410ec99e9..bc88370791096 100755 --- a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh @@ -85,7 +85,10 @@ run_tests_for_model() { echo "Starting prefill instance $i on GPU $GPU_ID, port $PORT" # Build the command with or without model-specific args - BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT vllm serve $model_name \ + BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID \ + UCX_NET_DEVICES=all \ + VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \ + vllm serve $model_name \ --port $PORT \ --enforce-eager \ --gpu-memory-utilization 0.2 \ @@ -117,7 +120,10 @@ run_tests_for_model() { echo "Starting decode instance $i on GPU $GPU_ID, port $PORT" # Build the command with or without model-specific args - BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT vllm serve $model_name \ + BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID \ + UCX_NET_DEVICES=all \ + VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \ + vllm serve $model_name \ --port $PORT \ --enforce-eager \ --gpu-memory-utilization 0.2 \ diff --git a/tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py b/tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py new file mode 100644 index 0000000000000..fe6296cf12ea0 --- /dev/null +++ b/tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa: E501 + SharedStorageConnectorMetadata) +from vllm.distributed.kv_transfer.kv_transfer_state import ( + ensure_kv_transfer_initialized, get_kv_transfer_group) +from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput +from vllm.v1.worker.kv_connector_model_runner_mixin import ( + KVConnectorModelRunnerMixin) + +# Importing utils registers TestSharedStorageConnector with the factory +from .utils import create_vllm_config + + +def _make_empty_scheduler_output(): + return SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), + num_scheduled_tokens={}, + total_num_scheduled_tokens=0, + scheduled_spec_decode_tokens={}, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=[], + finished_req_ids=set(), + free_encoder_mm_hashes=[], + structured_output_request_ids={}, + grammar_bitmask=None, + kv_connector_metadata=SharedStorageConnectorMetadata(), + ) + + +def test_kv_connector_mixin_clears_metadata(): + vllm_config = create_vllm_config() + vllm_config.kv_transfer_config.kv_connector = "TestSharedStorageConnector" + vllm_config.kv_transfer_config.kv_role = "kv_both" + vllm_config.kv_transfer_config.kv_connector_extra_config["name"] = ("unit") + + # Initialize the global connector instance + ensure_kv_transfer_initialized(vllm_config) + + try: + # Minimal scheduler output with empty metadata; mixin should still + # bind/clear metadata even if no loads happen + scheduler_output = _make_empty_scheduler_output() + + # Invoke the no-forward path which uses the mixin context manager + KVConnectorModelRunnerMixin.kv_connector_no_forward( + scheduler_output, vllm_config) + + # Verify clear_connector_metadata was called on the connector + connector = get_kv_transfer_group() + assert connector._connector_metadata is None + # Test connector wrapper records method calls + assert connector.call_record.get("bind_connector_metadata", 0) == 1 + assert connector.call_record.get("clear_connector_metadata", 0) == 1 + finally: + # Ensure we clean up the global connector between tests + KVConnectorModelRunnerMixin.ensure_kv_transfer_shutdown() diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 6e58d158c3f4b..24cc83c28614b 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -27,6 +27,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata, NixlConnectorWorker, NixlKVConnectorStats) from vllm.forward_context import ForwardContext +from vllm.platforms.interface import Platform from vllm.sampling_params import SamplingParams from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput @@ -56,7 +57,10 @@ class FakeNixlWrapper: def get_reg_descs(self, caches_data, memory_type: str) -> list: return [str(uuid.uuid4()) for _ in caches_data] - def register_memory(self, descs) -> None: + def register_memory(self, descs, backends) -> None: + pass + + def deregister_memory(self, descs) -> None: pass def get_xfer_descs(self, blocks_data, memory_type: str) -> list: @@ -85,6 +89,12 @@ class FakeNixlWrapper: def release_xfer_handle(self, handle: int) -> None: pass + def release_dlist_handle(self, handle: int) -> None: + pass + + def remove_remote_agent(self, agent: str) -> None: + pass + def send_notif(self, agent_name: str, notif_msg: bytes) -> None: pass @@ -855,3 +865,95 @@ def test_register_kv_caches(dist_init): assert block_len == expected_block_len, \ f"Block entry {i}: Expected block len {expected_block_len}, " \ f"got {block_len}" + + +class FakePlatform(Platform): + device_type: str = "oot" + + @classmethod + def get_nixl_supported_devices(cls) -> dict[str, tuple[str, ...]]: + """ + Returns a mapping from device_type to a tuple of supported + kv_buffer_device for nixl. + """ + return {'oot': ('oot', )} + + @classmethod + def get_nixl_memory_type(cls) -> Optional[str]: + """ + Returns the nixl memory type for the current platform. + """ + return 'VRAM' + + +@pytest.mark.parametrize("kv_buffer_device, nixl_memory_type", [ + ("oot", "VRAM"), +]) +def test_kv_buffer_to_nixl_memory_types(dist_init, kv_buffer_device, + nixl_memory_type): + """ + Test that register_kv_caches() passes the correct memory types from the + config to the nixl_wrapper. + """ + vllm_config = create_vllm_config() + # Override the default memory types in the config + vllm_config.kv_transfer_config.kv_buffer_device = kv_buffer_device + from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( + _NIXL_SUPPORTED_DEVICE) + _NIXL_SUPPORTED_DEVICE.update(FakePlatform.get_nixl_supported_devices()) + + with patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"), \ + patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event"), \ + patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread"), \ + patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform", FakePlatform), \ + patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector._NIXL_SUPPORTED_DEVICE", _NIXL_SUPPORTED_DEVICE): # noqa: E501 + + # Create connector and replace its worker with a fake one for isolation + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + + # Verify get_reg_descs was called with the correct memory_type + assert connector.connector_worker.kv_buffer_device == kv_buffer_device + assert connector.connector_worker.nixl_memory_type == nixl_memory_type + + +@patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper) +def test_shutdown_cleans_up_resources(dist_init): + """Test that shutdown() properly cleans up all resources.""" + vllm_config = create_vllm_config() + + worker = NixlConnectorWorker(vllm_config, + vllm_config.kv_transfer_config.engine_id) + nixl_wrapper = worker.nixl_wrapper + + with patch.object(worker, '_handshake_initiation_executor') as mock_exec, \ + patch.object(worker, '_nixl_handshake_listener_t') as mock_listener, \ + patch.object(nixl_wrapper, 'release_xfer_handle') as mock_rel_xfer, \ + patch.object(nixl_wrapper, 'release_dlist_handle') as mock_rel_dlist, \ + patch.object(nixl_wrapper, 'remove_remote_agent') as mock_rem_agent, \ + patch.object(nixl_wrapper, 'deregister_memory') as mock_dereg: + + worker._recving_transfers = {"req1": [(123, time.perf_counter())]} + worker.src_xfer_side_handle = 456 + worker.dst_xfer_side_handles = {"engine1": 789} + worker._remote_agents = {"engine1": {0: "agent1"}} + worker._registered_descs = ["desc1", "desc2"] + + worker.shutdown() + + # Test idempotency + worker.shutdown() + worker.shutdown() + + mock_exec.shutdown.assert_called_with(wait=False) + mock_listener.join.assert_called_once_with(timeout=0) + + mock_rel_xfer.assert_called_once_with(123) + assert mock_rel_dlist.call_count == 2 + mock_rel_dlist.assert_any_call(456) # src handle + mock_rel_dlist.assert_any_call(789) # dst handle + mock_rem_agent.assert_called_once_with("agent1") + assert mock_dereg.call_count == 2 + mock_dereg.assert_any_call("desc1") + mock_dereg.assert_any_call("desc2") diff --git a/tests/v1/kv_offload/test_cpu.py b/tests/v1/kv_offload/test_cpu_manager.py similarity index 100% rename from tests/v1/kv_offload/test_cpu.py rename to tests/v1/kv_offload/test_cpu_manager.py diff --git a/tests/v1/kv_offload/test_cpu_offloading.py b/tests/v1/kv_offload/test_cpu_offloading.py new file mode 100644 index 0000000000000..fc8ca09bea3de --- /dev/null +++ b/tests/v1/kv_offload/test_cpu_offloading.py @@ -0,0 +1,62 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import time + +import pytest + +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig + +CPU_BLOCK_SIZES = [16, 48] + + +@pytest.mark.parametrize("cpu_block_size", CPU_BLOCK_SIZES) +def test_cpu_offloading(cpu_block_size: int) -> None: + """ + Tests OffloadingConnector with CPUOffloadingSpec. + """ + + # configure OffloadingConnector (spec_name=CPUOffloadingSpec by default) + kv_transfer_config = KVTransferConfig( + kv_connector="OffloadingConnector", + kv_role="kv_both", + kv_connector_extra_config={ + "num_cpu_blocks": 100, + "block_size": cpu_block_size + }, + ) + + llm = LLM( + model="meta-llama/Llama-3.2-1B-Instruct", + gpu_memory_utilization=0.5, + kv_transfer_config=kv_transfer_config, + ) + + prompts = ["Hi " * 100] + sampling_params = SamplingParams(temperature=0, max_tokens=20) + + # run generation - this should trigger saving KV cache + start_time = time.time() + llm.generate(prompts, sampling_params, use_tqdm=False) + cold_time = time.time() - start_time + + # run generation again - should hit the GPU prefix cache + start_time = time.time() + llm.generate(prompts, sampling_params, use_tqdm=False) + gpu_hit_time = time.time() - start_time + + # reset prefix cache to avoid GPU hit. + llm.reset_prefix_cache() + + # sleep for a sec to make sure CPU finished storing + time.sleep(1) + + # run generation again - this should trigger loading from CPU + start_time = time.time() + llm.generate(prompts, sampling_params, use_tqdm=False) + cpu_hit_time = time.time() - start_time + + print("Generation times:") + print(f" Cold: {cold_time * 1000:.2f}ms") + print(f" GPU hit: {gpu_hit_time * 1000:.2f}ms") + print(f" CPU hit: {cpu_hit_time * 1000:.2f}ms") diff --git a/tests/v1/test_async_llm_dp.py b/tests/v1/test_async_llm_dp.py index 32da58011be98..cef0f362cff86 100644 --- a/tests/v1/test_async_llm_dp.py +++ b/tests/v1/test_async_llm_dp.py @@ -13,7 +13,6 @@ from vllm import SamplingParams from vllm.config import VllmConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.inputs import PromptType -from vllm.platforms import current_platform from vllm.sampling_params import RequestOutputKind from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.core_client import DPAsyncMPClient @@ -29,10 +28,6 @@ engine_args = AsyncEngineArgs( data_parallel_size=DP_SIZE, ) -if not current_platform.supports_v1(engine_args.create_model_config()): - pytest.skip(reason="Requires V1-supporting platform.", - allow_module_level=True) - async def generate( engine: AsyncLLM, diff --git a/tests/v1/tpu/test_topk_topp_sampler.py b/tests/v1/tpu/test_topk_topp_sampler.py index 05751badc7619..665cf8cd2629e 100644 --- a/tests/v1/tpu/test_topk_topp_sampler.py +++ b/tests/v1/tpu/test_topk_topp_sampler.py @@ -4,6 +4,7 @@ import math import pytest import torch +import torch_xla from vllm.platforms import current_platform from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p @@ -63,7 +64,7 @@ def test_topp_result_sums_past_p(): probs.masked_fill_(logits_masked.isinf(), 0) masked_prob_sum = probs.sum(dim=-1) - xm.mark_step() + torch_xla.sync() # Perform assertion on CPU. assert torch.all(torch.ge(masked_prob_sum.cpu() + TOLERANCE, p.cpu())) @@ -82,7 +83,7 @@ def test_topp_basic(): k=torch.tensor([3, 3]), p=torch.tensor([0.79, 0.79])) - xm.mark_step() + torch_xla.sync() # Expect the smallest elements to be dropped. expected_result = logits.clone().cpu() @@ -104,7 +105,7 @@ def test_topp_select_all(): k=torch.tensor([3, 3]), p=torch.tensor([1.0, 1.0])) - xm.mark_step() + torch_xla.sync() assert torch.allclose(logits.cpu(), result.cpu()) @@ -122,7 +123,7 @@ def test_topp_with_ties(): k=torch.tensor([4]), p=torch.tensor([0.2])) - xm.mark_step() + torch_xla.sync() # All tie values are included in the top-p set. Tie breaking is left # to be done during final sampling (all tie tokens have equal @@ -146,7 +147,7 @@ def test_both_topk_topp(): k=torch.tensor([1, 3]), p=torch.tensor([0.79, 0.79])) - xm.mark_step() + torch_xla.sync() # Since for the first batch k=1, expect only the largest element gets # selected. diff --git a/tools/mypy.sh b/tools/mypy.sh deleted file mode 100755 index 63e3b9a916634..0000000000000 --- a/tools/mypy.sh +++ /dev/null @@ -1,35 +0,0 @@ -#!/bin/bash - -CI=${1:-0} -PYTHON_VERSION=${2:-local} - -if [ "$CI" -eq 1 ]; then - set -e -fi - -if [ $PYTHON_VERSION == "local" ]; then - PYTHON_VERSION=$(python -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")') -fi - -run_mypy() { - echo "Running mypy on $1" - if [ "$CI" -eq 1 ] && [ -z "$1" ]; then - mypy --python-version "${PYTHON_VERSION}" "$@" - return - fi - mypy --follow-imports skip --python-version "${PYTHON_VERSION}" "$@" -} - -run_mypy # Note that this is less strict than CI -run_mypy tests -run_mypy vllm/attention -run_mypy vllm/compilation -run_mypy vllm/distributed -run_mypy vllm/engine -run_mypy vllm/executor -run_mypy vllm/inputs -run_mypy vllm/lora -run_mypy --exclude 'vllm/model_executor/layers/fla/ops' vllm/model_executor -run_mypy vllm/plugins -run_mypy vllm/worker -run_mypy vllm/v1 diff --git a/tools/check_pickle_imports.py b/tools/pre_commit/check_pickle_imports.py similarity index 61% rename from tools/check_pickle_imports.py rename to tools/pre_commit/check_pickle_imports.py index fe717121db40d..acbbc1f181d69 100644 --- a/tools/check_pickle_imports.py +++ b/tools/pre_commit/check_pickle_imports.py @@ -1,20 +1,10 @@ #!/usr/bin/env python3 # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import os import sys import regex as re -try: - import pathspec -except ImportError: - print( - "ERROR: The 'pathspec' library is required. " - "Install it with 'pip install pathspec'.", - file=sys.stderr) - sys.exit(2) - # List of files (relative to repo root) that are allowed to import pickle or # cloudpickle # @@ -25,7 +15,7 @@ except ImportError: # Before adding new uses of pickle/cloudpickle, please consider safer # alternatives like msgpack or pydantic that are already in use in vLLM. Only # add to this list if absolutely necessary and after careful security review. -ALLOWED_FILES = set([ +ALLOWED_FILES = { # pickle 'vllm/v1/serial_utils.py', 'vllm/v1/executor/multiproc_executor.py', @@ -36,11 +26,9 @@ ALLOWED_FILES = set([ 'tests/tokenization/test_cached_tokenizer.py', 'vllm/distributed/utils.py', 'vllm/distributed/parallel_state.py', - 'vllm/engine/multiprocessing/client.py', 'vllm/distributed/device_communicators/all_reduce_utils.py', 'vllm/distributed/device_communicators/shm_broadcast.py', 'vllm/distributed/device_communicators/shm_object_storage.py', - 'vllm/engine/multiprocessing/engine.py', 'benchmarks/kernels/graph_machete_bench.py', 'benchmarks/kernels/benchmark_lora.py', 'benchmarks/kernels/benchmark_machete.py', @@ -55,65 +43,30 @@ ALLOWED_FILES = set([ 'tests/utils.py', # pickle and cloudpickle 'vllm/utils/__init__.py', - 'vllm/v1/serial_utils.py', - 'vllm/v1/executor/multiproc_executor.py', - 'vllm/transformers_utils/config.py', - 'vllm/model_executor/models/registry.py', - 'vllm/engine/multiprocessing/client.py', - 'vllm/engine/multiprocessing/engine.py', -]) +} PICKLE_RE = re.compile(r"^\s*(import\s+(pickle|cloudpickle)(\s|$|\sas)" r"|from\s+(pickle|cloudpickle)\s+import\b)") -def is_python_file(path): - return path.endswith('.py') - - -def scan_file(path): +def scan_file(path: str) -> int: with open(path, encoding='utf-8') as f: - for line in f: + for i, line in enumerate(f, 1): if PICKLE_RE.match(line): - return True - return False - - -def load_gitignore(repo_root): - gitignore_path = os.path.join(repo_root, '.gitignore') - patterns = [] - if os.path.exists(gitignore_path): - with open(gitignore_path, encoding='utf-8') as f: - patterns = f.read().splitlines() - # Always ignore .git directory - patterns.append('.git/') - return pathspec.PathSpec.from_lines('gitwildmatch', patterns) + print(f"{path}:{i}: " + "\033[91merror:\033[0m " # red color + "Found pickle/cloudpickle import") + return 1 + return 0 def main(): - repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - spec = load_gitignore(repo_root) - bad_files = [] - for dirpath, _, filenames in os.walk(repo_root): - for filename in filenames: - if not is_python_file(filename): - continue - abs_path = os.path.join(dirpath, filename) - rel_path = os.path.relpath(abs_path, repo_root) - # Skip ignored files - if spec.match_file(rel_path): - continue - if scan_file(abs_path) and rel_path not in ALLOWED_FILES: - bad_files.append(rel_path) - if bad_files: - print("\nERROR: The following files import 'pickle' or 'cloudpickle' " - "but are not in the allowed list:") - for f in bad_files: - print(f" {f}") - print("\nIf this is intentional, update the allowed list in " - "tools/check_pickle_imports.py.") - sys.exit(1) - sys.exit(0) + returncode = 0 + for filename in sys.argv[1:]: + if filename in ALLOWED_FILES: + continue + returncode |= scan_file(filename) + return returncode def test_regex(): @@ -149,4 +102,4 @@ if __name__ == '__main__': if '--test-regex' in sys.argv: test_regex() else: - main() + sys.exit(main()) diff --git a/tools/pre_commit/mypy.py b/tools/pre_commit/mypy.py new file mode 100755 index 0000000000000..039cf6075f631 --- /dev/null +++ b/tools/pre_commit/mypy.py @@ -0,0 +1,140 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Run mypy on changed files. + +This script is designed to be used as a pre-commit hook. It runs mypy +on files that have been changed. It groups files into different mypy calls +based on their directory to avoid import following issues. + +Usage: + python tools/pre_commit/mypy.py + +Args: + ci: "1" if running in CI, "0" otherwise. In CI, follow_imports is set to + "silent" for the main group of files. + python_version: Python version to use (e.g., "3.10") or "local" to use + the local Python version. + changed_files: List of changed files to check. +""" + +import subprocess +import sys +from typing import Optional + +import regex as re + +FILES = [ + "vllm/*.py", + "vllm/assets", + "vllm/entrypoints", + "vllm/inputs", + "vllm/logging_utils", + "vllm/multimodal", + "vllm/platforms", + "vllm/transformers_utils", + "vllm/triton_utils", + "vllm/usage", +] + +# After fixing errors resulting from changing follow_imports +# from "skip" to "silent", move the following directories to FILES +SEPARATE_GROUPS = [ + "tests", + "vllm/attention", + "vllm/compilation", + "vllm/distributed", + "vllm/engine", + "vllm/executor", + "vllm/inputs", + "vllm/lora", + "vllm/model_executor", + "vllm/plugins", + "vllm/worker", + "vllm/v1", +] + +# TODO(woosuk): Include the code from Megatron and HuggingFace. +EXCLUDE = [ + "vllm/model_executor/parallel_utils", + "vllm/model_executor/models", + "vllm/model_executor/layers/fla/ops", + # Ignore triton kernels in ops. + "vllm/attention/ops", +] + + +def group_files(changed_files: list[str]) -> dict[str, list[str]]: + """ + Group changed files into different mypy calls. + + Args: + changed_files: List of changed files. + + Returns: + A dictionary mapping file group names to lists of changed files. + """ + exclude_pattern = re.compile(f"^{'|'.join(EXCLUDE)}.*") + files_pattern = re.compile(f"^({'|'.join(FILES)}).*") + file_groups = {"": []} + file_groups.update({k: [] for k in SEPARATE_GROUPS}) + for changed_file in changed_files: + # Skip files which should be ignored completely + if exclude_pattern.match(changed_file): + continue + # Group files by mypy call + if files_pattern.match(changed_file): + file_groups[""].append(changed_file) + continue + else: + for directory in SEPARATE_GROUPS: + if re.match(f"^{directory}.*", changed_file): + file_groups[directory].append(changed_file) + break + return file_groups + + +def mypy(targets: list[str], python_version: Optional[str], + follow_imports: Optional[str], file_group: str) -> int: + """ + Run mypy on the given targets. + + Args: + targets: List of files or directories to check. + python_version: Python version to use (e.g., "3.10") or None to use + the default mypy version. + follow_imports: Value for the --follow-imports option or None to use + the default mypy behavior. + file_group: The file group name for logging purposes. + + Returns: + The return code from mypy. + """ + args = ["mypy"] + if python_version is not None: + args += ["--python-version", python_version] + if follow_imports is not None: + args += ["--follow-imports", follow_imports] + print(f"$ {' '.join(args)} {file_group}") + return subprocess.run(args + targets, check=False).returncode + + +def main(): + ci = sys.argv[1] == "1" + python_version = sys.argv[2] + file_groups = group_files(sys.argv[3:]) + + if python_version == "local": + python_version = f"{sys.version_info.major}.{sys.version_info.minor}" + + returncode = 0 + for file_group, changed_files in file_groups.items(): + follow_imports = None if ci and file_group == "" else "skip" + if changed_files: + returncode |= mypy(changed_files, python_version, follow_imports, + file_group) + return returncode + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index ab7ef2112b083..1b392cd7c88d3 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -10,7 +10,6 @@ from typing import (Any, Dict, Generic, List, Optional, Protocol, Set, Tuple, import torch from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey -from vllm.multimodal import MultiModalPlaceholderMap class AttentionType: @@ -116,15 +115,6 @@ class AttentionMetadata: # in block 0, and 1st slot in block 1, respectively. slot_mapping: torch.Tensor - # The index maps that relate multi-modal embeddings to the corresponding - # placeholders. - # - # N.B. These aren't really related to attention and don't belong on this - # type -- this is just a temporary solution to make them available to - # `model_executable`. - multi_modal_placeholder_index_maps: Optional[Dict[ - str, MultiModalPlaceholderMap.IndexMap]] - # Enable/disable KV scales calculation. This is so that we can disable the # calculation until after prefill and cuda graph capture. enable_kv_scales_calculation: bool diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py index f82d28938f45c..cddeb2cf39bf0 100644 --- a/vllm/attention/backends/placeholder_attn.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -1,10 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections import defaultdict from dataclasses import dataclass from itertools import accumulate -from typing import Dict, List, Optional, Tuple, Type +from typing import List, Optional, Tuple, Type import torch @@ -12,7 +11,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionMetadataBuilder) from vllm.attention.backends.utils import CommonAttentionState -from vllm.multimodal import MultiModalPlaceholderMap from vllm.utils import async_tensor_h2d # Placeholder attention backend for models like Mamba and pooling models that @@ -141,8 +139,6 @@ class PlaceholderAttentionMetadata(AttentionMetadata): num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=self. - multi_modal_placeholder_index_maps, enable_kv_scales_calculation=self.enable_kv_scales_calculation, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, @@ -178,7 +174,6 @@ class PlaceholderAttentionMetadata(AttentionMetadata): num_prefill_tokens=0, num_decode_tokens=self.num_decode_tokens, slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, enable_kv_scales_calculation=True, seq_lens=None, seq_lens_tensor=seq_lens_tensor, @@ -210,9 +205,6 @@ class PlaceholderAttentionMetadataBuilder( self.prefill_seq_lens: List[int] = [] self.context_lens: List[int] = [] self.curr_seq_lens: List[int] = [] - self.multimodal_placeholder_maps: Dict[ - str, - MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) self.num_prefills = 0 self.num_prefill_tokens = 0 self.num_decode_tokens = 0 @@ -232,12 +224,6 @@ class PlaceholderAttentionMetadataBuilder( self.context_lens.append(context_len) if is_prompt: - mm_maps = inter_data.multi_modal_placeholder_maps - if mm_maps: - for modality, placeholders in mm_maps.items(): - self.multimodal_placeholder_maps[modality].extend( - placeholders) - self.num_prefills += 1 self.num_prefill_tokens += token_len self.prefill_seq_lens.append(seq_len) @@ -295,12 +281,6 @@ class PlaceholderAttentionMetadataBuilder( seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, device, self.runner.pin_memory) - placeholder_index_maps = { - modality: placeholder_map.index_map() - for modality, placeholder_map in - self.multimodal_placeholder_maps.items() - } - # Placeholders slot_mapping_tensor = torch.empty(0) block_tables = torch.empty(0) @@ -308,7 +288,6 @@ class PlaceholderAttentionMetadataBuilder( return PlaceholderAttentionMetadata( num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, - multi_modal_placeholder_index_maps=placeholder_index_maps, enable_kv_scales_calculation=True, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 3f15580872c7f..63ee8f50825c5 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention backend utils""" -from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass from itertools import accumulate @@ -15,16 +14,10 @@ from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder, from vllm.attention.backends.abstract import AttentionType from vllm.config import ModelConfig from vllm.logger import init_logger -from vllm.multimodal import MultiModalPlaceholderMap from vllm.utils import async_tensor_h2d, make_tensor_with_pad logger = init_logger(__name__) -# Error string(s) for encoder/decoder -# unsupported attention scenarios -STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported " - "with encoder/decoder models.") - PAD_SLOT_ID = -1 # Switch to numpy implementation of compute_slot_mapping @@ -135,9 +128,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): self.context_lens: List[int] = [] self.block_tables: List[List[int]] = [] self.curr_seq_lens: List[int] = [] - self.multimodal_placeholder_maps: Dict[ - str, - MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) self.num_prefills = 0 self.num_prefill_tokens = 0 self.num_decode_tokens = 0 @@ -154,12 +144,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): inter_data.curr_sliding_window_blocks): self.context_lens.append(context_len) if is_prompt: - mm_maps = inter_data.multi_modal_placeholder_maps - if mm_maps: - for modality, placeholders in mm_maps.items(): - self.multimodal_placeholder_maps[modality].extend( - placeholders) - self.num_prefills += 1 self.num_prefill_tokens += token_len self.prefill_seq_lens.append(seq_len) @@ -254,16 +238,10 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): self.runner.pin_memory) seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, device, self.runner.pin_memory) - placeholder_index_maps = { - modality: placeholder_map.index_map() - for modality, placeholder_map in - self.multimodal_placeholder_maps.items() - } return self._metadata_cls( # type: ignore num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, - multi_modal_placeholder_index_maps=placeholder_index_maps, enable_kv_scales_calculation=True, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, @@ -320,7 +298,6 @@ class CommonAttentionState(AttentionState): num_prefill_tokens=0, num_decode_tokens=batch_size, slot_mapping=self._graph_slot_mapping[:batch_size], - multi_modal_placeholder_index_maps=None, enable_kv_scales_calculation=True, seq_lens=None, seq_lens_tensor=self._graph_seq_lens[:batch_size], diff --git a/vllm/attention/ops/common.py b/vllm/attention/ops/common.py index 189b57e8e8b82..6253e1e56b0f1 100644 --- a/vllm/attention/ops/common.py +++ b/vllm/attention/ops/common.py @@ -134,6 +134,5 @@ def cp_lse_ag_out_rs(cp_attn_out: torch.Tensor, cp_attn_lse = cp_attn_lse.contiguous() lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses) out, _ = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx) - assert out.is_contiguous() out = cp_group.reduce_scatter(out, dim=1) return out diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py index 7382782f11655..2a042802d0d54 100644 --- a/vllm/benchmarks/serve.py +++ b/vllm/benchmarks/serve.py @@ -531,18 +531,22 @@ async def benchmark( extra_body=extra_body, ) - test_output = await wait_for_endpoint( - request_func, - test_input, - session, - timeout_seconds=ready_check_timeout_sec, - ) - if not test_output.success: - raise ValueError( - "Initial test run failed - Please make sure benchmark arguments " - f"are correctly specified. Error: {test_output.error}") + if ready_check_timeout_sec > 0: + test_output = await wait_for_endpoint( + request_func, + test_input, + session, + timeout_seconds=ready_check_timeout_sec, + ) + if not test_output.success: + raise ValueError( + "Initial test run failed - Please make sure benchmark " + "arguments are correctly specified. " + f"Error: {test_output.error}") + else: + print("Initial test run completed. Starting main benchmark run...") else: - print("Initial test run completed. Starting main benchmark run...") + print("Skipping endpoint ready check.") if lora_modules: # For each input request, choose a LoRA module at random. @@ -1151,7 +1155,8 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=600, help="Maximum time to wait for the endpoint to become ready " - "in seconds (default: 600 seconds / 10 minutes).", + "in seconds (default: 600 seconds / 10 minutes). If set to 0, " + "the ready check will be skipped." ) diff --git a/vllm/compilation/activation_quant_fusion.py b/vllm/compilation/activation_quant_fusion.py index f2fbb1200eecc..74462fb37ca97 100644 --- a/vllm/compilation/activation_quant_fusion.py +++ b/vllm/compilation/activation_quant_fusion.py @@ -17,7 +17,7 @@ from vllm.platforms import current_platform from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32 from .inductor_pass import enable_fake_mode -from .vllm_inductor_pass import VllmInductorPass +from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass logger = init_logger(__name__) @@ -152,7 +152,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern): register_replacement(pattern, replacement, inputs, fwd_only, pm_pass) -class ActivationQuantFusionPass(VllmInductorPass): +class ActivationQuantFusionPass(VllmPatternMatcherPass): """ This pass fuses a pre-defined set of custom ops into fused ops. It uses the torch pattern matcher to find the patterns and replace them. @@ -176,16 +176,12 @@ class ActivationQuantFusionPass(VllmInductorPass): pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern() pattern_silu_mul_nvfp4.register(self.patterns) + self.dump_patterns(config, self.patterns) + + @VllmInductorPass.time_and_log def __call__(self, graph: torch.fx.Graph): - self.begin() - self.dump_graph(graph, "before_act_quant_fusion") - - count = self.patterns.apply(graph) - logger.debug("Replaced %s patterns in ActivationQuantFusionPass", - count) - - self.dump_graph(graph, "after_act_quant_fusion") - self.end_and_log() + self.matched_count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", self.matched_count) def uuid(self): return VllmInductorPass.hash_source(self, ActivationQuantPattern, diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 0658b59a2e215..331cd8a873929 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -20,7 +20,7 @@ from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op from .inductor_pass import enable_fake_mode -from .vllm_inductor_pass import VllmInductorPass +from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass FP8_DTYPE = current_platform.fp8_dtype() @@ -348,7 +348,7 @@ class AllGatherCutlassScaledMMPattern(BasePattern): pm.fwd_only, pm_pass) -class AsyncTPPass(VllmInductorPass): +class AsyncTPPass(VllmPatternMatcherPass): @enable_fake_mode def __init__(self, config: VllmConfig): @@ -378,18 +378,17 @@ class AsyncTPPass(VllmInductorPass): AllGatherCutlassScaledMMPattern( self.model_dtype, self.device).register(self.patterns) + self.dump_patterns(config, self.patterns) + def is_applicable_for_shape(self, shape: Optional[int]) -> bool: # only do replace for specific shapes tp_size = get_tensor_model_parallel_world_size() return shape is not None and shape % tp_size == 0 + @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): - self.begin() - self.dump_graph(graph, "before_async_tp_pass") - count = self.patterns.apply(graph) - logger.debug("Replaced %s patterns with async TP pass.", count) - self.dump_graph(graph, "after_async_tp_pass") - self.end_and_log() + self.matched_count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", self.matched_count) if flashinfer_comm is not None: @@ -1068,7 +1067,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern): pm.fwd_only, pm_pass) -class AllReduceFusionPass(VllmInductorPass): +class AllReduceFusionPass(VllmPatternMatcherPass): def __init__(self, config: VllmConfig): super().__init__(config) @@ -1124,6 +1123,7 @@ class AllReduceFusionPass(VllmInductorPass): fuse_rms_quant=config.compilation_config.pass_config.enable_fusion) self.register_patterns() + self.dump_patterns(config, self.patterns) @enable_fake_mode def register_patterns(self): @@ -1172,15 +1172,14 @@ class AllReduceFusionPass(VllmInductorPass): self.disabled = False + @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): if self.disabled: + logger.debug("AllReduceFusionPass disabled") return - self.begin() - self.dump_graph(graph, "before_all_reduce_fusion_pass") - count = self.patterns.apply(graph) - logger.debug("Replaced %s patterns", count) - self.dump_graph(graph, "after_all_reduce_fusion_pass") - self.end_and_log() + + self.matched_count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", self.matched_count) def __del__(self): if getattr(self, "disabled", True): diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py index 6bc721eec3d45..54403c1f7ca3d 100644 --- a/vllm/compilation/fix_functionalization.py +++ b/vllm/compilation/fix_functionalization.py @@ -26,6 +26,7 @@ class FixFunctionalizationPass(VllmInductorPass): To add new nodes to defunctionalize, add to the if-elif chain in __call__. """ + @VllmInductorPass.time_and_log def __call__(self, graph: torch.fx.Graph): # XPU does not support auto-functionalization yet. # Will enable this when switch to vllm-xpu-kernels. @@ -34,9 +35,6 @@ class FixFunctionalizationPass(VllmInductorPass): "pass currently.") return - self.begin() - self.dump_graph(graph, "before_fix_functionalization") - self.nodes_to_remove: list[torch.fx.Node] = [] count = 0 for node in graph.nodes: @@ -111,7 +109,7 @@ class FixFunctionalizationPass(VllmInductorPass): count += 1 - self.dump_graph(graph, "before_fix_functionalization_cleanup") + self.dump_graph(graph, "before_cleanup") # Remove the nodes all at once count_removed = len(self.nodes_to_remove) @@ -120,8 +118,7 @@ class FixFunctionalizationPass(VllmInductorPass): logger.debug("De-functionalized %s nodes, removed %s nodes", count, count_removed) - self.dump_graph(graph, "after_fix_functionalization") - self.end_and_log() + self.nodes_to_remove.clear() def _remove(self, node_or_nodes: Union[torch.fx.Node, Iterable[torch.fx.Node]]): diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index afa739c966a5b..3034b6eaeaca1 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Callable, NamedTuple, Optional +from typing import Any, NamedTuple import torch import torch._inductor.pattern_matcher as pm @@ -16,10 +16,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale) from vllm.platforms import current_platform -from .fx_utils import find_getitem_maybe from .inductor_pass import enable_fake_mode -from .multi_output_match import MultiOutputMatch -from .vllm_inductor_pass import VllmInductorPass +from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass logger = init_logger(__name__) FP8_DTYPE = current_platform.fp8_dtype() @@ -50,8 +48,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = { torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 } if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): - QUANT_OPS[ - kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501 + QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default class FusedRMSQuantKey(NamedTuple): @@ -80,68 +77,6 @@ FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = { } -class QuantMultiOutputMatch(MultiOutputMatch): - - def __init__(self, match: pm.Match, quant_op, fused_op): - super().__init__(match) - assert isinstance(quant_op, OpOverload) - assert isinstance(fused_op, OpOverload) - self.QUANT_OP = quant_op # in-place quant op - self.FUSED_OP = fused_op # in-place fused quant op - - def insert_fused_node(self, fused_return_mapping: dict[int, tuple[fx.Node, - int]], - **kwargs): - """ - This utility function inserts an auto-functionalized node for FUSED_OP. - It also correctly sets its meta value and rebinds the users of the - unfused nodes to use the fused node instead. - - :param fused_return_mapping: A dictionary, mapping from getitem indices - of the fused node result to a tuple of the old node and a getitem index. - :param kwargs: kwargs that get directly forwarded to the auto_fn node - - Example: - If we want to replace this graph: - _, x1, x2 = auto_fn(op1) - _, y1, y2 = auto_fn(op2) - - with - _, x1, y2, x2 = auto_fn(FUSED_OP) - - we would call: - insert_fused_node({1: (op1_node, 1), 2: (op2_node, 2), 3: (op1_node, 2)} - - Note that the 0th element is None for auto-functionalized in-place ops. - Hence, others appear 1-indexed. - """ - fused_node = self.insert_auto_fn(self.FUSED_OP, kwargs) - indices = fused_return_mapping.keys() - getitem_nodes = self.insert_getitems(fused_node, indices) - - # Prepare the meta value, use a list so it's mutable - meta_val = [None] * (max(indices) + 1) - - # Iterate through elements of the tuple produced by fused_node - for idx, getitem_node in zip(indices, getitem_nodes): - old_node, old_idx = fused_return_mapping[idx] - - # If the old value was never used, the old_getitem might not exist - old_getitem = find_getitem_maybe(old_node, old_idx) - if old_getitem is not None: - # Rebind the users of match getitem nodes to use the new nodes. - # The old nodes will be removed by DCE at the end of the pass. - old_getitem.replace_all_uses_with(getitem_node) - getitem_node.meta["val"] = old_getitem.meta["val"] - - # Extract the appropriate meta value - # It is present even if the getitem node does not exist - meta_val[idx] = old_node.meta["val"][old_idx] - - # Fix the meta value on the new fused node - fused_node.meta["val"] = tuple(meta_val) - - class RMSNormQuantPattern: def __init__(self, epsilon: float, key: FusedRMSQuantKey): @@ -224,8 +159,7 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern): symmetric=symmetric)) super().__init__(epsilon, key) - def register(self, pm_pass: PatternMatcherPass, - record_match: Callable[[MultiOutputMatch], bool]): + def register(self, pm_pass: PatternMatcherPass): def pattern(result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, @@ -271,36 +205,7 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern): inputs, pm.fwd_only, pm_pass, - extra_check=lambda m: record_match( - self.Match(m, self.QUANT_OP, self.FUSED_OP))) - - class Match(QuantMultiOutputMatch): - - def process(self): - # Find the nodes in the match that we need to rebind - rms_node = self.find_auto_fn(RMS_ADD_OP) - quant_node = self.find_auto_fn(self.QUANT_OP) - - assert len(rms_node.users) == 2 - assert len(quant_node.users) == 1 - - # First, insert a new auto_functionalized node for the fused op, - # as well as getitem nodes to extract the result and residual. - # The auto_fn node returns a tuple of (None, result, residual). - # - # The resulting graph looks like this: - # at = auto_functionalized(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ...) # noqa - # result_node_new = at[1] - # residual_node_new = at[2] - with self.inserting_after_match(): - # Missing epsilon, scalars cannot be inputs to the pattern - kwargs = self.match.kwargs.copy() - - # 0 is always None - fused_return_mapping = {1: (quant_node, 1), 2: (rms_node, 2)} - self.insert_fused_node(fused_return_mapping, - **kwargs, - epsilon=rms_node.kwargs["epsilon"]) + ) class RMSNormDynamicQuantPattern(RMSNormQuantPattern): @@ -317,8 +222,7 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern): symmetric=symmetric)) super().__init__(epsilon, key) - def register(self, pm_pass: PatternMatcherPass, - record_match: Callable[[MultiOutputMatch], bool]): + def register(self, pm_pass: PatternMatcherPass): def pattern(result: torch.Tensor, result_rms: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, @@ -366,39 +270,7 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern): inputs, pm.fwd_only, pm_pass, - extra_check=lambda m: record_match( - self.Match(m, self.QUANT_OP, self.FUSED_OP))) - - class Match(QuantMultiOutputMatch): - - def process(self): - # Find the nodes in the match that we need to rebind - rms_node = self.find_auto_fn(RMS_OP) - quant_node = self.find_auto_fn(self.QUANT_OP) - - assert len(rms_node.users) == 1 - assert len(quant_node.users) == 2 - - # First, insert a new auto_functionalized node for the fused op, - # as well as getitem nodes to extract the result and scale. - # The auto_fn node returns a tuple of (None, result, scale). - # - # The resulting graph looks like this: - # at = auto_functionalized(torch.ops._C.rms_norm_dynamic_per_token_quant.default, ...) # noqa - # result_node_new = at[1] - # scale_node_new = at[2] - with self.inserting_after_match(): - # Missing epsilon, scalars cannot be inputs to the pattern - kwargs = self.match.kwargs.copy() - del kwargs["result_rms"] # not used in the fused op - - fused_return_mapping = {1: (quant_node, 1), 2: (quant_node, 2)} - self.insert_fused_node( - fused_return_mapping, - epsilon=rms_node.kwargs["epsilon"], - scale_ub=None, # not used but required - residual=None, # not used but required - **kwargs) + ) class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern): @@ -415,8 +287,7 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern): symmetric=symmetric)) super().__init__(epsilon, key) - def register(self, pm_pass: PatternMatcherPass, - record_match: Callable[[MultiOutputMatch], bool]): + def register(self, pm_pass: PatternMatcherPass): def pattern(result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, @@ -464,137 +335,49 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern): inputs, pm.fwd_only, pm_pass, - extra_check=lambda m: record_match( - self.Match(m, self.QUANT_OP, self.FUSED_OP))) - - class Match(QuantMultiOutputMatch): - - def process(self): - # Find the nodes in the match that we need to rebind - rms_node = self.find_auto_fn(RMS_ADD_OP) - quant_node = self.find_auto_fn(self.QUANT_OP) - - assert len(rms_node.users) == 2 - assert len(quant_node.users) == 2 - - # First, insert a new auto_functionalized node for the fused op, - # as well as getitem nodes to extract result, scale, and residual. - # The auto_fn node returns a tuple (None, result, scale, residual). - # - # The resulting graph looks like this: - # at = auto_functionalized(torch.ops._C.rms_norm_dynamic_per_token_quant.default, ...) # noqa - # result_node_new = at[1] - # scale_node_new = at[2] - # residual_node_new = at[3] - with self.inserting_after_match(): - # Missing epsilon, scalars cannot be inputs to the pattern - kwargs = self.match.kwargs.copy() - - fused_return_mapping = { - 1: (quant_node, 1), # result - 2: (quant_node, 2), # scale - 3: (rms_node, 2), # residual - } - self.insert_fused_node( - fused_return_mapping, - epsilon=rms_node.kwargs["epsilon"], - scale_ub=None, # not used but required - **kwargs) + ) -class FusionPass(VllmInductorPass): +class RMSNormQuantFusionPass(VllmPatternMatcherPass): """ - This pass fuses a pre-defined set of custom ops into fused ops. - It uses the torch pattern matcher to find the patterns and replace them. - It also manually processes multi-output matches, as those are broken in - the torch pattern matcher. - - Because patterns can only be registered once, the pass is a singleton. - This will be addressed in a future version of PyTorch: - https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980 + This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op. + It also supports fused_add_rms_norm. """ - _instance: 'Optional[FusionPass]' = None - - @classmethod - def instance(cls, config: VllmConfig): - """ - Get the singleton instance of the FusionPass. - If the instance exists, the config is updated but - initialization is not repeated. - """ - if cls._instance is None: - cls._instance = FusionPass(config) - else: - cls._instance.pass_config = config.compilation_config.pass_config - return cls._instance - @enable_fake_mode def __init__(self, config: VllmConfig): - assert self.__class__._instance is None, \ - "FusionPass singleton instance already exists" super().__init__(config) - self.matches: list[MultiOutputMatch] = [] self.patterns: PatternMatcherPass = PatternMatcherPass( - pass_name="fusion_pass") + pass_name="rmsnorm_quant_fusion_pass") for epsilon in [1e-5, 1e-6]: # Fuse rms_norm + static fp8 quant RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) - # Matches for patterns below have 2 or more outputs, - # so we need to process them manually (see process_matches) - - # Fuse rms_norm + static fp8 quant + # Fuse fused_add_rms_norm + static fp8 quant FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register( - self.patterns, self.record_match) + self.patterns) # Fuse rms_norm + dynamic per-token fp8 quant - RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register( - self.patterns, self.record_match) + RMSNormDynamicQuantPattern(epsilon, + FP8_DTYPE).register(self.patterns) # Fuse fused_add_rms_norm + dynamic per-token fp8 quant FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register( - self.patterns, self.record_match) + self.patterns) - # WARNING: This is a hack to clear the pattern matcher cache - # and allow multiple values of epsilon. - torch._inductor.pattern_matcher._seen_patterns.clear() - - def record_match(self, match: MultiOutputMatch) -> bool: - # Hijack the extra_check to record the match and - # save it for post-processing. - self.matches.append(match) - - # Return False to prevent automatic replacement. - return False - - def process_matches(self, graph: fx.Graph): - """ - Manually process multi-output matches and replace them with fused nodes. - See MultiOutputMatch for more details. - """ - for match in self.matches: - match.process() - - # Finally, remove matched nodes - graph.eliminate_dead_code() - assert all(node not in graph.nodes for match in self.matches - for node in match.match.nodes) + self.dump_patterns(config, self.patterns) + @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): - self.begin() - self.dump_graph(graph, "before_fusion") + self.matched_count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", self.matched_count) - count = self.patterns.apply(graph) - logger.debug("Replaced %s patterns", count) - self.dump_graph(graph, "after_pattern_match") - - # Manually process multi-output matches (and run DCE) - self.process_matches(graph) - logger.debug("Post-processed %s matches", len(self.matches)) - self.dump_graph(graph, "after_fusion") - self.matches.clear() - self.end_and_log() + def uuid(self) -> Any: + return self.hash_source(self, RMSNormQuantPattern, + RMSNormStaticQuantPattern, + RMSNormDynamicQuantPattern, + FusedAddRMSNormStaticQuantPattern, + FusedAddRMSNormDynamicQuantPattern) diff --git a/vllm/compilation/fusion_attn.py b/vllm/compilation/fusion_attn.py index e3677b3dd62d8..2c6cf8f12fdc1 100644 --- a/vllm/compilation/fusion_attn.py +++ b/vllm/compilation/fusion_attn.py @@ -18,7 +18,7 @@ from vllm.utils import round_up from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32 from .inductor_pass import enable_fake_mode -from .vllm_inductor_pass import VllmInductorPass +from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass logger = init_logger(__name__) @@ -245,7 +245,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern): pm_pass) -class AttnFusionPass(VllmInductorPass): +class AttnFusionPass(VllmPatternMatcherPass): """ This pass fuses post-attention quantization onto attention if supported. @@ -282,20 +282,12 @@ class AttnFusionPass(VllmInductorPass): "were found in CompilationConfig.static_forward_context " "so no fusion patterns were registered.") + self.dump_patterns(config, self.patterns) + + @VllmInductorPass.time_and_log def __call__(self, graph: torch.fx.graph.Graph) -> None: - self.begin() - self.dump_graph(graph, "before_attn_fusion") - - count = self.patterns.apply(graph) - - # TODO: Move this to pass_manager.py after the fx graph broken issue - # has been resolved. - # see https://github.com/vllm-project/vllm/issues/23091 - graph.eliminate_dead_code() - - logger.debug("Fused quantization onto %s attention nodes", count) - self.dump_graph(graph, "after_attn_fusion") - self.end_and_log() + self.matched_count = self.patterns.apply(graph) + logger.debug("Fused quant onto %s attention nodes", self.matched_count) def uuid(self): return VllmInductorPass.hash_source(self, AttentionQuantPattern, diff --git a/vllm/compilation/multi_output_match.py b/vllm/compilation/multi_output_match.py deleted file mode 100644 index 6d1893777cec6..0000000000000 --- a/vllm/compilation/multi_output_match.py +++ /dev/null @@ -1,109 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import abc -import operator -from abc import abstractmethod -from collections.abc import Iterable - -from torch import fx -from torch._higher_order_ops.auto_functionalize import auto_functionalized -from torch._inductor import pattern_matcher as pm -from torch._ops import OpOverload -from torch.fx import Node - -from vllm.compilation.fx_utils import find_auto_fn - - -class MultiOutputMatch(abc.ABC): - """ - This class provides utilities to process multi-output matches and - manually insert replacements. - - This is necessary because the automatic replacement for multi-output - matches is broken: https://github.com/pytorch/pytorch/issues/137280 - """ - - def __init__(self, match: pm.Match): - self.match = match - - @abstractmethod - def process(self): - """ - Process a multi-output match and manually insert the replacement. - - This method should: - 1. Insert the replacement nodes after the last node in the match. - 2. Rebind the users of nodes in the match to use the new nodes. - 3. Set meta["val"] for de-functionalization. - - The result of an auto-functionalized node is a tuple of tensors. - The first element is the return value of the function, usually None. - The remaining elements are the mutated args of the function. - - All auto-functionalized nodes must contain a proper meta["val"], - as it is used by de-functionalization. meta["val"] has to contain the - value of the node (tuple of tensors) that would be returned by the - functionalized node during tracing. - - Existing nodes in the graph all have this property set, but we have - to set it manually for new nodes we insert. - - Example: - # op schema: foo(a: Tensor!, b: Tensor, c: Tensor!) -> None - at = auto_functionalized(torch.ops._C.foo.default, a, b, c) - # at.meta["val"] = (None, a, c) - """ - raise NotImplementedError - - @property - def nodes(self) -> list[fx.Node]: - return self.match.nodes - - @property - def graph(self) -> fx.Graph: - return self.match.graph - - def find_auto_fn(self, op) -> fx.Node: - """ - Find the first auto_functionalized node with the given op in the match. - """ - return find_auto_fn(self.nodes, op) - - def inserting_after_match(self): - """ - Insert nodes after the last node in the match. - This is done to avoid use-before-definition errors after inserting - replacement nodes. - """ - - # match.nodes is not guaranteed to be sorted. - # Find the last node in the match. - for last_node_in_match in reversed(self.graph.nodes): - if last_node_in_match in self.match.nodes: - break - else: - raise ValueError("No nodes in graph") - - return self.graph.inserting_after(last_node_in_match) - - def insert_getitems(self, tuple_node: fx.Node, - indices: Iterable[int]) -> tuple[fx.Node, ...]: - """ - Insert operator.getitem nodes to extract elements from a tuple node. - - :param tuple_node: The tuple node to extract elements from. - :param indices: The indices of the elements to extract. - :return: Tuple of the new getitem nodes, corresponding to the indices. - """ - with self.graph.inserting_after(tuple_node): - return tuple( - self.graph.call_function(operator.getitem, (tuple_node, idx)) - for idx in indices) - - def insert_auto_fn(self, op: OpOverload, kwargs) -> Node: - """ - Insert an auto_functionalized node with the given op and kwargs. - """ - return self.graph.call_function(auto_functionalized, (op, ), - kwargs=kwargs) diff --git a/vllm/compilation/noop_elimination.py b/vllm/compilation/noop_elimination.py index 17e85e70218da..2c453daf873d2 100644 --- a/vllm/compilation/noop_elimination.py +++ b/vllm/compilation/noop_elimination.py @@ -64,9 +64,8 @@ class NoOpEliminationPass(VllmInductorPass): out: "f16[s0, 4096]" = at[1] """ + @VllmInductorPass.time_and_log def __call__(self, graph: torch.fx.Graph): - self.begin() - self.dump_graph(graph, "before_noop_elimination") count = 0 # Remove no-op reshapes/views: for node in graph.nodes: @@ -121,8 +120,6 @@ class NoOpEliminationPass(VllmInductorPass): count += 1 logger.debug("Removed %s no-op reshapes and slices", count) - self.dump_graph(graph, "after_noop_elimination") - self.end_and_log() # ---------------------- Reshape helpers ---------------------- def reshape_dims_equivalent(self, dim: Union[int, torch.fx.Node], diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 1b1cbe4fa12c2..e323fa1f77349 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -1,15 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools from torch import fx as fx +from vllm import envs from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.utils import set_env_var + +from .post_cleanup import PostCleanupPass +from .vllm_inductor_pass import VllmInductorPass if current_platform.is_cuda_alike(): from .activation_quant_fusion import ActivationQuantFusionPass - from .fusion import FusionPass + from .fusion import RMSNormQuantFusionPass from .fusion_attn import AttnFusionPass if current_platform.is_cuda(): @@ -19,11 +25,28 @@ from .fix_functionalization import FixFunctionalizationPass from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context from .noop_elimination import NoOpEliminationPass from .sequence_parallelism import SequenceParallelismPass -from .vllm_inductor_pass import VllmInductorPass logger = init_logger(__name__) +def with_pattern_match_debug(fn): + """ + Function decorator that turns on inductor pattern match debug + for the duration of the call. + Used to avoid logging builtin Inductor pattern matching. + """ + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if (debug_val := envs.VLLM_PATTERN_MATCH_DEBUG) is not None: + # optionally check rank here + with set_env_var("TORCHINDUCTOR_PATTERN_MATCH_DEBUG", debug_val): + return fn(*args, **kwargs) + return fn(*args, **kwargs) + + return wrapper + + class PostGradPassManager(CustomGraphPass): """ The pass manager for post-grad passes. @@ -40,16 +63,26 @@ class PostGradPassManager(CustomGraphPass): """ def __init__(self): - self.passes: list[VllmInductorPass] = [] + self.passes: list[InductorPass] = [] + @with_pattern_match_debug def __call__(self, graph: fx.Graph): + VllmInductorPass.dump_prefix = 0 # reset dump index + shape = get_pass_context().runtime_shape for pass_ in self.passes: if pass_.is_applicable_for_shape(shape): pass_(graph) + VllmInductorPass.dump_prefix += 1 + + # post-cleanup goes before fix_functionalization + # because it requires a functional graph + self.post_cleanup(graph) + VllmInductorPass.dump_prefix += 1 # always run fix_functionalization last self.fix_functionalization(graph) + VllmInductorPass.dump_prefix = None # Cleanup index def configure(self, config: VllmConfig): self.pass_config = config.compilation_config.pass_config @@ -61,14 +94,18 @@ class PostGradPassManager(CustomGraphPass): if self.pass_config.enable_async_tp: self.passes += [AsyncTPPass(config)] + if self.pass_config.enable_fi_allreduce_fusion: + self.passes += [AllReduceFusionPass(config)] + if self.pass_config.enable_fusion: - self.passes += [FusionPass.instance(config)] + self.passes += [RMSNormQuantFusionPass(config)] self.passes += [ActivationQuantFusionPass(config)] if self.pass_config.enable_attn_fusion: self.passes += [AttnFusionPass(config)] - if self.pass_config.enable_fi_allreduce_fusion: - self.passes += [AllReduceFusionPass(config)] + + # needs a functional graph + self.post_cleanup = PostCleanupPass(config) self.fix_functionalization = FixFunctionalizationPass(config) def add(self, pass_: InductorPass): diff --git a/vllm/compilation/post_cleanup.py b/vllm/compilation/post_cleanup.py new file mode 100644 index 0000000000000..6a31f3935da7c --- /dev/null +++ b/vllm/compilation/post_cleanup.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from torch import fx + +from vllm.compilation.vllm_inductor_pass import VllmInductorPass + + +class PostCleanupPass(VllmInductorPass): + """ + This pass performs cleanup after custom passes. + It topologically sorts the graph and removes unused nodes. + This is needed because the pattern matcher does not guarantee producing + a topologically sorted graph, and there may be unused nodes left around. + """ + + @VllmInductorPass.time_and_log + def __call__(self, graph: fx.Graph) -> None: + from torch._inductor.pattern_matcher import stable_topological_sort + stable_topological_sort(graph) + graph.eliminate_dead_code() diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index 1758ed4c86d27..a6ca50c925a2a 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -15,7 +15,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from .inductor_pass import enable_fake_mode -from .vllm_inductor_pass import VllmInductorPass +from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass logger = init_logger(__name__) @@ -417,7 +417,7 @@ class LastAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): pm.fwd_only, pm_pass) -class SequenceParallelismPass(VllmInductorPass): +class SequenceParallelismPass(VllmPatternMatcherPass): """ This pass enables sequence parallelism for models. It identifies patterns where an AllReduce operation is followed by @@ -466,19 +466,13 @@ class SequenceParallelismPass(VllmInductorPass): LastAllReduceRMSNormPattern(epsilon, self.model_dtype, self.device).register(self.patterns) - - # WARNING: This is a hack to clear the pattern matcher cache - # and allow multiple values of epsilon. - torch._inductor.pattern_matcher._seen_patterns.clear() + self.dump_patterns(config, self.patterns) def is_applicable_for_shape(self, shape: Optional[int]) -> bool: tp_size = get_tensor_model_parallel_world_size() return shape is not None and shape % tp_size == 0 + @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): - self.begin() - self.dump_graph(graph, "before_sequence_parallelism_pass") - count = self.patterns.apply(graph) - logger.debug("Replaced %s patterns with sequence parallelism", count) - self.dump_graph(graph, "after_sequence_parallelism_pass") - self.end_and_log() + self.matched_count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", self.matched_count) diff --git a/vllm/compilation/vllm_inductor_pass.py b/vllm/compilation/vllm_inductor_pass.py index b822b05b0f1ec..837770d181993 100644 --- a/vllm/compilation/vllm_inductor_pass.py +++ b/vllm/compilation/vllm_inductor_pass.py @@ -1,10 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - +import functools +import operator import time +from pathlib import Path +from typing import ClassVar, Optional +import regex as re import torch from torch._dynamo.utils import lazy_format_graph_code +from torch._inductor.pattern_matcher import (PatternMatcherPass, + PatternPrettyPrinter) from vllm.config import VllmConfig from vllm.logger import init_logger @@ -19,6 +25,8 @@ class VllmInductorPass(InductorPass): An inductor pass with access to vLLM PassConfig. It provides timing, logging, and dumping utilities. """ + dump_prefix: ClassVar[Optional[int]] = None + """Keep track of pass index for debug dump ordering.""" def __init__(self, config: VllmConfig): self.pass_config = config.compilation_config.pass_config @@ -28,8 +36,24 @@ class VllmInductorPass(InductorPass): else None self.pass_name = self.__class__.__name__ + @staticmethod + def time_and_log(call_fn): + + @functools.wraps(call_fn) + def wrapped(self: VllmInductorPass, graph: torch.fx.Graph): + self.begin() + self.dump_graph(graph, "before") + call_fn(self, graph) + self.dump_graph(graph, "after") + self.end_and_log() + + return wrapped + def dump_graph(self, graph: torch.fx.Graph, stage: str): - lazy_format_graph_code(stage, graph.owning_module) + i = VllmInductorPass.dump_prefix + i_str = "" if i is None else f".{i}" + lazy_format_graph_code(f"post_grad{i_str}.{self.pass_name}.{stage}", + graph.owning_module) def begin(self): self._start_time = time.perf_counter_ns() @@ -40,6 +64,88 @@ class VllmInductorPass(InductorPass): logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms) +class VllmPatternMatcherPass(VllmInductorPass): + """ + A VllmInductorPass that uses the Inductor pattern matcher. + Its main use is providing the dump_patterns utility that dumps the + Inductor pattern matcher patterns into a file, which greatly aids debugging. + + TODO(luka) move more utilities to this pass. + """ + matched_count: int = 0 + """The number of matched patterns in the pass.""" + + _OP_OVERLOAD_PATTERN: ClassVar[re.Pattern] = re.compile( + r"") + + def _replace_op_overloads(self, string: str) -> str: + """Replace with nicer formulations""" + return self._OP_OVERLOAD_PATTERN.sub( + lambda m: f"torch.ops.{m.group(1)}.{m.group(2)}", + string, + ) + + def dump_patterns(self, config: VllmConfig, pm_pass: PatternMatcherPass): + """ + If debug dumping is enabled, dump the Inductor pattern-matcher patterns + into the debug_dump_path folder next to the dumped fx graphs. + + This method does its best to print something that looks like Python code + for easier debugging and potentially navigation. If any errors appear in + the output, please add to this method. + + TODO(luka): use pattern object to manually produce pattern graph + """ + debug_dump_path = config.compilation_config.debug_dump_path + if not debug_dump_path: + return + + rank = config.parallel_config.rank + debug_dump_path = Path(debug_dump_path) / f"rank_{rank}" + debug_dump_path.mkdir(parents=True, exist_ok=True) + + from vllm.utils import unique_filepath + file_path = unique_filepath( + lambda i: debug_dump_path / f"patterns.{self.pass_name}.{i}.py") + + with file_path.open("w") as f: + print( + f'# This file was produced by VllmPatternMatcherPass.' + f'dump_patterns for {self.pass_name}.\n' + f'# It does its best to produce valid-Python-looking code but' + f' please add to dump_patterns if there are any errors.\n\n' + f'from torch._higher_order_ops.auto_functionalize import ' + f'auto_functionalized as auto_functionalized\n' + f'from torch._inductor.pattern_matcher import *', + file=f) + + for node, patterns in pm_pass.patterns.items(): + # fix the operator.getitem repr + if node[1] == operator.getitem: + node_repr = f"({repr(node[0])}, operator.getitem)" + else: + node_repr = repr(node) + + node_repr = self._replace_op_overloads(node_repr) + + print(f"\n\n# Patterns for op: {node_repr}", file=f) + for i, pattern in enumerate(patterns): + # reserve auto_functionalized ahead of time + pp = PatternPrettyPrinter() + pp.namespace.create_name("auto_functionalized", None) + + # Assemble pattern + out_node = pp.pretty_print(pattern.pattern) + pattern_repr = "\n".join([f"def pattern_{i}():"] + [ + f"{pp.memoized_objs_names[key]} = " + f"{pp.memoized_objs_pp[key]}" + for key in pp.memoized_objs_names + ] + [f"return {out_node}"]).replace("\n", "\n ") + + pattern_repr = self._replace_op_overloads(pattern_repr) + print(f"{pattern_repr}\n", file=f) + + class PrinterInductorPass(VllmInductorPass): def __init__(self, name: str, config: VllmConfig): diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index e31a78ba33baa..92fc68f8927ca 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -503,7 +503,7 @@ class VllmConfig: if self.compilation_config.pass_config.enable_sequence_parallelism: self.compilation_config.custom_ops.append("+rms_norm") - if current_platform.is_cuda_alike() or current_platform.is_xpu(): + if current_platform.support_static_graph_mode(): # if cudagraph_mode is not explicitly set by users, set default # value if self.compilation_config.cudagraph_mode is None: @@ -905,10 +905,9 @@ def set_current_vllm_config(vllm_config: VllmConfig, except Exception: raise else: - logger.debug("enabled custom ops: %s", - vllm_config.compilation_config.enabled_custom_ops) - logger.debug("disabled custom ops: %s", - vllm_config.compilation_config.disabled_custom_ops) + if check_compile: + vllm_config.compilation_config.custom_op_log_check() + if check_compile and \ vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \ and compilation_counter.num_models_seen == num_models_seen: diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 22b38daf46c39..34fa7fcfe7e87 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -487,6 +487,12 @@ class CompilationConfig: "supported with torch>=2.9.0.dev. Set " "use_inductor_graph_partition=False instead.") + for op in self.custom_ops: + if op[0] not in {'+', '-'} and op not in {'all', 'none'}: + raise ValueError(f"Invalid syntax '{op}' for custom op, " + "must be 'all', 'none', '+op' or '-op' " + "(where 'op' is the registered op name)") + def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]: if self.level == CompilationLevel.NO_COMPILATION: raise ValueError("No compilation level is set.") @@ -532,8 +538,8 @@ class CompilationConfig: for x in self.compile_sizes: if isinstance(x, str): assert x == "cudagraph_capture_sizes", \ - "Unrecognized size type in compile_sizes, " \ - f"expect 'cudagraph_capture_sizes', got {x}" + "Unrecognized size type in compile_sizes, " \ + f"expect 'cudagraph_capture_sizes', got {x}" computed_compile_sizes.extend(self.cudagraph_capture_sizes) else: assert isinstance(x, int) @@ -628,3 +634,41 @@ class CompilationConfig: return use_fx_graph_piecewise_compilation or \ use_inductor_piecewise_compilation + + def custom_op_log_check(self): + """ + This method logs the enabled/disabled custom ops and checks that the + passed custom_ops field only contains relevant ops. + It is called at the end of set_current_vllm_config, + after the custom ops have been instantiated. + """ + + if len(self.enabled_custom_ops) + len(self.disabled_custom_ops) == 0: + logger.debug("No custom ops found in model.") + return + + logger.debug("enabled custom ops: %s", self.enabled_custom_ops) + logger.debug("disabled custom ops: %s", self.disabled_custom_ops) + + all_ops_in_model = (self.enabled_custom_ops | self.disabled_custom_ops) + for op in self.custom_ops: + if op in {"all", "none"}: + continue + + assert op[0] in {'+', '-'}, "Invalid custom op syntax " \ + "(should be checked during init)" + + # check if op name exists in model + op_name = op[1:] + if op_name not in all_ops_in_model: + from vllm.model_executor.custom_op import CustomOp + + # Does op exist at all or is it just not present in this model? + # Note: Only imported op classes appear in the registry. + missing_str = "doesn't exist (or wasn't imported/registered)" \ + if op_name not in CustomOp.op_registry \ + else "not present in model" + + enable_str = "enabling" if op[0] == '+' else "disabling" + logger.warning_once("Op '%s' %s, %s with '%s' has no effect", + op_name, missing_str, enable_str, op) diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 37a41bf6de71a..a84d882430166 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import hashlib +import os from dataclasses import field from typing import TYPE_CHECKING, Any, Literal, Optional, Union @@ -351,6 +352,10 @@ class ParallelConfig: self.world_size = self.pipeline_parallel_size * \ self.tensor_parallel_size + if self.distributed_executor_backend == "external_launcher": + logger.info("Using external launcher for distributed inference.") + self.world_size *= self.data_parallel_size + if self.data_parallel_size_local > self.data_parallel_size: raise ValueError( f"data_parallel_size_local ({self.data_parallel_size_local}) " @@ -358,6 +363,13 @@ class ParallelConfig: if self.data_parallel_size > 1 or self.data_parallel_size_local == 0: # Data parallel was specified in the engine args. + if self.distributed_executor_backend == "external_launcher": + # For external launcher, + # we need to set the data parallel rank automatically + self.data_parallel_rank = int(os.environ["RANK"]) \ + // (self.world_size // self.data_parallel_size) + logger.info("Set data_parallel_rank to %d automatically.", + self.data_parallel_rank) if not self._data_parallel_master_port_list: self._data_parallel_master_port_list = get_open_ports_list(5) self.data_parallel_master_port = \ @@ -380,7 +392,6 @@ class ParallelConfig: "be set when data_parallel_size > 1") if self.distributed_executor_backend == "external_launcher": - import os os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" logger.info("Disabling V1 multiprocessing for external launcher.") diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 2c861723c3966..d533930e1c7aa 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -527,7 +527,7 @@ class SpeculativeConfig: "speculative decoding is > 1, but got " f"{self.disable_by_batch_size=}") - eagle3_target_supported = ["llama", "qwen"] + eagle3_target_supported = ["llama", "qwen", "gpt_oss"] if self.method == "eagle3" and self.target_model_config and not any( supported_model in self.target_model_config.hf_text_config.model_type diff --git a/vllm/distributed/device_communicators/xpu_communicator.py b/vllm/distributed/device_communicators/xpu_communicator.py index 067315deb773d..b236bae261e03 100644 --- a/vllm/distributed/device_communicators/xpu_communicator.py +++ b/vllm/distributed/device_communicators/xpu_communicator.py @@ -25,6 +25,12 @@ class XpuCommunicator(DeviceCommunicatorBase): super().__init__(cpu_group, device, device_group, unique_name) if self.use_all2all: all2all_backend = envs.VLLM_ALL2ALL_BACKEND + if all2all_backend != "naive": + logger.warning( + "`%s` all2all manager is not supported on XPU." + "Falling back to `naive` all2all manager for XPU.", + all2all_backend) + all2all_backend = "naive" if all2all_backend == "naive": from .all2all import NaiveAll2AllManager self.all2all_manager = NaiveAll2AllManager(self.cpu_group) @@ -67,3 +73,16 @@ class XpuCommunicator(DeviceCommunicatorBase): def broadcast(self, input_: torch.Tensor, src: int = 0) -> None: dist.broadcast(input_, src=src, group=self.device_group) + + def dispatch( + self, hidden_states: torch.Tensor, + router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + assert self.all2all_manager is not None + hidden_states, router_logits = self.all2all_manager.dispatch( + hidden_states, router_logits) + return hidden_states, router_logits + + def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + assert self.all2all_manager is not None + hidden_states = self.all2all_manager.combine(hidden_states) + return hidden_states diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index d3a08af088c11..64feddb591c27 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -58,6 +58,12 @@ except ImportError: logger.warning("NIXL is not available") NixlWrapper = None +try: + from nixl._api import nixl_agent_config +except ImportError: + nixl_agent_config = None + logger.warning("NIXL agent config is not available") + # Supported platforms and types of kv transfer buffer. # {device: tuple of supported kv buffer types} _NIXL_SUPPORTED_DEVICE = { @@ -65,6 +71,8 @@ _NIXL_SUPPORTED_DEVICE = { "tpu": ("cpu", ), "xpu": ("cpu", ), } +# support for oot platform by providing mapping in current_platform +_NIXL_SUPPORTED_DEVICE.update(current_platform.get_nixl_supported_devices()) class NixlAgentMetadata( @@ -242,6 +250,10 @@ class NixlConnector(KVConnectorBase_V1): self.connector_worker.copy_blocks: self.connector_worker.save_kv_to_host(self._connector_metadata) + def shutdown(self): + if self.connector_worker is not None: + self.connector_worker.shutdown() + class NixlConnectorScheduler: """Implementation of Scheduler side methods""" @@ -448,8 +460,15 @@ class NixlConnectorWorker: self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size + self.nixl_backends = \ + vllm_config.kv_transfer_config.get_from_extra_config( + "backends", ["UCX"]) # Agent. - self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None) + non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"] + config = nixl_agent_config(backends=self.nixl_backends) if len( + non_ucx_backends) > 0 and nixl_agent_config is not None else None + + self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), config) # Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}. self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict) @@ -486,11 +505,15 @@ class NixlConnectorWorker: # used when device memory can not be registered under nixl self.host_xfer_buffers: dict[str, torch.Tensor] = {} self.use_host_buffer = self.kv_buffer_device == "cpu" - if self.kv_buffer_device == "cuda": - self.nixl_memory_type = "VRAM" - elif self.kv_buffer_device == "cpu": - self.nixl_memory_type = "DRAM" - else: + # support for oot platform which can't register nixl memory + # type based on kv_buffer_device + self.nixl_memory_type = current_platform.get_nixl_memory_type() + if self.nixl_memory_type is None: + if self.kv_buffer_device == "cuda": + self.nixl_memory_type = "VRAM" + elif self.kv_buffer_device == "cpu": + self.nixl_memory_type = "DRAM" + if self.nixl_memory_type is None: raise RuntimeError( f"{self.device_type} with {self.kv_buffer_device} kv_buffer " "is not supported.") @@ -567,13 +590,6 @@ class NixlConnectorWorker: self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int) self.xfer_stats = NixlKVConnectorStats() - def __del__(self): - """Cleanup background threads on destruction.""" - if executor := getattr(self, "_handshake_initiation_executor", None): - executor.shutdown(wait=False) - if listener_t := getattr(self, "_nixl_handshake_listener_t", None): - listener_t.join(timeout=0) - @staticmethod def _nixl_handshake_listener(metadata: NixlAgentMetadata, ready_event: threading.Event, base_port: int, @@ -766,7 +782,7 @@ class NixlConnectorWorker: descs = self.nixl_wrapper.get_reg_descs(caches_data, self.nixl_memory_type) logger.debug("Registering descs: %s", caches_data) - self.nixl_wrapper.register_memory(descs) + self.nixl_wrapper.register_memory(descs, backends=self.nixl_backends) logger.debug("Done registering descs") self._registered_descs.append(descs) @@ -1327,6 +1343,30 @@ class NixlConnectorWorker: return self.xfer_stats.clone_and_reset() return None + def shutdown(self): + """Shutdown the connector worker.""" + self._handshake_initiation_executor.shutdown(wait=False) + if self._nixl_handshake_listener_t is not None: + self._nixl_handshake_listener_t.join(timeout=0) + self._nixl_handshake_listener_t = None + for handles in self._recving_transfers.values(): + for handle, _ in handles: + self.nixl_wrapper.release_xfer_handle(handle) + self._recving_transfers.clear() + if self.src_xfer_side_handle: + self.nixl_wrapper.release_dlist_handle(self.src_xfer_side_handle) + self.src_xfer_side_handle = 0 + for dst_xfer_side_handle in self.dst_xfer_side_handles.values(): + self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle) + self.dst_xfer_side_handles.clear() + for remote_agents in self._remote_agents.values(): + for agent_name in remote_agents.values(): + self.nixl_wrapper.remove_remote_agent(agent_name) + self._remote_agents.clear() + for desc in self._registered_descs: + self.nixl_wrapper.deregister_memory(desc) + self._registered_descs.clear() + @contextlib.contextmanager def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py index ec72905a0d3ec..3dadfa595ef1e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py @@ -178,6 +178,9 @@ class P2pNcclConnector(KVConnectorBase_V1): # Load the KV for each request each layer for request in metadata.requests: + request_id = request.request_id + ip, port = self.parse_request_id(request_id, False) + remote_address = ip + ":" + str(port + self._rank) for layer_name in forward_context.no_compile_layers: layer = forward_context.no_compile_layers[layer_name] @@ -191,7 +194,7 @@ class P2pNcclConnector(KVConnectorBase_V1): layer = kv_cache[forward_context.virtual_engine] kv_cache = self.p2p_nccl_engine.recv_tensor( - request.request_id + "#" + layer_name) + request.request_id + "#" + layer_name, remote_address) if kv_cache is None: logger.warning("🚧kv_cache is None, %s", request.request_id) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py index fa7cc66ab654d..959bf0277a3f5 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py @@ -134,7 +134,6 @@ class P2pNcclEngine: # PUT or PUT_ASYNC # tensor_id: torch.Tensor self.send_queue: deque[SendQueueItem] = deque() - self.send_request_id_to_tensor_ids: dict[str, set[str]] = {} if self.send_type == "PUT_ASYNC": self._send_thread = threading.Thread(target=self.send_async, daemon=True) @@ -143,6 +142,7 @@ class P2pNcclEngine: # tensor_id: torch.Tensor/(addr, dtype, shape) self.recv_store: dict[str, Any] = {} self.recv_request_id_to_tensor_ids: dict[str, set[str]] = {} + self.send_request_id_to_tensor_ids: dict[str, set[str]] = {} self.socks: dict[str, Any] = {} # remote_address: client socket self.comms: dict[str, Any] = {} # remote_address: (ncclComm_t, rank) @@ -223,18 +223,26 @@ class P2pNcclEngine: # GET with self.send_store_cv: tensor_size = tensor.element_size() * tensor.numel() + if tensor_size > self.buffer_size_threshold: + logger.warning( + "❗[GET]tensor_id:%s, tensor_size:%d, is greater than" + "buffer size threshold :%d, skip send to %s, rank:%d", + tensor_id, tensor_size, self.buffer_size_threshold, + remote_address, self.rank) + return False while (self.buffer_size + tensor_size > self.buffer_size_threshold): - oldest_tenser_id = next(iter(self.send_store)) - oldest_tenser = self.send_store.pop(oldest_tenser_id) - oldest_tenser_size = oldest_tenser.element_size( - ) * oldest_tenser.numel() - self.buffer_size -= oldest_tenser_size - logger.info( + assert len(self.send_store) > 0 + oldest_tensor_id = next(iter(self.send_store)) + oldest_tensor = self.send_store.pop(oldest_tensor_id) + oldest_tensor_size = oldest_tensor.element_size( + ) * oldest_tensor.numel() + self.buffer_size -= oldest_tensor_size + logger.debug( "⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d," - " buffer_size:%d, oldest_tenser_size:%d, rank:%d", + " buffer_size:%d, oldest_tensor_size:%d, rank:%d", remote_address, tensor_id, tensor_size, self.buffer_size, - oldest_tenser_size, self.rank) + oldest_tensor_size, self.rank) self.send_store[tensor_id] = tensor self.buffer_size += tensor_size diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 12571afaa4c13..895971893a661 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1032,7 +1032,9 @@ def init_distributed_environment(world_size: int = -1, distributed_init_method, backend) from vllm.config import get_current_vllm_config config = get_current_vllm_config() - if config is not None and config.parallel_config.data_parallel_size > 1: + if config is not None and config.parallel_config.data_parallel_size > 1 \ + and config.parallel_config.distributed_executor_backend \ + != "external_launcher": parallel_config = config.parallel_config # adjust to take into account data parallelism # offset the rank by the data parallel rank diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index b09d43f705580..8c7a1b413cdb7 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1147,20 +1147,15 @@ class EngineArgs: else: envs.set_vllm_use_v1(use_v1) - # Set default arguments for V0 or V1 Engine. - if use_v1: - self._set_default_args_v1(usage_context, model_config) - # Disable chunked prefill for POWER (ppc64le)/ARM/s390x CPUs in V1 - if current_platform.is_cpu( - ) and current_platform.get_cpu_architecture() in ( - CpuArchEnum.POWERPC, CpuArchEnum.S390X, CpuArchEnum.ARM): - logger.info( - "Chunked prefill is not supported for ARM and POWER " - "and S390X CPUs; " - "disabling it for V1 backend.") - self.enable_chunked_prefill = False - else: - self._set_default_args_v0(model_config) + # Set default arguments for V1 Engine. + self._set_default_args(usage_context, model_config) + # Disable chunked prefill for POWER (ppc64le)/ARM/s390x CPUs in V1 + if current_platform.is_cpu() and current_platform.get_cpu_architecture( + ) in (CpuArchEnum.POWERPC, CpuArchEnum.S390X, CpuArchEnum.ARM): + logger.info("Chunked prefill is not supported for ARM and POWER " + "and S390X CPUs; " + "disabling it for V1 backend.") + self.enable_chunked_prefill = False assert self.enable_chunked_prefill is not None sliding_window: Optional[int] = None @@ -1494,6 +1489,7 @@ class EngineArgs: "FLEX_ATTENTION", "TREE_ATTN", "XFORMERS_VLLM_V1", + "ROCM_ATTN_VLLM_V1", ] if (envs.is_set("VLLM_ATTENTION_BACKEND") and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS): @@ -1501,12 +1497,6 @@ class EngineArgs: _raise_or_fallback(feature_name=name, recommend_to_remove=True) return False - # Platforms must decide if they can support v1 for this model - if not current_platform.supports_v1(model_config=model_config): - _raise_or_fallback( - feature_name=f"device type={current_platform.device_type}", - recommend_to_remove=False) - return False ############################################################# # Experimental Features - allow users to opt in. @@ -1523,12 +1513,6 @@ class EngineArgs: recommend_to_remove=False) return False - # The platform may be supported on V1, but off by default for now. - if not current_platform.default_v1( # noqa: SIM103 - model_config=model_config) and _warn_or_fallback( - current_platform.device_name): - return False - if (current_platform.is_cpu() and model_config.get_sliding_window() is not None): _raise_or_fallback(feature_name="sliding window (CPU backend)", @@ -1539,64 +1523,8 @@ class EngineArgs: return True - def _set_default_args_v0(self, model_config: ModelConfig) -> None: - """Set Default Arguments for V0 Engine.""" - - max_model_len = model_config.max_model_len - use_long_context = max_model_len > 32768 - if self.enable_chunked_prefill is None: - # Chunked prefill not supported for Multimodal or MLA in V0. - if model_config.is_multimodal_model or model_config.use_mla: - self.enable_chunked_prefill = False - - # Enable chunked prefill by default for long context (> 32K) - # models to avoid OOM errors in initial memory profiling phase. - elif use_long_context: - is_gpu = current_platform.is_cuda() - use_sliding_window = (model_config.get_sliding_window() - is not None) - use_spec_decode = self.speculative_config is not None - - if (is_gpu and not use_sliding_window and not use_spec_decode - and not self.enable_lora): - self.enable_chunked_prefill = True - logger.warning( - "Chunked prefill is enabled by default for models " - "with max_model_len > 32K. Chunked prefill might " - "not work with some features or models. If you " - "encounter any issues, please disable by launching " - "with --enable-chunked-prefill=False.") - - if self.enable_chunked_prefill is None: - self.enable_chunked_prefill = False - - if not self.enable_chunked_prefill and use_long_context: - logger.warning( - "The model has a long context length (%s). This may cause" - "OOM during the initial memory profiling phase, or result " - "in low performance due to small KV cache size. Consider " - "setting --max-model-len to a smaller value.", max_model_len) - - # Disable prefix caching for multimodal models for VLLM_V0. - if self.enable_prefix_caching and model_config.is_multimodal_model: - logger.warning( - "--enable-prefix-caching is not supported for multimodal " - "models in V0 and has been disabled.") - self.enable_prefix_caching = False - - if self.enable_prompt_embeds: - logger.warning( - "--enable-prompt-embeds and --enable-prefix-caching " - "are not supported together in V0. Prefix caching has " - "been disabled.") - self.enable_prefix_caching = False - - # Set max_num_seqs to 256 for VLLM_V0. - if self.max_num_seqs is None: - self.max_num_seqs = 256 - - def _set_default_args_v1(self, usage_context: UsageContext, - model_config: ModelConfig) -> None: + def _set_default_args(self, usage_context: UsageContext, + model_config: ModelConfig) -> None: """Set Default Arguments for V1 Engine.""" # V1 always uses chunked prefills and prefix caching @@ -1795,21 +1723,6 @@ def _raise_or_fallback(feature_name: str, recommend_to_remove: bool): logger.warning(msg) -def _warn_or_fallback(feature_name: str) -> bool: - if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1: - logger.warning( - "Detected VLLM_USE_V1=1 with %s. Usage should " - "be considered experimental. Please report any " - "issues on Github.", feature_name) - should_exit = False - else: - logger.info( - "%s is experimental on VLLM_USE_V1=1. " - "Falling back to V0 Engine.", feature_name) - should_exit = True - return should_exit - - def human_readable_int(value): """Parse human-readable integers like '1k', '2M', etc. Including decimal values with decimal multipliers. diff --git a/vllm/entrypoints/context.py b/vllm/entrypoints/context.py index 8619452f2445f..ea81fdbcd825e 100644 --- a/vllm/entrypoints/context.py +++ b/vllm/entrypoints/context.py @@ -8,6 +8,7 @@ from abc import ABC, abstractmethod from contextlib import AsyncExitStack from typing import TYPE_CHECKING, Optional, Union +from openai.types.responses.tool import Mcp from openai_harmony import Author, Message, Role, StreamState, TextContent from vllm.entrypoints.harmony_utils import ( @@ -21,6 +22,24 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +# This is currently needed as the tool type doesn't 1:1 match the +# tool namespace, which is what is used to look up the +# connection to the tool server +_TOOL_NAME_TO_TYPE_MAP = { + "browser": "web_search_preview", + "python": "code_interpreter", + "container": "container", +} + + +def _map_tool_name_to_tool_type(tool_name: str) -> str: + if tool_name not in _TOOL_NAME_TO_TYPE_MAP: + available_tools = ', '.join(_TOOL_NAME_TO_TYPE_MAP.keys()) + raise ValueError( + f"Built-in tool name '{tool_name}' not defined in mapping. " + f"Available tools: {available_tools}") + return _TOOL_NAME_TO_TYPE_MAP[tool_name] + class TurnTokens: """Tracks token counts for a single conversation turn.""" @@ -59,8 +78,8 @@ class ConversationContext(ABC): @abstractmethod async def init_tool_sessions(self, tool_server: Optional[ToolServer], - exit_stack: AsyncExitStack, - request_id: str) -> None: + exit_stack: AsyncExitStack, request_id: str, + mcp_tools: dict[str, Mcp]) -> None: pass @abstractmethod @@ -96,8 +115,8 @@ class SimpleContext(ConversationContext): raise NotImplementedError("Should not be called.") async def init_tool_sessions(self, tool_server: Optional[ToolServer], - exit_stack: AsyncExitStack, - request_id: str) -> None: + exit_stack: AsyncExitStack, request_id: str, + mcp_tools: dict[str, Mcp]) -> None: pass async def cleanup_session(self) -> None: @@ -318,13 +337,17 @@ class HarmonyContext(ConversationContext): ] async def init_tool_sessions(self, tool_server: Optional[ToolServer], - exit_stack: AsyncExitStack, - request_id: str) -> None: + exit_stack: AsyncExitStack, request_id: str, + mcp_tools: dict[str, Mcp]): if tool_server: for tool_name in self.available_tools: if tool_name not in self._tool_sessions: + tool_type = _map_tool_name_to_tool_type(tool_name) + headers = mcp_tools[ + tool_type].headers if tool_type in mcp_tools else None tool_session = await exit_stack.enter_async_context( - tool_server.new_session(tool_name, request_id)) + tool_server.new_session(tool_name, request_id, + headers)) self._tool_sessions[tool_name] = tool_session exit_stack.push_async_exit(self.cleanup_session) diff --git a/vllm/entrypoints/harmony_utils.py b/vllm/entrypoints/harmony_utils.py index 1364a41be950d..57e4bb1e1da52 100644 --- a/vllm/entrypoints/harmony_utils.py +++ b/vllm/entrypoints/harmony_utils.py @@ -126,8 +126,10 @@ def get_developer_message( function_tools: list[Union[Tool, ChatCompletionToolsParam]] = [] for tool in tools: if tool.type in ("web_search_preview", "code_interpreter", - "container"): + "container", "mcp"): # These are built-in tools that are added to the system message. + # Adding in MCP for now until we support MCP tools executed + # server side pass elif tool.type == "function": diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 092d3f276d1c5..c41f44aa47187 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1468,7 +1468,7 @@ class LLM: def _validate_and_add_requests( self, - prompts: Union[PromptType, Sequence[PromptType]], + prompts: Union[PromptType, Sequence[PromptType], DataPrompt], params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, Sequence[PoolingParams]], *, @@ -1478,7 +1478,7 @@ class LLM: ) -> None: if isinstance(prompts, (str, dict)): # Convert a single prompt to a list. - prompts = [prompts] + prompts = [prompts] # type: ignore[list-item] num_requests = len(prompts) if isinstance(params, Sequence) and len(params) != num_requests: diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index 6e243671af242..99bb464db1d13 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -460,8 +460,12 @@ class OpenAIServingResponses(OpenAIServing): async with AsyncExitStack() as exit_stack: try: + mcp_tools = { + tool.server_label: tool + for tool in request.tools if tool.type == "mcp" + } await context.init_tool_sessions(self.tool_server, exit_stack, - request.request_id) + request.request_id, mcp_tools) async for _ in result_generator: pass except asyncio.CancelledError: @@ -748,11 +752,16 @@ class OpenAIServingResponses(OpenAIServing): # New conversation. reasoning_effort = (request.reasoning.effort if request.reasoning else None) - # Temporary: OpenAI types doesn't have container tool - # so we used MCP to cover that, up for change tool_types = [tool.type for tool in request.tools] - if envs.VLLM_GPT_OSS_USE_CONTAINER_TOOL: - tool_types.append("container") + + # Allow the MCP Tool type to enable built in tools if the + # server_label is allowlisted in + # envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS + if envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS: + for tool in request.tools: + if (tool.type == "mcp" and tool.server_label + in envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS): + tool_types.append(tool.server_label) enable_browser = ("web_search_preview" in tool_types and self.tool_server is not None and self.tool_server.has_tool("browser")) @@ -1653,8 +1662,12 @@ class OpenAIServingResponses(OpenAIServing): async with AsyncExitStack() as exit_stack: processer = None if self.use_harmony: + mcp_tools = { + tool.server_label: tool + for tool in request.tools if tool.type == "mcp" + } await context.init_tool_sessions(self.tool_server, exit_stack, - request.request_id) + request.request_id, mcp_tools) processer = self._process_harmony_streaming_events else: processer = self._process_simple_streaming_events diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index 35096b0461361..5e77c406b8d92 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -20,6 +20,7 @@ from .openai_tool_parser import OpenAIToolParser from .phi4mini_tool_parser import Phi4MiniJsonToolParser from .pythonic_tool_parser import PythonicToolParser from .qwen3coder_tool_parser import Qwen3CoderToolParser +from .qwen3xml_tool_parser import Qwen3XMLToolParser from .seed_oss_tool_parser import SeedOssToolParser from .step3_tool_parser import Step3ToolParser from .xlam_tool_parser import xLAMToolParser @@ -45,6 +46,7 @@ __all__ = [ "HunyuanA13BToolParser", "Glm4MoeModelToolParser", "Qwen3CoderToolParser", + "Qwen3XMLToolParser", "SeedOssToolParser", "Step3ToolParser", "OpenAIToolParser", diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index e74c420da1d3c..87595953da067 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -368,16 +368,32 @@ class Hermes2ProToolParser(ToolParser): # case -- we now have the first info about arguments available from # autocompleting the JSON elif cur_arguments and not prev_arguments: + # extract the content after {"name": ..., "arguments": + # directly from tool_call_portion as cur_arguments_json, + # since cur_arguments may differ from the original text + # due to partial JSON parsing + # for example, tool_call_portion = + # {"name": "search", "arguments": {"search_request": {" + # but cur_arguments = + # {"search_request": {}} + function_name = current_tool_call.get("name") + match = re.search( + r'\{"name":\s*"' + + re.escape(function_name) + r'"\s*,\s*"arguments":\s*(.*)', + tool_call_portion.strip(), re.DOTALL) + if match: + cur_arguments_json = match.group(1) + else: + cur_arguments_json = json.dumps(cur_arguments, + ensure_ascii=False) - cur_arguments_json = json.dumps(cur_arguments, - ensure_ascii=False) logger.debug("finding %s in %s", delta_text, cur_arguments_json) - # get the location where previous args differ from current - if (delta_text not in cur_arguments_json[:-2]): + # get the location where previous args differ from current. + if (delta_text not in cur_arguments_json): return None - args_delta_start_loc = cur_arguments_json[:-2]. \ + args_delta_start_loc = cur_arguments_json. \ rindex(delta_text) + \ len(delta_text) @@ -397,8 +413,20 @@ class Hermes2ProToolParser(ToolParser): # last case -- we have an update to existing arguments. elif cur_arguments and prev_arguments: - if isinstance(delta_text, str) and len(delta_text.rstrip( - )) >= 1 and delta_text.rstrip()[-1] == '}': + # judge whether the tool_call_portion is a complete JSON + try: + json.loads(tool_call_portion) + is_complete_json = True + except Exception: + is_complete_json = False + + # if the delta_text ends with a '}' and tool_call_portion is a + # complete JSON, then the last '}' does not belong to the + # arguments, so we should trim it off + if isinstance(delta_text, str) \ + and len(delta_text.rstrip()) >= 1 \ + and delta_text.rstrip()[-1] == '}' \ + and is_complete_json: delta_text = delta_text.rstrip()[:-1] logger.debug("got diff %s", delta_text) diff --git a/vllm/entrypoints/openai/tool_parsers/qwen3xml_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/qwen3xml_tool_parser.py new file mode 100644 index 0000000000000..4ab67dfea104c --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/qwen3xml_tool_parser.py @@ -0,0 +1,1137 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import ast +import json +import uuid +from collections.abc import Sequence +from typing import Any, Optional, Union +from xml.parsers.expat import ParserCreate + +import regex as re + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + ChatCompletionToolsParam, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer + +logger = init_logger(__name__) + + +class StreamingXMLToolCallParser: + """ + Simplified streaming XML tool call parser + Supports streaming input, parsing, and output + """ + + def __init__(self): + self.reset_streaming_state() + + # Tool configuration information + self.tools: Union[list[ChatCompletionToolsParam], None] = None + self.tool_call_start_token: str = '' + self.tool_call_end_token: str = '' + self.function_start_token: str = ' DeltaMessage: + """ + Parse single streaming XML chunk and return Delta response + This is the actual streaming interface that receives chunks + one by one and maintains internal state + + Args: + xml_chunk: Single XML chunk string + Returns: + DeltaMessage: Contains delta information generated by this chunk, + returns empty response if no complete elements + """ + # Record delta count before processing + initial_delta_count = len(self.deltas) + + self.streaming_buffer += xml_chunk + + found_elements = self._process_complete_xml_elements() + + if found_elements: + # If complete elements found, check if end events were missed + # some tags may not have been triggered + try: + new_deltas = self.deltas[initial_delta_count:] + # If this chunk contains + # but didn't generate '}', then complete it + if (self.current_call_id is not None + and self.function_end_token in xml_chunk): + + # - Added '}' (non-empty parameter ending) + # - Added '{}' (empty parameter function) + has_function_close = any((td.tool_calls and any( + (tc.function and tc.id == self.current_call_id + and isinstance(tc.function.arguments, str) and + (tc.function.arguments in ('}', '{}'))) + for tc in td.tool_calls)) for td in new_deltas) + if not has_function_close: + # Close potentially unclosed element + if self.current_param_name: + self._end_element('parameter') + if self.current_function_name: + self._end_element('function') + # If this chunk contains + # but didn't generate final empty delta, then complete it + if (self.current_call_id is not None + and self.tool_call_end_token in xml_chunk): + has_toolcall_close = any((td.tool_calls and any( + (tc.type == 'function' and tc.function and tc.function. + arguments == '' and tc.id == self.current_call_id) + for tc in td.tool_calls)) for td in new_deltas) + if not has_toolcall_close: + # Close potentially unclosed element + if self.current_param_name: + self._end_element('parameter') + if self.current_function_name: + self._end_element('function') + self._end_element('tool_call') + except Exception as e: + logger.warning("Error with fallback parsing: %s", e) + # Merge newly generated deltas into single response + result_delta = self._merge_new_deltas_to_single_response( + initial_delta_count) + return result_delta + else: + # No complete elements, check if there's unoutput text content + if self.text_content_buffer and self.tool_call_index == 0: + # Has text content but no tool_call yet, output text content + text_delta = DeltaMessage(content=self.text_content_buffer) + self._emit_delta(text_delta) + # Clear buffer to avoid duplicate output + self.text_content_buffer = '' + return text_delta + + # If this chunk contains end tags but wasn't triggered by parser, + # manually complete end events + # Only execute when still on the same call as when entered, + # to prevent accidentally closing new calls + # in multi scenarios + if (self.current_call_id is not None + and (self.function_end_token in xml_chunk + or self.tool_call_end_token in xml_chunk)): + # Close potentially unclosed element + if self.current_param_name: + self._end_element('parameter') + if self.function_end_token in xml_chunk and \ + self.current_function_name: + self._end_element('function') + if self.tool_call_end_token in xml_chunk: + self._end_element('tool_call') + # Return the merged delta result generated by this fallback + result_delta = self._merge_new_deltas_to_single_response( + initial_delta_count) + return result_delta + + # No complete elements, return empty response + return DeltaMessage(content=None) + + def _escape_xml_special_chars(self, text: str) -> str: + """ + Escape XML special characters + Args: + text: Original text + Returns: + Escaped text + """ + xml_escapes = { + '&': '&', + '<': '<', + '>': '>', + '"': '"', + "'": ''' + } + + for char, escape in xml_escapes.items(): + text = text.replace(char, escape) + + return text + + def _process_complete_xml_elements(self) -> bool: + """ + Process complete XML elements in buffer + + Returns: + bool: Whether complete elements were found and processed + """ + found_any = False + + while self.last_processed_pos < len(self.streaming_buffer): + # Find next complete xml element + element, end_pos = self._find_next_complete_element( + self.last_processed_pos) + if element is None: + # No complete element found, wait for more data + break + + # Check if this element should be skipped + if self._should_skip_element(element): + self.last_processed_pos = end_pos + continue + + # Found complete XML element, process it + try: + preprocessed_element = self._preprocess_xml_chunk(element) + # Check if this is the first tool_call start + if ((preprocessed_element.strip().startswith('') or + preprocessed_element.strip().startswith('') + and self.tool_call_index > 0 and self.current_call_id): + # Reset parser state but preserve generated deltas + if self.current_param_name: + self._end_element('parameter') + if self.current_function_open or self.current_function_name: + self._end_element('function') + # Output final tool_call tail delta + final_delta = DeltaMessage( + role=None, + content=None, + reasoning_content=None, + tool_calls=[ + DeltaToolCall(index=self.tool_call_index - 1, + id=self.current_call_id, + type='function', + function=DeltaFunctionCall( + name=None, arguments='')) + ]) + self._emit_delta(final_delta) + # Reset XML parser and current call state + self._reset_xml_parser_after_tool_call() + # Parse preprocessed element + self.parser.Parse(preprocessed_element, False) + found_any = True + + except Exception as e: + logger.warning("Error when parsing XML elements: %s", e) + + # Update processed position + self.last_processed_pos = end_pos + + return found_any + + def _should_skip_element(self, element: str) -> bool: + """ + Determine whether an element should be skipped + + Args: + element: Element to evaluate + + Returns: + bool: True means should skip, False means should process + """ + + # If it's a tool_call XML tag, don't skip + if element.startswith( + self.tool_call_start_token) or element.startswith( + self.function_start_token) or element.startswith( + self.parameter_start_token): + return False + + # If currently not parsing tool calls and not blank, + # collect this text instead of skipping + # Only process other XML elements after tool_call appears, + # otherwise treat as plain text + if self.current_call_id is None and element: + # Collect text content to buffer + self.text_content_buffer += element + return True # Still skip, but content has been collected + + # If currently parsing tool calls, + # this might be parameter value, don't skip + if self.current_call_id is not None: + return False + + # Skip blank content + return not element + + def _find_next_complete_element( + self, start_pos: int) -> tuple[Optional[str], int]: + """ + Find next complete XML element from specified position + + Args: + start_pos: Position to start searching + + Returns: + (Complete element string, element end position), + returns (None, start_pos) if no complete element found + """ + buffer = self.streaming_buffer[start_pos:] + + if not buffer: + return None, start_pos + + if buffer.startswith('<'): + # Need to ensure no new < appears, + # find the nearest one between < and > + tag_end = buffer.find('<', 1) + tag_end2 = buffer.find('>', 1) + if tag_end != -1 and tag_end2 != -1: + # Next nearest is < + if tag_end < tag_end2: + return buffer[:tag_end], start_pos + tag_end + # Next nearest is >, means found XML element + else: + return buffer[:tag_end2 + 1], start_pos + tag_end2 + 1 + elif tag_end != -1: + return buffer[:tag_end], start_pos + tag_end + elif tag_end2 != -1: + return buffer[:tag_end2 + 1], start_pos + tag_end2 + 1 + else: + # If currently not parsing tool calls (entering a tool_call), + # check if starts with + if self.current_call_id is None: + # Check if might be start of + if buffer == ''[:len(buffer)]: + # Might be start of , wait for more data + return None, start_pos + else: + # Not start of , treat as text + return buffer, start_pos + len(buffer) + else: + # When parsing tool calls, + # wait for more data to get complete tag + return None, start_pos + else: + # Find text content (until next < or buffer end) + next_tag_pos = buffer.find('<') + if next_tag_pos != -1: + # Found text content + text_content = buffer[:next_tag_pos] + return text_content, start_pos + next_tag_pos + else: + # Buffer end is all text, process + # (no longer wait for more data) + remaining = buffer + return remaining, start_pos + len(remaining) + + def _merge_new_deltas_to_single_response( + self, initial_count: int) -> DeltaMessage: + """ + Merge newly generated deltas from this processing + into a single DeltaMessage + + Args: + initial_count: Delta count before processing + + Returns: + Merged DeltaMessage containing all newly generated delta information + """ + if len(self.deltas) <= initial_count: + return DeltaMessage(content=None) + + # Get newly generated deltas + new_deltas = self.deltas[initial_count:] + + if len(new_deltas) == 1: + # Only one new delta, return directly + return new_deltas[0] + + # Merge multiple new deltas + merged_tool_calls: list[DeltaToolCall] = [] + merged_content: str = '' + + for delta in new_deltas: + if delta.content: + merged_content += delta.content + if delta.tool_calls: + # For tool_calls, we need to intelligently merge arguments + for tool_call in delta.tool_calls: + # Find if there's already a tool_call with the same call_id + existing_call = None + for existing in merged_tool_calls: + if existing.id == tool_call.id: + existing_call = existing + break + + if existing_call and existing_call.function: + # Merge to existing tool_call + if tool_call.function and tool_call.function.name: + existing_call.function.name = \ + tool_call.function.name + if tool_call.function \ + and tool_call.function.arguments is not None: + if existing_call.function.arguments is None: + existing_call.function.arguments = '' + + # For streaming JSON parameters, + # simply concatenate in order + new_args = tool_call.function.arguments + existing_call.function.arguments += new_args + if tool_call.type: + existing_call.type = tool_call.type + else: + # Add new tool_call + merged_tool_calls.append(tool_call) + + return DeltaMessage(content=merged_content if merged_content else None, + tool_calls=merged_tool_calls) + + def _preprocess_xml_chunk(self, chunk: str) -> str: + """ + Preprocess XML chunk, handle non-standard formats, + and escape special characters + + Args: + chunk: Original XML chunk + + Returns: + Processed XML chunk + """ + + # Check if this is a tool_call related element + is_tool_call = False + if chunk.startswith(self.tool_call_start_token) or chunk.startswith( + self.tool_call_end_token): + is_tool_call = True + if chunk.startswith(self.function_start_token) or chunk.startswith( + self.function_end_token): + is_tool_call = True + if chunk.startswith(self.parameter_start_token) or chunk.startswith( + self.parameter_end_token): + is_tool_call = True + # Handle format -> + processed = re.sub(r']+)>', r'', + chunk) + # Handle format -> + processed = re.sub(r']+)>', r'', + processed) + + original_chunk = chunk + # If in parameter value accumulation mode + if self._pre_inside_parameter: + # Parameter end: output accumulated raw text + # safely then return + if processed.startswith(''): + body_text = self._pre_param_buffer + # Trigger deferred parsing mode + # literal_eval+json output in end_element + self.defer_current_parameter = True + self.deferred_param_raw_value = body_text + # Clean up state + self._pre_inside_parameter = False + self._pre_param_buffer = "" + self._pre_current_param_name = None + safe_text = self._escape_xml_special_chars(body_text) + return f"{safe_text}" + else: + # If this is the first block of content after entering parameter + # evaluate if deferred parsing is needed; + # If not needed, exit accumulation mode + # and pass through directly + if self._pre_param_buffer == "": + # Get current parameter type + param_type = self._get_param_type( + self._pre_current_param_name + ) if self._pre_current_param_name else 'string' + # Only these types need deferred parsing to + # handle Python literals containing single quotes + is_object_type = param_type in ["object"] + is_complex_type = (param_type + in ["array", "arr", "sequence"] + or param_type.startswith("dict") + or param_type.startswith("list")) + + # Only delay when contains container symbols + # and has single quotes and is complex type + has_container_hint = ('[' in original_chunk) or ( + '{' in original_chunk) or ('(' in original_chunk) + + # Determine if deferred parsing is needed + need_defer = False + if is_complex_type: + # Complex type, always need deferred parsing + need_defer = True + elif is_object_type and has_container_hint and ( + "'" in original_chunk): + # Object type with container symbols + # and single quotes, need deferred parsing + need_defer = True + + if not need_defer: + # No need for deferred parsing, + # exit parameter mode directly + self._pre_inside_parameter = False + return self._escape_xml_special_chars(original_chunk) + self._pre_param_buffer += original_chunk + return "" + + # Parameter start: enable accumulation + if processed.startswith('', processed) + if m: + self._pre_current_param_name = m.group(1) + self._pre_inside_parameter = True + self._pre_param_buffer = "" + return processed + + # If processed doesn't contain special_token, escape processed + # This is because XML parsing encounters special characters + # and reports errors, so escaping is needed + if not is_tool_call: + processed = self._escape_xml_special_chars(processed) + return processed + + def _emit_delta(self, delta: DeltaMessage): + """Emit Delta response (streaming output)""" + self.deltas.append(delta) + + def _auto_close_open_parameter_if_needed(self, + incoming_tag: Optional[str] = None + ): + """Before starting to process new elements, + if there are unclosed tags from before, + automatically complete their endings to the parser. + - If there are unclosed parameters, + it's equivalent to feeding `` + - When about to start a new function or tool_call, + if there are unclosed functions, complete ``. + - When about to start a new tool_call, + if there are unclosed tool_calls, complete ``. + """ + # First close unclosed parameters + if self.current_param_name: + self._end_element('parameter') + + # If about to start new function or tool_call, + # and there are unclosed functions, close function first + if incoming_tag in ('function', + 'tool_call') and self.current_function_name: + self._end_element('function') + + # If about to start new tool_call, + # and there are unclosed tool_calls, close tool_call first + if incoming_tag == 'tool_call' and self.current_call_id: + self._end_element('tool_call') + + def _start_element(self, name: str, attrs: dict[str, str]): + """Handle XML start element events""" + + if name == 'root': + return + + if name == 'tool_call': + # Before opening new tool_call, + # automatically complete previous unclosed tags + self._auto_close_open_parameter_if_needed('tool_call') + + self.parameters = {} + self.current_call_id = self._get_next_call_id() + self.current_param_is_first = True + self.tool_call_index += 1 + elif name.startswith('function') or (name == 'function'): + # If missing tool_call, manually complete + if not self.current_call_id: + self._start_element('tool_call', {}) + # Before opening new function, + # automatically complete previous unclosed tags (parameter/function) + self._auto_close_open_parameter_if_needed('function') + function_name = self._extract_function_name(name, attrs) + self.current_function_name = function_name + self.current_function_open = True + if function_name: + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.tool_call_index - 1, + id=self.current_call_id, + type='function', + function=DeltaFunctionCall( + name=function_name, arguments='')) + ]) + self._emit_delta(delta) + elif name.startswith('parameter') or (name == 'parameter'): + # If previous parameter hasn't ended normally, + # complete its end first, then start new parameter + self._auto_close_open_parameter_if_needed('parameter') + param_name = self._extract_parameter_name(name, attrs) + self.current_param_name = param_name + self.current_param_value = '' + self.current_param_value_converted = '' + self.start_quote_emitted = False # Reset start quote flag + + # Only output parameter name and colon, + # don't output quotes + # decide after parameter value type is determined + if param_name: + if not self.parameters: + # First parameter + # start JSON, only output parameter name and colon + json_start = f'{{"{param_name}": ' + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.tool_call_index - 1, + id=self.current_call_id, + type='function', + function=DeltaFunctionCall( + name=None, arguments=json_start)) + ]) + self._emit_delta(delta) + self.current_param_is_first = True + else: + # Subsequent parameters + # add comma and parameter name, no quotes + json_continue = f', "{param_name}": ' + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.tool_call_index - 1, + id=self.current_call_id, + type='function', + function=DeltaFunctionCall( + name=None, arguments=json_continue)) + ]) + self._emit_delta(delta) + self.current_param_is_first = False + + def _char_data(self, data: str): + """Handle XML character data events""" + if data and self.current_param_name: + # If preprocessing stage determines deferred parsing is needed, + # only cache character data, no streaming output + if self.defer_current_parameter: + original_data = data + if self.should_emit_end_newline: + original_data = '\n' + original_data + self.should_emit_end_newline = False + if original_data.endswith('\n'): + self.should_emit_end_newline = True + original_data = original_data[:-1] + self.current_param_value += original_data + return + + param_type = self._get_param_type(self.current_param_name) + + # Check if this is the first time receiving data for this parameter + # If this is the first packet of data and starts with \n, remove \n + if not self.current_param_value and data.startswith('\n'): + data = data[1:] + + # Output start quote for string type (if not already output) + if (param_type + in ['string', 'str', 'text', 'varchar', 'char', 'enum'] + and not self.start_quote_emitted): + quote_delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.tool_call_index - 1, + id=self.current_call_id, + type='function', + function=DeltaFunctionCall(name=None, + arguments='"')) + ]) + self._emit_delta(quote_delta) + self.start_quote_emitted = True + + if not data: + return + + original_data = data + # Delay output of trailing newline + if self.should_emit_end_newline: + original_data = '\n' + original_data + self.should_emit_end_newline = False + if original_data.endswith('\n'): + self.should_emit_end_newline = True + original_data = original_data[:-1] + self.current_param_value += original_data + + # convert parameter value by param_type + converted_value = self._convert_param_value( + self.current_param_value, param_type) + output_data = self._convert_for_json_streaming( + converted_value, param_type) + + delta_data = output_data[len(self.current_param_value_converted):] + self.current_param_value_converted = output_data + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.tool_call_index - 1, + id=self.current_call_id, + type='function', + function=DeltaFunctionCall(name=None, + arguments=delta_data)) + ]) + self._emit_delta(delta) + + def _end_element(self, name: str): + """Handle XML end element events""" + + if name == 'root': + return + + # If function or tool_call ends and there are still unclosed parameters, + # complete parameter end first + if (name.startswith('function') or name == 'function' + or name == 'tool_call') and self.current_param_name: + self._auto_close_open_parameter_if_needed() + + if (name.startswith('parameter') + or name == 'parameter') and self.current_param_name: + # End current parameter + param_name = self.current_param_name + param_value = self.current_param_value + + # If in deferred parsing mode, + # perform overall parsing on raw content + # accumulated in preprocessing stage and output once + if self.defer_current_parameter: + raw_text = self.deferred_param_raw_value \ + if self.deferred_param_raw_value else param_value + parsed_value = None + output_arguments = None + try: + # If previously delayed trailing newline, + # add it back before parsing + if self.should_emit_end_newline: + raw_for_parse = raw_text + '\n' + else: + raw_for_parse = raw_text + parsed_value = ast.literal_eval(raw_for_parse) + output_arguments = json.dumps(parsed_value, + ensure_ascii=False) + except Exception: + # Fallback: output as string as-is + output_arguments = json.dumps(raw_text, ensure_ascii=False) + parsed_value = raw_text + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.tool_call_index - 1, + id=self.current_call_id, + type='function', + function=DeltaFunctionCall( + name=None, arguments=output_arguments)) + ]) + self._emit_delta(delta) + + # Clean up and store + self.should_emit_end_newline = False + self.parameters[param_name] = parsed_value + self.current_param_name = None + self.current_param_value = "" + self.current_param_value_converted = "" + self.start_quote_emitted = False + self.defer_current_parameter = False + self.deferred_param_raw_value = "" + return + + param_type = self._get_param_type(param_name) + + # convert complete parameter value by param_type + converted_value = self._convert_param_value( + param_value, param_type) + + # Decide whether to add end quote based on parameter type + if param_type in [ + 'string', 'str', 'text', 'varchar', 'char', 'enum' + ]: + # For empty string parameters, need special handling + if not param_value and not self.start_quote_emitted: + # No start quote output, + # directly output complete empty string + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.tool_call_index - 1, + id=self.current_call_id, + type='function', + function=DeltaFunctionCall( + name=None, arguments='""')) + ]) + self._emit_delta(delta) + else: + # Non-empty parameter value, output end quote + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.tool_call_index - 1, + id=self.current_call_id, + type='function', + function=DeltaFunctionCall( + name=None, arguments='"')) + ]) + self._emit_delta(delta) + + self.should_emit_end_newline = False + # Store converted value + self.parameters[param_name] = converted_value + self.current_param_name = None + self.current_param_value = '' + self.current_param_value_converted = '' + self.start_quote_emitted = False + + elif name.startswith('function') or name == 'function': + # if there are parameters, close JSON object + if self.parameters: + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.tool_call_index - 1, + id=self.current_call_id, + type='function', + function=DeltaFunctionCall(name=None, + arguments='}')) + ]) + self._emit_delta(delta) + # return empty object + else: + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.tool_call_index - 1, + id=self.current_call_id, + type='function', + function=DeltaFunctionCall(name=None, + arguments='{}')) + ]) + self._emit_delta(delta) + self.current_function_open = False + + elif name == 'tool_call': + # Before ending tool_call, + # ensure function is closed to complete missing right brace + if self.current_function_open: + # If there are still unclosed parameters, close them first + if self.current_param_name: + self._end_element('parameter') + # Close function, ensure output '}' or '{}' + self._end_element('function') + # Final Delta + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.tool_call_index - 1, + id=self.current_call_id, + type='function', + function=DeltaFunctionCall(name=None, + arguments='')) + ]) + self._emit_delta(delta) + + # Check if there's text content to output (between tool_calls) + if self.text_content_buffer.strip(): + text_delta = DeltaMessage(content=self.text_content_buffer) + self._emit_delta(text_delta) + + self._reset_xml_parser_after_tool_call() + + def setup_parser(self): + """Set up XML parser event handlers""" + self.parser.buffer_text = True + self.parser.StartElementHandler = self._start_element + self.parser.EndElementHandler = self._end_element + self.parser.CharacterDataHandler = self._char_data + + def set_tools(self, tools: Union[list[ChatCompletionToolsParam], None]): + """Set tool configuration information""" + self.tools = tools + + def _get_next_call_id(self): + """Generate unique call ID""" + return f'call_{uuid.uuid4().hex[:24]}' + + def _extract_function_name(self, name: str, + attrs: dict[str, str]) -> Optional[str]: + """Extract function name from various formats""" + if attrs and 'name' in attrs: + return attrs['name'] + + if '=' in name: + parts = name.split('=', 1) + if len(parts) == 2 and parts[0] == 'function': + return parts[1] + + return None + + def _extract_parameter_name(self, name: str, + attrs: dict[str, str]) -> Optional[str]: + """Extract parameter name from various formats""" + if attrs and 'name' in attrs: + return attrs['name'] + + if '=' in name: + parts = name.split('=', 1) + if len(parts) == 2 and parts[0] == 'parameter': + return parts[1] + + return None + + def _get_param_type(self, param_name: str) -> str: + """Get parameter type based on tool configuration, defaults to string + Args: + param_name: Parameter name + + Returns: + Parameter type + """ + if not self.tools or not self.current_function_name: + return 'string' + + for tool in self.tools: + if not hasattr(tool, 'type') or not (hasattr( + tool, 'function') and hasattr(tool.function, 'name')): + continue + if tool.type == 'function' and \ + tool.function.name == self.current_function_name: + if not hasattr(tool.function, 'parameters'): + return 'string' + params = tool.function.parameters + if isinstance(params, dict) and 'properties' in params: + properties = params['properties'] + if param_name in properties and isinstance( + properties[param_name], dict): + return self.repair_param_type( + str(properties[param_name].get('type', 'string'))) + elif isinstance(params, dict) and param_name in params: + param_config = params[param_name] + if isinstance(param_config, dict): + return self.repair_param_type( + str(param_config.get('type', 'string'))) + break + return 'string' + + def repair_param_type(self, param_type: str) -> str: + """Repair unknown parameter types by treating them as string + Args: + param_type: Parameter type + + Returns: + Repaired parameter type + """ + if param_type in [ + 'string', 'str', 'text', 'varchar', 'char', 'enum' + ] or param_type.startswith('int') or param_type.startswith( + 'uint' + ) or param_type.startswith('long') or param_type.startswith( + 'short' + ) or param_type.startswith('unsigned') or param_type.startswith( + 'num') or param_type.startswith('float') or param_type in [ + 'boolean', 'bool', 'binary' + ] or (param_type in ["object", "array", "arr", "sequence"] + or param_type.startswith("dict") + or param_type.startswith("list")): + return param_type + else: + return 'string' + + def _convert_param_value(self, param_value: str, param_type: str) -> Any: + """Convert value based on parameter type + Args: + param_value: Parameter value + param_type: Parameter type + + Returns: + Converted value + """ + if param_value.lower() == 'null': + return None + + param_type = param_type.strip().lower() + if param_type in ['string', 'str', 'text', 'varchar', 'char', 'enum']: + return param_value + elif (param_type.startswith('int') or param_type.startswith('uint') + or param_type.startswith('long') + or param_type.startswith('short') + or param_type.startswith('unsigned')): + try: + return int(param_value) + except (ValueError, TypeError): + logger.warning( + "Parsed value '%s' of parameter '%s' is not an integer " + "in tool '%s', degenerating to string.", param_value) + return param_value + elif param_type.startswith('num') or param_type.startswith('float'): + try: + float_param_value: float = float(param_value) + return float_param_value if float_param_value - int( + float_param_value) != 0 else int(float_param_value) + except (ValueError, TypeError): + logger.warning( + "Parsed value '%s' of parameter '%s' is not a float " + "in tool '%s', degenerating to string.", param_value) + return param_value + elif param_type in ['boolean', 'bool', 'binary']: + param_value = param_value.lower() + return param_value == 'true' + else: + return param_value + + def _convert_for_json_streaming(self, converted_value: Any, + param_type: str) -> str: + """Convert converted_value based on + whether it's empty and if type is string + Args: + converted_value: Converted value + param_type: Parameter type + + Returns: + Converted string for streaming output + """ + # Check if value is empty, but exclude numeric 0 + if converted_value is None or converted_value == '': + return '' + + if param_type in ['string', 'str', 'text', 'varchar', 'char', 'enum']: + # String type, remove double quotes + return json.dumps(converted_value, ensure_ascii=False)[1:-1] + else: + # Non-string type, return complete JSON string + if not isinstance(converted_value, str): + return json.dumps(converted_value, ensure_ascii=False) + else: + return converted_value + + def _reset_xml_parser_after_tool_call(self): + """ + Each tool_call is treated as a separate XML document, + so we need to reset the parser after each tool_call. + """ + + # recreate XML parser + self.parser = ParserCreate() + self.setup_parser() + + # Reset current tool_call state + if self.current_call_id: + self.last_completed_call_id = self.current_call_id + self.current_call_id = None + self.current_function_name = None + self.current_function_open = False + self.parameters = {} + self.current_param_name = None + self.current_param_value = '' + self.current_param_value_converted = '' + self.current_param_is_first = False + self.should_emit_end_newline = False + self.start_quote_emitted = False + self.text_content_buffer = '' + + # Reset preprocessing and deferred parsing state + self._pre_inside_parameter = False + self._pre_param_buffer = "" + self._pre_current_param_name = None + self.defer_current_parameter = False + self.deferred_param_raw_value = "" + + +@ToolParserManager.register_module("qwen3_xml") +class Qwen3XMLToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + self.parser = StreamingXMLToolCallParser() + + logger.info("vLLM Successfully import tool parser %s !", + self.__class__.__name__) + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + self.parser.reset_streaming_state() + if request: + self.parser.set_tools(request.tools) + result = self.parser.parse_single_streaming_chunks(model_output) + if not result.tool_calls: + return ExtractedToolCallInformation( + tool_calls=[], + tools_called=False, + content=result.content, + ) + else: + tool_calls = [] + for tool_call in result.tool_calls: + if tool_call.function and tool_call.function.name: + tool_calls.append( + ToolCall( + id=tool_call.id, + type=tool_call.type, + function=FunctionCall( + name=tool_call.function.name, + arguments=tool_call.function.arguments, + ), + )) + return ExtractedToolCallInformation( + tool_calls=tool_calls, + tools_called=len(tool_calls) > 0, + content=result.content, + ) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + if not previous_text: + self.parser.reset_streaming_state() + if request: + self.parser.set_tools(request.tools) + + # Model sometimes outputs separately causing delta_text to be empty. + # If there were tool_calls before and all current tool_calls have ended, + # return an empty tool_call for outer streaming output + # to correctly output tool_call field + if not delta_text and delta_token_ids: + open_calls = current_text.count( + self.parser.tool_call_start_token) - current_text.count( + self.parser.tool_call_end_token) + if open_calls == 0 and self.parser.tool_call_index > 0: + # If current_call_id is None, use last_completed_call_id + call_id = self.parser.current_call_id or \ + self.parser.last_completed_call_id + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.parser.tool_call_index - 1, + id=call_id, + function=DeltaFunctionCall(arguments=''), + type='function', + ) + ]) + + return self.parser.parse_single_streaming_chunks(delta_text) diff --git a/vllm/entrypoints/renderer.py b/vllm/entrypoints/renderer.py index fb859d57be9fe..d7ce57c728ba6 100644 --- a/vllm/entrypoints/renderer.py +++ b/vllm/entrypoints/renderer.py @@ -280,7 +280,7 @@ class CompletionRenderer(BaseRenderer): if truncate_prompt_tokens < 0: truncate_prompt_tokens = self.model_config.max_model_len - if max_length is not None and truncate_prompt_tokens > max_length: + if max_length is not None and truncate_prompt_tokens > max_length: # type: ignore[operator] raise ValueError( f"truncate_prompt_tokens ({truncate_prompt_tokens}) " f"cannot be greater than max_length ({max_length}). " diff --git a/vllm/entrypoints/tool_server.py b/vllm/entrypoints/tool_server.py index 056a571fb2fd1..4c627b865ef92 100644 --- a/vllm/entrypoints/tool_server.py +++ b/vllm/entrypoints/tool_server.py @@ -18,7 +18,6 @@ if TYPE_CHECKING: async def list_server_and_tools(server_url: str): from mcp import ClientSession from mcp.client.sse import sse_client - async with sse_client(url=server_url) as streams, ClientSession( *streams) as session: initialize_response = await session.initialize() @@ -86,8 +85,12 @@ class ToolServer(ABC): pass @abstractmethod - def new_session(self, tool_name: str, - session_id: str) -> AbstractAsyncContextManager[Any]: + def new_session( + self, + tool_name: str, + session_id: str, + headers: Optional[dict[str, str]] = None + ) -> AbstractAsyncContextManager[Any]: """ Create a session for the tool. """ @@ -144,16 +147,21 @@ class MCPToolServer(ToolServer): return self.harmony_tool_descriptions.get(tool_name) @asynccontextmanager - async def new_session(self, tool_name: str, session_id: str): + async def new_session(self, + tool_name: str, + session_id: str, + headers: Optional[dict[str, str]] = None): from mcp import ClientSession from mcp.client.sse import sse_client url = self.urls.get(tool_name) - headers = {"x-session-id": session_id} + request_headers = {"x-session-id": session_id} + if headers is not None: + request_headers.update(headers) if not url: raise KeyError(f"Tool '{tool_name}' is not supported") - async with sse_client(url=url, - headers=headers) as streams, ClientSession( - *streams) as session: + async with sse_client( + url=url, headers=request_headers) as streams, ClientSession( + *streams) as session: await session.initialize() yield session @@ -189,7 +197,10 @@ class DemoToolServer(ToolServer): raise ValueError(f"Unknown tool {tool_name}") @asynccontextmanager - async def new_session(self, tool_name: str, session_id: str): + async def new_session(self, + tool_name: str, + session_id: str, + headers: Optional[dict[str, str]] = None): if tool_name not in self.tools: raise KeyError(f"Tool '{tool_name}' is not supported") yield self.tools[tool_name] diff --git a/vllm/envs.py b/vllm/envs.py index cbd1d5474e60f..ee5efff8bcd92 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -119,12 +119,14 @@ if TYPE_CHECKING: VLLM_SERVER_DEV_MODE: bool = False VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128 VLLM_MLA_DISABLE: bool = False + VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH: int = 16 VLLM_RAY_PER_WORKER_GPUS: float = 1.0 VLLM_RAY_BUNDLE_INDICES: str = "" VLLM_CUDART_SO_PATH: Optional[str] = None VLLM_DP_RANK: int = 0 VLLM_DP_RANK_LOCAL: int = -1 VLLM_DP_SIZE: int = 1 + VLLM_USE_STANDALONE_COMPILE: bool = False VLLM_DP_MASTER_IP: str = "" VLLM_DP_MASTER_PORT: int = 0 VLLM_MOE_DP_CHUNK_SIZE: int = 256 @@ -183,11 +185,12 @@ if TYPE_CHECKING: VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False - VLLM_GPT_OSS_USE_CONTAINER_TOOL: bool = False VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME: str = "VLLM_OBJECT_STORAGE_SHM_BUFFER" + GPT_OSS_SYSTEM_TOOL_MCP_LABELS: list[str] = [] + VLLM_PATTERN_MATCH_DEBUG: Optional[str] = None def get_default_cache_root(): @@ -258,6 +261,58 @@ def env_with_choices( return _get_validated_env +def env_list_with_choices( + env_name: str, + default: list[str], + choices: Union[list[str], Callable[[], list[str]]], + case_sensitive: bool = True) -> Callable[[], list[str]]: + """ + Create a lambda that validates environment variable + containing comma-separated values against allowed choices + + Args: + env_name: Name of the environment variable + default: Default list of values if not set + choices: List of valid string options or callable that returns list + case_sensitive: Whether validation should be case sensitive + + Returns: + Lambda function for environment_variables + dict that returns list of strings + """ + + def _get_validated_env_list() -> list[str]: + value = os.getenv(env_name) + if value is None: + return default + + # Split comma-separated values and strip whitespace + values = [v.strip() for v in value.split(",") if v.strip()] + + if not values: + return default + + # Resolve choices if it's a callable (for lazy loading) + actual_choices = choices() if callable(choices) else choices + + # Validate each value + for val in values: + if not case_sensitive: + check_value = val.lower() + check_choices = [choice.lower() for choice in actual_choices] + else: + check_value = val + check_choices = actual_choices + + if check_value not in check_choices: + raise ValueError(f"Invalid value '{val}' in {env_name}. " + f"Valid options: {actual_choices}.") + + return values + + return _get_validated_env_list + + def get_vllm_port() -> Optional[int]: """Get the port from VLLM_PORT environment variable. @@ -436,9 +491,14 @@ environment_variables: dict[str, Callable[[], Any]] = { # Feature flag to enable/disable Inductor standalone compile. # In torch <= 2.7 we ignore this flag; in torch >= 2.8 this is - # enabled by default. + # disabled by default. "VLLM_USE_STANDALONE_COMPILE": - lambda: os.environ.get("VLLM_USE_STANDALONE_COMPILE", "1") == "1", + lambda: os.environ.get("VLLM_USE_STANDALONE_COMPILE", "0") == "1", + + # Debug pattern matching inside custom passes. + # Should be set to the fx.Node name (e.g. 'getitem_34' or 'scaled_mm_3'). + "VLLM_PATTERN_MATCH_DEBUG": + lambda: os.environ.get("VLLM_PATTERN_MATCH_DEBUG", None), # local rank of the process in the distributed setting, used to determine # the GPU device id @@ -946,6 +1006,12 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_MLA_DISABLE": lambda: bool(int(os.getenv("VLLM_MLA_DISABLE", "0"))), + # If set, vLLM will pick up the provided Flash Attention MLA + # max number splits for cuda graph decode + "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": + lambda: int(os.getenv("VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH", + "16")), + # Number of GPUs per worker in Ray, if it is set to be a fraction, # it allows ray to schedule multiple actors on a single GPU, # so that users can colocate other actors on the same GPUs as vLLM. @@ -1306,10 +1372,6 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_TUNED_CONFIG_FOLDER": lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None), - # Allows vllm use container tool - "VLLM_GPT_OSS_USE_CONTAINER_TOOL": - lambda: bool(int(os.getenv("VLLM_GPT_OSS_USE_CONTAINER_TOOL", "0"))), - # Allows harmony instructions to be injected on system messages "VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS": lambda: bool( @@ -1329,6 +1391,14 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME": lambda: os.getenv("VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME", "VLLM_OBJECT_STORAGE_SHM_BUFFER"), + + # Valid values are container,code_interpreter,web_search_preview + # ex GPT_OSS_SYSTEM_TOOL_MCP_LABELS=container,code_interpreter + "GPT_OSS_SYSTEM_TOOL_MCP_LABELS": + env_list_with_choices("GPT_OSS_SYSTEM_TOOL_MCP_LABELS", [], + ["container", + "code_interpreter", + "web_search_preview"]), } # --8<-- [end:env-vars-definition] @@ -1379,6 +1449,7 @@ def compute_hash() -> str: environment_variables_to_hash = [ "VLLM_PP_LAYER_PARTITION", "VLLM_MLA_DISABLE", + "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH", "VLLM_USE_TRITON_FLASH_ATTN", "VLLM_USE_TRITON_AWQ", "VLLM_DP_RANK", diff --git a/vllm/lora/layers/base_linear.py b/vllm/lora/layers/base_linear.py index 85a1f86ce6bf2..6cf5815ef12da 100644 --- a/vllm/lora/layers/base_linear.py +++ b/vllm/lora/layers/base_linear.py @@ -121,18 +121,18 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): lora_bias = self.slice_bias(lora_bias) self.lora_a_stacked[0][index, - 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( - lora_a.T, non_blocking=True) + 0, :lora_a.shape[0], :lora_a.shape[1]].copy_( + lora_a, non_blocking=True) self.lora_b_stacked[0][index, - 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( - lora_b.T, non_blocking=True) + 0, :lora_b.shape[0], :lora_b.shape[1]].copy_( + lora_b, non_blocking=True) if lora_bias is not None: self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], self.lora_bias_stacked) assert len(self.lora_bias_stacked) self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_( - lora_bias.T, non_blocking=True) + lora_bias, non_blocking=True) def apply(self, x: torch.Tensor, diff --git a/vllm/lora/layers/column_parallel_linear.py b/vllm/lora/layers/column_parallel_linear.py index 658fd23165da0..fa4eb272a69fe 100644 --- a/vllm/lora/layers/column_parallel_linear.py +++ b/vllm/lora/layers/column_parallel_linear.py @@ -99,13 +99,13 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): if self.is_merged_col_linear: tp_rank = get_tensor_model_parallel_rank() shard_size = self.output_size // 2 - offset = lora_b.shape[-1] // 2 + offset = lora_b.shape[0] // 2 - left_weight = lora_b[:, tp_rank * shard_size:(tp_rank + 1) * - shard_size] - right_weight = lora_b[:, offset + tp_rank * shard_size:offset + - (tp_rank + 1) * shard_size] - lora_b = torch.cat([left_weight, right_weight], dim=1) + left_weight = lora_b[tp_rank * shard_size:(tp_rank + 1) * + shard_size, :] + right_weight = lora_b[offset + tp_rank * shard_size:offset + + (tp_rank + 1) * shard_size, :] + lora_b = torch.cat([left_weight, right_weight], dim=0) # Applicable to cases where the base_layer is # ColumnParallelLinear. else: @@ -113,7 +113,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): shard_size = self.output_size start_idx = tensor_model_parallel_rank * shard_size end_idx = (tensor_model_parallel_rank + 1) * shard_size - lora_b = lora_b[:, start_idx:end_idx] + lora_b = lora_b[start_idx:end_idx, :] return lora_b def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: @@ -251,9 +251,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): for i, (shard_id, shard_size) in enumerate( zip(self.output_ids, self.output_slices)): if (lora_b_i := lora_b[i]) is not None: - sliced_lora_b[i] = lora_b_i[:, - shard_size * shard_id:shard_size * - (shard_id + 1)] + sliced_lora_b[i] = lora_b_i[shard_size * shard_id:shard_size * + (shard_id + 1), :] return sliced_lora_b def slice_bias( @@ -285,12 +284,12 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): for i in range(self.n_slices): if (lora_a_i := lora_a[i]) is not None: self.lora_a_stacked[i][ - index, 0, :lora_a_i.shape[1], :lora_a_i.shape[0]].copy_( - lora_a_i.T, non_blocking=True) + index, 0, :lora_a_i.shape[0], :lora_a_i.shape[1]].copy_( + lora_a_i, non_blocking=True) if (lora_b_i := lora_b[i]) is not None: self.lora_b_stacked[i][ - index, 0, :lora_b_i.shape[1], :lora_b_i.shape[0]].copy_( - lora_b_i.T, non_blocking=True) + index, 0, :lora_b_i.shape[0], :lora_b_i.shape[1]].copy_( + lora_b_i, non_blocking=True) if lora_bias is not None: self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], @@ -299,7 +298,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): if (lora_bias_i := lora_bias[i]) is not None: self.lora_bias_stacked[i][index, 0, :lora_bias_i.shape[0]].copy_( - lora_bias_i.T, + lora_bias_i, non_blocking=True) @classmethod @@ -345,18 +344,18 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): tp_rank = get_tensor_model_parallel_rank() self.q_shard_id = tp_rank self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas - lora_b_q = lora_b[:, self.q_proj_shard_size * + lora_b_q = lora_b[self.q_proj_shard_size * self.q_shard_id:self.q_proj_shard_size * - (self.q_shard_id + 1)] + (self.q_shard_id + 1), :] k_offset = self.q_proj_total_size - lora_b_k = lora_b[:, k_offset + + lora_b_k = lora_b[k_offset + self.kv_proj_shard_size * self.kv_shard_id:k_offset + - self.kv_proj_shard_size * (self.kv_shard_id + 1)] + self.kv_proj_shard_size * (self.kv_shard_id + 1), :] v_offset = k_offset + self.kv_proj_total_size - lora_b_v = lora_b[:, v_offset + + lora_b_v = lora_b[v_offset + self.kv_proj_shard_size * self.kv_shard_id:v_offset + - self.kv_proj_shard_size * (self.kv_shard_id + 1)] - lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1) + self.kv_proj_shard_size * (self.kv_shard_id + 1), :] + lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=0) return lora_b def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: @@ -465,7 +464,7 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA): tp_rank = get_tensor_model_parallel_rank() shard_size = self.lora_a_stacked[0].shape[2] start_idx = tp_rank * shard_size - lora_a = lora_a[:, start_idx:start_idx + shard_size] + lora_a = lora_a[start_idx:start_idx + shard_size, :] return lora_a def apply(self, @@ -508,10 +507,10 @@ class MergedColumnParallelLinearWithShardedLoRA( output_shard_size = self.lora_a_stacked[0].shape[2] output_start_idx = self.tp_rank * output_shard_size lora_a = [ - lora_a[0][:, output_start_idx:output_start_idx + - output_shard_size] if lora_a[0] is not None else None, - lora_a[1][:, output_start_idx:output_start_idx + - output_shard_size] if lora_a[1] is not None else None, + lora_a[0][output_start_idx:output_start_idx + + output_shard_size, :] if lora_a[0] is not None else None, + lora_a[1][output_start_idx:output_start_idx + + output_shard_size, :] if lora_a[1] is not None else None, ] return lora_a @@ -551,7 +550,7 @@ class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA): tp_rank = get_tensor_model_parallel_rank() shard_size = self.lora_a_stacked[0].shape[2] start_idx = tp_rank * shard_size - lora_a = lora_a[:, start_idx:start_idx + shard_size] + lora_a = lora_a[start_idx:start_idx + shard_size, :] return lora_a def apply(self, @@ -589,12 +588,12 @@ class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA): shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)] start_idx = [self.tp_rank * shard_size[i] for i in range(3)] lora_a = [ - lora_a[0][:, start_idx[0]:start_idx[0] + - shard_size[0]] if lora_a[0] is not None else None, - lora_a[1][:, start_idx[1]:start_idx[1] + - shard_size[1]] if lora_a[1] is not None else None, - lora_a[2][:, start_idx[2]:start_idx[2] + - shard_size[2]] if lora_a[2] is not None else None, + lora_a[0][start_idx[0]:start_idx[0] + + shard_size[0], :] if lora_a[0] is not None else None, + lora_a[1][start_idx[1]:start_idx[1] + + shard_size[1], :] if lora_a[1] is not None else None, + lora_a[2][start_idx[2]:start_idx[2] + + shard_size[2], :] if lora_a[2] is not None else None, ] return lora_a diff --git a/vllm/lora/layers/logits_processor.py b/vllm/lora/layers/logits_processor.py index a50dcfa748f2f..b8fbad3a4af01 100644 --- a/vllm/lora/layers/logits_processor.py +++ b/vllm/lora/layers/logits_processor.py @@ -140,11 +140,11 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): ): self.reset_lora(index) self.lora_a_stacked[index, - 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( - lora_a.T, non_blocking=True) + 0, :lora_a.shape[0], :lora_a.shape[1]].copy_( + lora_a, non_blocking=True) self.lora_b_stacked[index, - 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( - lora_b.T, non_blocking=True) + 0, :lora_b.shape[0], :lora_b.shape[1]].copy_( + lora_b, non_blocking=True) if embeddings_tensor is not None: self.embeddings_tensors[ index, diff --git a/vllm/lora/layers/row_parallel_linear.py b/vllm/lora/layers/row_parallel_linear.py index 18ef6fd1ddd78..cac2c92136dca 100644 --- a/vllm/lora/layers/row_parallel_linear.py +++ b/vllm/lora/layers/row_parallel_linear.py @@ -39,7 +39,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): shard_size = self.input_size start_idx = self.tp_rank * shard_size end_idx = (self.tp_rank + 1) * shard_size - lora_a = lora_a[start_idx:end_idx, :] + lora_a = lora_a[:,start_idx:end_idx] return lora_a def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: @@ -122,7 +122,7 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): shard_size = self.lora_b_stacked[0].shape[2] start_idx = self.tp_rank * shard_size end_idx = (self.tp_rank + 1) * shard_size - lora_b = lora_b[:, start_idx:end_idx] + lora_b = lora_b[ start_idx:end_idx,:] return lora_b def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: diff --git a/vllm/lora/layers/vocal_parallel_embedding.py b/vllm/lora/layers/vocal_parallel_embedding.py index 4d6218d970977..ca01c7e17fff4 100644 --- a/vllm/lora/layers/vocal_parallel_embedding.py +++ b/vllm/lora/layers/vocal_parallel_embedding.py @@ -95,11 +95,13 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): bias: Optional[torch.Tensor] = None, ): self.reset_lora(index) - self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_( - lora_a, non_blocking=True) + # NOTE self.lora_a_stacked is row-major, and lora_a is col-major, + # so we need transpose here + self.lora_a_stacked[index, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) self.lora_b_stacked[index, - 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( - lora_b.T, non_blocking=True) + 0, :lora_b.shape[0], :lora_b.shape[1]].copy_( + lora_b, non_blocking=True) if embeddings_tensor is not None: self.embeddings_tensors[ index, diff --git a/vllm/lora/lora_weights.py b/vllm/lora/lora_weights.py index 958364fca592f..e3198fb3d3ae4 100644 --- a/vllm/lora/lora_weights.py +++ b/vllm/lora/lora_weights.py @@ -86,11 +86,11 @@ class LoRALayerWeights: embeddings_tensor_dim: Optional[int] = None, bias_enabled: Optional[bool] = False) -> "LoRALayerWeights": pin_memory = str(device) == "cpu" and is_pin_memory_available() - lora_a = torch.zeros([input_dim, rank], + lora_a = torch.zeros([rank, input_dim], dtype=dtype, device=device, pin_memory=pin_memory) - lora_b = torch.zeros([rank, output_dim], + lora_b = torch.zeros([output_dim, rank], dtype=dtype, device=device, pin_memory=pin_memory) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 9ea46be65cff3..cc64cc78affa7 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -152,30 +152,29 @@ class LoRAModel: module_name, peft_helper, lora_embeddings_tensor) if is_bias: - loras[module_name].bias = tensor.to(device=device, - dtype=dtype).t() - bias = tensor.to(device=device, dtype=dtype).t() + loras[module_name].bias = tensor.to(device=device, dtype=dtype) + bias = tensor.to(device=device, dtype=dtype) if pin_memory: bias = bias.pin_memory() loras[module_name].bias = bias elif is_lora_a: loras[module_name].lora_a = tensor.to(device=device, - dtype=dtype).t() + dtype=dtype) if pin_memory: loras[module_name].lora_a = loras[ module_name].lora_a.pin_memory() else: loras[module_name].lora_b = tensor.to(device=device, - dtype=dtype).t() + dtype=dtype) assert embedding_padding_modules is not None if any(name in module_name for name in embedding_padding_modules ) and target_embedding_padding is not None: lora_b = loras[module_name].lora_b - assert target_embedding_padding >= lora_b.shape[1] - addition = target_embedding_padding - lora_b.shape[1] + assert target_embedding_padding >= lora_b.shape[0] + addition = target_embedding_padding - lora_b.shape[0] loras[module_name].lora_b = torch.nn.functional.pad( - lora_b, (0, addition)) + lora_b, (0, 0, 0, addition)) if pin_memory: loras[module_name].lora_b = loras[ module_name].lora_b.pin_memory() @@ -585,7 +584,6 @@ class LoRAModelManager: "cpu", bias_enabled=bias_enabled, ) - lora.optimize() else: parts = module_name.split(".") replacements = self.packed_modules_mapping[parts[-1]] @@ -600,7 +598,6 @@ class LoRAModelManager: "cpu", bias_enabled=bias_enabled, ) - lora.optimize() subloras.append(lora) lora = PackedLoRALayerWeights.pack(subloras) model.loras[module_name] = lora diff --git a/vllm/lora/ops/triton_ops/lora_kernel_metadata.py b/vllm/lora/ops/triton_ops/lora_kernel_metadata.py index 39e647b9b88a4..e27604728ed06 100644 --- a/vllm/lora/ops/triton_ops/lora_kernel_metadata.py +++ b/vllm/lora/ops/triton_ops/lora_kernel_metadata.py @@ -83,8 +83,8 @@ class LoRAKernelMeta: Prepare kernel metadata tensors for the current forward pass. Args: - token_lora_tensor (torch.Tensor): Tensor containing lora indices - for each input token. + token_lora_mapping (torch.Tensor): Tensor containing lora indices + for each input token. """ self._reset() @@ -136,7 +136,7 @@ class LoRAKernelMeta: Args: token_nums (int): Number of input tokens in the current forward - pass. + pass of the kernel. """ return ( self.token_lora_mapping[:token_nums], diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py index 9118f3351ef0a..29bfd5753a588 100644 --- a/vllm/lora/ops/xla_ops/lora_ops.py +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -93,7 +93,6 @@ def bgmv_shrink( inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size]. lora_b_weights (torch.Tensor): LoRA weights of shape [num_loras, lora_rank, hidden_size]. - output_tensor (torch.Tensor): (Unused) output tensor (placeholder). lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens] indicating which LoRA matrix to use for each token. scaling (float, optional): Scalar multiplier applied to the output. diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 07dc337a1cc87..5896da516540b 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional, Union import torch import torch.nn.functional as F -import torch_xla.core.xla_model as xm +import torch_xla from vllm.lora.ops.xla_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink from vllm.lora.punica_wrapper.utils import convert_mapping @@ -323,7 +323,7 @@ class PunicaWrapperTPU(PunicaWrapperBase): extra_vocab_size: int, ): # Make sure we don't accidentally collect outside operations - xm.mark_step() + torch_xla.sync() # Pad the prompt mapping to avoid running into recompiles on the TPU # TODO: Should this happen inside mapping internally? If so how can we diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index b14bc06e913cf..34bfe1c16aac7 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -14,11 +14,16 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.utils import cdiv, has_triton_kernels from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe -if has_triton_kernels(): - from triton_kernels.matmul_ogs import PrecisionConfig - logger = init_logger(__name__) +if has_triton_kernels(): + try: + from triton_kernels.matmul_ogs import PrecisionConfig + except ImportError: + logger.error( + "Failed to import Triton kernels. Please make sure your triton " + "version is compatible.") + def _get_config_dtype_str( dtype: torch.dtype, @@ -288,7 +293,11 @@ class FusedMoEQuantConfig: @property def use_mxfp4_w4a4(self) -> bool: - return self.quant_dtype == "mxfp4" + return (self._a1.dtype == "mxfp4" and self._w1.dtype == "mxfp4") + + @property + def use_mxfp4_w4a16(self) -> bool: + return (self._a1.dtype is None and self._w1.dtype == "mxfp4") @property def use_nvfp4_w4a4(self) -> bool: @@ -453,6 +462,22 @@ def int8_w8a8_moe_quant_config( ) +def mxfp4_w4a16_moe_quant_config( + w1_scale: Union[torch.Tensor, "PrecisionConfig"], + w2_scale: Union[torch.Tensor, "PrecisionConfig"], + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None) -> FusedMoEQuantConfig: + """ + Construct a quant config for unquantized activations and mxfp4 weights. + """ + return FusedMoEQuantConfig( + _a1=FusedMoEQuantDesc(), + _a2=FusedMoEQuantDesc(), + _w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias), + _w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias), + ) + + def mxfp4_w4a4_moe_quant_config( w1_scale: Union[torch.Tensor, "PrecisionConfig"], w2_scale: Union[torch.Tensor, "PrecisionConfig"], diff --git a/vllm/model_executor/layers/fused_moe/configs/E=62,N=128,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=62,N=128,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000000000..40d86ff8ba324 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=62,N=128,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "3072": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=62,N=256,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=62,N=256,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000000000..6014d827d7417 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=62,N=256,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=62,N=512,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=62,N=512,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000000000..3622659f3e915 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=62,N=512,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=72,N=192,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=72,N=192,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000000000..311d2e829a050 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=72,N=192,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "3072": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=72,N=384,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=72,N=384,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000000000..91c4b916b8649 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=72,N=384,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=72,N=768,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=72,N=768,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000000000..8fee30ec70660 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=72,N=768,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py index f390f0a25875e..a250a6218715e 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -11,6 +11,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceContiguous, TopKWeightAndReduceDelegate) from vllm.model_executor.layers.fused_moe.utils import ( moe_kernel_quantize_input) +from vllm.utils import round_up class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): @@ -18,6 +19,23 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): Prepare/Finalize using DeepEP High-Throughput kernels. """ + @staticmethod + def maybe_roundup_layer_hidden_size(hidden_size: int, + dtype: torch.dtype) -> int: + # Round up hidden size so it is compatible with DeepEP High Throughput + # kernels. + # DeepEP intranode kernels make copies in units of, + # 32(warp-size) int4 elements. Round up hidden size to respect this. + # For example, an input hidden size of 2880 with dtype torch.bfloat16 + # will be rounded up to 3072. + hidden_size_bytes = hidden_size * dtype.itemsize + xfer_atom_size = 512 # 32 * 16 (size(int4)) + if hidden_size_bytes % xfer_atom_size == 0: + return hidden_size + + hidden_size_bytes = round_up(hidden_size_bytes, xfer_atom_size) + return hidden_size_bytes // dtype.itemsize + def __init__(self, buffer: deep_ep.Buffer, num_dispatchers: int, dp_size: int, rank_expert_offset: int): super().__init__() diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 6c2a5bda7cbaa..0e334fdf24045 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1017,6 +1017,79 @@ def grouped_topk( return topk_weights.to(torch.float32), topk_ids.to(torch.int32) +@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) +def eplb_map_to_physical_and_record( + topk_ids: torch.Tensor, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + indices_type: Optional[torch.dtype] = None) -> torch.Tensor: + ''' + Map the logical expert ids to physical expert ids + and record the expert load metrics. + + This will select a pseudo-random replica for each logical expert. + Only used for EPLB. + + Args: + topk_ids: The logical expert ids. + expert_load_view: The expert load view. + logical_to_physical_map: The logical to physical map. + logical_replica_count: The logical replica count. + indices_type: The indices type. + + Returns: + The physical expert ids. + ''' + + # 1. Convert the logical expert ids to physical expert ids + # Directly select a random replica for each logical expert + + # In case `indices_type` is not `torch.long` or `torch.int`, + # e.g. `torch.uint32` as required by dispatch/combine kernels + topk_ids_long = topk_ids.long() + # Use (token position) modulo (replica count) + # to deterministically choose a replica + replica_count = logical_replica_count[topk_ids_long] + # Flatten-position based index, reshaped back to `topk_ids` shape + pos_indices = torch.arange(topk_ids.numel(), + device=topk_ids.device, + dtype=torch.long).reshape_as(topk_ids) + # Compute pseudo-random indices by modulo + replica_indices = (pos_indices % replica_count).unsqueeze(-1) + physical_ids = logical_to_physical_map[topk_ids_long].gather( + -1, replica_indices).squeeze(-1) + + topk_ids = physical_ids + + # 2. Record expert load metrics. + + # TODO(bowen): When using `FusedMoEModularKernel`, this + # can be done in a more unified way, since + # `FusedMoEPrepareAndFinalize` will return the expert + # token count, in some cases directly from the kernel. + # However, now there are many code paths not using + # the modular kernel, e.g. calling `fused_experts`, + # so we decide to keep the logic here. + # + # If later refactor moved all the MoE kernel calls + # to the modular kernel, we can move this logic there + # to achieve better efficiency. + + # `expert_load_view`: (num_physical_experts,) + + # `torch.bincount` is not compilable, so use `scatter_add_` instead. + topk_ids_flatten = topk_ids.flatten() + expert_load_view.scatter_add_( + dim=0, + index=topk_ids_flatten.long(), + src=torch.ones_like(topk_ids_flatten).to(expert_load_view)) + + if indices_type is not None: + topk_ids = topk_ids.to(dtype=indices_type) + return topk_ids + + def fused_grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index f12d3807517ff..0e84a9241e905 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -9,7 +9,8 @@ from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceDelegate) + TopKWeightAndReduceNoOP) +from vllm.triton_utils import tl, triton from vllm.utils import has_triton_kernels logger = init_logger(__name__) @@ -19,13 +20,55 @@ if has_triton_kernels(): import triton_kernels.swiglu from triton_kernels.matmul_ogs import (FnSpecs, FusedActivation, matmul_ogs) - from triton_kernels.routing import routing + from triton_kernels.routing import (RoutingData, routing, + routing_from_bitmatrix) + from triton_kernels.tensor import Bitmatrix except (ModuleNotFoundError, AttributeError) as e: logger.error( "Failed to import Triton kernels. Please make sure your triton " "version is compatible. Error: %s", e) +@triton.jit +def pack_bitmatrix( + bitmatrix, + topk_ids, + n_rows, # n_rows in bitmatrix / topk_ids + bm_cols: tl.constexpr, # n int32_t bitpacks in bitmatrix + n_expts_act, # num_topk + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + """ + Packs topk_ids into a bitmatrix. + code reference: + https://github.com/triton-lang/triton/blob/dd1bbc52b34d202dfe5ffea1e04fb16166c5c04e/python/triton_kernels/bench/distributed.py#L264 + """ + pid_m = tl.program_id(0) + offsets_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offsets_k = tl.arange(0, BLOCK_SIZE_K) + offsets = offsets_m[:, None] * n_expts_act + offsets_k[None, :] + mask = (offsets_m < n_rows)[:, None] & (offsets_k < n_expts_act)[None, :] + indices = tl.load(topk_ids + offsets, mask=mask, other=-1) + div = indices // 32 + rem = indices % 32 + one = tl.cast(1, tl.uint32) + + # Iterate through all the relevant bitmatrix columns. + for i in range(bm_cols): + # When BLOCK_SIZE_K=32, offs is just the column index. + offs = tl.arange(0, BLOCK_SIZE_K // 32) + i * (BLOCK_SIZE_K // 32) + # All topks that need to go into this column has the correct bit set. + # Other bits are 0. x is a 2D tensor. + x = tl.where(div[:, :, None] == offs[None, None, :], + (one << rem)[:, :, None], 0) + # Reduce x to get a single int32_t bitpack. + y = tl.reduce_or(x, axis=1) + bitmatrix_ptrs = bitmatrix + offsets_m[:, + None] * bm_cols + offs[None, :] + tl.store(bitmatrix_ptrs, y, mask=offsets_m[:, None] < n_rows) + + def triton_kernel_moe_forward( hidden_states: torch.Tensor, w1, # Tensor or triton_kernels.Tensor @@ -124,34 +167,88 @@ def triton_kernel_fused_experts( return intermediate_cache3 -class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): +def make_routing_data( + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + num_local_experts: int, +) -> tuple["RoutingData", torch.Tensor, torch.Tensor]: - def __init__( - self, - max_num_tokens: int, - num_dispatchers: int, - quant_config: FusedMoEQuantConfig, - ): + topk_ids = topk_ids.to(torch.int16) + topk_weights = topk_weights.to(torch.bfloat16) + + n_rows, num_topk = topk_ids.size() + + BLOCK_SIZE_M = 512 + BLOCK_SIZE_K = 32 + + bm_cols = triton.cdiv(num_local_experts, BLOCK_SIZE_K) # n_bitpacks + bitmatrix = torch.zeros((n_rows, bm_cols), + dtype=torch.uint32, + device=topk_ids.device) + + grid = (triton.cdiv(n_rows, BLOCK_SIZE_M), ) + pack_bitmatrix[grid]( + bitmatrix, + topk_ids, + n_rows, + bm_cols, + num_topk, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_K=BLOCK_SIZE_K, + ) + + bitmatrix_shape = [n_rows, bm_cols * 32] + bitmatrix_shape_max = [n_rows, None] + bitmatrix = Bitmatrix(bitmatrix, + shape=bitmatrix_shape, + shape_max=bitmatrix_shape_max, + scratchpad=None) + + # matmul_ogs expects invalid topk_weights to be -1s + topk_weights = torch.where(topk_ids == -1, -1.0, topk_weights) + routing_data, gather_indx, scatter_indx = routing_from_bitmatrix( + bitmatrix, topk_weights, topk_ids, num_local_experts, num_topk) + + return routing_data, gather_indx, scatter_indx + + +class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): + + def __init__(self, quant_config: FusedMoEQuantConfig): + super().__init__(quant_config) + + def supports_expert_map(self) -> bool: + return True + + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + # Weight application and reduction happens in the fused_experts kernel. + return TopKWeightAndReduceNoOP() + + def _make_routing_data( + self, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + num_local_experts: int, + ) -> tuple["RoutingData", torch.Tensor, torch.Tensor]: + return make_routing_data(topk_ids, topk_weights, num_local_experts) + + +class OAITritonExperts(BaseOAITritonExperts): + + def __init__(self, quant_config: FusedMoEQuantConfig): + # TODO (varun) : Enable activation quantization + assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16" super().__init__(quant_config) - self.max_num_tokens = max_num_tokens - self.num_dispatchers = num_dispatchers @property def activation_formats( self ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return (mk.FusedMoEActivationFormat.BatchedExperts, - mk.FusedMoEActivationFormat.BatchedExperts) + return (mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard) def supports_chunking(self) -> bool: - return False - - def supports_expert_map(self) -> bool: - return False - - def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: - # Let PrepareAndFinalize::finalize() decide the impl. - return TopKWeightAndReduceDelegate() + return True def workspace_shapes( self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int, @@ -159,13 +256,10 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): expert_tokens_meta: Optional[mk.ExpertTokensMetadata] ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: # workspace are allocated inside the kernel - assert a.dim() == 2 - num_dp = self.num_dispatchers - num_experts = local_num_experts - max_num_tokens = self.max_num_tokens - workspace2 = (0, 0, 0) - output = (num_experts, max_num_tokens * num_dp, N) - return (output, workspace2, output, a.dtype) + workspace1 = (M, K) + workspace2 = (0, 0) + output = (M, K) + return (workspace1, workspace2, output, a.dtype) def apply( self, @@ -185,17 +279,29 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: bool, ): - return triton_kernel_fused_experts( - output, + if expert_map is not None: + topk_ids = expert_map[topk_ids] + + local_num_experts = w1.size(0) + if global_num_experts == -1: + global_num_experts = local_num_experts + + routing_data, gather_indx, scatter_indx = self._make_routing_data( + topk_ids, topk_weights, local_num_experts) + + experts_output = triton_kernel_fused_experts( + None, hidden_states, w1, w2, - routing_data=None, - gather_indx=None, - scatter_indx=None, + routing_data, + gather_indx, + scatter_indx, activation=activation, quant_config=self.quant_config, apply_router_weight_on_input=False, - global_num_experts=global_num_experts, - expert_map=expert_map, + global_num_experts=local_num_experts, + expert_map=None, # applied already a1q_scale=a1q_scale) + + output.copy_(experts_output, non_blocking=True) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index da513d75da4da..71cc2bcf174dd 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -43,7 +43,8 @@ from vllm.v1.worker.ubatching import dbo_current_ubatch_id if current_platform.is_cuda_alike(): from .fused_batched_moe import BatchedTritonExperts - from .fused_moe import TritonExperts, fused_experts + from .fused_moe import (TritonExperts, eplb_map_to_physical_and_record, + fused_experts) if has_pplx(): from .pplx_prepare_finalize import (PplxPrepareAndFinalize, pplx_hidden_dim_scale_bytes) @@ -55,6 +56,16 @@ else: fused_experts = None # type: ignore FusedMoEPermuteExpertsUnpermute = None # type: ignore FusedMoEPrepareAndFinalize = None # type: ignore + + def eplb_map_to_physical_and_record( + topk_ids: torch.Tensor, expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + indices_type: Optional[torch.dtype]) -> torch.Tensor: + # CPU fallback: no EPLB so just return as is + return topk_ids + + if is_rocm_aiter_moe_enabled(): from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 rocm_aiter_grouped_topk as grouped_topk) @@ -789,6 +800,49 @@ def get_compressed_expert_map(expert_map: torch.Tensor) -> str: for local_index, global_index in zip(local_indices, global_indices)) +def maybe_roundup_hidden_size( + hidden_size: int, act_dtype: torch.dtype, + quant_config: Optional[QuantizationConfig], + moe_parallel_config: FusedMoEParallelConfig) -> int: + """ + Given layer hidden size and MoE configurations, round up hidden_size + if necessary. + + Args: + hidden_size(int): Layer hidden-size + act_dtype: Data type of the layer activations. + quant_config(FusedMoEQuantConfig): Fused MoE quantization configuration. + moe_parallel_config(FusedMoEParallelConfig): Fused MoE parallelization + strategy configuration. + + Return: + Rounded up hidden_size if rounding up is required based on the configs. + Original hidden size otherwise. + """ + + if (moe_parallel_config.use_deepep_ht_kernels): + hidden_size = ( + DeepEPHTPrepareAndFinalize.maybe_roundup_layer_hidden_size( + hidden_size, act_dtype)) + + # we are padding globally so EP buffer allocation works + if quant_config and quant_config.get_name() == "mxfp4": + + from vllm.model_executor.layers.quantization.mxfp4 import ( + Mxfp4Backend, get_mxfp4_backend) + current_mxfp4_backend = get_mxfp4_backend() + if (current_mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16 + or current_mxfp4_backend + == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS): + hidden_size = round_up(hidden_size, 128) + elif (current_platform.is_rocm() or current_mxfp4_backend + == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM + or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16): + hidden_size = round_up(hidden_size, 256) + + return hidden_size + + @CustomOp.register("fused_moe") class FusedMoE(CustomOp): """FusedMoE layer for MoE models. @@ -845,6 +899,18 @@ class FusedMoE(CustomOp): params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype + vllm_config = get_current_vllm_config() + + # FIXME (varun): We should have a better way of inferring the activation + # datatype. This works for now as the tensor datatype entering the MoE + # operation is typically unquantized (i.e. float16/bfloat16). + if vllm_config.model_config is not None: + moe_in_dtype = vllm_config.model_config.dtype + else: + # TODO (bnell): This is a hack to get test_mixtral_moe to work + # since model_config is not set in the pytest test. + moe_in_dtype = params_dtype + tp_size_ = (tp_size if tp_size is not None else get_tensor_model_parallel_world_size()) dp_size_ = (dp_size @@ -854,7 +920,6 @@ class FusedMoE(CustomOp): if self.is_sequence_parallel: self.sp_size = tp_size_ - vllm_config = get_current_vllm_config() self.moe_parallel_config: FusedMoEParallelConfig = ( FusedMoEParallelConfig.make( tp_size_=tp_size_, @@ -863,19 +928,10 @@ class FusedMoE(CustomOp): self.global_num_experts = num_experts + num_redundant_experts - # we are padding globally so EP buffer allocation works - if quant_config and quant_config.get_name() == "mxfp4": - from vllm.model_executor.layers.quantization.mxfp4 import ( - Mxfp4Backend, get_mxfp4_backend) - current_mxfp4_backend = get_mxfp4_backend() - if (current_mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16 - or current_mxfp4_backend - == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS): - hidden_size = round_up(hidden_size, 128) - elif (current_platform.is_rocm() or current_mxfp4_backend - == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM or - current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16): - hidden_size = round_up(hidden_size, 256) + # Round up hidden size if needed. + hidden_size = maybe_roundup_hidden_size(hidden_size, moe_in_dtype, + quant_config, + self.moe_parallel_config) # For smuggling this layer into the fused moe custom op compilation_config = vllm_config.compilation_config @@ -916,12 +972,15 @@ class FusedMoE(CustomOp): "experts. Falling back to linear expert placement.") expert_placement_strategy = "linear" - self.local_num_experts, self.expert_map = determine_expert_map( + self.expert_map: Optional[torch.Tensor] + local_num_experts, expert_map = determine_expert_map( ep_size=self.ep_size, ep_rank=self.ep_rank, global_num_experts=self.global_num_experts, expert_placement_strategy=expert_placement_strategy, ) + self.local_num_experts = local_num_experts + self.register_buffer("expert_map", expert_map) logger.info_once( "[EP Rank %s/%s] Expert parallelism is enabled. Expert " "placement strategy: %s. Local/global" @@ -956,20 +1015,13 @@ class FusedMoE(CustomOp): raise ValueError("Only softmax scoring function is supported for " "non-grouped topk.") - if vllm_config.model_config is not None: - model_dtype = vllm_config.model_config.dtype - else: - # TODO (bnell): This is a hack to get test_mixtral_moe to work - # since model_config is not set in the pytest test. - model_dtype = params_dtype - moe = FusedMoEConfig( num_experts=self.global_num_experts, experts_per_token=top_k, hidden_dim=hidden_size, num_local_experts=self.local_num_experts, moe_parallel_config=self.moe_parallel_config, - in_dtype=model_dtype, + in_dtype=moe_in_dtype, max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE, has_bias=has_bias, ) @@ -1105,10 +1157,12 @@ class FusedMoE(CustomOp): # ep_size and ep_rank should already be updated assert self.expert_map is not None with self.expert_map.device: - self.local_num_experts, self.expert_map = determine_expert_map( + local_num_experts, expert_map = determine_expert_map( ep_size=self.ep_size, ep_rank=self.ep_rank, global_num_experts=self.global_num_experts) + self.local_num_experts = local_num_experts + self.register_buffer("expert_map", expert_map) def _load_per_tensor_weight_scale(self, shard_id: str, param: torch.nn.Parameter, @@ -1616,55 +1670,13 @@ class FusedMoE(CustomOp): assert logical_to_physical_map is not None assert logical_replica_count is not None - # 1. Convert the logical expert ids to physical expert ids - # Directly select a random replica for each logical expert - - # TODO: maybe optimize this by using specified kernels, - # or compute pseudo-random indices by modulo - - # In case `indices_type` is not `torch.long` or `torch.int`, - # e.g. `torch.uint32` as required by dispatch/combine kernels - topk_ids_long = topk_ids.long() - replica_indices = ( - torch.rand_like(topk_ids, dtype=torch.float) * - logical_replica_count[topk_ids_long]).long().unsqueeze(-1) - physical_ids = logical_to_physical_map[topk_ids_long].gather( - -1, replica_indices).squeeze(-1) - - topk_ids = physical_ids - - # 2. Record expert load metrics. - - # TODO(bowen): When using `FusedMoEModularKernel`, this - # can be done in a more unified way, since - # `FusedMoEPrepareAndFinalize` will return the expert - # token count, in some cases directly from the kernel. - # However, now there are many code paths not using - # the modular kernel, e.g. calling `fused_experts`, - # so we decide to keep the logic here. - # - # If later refactor moved all the MoE kernel calls - # to the modular kernel, we can move this logic there - # to achieve better efficiency. - - # `expert_load_view`: (num_physical_experts,) - - topk_ids_flatten = topk_ids.flatten() - - # Performance optimization: - # `masked_fill` is significantly faster than `masked_select` - invalid_mask = topk_ids_flatten < 0 - # Replace invalid expert ids with 0 (just a dummy position) - # to avoid out-of-bounds errors in scatter_add_ - index = topk_ids_flatten.masked_fill_(invalid_mask, 0) - # `src` is the valid mask, which is 1 for valid and 0 for invalid - src = ~invalid_mask - - expert_load_view.scatter_add_(dim=0, - index=index.long(), - src=src.to(expert_load_view)) - - topk_ids = topk_ids.to(dtype=indices_type) + topk_ids = eplb_map_to_physical_and_record( + topk_ids=topk_ids, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + indices_type=indices_type, + ) assert topk_ids.dtype == indices_type or indices_type is None diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index a16c254fadf66..5fce24018e647 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -76,7 +76,7 @@ def _moe_problem_size( """ assert w1.dim() == 3 and w2.dim() == 3 E, N, _ = w1.size() - K = w2.size(1) + K = a1.size(-1) if a1.dim() == 2: # Make sure we are using the correct a1 (pre-permute). diff --git a/vllm/model_executor/layers/mamba/abstract.py b/vllm/model_executor/layers/mamba/abstract.py index a524e13405807..6da62b5426bb6 100644 --- a/vllm/model_executor/layers/mamba/abstract.py +++ b/vllm/model_executor/layers/mamba/abstract.py @@ -20,10 +20,7 @@ class MambaBase(AttentionLayerBase): # Contains the KV cache (mamba state) for the layer # in the shape specified by `self.get_state_shape`. - # The outer list is for v0 PP virtual engine. Though this code path - # only runs for v1, we have to do this to unify with the interface - # of Attention + v0 PP. - kv_cache: list[Iterable[torch.Tensor]] + kv_cache: tuple[torch.Tensor, ...] @abstractmethod def get_state_shape(self) -> Iterable[tuple[int, ...]]: diff --git a/vllm/model_executor/layers/mamba/linear_attn.py b/vllm/model_executor/layers/mamba/linear_attn.py index 5fe37a6289e01..6a901b47b8b63 100644 --- a/vllm/model_executor/layers/mamba/linear_attn.py +++ b/vllm/model_executor/layers/mamba/linear_attn.py @@ -15,7 +15,6 @@ import torch.nn.functional as F from einops import rearrange from torch import nn -from vllm import envs from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config from vllm.distributed.communication_op import tensor_model_parallel_all_reduce @@ -42,8 +41,6 @@ if TYPE_CHECKING: import torch import torch.distributed -from vllm.model_executor.models.minimax_cache import MinimaxCacheParams - class MiniMaxText01RMSNormTP(CustomOp): name = "MiniMaxText01RMSNormTP" @@ -225,11 +222,10 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase): self.tp_heads:(self.tp_rank + 1) * self.tp_heads].contiguous() - if envs.VLLM_USE_V1: - compilation_config = get_current_vllm_config().compilation_config - if prefix in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {prefix}") - compilation_config.static_forward_context[prefix] = self + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self @staticmethod def weight_direct_load(param: torch.Tensor, @@ -268,8 +264,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase): break if _prefill_idx >= len(state_indices_tensor): break - # prefills are packed at end of batch in V1 - offset = attn_metadata.num_decode_tokens if envs.VLLM_USE_V1 else 0 + offset = attn_metadata.num_decode_tokens _start = attn_metadata.query_start_loc[offset + _prefill_idx] _end = attn_metadata.query_start_loc[offset + _prefill_idx + 1] slot_id = state_indices_tensor[offset + _prefill_idx] @@ -291,10 +286,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase): hidden_decode = self._decode_infer(q, k, v, kv_cache, state_indices_tensor, attn_metadata) - if envs.VLLM_USE_V1: - hidden.insert(0, hidden_decode) - else: - hidden.append(hidden_decode) + hidden.insert(0, hidden_decode) if not hidden: return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype) @@ -304,40 +296,28 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase): def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, attn_metadata): - if not envs.VLLM_USE_V1: - q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() - k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() - v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() - num_prefills = getattr(attn_metadata, "num_prefills", 0) - slot_id = state_indices_tensor[num_prefills:] - else: - q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() - k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() - v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() - slot_id = state_indices_tensor[:attn_metadata.num_decodes] + q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() + k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() + v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() + slot_id = state_indices_tensor[:attn_metadata.num_decodes] hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope, slot_id, 32) return hidden def forward(self, hidden_states: torch.Tensor, output: torch.Tensor, - positions: torch.Tensor, - kv_caches: MinimaxCacheParams) -> None: - if not envs.VLLM_USE_V1: - self._forward(hidden_states, output, positions, kv_caches) - else: - torch.ops.vllm.linear_attention( - hidden_states, - output, - positions, - self.prefix, - ) + positions: torch.Tensor) -> None: + torch.ops.vllm.linear_attention( + hidden_states, + output, + positions, + self.prefix, + ) def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor, - positions: torch.Tensor, - kv_caches: Optional[MinimaxCacheParams]) -> None: + positions: torch.Tensor) -> None: forward_context = get_forward_context() attn_metadata: AttentionMetadata = forward_context.attn_metadata - if envs.VLLM_USE_V1 and attn_metadata is not None: + if attn_metadata is not None: assert isinstance(attn_metadata, dict) attn_metadata = attn_metadata[self.prefix] assert isinstance(attn_metadata, LinearAttentionMetadata) @@ -351,32 +331,26 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase): qkvact = torch.nn.functional.silu(qkv32) qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1)) q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1) - if envs.VLLM_USE_V1: - if attn_metadata is not None: - kv_cache = self.kv_cache[forward_context.virtual_engine][0] - state_indices_tensor = attn_metadata.state_indices_tensor + if attn_metadata is not None: + kv_cache = self.kv_cache[forward_context.virtual_engine][0] + state_indices_tensor = attn_metadata.state_indices_tensor - num_prefills = getattr(attn_metadata, "num_prefills", 0) - if num_prefills > 0: - num_decode_tokens = getattr(attn_metadata, - "num_decode_tokens", 0) - for prefill_idx in range(num_prefills): - q_start = attn_metadata.query_start_loc[ - num_decode_tokens + prefill_idx] - q_end = attn_metadata.query_start_loc[num_decode_tokens - + prefill_idx + - 1] - query_len = q_end - q_start - context_len = attn_metadata.seq_lens[ - num_decode_tokens + prefill_idx] - query_len - if context_len == 0: - block_to_clear = state_indices_tensor[ - num_decode_tokens + prefill_idx] - kv_cache[block_to_clear, ...] = 0 - else: - assert kv_caches is not None - kv_cache = kv_caches.minimax_cache - state_indices_tensor = kv_caches.state_indices_tensor + num_prefills = getattr(attn_metadata, "num_prefills", 0) + if num_prefills > 0: + num_decode_tokens = getattr(attn_metadata, "num_decode_tokens", + 0) + for prefill_idx in range(num_prefills): + q_start = attn_metadata.query_start_loc[num_decode_tokens + + prefill_idx] + q_end = attn_metadata.query_start_loc[num_decode_tokens + + prefill_idx + 1] + query_len = q_end - q_start + context_len = attn_metadata.seq_lens[ + num_decode_tokens + prefill_idx] - query_len + if context_len == 0: + block_to_clear = state_indices_tensor[num_decode_tokens + + prefill_idx] + kv_cache[block_to_clear, ...] = 0 decode_only = getattr(attn_metadata, "num_prefills", 0) == 0 if attn_metadata is None: @@ -410,8 +384,7 @@ def linear_attention( self = forward_context.no_compile_layers[layer_name] self._forward(hidden_states=hidden_states, output=output, - positions=positions, - kv_caches=None) + positions=positions) def linear_attention_fake( diff --git a/vllm/model_executor/layers/mamba/mamba2_metadata.py b/vllm/model_executor/layers/mamba/mamba2_metadata.py deleted file mode 100644 index 7f376b70a7ae0..0000000000000 --- a/vllm/model_executor/layers/mamba/mamba2_metadata.py +++ /dev/null @@ -1,177 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from dataclasses import dataclass -from typing import Optional, Union - -import numpy as np -import torch - -from vllm.attention.backends.abstract import AttentionMetadata -from vllm.attention.backends.placeholder_attn import ( - PlaceholderAttentionMetadata) -from vllm.attention.backends.utils import PAD_SLOT_ID -from vllm.platforms import current_platform -from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata -from vllm.v1.attention.backends.mamba2_attn import ( - Mamba2AttentionMetadata, _query_start_loc_to_chunk_indices_offsets) - - -@dataclass -class Mamba2Metadata: - prep_initial_states: bool - chunk_size: int - - has_initial_states_p: torch.Tensor - seq_idx_p: torch.Tensor - chunk_indices_p: torch.Tensor - chunk_offsets_p: torch.Tensor - """ - With continuous batching layout of `x` in vLLM, to enable a Triton program - to handle a request in parallel, two supporting tensors are used - (batch_ptr, token_chunk_offset_ptr) - BLOCK_M = the # tokens to be handled by a Triton program - (can be customized for different hardware) - - nums_dict: - tracks the data associated with a given value of BLOCK_M - BLOCK_M = #tokens handled by a Triton program - cu_seqlen: total tokens per batch - (used as flag to update other data at each new input) - batch_ptr: tracks batch-id handled by the Triton program - token_chunk_offset_ptr: tracks token group_idx handled by the Triton program - (Triton implementation of causal_conv1d handles parallelism in 3-axes - - feature-axis - - batch-axis - - sequence-axis) - """ - nums_dict: Optional[dict] = None - cu_seqlen: Optional[int] = None - batch_ptr: Optional[torch.Tensor] = None - token_chunk_offset_ptr: Optional[torch.Tensor] = None - - -def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]: - """Returns the appropriate metadata classes for the current platform.""" - if current_platform.is_rocm(): - from vllm.v1.attention.backends.rocm_aiter_fa import ( - AiterFlashAttentionMetadata) - from vllm.v1.attention.backends.triton_attn import ( - TritonAttentionMetadata) - return (AiterFlashAttentionMetadata, TritonAttentionMetadata, - PlaceholderAttentionMetadata) - if current_platform.is_cuda(): - from vllm.v1.attention.backends.flash_attn import ( - FlashAttentionMetadata) - from vllm.v1.attention.backends.xformers import ( - XFormersAttentionMetadata) - return (FlashAttentionMetadata, XFormersAttentionMetadata, - PlaceholderAttentionMetadata) - raise ValueError( - f"Unsupported platform for Mamba2: {current_platform.device_type}") - - -def prepare_mamba2_metadata( - chunk_size: int, - attn_metadata: AttentionMetadata, -) -> Mamba2Metadata: - - # compute number of prefill and decode requests - # NOTE: in V0 we assume prefills are before decodes - num_prefills = attn_metadata.num_prefills - num_prefill_tokens = attn_metadata.num_prefill_tokens - - seq_idx_p = None - chunk_indices_p, chunk_offsets_p = None, None - # Need flags to indicate if there are initial states - # currently we really only support the FlashAttention backend - has_initial_states_p = None - prep_initial_states = False - - # Compute seq_idx, chunk_indices and chunk_offsets for prefill only - if num_prefills > 0: - attn_metadata_instances = get_platform_metadata_classes() - if (isinstance(attn_metadata, attn_metadata_instances) - and attn_metadata.context_lens_tensor is not None): - # precompute flag to avoid device syncs later in mamba2 layer - # forwards - # prep is only needed for mamba2 ssd prefill processing - has_initial_states_p = ( - attn_metadata.context_lens_tensor[:num_prefills] > 0) - prep_initial_states = torch.any(has_initial_states_p).item() - query_start_loc_p = attn_metadata.query_start_loc[:num_prefills + 1] - seq_idx_p = torch.repeat_interleave(torch.arange( - num_prefills, dtype=torch.int32, device=query_start_loc_p.device), - query_start_loc_p.diff(), - output_size=num_prefill_tokens) - seq_idx_p.unsqueeze_(0) - - # We compute metadata for chunked prefill once at the top level model - # forward and reuse them in mamba layers. If not needed, they will be - # ignored inside mamba kernels. - if prep_initial_states: - chunk_indices_p, chunk_offsets_p = \ - _query_start_loc_to_chunk_indices_offsets( - query_start_loc_p, chunk_size, num_prefill_tokens) - - return Mamba2Metadata(has_initial_states_p=has_initial_states_p, - prep_initial_states=prep_initial_states, - chunk_size=chunk_size, - seq_idx_p=seq_idx_p, - chunk_indices_p=chunk_indices_p, - chunk_offsets_p=chunk_offsets_p) - - -def update_metadata(x: torch.Tensor, query_start_loc: torch.Tensor, - mamba2_metadata: Union[Mamba2Metadata, - Mamba2AttentionMetadata, - GDNAttentionMetadata]): - """ - this is triggered upon handling a new input at the first layer - """ - dim, cu_seqlen = x.shape - mamba2_metadata.cu_seqlen = cu_seqlen - seqlens = np.diff(query_start_loc.to('cpu')) - nums_dict = {} # type: ignore - for BLOCK_M in [8]: # cover all BLOCK_M values - nums = -(-seqlens // BLOCK_M) - nums_dict[BLOCK_M] = {} - nums_dict[BLOCK_M]['nums'] = nums - nums_dict[BLOCK_M]['tot'] = nums.sum().item() - mlist = torch.from_numpy(np.repeat(np.arange(len(nums)), nums)) - nums_dict[BLOCK_M]['mlist'] = mlist - mlist_len = len(nums_dict[BLOCK_M]['mlist']) - nums_dict[BLOCK_M]['mlist_len'] = mlist_len - MAX_NUM_PROGRAMS = max(1024, mlist_len) * 2 - offsetlist = [] # type: ignore - for idx, num in enumerate(nums): - offsetlist.extend(range(num)) - offsetlist = torch.tensor(offsetlist, dtype=torch.int32) - nums_dict[BLOCK_M]['offsetlist'] = offsetlist - - if mamba2_metadata.batch_ptr is None: - # Update default value after class definition - #mamba2_metadata.MAX_NUM_PROGRAMS *= 2 - mamba2_metadata.batch_ptr = torch.full((MAX_NUM_PROGRAMS, ), - PAD_SLOT_ID, - dtype=torch.int32, - device='cuda') - mamba2_metadata.token_chunk_offset_ptr = torch.full( - (MAX_NUM_PROGRAMS, ), - PAD_SLOT_ID, - dtype=torch.int32, - device='cuda') - else: - if mamba2_metadata.batch_ptr.nelement() < MAX_NUM_PROGRAMS: - mamba2_metadata.batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_( - PAD_SLOT_ID) - mamba2_metadata.token_chunk_offset_ptr.resize_( # type: ignore - MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID) - - mamba2_metadata.batch_ptr[0:mlist_len].copy_(mlist) - mamba2_metadata.token_chunk_offset_ptr[ # type: ignore - 0:mlist_len].copy_(offsetlist) - nums_dict[BLOCK_M]['batch_ptr'] = mamba2_metadata.batch_ptr - nums_dict[BLOCK_M]['token_chunk_offset_ptr'] = ( - mamba2_metadata.token_chunk_offset_ptr) # type: ignore - mamba2_metadata.nums_dict = nums_dict - return mamba2_metadata diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index e704bfd451bce..a56ee13a63804 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -10,8 +10,6 @@ import torch from torch import nn from torch.nn.parameter import Parameter -from vllm import envs -from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) @@ -28,7 +26,6 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( selective_scan_fn, selective_state_update) -from vllm.model_executor.models.mamba_cache import MambaCacheParams from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op @@ -149,16 +146,12 @@ class MambaMixer(MambaBase, CustomOp): has_weight=rms_norm_has_weight, ) if use_rms_norm else None - if envs.VLLM_USE_V1: - compilation_config = get_current_vllm_config().compilation_config - if prefix in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {prefix}") - compilation_config.static_forward_context[prefix] = self - # The outer list is for v0 PP virtual engine. Though this code path - # only runs for v1, we have to do this to unify with the interface - # of Attention + v0 PP. - # The inner tuple is (conv_state, ssm_state) - self.kv_cache = [(torch.tensor([]), torch.tensor([]))] + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + # The inner tuple is (conv_state, ssm_state) + self.kv_cache = (torch.tensor([]), torch.tensor([])) self.model_config = model_config self.cache_config = cache_config @@ -186,29 +179,18 @@ class MambaMixer(MambaBase, CustomOp): discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) return discrete_time_step, B, C - def forward(self, - hidden_states: torch.Tensor, - output: torch.Tensor, - mamba_cache_params: Optional[MambaCacheParams] = None): - if not envs.VLLM_USE_V1: - CustomOp.forward(self, hidden_states, output, mamba_cache_params) - else: - torch.ops.vllm.mamba_mixer( - hidden_states, - output, - self.prefix, - ) + def forward(self, hidden_states: torch.Tensor, output: torch.Tensor): + torch.ops.vllm.mamba_mixer( + hidden_states, + output, + self.prefix, + ) - def forward_native(self, - hidden_states: torch.Tensor, - output: torch.Tensor, - mamba_cache_params: Optional[MambaCacheParams] = None): + def forward_native(self, hidden_states: torch.Tensor, + output: torch.Tensor): pass - def forward_cuda(self, - hidden_states: torch.Tensor, - output: torch.Tensor, - mamba_cache_params: Optional[MambaCacheParams] = None): + def forward_cuda(self, hidden_states: torch.Tensor, output: torch.Tensor): """ Run the Mamba-1 SSM pipeline. @@ -234,31 +216,18 @@ class MambaMixer(MambaBase, CustomOp): forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata - if envs.VLLM_USE_V1: - if attn_metadata is not None: - assert isinstance(attn_metadata, dict) - attn_metadata = attn_metadata[self.prefix] - mamba1_metadata = attn_metadata - assert isinstance(mamba1_metadata, Mamba1AttentionMetadata) - query_start_loc = mamba1_metadata.query_start_loc - state_indices_tensor = mamba1_metadata.state_indices_tensor - self_kv_cache = self.kv_cache[forward_context.virtual_engine] - conv_state = self_kv_cache[0].transpose(-1, -2) - ssm_state = self_kv_cache[1] - has_initial_states = mamba1_metadata.has_initial_states - num_padded_decodes = mamba1_metadata.num_padded_decodes - else: - assert isinstance(attn_metadata, AttentionMetadata) - assert mamba_cache_params is not None - conv_state = mamba_cache_params.conv_state - ssm_state = mamba_cache_params.ssm_state - state_indices_tensor = mamba_cache_params.state_indices_tensor - query_start_loc = attn_metadata.query_start_loc - context_lens_tensor = attn_metadata.context_lens_tensor - has_initial_states = None - if context_lens_tensor is not None: - has_initial_states = context_lens_tensor > 0 - num_padded_decodes = attn_metadata.num_decode_tokens + if attn_metadata is not None: + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + mamba1_metadata = attn_metadata + assert isinstance(mamba1_metadata, Mamba1AttentionMetadata) + query_start_loc = mamba1_metadata.query_start_loc + state_indices_tensor = mamba1_metadata.state_indices_tensor + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + conv_state = self_kv_cache[0].transpose(-1, -2) + ssm_state = self_kv_cache[1] + has_initial_states = mamba1_metadata.has_initial_states + num_padded_decodes = mamba1_metadata.num_padded_decodes # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) @@ -267,7 +236,7 @@ class MambaMixer(MambaBase, CustomOp): conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) - if envs.VLLM_USE_V1 and attn_metadata is None: + if attn_metadata is None: # V1 profile run hidden_states_BC = hidden_states_BC.contiguous() return self.out_proj(hidden_states_BC.transpose(-2, -1))[0] @@ -368,10 +337,7 @@ class MambaMixer(MambaBase, CustomOp): out=scan_outputs_d) scan_outputs_d = scan_outputs_d.transpose(0, 1) - if envs.VLLM_USE_V1: - ssm_outputs.insert(0, scan_outputs_d) - else: - ssm_outputs.append(scan_outputs_d) + ssm_outputs.insert(0, scan_outputs_d) scan_outputs_combined = ssm_outputs[0] if len( ssm_outputs) == 1 else torch.cat(ssm_outputs, dim=-1) @@ -441,40 +407,27 @@ def split_batch_to_prefill_and_decode( num_decodes: int, num_padded_decodes: int, ) -> PrefillDecodeSplit: + num_actual_tokens = num_prefill_tokens + num_padded_decodes - if envs.VLLM_USE_V1: - # In v1, decode tokens come first, then prefill tokens. - hidden_states_BC_d, hidden_states_BC_p = torch.split( - hidden_states_BC[..., :num_actual_tokens], - [num_padded_decodes, num_prefill_tokens], - dim=-1) - gate_d, gate_p = torch.split(gate[..., :num_actual_tokens], - [num_padded_decodes, num_prefill_tokens], - dim=-1) + # In v1, decode tokens come first, then prefill tokens. + hidden_states_BC_d, hidden_states_BC_p = torch.split( + hidden_states_BC[..., :num_actual_tokens], + [num_padded_decodes, num_prefill_tokens], + dim=-1) + gate_d, gate_p = torch.split(gate[..., :num_actual_tokens], + [num_padded_decodes, num_prefill_tokens], + dim=-1) - # num_padded_decodes accounts for CUDA graph padding when applicable - state_indices_tensor_d, state_indices_tensor_p = torch.split( - state_indices_tensor[:num_padded_decodes + num_prefills], - [num_padded_decodes, num_prefills], - dim=0) - query_start_loc_p = (query_start_loc[-num_prefills - 1:] - - num_padded_decodes if num_prefills > 0 else None) - has_initial_states_p = has_initial_states[-num_prefills:] if ( - has_initial_states is not None and num_prefills > 0) else None - else: - # In v0, prefill tokens come first, then decode tokens. - hidden_states_BC_p, hidden_states_BC_d = torch.split( - hidden_states_BC, [num_prefill_tokens, num_decode_tokens], dim=-1) - gate_p, gate_d = torch.split(gate, - [num_prefill_tokens, num_decode_tokens], - dim=-1) - state_indices_tensor_p, state_indices_tensor_d = torch.split( - state_indices_tensor, [num_prefills, num_decodes], dim=0) - query_start_loc_p = (query_start_loc[:num_prefills + - 1] if num_prefills > 0 else None) - has_initial_states_p = has_initial_states[:num_prefills] if ( - has_initial_states is not None and num_prefills > 0) else None + # num_padded_decodes accounts for CUDA graph padding when applicable + state_indices_tensor_d, state_indices_tensor_p = torch.split( + state_indices_tensor[:num_padded_decodes + num_prefills], + [num_padded_decodes, num_prefills], + dim=0) + query_start_loc_p = (query_start_loc[-num_prefills - 1:] - + num_padded_decodes if num_prefills > 0 else None) + has_initial_states_p = has_initial_states[-num_prefills:] if ( + has_initial_states is not None and num_prefills > 0) else None return PrefillDecodeSplit( hidden_states_BC_p=hidden_states_BC_p, @@ -495,9 +448,7 @@ def mamba_mixer( ) -> None: forward_context: ForwardContext = get_forward_context() self = forward_context.no_compile_layers[layer_name] - self.forward_cuda(hidden_states=hidden_states, - output=output, - mamba_cache_params=None) + self.forward_cuda(hidden_states=hidden_states, output=output) def mamba_mixer_fake( diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 02e6a9138c05f..047ce4c4c43d0 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -9,7 +9,6 @@ if TYPE_CHECKING: import torch from torch import nn -from vllm import envs from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config from vllm.distributed import (divide, get_tensor_model_parallel_rank, @@ -22,8 +21,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.mamba.abstract import MambaBase -from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata, - update_metadata) from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( @@ -36,7 +33,6 @@ from vllm.model_executor.layers.mamba.ops.ssd_combined import ( from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import ( LoaderFunction, composed_weight_loader, sharded_weight_loader) -from vllm.model_executor.models.mamba_cache import MambaCacheParams from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op @@ -449,16 +445,12 @@ class MambaMixer2(MambaBase, CustomOp): self.use_rms_norm, eps=rms_norm_eps) - if envs.VLLM_USE_V1: - compilation_config = get_current_vllm_config().compilation_config - if prefix in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {prefix}") - compilation_config.static_forward_context[prefix] = self - # The outer list is for v0 PP virtual engine. Though this code path - # only runs for v1, we have to do this to unify with the interface - # of Attention + v0 PP. - # The inner tuple is (conv_state, ssm_state) - self.kv_cache = [(torch.tensor([]), torch.tensor([]))] + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + # The tuple is (conv_state, ssm_state) + self.kv_cache = (torch.tensor([]), torch.tensor([])) self.model_config = model_config self.cache_config = cache_config @@ -468,8 +460,6 @@ class MambaMixer2(MambaBase, CustomOp): self, hidden_states: torch.Tensor, output: torch.Tensor, - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, mup_vector: Optional[torch.Tensor] = None, ): pass @@ -478,59 +468,43 @@ class MambaMixer2(MambaBase, CustomOp): self, hidden_states: torch.Tensor, output: torch.Tensor, - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, mup_vector: Optional[torch.Tensor] = None, ): - if not envs.VLLM_USE_V1: - CustomOp.forward(self, hidden_states, output, mamba_cache_params, - mamba2_metadata, mup_vector) - else: - torch.ops.vllm.mamba_mixer2( - hidden_states, - output, - self.prefix, - mup_vector, - ) + torch.ops.vllm.mamba_mixer2( + hidden_states, + output, + self.prefix, + mup_vector, + ) def forward_cuda( self, hidden_states: torch.Tensor, output: torch.Tensor, - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, mup_vector: Optional[torch.Tensor] = None, ): forward_context = get_forward_context() - # mamba2_metadata contains metadata necessary for the mamba2 triton + # attn_metadata contains metadata necessary for the mamba2 triton # kernels to operate in continuous batching and in chunked prefill # modes; they are computed at top-level model forward since they # stay the same and reused for all mamba layers in the same iteration attn_metadata: AttentionMetadata = forward_context.attn_metadata - if envs.VLLM_USE_V1: - if attn_metadata is not None: - assert isinstance(attn_metadata, dict) - attn_metadata = attn_metadata[self.prefix] - mamba2_metadata = attn_metadata - assert isinstance(attn_metadata, Mamba2AttentionMetadata) - self_kv_cache = self.kv_cache[forward_context.virtual_engine] - # conv_state = (..., dim, width-1) yet contiguous along 'dim' - conv_state = self_kv_cache[0].transpose(-1, -2) - ssm_state = self_kv_cache[1] - state_indices_tensor = attn_metadata.state_indices_tensor - else: - conv_state = mamba_cache_params.conv_state - ssm_state = mamba_cache_params.ssm_state - state_indices_tensor = mamba_cache_params.state_indices_tensor - # Common members between V1 metadata and V0 metadata - if mamba2_metadata is not None: - has_initial_states_p = mamba2_metadata.has_initial_states_p - prep_initial_states = mamba2_metadata.prep_initial_states - chunk_size = mamba2_metadata.chunk_size - seq_idx_p = mamba2_metadata.seq_idx_p - chunk_indices_p = mamba2_metadata.chunk_indices_p - chunk_offsets_p = mamba2_metadata.chunk_offsets_p + if attn_metadata is not None: + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + assert isinstance(attn_metadata, Mamba2AttentionMetadata) + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + # conv_state = (..., dim, width-1) yet contiguous along 'dim' + conv_state = self_kv_cache[0].transpose(-1, -2) + ssm_state = self_kv_cache[1] + state_indices_tensor = attn_metadata.state_indices_tensor + has_initial_states_p = attn_metadata.has_initial_states_p + prep_initial_states = attn_metadata.prep_initial_states + chunk_size = attn_metadata.chunk_size + seq_idx_p = attn_metadata.seq_idx_p + chunk_indices_p = attn_metadata.chunk_indices_p + chunk_offsets_p = attn_metadata.chunk_offsets_p # 1. Gated MLP's linear projection projected_states, _ = self.in_proj(hidden_states) @@ -562,8 +536,8 @@ class MambaMixer2(MambaBase, CustomOp): dim=-1, ) - if envs.VLLM_USE_V1 and attn_metadata is None: - # V1 profile run + if attn_metadata is None: + # profile run hidden_states_B_C = (hidden_states_B_C.transpose( 0, 1).clone().transpose(0, 1)).contiguous() hidden_states, _B, _C = split_hidden_states_B_C_fn( @@ -579,49 +553,27 @@ class MambaMixer2(MambaBase, CustomOp): has_decode = num_decodes > 0 num_actual_tokens = num_prefill_tokens + num_decodes - # NOTE: V0 put prefill before decode, v1 puts decode before prefill # Separate prefill and decode by splitting varlen input # Split along token dimension - if envs.VLLM_USE_V1: - hidden_states_B_C_d, hidden_states_B_C_p = torch.split( - hidden_states_B_C[:num_actual_tokens], - [num_decodes, num_prefill_tokens], - dim=0, - ) - dt_d, dt_p = torch.split( - dt[:num_actual_tokens], - [num_decodes, num_prefill_tokens], - dim=0, - ) - # Split along batch dimension - state_indices_tensor_d, state_indices_tensor_p = torch.split( - state_indices_tensor[:num_actual_tokens], - [num_decodes, num_prefills], - dim=0, - ) - query_start_loc_p = ( - attn_metadata.query_start_loc[-num_prefills - 1:] - - num_decodes if has_prefill else None) - else: - hidden_states_B_C_p, hidden_states_B_C_d = torch.split( - hidden_states_B_C, - [num_prefill_tokens, num_decodes], - dim=0, - ) - dt_p, dt_d = torch.split( - dt, - [num_prefill_tokens, num_decodes], - dim=0, - ) - # Split along batch dimension - state_indices_tensor_p, state_indices_tensor_d = torch.split( - state_indices_tensor, - [num_prefills, num_decodes], - dim=0, - ) - query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + - 1] - if has_prefill else None) + hidden_states_B_C_d, hidden_states_B_C_p = torch.split( + hidden_states_B_C[:num_actual_tokens], + [num_decodes, num_prefill_tokens], + dim=0, + ) + dt_d, dt_p = torch.split( + dt[:num_actual_tokens], + [num_decodes, num_prefill_tokens], + dim=0, + ) + # Split along batch dimension + state_indices_tensor_d, state_indices_tensor_p = torch.split( + state_indices_tensor[:num_actual_tokens], + [num_decodes, num_prefills], + dim=0, + ) + query_start_loc_p = ( + attn_metadata.query_start_loc[-num_prefills - 1:] - + num_decodes if has_prefill else None) # Preallocate output tensor to avoid memcpy cost for merging prefill # and decode outputs @@ -633,18 +585,11 @@ class MambaMixer2(MambaBase, CustomOp): dtype=hidden_states.dtype, device=hidden_states.device, ) - if envs.VLLM_USE_V1: - preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split( - preallocated_ssm_out, - [num_decodes, num_prefill_tokens], - dim=0, - ) - else: - preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split( - preallocated_ssm_out, - [num_prefill_tokens, num_decodes], - dim=0, - ) + preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split( + preallocated_ssm_out, + [num_decodes, num_prefill_tokens], + dim=0, + ) # Process prefill requests if has_prefill: @@ -653,9 +598,6 @@ class MambaMixer2(MambaBase, CustomOp): # pointed to by "state_indices_tensor" x = hidden_states_B_C_p.transpose( 0, 1) # this is the form that causal-conv see - if mamba2_metadata.cu_seqlen is None: - mamba2_metadata = update_metadata(x, query_start_loc_p, - mamba2_metadata) hidden_states_B_C_p = causal_conv1d_fn( x, conv_weights, @@ -664,7 +606,7 @@ class MambaMixer2(MambaBase, CustomOp): conv_states=conv_state, has_initial_state=has_initial_states_p, cache_indices=state_indices_tensor_p, - metadata=mamba2_metadata, + metadata=attn_metadata, query_start_loc=query_start_loc_p).transpose( 0, 1)[:num_prefill_tokens] @@ -806,8 +748,6 @@ def mamba_mixer2( self = forward_context.no_compile_layers[layer_name] self.forward_cuda(hidden_states=hidden_states, output=output, - mamba_cache_params=None, - mamba2_metadata=None, mup_vector=mup_vector) diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index a6c1af91de421..677a4b9d87fc7 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -100,7 +100,6 @@ class MambaStateShapeCalculator: intermediate_size: int, state_size: int, conv_kernel: int, - use_v1: bool = True, ) -> tuple[tuple[int, int], tuple[int, int]]: conv_state_shape = (divide(intermediate_size, tp_world_size), conv_kernel - 1) @@ -108,11 +107,7 @@ class MambaStateShapeCalculator: temporal_state_shape = (divide(intermediate_size, tp_world_size), state_size) - # In V0, the conv_state shape was swapped during allocation in - # MambaCacheManager, but in V1 it needs to be determined here at the - # calculation level - if use_v1: - conv_state_shape = conv_state_shape[1], conv_state_shape[0] + conv_state_shape = conv_state_shape[1], conv_state_shape[0] return conv_state_shape, temporal_state_shape @@ -126,7 +121,6 @@ class MambaStateShapeCalculator: head_dim: int, state_size: int, conv_kernel: int, - use_v1: bool = True, ) -> tuple[tuple[int, int], tuple[int, int, int]]: # if n_groups is not divisible by world_size, need to extend the shards # to ensure all groups needed by a head is sharded along with it @@ -137,8 +131,6 @@ class MambaStateShapeCalculator: # contiguous along 'dim' axis conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size)) - if not use_v1: - conv_state_shape = conv_state_shape[1], conv_state_shape[0] # These are not TP-ed as they depend on A, dt_bias, D # - they are typically small @@ -153,12 +145,9 @@ class MambaStateShapeCalculator: tp_world_size: int, intermediate_size: int, conv_kernel: int, - use_v1: bool = True, ) -> tuple[tuple[int, int]]: conv_dim = divide(intermediate_size, tp_world_size) conv_state_shape = (conv_kernel - 1, conv_dim) - if not use_v1: - conv_state_shape = conv_state_shape[1], conv_state_shape[0] return (conv_state_shape, ) @classmethod @@ -183,7 +172,6 @@ class MambaStateShapeCalculator: head_v_dim: int, conv_kernel_size: int, num_spec: int = 0, - use_v1: bool = True, ): conv_dim = (head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads) conv_state_shape = ( @@ -191,11 +179,7 @@ class MambaStateShapeCalculator: conv_kernel_size - 1 + num_spec, ) - # In V0, the conv_state shape was swapped during allocation in - # MambaCacheManager, but in V1 it needs to be determined here at the - # calculation level - if use_v1: - conv_state_shape = conv_state_shape[1], conv_state_shape[0] + conv_state_shape = conv_state_shape[1], conv_state_shape[0] temporal_state_shape = (divide(num_v_heads, tp_world_size), head_k_dim, head_v_dim) diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index 8cfd0962c5bfe..010fcdda156c2 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -420,9 +420,7 @@ def causal_conv1d_fn( x = x.to(conv_states.dtype) out = torch.empty_like(x) if metadata is not None: - cu_seqlen = metadata.cu_seqlen nums_dict = metadata.nums_dict - #x = metadata.x args = nums_dict batch_ptr = metadata.batch_ptr token_chunk_offset_ptr = metadata.token_chunk_offset_ptr @@ -926,7 +924,6 @@ def causal_conv1d_update( query_start_loc: Optional[torch.Tensor] = None, max_query_len: int = -1, pad_slot_id: int = PAD_SLOT_ID, - metadata=None, validate_data=False, ): """ diff --git a/vllm/model_executor/layers/mamba/short_conv.py b/vllm/model_executor/layers/mamba/short_conv.py index 335191a5c82c1..ffdcd702aab40 100644 --- a/vllm/model_executor/layers/mamba/short_conv.py +++ b/vllm/model_executor/layers/mamba/short_conv.py @@ -8,7 +8,6 @@ if TYPE_CHECKING: import torch -from vllm import envs from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config from vllm.distributed import get_tensor_model_parallel_world_size @@ -18,7 +17,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.mamba.abstract import MambaBase -from vllm.model_executor.layers.mamba.mamba2_metadata import update_metadata from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( @@ -71,15 +69,11 @@ class ShortConv(MambaBase, CustomOp): prefix=f"{prefix}.out_proj", ) - assert envs.VLLM_USE_V1, ("ShortConv layers are only supported in V1") compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") compilation_config.static_forward_context[prefix] = self - # The outer list is for v0 PP virtual engine. Though this code path - # only runs for v1, we have to do this to unify with the interface - # of Attention + v0 PP. - self.kv_cache = [(torch.tensor([]), )] + self.kv_cache = (torch.tensor([]), ) self.model_config = model_config self.cache_config = cache_config @@ -89,7 +83,6 @@ class ShortConv(MambaBase, CustomOp): self, hidden_states: torch.Tensor, output: torch.Tensor, - conv_metadata: ShortConvAttentionMetadata, ): return @@ -97,7 +90,6 @@ class ShortConv(MambaBase, CustomOp): self, hidden_states: torch.Tensor, output: torch.Tensor, - conv_metadata: ShortConvAttentionMetadata, ): torch.ops.vllm.short_conv( hidden_states, @@ -109,7 +101,6 @@ class ShortConv(MambaBase, CustomOp): self, hidden_states: torch.Tensor, output: torch.Tensor, - conv_metadata: ShortConvAttentionMetadata, ): forward_context = get_forward_context() # ShortConvAttentionMetadata contains metadata necessary for the @@ -121,7 +112,6 @@ class ShortConv(MambaBase, CustomOp): if attn_metadata is not None: assert isinstance(attn_metadata, dict) attn_metadata = attn_metadata[self.prefix] - conv_metadata = attn_metadata assert isinstance(attn_metadata, ShortConvAttentionMetadata) self_kv_cache = self.kv_cache[forward_context.virtual_engine] conv_state = self_kv_cache[0].transpose(-1, -2) @@ -181,9 +171,6 @@ class ShortConv(MambaBase, CustomOp): if has_prefill: Bx_p = (B_p * x_p).transpose(0, 1) - if conv_metadata.cu_seqlen is None: - conv_metadata = update_metadata(Bx_p, query_start_loc_p, - conv_metadata) Bx = causal_conv1d_fn(Bx_p, conv_weights, self.conv.bias, @@ -191,7 +178,7 @@ class ShortConv(MambaBase, CustomOp): conv_states=conv_state, has_initial_state=has_initial_states_p, cache_indices=state_indices_tensor_p, - metadata=conv_metadata, + metadata=attn_metadata, query_start_loc=query_start_loc_p).transpose( 0, 1)[:num_prefill_tokens] @@ -248,9 +235,7 @@ def short_conv( ) -> None: forward_context: ForwardContext = get_forward_context() self = forward_context.no_compile_layers[layer_name] - self.forward_cuda(hidden_states=hidden_states, - output=output, - conv_metadata=None) + self.forward_cuda(hidden_states=hidden_states, output=output) def short_conv_fake( diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 5c3f8a891276b..a71c8d32a22c7 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -13,7 +13,10 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, FusedMoEMethodBase) from vllm.model_executor.layers.fused_moe import modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEQuantConfig, mxfp4_w4a4_moe_quant_config) + FusedMoEQuantConfig, mxfp4_w4a4_moe_quant_config, + mxfp4_w4a16_moe_quant_config) +from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( + OAITritonExperts) from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) @@ -578,9 +581,14 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): layer.w13_bias = Parameter(w13_bias, requires_grad=False) layer.w2_bias = Parameter(w2_bias, requires_grad=False) - # FIXME warp need to be adjusted based on batch size - # only apply to batched mode - if self.moe.use_ep: + # Ideally we'd use FusedMoEModularKernel.prepare_finalize object + # (stored in self.fused_experts) to determine if the MoE has a + # batched activation format. As self.fused_experts is not + # initialized at this point, we resort to checking the MoE config + # directly. + is_batched_moe = (self.moe.use_pplx_kernels + or self.moe.use_deepep_ll_kernels) + if is_batched_moe: num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8 else: num_warps = 8 @@ -640,16 +648,21 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): if self.mxfp4_backend == Mxfp4Backend.TRITON: w1_scale = self.w13_precision_config w2_scale = self.w2_precision_config + return mxfp4_w4a16_moe_quant_config( + w1_bias=layer.w13_bias, + w2_bias=layer.w2_bias, + w1_scale=w1_scale, + w2_scale=w2_scale, + ) else: w1_scale = layer.w13_weight_scale w2_scale = layer.w2_weight_scale - - return mxfp4_w4a4_moe_quant_config( - w1_bias=layer.w13_bias, - w2_bias=layer.w2_bias, - w1_scale=w1_scale, - w2_scale=w2_scale, - ) + return mxfp4_w4a4_moe_quant_config( + w1_bias=layer.w13_bias, + w2_bias=layer.w2_bias, + w1_scale=w1_scale, + w2_scale=w2_scale, + ) def select_gemm_impl( self, @@ -661,6 +674,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): raise NotImplementedError( "Mxfp4 does not support batched experts format for EP") else: + assert self.moe_quant_config is not None if (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16): # B200 code-path @@ -671,13 +685,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): # TODO(bnell): part of quant_config "max_capture_size": self.max_capture_size, } - assert self.moe_quant_config is not None return TrtLlmGenExperts(self.moe, self.moe_quant_config, **kwargs) else: - # Use matmul_ogs from triton_kernels here! - raise NotImplementedError( - "Mxfp4 does not support non-batched experts format for EP") + return OAITritonExperts(self.moe_quant_config) def _route_and_experts( self, @@ -722,10 +733,16 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): logical_to_physical_map=logical_to_physical_map, logical_replica_count=logical_replica_count) + w13_weight = (self.w13_weight_triton_tensor + if layer.w13_weight is None else layer.w13_weight) + w2_weight = (self.w2_weight_triton_tensor + if layer.w2_weight is None else layer.w2_weight) + assert all([w is not None for w in [w13_weight, w2_weight]]) + return self.fused_experts( hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, + w1=w13_weight, + w2=w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, diff --git a/vllm/model_executor/layers/rotary_embedding/__init__.py b/vllm/model_executor/layers/rotary_embedding/__init__.py index c9653aa9e4405..3576368981c7c 100644 --- a/vllm/model_executor/layers/rotary_embedding/__init__.py +++ b/vllm/model_executor/layers/rotary_embedding/__init__.py @@ -153,11 +153,23 @@ def get_rope( if k in ("extrapolation_factor", "attn_factor", "beta_fast", "beta_slow") } - rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim, - original_max_position, - base, is_neox_style, - scaling_factor, dtype, - **extra_kwargs) + if "mrope_section" in rope_scaling: + rotary_emb = MRotaryEmbedding( + head_size, + rotary_dim, + original_max_position, + base, + is_neox_style, + dtype, + mrope_section=rope_scaling["mrope_section"], + mrope_interleaved=rope_scaling.get("mrope_interleaved", + False), + scaling_factor=scaling_factor, + **extra_kwargs) + else: + rotary_emb = YaRNScalingRotaryEmbedding( + head_size, rotary_dim, original_max_position, base, + is_neox_style, scaling_factor, dtype, **extra_kwargs) elif scaling_type == "deepseek_yarn": scaling_factor = rope_scaling["factor"] original_max_position = rope_scaling[ diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py index 17d04a1ad715c..9bf0d6bd15e74 100644 --- a/vllm/model_executor/layers/rotary_embedding/mrope.py +++ b/vllm/model_executor/layers/rotary_embedding/mrope.py @@ -12,6 +12,7 @@ from vllm.triton_utils import tl, triton from .base import RotaryEmbedding from .common import apply_rotary_emb_dispatch +from .yarn_scaling_rope import YaRNScalingRotaryEmbedding, yarn_get_mscale @triton.jit @@ -213,7 +214,27 @@ class MRotaryEmbedding(RotaryEmbedding): dtype: torch.dtype, mrope_section: Optional[list[int]] = None, mrope_interleaved: bool = False, + # YaRN parameters. + *, + scaling_factor: Optional[float] = None, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, ) -> None: + + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + if self.scaling_factor is not None: + # Get n-d magnitude scaling corrected for interpolation + self.mscale = float( + yarn_get_mscale(self.scaling_factor) * attn_factor) + else: + self.mscale = 1.0 + # In Qwen2.5-VL, the maximum index value is related to the duration of # the input video. We enlarge max_position_embeddings to 4 times to get # a larger the cos and sin cache. @@ -226,6 +247,16 @@ class MRotaryEmbedding(RotaryEmbedding): if self.mrope_section: assert sum(self.mrope_section) == rotary_dim // 2 + def _compute_inv_freq(self, base: float) -> torch.Tensor: + if self.scaling_factor is None: + return super()._compute_inv_freq(base) + return YaRNScalingRotaryEmbedding._compute_inv_freq(self, base) + + def _compute_cos_sin_cache(self) -> torch.Tensor: + if self.scaling_factor is None: + return super()._compute_cos_sin_cache() + return YaRNScalingRotaryEmbedding._compute_cos_sin_cache(self) + def forward_native( self, positions: torch.Tensor, diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index d1bdec21fd974..4b7bcd37d4bc2 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -211,16 +211,15 @@ class DefaultModelLoader(BaseModelLoader): from vllm.platforms.tpu import USE_TPU_COMMONS if not USE_TPU_COMMONS: - # In PyTorch XLA, we should call `xm.mark_step` + # In PyTorch XLA, we should call `torch_xla.sync` # frequently so that not too many ops are accumulated - # in the XLA program. import torch_xla.core.xla_model - # as xm - import torch_xla.core.xla_model as xm + # in the XLA program. + import torch_xla def _xla_weights_iterator(iterator: Generator): for weights in iterator: yield weights - xm.mark_step() + torch_xla.sync(wait=False) weights_iterator = _xla_weights_iterator(weights_iterator) diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index 584981ef3ebfd..4a6154dc548aa 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -9,21 +9,17 @@ import torch from torch import nn from transformers import BambaConfig -from vllm import envs from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group -from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.mamba2_metadata import ( - Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator) @@ -32,10 +28,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.mamba_cache import (MambaCacheManager, - MambaCacheParams) from vllm.sequence import IntermediateTensors -from vllm.utils import LayerBlockType from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, SupportsQuant) @@ -115,8 +108,6 @@ class BambaMixerDecoderLayer(nn.Module): self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, **kwargs, ): if residual is None: @@ -127,7 +118,7 @@ class BambaMixerDecoderLayer(nn.Module): hidden_states, residual) output = torch.empty_like(hidden_states) - self.mamba(hidden_states, output, mamba_cache_params, mamba2_metadata) + self.mamba(hidden_states, output) # Fully Connected hidden_states, residual = self.pre_ff_layernorm(output, residual) hidden_states = self.feed_forward(hidden_states) @@ -315,22 +306,10 @@ class BambaModel(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - mamba_cache_params: MambaCacheParams, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - attn_metadata = get_forward_context().attn_metadata - - if not envs.VLLM_USE_V1: - mamba2_metadata = prepare_mamba2_metadata( - chunk_size=self.config.mamba_chunk_size, - attn_metadata=attn_metadata, - ) - else: - # v1 get mamba2_metadata from forward_context - mamba2_metadata = None - if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -343,23 +322,11 @@ class BambaModel(nn.Module): residual = intermediate_tensors["residual"] residual = None - num_attn = 0 for i, layer in enumerate(self.layers): - if isinstance(layer, BambaAttentionDecoderLayer): - num_attn += 1 - - layer_mamba_cache_params = None - if isinstance(layer, - BambaMixerDecoderLayer) and mamba_cache_params: - layer_mamba_cache_params = mamba_cache_params.at_layer_idx( - i - num_attn) - hidden_states, residual = layer( positions=positions, hidden_states=hidden_states, residual=residual, - mamba_cache_params=layer_mamba_cache_params, - mamba2_metadata=mamba2_metadata, ) if not get_pp_group().is_last_rank: @@ -457,13 +424,11 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, def get_mamba_state_shape_from_config( cls, vllm_config: "VllmConfig", - use_v1: bool = True, ) -> tuple[tuple[int, int], tuple[int, int, int]]: """Calculate shapes for Mamba's convolutional and state caches. Args: vllm_config: vLLM config - use_v1: Get shapes for V1 (or V0) Returns: Tuple containing: @@ -482,7 +447,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, head_dim=hf_config.mamba_d_head, state_size=hf_config.mamba_d_state, conv_kernel=hf_config.mamba_d_conv, - use_v1=use_v1, ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -515,8 +479,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, if not lora_config else lora_config.lora_vocab_padding_size, prefix=maybe_prefix(prefix, "lm_head"), ) - # Used to track and store by the Mamba cache between steps. - self.mamba_cache: Optional[MambaCacheManager] = None self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) @@ -534,39 +496,11 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): - mamba_cache_params = None - if not envs.VLLM_USE_V1: - if self.mamba_cache is None: - num_mamba_layers = \ - self.model_config.get_num_layers_by_block_type( - self.vllm_config.parallel_config, - LayerBlockType.mamba - ) - mamba_state_shape = \ - self.get_mamba_state_shape_from_config( - self.vllm_config, use_v1=False) - mamba_state_dtype = \ - self.get_mamba_state_dtype_from_config( - self.vllm_config) - self.mamba_cache = MambaCacheManager(self.vllm_config, - num_mamba_layers, - *mamba_state_shape, - *mamba_state_dtype) - - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - - hidden_states = self.model(input_ids, positions, mamba_cache_params, - intermediate_tensors, inputs_embeds) + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) return hidden_states - def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): - return self.mamba_cache.copy_inputs_before_cuda_graphs( - input_buffers, **kwargs) - - def get_seqlen_agnostic_capture_inputs(self, batch_size: int): - return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/constant_size_cache.py b/vllm/model_executor/models/constant_size_cache.py deleted file mode 100644 index f03c58a12932f..0000000000000 --- a/vllm/model_executor/models/constant_size_cache.py +++ /dev/null @@ -1,137 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from abc import ABC, abstractmethod -from typing import Any - -import torch - -from vllm.attention.backends.utils import PAD_SLOT_ID - - -class ConstantSizeCache(ABC): - """ - Abstract base class for managing constant size caches - like Mamba and Minimax. - """ - - def __init__(self, max_batch_size: int): - # Maps between the request id and a dict that maps between the seq_id - # and its index inside the cache - self.cache_indices_mapping: dict[str, dict[int, int]] = {} - self.free_cache_indices = list(range(max_batch_size)) - - @property - @abstractmethod - def cache(self) -> Any: - """Return the underlying cache tensor(s)""" - pass - - @abstractmethod - def _copy_cache(self, from_index: int, to_index: int): - """Copy cache data from one index to another""" - pass - - def current_run_tensors(self, **kwargs) -> tuple: - """ - Return the tensors for the current run's conv and ssm state. - """ - if "seqlen_agnostic_capture_inputs" not in kwargs: - # We get here only on Prefill/Eager mode runs - request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] - finished_requests_ids = kwargs["finished_requests_ids"] - - self._release_finished_requests(finished_requests_ids) - state_indices = self._prepare_current_run_cache( - request_ids_to_seq_ids, finished_requests_ids) - - state_indices_tensor = torch.as_tensor(state_indices, - dtype=torch.int32, - device="cuda") - cache_tensors = self.cache - else: - # CUDA graph capturing runs - cache_tensors, state_indices_tensor = kwargs[ - "seqlen_agnostic_capture_inputs"] - - return (cache_tensors, state_indices_tensor) - - def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): - """ - Copy the relevant state_indices into the CUDA graph input buffer - """ - assert all( - key in kwargs - for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) - finished_requests_ids = kwargs["finished_requests_ids"] - request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] - assert "seqlen_agnostic_capture_inputs" in input_buffers - _, input_state_indices_buffer = input_buffers[ - "seqlen_agnostic_capture_inputs"] - - self._release_finished_requests(finished_requests_ids) - state_indices = self._prepare_current_run_cache( - request_ids_to_seq_ids, finished_requests_ids) - cuda_graph_pad_len = input_state_indices_buffer.shape[0] - len( - state_indices) - state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len) - - input_state_indices_buffer.copy_( - torch.as_tensor(state_indices, dtype=torch.int32, device="cuda")) - - def get_seqlen_agnostic_capture_inputs(self, batch_size: int): - """ - Provide the CUDA graph capture runs with a buffer in adjusted size. - The buffer is used to maintain the Cache during the CUDA graph replay - runs. - """ - state_indices_tensor = torch.as_tensor([PAD_SLOT_ID] * batch_size, - dtype=torch.int32, - device="cuda") - return (self.cache, state_indices_tensor) - - def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int, - finished_requests_ids) -> int: - """ - Assign (req_id,seq_id) pair to a `destination_index` index, if - already occupied, move the occupying index to a free index. - """ - if cur_rid in finished_requests_ids: - # set as pad, do not allocate destination index - return PAD_SLOT_ID - elif cur_rid not in self.cache_indices_mapping: - destination_index = self.free_cache_indices.pop() - self.cache_indices_mapping[cur_rid] = {seq_id: destination_index} - return destination_index - elif seq_id not in (seq_ids2indices := - self.cache_indices_mapping[cur_rid]): - # parallel sampling , where n > 1, assume prefill have - # already happened, so we copy the - # existing cache into the siblings seq_ids caches - index_exists = next(iter(seq_ids2indices.values())) - # case of decoding n>1, copy prefill cache to decoding indices - destination_index = self.free_cache_indices.pop() - self._copy_cache(from_index=index_exists, - to_index=destination_index) - self.cache_indices_mapping[cur_rid][seq_id] = destination_index - return destination_index - else: - return self.cache_indices_mapping[cur_rid][seq_id] - - def _prepare_current_run_cache( - self, request_ids_to_seq_ids: dict[str, list[int]], - finished_requests_ids: list[str]) -> list[int]: - return [ - self._assign_seq_id_to_cache_index(req_id, seq_id, - finished_requests_ids) - for req_id, seq_ids in request_ids_to_seq_ids.items() - for seq_id in seq_ids - ] - - def _release_finished_requests(self, - finished_seq_groups_req_ids: list[str]): - for req_id in finished_seq_groups_req_ids: - if req_id in self.cache_indices_mapping: - for seq_id in self.cache_indices_mapping[req_id]: - self.free_cache_indices.append( - self.cache_indices_mapping[req_id][seq_id]) - self.cache_indices_mapping.pop(req_id) diff --git a/vllm/model_executor/models/dots_ocr.py b/vllm/model_executor/models/dots_ocr.py new file mode 100644 index 0000000000000..04fa5584199a3 --- /dev/null +++ b/vllm/model_executor/models/dots_ocr.py @@ -0,0 +1,824 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable, Mapping +from typing import Literal, Optional, TypedDict, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import LayerNorm +from transformers.modeling_utils import PreTrainedModel +from transformers.models.qwen2_vl import Qwen2VLProcessor + +from vllm.attention.layer import check_upstream_fa_availability +from vllm.config import VllmConfig +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.models.interfaces import (MultiModalEmbeddings, + SupportsMultiModal, + SupportsPP) +from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM +from vllm.model_executor.models.qwen2_vl import (Qwen2VLDummyInputsBuilder, + Qwen2VLMultiModalProcessor, + Qwen2VLProcessingInfo) +from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, + init_vllm_registered_model, + maybe_prefix, + merge_multimodal_embeddings) +from vllm.model_executor.models.vision import get_vit_attn_backend +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalDataDict +from vllm.platforms import _Backend +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.dotsocr import (DotsOCRConfig, + DotsVisionConfig) + +IMAGE_TOKEN = "<|imgpad|>" + + +class DotsOCRImagePixelInputs(TypedDict): + type: Literal["pixel_values", "image_grid_thw"] + + pixel_values: torch.Tensor + image_grid_thw: torch.Tensor + + +class DotsOCRImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds", "image_grid_thw"] + image_embeds: torch.Tensor + """Supported types: + - List[`torch.Tensor`]: A list of tensors holding all images' features. + Each tensor holds an image's features. + - `torch.Tensor`: A tensor holding all images' features + (concatenation of all images' feature tensors). + Tensor shape: `(num_image_features, hidden_size)` + - `num_image_features` varies based on + the number and resolution of the images. + - `hidden_size` must match the hidden size of language model backbone. + """ + + image_grid_thw: torch.Tensor + + +DotsOCRImageInputs = Union[DotsOCRImagePixelInputs, + DotsOCRImageEmbeddingInputs] + + +class DotsOCRDummyInputsBuilder(Qwen2VLDummyInputsBuilder): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + return IMAGE_TOKEN * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + + target_width, target_height = self.info.get_image_size_with_most_features( # noqa: E501 + ) + + return { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images), + } + + +class DotsOCRProcessingInfo(Qwen2VLProcessingInfo): + + def get_hf_config(self) -> DotsOCRConfig: + config = self.ctx.get_hf_config() + if not config.__class__.__name__ == 'DotsOCRConfig': + raise TypeError(f"Expected DotsOCRConfig, got {type(config)}") + + if hasattr(config, "vision_config") and isinstance( + config.vision_config, dict): + config.vision_config = DotsVisionConfig(**config.vision_config) + + return config + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + max_image_tokens = self.get_max_image_tokens() + return {"image": max_image_tokens} + + def get_hf_processor( + self, + **kwargs: object, + ) -> Qwen2VLProcessor: + self.get_tokenizer( + ).image_token = IMAGE_TOKEN # Ensure image token is set + processor = self.ctx.get_hf_processor( + Qwen2VLProcessor, + **kwargs, + ) + processor.image_token = IMAGE_TOKEN + processor.video_token = "<|video_pad|>" + return processor + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_vision(tensor: torch.Tensor, + freqs: torch.Tensor) -> torch.Tensor: + orig_dtype = tensor.dtype + tensor = tensor.float() + + cos = freqs.cos() + sin = freqs.sin() + + cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + + output = (tensor * cos) + (rotate_half(tensor) * sin) + + output = output.to(orig_dtype) + + return output + + +class VisionRotaryEmbedding(nn.Module): + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + inv_freq = 1.0 / (theta + **(torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, + device=self.inv_freq.device, + dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class PatchMerger(nn.Module): + + def __init__( + self, + dim: int, + context_dim: int, + spatial_merge_size: int = 2, + pre_norm="layernorm", + ) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + self.pre_norm = pre_norm + if self.pre_norm == "layernorm": + self.ln_q = LayerNorm(context_dim, eps=1e-6) + elif self.pre_norm == "rmsnorm": + self.ln_q = RMSNorm(context_dim, eps=1e-6) + else: + print("no norm in patch merger") + + self.mlp = nn.Sequential( + ColumnParallelLinear(self.hidden_size, + self.hidden_size, + bias=True, + return_bias=False, + disable_tp=True), + nn.GELU(), + RowParallelLinear(self.hidden_size, + dim, + bias=True, + return_bias=False, + disable_tp=True), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.pre_norm: + x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) + else: + x = self.mlp(x.view(-1, self.hidden_size)) + return x + + +class DotsVisionAttention(nn.Module): + + def __init__(self, + config, + dim: int, + num_heads: int = 16, + bias: bool = True, + *, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__() + from vllm.distributed import (parallel_state, + tensor_model_parallel_all_gather) + from vllm.distributed import utils as dist_utils + + self.embed_dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.tp_size = parallel_state.get_tensor_model_parallel_world_size() + self.tp_rank = parallel_state.get_tensor_model_parallel_rank() + self.num_heads_per_partition = dist_utils.divide( + num_heads, self.tp_size) + + # qkv/proj follow Qwen2-VL style; bias controlled by arg + self.qkv = QKVParallelLinear(hidden_size=dim, + head_size=dim // num_heads, + total_num_heads=num_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv") + self.proj = RowParallelLinear(input_size=dim, + output_size=dim, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.proj") + self._all_gather = tensor_model_parallel_all_gather + self._split_last = dist_utils.split_tensor_along_last_dim + + # Select attention backend + self.attn_backend = get_vit_attn_backend(self.head_dim, + torch.get_default_dtype()) + self.use_upstream_fa = False + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability(torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN + self.use_upstream_fa = True + if self.attn_backend not in { + _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, + _Backend.ROCM_AITER_FA + }: + raise RuntimeError( + f"Unsupported vision attention backend: {self.attn_backend}") + self.is_flash_attn_backend = self.attn_backend in { + _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA + } + + def _split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: + # qkv: [S, B, 3*dim] + seq_len, bs, _ = qkv.shape + if self.tp_size > 1: + qkv = self._all_gather(qkv) + q, k, v = qkv.chunk(3, dim=2) + if self.tp_size > 1: + q = self._split_last(q, num_partitions=self.tp_size)[self.tp_rank] + k = self._split_last(k, num_partitions=self.tp_size)[self.tp_rank] + v = self._split_last(v, num_partitions=self.tp_size)[self.tp_rank] + new_shape = (seq_len, bs, self.num_heads_per_partition, self.head_dim) + return (q.view(*new_shape), k.view(*new_shape), v.view(*new_shape)) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + *, + max_seqlen: Optional[int] = None, + seqlens: Optional[list[int]] = None, + ) -> torch.Tensor: + # [S, C] -> [S, B=1, C] + x = hidden_states.unsqueeze(1) + x, _ = self.qkv(x) + q, k, v = self._split_qkv(x) + bs = q.shape[1] + # [S,B,H,D] -> [B,S,H,D] + q = q.permute(1, 0, 2, 3).contiguous() + k = k.permute(1, 0, 2, 3).contiguous() + v = v.permute(1, 0, 2, 3).contiguous() + + if rotary_pos_emb is not None: + qk_concat = torch.cat([q, k], dim=0) + qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) + q, k = torch.chunk(qk_rotated, 2, dim=0) + + if self.is_flash_attn_backend: + if self.attn_backend == _Backend.ROCM_AITER_FA: + from aiter import flash_attn_varlen_func + else: + if self.use_upstream_fa: + from flash_attn import flash_attn_varlen_func + else: + from vllm.vllm_flash_attn import flash_attn_varlen_func + q_ = q.reshape(bs * q.shape[1], q.shape[2], q.shape[3]) + k_ = k.reshape(bs * k.shape[1], k.shape[2], k.shape[3]) + v_ = v.reshape(bs * v.shape[1], v.shape[2], v.shape[3]) + output = flash_attn_varlen_func(q_, + k_, + v_, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=0.0, + causal=False) + context_layer = output.view(bs, -1, self.num_heads_per_partition, + self.head_dim) + elif self.attn_backend == _Backend.TORCH_SDPA: + outputs = [] + for i in range(1, len(cu_seqlens)): + s = int(cu_seqlens[i - 1]) + e = int(cu_seqlens[i]) + q_i = q[:, s:e].permute(0, 2, 1, 3) + k_i = k[:, s:e].permute(0, 2, 1, 3) + v_i = v[:, s:e].permute(0, 2, 1, 3) + out_i = F.scaled_dot_product_attention(q_i, + k_i, + v_i, + dropout_p=0.0) + out_i = out_i.permute(0, 2, 1, 3) + outputs.append(out_i) + context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0] + elif self.attn_backend == _Backend.XFORMERS: + from xformers import ops as xops + from xformers.ops.fmha.attn_bias import BlockDiagonalMask + attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens, + kv_seqlen=None, + device=q.device) + context_layer = xops.memory_efficient_attention_forward( + q, k, v, attn_bias=attn_bias, p=0, scale=None) + else: + raise RuntimeError("Unsupported attention backend") + + # [B,S,H,D] -> [S,B,H*D] -> [S, C] + context_layer = context_layer.permute(1, 0, 2, 3).contiguous() + context_layer = context_layer.view(context_layer.shape[0], bs, -1) + out, _ = self.proj(context_layer) + return out.squeeze(1) + + +class DotsSwiGLUFFN(nn.Module): + + def __init__(self, + config, + *, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + hidden_features = config.intermediate_size + in_features = config.embed_dim + bias = config.use_bias + + # Referenced aimv2.py AIMv2SwiGLUFFN + self.fc13 = MergedColumnParallelLinear(in_features, + [hidden_features] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.fc13", + disable_tp=True) + self.fc2 = RowParallelLinear(hidden_features, + in_features, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + disable_tp=True) + self.act_fn = SiluAndMul() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.fc13(x) + x = self.act_fn(x) + x, _ = self.fc2(x) + return x + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + params = dict(self.named_parameters()) + loaded: set[str] = set() + for name, w in weights: + # Map fc1 -> fc13 (shard 0) + if name.startswith("fc1."): + tgt = name.replace("fc1.", "fc13.") + if tgt in params: + params[tgt].weight_loader(params[tgt], w, 0) + loaded.add(tgt) + continue + # Map fc3 -> fc13 (shard 1) + if name.startswith("fc3."): + tgt = name.replace("fc3.", "fc13.") + if tgt in params: + params[tgt].weight_loader(params[tgt], w, 1) + loaded.add(tgt) + continue + # Pass-through for fc2 and others + if name in params: + params[name].weight_loader(params[name], w) + loaded.add(name) + return loaded + + +class DotsPatchEmbed(nn.Module): + + def __init__(self, config): + super().__init__() + self.num_channels = config.num_channels + self.patch_size = config.patch_size + self.temporal_patch_size = config.temporal_patch_size + self.embed_dim = config.embed_dim + self.config = config + self.proj = nn.Conv2d( + config.num_channels, + config.embed_dim, + kernel_size=(config.patch_size, config.patch_size), + stride=(config.patch_size, config.patch_size), + ) + self.norm = RMSNorm(config.embed_dim, eps=config.rms_norm_eps) + + def forward(self, x: torch.Tensor, grid_thw=None) -> torch.Tensor: + x = x.view(-1, self.num_channels, self.temporal_patch_size, + self.patch_size, self.patch_size)[:, :, 0] + x = self.proj(x).view(-1, self.embed_dim) + x = self.norm(x) + return x + + +class DotsViTPreprocessor(nn.Module): + + def __init__(self, config): + super().__init__() + self.patch_h = config.patch_size + self.patch_w = config.patch_size + self.embed_dim = config.embed_dim + self.config = config + self.patchifier = DotsPatchEmbed(config) + + def forward(self, x: torch.Tensor, grid_thw=None) -> torch.Tensor: + tokens = self.patchifier(x, grid_thw) + return tokens + + +class DotsVisionBlock(nn.Module): + + def __init__(self, + config, + *, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + + self.attn = DotsVisionAttention( + config, + config.embed_dim, + num_heads=config.num_attention_heads, + bias=config.use_bias, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps) + self.mlp = DotsSwiGLUFFN(config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + self.norm2 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps) + + def forward(self, + hidden_states: torch.Tensor, + *, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, + seqlens: Optional[list[int]] = None) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class DotsVisionTransformer(PreTrainedModel): + + def __init__( + self, + config: DotsVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + num_hidden_layers_override: Optional[int] = None, + require_post_norm: Optional[bool] = None, + prefix: str = "", + ) -> None: + super().__init__(config) + self.config = config + self.spatial_merge_size = config.spatial_merge_size + + self.patch_embed = DotsViTPreprocessor(config) + + head_dim = config.embed_dim // config.num_attention_heads + self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) + self.attn_backend = get_vit_attn_backend( + head_size=head_dim, dtype=torch.get_default_dtype()) + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability(torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN + + # Keep blocks for compatibility with other vision towers + num_layers = (config.num_hidden_layers if num_hidden_layers_override + is None else num_hidden_layers_override) + self.blocks = nn.ModuleList([ + DotsVisionBlock(config, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{i}") + for i in range(num_layers) + ]) + if require_post_norm is None: + require_post_norm = (len(self.blocks) == config.num_hidden_layers) + if require_post_norm and self.config.post_norm: + self.post_trunk_norm = RMSNorm(config.embed_dim, + eps=config.rms_norm_eps) + else: + self.post_trunk_norm = None + + self.merger = PatchMerger( + dim=config.hidden_size, + context_dim=config.embed_dim, + spatial_merge_size=config.spatial_merge_size, + ) + + @property + def dtype(self) -> torch.dtype: + return self.patch_embed.patchifier.proj.weight.dtype + + @property + def device(self) -> torch.device: + return self.patch_embed.patchifier.proj.weight.device + + def get_pos_ids_by_grid(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append( + torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + + return pos_ids + + def rot_pos_emb(self, grid_thw): + pos_ids = self.get_pos_ids_by_grid(grid_thw) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def compute_attn_mask_seqlen( + self, cu_seqlens: torch.Tensor + ) -> tuple[Optional[int], Optional[list[int]]]: + max_seqlen, seqlens = None, None + if self.attn_backend == _Backend.FLASH_ATTN: + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + elif self.attn_backend == _Backend.XFORMERS: + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + return max_seqlen, seqlens + + def forward(self, hidden_states: torch.Tensor, + grid_thw: torch.Tensor) -> torch.Tensor: + hidden_states = hidden_states.to(self.dtype) + hidden_states = self.patch_embed(hidden_states, grid_thw) + + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + dtype=grid_thw.dtype + if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) + for blk in self.blocks: + hidden_states = blk(hidden_states, + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens) + + if self.post_trunk_norm is not None: + hidden_states = self.post_trunk_norm(hidden_states) + + hidden_states = self.merger(hidden_states) + return hidden_states + + +@MULTIMODAL_REGISTRY.register_processor( + Qwen2VLMultiModalProcessor, + info=DotsOCRProcessingInfo, + dummy_inputs=DotsOCRDummyInputsBuilder, +) +class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + ".attn.qkv_proj.": ".attn.qkv.", + ".attn.out_proj.": ".attn.proj.", + }, + orig_to_new_prefix={ + "lm_head.": "language_model.lm_head.", + "model.": "language_model.model.", + }, + ) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("image"): + return "<|img|><|imgpad|><|endofimg|>" + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + self.config: DotsOCRConfig = vllm_config.model_config.hf_config + self.quant_config = vllm_config.quant_config + self.multimodal_config = vllm_config.model_config.multimodal_config + + if isinstance(self.config.vision_config, dict): + vision_config = DotsVisionConfig(**self.config.vision_config) + self.config.vision_config = vision_config + else: + vision_config = self.config.vision_config + + self.vision_tower = DotsVisionTransformer( + vision_config, + quant_config=self.quant_config, + prefix=maybe_prefix(prefix, "vision_tower"), + ) + self.language_model: Qwen2ForCausalLM = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=self.config, + prefix=maybe_prefix(prefix, "language_model"), + architectures=["Qwen2ForCausalLM"], + ) + + def _validate_and_reshape_mm_tensor(self, mm_input: object, + name: str) -> torch.Tensor: + if not isinstance(mm_input, (torch.Tensor, list)): + raise ValueError(f"Incorrect type of {name}. " + f"Got type: {type(mm_input)}") + if isinstance(mm_input, torch.Tensor): + if mm_input.ndim == 2: + return mm_input + if mm_input.ndim != 3: + raise ValueError(f"{name} should be 2D or batched 3D tensor. " + f"Got ndim: {mm_input.ndim} " + f"(shape={mm_input.shape})") + return torch.concat(list(mm_input)) + else: + return torch.concat(mm_input) + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[DotsOCRImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + image_grid_thw = kwargs.pop("image_grid_thw", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + pixel_values = self._validate_and_reshape_mm_tensor( + pixel_values, "image pixel values") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, "image grid_thw") + + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of image pixel values. " + f"Got type: {type(pixel_values)}") + + return DotsOCRImagePixelInputs(type="pixel_values", + pixel_values=pixel_values, + image_grid_thw=image_grid_thw) + + if image_embeds is not None: + image_embeds = self._validate_and_reshape_mm_tensor( + image_embeds, "image embeds") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, "image grid_thw") + + if not isinstance(image_embeds, torch.Tensor): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + return DotsOCRImageEmbeddingInputs(type="image_embeds", + image_embeds=image_embeds, + image_grid_thw=image_grid_thw) + + def _process_image_input( + self, image_input: DotsOCRImageInputs) -> tuple[torch.Tensor, ...]: + grid_thw = image_input["image_grid_thw"] + assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() + + if image_input["type"] == "image_embeds": + image_embeds = image_input["image_embeds"].type( + self.vision_tower.dtype) + else: + pixel_values = image_input["pixel_values"].type( + self.vision_tower.dtype) + image_embeds = self.vision_tower( + pixel_values, grid_thw)[:, :self.config.hidden_size] + + # Split concatenated embeddings for each image item. + merge_size = self.vision_tower.spatial_merge_size + sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) // + (merge_size * merge_size)).tolist() + + return image_embeds.split(sizes) + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return [] + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + multimodal_embeddings, + self.config.image_token_id, + ) + + return inputs_embeds + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + inputs_embeds = None + elif inputs_embeds is None and kwargs.get("pixel_values") is not None: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + inputs_embeds = None + else: + assert input_ids is not None + inputs_embeds = self.get_multimodal_embeddings( + input_ids, + image_input=image_input, + ) + input_ids = None + + hidden_states = self.language_model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py index 83efdd2e433fc..f382018e2222c 100644 --- a/vllm/model_executor/models/falcon_h1.py +++ b/vllm/model_executor/models/falcon_h1.py @@ -8,21 +8,17 @@ import torch from torch import nn from transformers import FalconH1Config -from vllm import envs from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group -from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.mamba2_metadata import ( - Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator) @@ -31,8 +27,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.mamba_cache import (MambaCacheManager, - MambaCacheParams) from vllm.sequence import IntermediateTensors from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP @@ -179,16 +173,12 @@ class FalconH1SSMDecoderLayer(nn.Module): self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, **kwargs, ): output = torch.empty_like(hidden_states) self.mamba( hidden_states, output, - mamba_cache_params, - mamba2_metadata=mamba2_metadata, mup_vector=self.mup_vector, ) return output, residual @@ -364,8 +354,6 @@ class FalconH1ParallelHybrid(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, **kwargs, ): residual = hidden_states @@ -382,12 +370,10 @@ class FalconH1ParallelHybrid(nn.Module): # Process input through the SSM branch. # FalconH1SSMDecoderLayer expects hidden_states, attn_metadata, - # residual, mamba_cache_params, and sequence_idx. + # residual, and sequence_idx. ssm_hidden, _ = self.mamba( hidden_states=hidden_states * self.ssm_in_multiplier, residual=residual, - mamba_cache_params=mamba_cache_params, - mamba2_metadata=mamba2_metadata, **kwargs, ) # Sum the outputs from both branches. @@ -464,25 +450,10 @@ class FalconH1Model(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - mamba_cache_params: MambaCacheParams, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - # pass a sequence index tensor, that is required for - # proper continuous batching computation including - # chunked prefill - attn_metadata = get_forward_context().attn_metadata - - if not envs.VLLM_USE_V1: - mamba2_metadata = prepare_mamba2_metadata( - chunk_size=self.config.mamba_chunk_size, - attn_metadata=attn_metadata, - ) - else: - # v1 get mamba2_metadata from forward_context - mamba2_metadata = None - if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds * self.embedding_multiplier @@ -495,14 +466,9 @@ class FalconH1Model(nn.Module): for i in range(self.start_layer, self.end_layer): layer = self.layers[i] - layer_mamba_cache_params = None - if mamba_cache_params: - layer_mamba_cache_params = mamba_cache_params.at_layer_idx(i) hidden_states = layer( positions=positions, hidden_states=hidden_states, - mamba_cache_params=layer_mamba_cache_params, - mamba2_metadata=mamba2_metadata, ) if not get_pp_group().is_last_rank: return IntermediateTensors({ @@ -541,13 +507,11 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, def get_mamba_state_shape_from_config( cls, vllm_config: "VllmConfig", - use_v1: bool = True, ) -> tuple[tuple[int, int], tuple[int, int, int]]: """Calculate shapes for Mamba's convolutional and state caches. Args: vllm_config: vLLM config - use_v1: Get shapes for V1 (or V0) Returns: Tuple containing: @@ -570,7 +534,6 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, head_dim=hf_config.mamba_d_head, state_size=hf_config.mamba_d_state, conv_kernel=hf_config.mamba_d_conv, - use_v1=use_v1, ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -592,7 +555,6 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, prefix=maybe_prefix(prefix, "model")) self.tie_word_embeddings = config.tie_word_embeddings self.unpadded_vocab_size = config.vocab_size - self.mamba_cache: Optional[MambaCacheManager] = None if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size if get_pp_group().is_last_rank: @@ -637,40 +599,15 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, **kwargs, ): - mamba_cache_params = None - if not envs.VLLM_USE_V1: - if self.mamba_cache is None: - mamba_state_shape = \ - self.get_mamba_state_shape_from_config( - self.vllm_config, use_v1=False) - mamba_state_dtype = \ - self.get_mamba_state_dtype_from_config( - self.vllm_config) - self.mamba_cache = MambaCacheManager( - self.vllm_config, - self.config.num_hidden_layers, - *mamba_state_shape, - *mamba_state_dtype, - ) - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - hidden_states = self.model( input_ids, positions, - mamba_cache_params, intermediate_tensors, inputs_embeds, ) return hidden_states - def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): - return self.mamba_cache.copy_inputs_before_cuda_graphs( - input_buffers, **kwargs) - - def get_seqlen_agnostic_capture_inputs(self, batch_size: int): - return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 56ec634386909..b088e0c0dd241 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -69,7 +69,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope @@ -83,7 +82,7 @@ from .qwen2_vl import (_create_qwen2vl_field_factory, from .utils import (AutoWeightsLoader, WeightsMapper, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) -from .vision import get_vit_attn_backend +from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model logger = init_logger(__name__) diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index 4fe59f91124dd..7c755a00e1c98 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -27,7 +27,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from vllm.utils import cdiv -from .interfaces import SupportsPP +from .interfaces import SupportsEagle3, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, @@ -238,6 +238,7 @@ class GptOssModel(nn.Module): self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], self.config.hidden_size)) + self.aux_hidden_state_layers = tuple[int, ...]() def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embedding(input_ids) @@ -261,8 +262,12 @@ class GptOssModel(nn.Module): x = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] + aux_hidden_states = [] for i in range(self.start_layer, self.end_layer): layer = self.layers[i] + if i in self.aux_hidden_state_layers: + aux_hidden_states.append(x if residual is None else x + + residual) x, residual = layer(x, positions, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ @@ -270,6 +275,9 @@ class GptOssModel(nn.Module): "residual": residual }) x, _ = self.norm(x, residual) + + if len(aux_hidden_states) > 0: + return x, aux_hidden_states return x def _load_weights_mxfp4( @@ -610,7 +618,7 @@ class GptOssModel(nn.Module): weights, stacked_params_mapping) -class GptOssForCausalLM(nn.Module, SupportsPP): +class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3): packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]} hf_to_vllm_mapper = WeightsMapper( @@ -658,6 +666,13 @@ class GptOssForCausalLM(nn.Module, SupportsPP): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + self.model.aux_hidden_state_layers = layers + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + num_layers = len(self.model.layers) + return (2, num_layers // 2, num_layers - 3) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index e89a1a4a0f7d3..f5751fe47bb8b 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -9,19 +9,15 @@ import torch from torch import nn from transformers import GraniteMoeHybridConfig -from vllm import envs from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group -from vllm.forward_context import get_forward_context from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.mamba2_metadata import ( - Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator) @@ -30,10 +26,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.mamba_cache import (MambaCacheManager, - MambaCacheParams) from vllm.sequence import IntermediateTensors -from vllm.utils import LayerBlockType from .granitemoe import GraniteMoeMoE from .granitemoeshared import GraniteMoeSharedMLP @@ -102,14 +95,12 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module): self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, **kwargs, ): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) output = torch.empty_like(hidden_states) - self.mamba(hidden_states, output, mamba_cache_params, mamba2_metadata) + self.mamba(hidden_states, output) hidden_states = residual + output * self.residual_multiplier residual = hidden_states @@ -182,8 +173,6 @@ class GraniteMoeHybridAttentionDecoderLayer(nn.Module): positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, ) -> torch.Tensor: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -366,22 +355,10 @@ class GraniteMoeHybridModel(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - mamba_cache_params: MambaCacheParams, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - attn_metadata = get_forward_context().attn_metadata - - if not envs.VLLM_USE_V1: - mamba2_metadata = prepare_mamba2_metadata( - chunk_size=self.config.mamba_chunk_size, - attn_metadata=attn_metadata, - ) - else: - # v1 get mamba2_metadata from forward_context - mamba2_metadata = None - if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -399,20 +376,9 @@ class GraniteMoeHybridModel(nn.Module): for i, layer in enumerate(self.layers): if isinstance(layer, GraniteMoeHybridAttentionDecoderLayer): num_attn += 1 - - layer_mamba_cache_params = None - if isinstance( - layer, - GraniteMoeHybridMambaDecoderLayer) and mamba_cache_params: - layer_mamba_cache_params = mamba_cache_params.at_layer_idx( - i - num_attn) - - hidden_states, residual = layer( - positions=positions, - hidden_states=hidden_states, - residual=residual, - mamba_cache_params=layer_mamba_cache_params, - mamba2_metadata=mamba2_metadata) + hidden_states, residual = layer(positions=positions, + hidden_states=hidden_states, + residual=residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ @@ -552,13 +518,11 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, def get_mamba_state_shape_from_config( cls, vllm_config: "VllmConfig", - use_v1: bool = True, ) -> tuple[tuple[int, int], tuple[int, int, int]]: """Calculate shapes for Mamba's convolutional and state caches. Args: vllm_config: vLLM config - use_v1: Get shapes for V1 (or V0) Returns: Tuple containing: @@ -577,7 +541,6 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, head_dim=hf_config.mamba_d_head, state_size=hf_config.mamba_d_state, conv_kernel=hf_config.mamba_d_conv, - use_v1=use_v1, ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -620,9 +583,6 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, scale=1 / self.config.logits_scaling) - # Used to track and store by the Mamba cache between steps. - self.mamba_cache: Optional[MambaCacheManager] = None - self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -636,38 +596,11 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): - mamba_cache_params = None - if not envs.VLLM_USE_V1: - if self.mamba_cache is None: - num_mamba_layers = ( - self.model_config.get_num_layers_by_block_type( - self.vllm_config.parallel_config, - LayerBlockType.mamba)) - mamba_state_shape = \ - self.get_mamba_state_shape_from_config( - self.vllm_config, use_v1=False) - mamba_state_dtype = \ - self.get_mamba_state_dtype_from_config( - self.vllm_config) - self.mamba_cache = MambaCacheManager(self.vllm_config, - num_mamba_layers, - *mamba_state_shape, - *mamba_state_dtype) - - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - - hidden_states = self.model(input_ids, positions, mamba_cache_params, - intermediate_tensors, inputs_embeds) + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) return hidden_states - def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): - return self.mamba_cache.copy_inputs_before_cuda_graphs( - input_buffers, **kwargs) - - def get_seqlen_agnostic_capture_inputs(self, batch_size: int): - return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/idefics2_vision_model.py b/vllm/model_executor/models/idefics2_vision_model.py index 76737a4428232..2f0c4240413be 100644 --- a/vllm/model_executor/models/idefics2_vision_model.py +++ b/vllm/model_executor/models/idefics2_vision_model.py @@ -34,7 +34,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.multimodal.utils import run_dp_sharded_vision_model + +from .vision import run_dp_sharded_vision_model class Idefics2VisionEmbeddings(nn.Module): diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 18446d126b51a..79e130119ae83 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -608,7 +608,7 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, prefix=maybe_prefix(prefix, "lm_head"), ) if self.config.text_config.tie_word_embeddings: - self.lm_head.weight = self.model.text_model.wte.weight + self.lm_head.weight = self.model.text_model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.text_config.vocab_size) def _parse_and_validate_image_input( diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py index 892188c047228..2c341d2839719 100644 --- a/vllm/model_executor/models/intern_vit.py +++ b/vllm/model_executor/models/intern_vit.py @@ -28,7 +28,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.multimodal.utils import run_dp_sharded_vision_model + +from .vision import run_dp_sharded_vision_model NORM2FN = { 'rms_norm': RMSNorm, diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 12a49029195ff..e8277e259bc5b 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -9,7 +9,6 @@ import torch from torch import nn from transformers import JambaConfig -from vllm import envs from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig @@ -30,10 +29,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama import LlamaMLP as JambaMLP -from vllm.model_executor.models.mamba_cache import (MambaCacheManager, - MambaCacheParams) from vllm.sequence import IntermediateTensors -from vllm.utils import LayerBlockType from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, @@ -145,7 +141,6 @@ class JambaMambaDecoderLayer(nn.Module): self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - mamba_cache_params: MambaCacheParams, **kwargs, ): if residual is None: @@ -156,7 +151,7 @@ class JambaMambaDecoderLayer(nn.Module): hidden_states, residual) output = torch.empty_like(hidden_states) - self.mamba(hidden_states, output, mamba_cache_params) + self.mamba(hidden_states, output) # Fully Connected hidden_states, residual = self.pre_ff_layernorm(output, residual) hidden_states = self.feed_forward(hidden_states) @@ -333,7 +328,6 @@ class JambaModel(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - mamba_cache_params: MambaCacheParams, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -348,24 +342,11 @@ class JambaModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - kv_cache_index = 0 - mamba_cache_index = 0 for layer in islice(self.layers, self.start_layer, self.end_layer): - layer_mamba_cache_params = None - if isinstance(layer, JambaAttentionDecoderLayer): - kv_cache_index += 1 - if isinstance(layer, - JambaMambaDecoderLayer) and mamba_cache_params: - current_state_layer = mamba_cache_index - layer_mamba_cache_params = mamba_cache_params.at_layer_idx( - current_state_layer) - mamba_cache_index += 1 + hidden_states, residual = layer(positions=positions, + hidden_states=hidden_states, + residual=residual) - hidden_states, residual = layer( - positions=positions, - hidden_states=hidden_states, - residual=residual, - mamba_cache_params=layer_mamba_cache_params) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, @@ -503,8 +484,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, if not lora_config else lora_config.lora_vocab_padding_size, prefix=maybe_prefix(prefix, "lm_head"), ) - # Used to track and store by the Mamba cache between steps. - self.mamba_cache: Optional[MambaCacheManager] = None self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) @@ -521,24 +500,9 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): - # NOTE: mamba_cache_params is not needed for v1 - mamba_cache_params = None - if not envs.VLLM_USE_V1: - if self.mamba_cache is None: - num_layers = self.model_config.get_num_layers_by_block_type( - self.vllm_config.parallel_config, LayerBlockType.mamba) - state_shape = self.get_mamba_state_shape_from_config( - self.vllm_config) - state_dtype = self.get_mamba_state_dtype_from_config( - self.vllm_config) - self.mamba_cache = MambaCacheManager(self.vllm_config, - num_layers, *state_shape, - *state_dtype) - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - - hidden_states = self.model(input_ids, positions, mamba_cache_params, - intermediate_tensors, inputs_embeds) + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) return hidden_states def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): @@ -574,7 +538,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, intermediate_size=hf_config.mamba_expand * hidden_size, state_size=hf_config.mamba_d_state, conv_kernel=hf_config.mamba_d_conv, - use_v1=envs.VLLM_USE_V1, ) def compute_logits( diff --git a/vllm/model_executor/models/kimi_vl.py b/vllm/model_executor/models/kimi_vl.py index f554077935bf3..503627865c4a5 100644 --- a/vllm/model_executor/models/kimi_vl.py +++ b/vllm/model_executor/models/kimi_vl.py @@ -76,13 +76,13 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config from vllm.utils.tensor_schema import TensorSchema, TensorShape from .utils import PPMissingLayer, is_pp_missing_parameter, maybe_prefix +from .vision import run_dp_sharded_mrope_vision_model # For dummy input only diff --git a/vllm/model_executor/models/lfm2.py b/vllm/model_executor/models/lfm2.py index dd97afbeb668a..53c36e4e52d81 100644 --- a/vllm/model_executor/models/lfm2.py +++ b/vllm/model_executor/models/lfm2.py @@ -8,7 +8,6 @@ import torch import torch.nn as nn from transformers import Lfm2Config -from vllm import envs from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig @@ -297,7 +296,6 @@ class Lfm2ShortConvDecoderLayer(nn.Module): self.conv( hidden_states, output, - conv_metadata=None, ) hidden_states, residual = self.ffn_norm(output, residual) hidden_states = self.feed_forward(hidden_states) @@ -459,13 +457,11 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, def get_mamba_state_shape_from_config( cls, vllm_config: "VllmConfig", - use_v1: bool = True, ) -> tuple[tuple[int, int]]: """ Calculate shapes for LFM2's convolutional cache. Args: vllm_config: vLLM config - use_v1: Get shapes for V1 (or V0) Returns: Tuple containing: @@ -478,7 +474,6 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, tp_world_size=parallel_config.tensor_parallel_size, intermediate_size=hf_config.conv_dim, conv_kernel=hf_config.conv_L_cache, - use_v1=use_v1, ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: @@ -489,8 +484,6 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, scheduler_config = vllm_config.scheduler_config assert (not cache_config.enable_prefix_caching ), "Lfm2 currently does not support prefix caching" - assert envs.VLLM_USE_V1, ( - "Lfm2ForCausalLM doesn't support vLLM v0. Please enable v1") super().__init__() self.config = config diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 36141a5d50641..5bd268291c7d9 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -8,7 +8,6 @@ import torch from torch import nn from transformers import MambaConfig -from vllm import envs from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed.parallel_state import get_pp_group @@ -24,10 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import (HasInnerState, IsAttentionFree, SupportsPP) -from vllm.model_executor.models.mamba_cache import (MambaCacheManager, - MambaCacheParams) from vllm.sequence import IntermediateTensors -from vllm.utils import LayerBlockType from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, @@ -72,7 +68,6 @@ class MambaDecoderLayer(nn.Module): self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - mamba_cache_params: MambaCacheParams, **kwargs, ): if residual is None: @@ -82,7 +77,7 @@ class MambaDecoderLayer(nn.Module): hidden_states, residual = self.norm(hidden_states, residual) output = torch.empty_like(hidden_states) - self.mixer(hidden_states, output, mamba_cache_params) + self.mixer(hidden_states, output) return output, residual @@ -134,7 +129,6 @@ class MambaModel(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - mamba_cache_params: Optional[MambaCacheParams] = None, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -151,17 +145,9 @@ class MambaModel(nn.Module): for i in range(self.start_layer, self.end_layer): layer = self.layers[i] - - layer_cache_params = None - if mamba_cache_params is not None: - layer_cache_params = mamba_cache_params.at_layer_idx( - i - self.start_layer) - - hidden_states, residual = layer( - positions=positions, - hidden_states=hidden_states, - residual=residual, - mamba_cache_params=layer_cache_params) + hidden_states, residual = layer(positions=positions, + hidden_states=hidden_states, + residual=residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, @@ -225,9 +211,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP): prefix=maybe_prefix(prefix, "lm_head"), ) - # Used to track and store by the Mamba cache between steps. - self.mamba_cache: Optional[MambaCacheManager] = None - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) @@ -244,22 +227,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP): inputs_embeds: Optional[torch.Tensor] = None, **kwargs): - mamba_cache_params = None - if not envs.VLLM_USE_V1: - if self.mamba_cache is None: - num_layers = self.model_config.get_num_layers_by_block_type( - self.vllm_config.parallel_config, LayerBlockType.mamba) - state_shape = self.get_mamba_state_shape_from_config( - self.vllm_config) - state_dtype = self.get_mamba_state_dtype_from_config( - self.vllm_config) - self.mamba_cache = MambaCacheManager(self.vllm_config, - num_layers, *state_shape, - *state_dtype) - - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - - hidden_states = self.backbone(input_ids, positions, mamba_cache_params, + hidden_states = self.backbone(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states @@ -288,8 +256,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP): tp_world_size=parallel_config.tensor_parallel_size, intermediate_size=hf_config.intermediate_size, state_size=hf_config.state_size, - conv_kernel=hf_config.conv_kernel, - use_v1=envs.VLLM_USE_V1) + conv_kernel=hf_config.conv_kernel) def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): return self.mamba_cache.copy_inputs_before_cuda_graphs( diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index 9c3108146d2e5..97e9c5785e726 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -8,16 +8,11 @@ import torch from torch import nn from transformers import MambaConfig -from vllm import envs -from vllm.attention.backends.abstract import AttentionMetadata from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed.parallel_state import get_pp_group -from vllm.forward_context import get_forward_context from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.mamba2_metadata import ( - Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator) @@ -28,10 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import (HasInnerState, IsAttentionFree) -from vllm.model_executor.models.mamba_cache import (MambaCacheManager, - MambaCacheParams) from vllm.sequence import IntermediateTensors -from vllm.utils import LayerBlockType from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, @@ -74,8 +66,6 @@ class Mamba2DecoderLayer(nn.Module): self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, **kwargs, ): if residual is None: @@ -85,7 +75,7 @@ class Mamba2DecoderLayer(nn.Module): hidden_states, residual = self.norm(hidden_states, residual) output = torch.empty_like(hidden_states) - self.mixer(hidden_states, output, mamba_cache_params, mamba2_metadata) + self.mixer(hidden_states, output) return output, residual @@ -137,7 +127,6 @@ class Mamba2Model(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - mamba_cache_params: MambaCacheParams, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -152,25 +141,10 @@ class Mamba2Model(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - attn_metadata: AttentionMetadata = get_forward_context().attn_metadata - - if not envs.VLLM_USE_V1: - mamba2_metadata = prepare_mamba2_metadata( - chunk_size=self.config.chunk_size, - attn_metadata=attn_metadata, - ) - else: - # v1 get mamba2_metadata from forward_context - mamba2_metadata = None - for i, layer in enumerate(self.layers): - hidden_states, residual = layer( - positions=positions, - hidden_states=hidden_states, - residual=residual, - mamba_cache_params=mamba_cache_params.at_layer_idx( - i - self.start_layer) if mamba_cache_params else None, - mamba2_metadata=mamba2_metadata) + hidden_states, residual = layer(positions=positions, + hidden_states=hidden_states, + residual=residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ @@ -222,13 +196,11 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree): def get_mamba_state_shape_from_config( cls, vllm_config: "VllmConfig", - use_v1: bool = True, ) -> tuple[tuple[int, int], tuple[int, int, int]]: """Calculate shapes for Mamba's convolutional and state caches. Args: vllm_config: vLLM config - use_v1: Get shapes for V1 (or V0) Returns: Tuple containing: @@ -247,7 +219,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree): head_dim=hf_config.head_dim, state_size=hf_config.state_size, conv_kernel=hf_config.conv_kernel, - use_v1=use_v1, ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -282,9 +253,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree): if config.tie_word_embeddings: self.lm_head = self.lm_head.tie_weights(self.backbone.embeddings) - # Used to track and store by the Mamba cache between steps. - self.mamba_cache: Optional[MambaCacheManager] = None - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) @@ -300,29 +268,8 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): - if not envs.VLLM_USE_V1: - if self.mamba_cache is None: - num_mamba_layers = ( - self.model_config.get_num_layers_by_block_type( - self.vllm_config.parallel_config, - LayerBlockType.mamba)) - mamba_state_shape = \ - self.get_mamba_state_shape_from_config( - self.vllm_config, use_v1=False) - mamba_state_dtype = \ - self.get_mamba_state_dtype_from_config( - self.vllm_config) - self.mamba_cache = MambaCacheManager(self.vllm_config, - num_mamba_layers, - *mamba_state_shape, - *mamba_state_dtype) - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - else: - # NOTE: mamba_cache_params is not needed for v1 - mamba_cache_params = None - - hidden_states = self.backbone(input_ids, positions, mamba_cache_params, + hidden_states = self.backbone(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/mamba_cache.py b/vllm/model_executor/models/mamba_cache.py deleted file mode 100644 index 6b16e3ce7d984..0000000000000 --- a/vllm/model_executor/models/mamba_cache.py +++ /dev/null @@ -1,83 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from dataclasses import dataclass - -import torch - -from vllm.attention.backends.utils import PAD_SLOT_ID -from vllm.config import VllmConfig -from vllm.model_executor.models.constant_size_cache import ConstantSizeCache - - -@dataclass -class MambaCacheParams: - conv_state: torch.Tensor = torch.Tensor() - ssm_state: torch.Tensor = torch.Tensor() - state_indices_tensor: torch.Tensor = torch.Tensor() - - def at_layer_idx(self, layer_idx): - return MambaCacheParams(self.conv_state[layer_idx], - self.ssm_state[layer_idx], - self.state_indices_tensor) - - -class MambaCacheManager(ConstantSizeCache): - - def __init__(self, vllm_config: VllmConfig, num_mamba_layers: int, - conv_state_shape: tuple[int, int], - temporal_state_shape: tuple[int, int], - conv_state_dtype: torch.dtype, - temporal_state_dtype: torch.dtype): - - self.conv_state_dtype = conv_state_dtype - self.temporal_state_dtype = temporal_state_dtype - - # Determine max batch size to set size of MambaCache - max_batch_size = vllm_config.scheduler_config.max_num_seqs - if not vllm_config.model_config.enforce_eager: - max_batch_size = vllm_config.pad_for_cudagraph(max_batch_size) - - # Initialize parent class - super().__init__(max_batch_size) - - # assume conv_state = (dim, state_len) - assert conv_state_shape[0] > conv_state_shape[1] - conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) + - (conv_state_shape[1], conv_state_shape[0]), - dtype=self.conv_state_dtype, - device="cuda").transpose(-1, -2) - temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) + - temporal_state_shape, - dtype=self.temporal_state_dtype, - device="cuda") - - self._mamba_cache = (conv_state, temporal_state) - - @property - def cache(self): - return self._mamba_cache - - def _copy_cache(self, from_index: int, to_index: int): - for cache_t in self.cache: - cache_t[:, to_index].copy_(cache_t[:, from_index], - non_blocking=True) - - def current_run_tensors(self, **kwargs) -> MambaCacheParams: - """ - Return the tensors for the current run's conv and ssm state. - """ - cache_tensors, state_indices_tensor = super().current_run_tensors( - **kwargs) - return MambaCacheParams(cache_tensors[0], cache_tensors[1], - state_indices_tensor) - - def get_seqlen_agnostic_capture_inputs(self, batch_size: int): - """ - Provide the CUDA graph capture runs with a buffer in adjusted size. - The buffer is used to maintain the Mamba Cache during the CUDA graph - replay runs. - """ - return self._mamba_cache, torch.as_tensor([PAD_SLOT_ID] * batch_size, - dtype=torch.int32, - device="cuda") diff --git a/vllm/model_executor/models/minimax_cache.py b/vllm/model_executor/models/minimax_cache.py deleted file mode 100644 index 9164ac06a3b0a..0000000000000 --- a/vllm/model_executor/models/minimax_cache.py +++ /dev/null @@ -1,36 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from dataclasses import dataclass - -import torch - -from vllm.model_executor.models.constant_size_cache import ConstantSizeCache - - -@dataclass -class MinimaxCacheParams: - minimax_cache: torch.Tensor = torch.Tensor() - state_indices_tensor: torch.Tensor = torch.Tensor() - - def at_layer_idx(self, layer_idx): - return MinimaxCacheParams(self.minimax_cache[layer_idx, ...], - self.state_indices_tensor) - - -class MinimaxCacheManager(ConstantSizeCache): - - def __init__(self, dtype, cache_shape): - super().__init__(cache_shape[1]) # max_batch_size is cache_shape[1] - self._minimax_cache = torch.empty(size=cache_shape, - dtype=dtype, - device="cuda") - - @property - def cache(self): - return self._minimax_cache - - def _copy_cache(self, from_index: int, to_index: int): - assert len(self.cache) > 0 - for cache_t in self.cache: - cache_t[:, to_index].copy_(cache_t[:, from_index], - non_blocking=True) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 1d2c7dea811e0..cc9a959f63313 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -14,7 +14,6 @@ import torch.distributed from torch import nn from transformers import MiniMaxConfig -from vllm import envs from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig @@ -44,7 +43,6 @@ from vllm.model_executor.models.utils import maybe_prefix from vllm.sequence import IntermediateTensors from .interfaces import HasInnerState, IsHybrid -from .minimax_cache import MinimaxCacheManager, MinimaxCacheParams from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers @@ -404,7 +402,6 @@ class MiniMaxText01DecoderLayer(nn.Module): def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, - kv_caches: Union[list[dict], Optional[torch.Tensor]], attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], is_warmup: bool = False, @@ -418,7 +415,6 @@ class MiniMaxText01DecoderLayer(nn.Module): hidden_states=layernorm_output, output=self_attention_output, positions=positions, - kv_caches=kv_caches, ) residual = residual * self.layernorm_attention_alpha @@ -563,10 +559,6 @@ class MiniMaxText01Model(nn.Module): self._dtype = _dummy.dtype del _dummy - if not envs.VLLM_USE_V1: - self.minimax_cache = MinimaxCacheManager( - dtype=torch.float32, cache_shape=self.cache_shape) - norm_kwargs = {} if hasattr(config, "rms_norm_eps"): norm_kwargs["eps"] = config.rms_norm_eps @@ -614,25 +606,6 @@ class MiniMaxText01Model(nn.Module): **kwargs) -> Union[torch.Tensor, IntermediateTensors]: forward_context = get_forward_context() attn_metadata = forward_context.attn_metadata - if not envs.VLLM_USE_V1 and attn_metadata is None: - return None - if not envs.VLLM_USE_V1: - if "request_ids_to_seq_ids" not in kwargs: - kwargs["request_ids_to_seq_ids"] = {} - if "finished_requests_ids" not in kwargs: - kwargs["finished_requests_ids"] = [] - ( - minimax_cache_tensors, - state_indices_tensor, - ) = self.minimax_cache.current_run_tensors(**kwargs) - if getattr(attn_metadata, "num_prefills", 0) > 0: - self._clear_prefill_cache(attn_metadata, minimax_cache_tensors, - **kwargs) - - minimax_cache_params = MinimaxCacheParams(minimax_cache_tensors, - state_indices_tensor) - else: - minimax_cache_params = None if get_pp_group().is_first_rank: if inputs_embeds is None: @@ -645,20 +618,10 @@ class MiniMaxText01Model(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - minimax_cache_index = 0 - for layer in islice(self.layers, self.start_layer, self.end_layer): - _caches = None - if not envs.VLLM_USE_V1 and isinstance( - layer.self_attn, MiniMaxText01LinearAttention): - current_state_layer = minimax_cache_index - _caches = minimax_cache_params.at_layer_idx( - current_state_layer) - minimax_cache_index += 1 hidden_states, residual = layer( hidden_states=hidden_states, positions=positions, - kv_caches=_caches, attn_metadata=attn_metadata, residual=residual, ) @@ -1003,13 +966,11 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): def get_mamba_state_shape_from_config( cls, vllm_config: "VllmConfig", - use_v1: bool = True, ) -> tuple[tuple[int, ...], ...]: """Calculate shape for MiniMaxText01LinearAttention cache. Args: vllm_config: vLLM config - use_v1: Get shapes for V1 (or V0) Returns: Tuple containing: diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index 131a66b713235..50521b5937862 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -50,7 +50,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.multimodal.utils import run_dp_sharded_vision_model from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape @@ -58,6 +57,7 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .llama4 import Llama4ForCausalLM from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, merge_multimodal_embeddings) +from .vision import run_dp_sharded_vision_model class Llama4ImagePatchInputs(TensorSchema): diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index ff571541a60a5..987920ecc3310 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -23,21 +23,17 @@ from typing import Optional import torch from torch import nn -from vllm import envs from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group -from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import ReLUSquaredActivation from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.mamba2_metadata import ( - Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator) @@ -49,14 +45,11 @@ from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, SupportsQuant) -from vllm.model_executor.models.mamba_cache import (MambaCacheManager, - MambaCacheParams) from vllm.model_executor.models.utils import ( AutoWeightsLoader, WeightsMapper, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import NemotronHConfig -from vllm.utils import LayerBlockType class NemotronHMLP(nn.Module): @@ -181,8 +174,6 @@ class NemotronHMambaDecoderLayer(nn.Module): self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, **kwargs, ): if residual is None: @@ -192,7 +183,7 @@ class NemotronHMambaDecoderLayer(nn.Module): hidden_states, residual = self.norm(hidden_states, residual) output = torch.empty_like(hidden_states) - self.mixer(hidden_states, output, mamba_cache_params, mamba2_metadata) + self.mixer(hidden_states, output) return output, residual @@ -370,22 +361,10 @@ class NemotronHModel(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - mamba_cache_params: MambaCacheParams, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - attn_metadata = get_forward_context().attn_metadata - - if not envs.VLLM_USE_V1: - mamba2_metadata = prepare_mamba2_metadata( - chunk_size=self.config.chunk_size, - attn_metadata=attn_metadata, - ) - else: - # v1 get mamba2_metadata from forward_context - mamba2_metadata = None - if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -398,22 +377,11 @@ class NemotronHModel(nn.Module): residual = intermediate_tensors["residual"] residual = None - num_non_mamba_layers = 0 for i, layer in enumerate(self.layers): - layer_mamba_cache_params = None - if isinstance(layer, - NemotronHMambaDecoderLayer) and mamba_cache_params: - layer_mamba_cache_params = mamba_cache_params.at_layer_idx( - i - num_non_mamba_layers) - else: - num_non_mamba_layers += 1 - hidden_states, residual = layer( positions=positions, hidden_states=hidden_states, residual=residual, - mamba_cache_params=layer_mamba_cache_params, - mamba2_metadata=mamba2_metadata, ) if not get_pp_group().is_last_rank: @@ -508,13 +476,11 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, def get_mamba_state_shape_from_config( cls, vllm_config: "VllmConfig", - use_v1: bool = True, ) -> tuple[tuple[int, int], tuple[int, int, int]]: """Calculate shapes for Mamba's convolutional and state caches. Args: vllm_config: vLLM config - use_v1: Get shapes for V1 (or V0) Returns: Tuple containing: @@ -533,7 +499,6 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, head_dim=hf_config.mamba_head_dim, state_size=hf_config.ssm_state_size, conv_kernel=hf_config.conv_kernel, - use_v1=use_v1, ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -566,8 +531,6 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, if not lora_config else lora_config.lora_vocab_padding_size, prefix=maybe_prefix(prefix, "lm_head"), ) - # Used to track and store by the Mamba cache between steps. - self.mamba_cache: Optional[MambaCacheManager] = None self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) @@ -584,40 +547,11 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): - mamba_cache_params = None - if not envs.VLLM_USE_V1: - if self.mamba_cache is None: - - num_mamba_layers = \ - self.model_config.get_num_layers_by_block_type( - self.vllm_config.parallel_config, - LayerBlockType.mamba - ) - mamba_state_shape = \ - self.get_mamba_state_shape_from_config( - self.vllm_config, use_v1=False) - mamba_state_dtype = \ - self.get_mamba_state_dtype_from_config( - self.vllm_config) - self.mamba_cache = MambaCacheManager(self.vllm_config, - num_mamba_layers, - *mamba_state_shape, - *mamba_state_dtype) - - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - - hidden_states = self.model(input_ids, positions, mamba_cache_params, - intermediate_tensors, inputs_embeds) + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) return hidden_states - def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): - return self.mamba_cache.copy_inputs_before_cuda_graphs( - input_buffers, **kwargs) - - def get_seqlen_agnostic_capture_inputs(self, batch_size: int): - return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/phi4flash.py b/vllm/model_executor/models/phi4flash.py deleted file mode 100644 index ae153558e37aa..0000000000000 --- a/vllm/model_executor/models/phi4flash.py +++ /dev/null @@ -1,731 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import math -from collections.abc import Iterable -from typing import Optional, Union - -import torch -import torch.nn as nn -from transformers.activations import ACT2FN - -import vllm.envs as envs -from vllm.attention import Attention, AttentionMetadata, AttentionType -from vllm.attention.selector import _Backend -from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm.forward_context import ForwardContext, get_forward_context -from vllm.logger import init_logger -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( - causal_conv1d_fn, causal_conv1d_update) -from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( - selective_scan_fn, selective_state_update) -from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, - SupportsV0Only) -from vllm.model_executor.models.mamba_cache import (MambaCacheManager, - MambaCacheParams) -from vllm.sequence import IntermediateTensors - -from .utils import make_layers, maybe_prefix - -logger = init_logger(__name__) - - -class SwiGLUActivation(nn.Module): - - def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: - return x1 * nn.functional.silu(x2) - - -class SambaYMLP(nn.Module): - """Gated Linear Unit. - - Reference: - Language Modeling with Gated Convolutional Networks. - https://arxiv.org/pdf/1612.08083v3.pdf. - - """ - - def __init__(self, config): - super().__init__() - - self.config = config - self.fc1 = nn.Linear(config.hidden_size, - 2 * config.intermediate_size, - bias=False) - self.fc2 = nn.Linear(config.intermediate_size, - config.hidden_size, - bias=False) - - self.activation_fn = ACT2FN[config.hidden_act] - - def forward(self, hidden_states): - y = self.fc1(hidden_states) - gate, y = y.chunk(2, dim=-1) - y = y * self.activation_fn(gate) - return self.fc2(y) - - -def get_virtual_engine(): - forward_context: ForwardContext = get_forward_context() - return forward_context.virtual_engine - - -class SambaYAttention(nn.Module): - - def __init__(self, - config, - layer_idx: Optional[int] = None, - yoco_cross: bool = False, - cache_config: Optional[CacheConfig] = None, - prefix: str = ""): - super().__init__() - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing " - "a `layer_idx` is not recommended and will lead to errors " - "during the forward call if caching is used. Please make " - "sure to provide a `layer_idx` when creating this class.") - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.yoco_cross = yoco_cross - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError("hidden_size must be divisible by num_heads " - f"(got `hidden_size`: {self.hidden_size} and " - f"`num_heads`: {self.num_heads}).") - - op_size = self.num_heads * self.head_dim + 2 * ( - self.num_key_value_heads * self.head_dim) - self.out_proj = nn.Linear(self.num_heads * self.head_dim, - self.hidden_size, - bias=True) - if yoco_cross: - self.Wqkv = nn.Linear(self.hidden_size, - self.num_heads * self.head_dim, - bias=True) - else: - self.Wqkv = nn.Linear(self.hidden_size, op_size, bias=True) - - # disable sliding window for the second half of the model - is_sliding = config.layer_types[layer_idx] == "sliding_attention" - sliding_window = config.sliding_window if is_sliding else None - - assert self.num_heads % 2 == 0, 'num_heads should be even' - assert self.num_key_value_heads % 2 == 0, 'num_heads should be even' - - self.lambda_init = self.lambda_init_fn(layer_idx) - self.lambda_q1 = nn.Parameter( - torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, - std=0.1)) - self.lambda_k1 = nn.Parameter( - torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, - std=0.1)) - self.lambda_q2 = nn.Parameter( - torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, - std=0.1)) - self.lambda_k2 = nn.Parameter( - torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, - std=0.1)) - self.subln = nn.RMSNorm(2 * self.head_dim, - eps=1e-5, - elementwise_affine=True) - - params = { - 'differential_flash_attention_config': { - 'lambda_init': self.lambda_init, - 'lambda_q1': self.lambda_q1, - 'lambda_k1': self.lambda_k1, - 'lambda_q2': self.lambda_q2, - 'lambda_k2': self.lambda_k2, - "subln": self.subln, - } - } - - if yoco_cross: - kv_shared_layer_index = config.num_hidden_layers // 2 + 1 - kv_sharing_target_layer_name = \ - f"model.layers.{kv_shared_layer_index}.self_attn.attn" - else: - kv_sharing_target_layer_name = None - - self.attn = Attention( - self.num_heads, - self.head_dim, - self.head_dim**-0.5, - num_kv_heads=self.num_key_value_heads, - cache_config=cache_config, - per_layer_sliding_window=sliding_window, - prefix=f"{prefix}.attn", - attn_type=AttentionType.DECODER, - kv_sharing_target_layer_name=kv_sharing_target_layer_name, - **params) - assert self.attn.backend == _Backend.DIFFERENTIAL_FLASH_ATTN,\ - "DIFFERENTIAL_FLASH_ATTN required" - - def lambda_init_fn(self, depth): - return 0.8 - 0.6 * math.exp(-0.3 * depth) - - def forward( - self, - hidden_states: torch.Tensor, - ): - - if not self.yoco_cross: # need to generate kv-cache - qkv = self.Wqkv(hidden_states) - q, k, v = qkv.split([ - self.hidden_size, self.num_key_value_heads * self.head_dim, - self.num_key_value_heads * self.head_dim - ], - dim=-1) - attn_output = self.attn(q, k, v) - else: # reuse the kv cache, full attention - q = self.Wqkv(hidden_states) - attn_output = self.attn(q, None, None) - attn_output = attn_output.view(-1, self.num_heads * self.head_dim) - return self.out_proj(attn_output) - - -class Phi4Mamba(nn.Module): - - def __init__( - self, - d_model, - d_state=16, - d_conv=4, - expand=2, - dt_rank="auto", - dt_min=0.001, - dt_max=0.1, - dt_init="random", # difference - dt_scale=1.0, # difference - dt_init_floor=1e-4, - conv_bias=True, - bias=False, - use_fast_path=True, # Fused kernel options - layer_idx=None, - device=None, - dtype=None, - yoco_cross=False, - yoco_kv=False, - ): - factory_kwargs = {"params_dtype": dtype} # difference - super().__init__() - self.yoco_cross = yoco_cross - self.yoco_kv = yoco_kv - self.d_model = d_model - self.d_state = d_state - self.d_conv = d_conv - self.expand = expand - self.d_inner = int(self.expand * self.d_model) - self.dt_rank = math.ceil(self.d_model / - 16) if dt_rank == "auto" else dt_rank - self.use_fast_path = use_fast_path - self.layer_idx = layer_idx - self.swiGluActivation = SwiGLUActivation() - if self.yoco_cross: - self.in_proj = MergedColumnParallelLinear(self.d_model, - [self.d_inner], - bias=bias, - **factory_kwargs) - self.out_proj = RowParallelLinear(self.d_inner, - self.d_model, - bias=bias, - **factory_kwargs) - return - self.conv1d = ColumnParallelLinear( - input_size=d_conv, - output_size=self.d_inner, - bias=conv_bias, - params_dtype=dtype, - ) - # unsqueeze to fit conv1d weights shape into the linear weights shape. - # Can't do this in `weight_loader` since it already exists in - # `ColumnParallelLinear` and `set_weight_attrs` - # doesn't allow to override it - self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) - - self.in_proj = MergedColumnParallelLinear( - self.d_model, - [self.d_inner] * 2, - bias=bias, - params_dtype=dtype, - ) - - # selective projection used to make dt, B and C input dependent - self.x_proj = RowParallelLinear( - self.d_inner, - self.dt_rank + self.d_state * 2, - bias=False, - params_dtype=dtype, - ) - - # time step projection (discretization) - - # In the forward we need to apply dt_proj without the bias, - # as the bias is added in the selective scan kernel. - self.dt_proj = ColumnParallelLinear( - self.dt_rank, - self.d_inner, - bias=True, - skip_bias_add=True, - params_dtype=dtype, - ) - - # # D "skip" parameter - # self.D = nn.Parameter(torch.ones(self.d_inner)) # Keep in fp32 - self.A = nn.Parameter( - torch.empty( - self.d_inner, - self.d_state, - dtype=torch.float32, - )) - self.D = nn.Parameter(torch.ones(self.d_inner, dtype=torch.float32)) - - self.out_proj = RowParallelLinear( - self.d_inner, - self.d_model, - bias=bias, - input_is_parallel=True, - params_dtype=dtype, - ) - self.activation = "silu" - - def forward(self, - hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, - mamba_cache_params: MambaCacheParams, - yoco_key_values=None) -> torch.Tensor: - - if self.yoco_cross: - out = self.in_proj(hidden_states)[0] - out = self.swiGluActivation(yoco_key_values, out) - out = self.out_proj(out) - return out[0], yoco_key_values - - # 1. Gated MLP's linear projection - # projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) - projected_states = self.in_proj( - hidden_states.to(self.in_proj.weight.dtype))[0].transpose(-2, -1) - hidden_states, gate = projected_states.chunk(2, dim=-2) - - # 2. Convolution sequence transformation - conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), - self.conv1d.weight.size(2)) - - if attn_metadata.query_start_loc is not None \ - and attn_metadata.context_lens_tensor is not None: - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ---------------------| - # |-- query_len ---| - hidden_states = causal_conv1d_fn( - hidden_states, - conv_weights, - self.conv1d.bias, - activation=self.activation, - conv_states=mamba_cache_params.conv_state, - has_initial_state=attn_metadata.context_lens_tensor > 0, - cache_indices=mamba_cache_params.state_indices_tensor, - query_start_loc=attn_metadata.query_start_loc) - else: - hidden_states = causal_conv1d_update( - hidden_states.transpose(0, 1), - mamba_cache_params.conv_state, - conv_weights, - self.conv1d.bias, - self.activation, - conv_state_indices=mamba_cache_params.state_indices_tensor) - hidden_states = hidden_states.transpose(0, 1) - - # 3. State Space Model sequence transformation - # 3.a. input varying initialization of time_step, B and C - ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0] - - time_step, B, C = torch.split( - ssm_parameters, - [self.dt_rank, self.d_state, self.d_state], - dim=-1, - ) - - # Note that Jamba normalizes B, C, and time_step here but Mamba doesn't. - - discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) - # 3.c perform the recurrence y ← SSM(A, B, C)(x) - time_proj_bias = (self.dt_proj.bias.float() if hasattr( - self.dt_proj, "bias") else None) - - if attn_metadata.query_start_loc is not None \ - and attn_metadata.context_lens_tensor is not None: - scan_outputs = selective_scan_fn( - hidden_states, - mamba_cache_params.ssm_state, - discrete_time_step, - self.A, - B.transpose(-2, -1), - C.transpose(-2, -1), - self.D.float(), - # z, - None if self.yoco_kv else gate, - time_proj_bias, - delta_softplus=True, - cache_indices=mamba_cache_params.state_indices_tensor, - has_initial_state=attn_metadata.context_lens_tensor > 0, - query_start_loc=attn_metadata.query_start_loc) - else: - scan_outputs = torch.empty_like(hidden_states.transpose(0, 1)) - selective_state_update( - mamba_cache_params.ssm_state, - hidden_states.transpose(0, 1), - discrete_time_step.transpose(0, 1), - self.A, - B, - C, - self.D, - # z - # gate.transpose(0, 1), - None if self.yoco_kv else gate.transpose(0, 1), - time_proj_bias, - dt_softplus=True, - state_batch_indices=mamba_cache_params.state_indices_tensor, - out=scan_outputs) - scan_outputs = scan_outputs.transpose(0, 1) - - # 4. Final linear projection - if self.yoco_kv: - # gate = gate.transpose(-1,-2).contiguous() - yoco_key_values = scan_outputs.transpose(-2, -1) - scan_outputs = self.swiGluActivation(scan_outputs, gate) - - contextualized_states = self.out_proj(scan_outputs.transpose(-2, - -1))[0] - - return contextualized_states, yoco_key_values - - -class SambaYDecoderLayer(nn.Module): - - def __init__( - self, - config, - layer_idx, - cache_config, - prefix: str = "", - ) -> None: - super().__init__() - - self.config = config - self.layer_idx = layer_idx - - self.mlp = SambaYMLP(config) - self.input_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) - - self.yoco_mb = False - self.yoco_cross = False - if layer_idx >= config.num_hidden_layers // 2: - self.yoco_mb = True - self.yoco_cross = (layer_idx - >= (config.num_hidden_layers // 2 + 2)) - self.use_mamba = config.mb_per_layer > 0 and \ - layer_idx % config.mb_per_layer == 0 - if self.use_mamba: - factory_kwargs = {"dtype": None} - self.attn = Phi4Mamba(config.hidden_size, - layer_idx=layer_idx, - yoco_cross=self.yoco_cross, - yoco_kv=self.yoco_mb, - **factory_kwargs) - else: - self.attn = SambaYAttention(config, - layer_idx=layer_idx, - yoco_cross=self.yoco_cross, - cache_config=cache_config, - prefix=f"{prefix}.self_attn") - self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - positions: torch.Tensor, - attn_metadata: AttentionMetadata, - mamba_cache_params: MambaCacheParams, - ssm_output: Optional[torch.LongTensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - if self.use_mamba: - assert mamba_cache_params is not None - else: - assert mamba_cache_params is None - - residual = hidden_states - hidden_states = self.input_layernorm( - hidden_states.to(dtype=self.input_layernorm.weight.dtype)) - - if self.use_mamba: - attn_outputs, ssm_output = self.attn(hidden_states, - attn_metadata, - mamba_cache_params, - yoco_key_values=ssm_output) - residual = residual.to(torch.float32) - else: - attn_outputs = self.attn(hidden_states, ) - hidden_states = residual + attn_outputs - residual = hidden_states - hidden_states = self.post_attention_layernorm( - hidden_states.to(dtype=self.post_attention_layernorm.weight.dtype)) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - return hidden_states, ssm_output - - -class SambaYModel(nn.Module): - - def __init__(self, - config, - cache_config=None, - quant_config=None, - lora_config=None, - prefix: str = "") -> None: - super().__init__() - self.config = config - self.vocab_size = config.vocab_size - self.embed_tokens = VocabParallelEmbedding( - self.vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - ) - - # Pipeline parallel is not supported since the second half of - # the layers share the kv cache. - if get_pp_group().world_size != 1: - raise ValueError("Pipeline Parallel not supported") - - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, - lambda prefix: SambaYDecoderLayer(config, - int(prefix.split('.')[-1]), - cache_config, - prefix=prefix), - prefix=f"{prefix}.layers") - self.final_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) - - def forward( - self, - input_ids: Optional[torch.Tensor], - positions: torch.Tensor, - attn_metadata: AttentionMetadata, - mamba_cache_params: MambaCacheParams, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - - if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.get_input_embeddings(input_ids) - else: - assert intermediate_tensors is not None - hidden_states = intermediate_tensors["hidden_states"] - - mamba_state_idx = 0 - ssm_output = None - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - if i == self.config.num_hidden_layers // 2 + 2: - # profile run - kv_cache_idx = self.config.num_hidden_layers // 2 + 1 - cache_layer = self.layers[kv_cache_idx] - kv_cache = cache_layer.attn.attn.kv_cache - if kv_cache[0].numel() == 0: - break - - # Starting from this layer, we do not need to calculate - # the kv cache since we reuse the kv cache from last layer. - # If in prefill phase, we can prune> truncate - # the hidden state to save computation cost. - if attn_metadata.prefill_metadata and not envs.VLLM_USE_V1: - selected_token_indices = torch.cumsum( - attn_metadata.seq_lens_tensor, dim=0) - 1 - hidden_states = hidden_states.index_select( - 0, selected_token_indices) - ssm_output = ssm_output.index_select( - 0, selected_token_indices) - - if layer.use_mamba: - if i < self.config.num_hidden_layers // 2 or \ - not layer.yoco_cross: - mamba_cache = mamba_cache_params.at_layer_idx( - mamba_state_idx) - mamba_state_idx += 1 - else: - mamba_cache = mamba_cache_params.at_layer_idx( - mamba_state_idx - 1) - - hidden_states, ssm_output = layer(hidden_states, - positions, - attn_metadata, - mamba_cache, - ssm_output=ssm_output) - else: - hidden_states, ssm_output = layer( - hidden_states, - positions, - attn_metadata, - None, # mamba_cache_params - ssm_output=ssm_output) - - hidden_states = self.final_layernorm( - hidden_states.to(dtype=self.final_layernorm.weight.dtype)) - return hidden_states - - -class Phi4FlashForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - lora_config = vllm_config.lora_config - quant_config = vllm_config.quant_config - scheduler_config = vllm_config.scheduler_config - self.compilation_config = vllm_config.compilation_config - self.vllm_config = vllm_config - # Prefix caching and chunked prefill is not supported for this model. - assert not cache_config.enable_prefix_caching, \ - "Phi4flash currently does not support prefix caching" - assert not scheduler_config.chunked_prefill_enabled, \ - "Phi4Flash currently does not support prefix caching" - super().__init__() - self.config = config - self.model_config = vllm_config.model_config - self.scheduler_config = scheduler_config - self.model = SambaYModel(config, - cache_config=cache_config, - prefix=maybe_prefix(prefix, "model")) - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size - self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=( - DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config else lora_config.lora_vocab_padding_size), - quant_config=quant_config, - prefix=maybe_prefix(prefix, "lm_head"), - ) - self.embedding_bias = None - # Used to track and store by the Mamba cache between steps. - self.mamba_cache: Optional[MambaCacheManager] = None - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - logits_as_input=False) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs, - ) -> Union[torch.Tensor, IntermediateTensors]: - if self.mamba_cache is None: - num_mamba_layers = self.config.num_hidden_layers \ - // 2 // self.config.mb_per_layer + 1 - self.mamba_cache = MambaCacheManager( - self.vllm_config, - num_mamba_layers, - *self._get_mamba_cache_shape(), - self.lm_head.weight.dtype, - self.lm_head.weight.dtype, - ) - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - - attn_metadata = get_forward_context().attn_metadata - # input_ids and hidden_states isn't a one-to-one mapping in prefill - # stage due to YOCO optimization. - hidden_states = self.model(input_ids, positions, attn_metadata, - mamba_cache_params, intermediate_tensors, - inputs_embeds) - return hidden_states - - def _get_mamba_cache_shape( - self - ) -> tuple[Optional[tuple[int, int]], Optional[tuple[int, int]]]: - world_size = get_tensor_model_parallel_world_size() - hidden_size = self.config.hidden_size - mamba_expand = self.config.mamba_expand # 2 - mamba_d_conv = self.config.mamba_d_conv # 4 - mamba_d_state = self.config.mamba_d_state # 16 - conv_state_shape = ( - mamba_expand * hidden_size // world_size, - mamba_d_conv - 1, - ) - temporal_state_shape = ( - mamba_expand * hidden_size // world_size, - mamba_d_state, - ) - return conv_state_shape, temporal_state_shape - - def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): - return self.mamba_cache.copy_inputs_before_cuda_graphs( - input_buffers, **kwargs) - - def get_seqlen_agnostic_capture_inputs(self, batch_size: int): - return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - - def compute_logits( - self, - hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: - processed_logits = self.logits_processor( - self.lm_head, - hidden_states, - self.embedding_bias, - ) - return processed_logits - - def load_weights( - self, - weights: Iterable[tuple[str, torch.Tensor]], - ): - weights = {name: weight for name, weight in weights} - adjusted_weights = {} - for name, weight in weights.items(): - if "A_log" in name: - name = name.replace("A_log", "A") - weight = -torch.exp(weight.float()) - if "inner_cross_attn." in name: - name = name.replace("inner_cross_attn.", "") - adjusted_weights[name] = weight - adjusted_weights["lm_head.weight"] = weights[ - "model.embed_tokens.weight"] - loaded_params: set[str] = set() - for name, param in self.named_parameters(): - weight = adjusted_weights.get(name) - if weight is not None and weight.shape != param.shape: - logger.warning("Shape mismatch: %s %s %s", name, weight.shape, - param.shape) - loaded_params.add(name) - missing_keys, unexpected_keys = self.load_state_dict(adjusted_weights, - strict=False) - assert len(unexpected_keys) == 0, f"Unexpected keys: {unexpected_keys}" - assert len(missing_keys) == 0, f"Missing keys: {missing_keys}" - return loaded_params diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 33ee1cf44afd1..0292f3bf8317d 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -12,7 +12,6 @@ import torch from torch import nn from transformers import PretrainedConfig -from vllm import envs from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile @@ -29,8 +28,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.abstract import MambaBase -from vllm.model_executor.layers.mamba.mamba2_metadata import ( - Mamba2Metadata, prepare_mamba2_metadata, update_metadata) from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( @@ -47,15 +44,13 @@ from vllm.model_executor.model_loader.weight_utils import ( composed_weight_loader, default_weight_loader, sharded_weight_loader) from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, SupportsPP) -from vllm.model_executor.models.mamba_cache import (MambaCacheManager, - MambaCacheParams) from vllm.model_executor.models.utils import ( is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from vllm.utils import LayerBlockType, direct_register_custom_op +from vllm.utils import direct_register_custom_op from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata @@ -194,17 +189,13 @@ class Plamo2MambaMixer(MambaBase, CustomOp): self.chunk_size = self.config.mamba_chunk_size - if envs.VLLM_USE_V1: - compilation_config = get_current_vllm_config().compilation_config - if prefix in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {prefix}") - compilation_config.static_forward_context[prefix] = self - # The outer list is for v0 PP virtual engine. Though this code path - # only runs for v1, we have to do this to unify with the interface - # of Attention + v0 PP. - # The inner tuple is (conv_state, ssm_state) - self.kv_cache = [(torch.tensor([]), torch.tensor([]))] - assert self.chunk_size != -1, "chunk_size must be set for v1" + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + # The tuple is (conv_state, ssm_state) + self.kv_cache = (torch.tensor([]), torch.tensor([])) + assert self.chunk_size != -1, "chunk_size must be set for v1" self.prefix = prefix @@ -227,8 +218,6 @@ class Plamo2MambaMixer(MambaBase, CustomOp): self, hidden_states: torch.Tensor, output: torch.Tensor, - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, **kwargs, ): pass @@ -237,59 +226,43 @@ class Plamo2MambaMixer(MambaBase, CustomOp): self, hidden_states: torch.Tensor, output: torch.Tensor, - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, **kwargs, ): - if not envs.VLLM_USE_V1: - CustomOp.forward(self, hidden_states, output, mamba_cache_params, - mamba2_metadata) - else: - torch.ops.vllm.plamo2_mamba_mixer( - hidden_states, - output, - self.prefix, - ) + torch.ops.vllm.plamo2_mamba_mixer( + hidden_states, + output, + self.prefix, + ) def forward_cuda( self, hidden_states: torch.Tensor, output: torch.Tensor, - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, **kwargs, ): forward_context = get_forward_context() - # mamba2_metadata contains metadata necessary for the mamba2 triton + # attn_metadata contains metadata necessary for the mamba2 triton # kernels to operate in continuous batching and in chunked prefill # modes; they are computed at top-level model forward since they # stay the same and reused for all mamba layers in the same iteration attn_metadata: AttentionMetadata = forward_context.attn_metadata - if envs.VLLM_USE_V1: - if attn_metadata is not None: - assert isinstance(attn_metadata, dict) - attn_metadata = attn_metadata[self.prefix] - mamba2_metadata = attn_metadata - assert isinstance(attn_metadata, Mamba2AttentionMetadata) - self_kv_cache = self.kv_cache[forward_context.virtual_engine] - # conv_state = (..., dim, width-1) yet contiguous along 'dim' - conv_state = self_kv_cache[0].transpose(-1, -2) - ssm_state = self_kv_cache[1] - state_indices_tensor = attn_metadata.state_indices_tensor - else: - conv_state = mamba_cache_params.conv_state - ssm_state = mamba_cache_params.ssm_state - state_indices_tensor = mamba_cache_params.state_indices_tensor - # Common members between V1 metadata and V0 metadata - if mamba2_metadata is not None: - has_initial_states_p = mamba2_metadata.has_initial_states_p - prep_initial_states = mamba2_metadata.prep_initial_states - chunk_size = mamba2_metadata.chunk_size - seq_idx_p = mamba2_metadata.seq_idx_p - chunk_indices_p = mamba2_metadata.chunk_indices_p - chunk_offsets_p = mamba2_metadata.chunk_offsets_p + if attn_metadata is not None: + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + assert isinstance(attn_metadata, Mamba2AttentionMetadata) + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + # conv_state = (..., dim, width-1) yet contiguous along 'dim' + conv_state = self_kv_cache[0].transpose(-1, -2) + ssm_state = self_kv_cache[1] + state_indices_tensor = attn_metadata.state_indices_tensor + has_initial_states_p = attn_metadata.has_initial_states_p + prep_initial_states = attn_metadata.prep_initial_states + chunk_size = attn_metadata.chunk_size + seq_idx_p = attn_metadata.seq_idx_p + chunk_indices_p = attn_metadata.chunk_indices_p + chunk_offsets_p = attn_metadata.chunk_offsets_p # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states) @@ -299,8 +272,8 @@ class Plamo2MambaMixer(MambaBase, CustomOp): conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) - if envs.VLLM_USE_V1 and attn_metadata is None: - # V1 profile run + if attn_metadata is None: + # profile run hidden_states = (hidden_states.transpose(0, 1).clone().transpose( 0, 1)).contiguous() output[:] = self.out_proj(hidden_states) @@ -316,42 +289,23 @@ class Plamo2MambaMixer(MambaBase, CustomOp): # NOTE: V0 put prefill before decode, v1 puts decode before prefill # Separate prefill and decode by splitting varlen input # Split along token dimension - if envs.VLLM_USE_V1: - hidden_states_d, hidden_states_p = torch.split( - hidden_states[:num_actual_tokens], - [num_decodes, num_prefill_tokens], - dim=0, - ) - gate_d, gate_p = torch.split(gate[:num_actual_tokens], - [num_decodes, num_prefill_tokens], - dim=0) - # Split along batch dimension - state_indices_tensor_d, state_indices_tensor_p = torch.split( - state_indices_tensor, - [num_decodes, num_prefills], - dim=0, - ) - query_start_loc_p = ( - attn_metadata.query_start_loc[-num_prefills - 1:] - - num_decodes if has_prefill else None) - else: - hidden_states_p, hidden_states_d = torch.split( - hidden_states, - [num_prefill_tokens, num_decodes], - dim=0, - ) - gate_p, gate_d = torch.split(gate, - [num_prefill_tokens, num_decodes], - dim=0) - # Split along batch dimension - state_indices_tensor_p, state_indices_tensor_d = torch.split( - state_indices_tensor, - [num_prefills, num_decodes], - dim=0, - ) - query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + - 1] - if has_prefill else None) + hidden_states_d, hidden_states_p = torch.split( + hidden_states[:num_actual_tokens], + [num_decodes, num_prefill_tokens], + dim=0, + ) + gate_d, gate_p = torch.split(gate[:num_actual_tokens], + [num_decodes, num_prefill_tokens], + dim=0) + # Split along batch dimension + state_indices_tensor_d, state_indices_tensor_p = torch.split( + state_indices_tensor, + [num_decodes, num_prefills], + dim=0, + ) + query_start_loc_p = ( + attn_metadata.query_start_loc[-num_prefills - 1:] - + num_decodes if has_prefill else None) # Preallocate output tensor to avoid memcpy cost for merging prefill # and decode outputs @@ -363,18 +317,11 @@ class Plamo2MambaMixer(MambaBase, CustomOp): dtype=hidden_states.dtype, device=hidden_states.device, ) - if envs.VLLM_USE_V1: - preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split( - preallocated_ssm_out, - [num_decodes, num_prefill_tokens], - dim=0, - ) - else: - preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split( - preallocated_ssm_out, - [num_prefill_tokens, num_decodes], - dim=0, - ) + preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split( + preallocated_ssm_out, + [num_decodes, num_prefill_tokens], + dim=0, + ) # Process prefill requests if has_prefill: @@ -383,9 +330,6 @@ class Plamo2MambaMixer(MambaBase, CustomOp): # pointed to by "state_indices_tensor" x = hidden_states_p.transpose( 0, 1) # this is the form that causal-conv see - if mamba2_metadata.cu_seqlen is None: - mamba2_metadata = update_metadata(x, query_start_loc_p, - mamba2_metadata) hidden_states_p = causal_conv1d_fn( x, conv_weights, @@ -394,7 +338,7 @@ class Plamo2MambaMixer(MambaBase, CustomOp): conv_states=conv_state, has_initial_state=has_initial_states_p, cache_indices=state_indices_tensor_p, - metadata=mamba2_metadata, + metadata=attn_metadata, query_start_loc=query_start_loc_p) hidden_states_p = hidden_states_p.transpose(0, 1) hidden_states_p = hidden_states_p[:num_prefill_tokens] @@ -470,7 +414,7 @@ class Plamo2MambaMixer(MambaBase, CustomOp): -1, self.num_heads // self.tp_size, self.head_dim) # - the hidden is reshaped into (bs, num_heads, head_dim) - # - mamba_cache_params.ssm_state's slots will be selected + # - ssm_state's slots will be selected # using state_indices_tensor_d # NOTE: final output is an in-place update of out tensor @@ -530,10 +474,7 @@ def plamo2_mamba_mixer( ) -> None: forward_context: ForwardContext = get_forward_context() self = forward_context.no_compile_layers[layer_name] - self.forward_cuda(hidden_states=hidden_states, - output=output, - mamba_cache_params=None, - mamba2_metadata=None) + self.forward_cuda(hidden_states=hidden_states, output=output) def plamo2_mamba_mixer_fake( @@ -731,8 +672,6 @@ class Plamo2DecoderLayer(nn.Module): positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, **kwargs, ): if residual is None: @@ -747,8 +686,6 @@ class Plamo2DecoderLayer(nn.Module): output = torch.empty_like(hidden_states) mixer_kwargs = { "output": output, - "mamba_cache_params": mamba_cache_params, - "mamba2_metadata": mamba2_metadata, } else: mixer_kwargs = { @@ -790,23 +727,12 @@ class Plamo2Decoder(torch.nn.Module): positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, ) -> torch.Tensor: - mamba_cache_index = 0 for layer in islice(self.layers, self.start_layer, self.end_layer): - layer_mamba_cache_params = None - if layer.is_mamba and mamba_cache_params is not None: - layer_mamba_cache_params = mamba_cache_params.at_layer_idx( - mamba_cache_index) - mamba_cache_index += 1 - hidden_states, residual = layer( positions=positions, hidden_states=hidden_states, residual=residual, - mamba_cache_params=layer_mamba_cache_params, - mamba2_metadata=mamba2_metadata, ) return hidden_states, residual @@ -844,7 +770,6 @@ class Plamo2Model(torch.nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - mamba_cache_params: MambaCacheParams, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -859,23 +784,10 @@ class Plamo2Model(torch.nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - if not envs.VLLM_USE_V1: - attn_metadata: AttentionMetadata = get_forward_context( - ).attn_metadata - mamba2_metadata = prepare_mamba2_metadata( - chunk_size=self.config.mamba_chunk_size, - attn_metadata=attn_metadata, - ) - else: - # v1 get mamba2_metadata from forward_context - mamba2_metadata = None - hidden_states, residual = self.layers( positions=positions, hidden_states=hidden_states, residual=residual, - mamba_cache_params=mamba_cache_params, - mamba2_metadata=mamba2_metadata, ) if not get_pp_group().is_last_rank: return IntermediateTensors({ @@ -925,9 +837,6 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid): if self.config.tie_word_embeddings: self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) - # Used to track and store by the Mamba cache between steps. - self.mamba_cache: Optional[MambaCacheManager] = None - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.config.vocab_size) self.make_empty_intermediate_tensors = ( @@ -942,39 +851,11 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): - if not envs.VLLM_USE_V1: - if self.mamba_cache is None: - num_mamba_layers = ( - self.model_config.get_num_layers_by_block_type( - self.vllm_config.parallel_config, - LayerBlockType.mamba)) - mamba_state_shape = self.get_mamba_state_shape_from_config( - self.vllm_config, use_v1=False) - mamba_state_dtype = \ - self.get_mamba_state_dtype_from_config( - self.vllm_config) - self.mamba_cache = MambaCacheManager(self.vllm_config, - num_mamba_layers, - *mamba_state_shape, - *mamba_state_dtype) - - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - else: - # NOTE: mamba_cache_params is not needed for v1 - mamba_cache_params = None - - hidden_states = self.model(input_ids, positions, mamba_cache_params, - intermediate_tensors, inputs_embeds) + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) return hidden_states - def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): - return self.mamba_cache.copy_inputs_before_cuda_graphs( - input_buffers, **kwargs) - - def get_seqlen_agnostic_capture_inputs(self, batch_size: int): - return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - @classmethod def get_mamba_state_dtype_from_config( cls, @@ -991,12 +872,10 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid): def get_mamba_state_shape_from_config( cls, vllm_config: "VllmConfig", - use_v1: bool = True, ) -> tuple[tuple[int, int], tuple[int, int, int]]: """Calculate shapes for Mamba's convolutional and state caches. Args: vllm_config: vLLM config - use_v1: Get shapes for V1 (or V0) Returns: Tuple containing: - conv_state_shape: Shape for convolutional state cache @@ -1015,7 +894,6 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid): head_dim=hf_config.hidden_size_per_head, state_size=hf_config.mamba_d_state, conv_kernel=hf_config.mamba_d_conv, - use_v1=use_v1, ) def compute_logits( diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 73b27572a8ebd..b740e6d87b745 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -59,7 +59,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalFieldConfig -from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope @@ -74,7 +73,7 @@ from .qwen2_vl import (Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo, from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) -from .vision import get_vit_attn_backend +from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model logger = init_logger(__name__) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index b3c42c2572566..472e8b061a9e1 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -77,7 +77,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMRoPE, from .utils import (AutoWeightsLoader, WeightsMapper, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) -from .vision import get_vit_attn_backend +from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model logger = init_logger(__name__) @@ -217,17 +217,20 @@ class Qwen2VisionMLP(nn.Module): act_layer: type[nn.Module] = QuickGELU, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ): super().__init__() self.fc1 = ColumnParallelLinear(in_features, hidden_features, quant_config=quant_config, - prefix=f"{prefix}.fc1") + prefix=f"{prefix}.fc1", + disable_tp=use_data_parallel) self.act = act_layer() self.fc2 = RowParallelLinear(hidden_features, in_features, quant_config=quant_config, - prefix=f"{prefix}.fc2") + prefix=f"{prefix}.fc2", + disable_tp=use_data_parallel) def forward(self, x: torch.Tensor) -> torch.Tensor: x_parallel, _ = self.fc1(x) @@ -293,25 +296,28 @@ class Qwen2VisionAttention(nn.Module): projection_size: int, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() # Per attention head and per partition values. - world_size = parallel_state.get_tensor_model_parallel_world_size() - self.tp_size = world_size + self.tp_size = (1 if use_data_parallel else + parallel_state.get_tensor_model_parallel_world_size()) self.tp_rank = parallel_state.get_tensor_model_parallel_rank() self.hidden_size_per_attention_head = dist_utils.divide( projection_size, num_heads) self.num_attention_heads_per_partition = dist_utils.divide( - num_heads, world_size) + num_heads, self.tp_size) self.qkv = ColumnParallelLinear(input_size=embed_dim, output_size=3 * projection_size, quant_config=quant_config, - prefix=f"{prefix}.qkv") + prefix=f"{prefix}.qkv", + disable_tp=use_data_parallel) self.proj = RowParallelLinear(input_size=projection_size, output_size=embed_dim, quant_config=quant_config, - prefix=f"{prefix}.proj") + prefix=f"{prefix}.proj", + disable_tp=use_data_parallel) # Detect attention implementation. self.attn_backend = get_vit_attn_backend( @@ -453,6 +459,7 @@ class Qwen2VisionBlock(nn.Module): norm_layer: Optional[Callable[[int], nn.Module]] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() if norm_layer is None: @@ -465,12 +472,14 @@ class Qwen2VisionBlock(nn.Module): num_heads=num_heads, projection_size=dim, quant_config=quant_config, - prefix=f"{prefix}.attn") + prefix=f"{prefix}.attn", + use_data_parallel=use_data_parallel) self.mlp = Qwen2VisionMLP(dim, mlp_hidden_dim, act_layer=act_layer, quant_config=quant_config, - prefix=f"{prefix}.mlp") + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel) def forward( self, @@ -531,6 +540,7 @@ class Qwen2VisionPatchMerger(nn.Module): spatial_merge_size: int = 2, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.hidden_size = context_dim * (spatial_merge_size**2) @@ -542,13 +552,15 @@ class Qwen2VisionPatchMerger(nn.Module): self.hidden_size, bias=True, quant_config=quant_config, - prefix=f"{prefix}.mlp.0"), + prefix=f"{prefix}.mlp.0", + disable_tp=use_data_parallel), nn.GELU(), RowParallelLinear(self.hidden_size, d_model, bias=True, quant_config=quant_config, - prefix=f"{prefix}.mlp.2"), + prefix=f"{prefix}.mlp.2", + disable_tp=use_data_parallel), ]) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -600,6 +612,7 @@ class Qwen2VisionTransformer(nn.Module): norm_eps: float = 1e-6, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() @@ -613,6 +626,9 @@ class Qwen2VisionTransformer(nn.Module): num_heads = vision_config.num_heads mlp_ratio = vision_config.mlp_ratio + self.use_data_parallel = use_data_parallel + self.out_hidden_size = vision_config.hidden_size + self.spatial_merge_size = spatial_merge_size self.num_heads = num_heads self.embed_dim = embed_dim @@ -634,7 +650,8 @@ class Qwen2VisionTransformer(nn.Module): mlp_ratio=mlp_ratio, norm_layer=norm_layer, quant_config=quant_config, - prefix=f"{prefix}.blocks.{layer_idx}") + prefix=f"{prefix}.blocks.{layer_idx}", + use_data_parallel=use_data_parallel) for layer_idx in range(depth) ]) self.merger = Qwen2VisionPatchMerger( @@ -643,6 +660,7 @@ class Qwen2VisionTransformer(nn.Module): norm_layer=norm_layer, quant_config=quant_config, prefix=f"{prefix}.merger", + use_data_parallel=use_data_parallel, ) self.attn_backend = get_vit_attn_backend( head_size=head_dim, dtype=torch.get_default_dtype()) @@ -659,8 +677,9 @@ class Qwen2VisionTransformer(nn.Module): def device(self) -> torch.device: return self.patch_embed.proj.weight.device - def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + def rot_pos_emb(self, grid_thw: list[list[int]]) -> torch.Tensor: pos_ids = [] + max_grid_size = 0 for t, h, w in grid_thw: hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) @@ -678,8 +697,8 @@ class Qwen2VisionTransformer(nn.Module): ).permute(0, 2, 1, 3).flatten() pos_ids.append( torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + max_grid_size = max(max_grid_size, h, w) pos_ids = torch.cat(pos_ids, dim=0) - max_grid_size = grid_thw[:, 1:].max() rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb @@ -698,7 +717,7 @@ class Qwen2VisionTransformer(nn.Module): def forward( self, x: torch.Tensor, - grid_thw: torch.Tensor, + grid_thw: list[list[int]], ) -> torch.Tensor: # patchify x = x.to(device=self.device, dtype=self.dtype) @@ -708,8 +727,9 @@ class Qwen2VisionTransformer(nn.Module): rotary_pos_emb = self.rot_pos_emb(grid_thw) # compute cu_seqlens - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], - grid_thw[:, 0]).cumsum( + grid_thw_ = torch.tensor(grid_thw) + cu_seqlens = torch.repeat_interleave(grid_thw_[:, 1] * grid_thw_[:, 2], + grid_thw_[:, 0]).cumsum( dim=0, dtype=torch.int32) cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) @@ -1112,6 +1132,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, "model.": "language_model.model.", }) + supports_encoder_tp_data = True + def get_mrope_input_positions( self, input_tokens: list[int], @@ -1239,6 +1261,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.config = config self.multimodal_config = multimodal_config @@ -1249,6 +1272,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, norm_eps=getattr(config, "rms_norm_eps", 1e-6), quant_config=self._maybe_ignore_quant_config(quant_config), prefix=maybe_prefix(prefix, "visual"), + use_data_parallel=self.use_data_parallel, ) else: self.visual = None @@ -1357,7 +1381,15 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, image_embeds = image_input["image_embeds"] else: pixel_values = image_input["pixel_values"] - image_embeds = self.visual(pixel_values, grid_thw=grid_thw) + + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model(self.visual, + pixel_values, + grid_thw_list, + rope_type="rope_3d") + else: + image_embeds = self.visual(pixel_values, + grid_thw=grid_thw_list) # Split concatenated embeddings for each image item. merge_size = self.visual.spatial_merge_size @@ -1377,7 +1409,14 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, video_embeds = video_input["video_embeds"] else: pixel_values_videos = video_input["pixel_values_videos"] - video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model(self.visual, + pixel_values_videos, + grid_thw_list, + rope_type="rope_3d") + else: + video_embeds = self.visual(pixel_values_videos, + grid_thw=grid_thw_list) # Split concatenated embeddings for each video item. merge_size = self.visual.spatial_merge_size diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 24cebc5bfdd82..ab23b494e561e 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -11,7 +11,6 @@ from einops import rearrange from torch import nn from transformers.activations import ACT2FN -from vllm import envs from vllm.attention import Attention, AttentionBackend, AttentionMetadata from vllm.compilation.decorators import support_torch_compile from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig, @@ -35,7 +34,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.abstract import MambaBase -from vllm.model_executor.layers.mamba.mamba2_metadata import update_metadata from vllm.model_executor.layers.mamba.mamba_mixer2 import ( mamba_v2_sharded_weight_loader) from vllm.model_executor.layers.mamba.mamba_utils import ( @@ -51,7 +49,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, sharded_weight_loader) -from vllm.model_executor.models.mamba_cache import MambaCacheParams from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform @@ -198,14 +195,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: return MambaStateShapeCalculator.gated_delta_net_state_shape( - self.tp_size, - self.num_k_heads, - self.num_v_heads, - self.head_k_dim, - self.head_v_dim, - self.conv_kernel_size, - self.num_spec, - use_v1=True) + self.tp_size, self.num_k_heads, self.num_v_heads, self.head_k_dim, + self.head_v_dim, self.conv_kernel_size, self.num_spec) def __init__( self, @@ -394,7 +385,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): self, hidden_states: torch.Tensor, output: torch.Tensor, - cache_params: Optional[MambaCacheParams] = None, ): return torch.ops.vllm.gdn_attention( hidden_states, @@ -416,7 +406,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): assert isinstance(attn_metadata, dict) attn_metadata = attn_metadata[self.prefix] - conv_metadata = attn_metadata assert isinstance(attn_metadata, GDNAttentionMetadata) has_initial_state = attn_metadata.has_initial_state spec_query_start_loc = attn_metadata.spec_query_start_loc @@ -479,12 +468,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): # 2.2: process the remaining part if attn_metadata.num_prefills > 0: mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1) - if conv_metadata.cu_seqlen is None: - conv_metadata = update_metadata(mixed_qkv_non_spec_T, - non_spec_query_start_loc, - conv_metadata) # - "cache_indices" updates the conv_state cache in positions - # pointed to by "mamba_cache_params.state_indices_tensor" + # pointed to by "state_indices_tensor" mixed_qkv_non_spec = causal_conv1d_fn( mixed_qkv_non_spec_T, conv_weights, @@ -494,7 +479,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): has_initial_state=has_initial_state, cache_indices=non_spec_state_indices_tensor, query_start_loc=non_spec_query_start_loc, - metadata=conv_metadata, + metadata=attn_metadata, ).transpose(0, 1) elif attn_metadata.num_decodes > 0: mixed_qkv_non_spec = causal_conv1d_update( @@ -1075,7 +1060,6 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, scheduler_config = vllm_config.scheduler_config assert not cache_config.enable_prefix_caching, \ "Qwen3Next currently does not support prefix caching" - assert envs.VLLM_USE_V1, "Qwen3Next requires VLLM_USE_V1" self.quant_config = vllm_config.quant_config super().__init__() @@ -1195,14 +1179,10 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, num_spec = (vllm_config.speculative_config.num_speculative_tokens if vllm_config.speculative_config else 0) return MambaStateShapeCalculator.gated_delta_net_state_shape( - tp_size, - hf_config.linear_num_key_heads, - hf_config.linear_num_value_heads, - hf_config.linear_key_head_dim, - hf_config.linear_value_head_dim, - hf_config.linear_conv_kernel_dim, - num_spec, - use_v1=True) + tp_size, hf_config.linear_num_key_heads, + hf_config.linear_num_value_heads, hf_config.linear_key_head_dim, + hf_config.linear_value_head_dim, hf_config.linear_conv_kernel_dim, + num_spec) def compute_logits( self, diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 98d65dea27393..ee6703f7229e5 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -83,7 +83,7 @@ from .qwen2_vl import Qwen2VLProcessingInfo from .qwen3 import Qwen3ForCausalLM, Qwen3Model from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, maybe_prefix, merge_multimodal_embeddings) -from .vision import get_vit_attn_backend +from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model logger = init_logger(__name__) @@ -1214,8 +1214,6 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal, else: pixel_values = image_input["pixel_values"].type(self.visual.dtype) if self.use_data_parallel: - from vllm.multimodal.utils import ( - run_dp_sharded_mrope_vision_model) return run_dp_sharded_mrope_vision_model(self.visual, pixel_values, grid_thw_list, @@ -1245,8 +1243,6 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal, pixel_values_videos = video_input["pixel_values_videos"].type( self.visual.dtype) if self.use_data_parallel: - from vllm.multimodal.utils import ( - run_dp_sharded_mrope_vision_model) return run_dp_sharded_mrope_vision_model(self.visual, pixel_values_videos, grid_thw_list, diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 5dc5d545bb9c5..6ab3fa902c387 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -134,7 +134,6 @@ _TEXT_GENERATION_MODELS = { "PhiForCausalLM": ("phi", "PhiForCausalLM"), "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), - "Phi4FlashForCausalLM": ("phi4flash", "Phi4FlashForCausalLM"), "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"), "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), @@ -219,6 +218,7 @@ _MULTIMODAL_MODELS = { "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501 "Cohere2VisionForConditionalGeneration": ("cohere2_vision", "Cohere2VisionForConditionalGeneration"), # noqa: E501 "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"), + "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"), "Ernie4_5_VLMoeForConditionalGeneration": ("ernie45_vl", "Ernie4_5_VLMoeForConditionalGeneration"), # noqa: E501 "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501 diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py index f667266b77bfa..5f6ad58850439 100644 --- a/vllm/model_executor/models/step3_vl.py +++ b/vllm/model_executor/models/step3_vl.py @@ -31,7 +31,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.multimodal.utils import run_dp_sharded_vision_model from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import Step3VisionEncoderConfig from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -40,6 +39,7 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) +from .vision import run_dp_sharded_vision_model class Step3VLImagePixelInputs(TypedDict): diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index 81f86db7e1875..08ad8fbeb4246 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -1,12 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import itertools +import math from abc import ABC, abstractmethod -from typing import Final, Generic, Optional, Protocol, TypeVar, Union +from typing import Final, Generic, Literal, Optional, Protocol, TypeVar, Union import torch from transformers import PretrainedConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather) from vllm.logger import init_logger from vllm.platforms import _Backend, current_platform @@ -123,3 +128,277 @@ def resolve_visual_encoder_outputs( if post_layer_norm is not None and uses_last_layer: hs_pool[-1] = post_layer_norm(encoder_outputs) return torch.cat(hs_pool, dim=-1) + + +def run_dp_sharded_vision_model(image_input: torch.Tensor, + vision_model: torch.nn.Module) -> torch.Tensor: + """Run a vision model with data parallelism (DP) sharding. The function + will shard the input image tensor on the first dimension and run the vision + model + + Args: + image_input (torch.Tensor): Image input tensor. + vision_model (torch.nn.Module): Vision model. + Returns: + torch.Tensor: Output image embeddings + """ + + num_chunks = image_input.shape[0] + mp_world_size = get_tensor_model_parallel_world_size() + num_chunks_per_rank = (num_chunks + mp_world_size - 1) // mp_world_size + num_padded_chunks = num_chunks_per_rank * mp_world_size - num_chunks + pad = (0, ) * (2 * (image_input.dim() - 1)) + (0, num_padded_chunks) + image_input_padded = torch.nn.functional.pad(image_input, pad) + rank = get_tensor_model_parallel_rank() + image_input_per_rank = image_input_padded[rank * + num_chunks_per_rank:(rank + 1) * + num_chunks_per_rank, ...] + + vision_embeddings = vision_model(image_input_per_rank) + # Ensure tensor is contiguous before all_gather + vision_embeddings = vision_embeddings.contiguous() + vision_embeddings = tensor_model_parallel_all_gather(vision_embeddings, + dim=0) + vision_embeddings = vision_embeddings[:num_chunks, ...] + return vision_embeddings + + +def get_load_balance_assignment( + sizes: list[int], + num_gpus: int = 2, +) -> tuple[list[int], list[int], list[int]]: + """ + Generate load balancing assignment and metadata + for distributing data across GPUs. + The load is determined by the total image sizes, + not the number of images. + + Args: + sizes: The size of each image + num_gpus: Number of GPUs to balance across + + Returns: + shuffle_indices: + Indices to reorder data for balanced loading + gpu_sample_counts: + Number of samples assigned to each GPU + grouped_sizes_per_gpu: + Total size assigned to each GPU + + Example: + ``` + sizes = [1000, 100, 200, 50] + num_gpus=2 + ``` + + """ + + n_samples = len(sizes) + + # Handle edge cases + if n_samples == 0: + return [], [0] * num_gpus, [0] * num_gpus + + # Use greedy algorithm - balance by total size, not sample count + gpu_assignments = [list[int]() for _ in range(num_gpus)] + gpu_loads = [0] * num_gpus # This tracks total SIZE, not sample count + + # Sort indices by size (largest first for better load balancing) + # sizes = [1000, 100, 200, 50] + # large_to_small_indices = [0, 2, 1, 3] + large_to_small_indices = sorted(range(n_samples), + key=lambda i: sizes[i], + reverse=True) + + for idx in large_to_small_indices: + # Find GPU with minimum current load (by total size) + min_gpu = min(range(num_gpus), key=lambda i: gpu_loads[i]) + gpu_assignments[min_gpu].append(idx) + gpu_loads[min_gpu] += sizes[idx] + + # Create shuffle indices and counts + shuffle_indices = list[int]() + gpu_sample_counts = list[int]() + for gpu_id in range(num_gpus): + # GPU_0 = [1000] = [0] + # GPU_1 = [200, 100, 50] = [2, 1, 3] + # shuffle_indices = [0, 2, 1, 3] + shuffle_indices.extend(gpu_assignments[gpu_id]) + # GPU_0 = [1] + # GPU_1 = [3] + # gpu_sample_counts = [1, 3] + gpu_sample_counts.append(len(gpu_assignments[gpu_id])) + + return (shuffle_indices, gpu_sample_counts, gpu_loads) + + +def run_dp_sharded_mrope_vision_model( + vision_model: torch.nn.Module, + pixel_values: torch.Tensor, + grid_thw_list: list[list[int]], + *, + rope_type: Literal["rope_3d", "rope_2d"], +) -> tuple[torch.Tensor, ...]: + """Run a vision model with data parallelism (DP) sharding. + The function will shard the input image tensor on the + first dimension and run the vision model. + This function is used to run the vision model with mrope. + + Args: + vision_model (torch.nn.Module): Vision model. + pixel_values (torch.Tensor): Image/Video input tensor. + grid_thw_list: List of grid dimensions for each image + rope_type: Type of rope used in the vision model. + Different rope types have different dimension to do ViT. + "rope_3d" for 3D rope (e.g., Qwen2.5-VL) + "rope_2d" for 2D rope (e.g., Kimi-VL) + Returns: + torch.Tensor: Output image embeddings + + Example: + ``` + vision_model.out_hidden_size = 64 + vision_model.spatial_merge_size = 2 + pixel_values.shape = (1350, channel) + grid_thw_list = [[1, 10, 100], [1, 10, 10], [1, 10, 20], [1, 50]] + tp_size=2 + ``` + + """ + tp_size = get_tensor_model_parallel_world_size() + + # GPU_0 tp_rank_local = 0 + # GPU_1 tp_rank_local = 1 + tp_rank_local = get_tensor_model_parallel_rank() + + # patches_per_image = [1000, 100, 200, 50] + patches_per_image = [math.prod(grid_thw) for grid_thw in grid_thw_list] + # patches_per_image = [0, 1000, 1100, 1300, 1350] + cum_patches_per_image = [0, *itertools.accumulate(patches_per_image)] + + # Get load balancing assignment with all metadata + # image_to_tp_rank = [0, 2, 1, 3] + # gpu_sample_counts = [1, 3] + # grouped_pixel_values_len = [1000, 350] + (image_to_tp_rank, gpu_sample_counts, + grouped_pixel_values_len) = get_load_balance_assignment( + patches_per_image, tp_size) + + # cu_gpu_sample_counts = [0, 1, 4] + cum_gpu_sample_counts = [0, *itertools.accumulate(gpu_sample_counts)] + + # GPU_0 image_idxs_local = [0] + # GPU_1 image_idxs_local = [2, 1, 3] + image_idxs_local = image_to_tp_rank[cum_gpu_sample_counts[tp_rank_local]: + cum_gpu_sample_counts[tp_rank_local + + 1]] + + # Get the pixel values for the local images based on the image_idxs_local + if len(image_idxs_local) > 0: + pixel_values_local = torch.cat([ + pixel_values[cum_patches_per_image[i]:cum_patches_per_image[i + 1]] + for i in image_idxs_local + ]) + else: + # Handle case where this rank has no images + pixel_values_local = torch.empty((0, pixel_values.shape[1]), + device=pixel_values.device, + dtype=pixel_values.dtype) + # embed_dim_reduction_factor = 2 * 2 + if rope_type == "rope_2d": + embed_dim_reduction_factor = (vision_model.merge_kernel_size[0] * + vision_model.merge_kernel_size[1]) + else: + embed_dim_reduction_factor = (vision_model.spatial_merge_size * + vision_model.spatial_merge_size) + + # Find the max length across all ranks + # The output embedding of every DP rank has to be + # padded to this length for tensor_model_parallel_all_gather + # to work + max_len_per_rank = max( + grouped_pixel_values_len) // embed_dim_reduction_factor + local_grid_thw_list = [grid_thw_list[i] for i in image_idxs_local] + + # Run the vision model on the local pixel_values_local + if rope_type == "rope_2d": + if pixel_values_local.shape[0] > 0: + image_embeds_local = vision_model( + pixel_values_local, torch.tensor(local_grid_thw_list)) + if isinstance(image_embeds_local, list): + image_embeds_local = torch.cat(image_embeds_local, dim=0) + else: + out_dim = getattr(vision_model.config, "hidden_size", None) + image_embeds_local = torch.empty( + (0, embed_dim_reduction_factor, out_dim), + device=pixel_values.device, + dtype=pixel_values.dtype) + else: + if pixel_values_local.shape[0] > 0: + image_embeds_local = vision_model(pixel_values_local, + local_grid_thw_list) + else: + # Handle empty case + image_embeds_local = torch.empty((0, vision_model.out_hidden_size), + device=pixel_values.device, + dtype=pixel_values.dtype) + + # Pad the output based on max_len_per_rank + # for tensor_model_parallel_all_gather to work + current_len = image_embeds_local.shape[0] + if current_len < max_len_per_rank: + padding_size = max_len_per_rank - current_len + if rope_type == "rope_2d": + padding = torch.empty((padding_size, image_embeds_local.shape[1], + image_embeds_local.shape[2]), + dtype=image_embeds_local.dtype, + device=image_embeds_local.device) + else: + padding = torch.empty((padding_size, image_embeds_local.shape[1]), + dtype=image_embeds_local.dtype, + device=image_embeds_local.device) + image_embeds_local_padded = torch.cat([image_embeds_local, padding], + dim=0) + else: + image_embeds_local_padded = image_embeds_local + + # Do all_gather to collect embeddings from all ranks + gathered_embeds = tensor_model_parallel_all_gather( + image_embeds_local_padded, dim=0) + + # Remove padding and reconstruct per-rank embeddings + rank_embeddings = list[torch.Tensor]() + for rank in range(tp_size): + start_idx = rank * max_len_per_rank + end_idx = start_idx + (grouped_pixel_values_len[rank] // + embed_dim_reduction_factor) + rank_embeddings.append(gathered_embeds[start_idx:end_idx]) + + patches_per_output_image = [(patch_size // embed_dim_reduction_factor) + for patch_size in patches_per_image] + + # Reconstruct embeddings in the original order + original_order_embeddings = [None] * len(grid_thw_list) + current_idx = 0 + for rank in range(tp_size): + count = gpu_sample_counts[rank] + if count > 0: + # Get images assigned to this rank in shuffled order + # GPU_0 = image_idxs_local [0] + # GPU_1 = image_idxs_local [2, 1, 3] + rank_images = image_to_tp_rank[current_idx:current_idx + count] + + rank_embed = rank_embeddings[rank] + # Split rank embeddings back to individual images + embed_start = 0 + for img_idx in rank_images: + img_patches = patches_per_output_image[img_idx] + original_order_embeddings[img_idx] = rank_embed[ + embed_start:embed_start + img_patches] + embed_start += img_patches + current_idx += count + out_embeddings = tuple(embed for embed in original_order_embeddings + if embed is not None) + assert len(out_embeddings) == len( + original_order_embeddings), "Found unassigned embeddings" + return out_embeddings diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index 4350e38e02f96..a0d93045b74cf 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -15,12 +15,10 @@ import torch from torch import nn from transformers import Zamba2Config -from vllm import envs from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -29,8 +27,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.mamba2_metadata import ( - Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator) @@ -39,8 +35,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.mamba_cache import (MambaCacheManager, - MambaCacheParams) from vllm.sequence import IntermediateTensors from .interfaces import HasInnerState, IsHybrid @@ -515,8 +509,6 @@ class Zamba2MambaDecoderLayer(nn.Module): def forward( self, hidden_states: torch.Tensor, - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, transformer_hidden_states: Optional[torch.Tensor] = None, positions: Optional[torch.Tensor] = None, original_hidden_states: Optional[torch.Tensor] = None, @@ -525,8 +517,6 @@ class Zamba2MambaDecoderLayer(nn.Module): Args: hidden_states: Input tensor [batch_size, seq_len, hidden_size] - mamba_cache_params: Parameters for Mamba's state caches - (one for conv, one for ssm) transformer_hidden_states: Optional output from transformer path Added to input if provided (used in hybrid architecture) positions: Optional position IDs (unused in Mamba) @@ -555,8 +545,6 @@ class Zamba2MambaDecoderLayer(nn.Module): self.mamba( hidden_states, output, - mamba_cache_params=mamba_cache_params, - mamba2_metadata=mamba2_metadata, ) # residual connection after mamba @@ -607,8 +595,6 @@ class Zamba2HybridLayer(nn.Module): hidden_states: torch.Tensor, original_hidden_states: torch.Tensor, positions: torch.Tensor, - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, ) -> torch.Tensor: """Forward pass through the hybrid layer. @@ -623,8 +609,6 @@ class Zamba2HybridLayer(nn.Module): original_hidden_states: Original input for transformer residual connection positions: Position IDs for positional embeddings - mamba_cache_params: Parameters for Mamba's state caches - (one for conv, one for ssm) Returns: Output tensor combining transformer and Mamba representations @@ -644,8 +628,6 @@ class Zamba2HybridLayer(nn.Module): layer_outputs = self.mamba_decoder( hidden_states, transformer_hidden_states=transformer_hidden_states, - mamba_cache_params=mamba_cache_params, - mamba2_metadata=mamba2_metadata, ) return layer_outputs @@ -752,7 +734,6 @@ class Zamba2Model(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - mamba_cache_params: MambaCacheParams, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: """Forward pass through the model. @@ -760,8 +741,6 @@ class Zamba2Model(nn.Module): Args: input_ids: Input token IDs positions: Position IDs for embeddings - mamba_cache_params: Parameters for Mamba's state caches - (one for conv, one for ssm) inputs_embeds: Optional pre-computed input embeddings Returns: @@ -773,33 +752,13 @@ class Zamba2Model(nn.Module): inputs_embeds = self.get_input_embeddings(input_ids) hidden_states = inputs_embeds - attn_metadata = get_forward_context().attn_metadata - - if not envs.VLLM_USE_V1: - mamba2_metadata = prepare_mamba2_metadata( - chunk_size=self.config.chunk_size, - attn_metadata=attn_metadata, - ) - else: - # v1 get mamba2_metadata from forward_context - mamba2_metadata = None - # Process through layers original_hidden_states = torch.clone(hidden_states) for layer_idx, layer in enumerate(self.layers): - - layer_mamba_cache_params = None - if (isinstance(layer, (Zamba2HybridLayer, Zamba2MambaDecoderLayer)) - and mamba_cache_params): - layer_mamba_cache_params = mamba_cache_params.at_layer_idx( - layer_idx) - layer_outputs = layer( hidden_states, original_hidden_states=original_hidden_states, positions=positions, - mamba_cache_params=layer_mamba_cache_params, - mamba2_metadata=mamba2_metadata, ) hidden_states = layer_outputs @@ -870,13 +829,11 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): def get_mamba_state_shape_from_config( cls, vllm_config: "VllmConfig", - use_v1: bool = True, ) -> tuple[tuple[int, int], tuple[int, int, int]]: """Calculate shapes for Mamba's convolutional and state caches. Args: vllm_config: vLLM config - use_v1: Get shapes for V1 (or V0) Returns: Tuple containing: @@ -896,7 +853,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): head_dim=hf_config.mamba_headdim, state_size=hf_config.mamba_d_state, conv_kernel=hf_config.mamba_d_conv, - use_v1=use_v1, ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: @@ -945,9 +901,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): # Tie weights with input embeddings if using same dimensions self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) - # Used to track and store by the Mamba cache between steps. - self.mamba_cache: Optional[MambaCacheManager] = None - # Initialize logits processing and sampling self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) @@ -977,61 +930,15 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): Returns: Output hidden states """ - # Initialize Mamba cache if needed - mamba_cache_params = None - if not envs.VLLM_USE_V1: - if self.mamba_cache is None: - num_mamba_layers = self.config.num_hidden_layers - mamba_state_shape = \ - self.get_mamba_state_shape_from_config( - self.vllm_config, use_v1=False) - mamba_state_dtype = \ - self.get_mamba_state_dtype_from_config( - self.vllm_config) - self.mamba_cache = MambaCacheManager(self.vllm_config, - num_mamba_layers, - *mamba_state_shape, - *mamba_state_dtype) - - # Get cache parameters for current run - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - # Forward pass through model hidden_states = self.model( input_ids, positions, - mamba_cache_params, inputs_embeds, ) return hidden_states - def copy_inputs_before_cuda_graphs( - self, input_buffers: dict[str, torch.Tensor], - **kwargs: Any) -> dict[str, torch.Tensor]: - """Copy inputs before CUDA graph capture. - - Args: - input_buffers: Dictionary of input tensors - **kwargs: Additional arguments passed to cache manager - - Returns: - Updated input buffers - """ - return self.mamba_cache.copy_inputs_before_cuda_graphs( - input_buffers, **kwargs) - - def get_seqlen_agnostic_capture_inputs( - self, batch_size: int) -> dict[str, torch.Tensor]: - """Get inputs for sequence-length-agnostic graph capture. - - Args: - batch_size: Size of batch to capture - Returns: - Dictionary of capture inputs - """ - return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/multimodal/__init__.py b/vllm/multimodal/__init__.py index 7ffa732cf3708..8ea79078465e6 100644 --- a/vllm/multimodal/__init__.py +++ b/vllm/multimodal/__init__.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from .base import MultiModalPlaceholderMap from .hasher import MultiModalHasher from .inputs import (BatchedTensorInputs, ModalityData, MultiModalDataBuiltins, MultiModalDataDict, MultiModalKwargs, @@ -27,7 +26,6 @@ __all__ = [ "MultiModalKwargs", "MultiModalKwargsItems", "MultiModalPlaceholderDict", - "MultiModalPlaceholderMap", "MultiModalUUIDDict", "NestedTensors", "MULTIMODAL_REGISTRY", diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index e0edb3e883ed6..faffddd57199d 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -3,83 +3,11 @@ from abc import ABC, abstractmethod from pathlib import Path -from typing import Generic, NamedTuple, TypeVar +from typing import Generic, TypeVar _T = TypeVar("_T") -class MultiModalPlaceholderMap: - """ - Relates multi-modal embeddings to their corresponding placeholders. - - Note: This is only used in V0. - """ - - class IndexMap(NamedTuple): - src: list[int] - dest: list[int] - - src_ranges: list[range] - """ - The indices of the multi-modal embeddings that will replace the - corresponding placeholder embeddings pointed to by ``dest_ranges``. - """ - - src_len: int - """ - The total number of flattened multi-modal embeddings. - """ - - dest_ranges: list[range] - """ - The indices of the placeholder embeddings that will be replaced by the - multimodal embeddings. - """ - - dest_len: int - """ - The total number of embeddings in the destination tensor. - """ - - def __init__(self): - self.src_ranges = [] - self.src_len = 0 - self.dest_ranges = [] - self.dest_len = 0 - - def extend(self, other: "MultiModalPlaceholderMap"): - """ - Adds the placeholders from another ``MultiModalPlaceholderMap`` to this - instance based on the source and destination tensors being - concatenated. - """ - - self.src_ranges.extend( - range(self.src_len + r.start, self.src_len + r.stop) - for r in other.src_ranges) - self.src_len += other.src_len - self.dest_ranges.extend( - range(self.dest_len + r.start, self.dest_len + r.stop) - for r in other.dest_ranges) - self.dest_len += other.dest_len - - def index_map(self) -> "IndexMap": - """ - Finalizes the placeholder map into lists of indices that can be used to - index the source and destination tensors. - """ - - src_indices = [i for r in self.src_ranges for i in r] - dest_indices = [i for r in self.dest_ranges for i in r] - - if len(src_indices) != len(dest_indices): - raise ValueError( - f"The number of source ({len(src_indices)}) and destination " - f"indices ({len(dest_indices)}) must be the same.") - - return self.IndexMap(src=src_indices, dest=dest_indices) - - class MediaIO(ABC, Generic[_T]): @abstractmethod diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 38adbf8f3536a..5d485bc361d11 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -12,8 +12,7 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, cached_tokenizer_from_config) from vllm.utils import ClassRegistry -from .cache import (BaseMultiModalProcessorCache, - processor_only_cache_from_config) +from .cache import BaseMultiModalProcessorCache from .processing import BaseMultiModalProcessor, BaseProcessingInfo from .profiling import (BaseDummyInputsBuilder, DummyDecoderData, DummyEncoderData, MultiModalProfiler) @@ -176,35 +175,6 @@ class MultiModalRegistry: if mm_limits[key] > 0 } - # TODO: Remove once V0 is gone - def get_max_tokens_by_modality( - self, - model_config: "ModelConfig", - ) -> Mapping[str, int]: - """ - Get the maximum number of tokens from each modality - for profiling the memory usage of a model. - """ - cache = processor_only_cache_from_config(model_config, self) - mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache) - max_tokens_per_item = self.get_max_tokens_per_item_by_modality( - model_config, - cache=cache, - ) - - return { - key: mm_limits[key] * max_tokens_per_mm_item - for key, max_tokens_per_mm_item in max_tokens_per_item.items() - } - - # TODO: Remove once V0 is gone - def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: - """ - Get the maximum number of multi-modal tokens - for profiling the memory usage of a model. - """ - return sum(self.get_max_tokens_by_modality(model_config).values()) - def get_mm_limits_per_prompt( self, model_config: "ModelConfig", diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index f4e2ed72e2d7d..0f8aeceb39448 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -3,13 +3,11 @@ import asyncio import atexit -import itertools -import math from collections.abc import Iterable from concurrent.futures import ThreadPoolExecutor from itertools import groupby from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union from urllib.parse import ParseResult, urlparse from urllib.request import url2pathname @@ -21,9 +19,6 @@ from typing_extensions import deprecated import vllm.envs as envs from vllm.connections import HTTPConnection, global_http_connection -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_gather) from .audio import AudioMediaIO from .base import MediaIO @@ -33,12 +28,10 @@ from .video import VideoMediaIO _M = TypeVar("_M") if TYPE_CHECKING: - from .inputs import (BatchedTensorInputs, MultiModalKwargs, - MultiModalKwargsItem, MultiModalKwargsItems, - MultiModalPlaceholderDict) + from .inputs import (BatchedTensorInputs, MultiModalKwargsItem, + MultiModalKwargsItems, MultiModalPlaceholderDict) else: BatchedTensorInputs = Any - MultiModalKwargs = Any MultiModalKwargsItem = Any MultiModalKwargsItems = Any MultiModalPlaceholderDict = Any @@ -93,7 +86,7 @@ class MediaConnector: self, url_spec: ParseResult, media_io: MediaIO[_M], - ) -> _M: + ) -> _M: # type: ignore[type-var] data_spec, data = url_spec.path.split(",", 1) media_type, data_type = data_spec.split(";", 1) @@ -107,7 +100,7 @@ class MediaConnector: self, url_spec: ParseResult, media_io: MediaIO[_M], - ) -> _M: + ) -> _M: # type: ignore[type-var] allowed_local_media_path = self.allowed_local_media_path if allowed_local_media_path is None: raise RuntimeError("Cannot load local files without " @@ -127,7 +120,7 @@ class MediaConnector: media_io: MediaIO[_M], *, fetch_timeout: Optional[int] = None, - ) -> _M: + ) -> _M: # type: ignore[type-var] url_spec = urlparse(url) if url_spec.scheme.startswith("http"): @@ -434,280 +427,6 @@ def group_mm_kwargs_by_modality( yield modality, len(items_lst), mm_kwargs_group -def run_dp_sharded_vision_model(image_input: torch.Tensor, - vision_model: torch.nn.Module) -> torch.Tensor: - """Run a vision model with data parallelism (DP) sharding. The function - will shard the input image tensor on the first dimension and run the vision - model - - Args: - image_input (torch.Tensor): Image input tensor. - vision_model (torch.nn.Module): Vision model. - Returns: - torch.Tensor: Output image embeddings - """ - - num_chunks = image_input.shape[0] - mp_world_size = get_tensor_model_parallel_world_size() - num_chunks_per_rank = (num_chunks + mp_world_size - 1) // mp_world_size - num_padded_chunks = num_chunks_per_rank * mp_world_size - num_chunks - pad = (0, ) * (2 * (image_input.dim() - 1)) + (0, num_padded_chunks) - image_input_padded = torch.nn.functional.pad(image_input, pad) - rank = get_tensor_model_parallel_rank() - image_input_per_rank = image_input_padded[rank * - num_chunks_per_rank:(rank + 1) * - num_chunks_per_rank, ...] - - vision_embeddings = vision_model(image_input_per_rank) - # Ensure tensor is contiguous before all_gather - vision_embeddings = vision_embeddings.contiguous() - vision_embeddings = tensor_model_parallel_all_gather(vision_embeddings, - dim=0) - vision_embeddings = vision_embeddings[:num_chunks, ...] - return vision_embeddings - - -def get_load_balance_assignment( - sizes: list[int], - num_gpus: int = 2, -) -> tuple[list[int], list[int], list[int]]: - """ - Generate load balancing assignment and metadata - for distributing data across GPUs. - The load is determined by the total image sizes, - not the number of images. - - Args: - sizes: The size of each image - num_gpus: Number of GPUs to balance across - - Returns: - shuffle_indices: - Indices to reorder data for balanced loading - gpu_sample_counts: - Number of samples assigned to each GPU - grouped_sizes_per_gpu: - Total size assigned to each GPU - - Example: - ``` - sizes = [1000, 100, 200, 50] - num_gpus=2 - ``` - - """ - - n_samples = len(sizes) - - # Handle edge cases - if n_samples == 0: - return [], [0] * num_gpus, [0] * num_gpus - - # Use greedy algorithm - balance by total size, not sample count - gpu_assignments = [list[int]() for _ in range(num_gpus)] - gpu_loads = [0] * num_gpus # This tracks total SIZE, not sample count - - # Sort indices by size (largest first for better load balancing) - # sizes = [1000, 100, 200, 50] - # large_to_small_indices = [0, 2, 1, 3] - large_to_small_indices = sorted(range(n_samples), - key=lambda i: sizes[i], - reverse=True) - - for idx in large_to_small_indices: - # Find GPU with minimum current load (by total size) - min_gpu = min(range(num_gpus), key=lambda i: gpu_loads[i]) - gpu_assignments[min_gpu].append(idx) - gpu_loads[min_gpu] += sizes[idx] - - # Create shuffle indices and counts - shuffle_indices = list[int]() - gpu_sample_counts = list[int]() - for gpu_id in range(num_gpus): - # GPU_0 = [1000] = [0] - # GPU_1 = [200, 100, 50] = [2, 1, 3] - # shuffle_indices = [0, 2, 1, 3] - shuffle_indices.extend(gpu_assignments[gpu_id]) - # GPU_0 = [1] - # GPU_1 = [3] - # gpu_sample_counts = [1, 3] - gpu_sample_counts.append(len(gpu_assignments[gpu_id])) - - return (shuffle_indices, gpu_sample_counts, gpu_loads) - - -def run_dp_sharded_mrope_vision_model( - vision_model: torch.nn.Module, - pixel_values: torch.Tensor, - grid_thw_list: list[list[int]], - *, - rope_type: Literal["rope_3d", "rope_2d"], -) -> tuple[torch.Tensor, ...]: - """Run a vision model with data parallelism (DP) sharding. - The function will shard the input image tensor on the - first dimension and run the vision model. - This function is used to run the vision model with mrope. - - Args: - vision_model (torch.nn.Module): Vision model. - pixel_values (torch.Tensor): Image/Video input tensor. - grid_thw_list: List of grid dimensions for each image - rope_type: Type of rope used in the vision model. - Different rope types have different dimension to do ViT. - "rope_3d" for 3D rope (e.g., Qwen2.5-VL) - "rope_2d" for 2D rope (e.g., Kimi-VL) - Returns: - torch.Tensor: Output image embeddings - - Example: - ``` - vision_model.out_hidden_size = 64 - vision_model.spatial_merge_size = 2 - pixel_values.shape = (1350, channel) - grid_thw_list = [[1, 10, 100], [1, 10, 10], [1, 10, 20], [1, 50]] - tp_size=2 - ``` - - """ - tp_size = get_tensor_model_parallel_world_size() - - # GPU_0 tp_rank_local = 0 - # GPU_1 tp_rank_local = 1 - tp_rank_local = get_tensor_model_parallel_rank() - - # patches_per_image = [1000, 100, 200, 50] - patches_per_image = [math.prod(grid_thw) for grid_thw in grid_thw_list] - # patches_per_image = [0, 1000, 1100, 1300, 1350] - cum_patches_per_image = [0, *itertools.accumulate(patches_per_image)] - - # Get load balancing assignment with all metadata - # image_to_tp_rank = [0, 2, 1, 3] - # gpu_sample_counts = [1, 3] - # grouped_pixel_values_len = [1000, 350] - (image_to_tp_rank, gpu_sample_counts, - grouped_pixel_values_len) = get_load_balance_assignment( - patches_per_image, tp_size) - - # cu_gpu_sample_counts = [0, 1, 4] - cum_gpu_sample_counts = [0, *itertools.accumulate(gpu_sample_counts)] - - # GPU_0 image_idxs_local = [0] - # GPU_1 image_idxs_local = [2, 1, 3] - image_idxs_local = image_to_tp_rank[cum_gpu_sample_counts[tp_rank_local]: - cum_gpu_sample_counts[tp_rank_local + - 1]] - - # Get the pixel values for the local images based on the image_idxs_local - if len(image_idxs_local) > 0: - pixel_values_local = torch.cat([ - pixel_values[cum_patches_per_image[i]:cum_patches_per_image[i + 1]] - for i in image_idxs_local - ]) - else: - # Handle case where this rank has no images - pixel_values_local = torch.empty((0, pixel_values.shape[1]), - device=pixel_values.device, - dtype=pixel_values.dtype) - # embed_dim_reduction_factor = 2 * 2 - if rope_type == "rope_2d": - embed_dim_reduction_factor = (vision_model.merge_kernel_size[0] * - vision_model.merge_kernel_size[1]) - else: - embed_dim_reduction_factor = (vision_model.spatial_merge_size * - vision_model.spatial_merge_size) - - # Find the max length across all ranks - # The output embedding of every DP rank has to be - # padded to this length for tensor_model_parallel_all_gather - # to work - max_len_per_rank = max( - grouped_pixel_values_len) // embed_dim_reduction_factor - local_grid_thw_list = [grid_thw_list[i] for i in image_idxs_local] - - # Run the vision model on the local pixel_values_local - if rope_type == "rope_2d": - if pixel_values_local.shape[0] > 0: - image_embeds_local = vision_model( - pixel_values_local, torch.tensor(local_grid_thw_list)) - if isinstance(image_embeds_local, list): - image_embeds_local = torch.cat(image_embeds_local, dim=0) - else: - out_dim = getattr(vision_model.config, "hidden_size", None) - image_embeds_local = torch.empty( - (0, embed_dim_reduction_factor, out_dim), - device=pixel_values.device, - dtype=pixel_values.dtype) - else: - if pixel_values_local.shape[0] > 0: - image_embeds_local = vision_model(pixel_values_local, - local_grid_thw_list) - else: - # Handle empty case - image_embeds_local = torch.empty((0, vision_model.out_hidden_size), - device=pixel_values.device, - dtype=pixel_values.dtype) - - # Pad the output based on max_len_per_rank - # for tensor_model_parallel_all_gather to work - current_len = image_embeds_local.shape[0] - if current_len < max_len_per_rank: - padding_size = max_len_per_rank - current_len - if rope_type == "rope_2d": - padding = torch.empty((padding_size, image_embeds_local.shape[1], - image_embeds_local.shape[2]), - dtype=image_embeds_local.dtype, - device=image_embeds_local.device) - else: - padding = torch.empty((padding_size, image_embeds_local.shape[1]), - dtype=image_embeds_local.dtype, - device=image_embeds_local.device) - image_embeds_local_padded = torch.cat([image_embeds_local, padding], - dim=0) - else: - image_embeds_local_padded = image_embeds_local - - # Do all_gather to collect embeddings from all ranks - gathered_embeds = tensor_model_parallel_all_gather( - image_embeds_local_padded, dim=0) - - # Remove padding and reconstruct per-rank embeddings - rank_embeddings = list[torch.Tensor]() - for rank in range(tp_size): - start_idx = rank * max_len_per_rank - end_idx = start_idx + (grouped_pixel_values_len[rank] // - embed_dim_reduction_factor) - rank_embeddings.append(gathered_embeds[start_idx:end_idx]) - - patches_per_output_image = [(patch_size // embed_dim_reduction_factor) - for patch_size in patches_per_image] - - # Reconstruct embeddings in the original order - original_order_embeddings = [None] * len(grid_thw_list) - current_idx = 0 - for rank in range(tp_size): - count = gpu_sample_counts[rank] - if count > 0: - # Get images assigned to this rank in shuffled order - # GPU_0 = image_idxs_local [0] - # GPU_1 = image_idxs_local [2, 1, 3] - rank_images = image_to_tp_rank[current_idx:current_idx + count] - - rank_embed = rank_embeddings[rank] - # Split rank embeddings back to individual images - embed_start = 0 - for img_idx in rank_images: - img_patches = patches_per_output_image[img_idx] - original_order_embeddings[img_idx] = rank_embed[ - embed_start:embed_start + img_patches] - embed_start += img_patches - current_idx += count - out_embeddings = tuple(embed for embed in original_order_embeddings - if embed is not None) - assert len(out_embeddings) == len( - original_order_embeddings), "Found unassigned embeddings" - return out_embeddings - - def fetch_audio( audio_url: str, audio_io_kwargs: Optional[dict[str, Any]] = None, diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index cd41832bc2ea4..1e15dc6a91aa0 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -328,23 +328,6 @@ class CpuPlatform(Platform): def supports_structured_output(cls) -> bool: return True - @classmethod - def supports_v1(cls, model_config) -> bool: - """Returns whether the current platform can support v1 for the supplied - model configuration. - """ - return True - - @classmethod - def default_v1(cls, model_config) -> bool: - """Returns whether the current platform can use v1 by default for the - supplied model configuration. - """ - arch = cls.get_cpu_architecture() - return (cls.supports_v1(model_config) - and arch in (CpuArchEnum.X86, CpuArchEnum.POWERPC, - CpuArchEnum.ARM, CpuArchEnum.S390X)) - @classmethod def opaque_attention_op(cls) -> bool: return True diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 05f129f513a0a..d5f3599acb1cc 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -286,6 +286,9 @@ class CudaPlatformBase(Platform): TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501 XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501 + use_fp8_kv_cache = (kv_cache_dtype is not None + and kv_cache_dtype.startswith("fp8")) + if selected_backend == _Backend.FLASHINFER: logger.info_once("Using FlashInfer backend on V1 engine.") if cls.has_device_capability(100): @@ -334,10 +337,11 @@ class CudaPlatformBase(Platform): # FlashAttention is the default for SM 8.0+ GPUs if cls.has_device_capability(80): - if has_sink and not cls.is_device_capability(90): + if (has_sink or + use_fp8_kv_cache) and not cls.is_device_capability(90): logger.info_once("Using Triton backend on V1 engine.") return TRITON_ATTN_VLLM_V1 - if is_default_backend_supported := is_attn_backend_supported( + elif is_default_backend_supported := is_attn_backend_supported( FLASH_ATTN_V1, head_size, dtype, allow_import_error=False): logger.info_once("Using Flash Attention backend on " @@ -380,10 +384,6 @@ class CudaPlatformBase(Platform): def supports_fp8(cls) -> bool: return cls.has_device_capability(89) - @classmethod - def supports_v1(cls, model_config: "ModelConfig") -> bool: - return True - @classmethod def use_custom_allreduce(cls) -> bool: return True @@ -498,6 +498,10 @@ class CudaPlatformBase(Platform): def support_hybrid_kv_cache(cls) -> bool: return True + @classmethod + def support_static_graph_mode(cls) -> bool: + return True + # NVML utils # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index c43580ac5da13..7dd935d2eb31c 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -67,6 +67,7 @@ class _Backend(enum.Enum): FLEX_ATTENTION = enum.auto() TREE_ATTN = enum.auto() XFORMERS_VLLM_V1 = enum.auto() + ROCM_ATTN_VLLM_V1 = enum.auto() class PlatformEnum(enum.Enum): @@ -481,20 +482,6 @@ class Platform: or parallel_config.distributed_executor_backend == "external_launcher") - @classmethod - def supports_v1(cls, model_config: ModelConfig) -> bool: - """Returns whether the current platform can support v1 for the supplied - model configuration. - """ - return False - - @classmethod - def default_v1(cls, model_config: ModelConfig) -> bool: - """ - Returns whether the current platform supports v1 by default. - """ - return cls.supports_v1(model_config) - @classmethod def use_custom_allreduce(cls) -> bool: """ @@ -587,6 +574,13 @@ class Platform: """ return False + @classmethod + def support_static_graph_mode(cls) -> bool: + """ + Returns if the graph mode is supported by the current platform. + """ + return False + @classmethod def use_sync_weight_loader(cls) -> bool: """ @@ -610,6 +604,21 @@ class Platform: return _synced_weight_loader + @classmethod + def get_nixl_supported_devices(cls) -> dict[str, tuple[str, ...]]: + """ + Returns a mapping from device_type to a tuple of supported + kv_buffer_device for nixl. + """ + return {} + + @classmethod + def get_nixl_memory_type(cls) -> Optional[str]: + """ + Returns the nixl memory type for the current platform. + """ + return None + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 9470434aa428b..878718489fa88 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -231,7 +231,17 @@ class RocmPlatform(Platform): logger.info("Using Flash Attention backend on V1 engine.") return ("vllm.v1.attention.backends." "rocm_aiter_fa.AiterFlashAttentionBackend") + elif (envs.VLLM_ROCM_USE_AITER and + envs.VLLM_USE_AITER_UNIFIED_ATTENTION) or \ + envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION or \ + selected_backend == _Backend.ROCM_ATTN_VLLM_V1: + # rocm specific backend, with aiter and/or + # triton prefix-prefill + logger.info("Using Rocm/Aiter Attention backend on V1 engine.") + return ("vllm.v1.attention.backends." + "rocm_attn.RocmAttentionBackend") else: + # default case, using triton unified attention logger.info("Using Triton Attention backend on V1 engine.") return ("vllm.v1.attention.backends." "triton_attn.TritonAttentionBackend") @@ -386,11 +396,6 @@ class RocmPlatform(Platform): else: return torch.float8_e4m3fn - @classmethod - def supports_v1(cls, model_config: "ModelConfig") -> bool: - # V1 support on AMD gpus is experimental - return True - @classmethod def use_custom_allreduce(cls) -> bool: # We only enable custom allreduce for MI300 series @@ -477,3 +482,7 @@ class RocmPlatform(Platform): @classmethod def support_hybrid_kv_cache(cls) -> bool: return True + + @classmethod + def support_static_graph_mode(cls) -> bool: + return True diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 9852d948bc4b8..e4c73b1bae6fb 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -174,11 +174,6 @@ class TpuPlatform(Platform): def use_all_gather(cls) -> bool: return True - @classmethod - def supports_v1(cls, model_config: ModelConfig) -> bool: - # V1 support on TPU is experimental - return True - @classmethod def validate_request( cls, diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 4d3bef4b42947..af61db5e312a4 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -113,12 +113,11 @@ class XPUPlatform(Platform): # lazy import to avoid circular import from vllm.config import CompilationLevel, CUDAGraphMode compilation_config = vllm_config.compilation_config - if compilation_config.cudagraph_mode is None or \ - compilation_config.cudagraph_mode.max_cudagraph_mode() \ - != CUDAGraphMode.NONE: - logger.info("[XPU] CUDA graph is not supported on XPU, disabling " - "cudagraphs. Fallback to cudagraph_mode=NONE") - compilation_config.cudagraph_mode = CUDAGraphMode.NONE + if compilation_config.compile_sizes is None: + compilation_config.compile_sizes = [] + + assert compilation_config.cudagraph_mode == CUDAGraphMode.NONE, \ + "CUDA graph mode should be NONE on XPU" if vllm_config.lora_config is not None: compilation_config.level = CompilationLevel.NO_COMPILATION @@ -169,6 +168,10 @@ class XPUPlatform(Platform): def support_hybrid_kv_cache(cls) -> bool: return True + @classmethod + def support_static_graph_mode(cls) -> bool: + return False + @classmethod def is_pin_memory_available(cls): return True @@ -193,10 +196,6 @@ class XPUPlatform(Platform): def get_device_communicator_cls(cls) -> str: return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa - @classmethod - def supports_v1(cls, model_config: ModelConfig) -> bool: - return True - @classmethod def device_count(cls) -> int: return torch.xpu.device_count() diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 91bfeb8c55ee5..52fa49ad302bd 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -9,6 +9,7 @@ Model configs may be defined in this directory for the following reasons: from vllm.transformers_utils.configs.chatglm import ChatGLMConfig from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekVLV2Config +from vllm.transformers_utils.configs.dotsocr import DotsOCRConfig from vllm.transformers_utils.configs.eagle import EAGLEConfig # RWConfig is for the original tiiuae/falcon-40b(-instruct) and # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the @@ -36,6 +37,7 @@ from vllm.transformers_utils.configs.ultravox import UltravoxConfig __all__ = [ "ChatGLMConfig", "DeepseekVLV2Config", + "DotsOCRConfig", "EAGLEConfig", "RWConfig", "JAISConfig", diff --git a/vllm/transformers_utils/configs/dotsocr.py b/vllm/transformers_utils/configs/dotsocr.py new file mode 100644 index 0000000000000..6bb3c12d9c7eb --- /dev/null +++ b/vllm/transformers_utils/configs/dotsocr.py @@ -0,0 +1,69 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any, Optional + +from transformers.configuration_utils import PretrainedConfig +from transformers.models.qwen2 import Qwen2Config + + +class DotsVisionConfig(PretrainedConfig): + model_type: str = "dots_vit" + + def __init__( + self, + embed_dim: int = 1536, # vision encoder embed size + hidden_size: int = 1536, # after merger hidden size + intermediate_size: int = 4224, + num_hidden_layers: int = 42, + num_attention_heads: int = 12, + num_channels: int = 3, + patch_size: int = 14, + spatial_merge_size: int = 2, + temporal_patch_size: int = 1, + rms_norm_eps: float = 1e-5, + use_bias: bool = False, + attn_implementation="flash_attention_2", + initializer_range=0.02, + init_merger_std=0.02, + is_causal=False, # ve causal forward + post_norm=True, + gradient_checkpointing=False, + **kwargs: Any, + ): + super().__init__(**kwargs) + self.embed_dim = embed_dim + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.rms_norm_eps = rms_norm_eps + self.use_bias = use_bias + self.attn_implementation = attn_implementation + self.initializer_range = initializer_range + self.init_merger_std = init_merger_std + self.is_causal = is_causal + self.post_norm = post_norm + self.gradient_checkpointing = gradient_checkpointing + + +class DotsOCRConfig(Qwen2Config): + model_type = "dots_ocr" + + def __init__(self, + image_token_id=151665, + video_token_id=151656, + vision_config: Optional[dict] = None, + *args, + **kwargs): + super().__init__(*args, **kwargs) + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.vision_config = DotsVisionConfig(**(vision_config or {})) + + def save_pretrained(self, save_directory, **kwargs): + self._auto_class = None + super().save_pretrained(save_directory, **kwargs) diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 834ec9b1d30b4..3399d00fbabbd 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -88,64 +88,6 @@ DEFAULT_MAX_NUM_BATCHED_TOKENS = 2048 POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120 -# Exception strings for non-implemented encoder/decoder scenarios - -# Reminder: Please update docs/features/compatibility_matrix.md -# If the feature combo become valid - -STR_NOT_IMPL_ENC_DEC_SWA = \ - "Sliding window attention for encoder/decoder models " + \ - "is not currently supported." - -STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \ - "Prefix caching for encoder/decoder models " + \ - "is not currently supported." - -STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \ - "Chunked prefill for encoder/decoder models " + \ - "is not currently supported." - -STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP = ( - "Models with logits_soft_cap " - "require FlashInfer backend, which is " - "currently not supported for encoder/decoder " - "models.") - -STR_NOT_IMPL_ENC_DEC_LORA = ("LoRA is not currently " - "supported with encoder/decoder " - "models.") - -STR_NOT_IMPL_ENC_DEC_PP = ("Pipeline parallelism is not " - "currently supported with " - "encoder/decoder models.") - -STR_NOT_IMPL_ENC_DEC_MM = ("Multimodal is not currently " - "supported with encoder/decoder " - "models.") - -STR_NOT_IMPL_ENC_DEC_SPEC_DEC = ("Speculative decoding is not " - "currently supported with encoder/" - "decoder models.") - -STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers and Flash-Attention are the only " - "backends currently supported with encoder/" - "decoder models.") - -# Efficiently import all enc/dec error strings -# rather than having to import all of the above -STR_NOT_IMPL_ENC_DEC_ERR_STRS = { - "STR_NOT_IMPL_ENC_DEC_SWA": STR_NOT_IMPL_ENC_DEC_SWA, - "STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE": STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, - "STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL": - STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL, - "STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP": STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP, - "STR_NOT_IMPL_ENC_DEC_LORA": STR_NOT_IMPL_ENC_DEC_LORA, - "STR_NOT_IMPL_ENC_DEC_PP": STR_NOT_IMPL_ENC_DEC_PP, - "STR_NOT_IMPL_ENC_DEC_MM": STR_NOT_IMPL_ENC_DEC_MM, - "STR_NOT_IMPL_ENC_DEC_SPEC_DEC": STR_NOT_IMPL_ENC_DEC_SPEC_DEC, - "STR_NOT_IMPL_ENC_DEC_BACKEND": STR_NOT_IMPL_ENC_DEC_BACKEND, -} - # Constants related to forcing the attention backend selection # String name of register which may be set in order to @@ -609,9 +551,10 @@ class AsyncMicrobatchTokenizer: # If every request uses identical kwargs we can run a single # batched tokenizer call for a big speed-up. if can_batch and len(prompts) > 1: - encode_fn = partial(self.tokenizer, prompts, **kwargs) + batch_encode_fn = partial(self.tokenizer, prompts, + **kwargs) results = await self._loop.run_in_executor( - self._executor, encode_fn) + self._executor, batch_encode_fn) for i, fut in enumerate(result_futures): if not fut.done(): @@ -947,7 +890,7 @@ def get_open_port() -> int: def get_open_ports_list(count: int = 5) -> list[int]: """Get a list of open ports.""" - ports = set() + ports = set[int]() while len(ports) < count: ports.add(get_open_port()) return list(ports) @@ -1337,7 +1280,7 @@ def as_list(maybe_list: Iterable[T]) -> list[T]: def as_iter(obj: Union[T, Iterable[T]]) -> Iterable[T]: if isinstance(obj, str) or not isinstance(obj, Iterable): - obj = [obj] + return [obj] # type: ignore[list-item] return obj @@ -3449,7 +3392,7 @@ def length_from_prompt_token_ids_or_embeds( prompt_token_ids: Optional[list[int]], prompt_embeds: Optional[torch.Tensor], ) -> int: - """Calculate the request length (in number of tokens) give either + """Calculate the request length (in number of tokens) give either prompt_token_ids or prompt_embeds. """ prompt_token_len = None if prompt_token_ids is None else len( @@ -3470,3 +3413,16 @@ def length_from_prompt_token_ids_or_embeds( f" prompt_token_ids={prompt_token_len}" f" prompt_embeds={prompt_embeds_len}") return prompt_token_len + + +@contextlib.contextmanager +def set_env_var(key, value): + old = os.environ.get(key) + os.environ[key] = value + try: + yield + finally: + if old is None: + del os.environ[key] + else: + os.environ[key] = old diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 38d92f01192b1..4083193d7650e 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -135,7 +135,7 @@ DEFAULT_BLOCK_SIZE = [128, 128] # Taken from https://github.com/deepseek-ai/DeepGEMM/blob/dd6ed14acbc7445dcef224248a77ab4d22b5f240/deep_gemm/utils/math.py#L38 -# TODO(wentao): optimize this function, using triton or cuda kernel +@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def per_block_cast_to_fp8( x: torch.Tensor, block_size: list[int] = DEFAULT_BLOCK_SIZE, @@ -187,4 +187,4 @@ __all__ = [ "is_deep_gemm_e8m0_used", "is_deep_gemm_supported", "should_use_deepgemm_for_fp8_linear", -] +] \ No newline at end of file diff --git a/vllm/utils/tensor_schema.py b/vllm/utils/tensor_schema.py index 21d3249fe1547..d75dbcd5401b2 100644 --- a/vllm/utils/tensor_schema.py +++ b/vllm/utils/tensor_schema.py @@ -22,9 +22,8 @@ class TensorShape: self.dims = dims self.dynamic_dims = dynamic_dims if dynamic_dims else set() - def resolve(self, **bindings: dict[str, - int]) -> tuple[Union[int, str], ...]: - resolved = [] + def resolve(self, **bindings: int) -> tuple[Union[int, str], ...]: + resolved = list[Union[int, str]]() for dim in self.dims: if isinstance(dim, str) and dim in bindings: resolved.append(bindings[dim]) @@ -159,7 +158,7 @@ class TensorSchema: def validate(self) -> None: type_hints = get_type_hints(self.__class__, include_extras=True) - shape_env = {} + shape_env = dict[str, int]() for field_name, field_type in type_hints.items(): # Check if field is missing diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 6627164c98798..7e485fea2689d 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -425,7 +425,6 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]): num_prompt_req], # prefill query_start_loc=query_start_loc_cpu[:num_reqs + 1], # for logits index - multi_modal_placeholder_index_maps=None, enable_kv_scales_calculation=False, ) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 20f1904b3be6f..d564cf9988ea6 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -8,6 +8,7 @@ import numpy as np import torch from vllm import _custom_ops as ops +from vllm import envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType, is_quantized_kv_cache) @@ -33,9 +34,6 @@ from vllm.v1.kv_cache_interface import AttentionSpec logger = init_logger(__name__) -# NOTE(woosuk): This is an arbitrary number. Tune it if needed. -_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16 - class FlashAttentionBackend(AttentionBackend): @@ -215,7 +213,8 @@ class FlashAttentionMetadataBuilder( # When using cuda graph, we need to set the upper bound of the # number of splits so that large enough intermediate buffers are # pre-allocated during capture. - self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH + self.max_num_splits = ( + envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH) # Sliding window size to be used with the AOT scheduler will be # populated on first build() call. diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index 5dadc52d0fb1c..06a87a4a3c8b2 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -12,6 +12,7 @@ from vllm.config import VllmConfig from vllm.v1.attention.backends.utils import (AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, + compute_causal_conv1d_metadata, split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec @@ -52,7 +53,6 @@ class GDNAttentionMetadata: # The following attributes are for triton implementation of causal_conv1d nums_dict: Optional[dict] = None - cu_seqlen: Optional[int] = None batch_ptr: Optional[torch.Tensor] = None token_chunk_offset_ptr: Optional[torch.Tensor] = None @@ -134,6 +134,7 @@ class GDNAttentionMetadataBuilder( context_lens = m.num_computed_tokens_cpu context_lens_tensor = context_lens.to(query_start_loc.device) seq_lens_tensor = m.seq_lens + nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None if (not self.use_spec_decode or num_draft_tokens is None or num_draft_tokens.sum().item() == 0): @@ -210,6 +211,8 @@ class GDNAttentionMetadataBuilder( has_initial_state = context_lens_tensor > 0 if spec_sequence_masks is not None: has_initial_state = has_initial_state[~spec_sequence_masks] + nums_dict, batch_ptr, token_chunk_offset_ptr = \ + compute_causal_conv1d_metadata(non_spec_query_start_loc) else: has_initial_state = None num_actual_tokens = num_prefill_tokens + num_decode_tokens + \ @@ -297,6 +300,9 @@ class GDNAttentionMetadataBuilder( spec_sequence_masks=spec_sequence_masks, spec_token_masks=spec_token_masks, num_accepted_tokens=num_accepted_tokens, + nums_dict=nums_dict, + batch_ptr=batch_ptr, + token_chunk_offset_ptr=token_chunk_offset_ptr, ) return attn_metadata diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index 2fe1f14ca1db0..f45fc75334a21 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -7,11 +7,12 @@ from typing import Optional import torch from vllm.attention.backends.abstract import AttentionBackend -from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.config import VllmConfig from vllm.v1.attention.backends.mamba_attn import ( BaseMambaAttentionMetadataBuilder) -from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, +from vllm.v1.attention.backends.utils import (PAD_SLOT_ID, + CommonAttentionMetadata, + compute_causal_conv1d_metadata, split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec @@ -131,7 +132,6 @@ class Mamba2AttentionMetadata: # The following attributes are for triton implementation of causal_conv1d nums_dict: Optional[dict] = None - cu_seqlen: Optional[int] = None batch_ptr: Optional[torch.Tensor] = None token_chunk_offset_ptr: Optional[torch.Tensor] = None @@ -161,6 +161,9 @@ class Mamba2AttentionMetadataBuilder( has_initial_states_p = None prep_initial_states = False + # for causal_conv1d + nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None + state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( @@ -198,6 +201,9 @@ class Mamba2AttentionMetadataBuilder( query_start_loc_p, self.chunk_size, num_prefill_tokens)) + nums_dict, batch_ptr, token_chunk_offset_ptr = \ + compute_causal_conv1d_metadata(query_start_loc_p) + elif num_decodes <= self.decode_cudagraph_max_bs: # Pad state tensor for CUDA graph num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes) @@ -220,5 +226,8 @@ class Mamba2AttentionMetadataBuilder( chunk_indices_p=chunk_indices_p, chunk_offsets_p=chunk_offsets_p, state_indices_tensor=state_indices_tensor, + nums_dict=nums_dict, + batch_ptr=batch_ptr, + token_chunk_offset_ptr=token_chunk_offset_ptr, ) return attn_metadata diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 5b307810de930..a177117a50bd1 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -481,7 +481,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): # which would result in up-projected context being # 2*(192*128)*(64*1024) = 3gb # (assuming 192 QK head dim, 128 heads, and fp16) - 128 * 1024) + 64 * 1024) assert self.chunked_prefill_workspace_size >= \ scheduler_config.max_num_seqs * cache_config.block_size if self.dcp_world_size > 1: @@ -942,6 +942,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): qk_head_dim: int, v_head_dim: int, kv_b_proj: ColumnParallelLinear, + q_pad_num_heads: Optional[int] = None, ) -> None: if kv_sharing_target_layer_name is not None: raise NotImplementedError("KV sharing is not supported for MLA") @@ -959,6 +960,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): self.qk_head_dim = qk_head_dim self.v_head_dim = v_head_dim self.kv_b_proj = kv_b_proj + self.q_pad_num_heads = q_pad_num_heads if use_flashinfer_prefill(): logger.debug_once("Using FlashInfer prefill for MLA") @@ -1134,7 +1136,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): True, #Indicates actual_seq_lens are on GPU or CPU. ) - def _v_up_proj(self, x): + def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) if is_rocm_aiter_fp8bmm_enabled(): @@ -1146,12 +1148,23 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): transpose_bm=True) # Convert from (B, N, V) to (B, N * V) x = x.reshape(-1, self.num_heads * self.v_head_dim) + # Copy result + out.copy_(x) else: + # Convert from (B, N * V) to (N, B, V) + out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1) + # Multiply (N, B, L) x (N, L, V) -> (N, B, V) - x = torch.bmm(x, self.W_UV) + torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot" + # Convert from (N, B, V) to (B, N * V) - x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) - return x + out_new = out.transpose(0, 1).reshape( + -1, self.num_heads * self.v_head_dim) + + # Adjust output buffer shape back to the original (B, N * V) + N, B, V = out.shape + out.resize_((B, N * V)) + out.copy_(out_new) # Copy result def process_weights_after_loading(self, act_dtype: torch.dtype): @@ -1559,6 +1572,15 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): # Convert from (B, N, P) to (N, B, P) decode_q_nope = decode_q_nope.transpose(0, 1) + # Pads the head_dim if necessary (for the underlying kernel) + if self.q_pad_num_heads is not None: + B, N, L = decode_q_pe.shape + decode_pe_padded = decode_q_pe.new_empty( + (B, self.q_pad_num_heads, L)) + decode_pe_padded.resize_((B, N, L)) + decode_pe_padded.copy_(decode_q_pe) + decode_q_pe = decode_pe_padded + if is_rocm_aiter_fp8bmm_enabled(): # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L) decode_ql_nope = aiter_triton_fp8_bmm(decode_q_nope, @@ -1567,8 +1589,19 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): group_size=128, transpose_bm=True) else: + # Pads the head_dim if necessary (for the underlying kernel) + N, B, P = decode_q_nope.shape + _, _, L = self.W_UK_T.shape + if self.q_pad_num_heads is not None: + decode_ql_nope = decode_q_nope.new_empty( + (self.q_pad_num_heads, B, L)) + decode_ql_nope.resize_((N, B, L)) + + else: + decode_ql_nope = decode_q_nope.new_empty((N, B, L)) + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) - decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) + torch.bmm(decode_q_nope, self.W_UK_T, out=decode_ql_nope) # Convert from (N, B, L) to (B, N, L) decode_ql_nope = decode_ql_nope.transpose(0, 1) @@ -1603,5 +1636,5 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group()) # v_up projection - output[:num_decode_tokens] = self._v_up_proj(attn_out) + self._v_up_proj(attn_out, out=output[:num_decode_tokens]) return output_padded diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index ae534f3207b51..d44e20f2cb6be 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -74,6 +74,8 @@ class SM100Workspace: g_sm100_workspace = SM100Workspace(128 * 1024 * 1024) # 128MB +MAX_HEADS = 128 + class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): can_return_lse_for_decode: bool = True @@ -92,10 +94,18 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): kv_sharing_target_layer_name: Optional[str], # MLA Specific Arguments **mla_args) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - logits_soft_cap, attn_type, - kv_sharing_target_layer_name, **mla_args) + super().__init__(num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + logits_soft_cap, + attn_type, + kv_sharing_target_layer_name, + q_pad_num_heads=MAX_HEADS, + **mla_args) unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): @@ -157,14 +167,6 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): MAX_HEADS = 128 assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}" - if H < MAX_HEADS: - q_nope_padded = q_nope.new_empty((B_q, MAX_HEADS, D_q_nope)) - q_nope_padded[:, :H] = q_nope - q_nope = q_nope_padded - - q_pe_padded = q_pe.new_empty((B_q, MAX_HEADS, D_q_pe)) - q_pe_padded[:, :H] = q_pe - q_pe = q_pe_padded assert len(page_table.shape) == 2 B_block_table, block_num = page_table.shape @@ -206,9 +208,9 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): ) if H < MAX_HEADS: + # Extract the subsets of the outputs + lse = lse[:, :H] if self.need_to_return_lse_for_decode else lse out = out[:, :H] - if self.need_to_return_lse_for_decode: - lse = lse[:, :H].contiguous() return out, lse diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 472095e13615b..4ad9a13b61d8e 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -6,6 +6,7 @@ from typing import ClassVar, Optional, Union import torch +from vllm import envs from vllm.attention.backends.abstract import (AttentionLayer, AttentionType, is_quantized_kv_cache) from vllm.attention.utils.fa_utils import (flash_attn_supports_mla, @@ -24,10 +25,6 @@ from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata logger = init_logger(__name__) -# NOTE(matt): This is an arbitrary number, copied from -# woosuk's implementation in standard FlashAttention backend -_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16 - class FlashAttnMLABackend(MLACommonBackend): @@ -97,7 +94,8 @@ class FlashAttnMLAMetadataBuilder( # When using cuda graph, we need to set the upper bound of the # number of splits so that large enough intermediate buffers are # pre-allocated during capture. - self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH + self.max_num_splits = ( + envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH) # TODO(lucas): Until we add support for the DCP custom masking we need # to restrict decodes to q_len == 1 when DCP is enabled. diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py new file mode 100644 index 0000000000000..365df5f0d6eca --- /dev/null +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -0,0 +1,426 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Attention layer with PagedAttention and Triton prefix prefill.""" +from dataclasses import dataclass +from functools import cache +from typing import ClassVar, Optional + +import torch + +from vllm import _custom_ops as ops +from vllm import envs +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, AttentionType) +from vllm.attention.ops.chunked_prefill_paged_decode import ( + chunked_prefill_paged_decode) +from vllm.attention.ops.paged_attn import PagedAttention +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, kFp8StaticTensorSym) +from vllm.platforms import current_platform +from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.v1.attention.backends.utils import (AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata) +from vllm.v1.kv_cache_interface import AttentionSpec + +logger = init_logger(__name__) + + +@dataclass +class RocmAttentionMetadata: + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + num_actual_tokens: int # Number of tokens excluding padding. + max_query_len: int + query_start_loc: torch.Tensor + max_seq_len: int + seq_lens: torch.Tensor + block_table: torch.Tensor + slot_mapping: torch.Tensor + + # For cascade attention. + use_cascade: bool + common_prefix_len: int + cu_prefix_query_lens: Optional[torch.Tensor] + prefix_kv_lens: Optional[torch.Tensor] + suffix_kv_lens: Optional[torch.Tensor] + + # Optional aot scheduling + scheduler_metadata: Optional[torch.Tensor] = None + prefix_scheduler_metadata: Optional[torch.Tensor] = None + + +class RocmAttentionMetadataBuilder( + AttentionMetadataBuilder[RocmAttentionMetadata]): + cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS + + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], + vllm_config: VllmConfig, device: torch.device): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + + self.block_size = kv_cache_spec.block_size + + model_config = vllm_config.model_config + self.num_heads_q = model_config.get_num_attention_heads( + vllm_config.parallel_config) + self.num_heads_kv = model_config.get_num_kv_heads( + vllm_config.parallel_config) + self.headdim = model_config.get_head_size() + + def build_for_cudagraph_capture( + self, common_attn_metadata: CommonAttentionMetadata + ) -> RocmAttentionMetadata: + attn_metadata = self.build(0, common_attn_metadata) + # When doing full graph capture, setting seq_lens to + # max_model_len will cause graph capture to be extremely + # slow, so here we set it to 1. + attn_metadata.seq_lens.fill_(1) + return attn_metadata + + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> RocmAttentionMetadata: + num_actual_tokens = common_attn_metadata.num_actual_tokens + max_query_len = common_attn_metadata.max_query_len + + max_seq_len = common_attn_metadata.max_seq_len + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens + block_table_tensor = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping + + use_cascade = common_prefix_len > 0 + + if use_cascade: + cu_prefix_query_lens = torch.tensor([0, num_actual_tokens], + dtype=torch.int32, + device=self.device) + prefix_kv_lens = torch.tensor([common_prefix_len], + dtype=torch.int32, + device=self.device) + suffix_kv_lens = (common_attn_metadata.seq_lens_cpu - + common_prefix_len) + suffix_kv_lens = suffix_kv_lens.to(self.device) + else: + cu_prefix_query_lens = None + prefix_kv_lens = None + suffix_kv_lens = None + prefix_scheduler_metadata = None + + attn_metadata = RocmAttentionMetadata( + num_actual_tokens=num_actual_tokens, + max_query_len=max_query_len, + query_start_loc=query_start_loc, + max_seq_len=max_seq_len, + seq_lens=seq_lens, + block_table=block_table_tensor, + slot_mapping=slot_mapping, + use_cascade=use_cascade, + common_prefix_len=common_prefix_len, + cu_prefix_query_lens=cu_prefix_query_lens, + prefix_kv_lens=prefix_kv_lens, + suffix_kv_lens=suffix_kv_lens, + prefix_scheduler_metadata=prefix_scheduler_metadata, + ) + return attn_metadata + + +class RocmAttentionBackend(AttentionBackend): + + accept_output_buffer: bool = True + + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + + @classmethod + def validate_head_size(cls, head_size: int) -> None: + supported_head_sizes = cls.get_supported_head_sizes() + if head_size not in supported_head_sizes: + attn_type = cls.__name__.removesuffix("Backend") + raise ValueError( + f"Head size {head_size} is not supported by {attn_type}. " + f"Supported head sizes are: {supported_head_sizes}. " + "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " + "FlexAttention backend which supports all head sizes.") + + @staticmethod + def get_name() -> str: + return "ROCM_ATTN_VLLM_V1" + + @staticmethod + def get_impl_cls() -> type["RocmAttentionImpl"]: + return RocmAttentionImpl + + @staticmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + return RocmAttentionMetadata + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return (2, num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def use_cascade_attention(*args, **kwargs) -> bool: + return False + + @staticmethod + def get_builder_cls() -> type["RocmAttentionMetadataBuilder"]: + return RocmAttentionMetadataBuilder + + +@cache +def use_aiter_unified_attention() -> bool: + """Check if aiter unified attention should be used.""" + # VLLM_ROCM_USE_AITER_MHA needs to set to 0 as well as it is set + # to 1 as default + return envs.VLLM_ROCM_USE_AITER \ + and envs.VLLM_USE_AITER_UNIFIED_ATTENTION + + +class RocmAttentionImpl(AttentionImpl): + + def fused_output_quant_supported(self, quant_key: QuantKey): + return quant_key == kFp8StaticTensorSym + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float] = None, + attn_type: AttentionType = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[int] = None, + sinks: Optional[torch.Tensor] = None, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + if sliding_window is None: + self.sliding_window = (-1, -1) + else: + self.sliding_window = (sliding_window - 1, 0) + self.kv_cache_dtype = kv_cache_dtype + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + logits_soft_cap = 0 + self.logits_soft_cap = logits_soft_cap + self.kv_sharing_target_layer_name = kv_sharing_target_layer_name + + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + RocmAttentionBackend.validate_head_size(head_size) + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "RocmAttentionImpl") + + self.fp8_dtype = current_platform.fp8_dtype() + self.force_prefill_decode_attn = \ + envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION + + if not self.force_prefill_decode_attn: + # If not using prefill decode attention, we use the Triton + # unified attention implementation. + if use_aiter_unified_attention(): + logger.info_once( + "Using aiter unified attention for RocmAttentionImpl") + from aiter.ops.triton.unified_attention import ( + unified_attention) + self.unified_attention = unified_attention + else: + logger.info_once( + "Using vllm unified attention for RocmAttentionImpl") + from vllm.attention.ops.triton_unified_attention import ( + unified_attention) + self.unified_attention = unified_attention + + self.sinks = sinks + if sinks is not None: + assert sinks.shape[0] == num_heads, ( + "Sinks must have the same number of heads as the number of " + f"heads in the layer. Sinks shape: {sinks.shape}, " + f"num_heads: {num_heads}.") + + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with FlashAttention. + + Args: + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + kv_cache: shape = + [2, num_blocks, block_size, num_kv_heads, head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + assert output is not None, "Output tensor must be provided." + + if output_block_scale is not None: + raise NotImplementedError( + "fused block_scale output quantization is not yet supported" + " for RocmAttentionImpl") + + if attn_metadata is None: + # Profiling run. + return output + + assert attn_metadata.use_cascade is False + + # IMPORTANT! + # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in + # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead + # in this method. For example, `view` and `slice` (or `[:n]`) operations + # are surprisingly slow even in the case they do not invoke any GPU ops. + # Minimize the PyTorch ops in this method as much as possible. + # Whenever making a change in this method, please benchmark the + # performance to make sure it does not introduce any overhead. + + use_prefill_decode_attn = self.force_prefill_decode_attn + num_actual_tokens = attn_metadata.num_actual_tokens + + if use_prefill_decode_attn: + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + else: + key_cache, value_cache = kv_cache.unbind(0) + + if self.kv_sharing_target_layer_name is None: + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + if use_prefill_decode_attn: + PagedAttention.write_to_paged_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + else: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + if self.kv_cache_dtype.startswith("fp8"): + key_cache = key_cache.view(self.fp8_dtype) + value_cache = value_cache.view(self.fp8_dtype) + num_tokens, num_heads, head_size = query.shape + assert layer._q_scale_float == 1.0, \ + "A non 1.0 q_scale is not currently supported." + if current_platform.is_cuda(): + # Skip Q quantization on ROCm and XPU, enable this on cuda + # only, since dequantizing back to f32 in the attention kernel + # is not supported. + query, _ = ops.scaled_fp8_quant( + query.reshape( + (num_tokens, num_heads * head_size)).contiguous(), + layer._q_scale) + query = query.reshape((num_tokens, num_heads, head_size)) + + cu_seqlens_q = attn_metadata.query_start_loc + seqused_k = attn_metadata.seq_lens + max_seqlen_q = attn_metadata.max_query_len + max_seqlen_k = attn_metadata.max_seq_len + block_table = attn_metadata.block_table + + if use_prefill_decode_attn: + # Compute attention and update output up to `num_actual_tokens`. + chunked_prefill_paged_decode( + query=query[:num_actual_tokens], + key=key[:num_actual_tokens], + value=value[:num_actual_tokens], + output=output[:num_actual_tokens], + kv_cache_dtype=self.kv_cache_dtype, + key_cache=key_cache, + value_cache=value_cache, + block_table=block_table, + query_start_loc=cu_seqlens_q, + seq_lens=seqused_k, + max_seq_len=max_seqlen_k, + max_query_len=max_seqlen_q, + k_scale=layer._k_scale, + v_scale=layer._v_scale, + alibi_slopes=self.alibi_slopes, + sliding_window=self.sliding_window[0], + sm_scale=self.scale, + output_scale=output_scale, + sinks=self.sinks, + ) + + else: + descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) + + self.unified_attention( + q=query[:num_actual_tokens], + k=key_cache, + v=value_cache, + out=output[:num_actual_tokens], + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + seqused_k=seqused_k, + max_seqlen_k=max_seqlen_k, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=block_table, + softcap=self.logits_soft_cap, + q_descale=None, # Not supported + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + sinks=self.sinks, + output_scale=output_scale) + + return output diff --git a/vllm/v1/attention/backends/short_conv_attn.py b/vllm/v1/attention/backends/short_conv_attn.py index 717c40b37ecfb..428e409659798 100644 --- a/vllm/v1/attention/backends/short_conv_attn.py +++ b/vllm/v1/attention/backends/short_conv_attn.py @@ -9,6 +9,7 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.config import VllmConfig from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata, + compute_causal_conv1d_metadata, split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec @@ -33,7 +34,6 @@ class ShortConvAttentionMetadata: # For causal_conv1d nums_dict: Optional[dict] = None - cu_seqlen: Optional[int] = None batch_ptr: Optional[torch.Tensor] = None token_chunk_offset_ptr: Optional[torch.Tensor] = None @@ -57,6 +57,9 @@ class ShortConvAttentionMetadataBuilder( state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] + # for causal_conv1d + nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( common_attn_metadata, @@ -70,6 +73,12 @@ class ShortConvAttentionMetadataBuilder( has_initial_states = has_initial_states_cpu.to( query_start_loc.device) + query_start_loc_p = common_attn_metadata.query_start_loc[ + -num_prefills - 1:] - num_decode_tokens + + nums_dict, batch_ptr, token_chunk_offset_ptr = \ + compute_causal_conv1d_metadata(query_start_loc_p) + attn_metadata = ShortConvAttentionMetadata( num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, @@ -78,5 +87,8 @@ class ShortConvAttentionMetadataBuilder( query_start_loc=query_start_loc, has_initial_states=has_initial_states, state_indices_tensor=state_indices_tensor, + nums_dict=nums_dict, + batch_ptr=batch_ptr, + token_chunk_offset_ptr=token_chunk_offset_ptr, ) return attn_metadata diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 784912a122f68..722c23f150cd3 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -1,24 +1,19 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Attention layer with PagedAttention and Triton prefix prefill.""" +"""High-Performance Triton-only Attention layer.""" from dataclasses import dataclass -from functools import cache from typing import ClassVar, Optional import torch -from vllm import envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) -from vllm.attention.ops.chunked_prefill_paged_decode import ( - chunked_prefill_paged_decode) -from vllm.attention.ops.paged_attn import PagedAttention +from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8StaticTensorSym) from vllm.platforms import current_platform -from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.utils import (AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata) @@ -144,20 +139,15 @@ class TritonAttentionBackend(AttentionBackend): @classmethod def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] - - @classmethod - def get_supported_head_sizes(cls) -> list[int]: - return [32, 64, 96, 128, 160, 192, 224, 256] + return [torch.float16, torch.bfloat16, torch.float32] @classmethod def validate_head_size(cls, head_size: int) -> None: - supported_head_sizes = cls.get_supported_head_sizes() - if head_size not in supported_head_sizes: - attn_type = cls.__name__.removesuffix("Backend") + # Triton Attention supports any head size above 32 + if head_size < 32: raise ValueError( - f"Head size {head_size} is not supported by {attn_type}. " - f"Supported head sizes are: {supported_head_sizes}. " + f"Head size {head_size} is not supported by TritonAttention." + f"Head sizes need to be larger or equal 32 for this backend. " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " "FlexAttention backend which supports all head sizes.") @@ -182,7 +172,7 @@ class TritonAttentionBackend(AttentionBackend): ) -> tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") - return (2, num_blocks, block_size, num_kv_heads, head_size) + return (num_blocks, 2, block_size, num_kv_heads, head_size) @staticmethod def use_cascade_attention(*args, **kwargs) -> bool: @@ -193,15 +183,6 @@ class TritonAttentionBackend(AttentionBackend): return TritonAttentionMetadataBuilder -@cache -def use_aiter_unified_attention() -> bool: - """Check if aiter unified attention should be used.""" - # VLLM_ROCM_USE_AITER_MHA needs to set to 0 as well as it is set - # to 1 as default - return envs.VLLM_ROCM_USE_AITER \ - and envs.VLLM_USE_AITER_UNIFIED_ATTENTION - - class TritonAttentionImpl(AttentionImpl): def fused_output_quant_supported(self, quant_key: QuantKey): @@ -250,24 +231,6 @@ class TritonAttentionImpl(AttentionImpl): "TritonAttentionImpl") self.fp8_dtype = current_platform.fp8_dtype() - self.force_prefill_decode_attn = \ - envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION - - if not self.force_prefill_decode_attn: - # If not using prefill decode attention, we use the Triton - # unified attention implementation. - if use_aiter_unified_attention(): - logger.info_once( - "Using aiter unified attention for TritonAttentionImpl") - from aiter.ops.triton.unified_attention import ( - unified_attention) - self.unified_attention = unified_attention - else: - logger.info_once( - "Using vllm unified attention for TritonAttentionImpl") - from vllm.attention.ops.triton_unified_attention import ( - unified_attention) - self.unified_attention = unified_attention self.sinks = sinks if sinks is not None: @@ -283,19 +246,19 @@ class TritonAttentionImpl(AttentionImpl): key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: FlashAttentionMetadata, + attn_metadata: TritonAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: - """Forward pass with FlashAttention. + """Forward pass with Paged Attention impl. in Triton. Args: query: shape = [num_tokens, num_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] kv_cache: shape = - [2, num_blocks, block_size, num_kv_heads, head_size] + [num_blocks, 2, block_size, num_kv_heads, head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] @@ -322,40 +285,22 @@ class TritonAttentionImpl(AttentionImpl): # Whenever making a change in this method, please benchmark the # performance to make sure it does not introduce any overhead. - use_prefill_decode_attn = self.force_prefill_decode_attn num_actual_tokens = attn_metadata.num_actual_tokens - - if use_prefill_decode_attn: - key_cache, value_cache = PagedAttention.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) - else: - key_cache, value_cache = kv_cache.unbind(0) + key_cache, value_cache = kv_cache.unbind(1) if self.kv_sharing_target_layer_name is None: # Reshape the input keys and values and store them in the cache. # Skip this if sharing KV cache with an earlier attention layer. - if use_prefill_decode_attn: - PagedAttention.write_to_paged_cache( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - else: - ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) if self.kv_cache_dtype.startswith("fp8"): key_cache = key_cache.view(self.fp8_dtype) @@ -379,52 +324,28 @@ class TritonAttentionImpl(AttentionImpl): max_seqlen_k = attn_metadata.max_seq_len block_table = attn_metadata.block_table - if use_prefill_decode_attn: - # Compute attention and update output up to `num_actual_tokens`. - chunked_prefill_paged_decode( - query=query[:num_actual_tokens], - key=key[:num_actual_tokens], - value=value[:num_actual_tokens], - output=output[:num_actual_tokens], - kv_cache_dtype=self.kv_cache_dtype, - key_cache=key_cache, - value_cache=value_cache, - block_table=block_table, - query_start_loc=cu_seqlens_q, - seq_lens=seqused_k, - max_seq_len=max_seqlen_k, - max_query_len=max_seqlen_q, - k_scale=layer._k_scale, - v_scale=layer._v_scale, - alibi_slopes=self.alibi_slopes, - sliding_window=self.sliding_window[0], - sm_scale=self.scale, - output_scale=output_scale, - sinks=self.sinks, - ) + descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) - else: - descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) - - self.unified_attention( - q=query[:num_actual_tokens], - k=key_cache, - v=value_cache, - out=output[:num_actual_tokens], - cu_seqlens_q=cu_seqlens_q, - max_seqlen_q=max_seqlen_q, - seqused_k=seqused_k, - max_seqlen_k=max_seqlen_k, - softmax_scale=self.scale, - causal=True, - alibi_slopes=self.alibi_slopes, - window_size=self.sliding_window, - block_table=block_table, - softcap=self.logits_soft_cap, - q_descale=None, # Not supported - k_descale=layer._k_scale.expand(descale_shape), - v_descale=layer._v_scale.expand(descale_shape), - sinks=self.sinks, - output_scale=output_scale) + unified_attention( + q=query[:num_actual_tokens], + k=key_cache, + v=value_cache, + out=output[:num_actual_tokens], + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + seqused_k=seqused_k, + max_seqlen_k=max_seqlen_k, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=block_table, + softcap=self.logits_soft_cap, + q_descale=None, # Not supported + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + sinks=self.sinks, + output_scale=output_scale, + ) return output diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 63326d19194f0..6ef489f5a7a28 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -34,6 +34,8 @@ logger = init_logger(__name__) KVCacheLayoutType = Literal["NHD", "HND"] _KV_CACHE_LAYOUT_OVERRIDE: Union[KVCacheLayoutType, None] = None +PAD_SLOT_ID = -1 + def is_valid_kv_cache_layout(value: str) -> bool: return value in get_args(KVCacheLayoutType) @@ -838,3 +840,52 @@ def create_fast_prefill_custom_backend( builder_cls=FastPrefillAttentionBuilder) return attn_backend + + +def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor): + + # Needed for causal_conv1d + seqlens = query_start_loc_p.diff().to('cpu') + nums_dict = {} # type: ignore + batch_ptr = None + token_chunk_offset_ptr = None + for BLOCK_M in [8]: # cover all BLOCK_M values + nums = -(-seqlens // BLOCK_M) + nums_dict[BLOCK_M] = {} + nums_dict[BLOCK_M]['nums'] = nums + nums_dict[BLOCK_M]['tot'] = nums.sum().item() + mlist = torch.from_numpy(np.repeat(np.arange(len(nums)), nums)) + nums_dict[BLOCK_M]['mlist'] = mlist + mlist_len = len(nums_dict[BLOCK_M]['mlist']) + nums_dict[BLOCK_M]['mlist_len'] = mlist_len + MAX_NUM_PROGRAMS = max(1024, mlist_len) * 2 + offsetlist = [] # type: ignore + for idx, num in enumerate(nums): + offsetlist.extend(range(num)) + offsetlist = torch.tensor(offsetlist, dtype=torch.int32) + nums_dict[BLOCK_M]['offsetlist'] = offsetlist + + if batch_ptr is None: + # Update default value after class definition + batch_ptr = torch.full((MAX_NUM_PROGRAMS, ), + PAD_SLOT_ID, + dtype=torch.int32, + device='cuda') + token_chunk_offset_ptr = torch.full((MAX_NUM_PROGRAMS, ), + PAD_SLOT_ID, + dtype=torch.int32, + device='cuda') + else: + if batch_ptr.nelement() < MAX_NUM_PROGRAMS: + batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID) + token_chunk_offset_ptr.resize_( # type: ignore + MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID) + + batch_ptr[0:mlist_len].copy_(mlist) + token_chunk_offset_ptr[ # type: ignore + 0:mlist_len].copy_(offsetlist) + nums_dict[BLOCK_M]['batch_ptr'] = batch_ptr + nums_dict[BLOCK_M]['token_chunk_offset_ptr'] = (token_chunk_offset_ptr + ) # type: ignore + + return nums_dict, batch_ptr, token_chunk_offset_ptr diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 37f2cd80445d1..cf2f490621ebc 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -463,10 +463,6 @@ class Scheduler(SchedulerInterface): # always padded to the maximum length. If we support other # encoder-decoder models, this will need to be updated if we # want to only allocate what is needed. - assert ("whisper" - in self.vllm_config.model_config.model.lower()), ( - "Whisper is the only supported " - "encoder-decoder model.") num_encoder_tokens =\ self.scheduler_config.max_num_encoder_input_tokens else: diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 907656d1b24cb..92c861d9e91fe 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -11,6 +11,7 @@ from typing_extensions import TypeVar import vllm.envs as envs from vllm.config import ParallelConfig, VllmConfig from vllm.distributed import stateless_destroy_torch_distributed_process_group +from vllm.distributed.parallel_state import get_dp_group from vllm.engine.arg_utils import EngineArgs from vllm.inputs import PromptType from vllm.logger import init_logger @@ -77,10 +78,15 @@ class LLMEngine: if self.log_stats: self.stat_logger = PrometheusStatLogger(vllm_config) + executor_backend = ( + self.vllm_config.parallel_config.distributed_executor_backend) + parallel_config = vllm_config.parallel_config + self.external_launcher_dp = (parallel_config.data_parallel_size > 1 and + executor_backend == "external_launcher") # important: init dp group before init the engine_core # In the decoupled engine case this is handled in EngineCoreProc. - parallel_config = vllm_config.parallel_config - if not multiprocess_mode and parallel_config.data_parallel_size > 1: + if not multiprocess_mode and parallel_config.data_parallel_size > 1 \ + and not self.external_launcher_dp: self.dp_group = parallel_config.stateless_init_dp_group() else: self.dp_group = None @@ -120,6 +126,11 @@ class LLMEngine: # for v0 compatibility self.model_executor = self.engine_core.engine_core.model_executor # type: ignore + if self.external_launcher_dp: + # If we use DP in external launcher mode, we reuse the + # existing DP group used for data communication. + self.dp_group = get_dp_group().cpu_group + # Don't keep the dummy data in memory self.reset_mm_cache() @@ -331,5 +342,6 @@ class LLMEngine: return self.collective_rpc("apply_model", args=(func, )) def __del__(self): - if dp_group := getattr(self, "dp_group", None): + if dp_group := getattr(self, "dp_group", + None) and not self.external_launcher_dp: stateless_destroy_torch_distributed_process_group(dp_group) diff --git a/vllm/v1/kv_offload/cpu.py b/vllm/v1/kv_offload/cpu.py new file mode 100644 index 0000000000000..b85d375fe63e2 --- /dev/null +++ b/vllm/v1/kv_offload/cpu.py @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterator +from typing import Optional + +import torch + +from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.platforms import current_platform +from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager +from vllm.v1.kv_offload.backends.cpu import CPUBackend +from vllm.v1.kv_offload.lru_manager import LRUOffloadingManager +from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec +from vllm.v1.kv_offload.spec import OffloadingSpec +from vllm.v1.kv_offload.worker.cpu_gpu import CpuGpuOffloadingHandler +from vllm.v1.kv_offload.worker.worker import OffloadingHandler + + +class CPUOffloadingSpec(OffloadingSpec): + + def __init__(self, vllm_config: VllmConfig): + super().__init__(vllm_config) + + num_cpu_blocks = self.extra_config.get("num_cpu_blocks") + if not num_cpu_blocks: + raise Exception("num_cpu_blocks must be specified " + "in kv_connector_extra_config") + self.num_cpu_blocks: int = num_cpu_blocks + + # scheduler-side + self._manager: Optional[OffloadingManager] = None + + # worker-side + self._handler: Optional[OffloadingHandler] = None + + def get_manager(self) -> OffloadingManager: + if not self._manager: + kv_events_config = self.vllm_config.kv_events_config + enable_events = (kv_events_config is not None + and kv_events_config.enable_kv_cache_events) + self._manager = LRUOffloadingManager(CPUBackend( + block_size=self.offloaded_block_size, + num_blocks=self.num_cpu_blocks), + enable_events=enable_events) + return self._manager + + def get_handlers( + self, kv_caches: dict[str, torch.Tensor] + ) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], + OffloadingHandler]]: + if not self._handler: + if not current_platform.is_cuda(): + raise Exception("CPU Offloading is currently only supported" + " on CUDA GPUs") + + layer_names = list(kv_caches.keys()) + layers = get_layers_from_vllm_config(self.vllm_config, + AttentionLayerBase, + layer_names) + attn_backends = { + layer_name: layers[layer_name].get_attn_backend() + for layer_name in layer_names + } + + self._handler = CpuGpuOffloadingHandler( + attn_backends=attn_backends, + gpu_block_size=self.gpu_block_size, + cpu_block_size=self.offloaded_block_size, + num_cpu_blocks=self.num_cpu_blocks, + gpu_caches=kv_caches) + + assert self._handler is not None + yield GPULoadStoreSpec, CPULoadStoreSpec, self._handler + yield CPULoadStoreSpec, GPULoadStoreSpec, self._handler diff --git a/vllm/v1/kv_offload/factory.py b/vllm/v1/kv_offload/factory.py index 6365ab4a6db75..f9bef6cea9038 100644 --- a/vllm/v1/kv_offload/factory.py +++ b/vllm/v1/kv_offload/factory.py @@ -51,3 +51,6 @@ class OffloadingSpecFactory: # Register various specs here. +OffloadingSpecFactory.register_spec("CPUOffloadingSpec", + "vllm.v1.kv_offload.cpu", + "CPUOffloadingSpec") diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 5dacf60886966..dc97d5c8f39d4 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -823,15 +823,29 @@ class EagleProposer: else: target_language_model = target_model # share embed_tokens with the target model if needed - if get_pp_group().world_size == 1 \ - and self.model.model.embed_tokens.weight.shape \ - == target_language_model.model.embed_tokens.weight.shape: - logger.info( - "Assuming the EAGLE head shares the same vocab embedding" - " with the target model.") - del self.model.model.embed_tokens - self.model.model.embed_tokens = ( - target_language_model.model.embed_tokens) + if get_pp_group().world_size == 1: + if hasattr(target_language_model.model, 'embed_tokens'): + target_embed_tokens = target_language_model.model.embed_tokens + elif hasattr(target_language_model.model, 'embedding'): + target_embed_tokens = target_language_model.model.embedding + else: + raise AttributeError( + "Target model does not have 'embed_tokens' or 'embedding' " + "attribute") + + # Check if shapes match and we found the embedding + eagle_shape = self.model.model.embed_tokens.weight.shape + target_shape = target_embed_tokens.weight.shape + if eagle_shape == target_shape: + logger.info( + "Assuming the EAGLE head shares the same vocab embedding" + " with the target model.") + del self.model.model.embed_tokens + self.model.model.embed_tokens = target_embed_tokens + else: + logger.info( + "The EAGLE head's vocab embedding will be loaded separately" + " from the target model.") else: logger.info( "The EAGLE head's vocab embedding will be loaded separately" diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index b0cd0f4133079..89b9a3c34f2ac 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -55,7 +55,7 @@ from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - GiB_bytes, check_use_alibi, get_dtype_size, + GiB_bytes, cdiv, check_use_alibi, get_dtype_size, is_pin_memory_available, length_from_prompt_token_ids_or_embeds, round_up, supports_dynamo) @@ -2913,12 +2913,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Note: Overriding max_query_len to be the prefill tokens max_query_len = num_prefill_tokens elif uniform_decode: - num_reqs = num_tokens // max_query_len + assert not create_mixed_batch + num_reqs = cdiv(num_tokens, max_query_len) assert num_reqs <= max_num_reqs, \ "Do not capture num_reqs > max_num_reqs for uniform batch" num_scheduled_tokens_list = [max_query_len] * num_reqs if num_tokens % max_query_len != 0: - num_scheduled_tokens_list[-1] += num_tokens % max_query_len + num_scheduled_tokens_list[-1] = num_tokens % max_query_len else: num_reqs = min(num_tokens, max_num_reqs) min_tokens_per_req = num_tokens // num_reqs diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 0590f5cbd9989..ec58fa43099c3 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -392,7 +392,7 @@ class Worker(WorkerBase): f"utilize gpu memory. Current kv cache memory in use is " f"{int(self.available_kv_cache_memory_bytes)} bytes.") - logger.info(msg) + logger.debug(msg) # Warm up sampler and preallocate memory buffer for logits and other # sampling related tensors of max possible shape to avoid memory @@ -491,7 +491,7 @@ class Worker(WorkerBase): sort_by="self_cuda_time_total")) def execute_dummy_batch(self) -> None: - self.model_runner._dummy_run(1) + self.model_runner._dummy_run(1, uniform_decode=True) def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_runner.add_lora(lora_request) diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py index 016a90c196ba3..7eaff924ecc1f 100644 --- a/vllm/v1/worker/kv_connector_model_runner_mixin.py +++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py @@ -123,6 +123,7 @@ class KVConnectorModelRunnerMixin: output.kv_connector_stats = KVConnectorModelRunnerMixin.\ get_kv_connector_stats() + kv_connector.clear_connector_metadata() @staticmethod def get_kv_connector_stats() -> Optional[KVConnectorStats]: diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index dd11b1dcbe94c..4cbf991a14c11 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -10,6 +10,7 @@ import numpy as np import torch import torch.nn as nn # TPU XLA related +import torch_xla import torch_xla.core.xla_model as xm import torch_xla.distributed.spmd as xs import torch_xla.runtime as xr @@ -846,10 +847,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # 2. A list or tuple (length: num_items) of tensors, each of shape # (feature_size, hidden_size) in case the feature size is dynamic # depending on the input multimodal items. - xm.mark_step() + torch_xla.sync(wait=False) curr_group_outputs = self.model.get_multimodal_embeddings( **mm_kwargs_group) - xm.mark_step() + torch_xla.sync(wait=False) sanity_check_mm_encoder_outputs( curr_group_outputs, @@ -952,7 +953,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): mm_embeds = self._gather_mm_embeddings(scheduler_output) else: mm_embeds = [] - xm.mark_step() + torch_xla.sync(wait=False) # Prepare inputs, the requests might be split into multiple # executions, combine the result of each execution. start_index = 0 @@ -969,7 +970,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): end_index = self._prepare_inputs(scheduler_output, start_index) input_ids, inputs_embeds = self._get_model_inputs( self.input_ids, mm_embeds) - xm.mark_step() + torch_xla.sync(wait=False) # Run the decoder with set_forward_context( attn_metadata, @@ -1183,7 +1184,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Sync all pending XLA execution during model initialization and weight # loading. - xm.mark_step() + torch_xla.sync(wait=False) xm.wait_device_ops() if not hasattr(self, "model"): self.model = model @@ -1267,10 +1268,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def _set_active_loras(self, prompt_lora_mapping, token_lora_mapping, lora_requests) -> None: - xm.mark_step() # Captures input updates + torch_xla.sync(wait=False) # Captures input updates super()._set_active_loras(prompt_lora_mapping, token_lora_mapping, lora_requests) - xm.mark_step() # Captures metadata updates + torch_xla.sync(wait=False) # Captures metadata updates def _precompile_mm_encoder(self) -> None: if not self.supports_mm_inputs: @@ -1297,10 +1298,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_items, ) # Run multimodal encoder. - xm.mark_step() + torch_xla.sync(wait=False) mm_embeds = self.model.get_multimodal_embeddings( **batched_dummy_mm_inputs) - xm.mark_step() + torch_xla.sync(wait=False) num_patches = mm_embeds[0].shape[0] items_size = num_patches * num_items @@ -1325,7 +1326,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): a, b = self._get_model_inputs(placeholders_ids, [mm_embeds]) assert a is None - xm.mark_step() + torch_xla.sync(wait=False) # Pre-compile `get_input_embeddings` when mm_embeddings are not # present. Chunk is only made of text, no mm_placeholders. @@ -1336,7 +1337,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): placeholders_ids = placeholders_ids.to(self.device) a, b = self._get_model_inputs(placeholders_ids, []) assert a is None - xm.mark_step() + torch_xla.sync(wait=False) xm.wait_device_ops() end = time.perf_counter() @@ -1532,11 +1533,11 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Isolate encoder graph from post-processing to minimize # impact of recompilation until it's fixed. start = time.perf_counter() - xm.mark_step() + torch_xla.sync(wait=False) dummy_encoder_outputs = \ self.model.get_multimodal_embeddings( **batched_dummy_mm_inputs) - xm.mark_step() + torch_xla.sync(wait=False) xm.wait_device_ops() end = time.perf_counter() logger.info( @@ -1559,7 +1560,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self._dummy_run(num_tokens, self.num_reqs_most_model_len, self.num_blocks_per_most_len_req) - xm.mark_step() + torch_xla.sync(wait=False) xm.wait_device_ops() self.encoder_cache.clear() gc.collect() @@ -1927,11 +1928,11 @@ def replace_set_lora(model): # to a tensor doesn't seem to work anymore. This might be fixed with a # later release of torch_xla. self._original_set_lora(index, lora_a, lora_b, embeddings_tensor, bias) - xm.mark_step() + torch_xla.sync(wait=False) def _tpu_reset_lora(self, index: int): self._original_reset_lora(index) - xm.mark_step() + torch_xla.sync(wait=False) for _, module in model.named_modules(): if isinstance(module, BaseLayerWithLoRA):