diff --git a/.buildkite/scripts/hardware_ci/run-amd-test.sh b/.buildkite/scripts/hardware_ci/run-amd-test.sh index 0e5b21ddf25b3..864eb470bb0a7 100755 --- a/.buildkite/scripts/hardware_ci/run-amd-test.sh +++ b/.buildkite/scripts/hardware_ci/run-amd-test.sh @@ -59,7 +59,7 @@ while true; do fi done -echo "--- Pulling container" +echo "--- Pulling container" image_name="rocm/vllm-ci:${BUILDKITE_COMMIT}" container_name="rocm_${BUILDKITE_COMMIT}_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)" docker pull "${image_name}" @@ -177,13 +177,13 @@ if [[ -z "$render_gid" ]]; then exit 1 fi -# check if the command contains shard flag, we will run all shards in parallel because the host have 8 GPUs. +# check if the command contains shard flag, we will run all shards in parallel because the host have 8 GPUs. if [[ $commands == *"--shard-id="* ]]; then - # assign job count as the number of shards used - commands=${commands//"--num-shards= "/"--num-shards=${PARALLEL_JOB_COUNT} "} + # assign job count as the number of shards used + commands=$(echo "$commands" | sed -E "s/--num-shards[[:blank:]]*=[[:blank:]]*[0-9]*/--num-shards=${PARALLEL_JOB_COUNT} /g" | sed 's/ \\ / /g') for GPU in $(seq 0 $(($PARALLEL_JOB_COUNT-1))); do # assign shard-id for each shard - commands_gpu=${commands//"--shard-id= "/"--shard-id=${GPU} "} + commands_gpu=$(echo "$commands" | sed -E "s/--shard-id[[:blank:]]*=[[:blank:]]*[0-9]*/--shard-id=${GPU} /g" | sed 's/ \\ / /g') echo "Shard ${GPU} commands:$commands_gpu" echo "Render devices: $BUILDKITE_AGENT_META_DATA_RENDER_DEVICES" docker run \ diff --git a/.buildkite/scripts/hardware_ci/run-xpu-test.sh b/.buildkite/scripts/hardware_ci/run-xpu-test.sh index 27ed67c4517e2..d49f3e2f47cf1 100644 --- a/.buildkite/scripts/hardware_ci/run-xpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-xpu-test.sh @@ -46,6 +46,6 @@ docker run \ pytest -v -s v1/worker --ignore=v1/worker/test_gpu_model_runner.py pytest -v -s v1/structured_output pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_tree_attention.py --ignore=v1/spec_decode/test_speculators_eagle3.py - pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py --ignore=v1/kv_connector/unit/test_shared_storage_connector.py + pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py --ignore=v1/kv_connector/unit/test_shared_storage_connector.py --ignore=v1/kv_connector/unit/test_lmcache_integration.py pytest -v -s v1/test_serial_utils.py ' diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index 5fd048c2ad0c6..2471b509a9fff 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -226,6 +226,27 @@ steps: - VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py - popd +- label: Distributed Tests (8 GPUs) # 4min + timeout_in_minutes: 10 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_8 + # grade: Blocking + gpu: h100 + num_gpus: 8 + working_dir: "/vllm-workspace/tests" + source_file_dependencies: + - examples/offline_inference/torchrun_dp_example.py + - vllm/config/parallel.py + - vllm/distributed/ + - vllm/v1/engine/llm_engine.py + - vllm/v1/executor/uniproc_executor.py + - vllm/v1/worker/gpu_worker.py + commands: + # https://github.com/NVIDIA/nccl/issues/1838 + #- export NCCL_CUMEM_HOST_ENABLE=0 + # test with torchrun tp=2 and dp=4 with ep + - torchrun --nproc-per-node=8 ../examples/offline_inference/torchrun_dp_example.py --tp-size=2 --pp-size=1 --dp-size=4 --enable-ep + - label: EPLB Algorithm Test # 5min mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 @@ -238,11 +259,11 @@ steps: commands: - pytest -v -s distributed/test_eplb_algo.py -- label: EPLB Execution Test # 5min +- label: EPLB Execution Test # 10min mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_4 # grade: Blocking - timeout_in_minutes: 15 + timeout_in_minutes: 20 working_dir: "/vllm-workspace/tests" num_gpus: 4 source_file_dependencies: @@ -250,6 +271,7 @@ steps: - tests/distributed/test_eplb_execute.py commands: - pytest -v -s distributed/test_eplb_execute.py + - pytest -v -s distributed/test_eplb_spec_decode.py - label: Metrics, Tracing Test # 12min timeout_in_minutes: 20 @@ -273,7 +295,7 @@ steps: - label: Regression Test # 7min timeout_in_minutes: 20 - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental, amdproduction, amdtentative] agent_pool: mi325_1 grade: Blocking source_file_dependencies: @@ -288,7 +310,7 @@ steps: timeout_in_minutes: 40 mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 - #grade: Blocking + # grade: Blocking source_file_dependencies: - vllm/ - tests/engine @@ -337,6 +359,7 @@ steps: - tests/v1 commands: # split the test to avoid interference + - uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt - pytest -v -s -m 'not cpu_test' v1/core - pytest -v -s v1/executor - pytest -v -s v1/kv_offload @@ -344,7 +367,7 @@ steps: - pytest -v -s v1/logits_processors - pytest -v -s v1/worker - pytest -v -s v1/spec_decode - - pytest -v -s -m 'not cpu_test' v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_lmcache_integration.py + - pytest -v -s -m 'not cpu_test' v1/kv_connector/unit - pytest -v -s -m 'not cpu_test' v1/metrics - pytest -v -s v1/test_oracle.py - pytest -v -s v1/test_request.py @@ -353,6 +376,20 @@ steps: - pip install -U git+https://github.com/robertgshaw2-redhat/lm-evaluation-harness.git@streaming-api - pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine +# TODO: Add the "V1 Test attetion (MI300)" test group + +- label: V1 Test attention (H100) # 10min + mirror_hardwares: [amdexperimental] + agent_pool: mi325_1 + # grade: Blocking + timeout_in_minutes: 30 + gpu: h100 + source_file_dependencies: + - vllm/v1/attention + - tests/v1/attention + commands: + - pytest -v -s v1/attention + - label: V1 Test others (CPU) # 5 mins mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 @@ -479,10 +516,11 @@ steps: - tests/compile commands: - pytest -v -s compile/test_basic_correctness.py + - pytest -v -s compile/test_multimodal_compile.py - pytest -v -s compile/piecewise/ -- label: PyTorch Fullgraph Test # 22min - timeout_in_minutes: 35 +- label: PyTorch Fullgraph Test # 27min + timeout_in_minutes: 40 mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 # grade: Blocking @@ -491,8 +529,23 @@ steps: - vllm/ - tests/compile commands: - - pytest -v -s compile/test_full_graph.py - - pytest -v -s compile/test_fusions_e2e.py + - pytest -v -s compile/test_full_graph.py -k 'not test_fp8_kv_scale_compile' + # Limit to no custom ops to reduce running time + # Wrap with quotes to escape yaml and avoid starting -k string with a - + - "pytest -v -s compile/test_fusions_e2e.py -k 'TRITON and -quant_fp8'" + +- label: Cudagraph test + timeout_in_minutes: 20 + mirror_hardwares: [amdexperimental, amdproduction] + agent_pool: mi325_1 + source_file_dependencies: + - tests/v1/cudagraph + - vllm/v1/cudagraph_dispatcher.py + - vllm/config/compilation.py + - vllm/compilation + commands: + - pytest -v -s v1/cudagraph/test_cudagraph_dispatch.py + - pytest -v -s v1/cudagraph/test_cudagraph_mode.py - label: Kernels Core Operation Test # 48min timeout_in_minutes: 75 @@ -544,6 +597,8 @@ steps: - tests/kernels/moe - vllm/model_executor/layers/fused_moe/ - vllm/distributed/device_communicators/ + - vllm/envs.py + - vllm/config commands: - pytest -v -s kernels/moe --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT parallelism: 2 @@ -562,10 +617,13 @@ steps: - label: Model Executor Test # 23min timeout_in_minutes: 35 + torch_nightly: true mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 # grade: Blocking source_file_dependencies: + - vllm/engine/arg_utils.py + - vllm/config/model.py - vllm/model_executor - tests/model_executor - tests/entrypoints/openai/test_tensorizer_entrypoint.py @@ -861,9 +919,10 @@ steps: - cd .. && VLLM_WORKER_MULTIPROC_METHOD=spawn pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work - label: Multi-Modal Accuracy Eval (Small Models) # 10min + timeout_in_minutes: 70 mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 - timeout_in_minutes: 15 + # grade: Blocking working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" source_file_dependencies: - vllm/multimodal/ @@ -934,6 +993,7 @@ steps: - label: Transformers Nightly Models Test mirror_hardwares: [amdexperimental] agent_pool: mi325_1 + # grade: Blocking working_dir: "/vllm-workspace/" optional: true commands: @@ -961,11 +1021,16 @@ steps: - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py - vllm/v1/attention/backends/flashinfer.py + - vllm/v1/attention/backends/mla/cutlass_mla.py + - vllm/v1/attention/backends/mla/flashinfer_mla.py + - vllm/platforms/cuda.py + - vllm/attention/selector.py commands: - nvidia-smi - python3 examples/offline_inference/basic/chat.py # Attention # num_heads2 broken by https://github.com/flashinfer-ai/flashinfer/issues/1353 + - pytest -v -s tests/kernels/attention/test_attention_selector.py - pytest -v -s tests/kernels/attention/test_flashinfer.py -k 'not num_heads2' - pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_attention.py - pytest -v -s tests/kernels/attention/test_cutlass_mla_decode.py @@ -1002,12 +1067,39 @@ steps: - pytest -v -s tests/compile/test_silu_mul_quant_fusion.py # this runner has 2 GPUs available even though num_gpus=2 is not set - pytest -v -s tests/compile/test_fusion_all_reduce.py - - pytest -v -s tests/compile/test_fusions_e2e.py + # Limit to Inductor partition, no custom ops, and allreduce & attn fusion to reduce running time + # Wrap with quotes to escape yaml + - "pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and Llama-3.1 and -quant_fp8 and -rms_norm'" -- label: Blackwell GPT-OSS Eval - timeout_in_minutes: 60 +- label: Blackwell Fusion E2E Tests # 30 min + timeout_in_minutes: 40 working_dir: "/vllm-workspace/" gpu: b200 + optional: true + num_gpus: 2 + source_file_dependencies: + - csrc/quantization/fp4/ + - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py + - vllm/v1/attention/backends/flashinfer.py + - vllm/compilation/ + # can affect pattern matching + - vllm/model_executor/layers/layernorm.py + - vllm/model_executor/layers/activation.py + - vllm/model_executor/layers/quantization/input_quant_fp8.py + - tests/compile/test_fusions_e2e.py + - tests/compile/test_full_graph.py + commands: + - nvidia-smi + # Run all e2e fusion tests + - pytest -v -s tests/compile/test_fusions_e2e.py + # test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40) + - pytest -v -s tests/compile/test_full_graph.py::test_fp8_kv_scale_compile + +- label: ROCm GPT-OSS Eval + timeout_in_minutes: 60 + working_dir: "/vllm-workspace/" + agent_pool: mi325_1 + mirror_hardwares: [amdproduction] optional: true # run on nightlies source_file_dependencies: - tests/evals/gpt_oss @@ -1016,7 +1108,7 @@ steps: - vllm/v1/attention/backends/flashinfer.py commands: - uv pip install --system 'gpt-oss[eval]==0.0.5' - - pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58 + - VLLM_ROCM_USE_AITER_MHA=0 VLLM_ROCM_USE_AITER=1 VLLM_USE_AITER_UNIFIED_ATTENTION=1 pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58 - label: Blackwell Quantized MoE Test timeout_in_minutes: 60 @@ -1253,6 +1345,7 @@ steps: - label: NixlConnector PD accuracy tests (Distributed) # 30min mirror_hardwares: [amdexperimental] agent_pool: mi325_4 + # grade: Blocking timeout_in_minutes: 30 working_dir: "/vllm-workspace/tests" num_gpus: 4 @@ -1267,6 +1360,9 @@ steps: ##### A100 test ##### - label: Distributed Tests (A100) # optional + mirror_hardwares: [amdexperimental] + agent_pool: mi325_4 + # grade: Blocking gpu: a100 optional: true num_gpus: 4 @@ -1281,6 +1377,9 @@ steps: - pytest -v -s -x lora/test_mixtral.py - label: LM Eval Large Models # optional + mirror_hardwares: [amdexperimental, amdproduction] + agent_pool: mi325_4 + # grade: Blocking gpu: a100 optional: true num_gpus: 4 @@ -1292,8 +1391,27 @@ steps: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4 +##### H100 test ##### +- label: LM Eval Large Models (H100) # optional + mirror_hardwares: [amdexperimental, amdproduction] + agent_pool: mi325_4 + # grade: Blocking + gpu: h100 + optional: true + num_gpus: 4 + working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" + source_file_dependencies: + - csrc/ + - vllm/model_executor/layers/quantization + commands: + - export VLLM_USE_DEEP_GEMM=0 # We found Triton is faster than DeepGEMM for H100 + - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large-hopper.txt --tp-size=4 + ##### H200 test ##### - label: Distributed Tests (H200) # optional + mirror_hardwares: [amdexperimental] + agent_pool: mi325_2 + # grade: Blocking gpu: h200 optional: true working_dir: "/vllm-workspace/" @@ -1305,6 +1423,7 @@ steps: - pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm - pytest -v -s tests/distributed/test_context_parallel.py - CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 + - pytest -v -s tests/v1/distributed/test_dbo.py ##### B200 test ##### - label: Distributed Tests (B200) # optional @@ -1315,6 +1434,7 @@ steps: commands: - pytest -v -s tests/distributed/test_context_parallel.py - pytest -v -s tests/distributed/test_nccl_symm_mem_allreduce.py + - pytest -v -s tests/v1/distributed/test_dbo.py ##### RL Integration Tests ##### - label: Prime-RL Integration Test # 15min @@ -1330,3 +1450,27 @@ steps: - .buildkite/scripts/run-prime-rl-test.sh commands: - bash .buildkite/scripts/run-prime-rl-test.sh + +- label: DeepSeek V2-Lite Accuracy + mirror_hardwares: [amdexperimental] + agent_pool: mi325_4 + # grade: Blocking + timeout_in_minutes: 60 + gpu: h100 + optional: true + num_gpus: 4 + working_dir: "/vllm-workspace" + commands: + - bash .buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_eplb.sh 0.25 200 8010 + +- label: Qwen3-30B-A3B-FP8-block Accuracy + mirror_hardwares: [amdexperimental] + agent_pool: mi325_4 + # grade: Blocking + timeout_in_minutes: 60 + gpu: h100 + optional: true + num_gpus: 4 + working_dir: "/vllm-workspace" + commands: + - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep.sh 0.8 200 8020 diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index be1b79ddc4324..4ac76aba67b9c 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -445,6 +445,7 @@ steps: - vllm/ - tests/compile commands: + - pytest -v -s compile/test_graph_partition.py - pytest -v -s compile/test_config.py - pytest -v -s compile/test_pass_manager.py - pytest -v -s compile/test_fusion.py @@ -477,10 +478,11 @@ steps: - vllm/ - tests/compile commands: + # fp8 kv scales not supported on sm89, tested on Blackwell instead - pytest -v -s compile/test_full_graph.py -k 'not test_fp8_kv_scale_compile' # Limit to no custom ops to reduce running time # Wrap with quotes to escape yaml and avoid starting -k string with a - - - "pytest -v -s compile/test_fusions_e2e.py -k 'TRITON and -quant_fp8'" + - "pytest -v -s compile/test_fusions_e2e.py -k 'TRITON and not +quant_fp8 and not Llama-4'" - label: Cudagraph test timeout_in_minutes: 20 @@ -924,7 +926,7 @@ steps: - pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py - pytest -v -s tests/kernels/moe/test_flashinfer.py -- label: Blackwell Fusion Tests # 30 min +- label: Blackwell Fusion and Compile Tests # 30 min timeout_in_minutes: 40 working_dir: "/vllm-workspace/" gpu: b200 @@ -945,7 +947,9 @@ steps: - pytest -v -s tests/compile/test_fusion_all_reduce.py # Limit to Inductor partition, no custom ops, and allreduce & attn fusion to reduce running time # Wrap with quotes to escape yaml - - "pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and Llama-3.1 and -quant_fp8 and -rms_norm'" + - "pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and not +quant_fp8 and not +rms_norm'" + # test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40) + - pytest -v -s tests/compile/test_full_graph.py::test_fp8_kv_scale_compile - label: Blackwell Fusion E2E Tests # 30 min timeout_in_minutes: 40 @@ -968,8 +972,6 @@ steps: - nvidia-smi # Run all e2e fusion tests - pytest -v -s tests/compile/test_fusions_e2e.py - # test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40) - - pytest -v -s tests/compile/test_full_graph.py::test_fp8_kv_scale_compile - label: Blackwell GPT-OSS Eval timeout_in_minutes: 60 @@ -1265,7 +1267,8 @@ steps: - pytest -v -s tests/compile/test_async_tp.py - pytest -v -s tests/compile/test_sequence_parallelism.py - pytest -v -s tests/compile/test_fusion_all_reduce.py - - pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm + - "pytest -v -s tests/compile/test_fusions_e2e.py -k 'not Llama-4'" + - pytest -v -s tests/distributed/test_sequence_parallel.py - pytest -v -s tests/distributed/test_context_parallel.py - CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 - pytest -v -s tests/v1/distributed/test_dbo.py diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index f26c782bccf2c..6e178bb690c56 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -3,8 +3,8 @@ # This lists cover the "core" components of vLLM that require careful review /vllm/attention @LucasWilkinson -/vllm/attention/backends/abstract.py @WoosukKwon @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill -/vllm/executor/executor_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @22quinn +/vllm/attention/backends/abstract.py @WoosukKwon @zhuohan123 @youkaichao @alexm-redhat @njhill +/vllm/executor/executor_base.py @zhuohan123 @youkaichao @alexm-redhat @njhill @22quinn /vllm/model_executor/layers/fused_moe @mgoin @pavanimajety /vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth @yewentao256 @pavanimajety /vllm/model_executor/layers/mamba @tdoublep @@ -20,15 +20,15 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson # Any change to the VllmConfig changes can have a large user-facing impact, # so spam a lot of people -/vllm/config @simon-mo @WoosukKwon @youkaichao @robertgshaw2-redhat @mgoin @tlrmchlsmth @houseroad @hmellor @yewentao256 @ProExpertProg -/vllm/config/cache.py @simon-mo @WoosukKwon @youkaichao @robertgshaw2-redhat @mgoin @tlrmchlsmth @houseroad @hmellor @yewentao256 @ProExpertProg @heheda12345 +/vllm/config @WoosukKwon @youkaichao @robertgshaw2-redhat @mgoin @tlrmchlsmth @houseroad @hmellor @yewentao256 @ProExpertProg +/vllm/config/cache.py @WoosukKwon @youkaichao @robertgshaw2-redhat @mgoin @tlrmchlsmth @houseroad @hmellor @yewentao256 @ProExpertProg @heheda12345 # vLLM V1 /vllm/v1/attention @LucasWilkinson /vllm/v1/attention/backends/mla @pavanimajety /vllm/v1/attention/backends/flashinfer.py @mgoin @pavanimajety /vllm/v1/attention/backends/triton_attn.py @tdoublep -/vllm/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat @heheda12345 @ApostaC +/vllm/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @alexm-redhat @heheda12345 @ApostaC /vllm/v1/sample @22quinn @houseroad @njhill /vllm/v1/spec_decode @benchislett @luccafong /vllm/v1/structured_output @mgoin @russellb @aarnphm @benchislett @@ -36,11 +36,11 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson /vllm/v1/offloading @ApostaC # Test ownership -/.buildkite/lm-eval-harness @mgoin @simon-mo +/.buildkite/lm-eval-harness @mgoin /tests/distributed/test_multi_node_assignment.py @youkaichao /tests/distributed/test_pipeline_parallel.py @youkaichao /tests/distributed/test_same_node.py @youkaichao -/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @simon-mo @aarnphm @NickLucche +/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @aarnphm @NickLucche /tests/evals @mgoin /tests/kernels @mgoin @tlrmchlsmth @WoosukKwon @yewentao256 /tests/models @DarkLight1337 @ywang96 @@ -49,7 +49,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson /tests/test_inputs.py @DarkLight1337 @ywang96 /tests/v1/entrypoints/llm/test_struct_output_generate.py @mgoin @russellb @aarnphm /tests/v1/structured_output @mgoin @russellb @aarnphm -/tests/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat @heheda12345 @ApostaC +/tests/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @alexm-redhat @heheda12345 @ApostaC /tests/weight_loading @mgoin @youkaichao @yewentao256 /tests/lora @jeejeelee /tests/models/language/generation/test_hybrid.py @tdoublep @@ -57,7 +57,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson /tests/v1/kv_connector @ApostaC /tests/v1/offloading @ApostaC -# Transformers backend +# Transformers modeling backend /vllm/model_executor/models/transformers @hmellor /tests/models/test_transformers.py @hmellor diff --git a/.github/workflows/macos-smoke-test.yml b/.github/workflows/macos-smoke-test.yml new file mode 100644 index 0000000000000..8d40aa587bf00 --- /dev/null +++ b/.github/workflows/macos-smoke-test.yml @@ -0,0 +1,76 @@ +name: macOS Apple Silicon Smoke Test + +on: + workflow_dispatch: # Manual trigger + +jobs: + macos-m1-smoke-test: + runs-on: macos-latest + timeout-minutes: 20 + + steps: + - uses: actions/checkout@v4 + + - uses: astral-sh/setup-uv@v7 + with: + enable-cache: true + cache-dependency-glob: | + requirements/**/*.txt + pyproject.toml + python-version: '3.12' + + - name: Install dependencies + run: | + uv pip install -r requirements/cpu-build.txt + uv pip install -r requirements/cpu.txt + + - name: Build vLLM + run: uv pip install -v -e . + env: + CMAKE_BUILD_PARALLEL_LEVEL: 4 + + - name: Verify installation + run: | + python -c "import vllm; print(f'vLLM version: {vllm.__version__}')" + python -c "import torch; print(f'PyTorch: {torch.__version__}')" + + - name: Smoke test vllm serve + timeout-minutes: 10 + run: | + # Start server in background + vllm serve Qwen/Qwen3-0.6B \ + --max-model-len=2048 \ + --load-format=dummy \ + --enforce-eager \ + --port 8000 & + + SERVER_PID=$! + + # Wait for server to start + for i in {1..30}; do + if curl -s http://localhost:8000/health > /dev/null; then + echo "Server started successfully" + break + fi + if [ "$i" -eq 30 ]; then + echo "Server failed to start" + kill "$SERVER_PID" + exit 1 + fi + sleep 2 + done + + # Test health endpoint + curl -f http://localhost:8000/health + + # Test completion + curl -f http://localhost:8000/v1/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Qwen/Qwen3-0.6B", + "prompt": "Hello", + "max_tokens": 5 + }' + + # Cleanup + kill "$SERVER_PID" diff --git a/.markdownlint.yaml b/.markdownlint.yaml index cd9df57cd9803..937487f47364d 100644 --- a/.markdownlint.yaml +++ b/.markdownlint.yaml @@ -3,10 +3,9 @@ MD007: MD013: false MD024: siblings_only: true +MD031: + list_items: false MD033: false -MD045: false MD046: false -MD051: false MD052: false -MD053: false MD059: false diff --git a/CMakeLists.txt b/CMakeLists.txt index dcc44be87e557..3a37040edbf1a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -861,7 +861,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() # Hadacore kernels - cuda_archs_loose_intersection(HADACORE_ARCHS "8.0;8.9;9.0" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(HADACORE_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}") if(HADACORE_ARCHS) set(SRCS "csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu") set_gencode_flags_for_srcs( diff --git a/benchmarks/benchmark_batch_invariance.py b/benchmarks/benchmark_batch_invariance.py new file mode 100755 index 0000000000000..b5c16c42de467 --- /dev/null +++ b/benchmarks/benchmark_batch_invariance.py @@ -0,0 +1,380 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Benchmark to measure the performance overhead of VLLM_BATCH_INVARIANT mode. + +This benchmark runs the same workload twice: +1. With VLLM_BATCH_INVARIANT=0 (baseline) +2. With VLLM_BATCH_INVARIANT=1 (batch invariant mode) + +And reports the timing and throughput metrics for comparison. + +Environment variables: + VLLM_BENCH_MODEL: Model to benchmark (default: "Qwen/Qwen3-1.7B") + VLLM_BENCH_TP_SIZE: Tensor parallel size (default: 1, use 8 for deepseek) + VLLM_BENCH_BATCH_SIZE: Max batch size (default: 128) + VLLM_BENCH_NUM_TRIALS: Number of trials to run (default: 5) + VLLM_BENCH_MIN_PROMPT: Min prompt length in words (default: 1024) + VLLM_BENCH_MAX_PROMPT: Max prompt length in words (default: 2048) + VLLM_BENCH_MAX_TOKENS: Max tokens to generate (default: 128) + VLLM_BENCH_TEMPERATURE: Temperature for sampling (default: 0.0) + VLLM_BENCH_GPU_MEMORY_UTILIZATION: GPU memory utilization (default: 0.4) + VLLM_BENCH_MAX_MODEL_LEN: Max model length (default: 5120) + VLLM_BENCH_BACKEND: Attention backend (default: FLASH_ATTN) + +Example usage: + # Benchmark qwen3 (default) + python benchmarks/benchmark_batch_invariance.py + + # Benchmark deepseek with 8 GPUs + VLLM_BENCH_MODEL="deepseek-ai/DeepSeek-V3" VLLM_BENCH_TP_SIZE=8 \\ + python benchmarks/benchmark_batch_invariance.py + + # Quick test with fewer trials + VLLM_BENCH_NUM_TRIALS=2 VLLM_BENCH_BATCH_SIZE=32 \\ + python benchmarks/benchmark_batch_invariance.py +""" + +import contextlib +import os +import random +import time + +from vllm import LLM, SamplingParams +from vllm.platforms import current_platform + + +def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str: + """Generate a random prompt for benchmarking.""" + prompt_templates = [ + "Question: What is the capital of France?\nAnswer: The capital of France is", + "Q: How does photosynthesis work?\nA: Photosynthesis is the process by which", + "User: Can you explain quantum mechanics?\nAssistant: Quantum mechanics is", + "Once upon a time in a distant galaxy, there lived", + "The old man walked slowly down the street, remembering", + "In the year 2157, humanity finally discovered", + "To implement a binary search tree in Python, first we need to", + "The algorithm works by iterating through the array and", + "Here's how to optimize database queries using indexing:", + "The Renaissance was a period in European history that", + "Climate change is caused by several factors including", + "The human brain contains approximately 86 billion neurons which", + "I've been thinking about getting a new laptop because", + "Yesterday I went to the store and bought", + "My favorite thing about summer is definitely", + ] + + base_prompt = random.choice(prompt_templates) + + if max_words < min_words: + max_words = min_words + target_words = random.randint(min_words, max_words) + + if target_words > 50: + padding_text = ( + " This is an interesting topic that deserves more explanation. " + * (target_words // 50) + ) + base_prompt = base_prompt + padding_text + + return base_prompt + + +def run_benchmark_with_batch_invariant( + model: str, + tp_size: int, + max_batch_size: int, + num_trials: int, + min_prompt: int, + max_prompt: int, + max_tokens: int, + temperature: float, + gpu_mem_util: float, + max_model_len: int, + backend: str, + batch_invariant: bool, + seed: int = 12345, +) -> dict: + """ + Run the benchmark with the specified configuration. + + Returns a dict with timing and throughput metrics. + """ + random.seed(seed) + + # Set environment variables + os.environ["VLLM_ATTENTION_BACKEND"] = backend + if batch_invariant: + os.environ["VLLM_BATCH_INVARIANT"] = "1" + else: + os.environ["VLLM_BATCH_INVARIANT"] = "0" + + print(f"\n{'=' * 80}") + print(f"BENCHMARK: VLLM_BATCH_INVARIANT={int(batch_invariant)}") + print(f" Model: {model}") + print(f" TP Size: {tp_size}") + print(f" Backend: {backend}") + print(f" Max Batch Size: {max_batch_size}") + print(f" Trials: {num_trials}") + print(f" Max Tokens: {max_tokens}") + print(f"{'=' * 80}\n") + + sampling = SamplingParams( + temperature=temperature, + top_p=0.95, + max_tokens=max_tokens, + seed=20240919, + ) + + needle_prompt = "There once was a " + + llm = None + try: + # Create LLM engine + start_init = time.perf_counter() + llm = LLM( + model=model, + max_num_seqs=max_batch_size, + gpu_memory_utilization=gpu_mem_util, + max_model_len=max_model_len, + dtype="bfloat16", + tensor_parallel_size=tp_size, + enable_prefix_caching=False, + ) + init_time = time.perf_counter() - start_init + print(f"Engine initialization time: {init_time:.2f}s\n") + + # Generate baseline + print("Generating baseline (warmup)...") + baseline_out = llm.generate([needle_prompt], sampling) + assert len(baseline_out) == 1 + baseline_text = baseline_out[0].outputs[0].text + print(f"Baseline output: '{baseline_text[:50]}...'\n") + + # Run trials and measure timing + trial_times: list[float] = [] + total_tokens = 0 + total_prompts = 0 + + for trial in range(num_trials): + # Create a batch + prompts: list[str] = [] + batch_size = random.randint(max_batch_size // 2, max_batch_size) + needle_pos = random.randint(0, batch_size - 1) + for i in range(batch_size): + if i == needle_pos: + prompts.append(needle_prompt) + else: + prompts.append(_random_prompt(min_prompt, max_prompt)) + + # Measure time for this trial + start_time = time.perf_counter() + outputs = llm.generate(prompts, sampling) + trial_time = time.perf_counter() - start_time + + trial_times.append(trial_time) + total_prompts += len(prompts) + + # Count tokens + for output in outputs: + if output.outputs: + total_tokens += len(output.outputs[0].token_ids) + + print( + f"Trial {trial + 1}/{num_trials}: " + f"batch_size={batch_size}, " + f"time={trial_time:.2f}s" + ) + + # Verify needle output still matches + needle_output = outputs[needle_pos] + assert needle_output.prompt == needle_prompt + + # Compute statistics + avg_time = sum(trial_times) / len(trial_times) + min_time = min(trial_times) + max_time = max(trial_times) + throughput = total_tokens / sum(trial_times) + prompts_per_sec = total_prompts / sum(trial_times) + + print(f"\n{'=' * 80}") + print("RESULTS:") + print(f" Average time per trial: {avg_time:.2f}s") + print(f" Min time: {min_time:.2f}s") + print(f" Max time: {max_time:.2f}s") + print(f" Total tokens generated: {total_tokens}") + print(f" Total prompts processed: {total_prompts}") + print(f" Throughput: {throughput:.2f} tokens/s") + print(f" Prompts/s: {prompts_per_sec:.2f}") + print(f"{'=' * 80}\n") + + return { + "init_time": init_time, + "avg_time": avg_time, + "min_time": min_time, + "max_time": max_time, + "total_tokens": total_tokens, + "total_prompts": total_prompts, + "throughput": throughput, + "prompts_per_sec": prompts_per_sec, + "trial_times": trial_times, + } + + finally: + # Cleanup + if llm is not None: + with contextlib.suppress(Exception): + llm.shutdown() + + +def main(): + # Check platform support + if not (current_platform.is_cuda() and current_platform.has_device_capability(90)): + print("ERROR: Requires CUDA and >= Hopper (SM90)") + print(f"Current platform: {current_platform.device_type}") + if current_platform.is_cuda(): + print(f"Device capability: {current_platform.get_device_capability()}") + return 1 + + # Read configuration from environment + model = os.getenv("VLLM_BENCH_MODEL", "Qwen/Qwen3-1.7B") + tp_size = int(os.getenv("VLLM_BENCH_TP_SIZE", "1")) + max_batch_size = int(os.getenv("VLLM_BENCH_BATCH_SIZE", "128")) + num_trials = int(os.getenv("VLLM_BENCH_NUM_TRIALS", "5")) + min_prompt = int(os.getenv("VLLM_BENCH_MIN_PROMPT", "1024")) + max_prompt = int(os.getenv("VLLM_BENCH_MAX_PROMPT", "2048")) + max_tokens = int(os.getenv("VLLM_BENCH_MAX_TOKENS", "128")) + temperature = float(os.getenv("VLLM_BENCH_TEMPERATURE", "0.0")) + gpu_mem_util = float(os.getenv("VLLM_BENCH_GPU_MEMORY_UTILIZATION", "0.4")) + max_model_len = int(os.getenv("VLLM_BENCH_MAX_MODEL_LEN", "5120")) + backend = os.getenv("VLLM_BENCH_BACKEND", "FLASH_ATTN") + + print("\n" + "=" * 80) + print("VLLM BATCH INVARIANCE BENCHMARK") + print("=" * 80) + print("\nConfiguration:") + print(f" Model: {model}") + print(f" Tensor Parallel Size: {tp_size}") + print(f" Attention Backend: {backend}") + print(f" Max Batch Size: {max_batch_size}") + print(f" Number of Trials: {num_trials}") + print(f" Prompt Length Range: {min_prompt}-{max_prompt} words") + print(f" Max Tokens to Generate: {max_tokens}") + print(f" Temperature: {temperature}") + print(f" GPU Memory Utilization: {gpu_mem_util}") + print(f" Max Model Length: {max_model_len}") + print("=" * 80) + + # Run benchmark WITHOUT batch invariance (baseline) + print("\n" + "=" * 80) + print("PHASE 1: Running WITHOUT batch invariance (baseline)") + print("=" * 80) + baseline_results = run_benchmark_with_batch_invariant( + model=model, + tp_size=tp_size, + max_batch_size=max_batch_size, + num_trials=num_trials, + min_prompt=min_prompt, + max_prompt=max_prompt, + max_tokens=max_tokens, + temperature=temperature, + gpu_mem_util=gpu_mem_util, + max_model_len=max_model_len, + backend=backend, + batch_invariant=False, + ) + + # Run benchmark WITH batch invariance + print("\n" + "=" * 80) + print("PHASE 2: Running WITH batch invariance") + print("=" * 80) + batch_inv_results = run_benchmark_with_batch_invariant( + model=model, + tp_size=tp_size, + max_batch_size=max_batch_size, + num_trials=num_trials, + min_prompt=min_prompt, + max_prompt=max_prompt, + max_tokens=max_tokens, + temperature=temperature, + gpu_mem_util=gpu_mem_util, + max_model_len=max_model_len, + backend=backend, + batch_invariant=True, + ) + + # Compare results + print("\n" + "=" * 80) + print("COMPARISON: Batch Invariance vs Baseline") + print("=" * 80) + + init_overhead_pct = ( + (batch_inv_results["init_time"] - baseline_results["init_time"]) + / baseline_results["init_time"] + * 100 + ) + time_overhead_pct = ( + (batch_inv_results["avg_time"] - baseline_results["avg_time"]) + / baseline_results["avg_time"] + * 100 + ) + throughput_change_pct = ( + (batch_inv_results["throughput"] - baseline_results["throughput"]) + / baseline_results["throughput"] + * 100 + ) + + print("\nInitialization Time:") + print(f" Baseline: {baseline_results['init_time']:.2f}s") + print(f" Batch Invariant: {batch_inv_results['init_time']:.2f}s") + print(f" Overhead: {init_overhead_pct:+.2f}%") + + print("\nAverage Trial Time:") + print(f" Baseline: {baseline_results['avg_time']:.2f}s") + print(f" Batch Invariant: {batch_inv_results['avg_time']:.2f}s") + print(f" Overhead: {time_overhead_pct:+.2f}%") + + print("\nThroughput (tokens/s):") + print(f" Baseline: {baseline_results['throughput']:.2f}") + print(f" Batch Invariant: {batch_inv_results['throughput']:.2f}") + print(f" Change: {throughput_change_pct:+.2f}%") + + print("\nPrompts/s:") + print(f" Baseline: {baseline_results['prompts_per_sec']:.2f}") + print(f" Batch Invariant: {batch_inv_results['prompts_per_sec']:.2f}") + + print("\n" + "=" * 80) + print("SUMMARY") + print("=" * 80) + if time_overhead_pct > 0: + print( + f"Batch invariance mode adds approximately {time_overhead_pct:.1f}% " + "overhead" + ) + else: + print( + f"Batch invariance mode is approximately {-time_overhead_pct:.1f}% " + "faster (unexpected!)" + ) + + if abs(throughput_change_pct) < 1.0: + print("Throughput difference is negligible (< 1%)") + elif throughput_change_pct < 0: + print( + f"Throughput decreased by {-throughput_change_pct:.1f}% " + "with batch invariance" + ) + else: + print( + f"Throughput increased by {throughput_change_pct:.1f}% " + "with batch invariance (unexpected!)" + ) + + print("=" * 80 + "\n") + + return 0 + + +if __name__ == "__main__": + exit(main()) diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index 146c268a6b7f2..28fc383a318dd 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -69,7 +69,7 @@ def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> list[int]: # Remove the special tokens. return random.choices( - [v for k, v in vocab.items() if k not in all_special_ids], + [v for v in vocab.values() if v not in all_special_ids], k=length, ) diff --git a/benchmarks/multi_turn/benchmark_serving_multi_turn.py b/benchmarks/multi_turn/benchmark_serving_multi_turn.py index ae9e9753441aa..772d685ad90ff 100644 --- a/benchmarks/multi_turn/benchmark_serving_multi_turn.py +++ b/benchmarks/multi_turn/benchmark_serving_multi_turn.py @@ -561,8 +561,11 @@ async def client_main( f"{Color.CYAN}Started client {client_id}: max_num_requests={args.max_num_requests}, max_active_conversations={args.max_active_conversations}{Color.RESET}" # noqa: E501 ) - random.seed(args.seed) - np.random.seed(args.seed) + # Set unique seed per client (each client runs in its own process) + # Add 1 to ensure no client uses the same seed as the main process + client_seed = args.seed + client_id + 1 + random.seed(client_seed) + np.random.seed(client_seed) # Active conversations active_convs: ConversationsMap = {} @@ -1490,6 +1493,7 @@ async def main() -> None: f"Invalid --warmup-percentage={args.warmup_percentage}" ) from None + # Set global seeds for main process random.seed(args.seed) np.random.seed(args.seed) diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index bb0179c79c108..aa84125818d10 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -242,7 +242,7 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON SUBBUILD_DIR "${FETCHCONTENT_BASE_DIR}/arm_compute-subbuild" SOURCE_DIR "${FETCHCONTENT_BASE_DIR}/arm_compute-src" GIT_REPOSITORY https://github.com/ARM-software/ComputeLibrary.git - GIT_TAG v52.2.0 + GIT_TAG v52.6.0 GIT_SHALLOW TRUE GIT_PROGRESS TRUE ) @@ -310,7 +310,7 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON FetchContent_Declare( oneDNN GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git - GIT_TAG v3.9 + GIT_TAG v3.10 GIT_PROGRESS TRUE GIT_SHALLOW TRUE ) diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index 29db9fa273a41..567c8959f0454 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -38,7 +38,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG 8e1b01d56210dc72030a2d0d41c2d8d266ba6309 + GIT_TAG 58e0626a692f09241182582659e3bf8f16472659 GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn diff --git a/csrc/cpu/cpu_attn_impl.hpp b/csrc/cpu/cpu_attn_impl.hpp index 8f4c780998020..344296528b652 100644 --- a/csrc/cpu/cpu_attn_impl.hpp +++ b/csrc/cpu/cpu_attn_impl.hpp @@ -5,6 +5,10 @@ #include #include +#if defined(__APPLE__) + #include +#endif + #include "cpu_types.hpp" #include "scratchpad_manager.h" #include "cpu_attn_macros.h" @@ -741,9 +745,21 @@ class AttentionScheduler { static int64_t get_available_l2_size() { static int64_t size = []() { +#if defined(__APPLE__) + // macOS doesn't have _SC_LEVEL2_CACHE_SIZE. Use sysctlbyname. + int64_t l2_cache_size = 0; + size_t len = sizeof(l2_cache_size); + if (sysctlbyname("hw.l2cachesize", &l2_cache_size, &len, NULL, 0) == 0 && + l2_cache_size > 0) { + return l2_cache_size >> 1; // use 50% of L2 cache + } + // Fallback if sysctlbyname fails + return 128LL * 1024 >> 1; // use 50% of 128KB +#else long l2_cache_size = sysconf(_SC_LEVEL2_CACHE_SIZE); TORCH_CHECK_NE(l2_cache_size, -1); return l2_cache_size >> 1; // use 50% of L2 cache +#endif }(); return size; } @@ -816,15 +832,21 @@ struct VecTypeTrait { using vec_t = vec_op::FP32Vec16; }; +// ARM only supports BF16 with ARMv8.6-A extension +#if (defined(__aarch64__) && !defined(ARM_BF16_SUPPORT)) +#else template <> struct VecTypeTrait { using vec_t = vec_op::BF16Vec16; }; +#endif +#if !defined(__powerpc__) template <> struct VecTypeTrait { using vec_t = vec_op::FP16Vec16; }; +#endif template void print_logits(const char* name, T* ptr, int32_t row, int32_t col, @@ -1586,9 +1608,17 @@ class AttentionMainLoop { if (use_sink) { alignas(64) float s_aux_fp32[16]; +#if defined(__aarch64__) && !defined(ARM_BF16_SUPPORT) + // ARM without native BF16 support: manual conversion + for (int i = 0; i < 16; ++i) { + s_aux_fp32[i] = static_cast(curr_s_aux[i]); + } +#else + // All other platforms have BF16Vec16 available vec_op::BF16Vec16 vec_bf16(curr_s_aux); vec_op::FP32Vec16 vec_fp32(vec_bf16); vec_fp32.save(s_aux_fp32); +#endif float* __restrict__ curr_sum_buffer = sum_buffer; float* __restrict__ curr_max_buffer = max_buffer; diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 5e2aa70692566..9fefd88cd9b08 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -100,6 +100,9 @@ void cpu_attention_with_kv_cache( const torch::Tensor& scheduler_metadata, const std::optional& s_aux); +// Note: just for avoiding importing errors +void placeholder_op() { TORCH_CHECK(false, "Unimplemented"); } + TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // vLLM custom ops @@ -275,6 +278,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "sliding_window_left, SymInt sliding_window_right, Tensor block_table, " "float softcap, Tensor sheduler_metadata, Tensor? s_aux) -> ()", &cpu_attention_with_kv_cache); + + // placeholders + ops.def("static_scaled_fp8_quant() -> ()", placeholder_op); + ops.def("dynamic_scaled_fp8_quant() -> ()", placeholder_op); + ops.def("dynamic_per_token_scaled_fp8_quant() -> ()", placeholder_op); } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) { diff --git a/csrc/fused_qknorm_rope_kernel.cu b/csrc/fused_qknorm_rope_kernel.cu index 83017250ebcd5..baff8363162ef 100644 --- a/csrc/fused_qknorm_rope_kernel.cu +++ b/csrc/fused_qknorm_rope_kernel.cu @@ -37,6 +37,16 @@ #ifdef USE_ROCM #define FINAL_MASK 0xffffffffffffffffULL + + #if defined(HIP_VERSION) && HIP_VERSION < 70000000 +// On ROCm versions before 7.0, __syncwarp isn't defined. The below +// implementation is copy/pasted from the implementation in ROCm 7.0 +__device__ inline void __syncwarp() { + __builtin_amdgcn_fence(__ATOMIC_RELEASE, "wavefront"); + __builtin_amdgcn_wave_barrier(); + __builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "wavefront"); +} + #endif #else #define FINAL_MASK 0xffffffff #endif diff --git a/csrc/quantization/activation_kernels.cu b/csrc/quantization/activation_kernels.cu index 2521b2797e2c2..0c3bcf3b64b26 100644 --- a/csrc/quantization/activation_kernels.cu +++ b/csrc/quantization/activation_kernels.cu @@ -279,17 +279,17 @@ __device__ __forceinline__ void token_bounds(int32_t n_tokens, } template + typename scale_t, int THREADS, typename Idx_t, bool CEIL_UE8M0, + int GROUP_SIZE = 128, int NUM_STAGES = 3> __global__ void silu_mul_fp8_quant_deep_gemm_kernel( const __nv_bfloat16* __restrict__ _input, fp8_type* __restrict__ _y_q, - float* __restrict__ _y_s, const int32_t* __restrict__ tokens_per_expert, + scale_t* __restrict__ _y_s, const int32_t* __restrict__ tokens_per_expert, // sizes Idx_t E, Idx_t T, Idx_t H, // strides (in elements) Idx_t stride_i_e, Idx_t stride_i_t, Idx_t stride_i_h, Idx_t stride_yq_e, Idx_t stride_yq_t, Idx_t stride_yq_h, Idx_t stride_ys_e, Idx_t stride_ys_t, - Idx_t stride_ys_g, Idx_t stride_counts_e) { + Idx_t stride_ys_g, Idx_t stride_ys_p, Idx_t stride_counts_e) { #ifndef USE_ROCM static constexpr int NUM_WARPS = THREADS / WARP_SIZE; @@ -466,9 +466,22 @@ __global__ void silu_mul_fp8_quant_deep_gemm_kernel( __nv_fp8x4_e4m3* y_q_base_ptr = reinterpret_cast<__nv_fp8x4_e4m3*>(_y_q) + lane_id; - auto y_scale_base_ptr = _y_s + warp_position_scales * stride_ys_g; + + Idx_t scale_group_offset = 0; + if constexpr (std::is_same::value) { + // packed int32_t format + int pack_id = warp_position_scales / 4; + int scale_in_pack = warp_position_scales % 4; + scale_group_offset = pack_id * stride_ys_p + scale_in_pack * stride_ys_g; + } else { + scale_group_offset = warp_position_scales * stride_ys_g; + } + + scale_t* const y_scale_base_ptr = _y_s + scale_group_offset; for (auto j = tokens_lower; j < tokens_upper; j++) { + int current_group_id = warp_position_scales; // Running count of which + // group is being processed const Idx_t base_ys = expert_id * stride_ys_e; auto y_s_ptr = y_scale_base_ptr + base_ys + token_offset * stride_ys_t; __nv_fp8x4_e4m3* y_q_ptr = @@ -509,7 +522,7 @@ __global__ void silu_mul_fp8_quant_deep_gemm_kernel( __nv_bfloat16 y_s = __hmul(warp_max(_y_max2.x), fp8_inv); - if constexpr (USE_UE8M0) { + if constexpr (CEIL_UE8M0) { y_s = hexp2(hceil(hlog2(y_s))); } @@ -527,8 +540,24 @@ __global__ void silu_mul_fp8_quant_deep_gemm_kernel( y_q_ptr += WARP_SIZE * stride_yq_h; if (!lane_id) { - *y_s_ptr = y_s; - y_s_ptr += stride_ys_g; + // Store scales. + if constexpr (std::is_same::value) { + // Packed UE8MO format. Remove Mantissa. + *y_s_ptr = reinterpret_cast(y_s) >> 7; + + bool const jump_pack = (current_group_id + 1) % 4 == 0; + // Minus 3 because we need to get to the first group in the + // next pack. + y_s_ptr += jump_pack ? (stride_ys_p - 3) : stride_ys_g; + + } else { + // float32 format + static_assert(std::is_same::value); + *y_s_ptr = y_s; + y_s_ptr += stride_ys_g; + } + + current_group_id += 1; } } } @@ -573,7 +602,7 @@ void persistent_masked_m_silu_mul_quant( const at::Tensor& tokens_per_expert, // (E) at::Tensor& y_q, // (E, T, H) [OUT] at::Tensor& y_s, // (E, T, H//group_size) [OUT] - bool use_ue8m0) { + bool cast_scale_ue8m0) { #ifndef USE_ROCM // This kernel currently only supports H % 128 == 0 and assumes a @@ -583,9 +612,12 @@ void persistent_masked_m_silu_mul_quant( TORCH_CHECK(input.dtype() == torch::kBFloat16); TORCH_CHECK(y_q.dtype() == torch::kFloat8_e4m3fn || y_q.dtype() == torch::kFloat8_e4m3fnuz); - TORCH_CHECK(y_s.dtype() == torch::kFloat32); TORCH_CHECK(input.size(-1) % (GROUP_SIZE * 2) == 0); + bool const is_packed_ue8m0 = + (y_s.dtype() == torch::kInt32 && cast_scale_ue8m0); + TORCH_CHECK(y_s.dtype() == torch::kFloat32 || is_packed_ue8m0); + using Idx_t = int64_t; Idx_t E = input.size(0); @@ -597,15 +629,18 @@ void persistent_masked_m_silu_mul_quant( Idx_t stride_yq_e = y_q.stride(0); Idx_t stride_yq_t = y_q.stride(1); Idx_t stride_yq_h = y_q.stride(2); - Idx_t stride_ys_e = y_s.stride(0); - Idx_t stride_ys_t = y_s.stride(1); - Idx_t stride_ys_g = y_s.stride(2); Idx_t stride_counts_e = tokens_per_expert.stride(0); + int const NUM_GROUPS = H / GROUP_SIZE; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - #define KERNEL(BLOCK_COUNT, USE_UE8M0, THREAD_COUNT, STAGES) \ + // TODO: Get this from cuda_arch ? + static constexpr int SILU_V2_BLOCK_COUNT = 132 * 32; + + #define KERNEL(BLOCK_COUNT, scale_t, STRIDE_YS_E, STRIDE_YS_T, STRIDE_YS_G, \ + STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, STAGES) \ static constexpr int NUM_WARPS = THREAD_COUNT / WARP_SIZE; \ int sms = SILU_V2_BLOCK_COUNT; \ static constexpr int max_shared_mem_bytes = \ @@ -615,43 +650,86 @@ void persistent_masked_m_silu_mul_quant( VLLM_DISPATCH_FP8_TYPES( \ y_q.scalar_type(), "silu_mul_fp8_quant_deep_gemm_kernel", [&] { \ vllm::silu_mul_fp8_quant_deep_gemm_kernel< \ - BLOCK_COUNT, max_shared_mem_bytes, fp8_t, THREAD_COUNT, Idx_t, \ - USE_UE8M0, GROUP_SIZE, STAGES> \ + BLOCK_COUNT, max_shared_mem_bytes, fp8_t, scale_t, THREAD_COUNT, \ + Idx_t, CEIL_UE8M0, GROUP_SIZE, STAGES> \ <<>>( \ reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \ - (fp8_t*)y_q.data_ptr(), y_s.data_ptr(), \ + (fp8_t*)y_q.data_ptr(), \ + reinterpret_cast(y_s.data_ptr()), \ reinterpret_cast(tokens_per_expert.data_ptr()), E, \ T, H, stride_i_e, stride_i_t, stride_i_h, stride_yq_e, \ - stride_yq_t, stride_yq_h, stride_ys_e, stride_ys_t, \ - stride_ys_g, stride_counts_e); \ + stride_yq_t, stride_yq_h, STRIDE_YS_E, STRIDE_YS_T, \ + STRIDE_YS_G, STRIDE_YS_P, stride_counts_e); \ }); - static constexpr int SILU_V2_BLOCK_COUNT = 132 * 32; + #define LAUNCH_ON_H(scale_t, STRIDE_YS_E, STRIDE_YS_T, STRIDE_YS_G, \ + STRIDE_YS_P, CEIL_UE8M0) \ + if (H >= 4096 && (NUM_GROUPS % 8) == 0) { \ + /* 8 warp config */ \ + static constexpr int NUM_STAGES = 4; \ + static constexpr int THREAD_COUNT = 256; \ + KERNEL(SILU_V2_BLOCK_COUNT, scale_t, STRIDE_YS_E, STRIDE_YS_T, \ + STRIDE_YS_G, STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, NUM_STAGES); \ + } else { \ + /* 1 warp config */ \ + static constexpr int THREAD_COUNT = 32; \ + KERNEL(SILU_V2_BLOCK_COUNT, scale_t, STRIDE_YS_E, STRIDE_YS_T, \ + STRIDE_YS_G, STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, 2); \ + } - int const NUM_GROUPS = H / GROUP_SIZE; - if (!use_ue8m0) { - if (H >= 4096 && (NUM_GROUPS % 8 == 0)) { - /* 8 warps config */ - static constexpr int NUM_STAGES = 4; - static constexpr int THREAD_COUNT = 256; - KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, NUM_STAGES); - } else { - /* 1 warp config */ - static constexpr int THREAD_COUNT = 32; - KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, 2); - } - } else { - if (H >= 4096 && (NUM_GROUPS % 8 == 0)) { - /* 8 warps config */ - static constexpr int NUM_STAGES = 4; - static constexpr int THREAD_COUNT = 256; - KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, NUM_STAGES); - } else { - /* 1 warp config */ - static constexpr int THREAD_COUNT = 32; - KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, 2); - } + Idx_t stride_ys_e = y_s.stride(0); + Idx_t stride_ys_t = y_s.stride(1); + Idx_t stride_ys_g = y_s.stride(2); + Idx_t stride_ys_p = 0; + if (!cast_scale_ue8m0) { + TORCH_CHECK(!is_packed_ue8m0); + LAUNCH_ON_H(float, stride_ys_e, stride_ys_t, stride_ys_g, stride_ys_p, + false); + return; } + if (!is_packed_ue8m0) { + // UE8M0 but not packed + LAUNCH_ON_H(float, stride_ys_e, stride_ys_t, stride_ys_g, stride_ys_p, + true); + return; + } + + TORCH_CHECK(cast_scale_ue8m0 && is_packed_ue8m0); + TORCH_CHECK(y_s.dtype() == torch::kInt32); + + // Int32 packed ue8m0 scales tensor. + // Let E, T, G be the number to experts, number of tokens and number of groups + // respectively. Let, E = 2, T = 4, G = 6, in this case the int32 scales + // tensor are of shape [1, 4, 2] and stride [8, 1, 4]. The scales are expected + // to be arranged as follows, + // [[T0G0-T0G1-T0G2-T0G3, T0G4-T0G5-X-X,], + // [T1G0-T1G1-T1G2-T1G3, T1G4-T1G5-X-X,] + // [T2G0-T2G1-T2G2-T2G3, T2G4-T2G5-X-X,] + // [T3G0-T3G1-T3G2-T3G3, T3G4-T3G5-X-X,]] + // where, TxGy is the scale ue8m0 scale value of Token x, Group y. + // + // In memory (in bytes) the scale values are arranged as, + // [T0G0, T0G1, T0G2, T0G3, T1G0, T1G2, T1G3, T1G4, T2G0, T2G1, T2G3, T2G4, + // T3G0, T3G1, T3G2, T3G3, T0G4, T0G5, X, X, T1G4, T1G5, X, X, T2G4, T2G5, + // X, X, T3G4, T3G5, X, X] + // + // An Int32 tensor of size [1, 4, 2] and stride [8, 1, 4] can be represented + // as an uint8 tensor of shape [1, 2, 4, 4] and stride [32, 16, 4, 1]. In + // english, ignoring the Experts dimension, the original int32 tensor is + // simply treated as two packed [4, 4] uint8 tensor (or two [4, 1] int32 + // tensor). The following strides setting reflects this change. Caveat: This + // means that the G dimension is no longer contiguous. i.e. Note that to move + // from G3 to G4, we need to jump along the packing dimension. The kernel + // handles this case. + + stride_ys_e *= sizeof(int32_t); + stride_ys_p = T * sizeof(int32_t); // Packing dimension + stride_ys_t = sizeof(int32_t); + stride_ys_g = 1; + + LAUNCH_ON_H(uint8_t, stride_ys_e, stride_ys_t, stride_ys_g, stride_ys_p, + true); + #endif } diff --git a/csrc/quantization/gptq_marlin/awq_marlin_repack.cu b/csrc/quantization/gptq_marlin/awq_marlin_repack.cu index 8ba617a9e6555..e607107b3e77c 100644 --- a/csrc/quantization/gptq_marlin/awq_marlin_repack.cu +++ b/csrc/quantization/gptq_marlin/awq_marlin_repack.cu @@ -247,22 +247,6 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, return out; } -torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight, - c10::SymInt size_k, c10::SymInt size_n, - int64_t num_bits) { - int const pack_factor = 32 / num_bits; - auto options = torch::TensorOptions() - .dtype(b_q_weight.dtype()) - .device(b_q_weight.device()); - return torch::empty_symint( - {size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor}, - options); -} - TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { m.impl("awq_marlin_repack", &awq_marlin_repack); } - -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) { - m.impl("awq_marlin_repack", &awq_marlin_repack_meta); -} diff --git a/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu index 7c2d089a70d95..ad80d51ece94e 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu @@ -321,22 +321,6 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, return out; } -torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight, - torch::Tensor& perm, c10::SymInt size_k, - c10::SymInt size_n, int64_t num_bits) { - int const pack_factor = 32 / num_bits; - auto options = torch::TensorOptions() - .dtype(b_q_weight.dtype()) - .device(b_q_weight.device()); - return torch::empty_symint( - {size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor}, - options); -} - TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { m.impl("gptq_marlin_repack", &gptq_marlin_repack); } - -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) { - m.impl("gptq_marlin_repack", &gptq_marlin_repack_meta); -} diff --git a/csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu b/csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu index 5369d409f9b21..aff11326d78e9 100644 --- a/csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu +++ b/csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu @@ -802,7 +802,7 @@ torch::Tensor hadacore_transform(torch::Tensor& x, bool inplace) { }); if (numel % 256 != 0) { - out = out.index({torch::indexing::Slice(0, numel / had_size)}); + out = out.narrow(0, 0, numel / had_size); } if (inplace && out.data_ptr() != x.data_ptr()) { diff --git a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8_dispatch.cuh index 4ff3e65f2b2e1..b8433214be1ba 100644 --- a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8_dispatch.cuh +++ b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8_dispatch.cuh @@ -116,6 +116,26 @@ struct sm90_fp8_config_default { ClusterShape, KernelSchedule, EpilogueSchedule>>; }; +template +struct sm90_fp8_config_M8192_K6144 { + // M >= 8192, K >= 6144 + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum; + using EpilogueSchedule = + typename cutlass::epilogue::TmaWarpSpecializedCooperative; + using TileShape = Shape<_256, _128, _128>; + using ClusterShape = Shape<_2, _1, _1>; + + using Cutlass3xGemm = conditional_t< + EnableBias, + cutlass_3x_gemm_sm90_fp8, + cutlass_3x_gemm_sm90_fp8>; +}; + template struct sm90_fp8_config_M128 { // M in (64, 128] @@ -273,6 +293,9 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, using Cutlass3xGemmDefault = typename sm90_fp8_config_default::Cutlass3xGemm; + using Cutlass3xGemmM8192_K6144 = + typename sm90_fp8_config_M8192_K6144::Cutlass3xGemm; using Cutlass3xGemmM128 = typename sm90_fp8_config_M128::Cutlass3xGemm; @@ -291,6 +314,7 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, uint32_t const m = a.size(0); uint32_t const n = b.size(1); + uint32_t const k = a.size(1); if (m <= 16) { // m in [1, 16] @@ -312,6 +336,9 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, // m in (64, 128] return cutlass_gemm_caller_sm90_fp8( out, a, b, a_scales, b_scales, std::forward(args)...); + } else if (m >= 8192 && k >= 6144) { + return cutlass_gemm_caller_sm90_fp8( + out, a, b, a_scales, b_scales, std::forward(args)...); } else { // m in (128, inf) return cutlass_gemm_caller_sm90_fp8( diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index 06d229f315bdc..731a97d93da1f 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -15,6 +15,17 @@ RUN apt-get update -q -y && apt-get install -q -y \ # Remove sccache RUN python3 -m pip install --upgrade pip RUN apt-get purge -y sccache; python3 -m pip uninstall -y sccache; rm -f "$(which sccache)" + +# Install UV +RUN curl -LsSf https://astral.sh/uv/install.sh | env UV_INSTALL_DIR="/usr/local/bin" sh + +# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out +# Reference: https://github.com/astral-sh/uv/pull/1694 +ENV UV_HTTP_TIMEOUT=500 +ENV UV_INDEX_STRATEGY="unsafe-best-match" +# Use copy mode to avoid hardlink failures with Docker cache mounts +ENV UV_LINK_MODE=copy + ARG COMMON_WORKDIR WORKDIR ${COMMON_WORKDIR} @@ -59,13 +70,15 @@ FROM base AS test RUN python3 -m pip install --upgrade pip && rm -rf /var/lib/apt/lists/* -# Install vLLM +# Install vLLM using uv (inherited from base stage) +# Note: No -U flag to avoid upgrading PyTorch ROCm to CUDA version RUN --mount=type=bind,from=export_vllm,src=/,target=/install \ + --mount=type=cache,target=/root/.cache/uv \ cd /install \ - && pip install -U -r requirements/rocm.txt \ - && pip install -U -r requirements/rocm-test.txt \ + && uv pip install --system -r requirements/rocm.txt \ + && uv pip install --system -r requirements/rocm-test.txt \ && pip uninstall -y vllm \ - && pip install *.whl + && uv pip install --system *.whl WORKDIR /vllm-workspace ARG COMMON_WORKDIR @@ -89,14 +102,17 @@ RUN case "$(which python3)" in \ rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/;; \ *) ;; esac -RUN python3 -m pip install --upgrade huggingface-hub[cli] +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system --upgrade huggingface-hub[cli] -# Install vLLM +# Install vLLM using uv (inherited from base stage) +# Note: No -U flag to avoid upgrading PyTorch ROCm to CUDA version RUN --mount=type=bind,from=export_vllm,src=/,target=/install \ + --mount=type=cache,target=/root/.cache/uv \ cd /install \ - && pip install -U -r requirements/rocm.txt \ + && uv pip install --system -r requirements/rocm.txt \ && pip uninstall -y vllm \ - && pip install *.whl + && uv pip install --system *.whl ARG COMMON_WORKDIR diff --git a/docker/Dockerfile.rocm_base b/docker/Dockerfile.rocm_base index 19f7fa7e1468d..df4f9b6c26e7d 100644 --- a/docker/Dockerfile.rocm_base +++ b/docker/Dockerfile.rocm_base @@ -1,4 +1,4 @@ -ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:7.0-complete +ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:7.1-complete ARG TRITON_BRANCH="57c693b6" ARG TRITON_REPO="https://github.com/ROCm/triton.git" ARG PYTORCH_BRANCH="1c57644d" @@ -7,7 +7,7 @@ ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git" ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" ARG FA_BRANCH="0e60e394" ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git" -ARG AITER_BRANCH="9716b1b8" +ARG AITER_BRANCH="59bd8ff2" ARG AITER_REPO="https://github.com/ROCm/aiter.git" FROM ${BASE_IMAGE} AS base @@ -19,6 +19,9 @@ ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942;gfx950;gfx1100;gfx1101;gfx1200;gfx1201;gfx11 ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} ENV AITER_ROCM_ARCH=gfx942;gfx950 +# Required for RCCL in ROCm7.1 +ENV HSA_NO_SCRATCH_RECLAIM=1 + ARG PYTHON_VERSION=3.12 RUN mkdir -p /app diff --git a/docker/Dockerfile.xpu b/docker/Dockerfile.xpu index 4e6ef8f5ca13c..5d5b82c4fa5af 100644 --- a/docker/Dockerfile.xpu +++ b/docker/Dockerfile.xpu @@ -14,6 +14,7 @@ RUN apt clean && apt-get update -y && \ libxext6 \ libgl1 \ lsb-release \ + libaio-dev \ numactl \ wget \ vim \ @@ -68,8 +69,8 @@ RUN --mount=type=cache,target=/root/.cache/pip \ RUN python3 -m pip install -e tests/vllm_test_utils # install nixl from source code +ENV NIXL_VERSION=0.7.0 RUN python3 /workspace/vllm/tools/install_nixl_from_source_ubuntu.py -ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/lib/python3.12/dist-packages/.nixl.mesonpy.libs/plugins/" RUN --mount=type=cache,target=/root/.cache/pip \ pip uninstall oneccl oneccl-devel -y diff --git a/docs/.nav.yml b/docs/.nav.yml index c103ed476d76d..3151ea0e2ec22 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -46,7 +46,10 @@ nav: - contributing/model/multimodal.md - contributing/model/transcription.md - CI: contributing/ci - - Design Documents: design + - Design Documents: + - Plugins: + - design/*plugin*.md + - design/* - API Reference: - api/README.md - api/vllm diff --git a/docs/README.md b/docs/README.md index 0608794e7e650..0c279c19f96ca 100644 --- a/docs/README.md +++ b/docs/README.md @@ -30,8 +30,8 @@ Originally developed in the [Sky Computing Lab](https://sky.cs.berkeley.edu) at Where to get started with vLLM depends on the type of user. If you are looking to: - Run open-source models on vLLM, we recommend starting with the [Quickstart Guide](./getting_started/quickstart.md) -- Build applications with vLLM, we recommend starting with the [User Guide](./usage) -- Build vLLM, we recommend starting with [Developer Guide](./contributing) +- Build applications with vLLM, we recommend starting with the [User Guide](./usage/README.md) +- Build vLLM, we recommend starting with [Developer Guide](./contributing/README.md) For information about the development of vLLM, see: diff --git a/docs/cli/bench/latency.md b/docs/cli/bench/latency.md index 21ab13e63781a..ea7ea7321ffcd 100644 --- a/docs/cli/bench/latency.md +++ b/docs/cli/bench/latency.md @@ -4,6 +4,6 @@ --8<-- "docs/cli/json_tip.inc.md" -## Options +## Arguments ---8<-- "docs/argparse/bench_latency.md" +--8<-- "docs/argparse/bench_latency.inc.md" diff --git a/docs/cli/bench/serve.md b/docs/cli/bench/serve.md index f7c415c6becb5..f7dc8036cc262 100644 --- a/docs/cli/bench/serve.md +++ b/docs/cli/bench/serve.md @@ -4,6 +4,6 @@ --8<-- "docs/cli/json_tip.inc.md" -## Options +## Arguments ---8<-- "docs/argparse/bench_serve.md" +--8<-- "docs/argparse/bench_serve.inc.md" diff --git a/docs/cli/bench/sweep/plot.md b/docs/cli/bench/sweep/plot.md index f29bffb64655c..a101330e093cc 100644 --- a/docs/cli/bench/sweep/plot.md +++ b/docs/cli/bench/sweep/plot.md @@ -4,6 +4,6 @@ --8<-- "docs/cli/json_tip.inc.md" -## Options +## Arguments ---8<-- "docs/argparse/bench_sweep_plot.md" +--8<-- "docs/argparse/bench_sweep_plot.inc.md" diff --git a/docs/cli/bench/sweep/serve.md b/docs/cli/bench/sweep/serve.md index 5b5f91a951ed0..f0468f06fc287 100644 --- a/docs/cli/bench/sweep/serve.md +++ b/docs/cli/bench/sweep/serve.md @@ -4,6 +4,6 @@ --8<-- "docs/cli/json_tip.inc.md" -## Options +## Arguments ---8<-- "docs/argparse/bench_sweep_serve.md" +--8<-- "docs/argparse/bench_sweep_serve.inc.md" diff --git a/docs/cli/bench/sweep/serve_sla.md b/docs/cli/bench/sweep/serve_sla.md index 5f8ab6005e50b..5642ec67eb007 100644 --- a/docs/cli/bench/sweep/serve_sla.md +++ b/docs/cli/bench/sweep/serve_sla.md @@ -4,6 +4,6 @@ --8<-- "docs/cli/json_tip.inc.md" -## Options +## Arguments ---8<-- "docs/argparse/bench_sweep_serve_sla.md" +--8<-- "docs/argparse/bench_sweep_serve_sla.inc.md" diff --git a/docs/cli/bench/throughput.md b/docs/cli/bench/throughput.md index e4ff5ce43c9ce..e7f618fb4d147 100644 --- a/docs/cli/bench/throughput.md +++ b/docs/cli/bench/throughput.md @@ -4,6 +4,6 @@ --8<-- "docs/cli/json_tip.inc.md" -## Options +## Arguments ---8<-- "docs/argparse/bench_throughput.md" +--8<-- "docs/argparse/bench_throughput.inc.md" diff --git a/docs/cli/chat.md b/docs/cli/chat.md index b006cb8de60d0..0246bd431b101 100644 --- a/docs/cli/chat.md +++ b/docs/cli/chat.md @@ -1,5 +1,5 @@ # vllm chat -## Options +## Arguments ---8<-- "docs/argparse/chat.md" +--8<-- "docs/argparse/chat.inc.md" diff --git a/docs/cli/complete.md b/docs/cli/complete.md index 400359acf4fb8..eb2ffdaabac25 100644 --- a/docs/cli/complete.md +++ b/docs/cli/complete.md @@ -1,5 +1,5 @@ # vllm complete -## Options +## Arguments ---8<-- "docs/argparse/complete.md" +--8<-- "docs/argparse/complete.inc.md" diff --git a/docs/cli/run-batch.md b/docs/cli/run-batch.md index f7d401b8dad2b..758fbda283978 100644 --- a/docs/cli/run-batch.md +++ b/docs/cli/run-batch.md @@ -4,6 +4,6 @@ --8<-- "docs/cli/json_tip.inc.md" -## Options +## Arguments ---8<-- "docs/argparse/run-batch.md" +--8<-- "docs/argparse/run-batch.inc.md" diff --git a/docs/cli/serve.md b/docs/cli/serve.md index 2c8f9d320f5df..35652fec587b3 100644 --- a/docs/cli/serve.md +++ b/docs/cli/serve.md @@ -4,6 +4,6 @@ --8<-- "docs/cli/json_tip.inc.md" -## Options +## Arguments ---8<-- "docs/argparse/serve.md" +--8<-- "docs/argparse/serve.inc.md" diff --git a/docs/configuration/serve_args.md b/docs/configuration/serve_args.md index c1cc5577bc7ab..baaf21f01f066 100644 --- a/docs/configuration/serve_args.md +++ b/docs/configuration/serve_args.md @@ -5,7 +5,7 @@ The `vllm serve` command is used to launch the OpenAI-compatible server. ## CLI Arguments The `vllm serve` command is used to launch the OpenAI-compatible server. -To see the available options, take a look at the [CLI Reference](../cli/README.md#options)! +To see the available options, take a look at the [CLI Reference](../cli/README.md)! ## Configuration file diff --git a/docs/contributing/benchmarks.md b/docs/contributing/benchmarks.md index dca01eab5b426..c9bc9cfe28a35 100644 --- a/docs/contributing/benchmarks.md +++ b/docs/contributing/benchmarks.md @@ -10,8 +10,6 @@ vLLM provides comprehensive benchmarking tools for performance testing and evalu - **[Parameter sweeps](#parameter-sweeps)**: Automate `vllm bench` runs for multiple configurations - **[Performance benchmarks](#performance-benchmarks)**: Automated CI benchmarks for development -[Benchmark CLI]: #benchmark-cli - ## Benchmark CLI This section guides you through running benchmark tests with the extensive @@ -985,7 +983,7 @@ each document has close to 512 tokens. Please note that the `/v1/rerank` is also supported by embedding models. So if you're running with an embedding model, also set `--no_reranker`. Because in this case the query is -treated as a individual prompt by the server, here we send `random_batch_size - 1` documents +treated as an individual prompt by the server, here we send `random_batch_size - 1` documents to account for the extra prompt which is the query. The token accounting to report the throughput numbers correctly is also adjusted. diff --git a/docs/contributing/ci/update_pytorch_version.md b/docs/contributing/ci/update_pytorch_version.md index f983c25f26ee1..09fd85a466eed 100644 --- a/docs/contributing/ci/update_pytorch_version.md +++ b/docs/contributing/ci/update_pytorch_version.md @@ -95,7 +95,7 @@ when manually triggering a build on Buildkite. This branch accomplishes two thin to warm it up so that future builds are faster.

- + Buildkite new build popup

## Update dependencies diff --git a/docs/contributing/model/README.md b/docs/contributing/model/README.md index d8c40c5195735..13f3edb7e1af1 100644 --- a/docs/contributing/model/README.md +++ b/docs/contributing/model/README.md @@ -1,7 +1,7 @@ # Summary !!! important - Many decoder language models can now be automatically loaded using the [Transformers backend](../../models/supported_models.md#transformers) without having to implement them in vLLM. See if `vllm serve ` works first! + Many decoder language models can now be automatically loaded using the [Transformers modeling backend](../../models/supported_models.md#transformers) without having to implement them in vLLM. See if `vllm serve ` works first! vLLM models are specialized [PyTorch](https://pytorch.org/) models that take advantage of various [features](../../features/README.md#compatibility-matrix) to optimize their performance. diff --git a/docs/contributing/model/transcription.md b/docs/contributing/model/transcription.md index a590ecd6a1a23..fca941acd5076 100644 --- a/docs/contributing/model/transcription.md +++ b/docs/contributing/model/transcription.md @@ -249,7 +249,7 @@ No extra registration is required beyond having your model class available via t ## Examples in-tree - Whisper encoder–decoder (audio-only): [vllm/model_executor/models/whisper.py](../../../vllm/model_executor/models/whisper.py) -- Voxtral decoder-only (audio embeddings + LLM): [vllm/model_executor/models/voxtral.py](../../../vllm/model_executor/models/voxtral.py) +- Voxtral decoder-only (audio embeddings + LLM): [vllm/model_executor/models/voxtral.py](../../../vllm/model_executor/models/voxtral.py). Make sure to have installed `mistral-common[audio]`. - Gemma3n decoder-only with fixed instruction prompt: [vllm/model_executor/models/gemma3n_mm.py](../../../vllm/model_executor/models/gemma3n_mm.py) ## Test with the API diff --git a/docs/deployment/frameworks/chatbox.md b/docs/deployment/frameworks/chatbox.md index 002935da56009..5f7cef1a87dfb 100644 --- a/docs/deployment/frameworks/chatbox.md +++ b/docs/deployment/frameworks/chatbox.md @@ -29,8 +29,8 @@ pip install vllm - API Path: `/chat/completions` - Model: `qwen/Qwen1.5-0.5B-Chat` - ![](../../assets/deployment/chatbox-settings.png) + ![Chatbox settings screen](../../assets/deployment/chatbox-settings.png) 1. Go to `Just chat`, and start to chat: - ![](../../assets/deployment/chatbox-chat.png) + ![Chatbot chat screen](../../assets/deployment/chatbox-chat.png) diff --git a/docs/deployment/frameworks/dify.md b/docs/deployment/frameworks/dify.md index 820ef0cbed9fa..673cbf4b6a24a 100644 --- a/docs/deployment/frameworks/dify.md +++ b/docs/deployment/frameworks/dify.md @@ -46,12 +46,12 @@ And install [Docker](https://docs.docker.com/engine/install/) and [Docker Compos - **Model Name for API Endpoint**: `Qwen/Qwen1.5-7B-Chat` - **Completion Mode**: `Completion` - ![](../../assets/deployment/dify-settings.png) + ![Dify settings screen](../../assets/deployment/dify-settings.png) 1. To create a test chatbot, go to `Studio → Chatbot → Create from Blank`, then select Chatbot as the type: - ![](../../assets/deployment/dify-create-chatbot.png) + ![Dify create chatbot screen](../../assets/deployment/dify-create-chatbot.png) 1. Click the chatbot you just created to open the chat interface and start interacting with the model: - ![](../../assets/deployment/dify-chat.png) + ![Dify chat screen](../../assets/deployment/dify-chat.png) diff --git a/docs/deployment/frameworks/hf_inference_endpoints.md b/docs/deployment/frameworks/hf_inference_endpoints.md index d39bb9a899c8a..05df0dacd8f11 100644 --- a/docs/deployment/frameworks/hf_inference_endpoints.md +++ b/docs/deployment/frameworks/hf_inference_endpoints.md @@ -156,7 +156,7 @@ In this guide, we demonstrate manual deployment using the [`rednote-hilab/dots.o ## Advanced Deployment Details -With the [transformers backend integration](https://blog.vllm.ai/2025/04/11/transformers-backend.html), vLLM now offers Day 0 support for any model compatible with `transformers`. This means you can deploy such models immediately, leveraging vLLM’s optimized inference without additional backend modifications. +With the [Transformers modeling backend integration](https://blog.vllm.ai/2025/04/11/transformers-backend.html), vLLM now offers Day 0 support for any model compatible with `transformers`. This means you can deploy such models immediately, leveraging vLLM’s optimized inference without additional backend modifications. Hugging Face Inference Endpoints provides a fully managed environment for serving models via vLLM. You can deploy models without configuring servers, installing dependencies, or managing clusters. Endpoints also support deployment across multiple cloud providers (AWS, Azure, GCP) without the need for separate accounts. @@ -167,4 +167,4 @@ The platform integrates seamlessly with the Hugging Face Hub, allowing you to de - Explore the [Inference Endpoints](https://endpoints.huggingface.co/catalog) model catalog - Read the Inference Endpoints [documentation](https://huggingface.co/docs/inference-endpoints/en/index) - Learn about [Inference Endpoints engines](https://huggingface.co/docs/inference-endpoints/en/engines/vllm) -- Understand the [transformers backend integration](https://blog.vllm.ai/2025/04/11/transformers-backend.html) +- Understand the [Transformers modeling backend integration](https://blog.vllm.ai/2025/04/11/transformers-backend.html) diff --git a/docs/design/cuda_graphs.md b/docs/design/cuda_graphs.md index aac7b76eea265..66bf3b27d1f52 100644 --- a/docs/design/cuda_graphs.md +++ b/docs/design/cuda_graphs.md @@ -128,7 +128,7 @@ A [CUDAGraphWrapper][vllm.compilation.cuda_graph.CUDAGraphWrapper] instance wrap 3. Otherwise, i.e., the runtime_mode matches the mode of the wrapper, the wrapper will perform CUDA Graphs capture (if key does not exist, create a new entry and cache it) or replay (if key exists in the cache). -The above steps are based on the assumption that the CUDA Graphs wrapper would directly trust what’s in the forward context (controlled by the dispatcher). This lets us simplify and cenralize the logic, reducing the complexity as well as the risk of mismatched state between the wrappers and the dispatcher. It also allows reusing the wrapper class for both `FULL` and `PIECEWISE` runtime modes. See the implementation [here](https://github.com/vllm-project/vllm/blob/f751e50b7a2aae3110d83ed0d88202fc91b3e78a/vllm/compilation/cuda_graph.py#L106). +The above steps are based on the assumption that the CUDA Graphs wrapper would directly trust what’s in the forward context (controlled by the dispatcher). This lets us simplify and centralize the logic, reducing the complexity as well as the risk of mismatched state between the wrappers and the dispatcher. It also allows reusing the wrapper class for both `FULL` and `PIECEWISE` runtime modes. See the implementation [here](https://github.com/vllm-project/vllm/blob/f751e50b7a2aae3110d83ed0d88202fc91b3e78a/vllm/compilation/cuda_graph.py#L106). #### Nested Wrapper design diff --git a/docs/design/fused_moe_modular_kernel.md b/docs/design/fused_moe_modular_kernel.md index 76df0d8d8a38f..e1a96be6c3445 100644 --- a/docs/design/fused_moe_modular_kernel.md +++ b/docs/design/fused_moe_modular_kernel.md @@ -19,9 +19,9 @@ The input activation format completely depends on the All2All Dispatch being use The FusedMoE operation is generally made of multiple operations, in both the Contiguous and Batched variants, as described in the diagrams below -![](../assets/design/fused_moe_modular_kernel/fused_moe_non_batched.png "FusedMoE Non-Batched") +![FusedMoE Non-Batched](../assets/design/fused_moe_modular_kernel/fused_moe_non_batched.png) -![](../assets/design/fused_moe_modular_kernel/fused_moe_batched.png "FusedMoE Batched") +![FusedMoE Batched](../assets/design/fused_moe_modular_kernel/fused_moe_batched.png) !!! note The main difference, in terms of operations, between the Batched and Non-Batched cases is the Permute / Unpermute operations. All other operations remain. @@ -57,7 +57,7 @@ The `FusedMoEModularKernel` acts as a bridge between the `FusedMoEPermuteExperts The `FusedMoEPrepareAndFinalize` abstract class exposes `prepare`, `prepare_no_receive` and `finalize` functions. The `prepare` function is responsible for input activation Quantization and All2All Dispatch. If implemented, The `prepare_no_receive` is like `prepare` except it does not wait to receive results from other workers. Instead it returns a "receiver" callback that must be invoked to wait for the final results of worker. It is not required that this method is supported by all `FusedMoEPrepareAndFinalize` classes, but if it is available, it can be used to interleave work with the initial all to all communication, e.g. interleaving shared experts with fused experts. The `finalize` function is responsible for invoking the All2All Combine. Additionally the `finalize` function may or may not do the TopK weight application and reduction (Please refer to the TopKWeightAndReduce section) -![](../assets/design/fused_moe_modular_kernel/prepare_and_finalize_blocks.png "FusedMoEPrepareAndFinalize Blocks") +![FusedMoEPrepareAndFinalize Blocks](../assets/design/fused_moe_modular_kernel/prepare_and_finalize_blocks.png) ### FusedMoEPermuteExpertsUnpermute @@ -88,7 +88,7 @@ The core FusedMoE implementation performs a series of operations. It would be in It is sometimes efficient to perform TopK weight application and Reduction inside the `FusedMoEPermuteExpertsUnpermute::apply()`. Find an example [here](https://github.com/vllm-project/vllm/pull/20228). We have a `TopKWeightAndReduce` abstract class to facilitate such implementations. Please refer to the TopKWeightAndReduce section. `FusedMoEPermuteExpertsUnpermute::finalize_weight_and_reduce_impl()` returns the `TopKWeightAndReduce` object that the implementation wants the `FusedMoEPrepareAndFinalize::finalize()` to use. -![](../assets/design/fused_moe_modular_kernel/fused_experts_blocks.png "FusedMoEPermuteExpertsUnpermute Blocks") +![FusedMoEPermuteExpertsUnpermute Blocks](../assets/design/fused_moe_modular_kernel/fused_experts_blocks.png) ### FusedMoEModularKernel diff --git a/docs/design/lora_resolver_plugins.md b/docs/design/lora_resolver_plugins.md new file mode 100644 index 0000000000000..bd0dc6dc9c7bb --- /dev/null +++ b/docs/design/lora_resolver_plugins.md @@ -0,0 +1,220 @@ +# LoRA Resolver Plugins + +This directory contains vLLM's LoRA resolver plugins built on the `LoRAResolver` framework. +They automatically discover and load LoRA adapters from a specified local storage path, eliminating the need for manual configuration or server restarts. + +## Overview + +LoRA Resolver Plugins provide a flexible way to dynamically load LoRA adapters at runtime. When vLLM +receives a request for a LoRA adapter that hasn't been loaded yet, the resolver plugins will attempt +to locate and load the adapter from their configured storage locations. This enables: + +- **Dynamic LoRA Loading**: Load adapters on-demand without server restarts +- **Multiple Storage Backends**: Support for filesystem, S3, and custom backends. The built-in `lora_filesystem_resolver` requires a local storage path, but custom resolvers can be implemented to fetch from any source. +- **Automatic Discovery**: Seamless integration with existing LoRA workflows +- **Scalable Deployment**: Centralized adapter management across multiple vLLM instances + +## Prerequisites + +Before using LoRA Resolver Plugins, ensure the following environment variables are configured: + +### Required Environment Variables + +1. **`VLLM_ALLOW_RUNTIME_LORA_UPDATING`**: Must be set to `true` or `1` to enable dynamic LoRA loading + ```bash + export VLLM_ALLOW_RUNTIME_LORA_UPDATING=true + ``` + +2. **`VLLM_PLUGINS`**: Must include the desired resolver plugins (comma-separated list) + ```bash + export VLLM_PLUGINS=lora_filesystem_resolver + ``` + +3. **`VLLM_LORA_RESOLVER_CACHE_DIR`**: Must be set to a valid directory path for filesystem resolver + ```bash + export VLLM_LORA_RESOLVER_CACHE_DIR=/path/to/lora/adapters + ``` + +### Optional Environment Variables + +- **`VLLM_PLUGINS`**: If not set, all available plugins will be loaded. If set to empty string, no plugins will be loaded. + +## Available Resolvers + +### lora_filesystem_resolver + +The filesystem resolver is installed with vLLM by default and enables loading LoRA adapters from a local directory structure. + +#### Setup Steps + +1. **Create the LoRA adapter storage directory**: + ```bash + mkdir -p /path/to/lora/adapters + ``` + +2. **Set environment variables**: + ```bash + export VLLM_ALLOW_RUNTIME_LORA_UPDATING=true + export VLLM_PLUGINS=lora_filesystem_resolver + export VLLM_LORA_RESOLVER_CACHE_DIR=/path/to/lora/adapters + ``` + +3. **Start vLLM server**: + Your base model can be `meta-llama/Llama-2-7b-hf`. Please make sure you set up the Hugging Face token in your env var `export HF_TOKEN=xxx235`. + ```bash + python -m vllm.entrypoints.openai.api_server \ + --model your-base-model \ + --enable-lora + ``` + +#### Directory Structure Requirements + +The filesystem resolver expects LoRA adapters to be organized in the following structure: + +```text +/path/to/lora/adapters/ +├── adapter1/ +│ ├── adapter_config.json +│ ├── adapter_model.bin +│ └── tokenizer files (if applicable) +├── adapter2/ +│ ├── adapter_config.json +│ ├── adapter_model.bin +│ └── tokenizer files (if applicable) +└── ... +``` + +Each adapter directory must contain: + +- **`adapter_config.json`**: Required configuration file with the following structure: + ```json + { + "peft_type": "LORA", + "base_model_name_or_path": "your-base-model-name", + "r": 16, + "lora_alpha": 32, + "target_modules": ["q_proj", "v_proj"], + "bias": "none", + "modules_to_save": null, + "use_rslora": false, + "use_dora": false + } + ``` + +- **`adapter_model.bin`**: The LoRA adapter weights file + +#### Usage Example + +1. **Prepare your LoRA adapter**: + ```bash + # Assuming you have a LoRA adapter in /tmp/my_lora_adapter + cp -r /tmp/my_lora_adapter /path/to/lora/adapters/my_sql_adapter + ``` + +2. **Verify the directory structure**: + ```bash + ls -la /path/to/lora/adapters/my_sql_adapter/ + # Should show: adapter_config.json, adapter_model.bin, etc. + ``` + +3. **Make a request using the adapter**: + ```bash + curl http://localhost:8000/v1/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "my_sql_adapter", + "prompt": "Generate a SQL query for:", + "max_tokens": 50, + "temperature": 0.1 + }' + ``` + +#### How It Works + +1. When vLLM receives a request for a LoRA adapter named `my_sql_adapter` +2. The filesystem resolver checks if `/path/to/lora/adapters/my_sql_adapter/` exists +3. If found, it validates the `adapter_config.json` file +4. If the configuration matches the base model and is valid, the adapter is loaded +5. The request is processed normally with the newly loaded adapter +6. The adapter remains available for future requests + +## Advanced Configuration + +### Multiple Resolvers + +You can configure multiple resolver plugins to load adapters from different sources: + +'lora_s3_resolver' is an example of a custom resolver you would need to implement + +```bash +export VLLM_PLUGINS=lora_filesystem_resolver,lora_s3_resolver +``` + +All listed resolvers are enabled; at request time, vLLM tries them in order until one succeeds. + +### Custom Resolver Implementation + +To implement your own resolver plugin: + +1. **Create a new resolver class**: + ```python + from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry + from vllm.lora.request import LoRARequest + + class CustomResolver(LoRAResolver): + async def resolve_lora(self, base_model_name: str, lora_name: str) -> Optional[LoRARequest]: + # Your custom resolution logic here + pass + ``` + +2. **Register the resolver**: + ```python + def register_custom_resolver(): + resolver = CustomResolver() + LoRAResolverRegistry.register_resolver("Custom Resolver", resolver) + ``` + +## Troubleshooting + +### Common Issues + +1. **"VLLM_LORA_RESOLVER_CACHE_DIR must be set to a valid directory"** + - Ensure the directory exists and is accessible + - Check file permissions on the directory + +2. **"LoRA adapter not found"** + - Verify the adapter directory name matches the requested model name + - Check that `adapter_config.json` exists and is valid JSON + - Ensure `adapter_model.bin` exists in the directory + +3. **"Invalid adapter configuration"** + - Verify `peft_type` is set to "LORA" + - Check that `base_model_name_or_path` matches your base model + - Ensure `target_modules` is properly configured + +4. **"LoRA rank exceeds maximum"** + - Check that `r` value in `adapter_config.json` doesn't exceed `max_lora_rank` setting + +### Debugging Tips + +1. **Enable debug logging**: + ```bash + export VLLM_LOGGING_LEVEL=DEBUG + ``` + +2. **Verify environment variables**: + ```bash + echo $VLLM_ALLOW_RUNTIME_LORA_UPDATING + echo $VLLM_PLUGINS + echo $VLLM_LORA_RESOLVER_CACHE_DIR + ``` + +3. **Test adapter configuration**: + ```bash + python -c " + import json + with open('/path/to/lora/adapters/my_adapter/adapter_config.json') as f: + config = json.load(f) + print('Config valid:', config) + " + ``` diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md index ee224e6922fbd..7663b82266f0b 100644 --- a/docs/design/moe_kernel_features.md +++ b/docs/design/moe_kernel_features.md @@ -68,7 +68,7 @@ Modular kernels are supported by the following `FusedMoEMethodBase` classes. ## Fused MoE Experts Kernels -The are a number of MoE experts kernel implementations for different quantization types and architectures. Most follow the general API of the base Triton [`fused_experts`][vllm.model_executor.layers.fused_moe.fused_moe.fused_experts] function. Many have modular kernel adatpers so they can be used with compatible all2all backends. This table lists each experts kernel and its particular properties. +The are a number of MoE experts kernel implementations for different quantization types and architectures. Most follow the general API of the base Triton [`fused_experts`][vllm.model_executor.layers.fused_moe.fused_moe.fused_experts] function. Many have modular kernel adapters so they can be used with compatible all2all backends. This table lists each experts kernel and its particular properties. Each kernel must be provided with one of the supported input activation formats. Some flavors of kernels support both standard and batched formats through different entry points, e.g. `TritonExperts` and `BatchedTritonExperts`. Batched format kernels are currently only needed for matching with certain all2all backends, e.g. `pplx`, `DeepEPLLPrepareAndFinalize`. diff --git a/docs/features/custom_arguments.md b/docs/features/custom_arguments.md index 7a650d0e79c23..728a2c89901de 100644 --- a/docs/features/custom_arguments.md +++ b/docs/features/custom_arguments.md @@ -5,7 +5,7 @@ You can use vLLM *custom arguments* to pass in arguments which are not part of t Custom arguments can be useful if, for example, you want to use a [custom logits processor](./custom_logitsprocs.md) without modifying the vLLM source code. !!! note - Make sure your custom logits processor have implemented `validate_params` for custom arguments. Otherwise invalid custom arguments can cause unexpected behaviour. + Make sure your custom logits processor have implemented `validate_params` for custom arguments. Otherwise, invalid custom arguments can cause unexpected behaviour. ## Offline Custom Arguments diff --git a/docs/features/custom_logitsprocs.md b/docs/features/custom_logitsprocs.md index 52fcc44efacc5..5ddef9db1611b 100644 --- a/docs/features/custom_logitsprocs.md +++ b/docs/features/custom_logitsprocs.md @@ -71,7 +71,7 @@ Logits processor `update_state()` implementations should assume the following mo * **"Condense" the batch to be contiguous:** starting with the lowest-index empty slot (which was caused by a Remove), apply a Unidirectional Move from the current highest non-empty slot in the batch to fill the empty slot. Proceed with additional Unidirectional Move operations in order of increasing empty slot destination index and decreasing non-empty slot source index until the batch is contiguous - * **Shrink the batch:** a side-effect of condensing the batch is that empty slots resulting from Remove operations are grouped in a contiguous block at the end of the batch array. Thus, after condensing, update `BatchUpdate.batch_size` to reflect the number of non-empty slots + * **Shrink the batch:** a side effect of condensing the batch is that empty slots resulting from Remove operations are grouped in a contiguous block at the end of the batch array. Thus, after condensing, update `BatchUpdate.batch_size` to reflect the number of non-empty slots 5. Reorder the batch for improved efficiency. Depending on the attention backend implementation and the current characteristics of the batch, zero or more Swap Move operations may be applied to reorder the batch @@ -286,7 +286,7 @@ Once you have created a custom subclass (like `WrappedPerReqLogitsProcessor`) wh ## Ways to Load Your Custom Logits Processor in vLLM -Logits processors are loaded at initialization. Critically, the set of loaded logits processors cannot be modified after the vLLM engine finishes loading, and new logits logits processors cannot be loaded on-demand for individual requests. +Logits processors are loaded at initialization. Critically, the set of loaded logits processors cannot be modified after the vLLM engine finishes loading, and new logits processors cannot be loaded on-demand for individual requests. This section details different ways of making your logits processor visible to vLLM and triggering vLLM to load your logits processor. @@ -438,7 +438,7 @@ The examples below show how a user would pass a custom argument (`target_token`) ## Best Practices for Writing Custom Logits Processors -Once vLLM loads a logits processor during initialization, then vLLM will invoke `update_state()` and `apply()` against that logits processor in every engine step. Both methods operate on all requests which currently reside in the vLLM persistent batch. Thus it is important to implement these methods efficiently. +Once vLLM loads a logits processor during initialization, then vLLM will invoke `update_state()` and `apply()` against that logits processor in every engine step. Both methods operate on all requests which currently reside in the vLLM persistent batch. Thus, it is important to implement these methods efficiently. * Write efficient `apply()` and `update_state()` implementations in light of the fact that logits processors operate at batch granularity * For example, you may be able to use efficient vectorized operations to implement `apply()` or update internal state vectors in `update_state()` @@ -465,4 +465,4 @@ Once vLLM loads a logits processor during initialization, then vLLM will invoke * **Note:** for wrapped per-request logits processors, the `AdapterLogitsProcessor` base-class handles this by default -* `is_argmax_invariant()` can be hard-coded to `True` or `False` if the logits processor has consistent behavior. However the argmax invariance may also be determined programmatically (i.e. if your logits processor is user-customizable in some way that impacts whether the logits processor is argmax invariant). For this reason, `is_argmax_invariant()` is not a class method +* `is_argmax_invariant()` can be hard-coded to `True` or `False` if the logits processor has consistent behavior. However, the argmax invariance may also be determined programmatically (i.e. if your logits processor is user-customizable in some way that impacts whether the logits processor is argmax invariant). For this reason, `is_argmax_invariant()` is not a class method diff --git a/docs/features/interleaved_thinking.md b/docs/features/interleaved_thinking.md new file mode 100644 index 0000000000000..7343324b48494 --- /dev/null +++ b/docs/features/interleaved_thinking.md @@ -0,0 +1,118 @@ +# Interleaved Thinking + +## Introduction + +Interleaved thinking allows models to reason between tool calls, enabling more sophisticated decision-making after receiving tool results. This feature helps models chain multiple tool calls with reasoning steps in between and make nuanced decisions based on intermediate results. + +Important: Interleaved thinking increases token usage and response latency. Consider your budget and performance requirements when enabling this feature. + +## How Interleaved Thinking Works + +With interleaved thinking, the model can: + +- Reason about the results of a tool call before deciding what to do next +- Chain multiple tool calls with reasoning steps in between +- Make more nuanced decisions based on intermediate results +- Provide transparent reasoning for its tool selection process + +## Supported Models + +vLLM currently supports the following interleaved thinking models: + +| Model Series | Reasoning Parser Name | +|--------------|-----------------------| +| moonshotai/Kimi-K2-Thinking | kimi_k2 | +| MiniMaxAI/MiniMax-M2 | minimax_m2 | + +## Example Usage + +To use interleaved thinking with tool calls, specify a model that supports this feature and enable tool calls in your chat completion request. Here's an example: + +??? code + + ```python + """ + vllm serve MiniMaxAI/MiniMax-M2 \ + --tensor-parallel-size 4 \ + --tool-call-parser minimax_m2 \ + --reasoning-parser minimax_m2 \ + --enable-auto-tool-choice + """ + import json + + from openai import OpenAI + + client = OpenAI(base_url="http://localhost:8000/v1", api_key="dummy") + + + def get_current_weather(location: str, unit: "str"): + """Get the current weather in a given location""" + if unit == "celsius": + return f"The current temperature in {location} is 22°C." + else: + return f"The current temperature in {location} is 72°F." + + + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "City and state, e.g., 'San Francisco, CA'", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location", "unit"], + }, + }, + } + ] + messages = [{"role": "user", "content": "What's the weather in Fahrenheit like in San Francisco?"}] + response = client.chat.completions.create( + model=client.models.list().data[0].id, + messages=messages, + tools=tools, + tool_choice="auto", + ) + + tool_call = response.choices[0].message.tool_calls[0].function + + messages.append( + { + "role": "assistant", + "tool_calls": response.choices[0].message.tool_calls, + "reasoning": response.choices[0].message.reasoning, # append reasoning + } + ) + + # Simulate tool execution + available_tools = {"get_weather": get_current_weather} + + completion_tool_calls = response.choices[0].message.tool_calls + for call in completion_tool_calls: + tool_to_call = available_tools[call.function.name] + args = json.loads(call.function.arguments) + result = tool_to_call(**args) + messages.append( + { + "role": "tool", + "content": result, + "tool_call_id": call.id, + "name": call.function.name, + } + ) + response_2 = client.chat.completions.create( + model=client.models.list().data[0].id, + messages=messages, + tools=tools, + tool_choice="auto", + ) + print(response_2.choices[0].message.content) + ``` +This example demonstrates how to set up interleaved thinking with tool calls using a weather retrieval function. The model reasons about the tool results before generating the final response. diff --git a/docs/features/quantization/quark.md b/docs/features/quantization/quark.md index be0702f4c9e16..bd7bc186e13aa 100644 --- a/docs/features/quantization/quark.md +++ b/docs/features/quantization/quark.md @@ -298,7 +298,7 @@ There are two steps to generate and deploy a mixed precision model quantized wit Firstly, the layerwise mixed-precision configuration for a given LLM model is searched and then quantized using AMD Quark. We will provide a detailed tutorial with Quark APIs later. -As examples, we provide some ready-to-use quantized mixed precision model to show the usage in vLLM and the accuracy benifits. They are: +As examples, we provide some ready-to-use quantized mixed precision model to show the usage in vLLM and the accuracy benefits. They are: - amd/Llama-2-70b-chat-hf-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8 - amd/Mixtral-8x7B-Instruct-v0.1-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8 diff --git a/docs/getting_started/installation/cpu.apple.inc.md b/docs/getting_started/installation/cpu.apple.inc.md index 7e2ed55008a57..4dc707d5f9a14 100644 --- a/docs/getting_started/installation/cpu.apple.inc.md +++ b/docs/getting_started/installation/cpu.apple.inc.md @@ -28,10 +28,15 @@ After installation of XCode and the Command Line Tools, which include Apple Clan ```bash git clone https://github.com/vllm-project/vllm.git cd vllm -uv pip install -r requirements/cpu.txt +uv pip install -r requirements/cpu.txt --index-strategy unsafe-best-match uv pip install -e . ``` +!!! tip + The `--index-strategy unsafe-best-match` flag is needed to resolve dependencies across multiple package indexes (PyTorch CPU index and PyPI). Without this flag, you may encounter `typing-extensions` version conflicts. + + The term "unsafe" refers to the package resolution strategy, not security. By default, `uv` only searches the first index where a package is found to prevent dependency confusion attacks. This flag allows `uv` to search all configured indexes to find the best compatible versions. Since both PyTorch and PyPI are trusted package sources, using this strategy is safe and appropriate for vLLM installation. + !!! note On macOS the `VLLM_TARGET_DEVICE` is automatically set to `cpu`, which is currently the only supported device. diff --git a/docs/getting_started/installation/cpu.md b/docs/getting_started/installation/cpu.md index e8bfca0e5e88f..be99cef3723e6 100644 --- a/docs/getting_started/installation/cpu.md +++ b/docs/getting_started/installation/cpu.md @@ -104,7 +104,7 @@ Currently, there are no pre-built CPU wheels. ### Which `dtype` should be used? -- Currently vLLM CPU uses model default settings as `dtype`. However, due to unstable float16 support in torch CPU, it is recommended to explicitly set `dtype=bfloat16` if there are any performance or accuracy problem. +- Currently, vLLM CPU uses model default settings as `dtype`. However, due to unstable float16 support in torch CPU, it is recommended to explicitly set `dtype=bfloat16` if there are any performance or accuracy problem. ### How to launch a vLLM service on CPU? diff --git a/docs/getting_started/installation/cpu.s390x.inc.md b/docs/getting_started/installation/cpu.s390x.inc.md index 442c2b4ec64e8..c2163139a7c5d 100644 --- a/docs/getting_started/installation/cpu.s390x.inc.md +++ b/docs/getting_started/installation/cpu.s390x.inc.md @@ -2,7 +2,7 @@ vLLM has experimental support for s390x architecture on IBM Z platform. For now, users must build from source to natively run on IBM Z platform. -Currently the CPU implementation for s390x architecture supports FP32 datatype only. +Currently, the CPU implementation for s390x architecture supports FP32 datatype only. !!! warning There are no pre-built wheels or images for this device, so you must build vLLM from source. diff --git a/docs/getting_started/installation/cpu.x86.inc.md b/docs/getting_started/installation/cpu.x86.inc.md index 00f3b726b1a0e..310f179cb89ca 100644 --- a/docs/getting_started/installation/cpu.x86.inc.md +++ b/docs/getting_started/installation/cpu.x86.inc.md @@ -83,7 +83,7 @@ uv pip install dist/*.whl !!! example "Troubleshooting" - **NumPy ≥2.0 error**: Downgrade using `pip install "numpy<2.0"`. - **CMake picks up CUDA**: Add `CMAKE_DISABLE_FIND_PACKAGE_CUDA=ON` to prevent CUDA detection during CPU builds, even if CUDA is installed. - - `AMD` requies at least 4th gen processors (Zen 4/Genoa) or higher to support [AVX512](https://www.phoronix.com/review/amd-zen4-avx512) to run vLLM on CPU. + - `AMD` requires at least 4th gen processors (Zen 4/Genoa) or higher to support [AVX512](https://www.phoronix.com/review/amd-zen4-avx512) to run vLLM on CPU. - If you receive an error such as: `Could not find a version that satisfies the requirement torch==X.Y.Z+cpu+cpu`, consider updating [pyproject.toml](https://github.com/vllm-project/vllm/blob/main/pyproject.toml) to help pip resolve the dependency. ```toml title="pyproject.toml" [build-system] diff --git a/docs/mkdocs/hooks/generate_argparse.py b/docs/mkdocs/hooks/generate_argparse.py index ce1c5c53cf35a..735074c08b8c8 100644 --- a/docs/mkdocs/hooks/generate_argparse.py +++ b/docs/mkdocs/hooks/generate_argparse.py @@ -1,12 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import importlib +import importlib.metadata +import importlib.util import logging import sys import traceback -from argparse import SUPPRESS, HelpFormatter +from argparse import SUPPRESS, Action, HelpFormatter +from collections.abc import Iterable +from importlib.machinery import ModuleSpec from pathlib import Path -from typing import Literal +from typing import TYPE_CHECKING, Literal from unittest.mock import MagicMock, patch from pydantic_core import core_schema @@ -19,6 +22,11 @@ ARGPARSE_DOC_DIR = ROOT_DIR / "docs/argparse" sys.path.insert(0, str(ROOT_DIR)) +def mock_if_no_torch(mock_module: str, mock: MagicMock): + if not importlib.util.find_spec("torch"): + sys.modules[mock_module] = mock + + # Mock custom op code class MockCustomOp: @staticmethod @@ -29,18 +37,21 @@ class MockCustomOp: return decorator -noop = lambda *a, **k: None -sys.modules["vllm._C"] = MagicMock() -sys.modules["vllm.model_executor.custom_op"] = MagicMock(CustomOp=MockCustomOp) -sys.modules["vllm.utils.torch_utils"] = MagicMock(direct_register_custom_op=noop) +mock_if_no_torch("vllm._C", MagicMock()) +mock_if_no_torch("vllm.model_executor.custom_op", MagicMock(CustomOp=MockCustomOp)) +mock_if_no_torch( + "vllm.utils.torch_utils", MagicMock(direct_register_custom_op=lambda *a, **k: None) +) + # Mock any version checks by reading from compiled CI requirements with open(ROOT_DIR / "requirements/test.txt") as f: VERSIONS = dict(line.strip().split("==") for line in f if "==" in line) importlib.metadata.version = lambda name: VERSIONS.get(name) or "0.0.0" + # Make torch.nn.Parameter safe to inherit from -sys.modules["torch.nn"] = MagicMock(Parameter=object) +mock_if_no_torch("torch.nn", MagicMock(Parameter=object)) class PydanticMagicMock(MagicMock): @@ -49,31 +60,34 @@ class PydanticMagicMock(MagicMock): def __init__(self, *args, **kwargs): name = kwargs.pop("name", None) super().__init__(*args, **kwargs) - self.__spec__ = importlib.machinery.ModuleSpec(name, None) + self.__spec__ = ModuleSpec(name, None) def __get_pydantic_core_schema__(self, source_type, handler): return core_schema.any_schema() -def auto_mock(module, attr, max_mocks=100): +def auto_mock(module_name: str, attr: str, max_mocks: int = 100): """Function that automatically mocks missing modules during imports.""" - logger.info("Importing %s from %s", attr, module) + logger.info("Importing %s from %s", attr, module_name) + for _ in range(max_mocks): try: + module = importlib.import_module(module_name) + # First treat attr as an attr, then as a submodule - return getattr( - importlib.import_module(module), - attr, - importlib.import_module(f"{module}.{attr}"), - ) + if hasattr(module, attr): + return getattr(module, attr) + + return importlib.import_module(f"{module_name}.{attr}") except ModuleNotFoundError as e: + assert e.name is not None logger.info("Mocking %s for argparse doc generation", e.name) sys.modules[e.name] = PydanticMagicMock(name=e.name) - except Exception as e: - logger.warning("Failed to import %s.%s: %s", module, attr, e) + except Exception: + logger.exception("Failed to import %s.%s: %s", module_name, attr) raise ImportError( - f"Failed to import {module}.{attr} after mocking {max_mocks} imports" + f"Failed to import {module_name}.{attr} after mocking {max_mocks} imports" ) @@ -91,21 +105,26 @@ ChatCommand = auto_mock("vllm.entrypoints.cli.openai", "ChatCommand") CompleteCommand = auto_mock("vllm.entrypoints.cli.openai", "CompleteCommand") openai_cli_args = auto_mock("vllm.entrypoints.openai", "cli_args") openai_run_batch = auto_mock("vllm.entrypoints.openai", "run_batch") -FlexibleArgumentParser = auto_mock( - "vllm.utils.argparse_utils", "FlexibleArgumentParser" -) + +if TYPE_CHECKING: + from vllm.utils.argparse_utils import FlexibleArgumentParser +else: + FlexibleArgumentParser = auto_mock( + "vllm.utils.argparse_utils", "FlexibleArgumentParser" + ) class MarkdownFormatter(HelpFormatter): """Custom formatter that generates markdown for argument groups.""" - def __init__(self, prog, starting_heading_level=3): - super().__init__(prog, max_help_position=float("inf"), width=float("inf")) + def __init__(self, prog: str, starting_heading_level: int = 3): + super().__init__(prog, max_help_position=sys.maxsize, width=sys.maxsize) + self._section_heading_prefix = "#" * starting_heading_level self._argument_heading_prefix = "#" * (starting_heading_level + 1) self._markdown_output = [] - def start_section(self, heading): + def start_section(self, heading: str): if heading not in {"positional arguments", "options"}: heading_md = f"\n{self._section_heading_prefix} {heading}\n\n" self._markdown_output.append(heading_md) @@ -113,14 +132,14 @@ class MarkdownFormatter(HelpFormatter): def end_section(self): pass - def add_text(self, text): + def add_text(self, text: str): if text: self._markdown_output.append(f"{text.strip()}\n\n") def add_usage(self, usage, actions, groups, prefix=None): pass - def add_arguments(self, actions): + def add_arguments(self, actions: Iterable[Action]): for action in actions: if len(action.option_strings) == 0 or "--help" in action.option_strings: continue @@ -169,7 +188,7 @@ def create_parser(add_cli_args, **kwargs) -> FlexibleArgumentParser: # Auto-mock runtime imports if tb_list := traceback.extract_tb(e.__traceback__): path = Path(tb_list[-1].filename).relative_to(ROOT_DIR) - auto_mock(module=".".join(path.parent.parts), attr=path.stem) + auto_mock(module_name=".".join(path.parent.parts), attr=path.stem) return create_parser(add_cli_args, **kwargs) else: raise e @@ -209,7 +228,7 @@ def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool): # Generate documentation for each parser for stem, parser in parsers.items(): - doc_path = ARGPARSE_DOC_DIR / f"{stem}.md" + doc_path = ARGPARSE_DOC_DIR / f"{stem}.inc.md" # Specify encoding for building on Windows with open(doc_path, "w", encoding="utf-8") as f: f.write(super(type(parser), parser).format_help()) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index c1eb207efcd18..bd14bbb9ab662 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -15,9 +15,9 @@ These models are what we list in [supported text models](#list-of-text-only-lang ### Transformers -vLLM also supports model implementations that are available in Transformers. You should expect the performance of a Transformers model implementation used in vLLM to be within <5% of the performance of a dedicated vLLM model implementation. We call this feature the "Transformers backend". +vLLM also supports model implementations that are available in Transformers. You should expect the performance of a Transformers model implementation used in vLLM to be within <5% of the performance of a dedicated vLLM model implementation. We call this feature the "Transformers modeling backend". -Currently, the Transformers backend works for the following: +Currently, the Transformers modeling backend works for the following: - Modalities: embedding models, language models and vision-language models* - Architectures: encoder-only, decoder-only, mixture-of-experts @@ -25,7 +25,7 @@ Currently, the Transformers backend works for the following: _*Vision-language models currently accept only image inputs. Support for video inputs will be added in a future release._ -If the Transformers model implementation follows all the steps in [writing a custom model](#writing-custom-models) then, when used with the Transformers backend, it will be compatible with the following features of vLLM: +If the Transformers model implementation follows all the steps in [writing a custom model](#writing-custom-models) then, when used with the Transformers modeling backend, it will be compatible with the following features of vLLM: - All the features listed in the [compatibility matrix](../features/README.md#feature-x-feature) - Any combination of the following vLLM parallelisation schemes: @@ -44,7 +44,7 @@ llm.apply_model(lambda model: print(type(model))) If the printed type starts with `Transformers...` then it's using the Transformers model implementation! -If a model has a vLLM implementation but you would prefer to use the Transformers implementation via the Transformers backend, set `model_impl="transformers"` for [offline inference](../serving/offline_inference.md) or `--model-impl transformers` for the [online serving](../serving/openai_compatible_server.md). +If a model has a vLLM implementation but you would prefer to use the Transformers implementation via the Transformers modeling backend, set `model_impl="transformers"` for [offline inference](../serving/offline_inference.md) or `--model-impl transformers` for the [online serving](../serving/openai_compatible_server.md). !!! note For vision-language models, if you are loading with `dtype="auto"`, vLLM loads the whole model with config's `dtype` if it exists. In contrast the native Transformers will respect the `dtype` attribute of each backbone in the model. That might cause a slight difference in performance. @@ -53,12 +53,12 @@ If a model has a vLLM implementation but you would prefer to use the Transformer If a model is neither supported natively by vLLM nor Transformers, it can still be used in vLLM! -For a model to be compatible with the Transformers backend for vLLM it must: +For a model to be compatible with the Transformers modeling backend for vLLM it must: - be a Transformers compatible custom model (see [Transformers - Customizing models](https://huggingface.co/docs/transformers/en/custom_models)): - The model directory must have the correct structure (e.g. `config.json` is present). - `config.json` must contain `auto_map.AutoModel`. -- be a Transformers backend for vLLM compatible model (see [Writing custom models](#writing-custom-models)): +- be a Transformers modeling backend for vLLM compatible model (see [Writing custom models](#writing-custom-models)): - Customisation should be done in the base model (e.g. in `MyModel`, not `MyModelForCausalLM`). If the compatible model is: @@ -66,13 +66,13 @@ If the compatible model is: - on the Hugging Face Model Hub, simply set `trust_remote_code=True` for [offline-inference](../serving/offline_inference.md) or `--trust-remote-code` for the [openai-compatible-server](../serving/openai_compatible_server.md). - in a local directory, simply pass directory path to `model=` for [offline-inference](../serving/offline_inference.md) or `vllm serve ` for the [openai-compatible-server](../serving/openai_compatible_server.md). -This means that, with the Transformers backend for vLLM, new models can be used before they are officially supported in Transformers or vLLM! +This means that, with the Transformers modeling backend for vLLM, new models can be used before they are officially supported in Transformers or vLLM! #### Writing custom models -This section details the necessary modifications to make to a Transformers compatible custom model that make it compatible with the Transformers backend for vLLM. (We assume that a Transformers compatible custom model has already been created, see [Transformers - Customizing models](https://huggingface.co/docs/transformers/en/custom_models)). +This section details the necessary modifications to make to a Transformers compatible custom model that make it compatible with the Transformers modeling backend for vLLM. (We assume that a Transformers compatible custom model has already been created, see [Transformers - Customizing models](https://huggingface.co/docs/transformers/en/custom_models)). -To make your model compatible with the Transformers backend, it needs: +To make your model compatible with the Transformers modeling backend, it needs: 1. `kwargs` passed down through all modules from `MyModel` to `MyAttention`. - If your model is encoder-only: @@ -134,7 +134,7 @@ Here is what happens in the background when this model is loaded: 1. The config is loaded. 2. `MyModel` Python class is loaded from the `auto_map` in config, and we check that the model `is_backend_compatible()`. -3. `MyModel` is loaded into one of the Transformers backend classes in [vllm/model_executor/models/transformers](../../vllm/model_executor/models/transformers) which sets `self.config._attn_implementation = "vllm"` so that vLLM's attention layer is used. +3. `MyModel` is loaded into one of the Transformers modeling backend classes in [vllm/model_executor/models/transformers](../../vllm/model_executor/models/transformers) which sets `self.config._attn_implementation = "vllm"` so that vLLM's attention layer is used. That's it! @@ -182,7 +182,7 @@ To determine whether a given model is natively supported, you can check the `con If the `"architectures"` field contains a model architecture listed below, then it should be natively supported. Models do not _need_ to be natively supported to be used in vLLM. -The [Transformers backend](#transformers) enables you to run models directly using their Transformers implementation (or even remote code on the Hugging Face Model Hub!). +The [Transformers modeling backend](#transformers) enables you to run models directly using their Transformers implementation (or even remote code on the Hugging Face Model Hub!). !!! tip The easiest way to check if your model is really supported at runtime is to run the program below: @@ -351,6 +351,7 @@ th { | Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | |--------------|--------|-------------------|----------------------|---------------------------| +| `AfmoeForCausalLM` | Afmoe | TBA | ✅︎ | ✅︎ | | `ApertusForCausalLM` | Apertus | `swiss-ai/Apertus-8B-2509`, `swiss-ai/Apertus-70B-Instruct-2509`, etc. | ✅︎ | ✅︎ | | `AquilaForCausalLM` | Aquila, Aquila2 | `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc. | ✅︎ | ✅︎ | | `ArceeForCausalLM` | Arcee (AFM) | `arcee-ai/AFM-4.5B-Base`, etc. | ✅︎ | ✅︎ | @@ -451,7 +452,7 @@ th { | `Zamba2ForCausalLM` | Zamba2 | `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. | | | | `LongcatFlashForCausalLM` | LongCat-Flash | `meituan-longcat/LongCat-Flash-Chat`, `meituan-longcat/LongCat-Flash-Chat-FP8` | ✅︎ | ✅︎ | -Some models are supported only via the [Transformers backend](#transformers). The purpose of the table below is to acknowledge models which we officially support in this way. The logs will say that the Transformers backend is being used, and you will see no warning that this is fallback behaviour. This means that, if you have issues with any of the models listed below, please [make an issue](https://github.com/vllm-project/vllm/issues/new/choose) and we'll do our best to fix it! +Some models are supported only via the [Transformers modeling backend](#transformers). The purpose of the table below is to acknowledge models which we officially support in this way. The logs will say that the Transformers modeling backend is being used, and you will see no warning that this is fallback behaviour. This means that, if you have issues with any of the models listed below, please [make an issue](https://github.com/vllm-project/vllm/issues/new/choose) and we'll do our best to fix it! | Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | |--------------|--------|-------------------|----------------------|---------------------------| @@ -669,7 +670,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `DeepseekOCRForCausalLM` | DeepSeek-OCR | T + I+ | `deepseek-ai/DeepSeek-OCR`, etc. | | ✅︎ | | `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I+/ V+ | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ | | `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | -| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I+ | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | +| `Gemma3ForConditionalGeneration` | Gemma 3 | T + IE+ | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | | `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | | `GLM4VForCausalLM`^ | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ | | `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + IE+ + VE+ | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ | @@ -684,7 +685,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `KeyeVL1_5ForConditionalGeneration` | Keye-VL-1_5-8B | T + IE+ + VE+ | `Kwai-Keye/Keye-VL-1_5-8B` | ✅︎ | ✅︎ | | `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I+ | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | ✅︎ | | `LightOnOCRForConditionalGeneration` | LightOnOCR-1B | T + I+ | `lightonai/LightOnOCR-1B`, etc | ✅︎ | ✅︎ | -| `Llama4ForConditionalGeneration` | Llama 4 | T + I+ | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | | ✅︎ | +| `Llama4ForConditionalGeneration` | Llama 4 | T + I+ | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | ✅︎ | ✅︎ | | `Llama_Nemotron_Nano_VL` | Llama Nemotron Nano VL | T + IE+ | `nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1` | ✅︎ | ✅︎ | | `LlavaForConditionalGeneration` | LLaVA-1.5, Pixtral (HF Transformers) | T + IE+ | `llava-hf/llava-1.5-7b-hf`, `TIGER-Lab/Mantis-8B-siglip-llama3` (see note), `mistral-community/pixtral-12b`, etc. | | ✅︎ | | `LlavaNextForConditionalGeneration` | LLaVA-NeXT | T + IE+ | `llava-hf/llava-v1.6-mistral-7b-hf`, `llava-hf/llava-v1.6-vicuna-7b-hf`, etc. | | ✅︎ | @@ -720,7 +721,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `TarsierForConditionalGeneration` | Tarsier | T + IE+ | `omni-search/Tarsier-7b`, `omni-search/Tarsier-34b` | | ✅︎ | | `Tarsier2ForConditionalGeneration`^ | Tarsier2 | T + IE+ + VE+ | `omni-research/Tarsier2-Recap-7b`, `omni-research/Tarsier2-7b-0115` | | ✅︎ | -Some models are supported only via the [Transformers backend](#transformers). The purpose of the table below is to acknowledge models which we officially support in this way. The logs will say that the Transformers backend is being used, and you will see no warning that this is fallback behaviour. This means that, if you have issues with any of the models listed below, please [make an issue](https://github.com/vllm-project/vllm/issues/new/choose) and we'll do our best to fix it! +Some models are supported only via the [Transformers modeling backend](#transformers). The purpose of the table below is to acknowledge models which we officially support in this way. The logs will say that the Transformers modeling backend is being used, and you will see no warning that this is fallback behaviour. This means that, if you have issues with any of the models listed below, please [make an issue](https://github.com/vllm-project/vllm/issues/new/choose) and we'll do our best to fix it! | Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | |--------------|--------|--------|-------------------|-----------------------------|-----------------------------------------| @@ -785,6 +786,9 @@ Speech2Text models trained specifically for Automatic Speech Recognition. | `Gemma3nForConditionalGeneration` | Gemma3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | | `GraniteSpeechForConditionalGeneration` | Granite Speech | `ibm-granite/granite-speech-3.3-2b`, `ibm-granite/granite-speech-3.3-8b`, etc. | ✅︎ | ✅︎ | +!!! note + `VoxtralForConditionalGeneration` requires `mistral-common[audio]` to be installed. + ### Pooling Models See [this page](./pooling_models.md) for more information on how to use pooling models. diff --git a/docs/usage/README.md b/docs/usage/README.md index 0c63d01f0f99f..4e8ece2c06052 100644 --- a/docs/usage/README.md +++ b/docs/usage/README.md @@ -1,6 +1,6 @@ # Using vLLM -First, vLLM must be [installed](../getting_started/installation/) for your chosen device in either a Python or Docker environment. +First, vLLM must be [installed](../getting_started/installation/README.md) for your chosen device in either a Python or Docker environment. Then, vLLM supports the following usage patterns: diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py index 53d69bbdbdc7d..04e6f99f8957e 100644 --- a/examples/offline_inference/audio_language.py +++ b/examples/offline_inference/audio_language.py @@ -43,6 +43,7 @@ class ModelRequestData(NamedTuple): # Voxtral +# Make sure to install mistral-common[audio]. def run_voxtral(question: str, audio_count: int) -> ModelRequestData: from mistral_common.audio import Audio from mistral_common.protocol.instruct.chunk import ( diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 371cf6309a678..624de2a2debc3 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -1536,7 +1536,7 @@ def run_qwen2_5_omni(questions: list[str], modality: str): mm_processor_kwargs={ "min_pixels": 28 * 28, "max_pixels": 1280 * 28 * 28, - "fps": [1], + "fps": 1, }, limit_mm_per_prompt={modality: 1}, ) diff --git a/examples/online_serving/openai_chat_completion_client_with_tools.py b/examples/online_serving/openai_chat_completion_client_with_tools.py index 41dbb3236297c..0bd1d05322f81 100644 --- a/examples/online_serving/openai_chat_completion_client_with_tools.py +++ b/examples/online_serving/openai_chat_completion_client_with_tools.py @@ -161,6 +161,7 @@ def main(): { "role": "assistant", "tool_calls": chat_completion.choices[0].message.tool_calls, + "reasoning": chat_completion.choices[0].message.reasoning, } ) diff --git a/examples/online_serving/token_generation_client.py b/examples/online_serving/token_generation_client.py new file mode 100644 index 0000000000000..88ee43c5d9cdf --- /dev/null +++ b/examples/online_serving/token_generation_client.py @@ -0,0 +1,49 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import httpx +from transformers import AutoTokenizer + +GEN_ENDPOINT = "http://localhost:8000/inference/v1/generate" +DUMMY_API_KEY = "empty" +MODEL_NAME = "Qwen/Qwen3-0.6B" + +transport = httpx.HTTPTransport() +headers = {"Authorization": f"Bearer {DUMMY_API_KEY}"} +client = httpx.Client( + transport=transport, + base_url=GEN_ENDPOINT, + timeout=600, + headers=headers, +) +messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "How many countries are in the EU?"}, +] + + +def main(client): + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + token_ids = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + enable_thinking=False, + ) + payload = { + "model": MODEL_NAME, + "token_ids": token_ids, + "sampling_params": {"max_tokens": 24, "temperature": 0.2, "detokenize": False}, + "stream": False, + } + resp = client.post(GEN_ENDPOINT, json=payload) + resp.raise_for_status() + data = resp.json() + print(data) + print("-" * 50) + print("Token generation results:") + res = tokenizer.decode(data["choices"][0]["token_ids"]) + print(res) + print("-" * 50) + + +if __name__ == "__main__": + main(client) diff --git a/requirements/common.txt b/requirements/common.txt index 90efb79a845d3..ad92ba3ad8278 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -31,7 +31,7 @@ partial-json-parser # used for parsing partial JSON outputs pyzmq >= 25.0.0 msgspec gguf >= 0.13.0 -mistral_common[image,audio] >= 1.8.5 +mistral_common[image] >= 1.8.5 opencv-python-headless >= 4.11.0 # required for video IO pyyaml six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 diff --git a/requirements/cpu.txt b/requirements/cpu.txt index 605ce73bff9ce..d11787df4d92b 100644 --- a/requirements/cpu.txt +++ b/requirements/cpu.txt @@ -8,7 +8,7 @@ packaging>=24.2 setuptools>=77.0.3,<81.0.0 --extra-index-url https://download.pytorch.org/whl/cpu torch==2.8.0+cpu; platform_machine == "x86_64" or platform_machine == "s390x" -torch==2.8.0; platform_system == "Darwin" +torch==2.9.0; platform_system == "Darwin" torch==2.8.0; platform_machine == "ppc64le" or platform_machine == "aarch64" # required for the image processor of minicpm-o-2_6, this must be updated alongside torch diff --git a/requirements/cuda.txt b/requirements/cuda.txt index 76874cbd2f482..d63fe9e1e77c1 100644 --- a/requirements/cuda.txt +++ b/requirements/cuda.txt @@ -9,6 +9,6 @@ torch==2.9.0 torchaudio==2.9.0 # These must be updated alongside torch torchvision==0.24.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version -xformers==0.0.33; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.9 +xformers==0.0.33.post1; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.9 # FlashInfer should be updated together with the Dockerfile flashinfer-python==0.5.2 diff --git a/requirements/docs.txt b/requirements/docs.txt index 0fd6dbe22c512..32e004b2b64ba 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -10,3 +10,7 @@ mkdocs-minify-plugin regex ruff pydantic + +# For generating argparse docs. +# Adding requirements here should only be used as a last resort. +msgspec # Need for multiple inheritance involving msgspec.Struct \ No newline at end of file diff --git a/requirements/rocm-build.txt b/requirements/rocm-build.txt index 51f58e57a7851..b977e80be067f 100644 --- a/requirements/rocm-build.txt +++ b/requirements/rocm-build.txt @@ -13,5 +13,5 @@ setuptools>=77.0.3,<80.0.0 setuptools-scm>=8 wheel jinja2>=3.1.6 -amdsmi==6.2.4 +amdsmi==6.4.3 timm>=1.0.17 diff --git a/setup.py b/setup.py index 0934a8608eb12..e9b36e2a2e037 100644 --- a/setup.py +++ b/setup.py @@ -545,7 +545,9 @@ def get_vllm_version() -> str: # Allow overriding the version. This is useful to build platform-specific # wheels (e.g. CPU, TPU) without modifying the source. if env_version := os.getenv("VLLM_VERSION_OVERRIDE"): - return env_version + print(f"Overriding VLLM version with {env_version} from VLLM_VERSION_OVERRIDE") + os.environ["SETUPTOOLS_SCM_PRETEND_VERSION"] = env_version + return get_version(write_to="vllm/_version.py") version = get_version(write_to="vllm/_version.py") sep = "+" if "+" not in version else "." # dev versions might contain + diff --git a/tests/compile/piecewise/test_multiple_graphs.py b/tests/compile/piecewise/test_multiple_graphs.py index 64d626bae483d..6d3788af9de0d 100644 --- a/tests/compile/piecewise/test_multiple_graphs.py +++ b/tests/compile/piecewise/test_multiple_graphs.py @@ -22,6 +22,8 @@ from vllm.config import ( from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.utils.torch_utils import is_torch_equal_or_newer +from ...utils import create_new_process_for_each_test + # This import automatically registers `torch.ops.silly.attention` from .. import silly_attention # noqa: F401 @@ -193,7 +195,14 @@ def run_model( @pytest.mark.parametrize("use_inductor_graph_partition", [False, True]) -def test_multi_graph_piecewise_compile(use_inductor_graph_partition: bool): +@pytest.mark.parametrize("use_bytecode_hook", [True, False]) +@create_new_process_for_each_test("spawn") +def test_multi_graph_piecewise_compile( + use_inductor_graph_partition: bool, use_bytecode_hook: bool, monkeypatch +): + # Set the environment variable for this test + monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0") + if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): pytest.skip("inductor graph partition is only available in PyTorch 2.9+") diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index a48af8a8952ad..e258133ab50a7 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -21,6 +21,8 @@ from vllm.config import ( from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.utils.torch_utils import is_torch_equal_or_newer +from ...utils import create_new_process_for_each_test + # This import automatically registers `torch.ops.silly.attention` from ..silly_attention import get_global_counter, reset_global_counter @@ -124,6 +126,7 @@ def _run_simple_model( @pytest.mark.parametrize("use_inductor", [True, False]) @torch.inference_mode() +@create_new_process_for_each_test("spawn") def test_simple_piecewise_compile(use_inductor): _run_simple_model( splitting_ops=["silly::attention"], diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index 92998ede16992..915fbc6ce7f39 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -29,6 +29,8 @@ from vllm.config import ( from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.utils.torch_utils import is_torch_equal_or_newer +from ...utils import create_new_process_for_each_test + # This import automatically registers `torch.ops.silly.attention` from .. import silly_attention # noqa: F401 @@ -334,6 +336,7 @@ def run_model(llama_config, compile_config: CompilationConfig) -> torch.Tensor: ("inductor", True), # Inductor, Inductor partition ], ) +@create_new_process_for_each_test("spawn") def test_toy_llama( backend: str, use_inductor_graph_partition: bool, monkeypatch, tmp_path ): @@ -513,4 +516,8 @@ def benchmark(): if __name__ == "__main__": - benchmark() + # Protect against subprocess reimport when using spawn_new_process_for_each_test + import os + + if os.environ.get("RUNNING_IN_SUBPROCESS") != "1": + benchmark() diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index bb66ef5529b12..1e8a882a7f3eb 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -15,6 +15,9 @@ from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform from vllm.utils.torch_utils import _is_torch_equal_or_newer +# This import automatically registers `torch.ops.silly.attention` +from . import silly_attention # noqa: F401 + def test_version(): # Test the version comparison logic using the private function @@ -257,15 +260,6 @@ def test_should_split(): splitting_ops = ["aten::add.Tensor"] assert not should_split(node, splitting_ops) - @torch.library.custom_op( - "silly::attention", - mutates_args=["out"], - ) - def attention( - q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor - ) -> None: - out.copy_(q + k + v) - q, k, v, out = [torch.randn(1)] * 4 # supports custom ops as OpOverloadPacket diff --git a/tests/compile/test_fusions_e2e.py b/tests/compile/test_fusions_e2e.py index e1560efb3f247..f22d60ef000b2 100644 --- a/tests/compile/test_fusions_e2e.py +++ b/tests/compile/test_fusions_e2e.py @@ -20,13 +20,22 @@ from vllm.utils.torch_utils import is_torch_equal_or_newer from ..utils import flat_product, multi_gpu_test +is_blackwell = lambda: current_platform.is_device_capability(100) +"""Are we running on Blackwell, a lot of tests depend on it""" + + +class Matches(NamedTuple): + attention_fusion: int = 0 + allreduce_fusion: int = 0 + sequence_parallel: int = 0 + async_tp: int = 0 + class ModelBackendTestCase(NamedTuple): model_name: str model_kwargs: dict[str, Any] backend: AttentionBackendEnum - attention_fusions: int - allreduce_fusions: int | None = None + matches: Matches MODELS_FP8: list[ModelBackendTestCase] = [] @@ -38,17 +47,33 @@ if current_platform.is_cuda(): ModelBackendTestCase( # Use smaller model for L40s in CI model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8", - model_kwargs=dict(max_model_len=1024), - backend=AttentionBackendEnum.TRITON_ATTN, - attention_fusions=32, - allreduce_fusions=65, + # TODO while llama4 is broken, use FLASHINFER for llama3 on Blackwell + # so FI attention+fp8_quant is at least tested once + model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"), + backend=AttentionBackendEnum.FLASHINFER + if is_blackwell() + else AttentionBackendEnum.TRITON_ATTN, + matches=Matches( + attention_fusion=32, + allreduce_fusion=65, + sequence_parallel=65, + async_tp=128, + ), ), ModelBackendTestCase( model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"), - backend=AttentionBackendEnum.FLASHINFER, - attention_fusions=48, - allreduce_fusions=96, + # TODO FlashInfer attn broken on Hopper with kvcache=fp8: + # https://github.com/vllm-project/vllm/issues/28568 + # TODO FlashInfer attn broken on Blackwell for llama4: + # https://github.com/vllm-project/vllm/issues/28604 + backend=AttentionBackendEnum.TRITON_ATTN, + matches=Matches( + attention_fusion=48, + allreduce_fusion=96, + sequence_parallel=96, + async_tp=95, # mlp is moe, no fusion there + ), ), ] @@ -57,8 +82,12 @@ if current_platform.is_cuda(): model_name="nvidia/Llama-3.1-8B-Instruct-FP4", model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"), backend=AttentionBackendEnum.FLASHINFER, - attention_fusions=32, - allreduce_fusions=65, + matches=Matches( + attention_fusion=32, + allreduce_fusion=65, + sequence_parallel=65, + async_tp=128, + ), ), ] @@ -68,15 +97,23 @@ if current_platform.is_cuda(): model_name="meta-llama/Llama-3.1-8B-Instruct", model_kwargs=dict(max_model_len=1024), backend=AttentionBackendEnum.TRITON_ATTN, - attention_fusions=0, - allreduce_fusions=65, + matches=Matches( + attention_fusion=0, + allreduce_fusion=65, + sequence_parallel=65, + async_tp=128, + ), ), ModelBackendTestCase( model_name="Qwen/Qwen3-30B-A3B", model_kwargs=dict(max_model_len=1024), backend=AttentionBackendEnum.TRITON_ATTN, - attention_fusions=0, - allreduce_fusions=97, + matches=Matches( + attention_fusion=0, + allreduce_fusion=97, + sequence_parallel=97, + async_tp=96, # MLP is MoE, half the fusions of dense + ), ), ] @@ -86,19 +123,19 @@ elif current_platform.is_rocm(): model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", model_kwargs=dict(max_model_len=1024), backend=AttentionBackendEnum.TRITON_ATTN, - attention_fusions=32, + matches=Matches(attention_fusion=32), ), ModelBackendTestCase( model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", model_kwargs=dict(max_model_len=1024), backend=AttentionBackendEnum.ROCM_ATTN, - attention_fusions=32, + matches=Matches(attention_fusion=32), ), ModelBackendTestCase( model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", model_kwargs=dict(max_model_len=1024), backend=AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN, - attention_fusions=32, + matches=Matches(attention_fusion=32), ), ] @@ -106,8 +143,7 @@ CUSTOM_OPS_FP8 = ["-quant_fp8", "+quant_fp8"] @pytest.mark.parametrize( - "model_name, model_kwargs, backend, " - "attention_fusions, allreduce_fusions, custom_ops", + "model_name, model_kwargs, backend, matches, custom_ops", # Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8 list(flat_product(MODELS_FP8, CUSTOM_OPS_FP8)) # quant_fp4 only has the custom impl @@ -118,15 +154,14 @@ def test_attn_quant( model_name: str, model_kwargs: dict[str, Any], backend: AttentionBackendEnum, - attention_fusions: int, - allreduce_fusions: int, + matches: Matches, custom_ops: str, inductor_graph_partition: bool, caplog_mp_spawn, monkeypatch, ): if backend == AttentionBackendEnum.FLASHINFER and ( - not current_platform.is_device_capability((10, 0)) or not has_flashinfer() + not is_blackwell() or not has_flashinfer() ): pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer") if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): @@ -169,12 +204,12 @@ def test_attn_quant( with caplog_mp_spawn(logging.DEBUG) as log_holder: run_model(compilation_config, model_name, **model_kwargs) - matches = re.findall( + log_matches = re.findall( r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes", log_holder.text, ) - assert len(matches) == 1, log_holder.text - assert int(matches[0]) == attention_fusions + assert len(log_matches) == 1, log_holder.text + assert int(log_matches[0]) == matches.attention_fusion CUSTOM_OPS_RMS_NORM = ["-rms_norm", "+rms_norm"] @@ -187,8 +222,7 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]: @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize( - "model_name, model_kwargs, backend, " - "attention_fusions, allreduce_fusions, custom_ops", + "model_name, model_kwargs, backend, matches, custom_ops", # Toggle RMSNorm and QuantFP8 for FP8 models list( flat_product( @@ -209,8 +243,7 @@ def test_tp2_attn_quant_allreduce_rmsnorm( model_name: str, model_kwargs: dict, backend: AttentionBackendEnum, - attention_fusions: int, - allreduce_fusions: int, + matches: Matches, custom_ops: str, inductor_graph_partition: bool, caplog_mp_spawn, @@ -219,6 +252,13 @@ def test_tp2_attn_quant_allreduce_rmsnorm( if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): pytest.skip("Inductor graph partition requires torch>=2.9") + if "fp4" in model_name.lower() and not is_blackwell(): + pytest.skip("NVFP4 quant requires Blackwell") + + if backend == AttentionBackendEnum.FLASHINFER and not is_blackwell(): + # FlashInfer attn fusion requires Blackwell + matches = matches._replace(attention_fusion=0) + custom_ops_list = custom_ops.split(",") if custom_ops else [] if inductor_graph_partition: @@ -258,23 +298,135 @@ def test_tp2_attn_quant_allreduce_rmsnorm( run_model( compilation_config, model_name, tensor_parallel_size=2, **model_kwargs ) - matches = re.findall( + log_matches = re.findall( r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes", log_holder.text, ) - assert len(matches) == 2, log_holder.text + assert len(log_matches) == 2, log_holder.text - assert int(matches[0]) == attention_fusions - assert int(matches[1]) == attention_fusions + assert int(log_matches[0]) == matches.attention_fusion + assert int(log_matches[1]) == matches.attention_fusion - matches = re.findall( + log_matches = re.findall( r"collective_fusion.py:\d+] Replaced (\d+) patterns", log_holder.text, ) - assert len(matches) == 2, log_holder.text + assert len(log_matches) == 2, log_holder.text - assert int(matches[0]) == allreduce_fusions - assert int(matches[1]) == allreduce_fusions + assert int(log_matches[0]) == matches.allreduce_fusion + assert int(log_matches[1]) == matches.allreduce_fusion + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize( + "model_name, model_kwargs, backend, matches, custom_ops", + # Toggle RMSNorm and QuantFP8 for FP8 models + list( + flat_product( + MODELS_FP8, custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM) + ) + ) + # Toggle RMSNorm for FP4 models and unquant models + + list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)), +) +@pytest.mark.parametrize("inductor_graph_partition", [True, False]) +@pytest.mark.skipif( + not current_platform.is_cuda(), + reason="sequence parallel only tested on CUDA", +) +def test_tp2_attn_quant_async_tp( + model_name: str, + model_kwargs: dict, + backend: AttentionBackendEnum, + matches: Matches, + custom_ops: str, + inductor_graph_partition: bool, + caplog_mp_spawn, + monkeypatch, +): + if is_blackwell(): + # TODO: https://github.com/vllm-project/vllm/issues/27893 + pytest.skip("Blackwell is not supported for AsyncTP pass") + + if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): + pytest.skip("Inductor graph partition requires torch>=2.9") + + if "fp4" in model_name.lower() and not is_blackwell(): + pytest.skip("NVFP4 quant requires Blackwell") + + if backend == AttentionBackendEnum.FLASHINFER: + if not has_flashinfer(): + pytest.skip("FlashInfer backend requires flashinfer installed") + if not is_blackwell(): + # FlashInfer attn fusion requires Blackwell + matches = matches._replace(attention_fusion=0) + + custom_ops_list = custom_ops.split(",") if custom_ops else [] + + if inductor_graph_partition: + mode = CUDAGraphMode.FULL_AND_PIECEWISE + splitting_ops: list[str] | None = None + else: + mode = CUDAGraphMode.FULL_DECODE_ONLY + splitting_ops = [] + + # Disable, compile cache to make sure custom passes run. + # Otherwise, we can't verify fusion happened through the logs. + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") + + # To capture subprocess logs, we need to know whether spawn or fork is used. + # Force spawn as it is more general. + monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name) + + compilation_config = CompilationConfig( + # Testing properties + use_inductor_graph_partition=inductor_graph_partition, + cudagraph_mode=mode, + custom_ops=custom_ops_list, + splitting_ops=splitting_ops, + # Common + level=CompilationMode.VLLM_COMPILE, + pass_config=PassConfig( + enable_attn_fusion=True, + enable_noop=True, + enable_sequence_parallelism=True, + enable_async_tp=True, + ), + # Inductor caches custom passes by default as well via uuid + inductor_compile_config={"force_disable_caches": True}, + ) + + with caplog_mp_spawn(logging.DEBUG) as log_holder: + run_model( + compilation_config, model_name, tensor_parallel_size=2, **model_kwargs + ) + log_matches = re.findall( + r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes", + log_holder.text, + ) + assert len(log_matches) == 2, log_holder.text + + assert int(log_matches[0]) == matches.attention_fusion + assert int(log_matches[1]) == matches.attention_fusion + + log_matches = re.findall( + r"sequence_parallelism.py:\d+] Replaced (\d+) patterns", + log_holder.text, + ) + assert len(log_matches) == 2, log_holder.text + + assert int(log_matches[0]) == matches.sequence_parallel + assert int(log_matches[1]) == matches.sequence_parallel + + log_matches = re.findall( + r"collective_fusion.py:\d+] Replaced (\d+) patterns", + log_holder.text, + ) + assert len(log_matches) == 2, log_holder.text + + assert int(log_matches[0]) == matches.async_tp + assert int(log_matches[1]) == matches.async_tp def run_model(compile_config: int | CompilationConfig, model: str, **model_kwargs): diff --git a/tests/compile/test_graph_partition.py b/tests/compile/test_graph_partition.py new file mode 100644 index 0000000000000..1cd783843a626 --- /dev/null +++ b/tests/compile/test_graph_partition.py @@ -0,0 +1,124 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import operator + +import pytest +import torch +from torch.fx.experimental.proxy_tensor import make_fx + +from vllm.compilation.backends import split_graph + + +def test_getitem_moved_to_producer_subgraph(): + """ + Test that getitem operations are moved to the same subgraph as their input, + preventing tuple inputs to submodules. + """ + + def model_fn(x: torch.Tensor) -> torch.Tensor: + # torch.split returns a tuple, creating real getitem operations + # Should become first submodule that produces tuple + chunks = torch.split(x, x.shape[0] // 2, dim=0) + + # Following ops should become second submodule that consumes tuple + result_0 = torch.relu(chunks[0]) + result_1 = torch.relu(chunks[1]) + return torch.cat([result_0, result_1], dim=0) + + x = torch.randn(4, 3) + gm = make_fx(model_fn)(x) + + has_getitem = any( + node.op == "call_function" and node.target == operator.getitem + for node in gm.graph.nodes + ) + assert has_getitem, "Test setup failed: graph should contain getitem operations" + + # Split on tuple producer aten::split + split_ops = ["aten::split.Tensor"] + split_gm, split_items = split_graph(gm, split_ops) + assert len(split_items) == 2, "Graph should be split into 2 submodules" + + for split_item in split_items: + submodule = split_item.graph + + getitem_on_placeholder = [] + for node in submodule.graph.nodes: + if ( + node.op == "call_function" + and node.target == operator.getitem + and node.args[0].op == "placeholder" + ): + getitem_on_placeholder.append(node) + + assert len(getitem_on_placeholder) == 0, ( + f"Submodule {split_item.submod_name} has getitem operations on " + f"placeholder nodes: {[n.name for n in getitem_on_placeholder]}. " + "This means tuple inputs were not properly eliminated." + ) + + new_x = torch.randn(4, 3) + output_original = gm(new_x) + output_split = split_gm(new_x) + + assert torch.allclose(output_original, output_split), "Output mismatch" + + +def test_no_tuple_inputs_with_multiple_consumers(): + """ + Test that when a tuple is consumed by multiple split operations, + getitem operations are properly moved to avoid tuple inputs. + """ + + def model_fn(x: torch.Tensor) -> torch.Tensor: + # torch.split returns a tuple, creating real getitem operations + # Should become first submodule that produces tuple + chunks = torch.split(x, x.shape[0] // 2, dim=0) + + # These should become second submodule consuming tuple + result_1 = torch.relu(chunks[0]) + result_2 = torch.relu(chunks[1]) + + # Artificial graph splitting point to create another + # independent submodule that consumes tuple later + # This would become the third submodule + result_1 = torch.sigmoid(result_1) + + # Fourth submodule that consumes tuple + result = torch.cat([chunks[0], chunks[1], result_1, result_2]) + return result + + x = torch.randn(4, 3) + gm = make_fx(model_fn)(x) + + has_getitem = any( + node.op == "call_function" and node.target == operator.getitem + for node in gm.graph.nodes + ) + assert has_getitem, "Test setup failed: graph should contain getitem operations" + + split_ops = ["aten::split.Tensor", "aten::sigmoid"] + split_gm, split_items = split_graph(gm, split_ops) + assert len(split_items) == 4, "Graph should be split into 4 submodules" + + for split_item in split_items: + submodule = split_item.graph + + for node in submodule.graph.nodes: + if ( + node.op == "call_function" + and node.target == operator.getitem + and node.args[0].op == "placeholder" + ): + pytest.fail( + f"Submodule {split_item.submod_name} has getitem on " + f"placeholder {node.args[0].name}, indicating it receives " + "a tuple input" + ) + + new_x = torch.randn(4, 3) + output_original = gm(new_x) + output_split = split_gm(new_x) + + assert torch.allclose(output_original, output_split), "Output mismatch after split" diff --git a/tests/compile/test_multimodal_compile.py b/tests/compile/test_multimodal_compile.py index b76c29819a2df..621f6a51a918f 100644 --- a/tests/compile/test_multimodal_compile.py +++ b/tests/compile/test_multimodal_compile.py @@ -10,8 +10,8 @@ from vllm.platforms import current_platform def test_compile(): vllm_config = VllmConfig() - # Default configuration compiles mm encoder - assert vllm_config.compilation_config.compile_mm_encoder + # Default configuration does not compile mm encoder + assert not vllm_config.compilation_config.compile_mm_encoder # forked needed to workaround https://github.com/vllm-project/vllm/issues/21073 @@ -39,7 +39,10 @@ def test_qwen2_5_vl_compilation(vllm_runner, monkeypatch): "Qwen/Qwen2.5-VL-3B-Instruct", max_model_len=2048, gpu_memory_utilization=0.8, - compilation_config={"mode": CompilationMode.VLLM_COMPILE}, + compilation_config={ + "mode": CompilationMode.VLLM_COMPILE, + "compile_mm_encoder": True, + }, ) as _, ): pass diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py index fc4d38c8f8374..fd6aeb6b44389 100644 --- a/tests/compile/test_sequence_parallelism.py +++ b/tests/compile/test_sequence_parallelism.py @@ -5,15 +5,15 @@ import pytest import torch import vllm.envs as envs -from vllm.compilation.fix_functionalization import FixFunctionalizationPass from vllm.compilation.fusion import RMSNormQuantFusionPass -from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func +from vllm.compilation.fx_utils import find_auto_fn from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.post_cleanup import PostCleanupPass from vllm.compilation.sequence_parallelism import SequenceParallelismPass from vllm.compilation.vllm_inductor_pass import VllmInductorPass from vllm.config import ( CompilationConfig, + CUDAGraphMode, DeviceConfig, ModelConfig, PassConfig, @@ -30,6 +30,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.quant_utils import ( kFp8StaticTensorSym, ) +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.platforms import current_platform from vllm.utils.system_utils import update_environment_variables @@ -45,169 +46,157 @@ prompts = [ ] -class TestModel(torch.nn.Module): - def __init__(self, hidden_size=16, intermediate_size=32): +class TestAllReduceRMSNormModel(torch.nn.Module): + def __init__(self, hidden_size=16, eps=1e-6): super().__init__() self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.gate_proj = torch.nn.Parameter( - torch.empty((intermediate_size, hidden_size)) - ) - self.norm = RMSNorm(intermediate_size, 1e-05) - # Initialize weights - torch.nn.init.normal_(self.gate_proj, std=0.02) + self.eps = eps + self.norm = [RMSNorm(hidden_size, eps) for i in range(4)] + self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)] - def forward(self, hidden_states, residual): - """ - Forward pass implementing the operations in the FX graph + def forward(self, x): + z = torch.relu(x) + x = resid = tensor_model_parallel_all_reduce(z) + y = self.norm[0](x) - Args: - hidden_states: Input tensor - residual: Residual tensor from previous layer + z2 = torch.mm(y, self.w[0]) + x2 = tensor_model_parallel_all_reduce(z2) - Returns: - Tuple containing the output tensor - """ - # Reshape input - view = hidden_states.reshape(-1, self.hidden_size) + y2, resid = self.norm[1](x2, resid) - # matrix multiplication - permute = self.gate_proj.permute(1, 0) - mm = torch.mm(view, permute) + z3 = torch.mm(y2, self.w[1]) + x3 = tensor_model_parallel_all_reduce(z3) - # Tensor parallel all-reduce - all_reduce = tensor_model_parallel_all_reduce(mm) + y3, resid = self.norm[2](x3, resid) - # layer normalization - norm_output, residual_output = self.norm(all_reduce, residual) + z4 = torch.mm(y3, self.w[2]) + x4 = tensor_model_parallel_all_reduce(z4) - return norm_output, residual_output + y4, resid = self.norm[3](x4, resid) + return y4 def ops_in_model_before(self): return [torch.ops.vllm.all_reduce.default] def ops_in_model_after(self): return [ - torch.ops.vllm.reduce_scatter.default, torch.ops.vllm.all_gather.default, + torch.ops.vllm.reduce_scatter.default, ] def ops_in_model(self): - return [torch.ops._C.fused_add_rms_norm.default] + if RMSNorm.enabled(): + return [ + torch.ops._C.rms_norm.default, + torch.ops._C.fused_add_rms_norm.default, + ] + else: + return [] -class TestQuantModel(torch.nn.Module): +class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module): quant_key = kFp8StaticTensorSym - def __init__(self, hidden_size=16, intermediate_size=32): + def __init__(self, hidden_size=16, eps=1e-6): super().__init__() - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size self.vllm_config = get_current_vllm_config() - self.gate_proj = torch.nn.Parameter( - torch.empty((intermediate_size, hidden_size)), requires_grad=False - ) - self.norm = RMSNorm(intermediate_size, 1e-05) - # Initialize weights - torch.nn.init.normal_(self.gate_proj, std=0.02) + self.hidden_size = hidden_size + self.eps = eps + self.norm = [RMSNorm(hidden_size, eps) for i in range(4)] + self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] + self.w = [ + torch.rand(hidden_size, hidden_size) + .to(dtype=current_platform.fp8_dtype()) + .t() + for _ in range(3) + ] + self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] - self.scale = torch.rand(1, dtype=torch.float32) - # Create a weight that is compatible with torch._scaled_mm, - # which expects a column-major layout. - self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t() - self.wscale = torch.rand(1, dtype=torch.float32) - self.fp8_linear = TestFP8Layer( - self.quant_key, self.quant_key, self.w, self.wscale, self.scale - ) - - def forward(self, hidden_states, residual): - """ - Forward pass implementing the operations in the FX graph - - Args: - hidden_states: Input tensor - residual: Residual tensor from previous layer - - Returns: - Tuple containing the output tensor - """ - # Reshape input - view = hidden_states.reshape(-1, self.hidden_size) - - # matrix multiplication - permute = self.gate_proj.permute(1, 0) - mm = torch.mm(view, permute) - - # Tensor parallel all-reduce - all_reduce = tensor_model_parallel_all_reduce(mm) - - # layer normalization - norm_output, residual_output = self.norm(all_reduce, residual) - # scaled_mm with static input quantization - fp8_linear_result = self.fp8_linear(norm_output) - - return fp8_linear_result, residual_output - - def ops_in_model_before(self): - ops_to_remove = [torch.ops.vllm.all_reduce.default] # Always removed by SP - # The following are only removed if fusion happens - if ( - self.vllm_config - and self.vllm_config.compilation_config.pass_config.enable_fusion - ): - ops_to_remove.extend( - [ - torch.ops._C.fused_add_rms_norm.default, - torch.ops._C.static_scaled_fp8_quant.default, - ] + self.fp8_linear_layers = [ + TestFP8Layer( + self.quant_key, self.quant_key, self.w[i], self.wscale[i], self.scale[i] ) - return ops_to_remove + for i in range(3) + ] + + + def forward(self, hidden_states): + # avoid having graph input be an arg to a pattern directly + z = torch.relu(hidden_states) + x = resid = tensor_model_parallel_all_reduce(z) + y = self.norm[0](x) + + z2 = self.fp8_linear_layers[0](y) + + x2 = tensor_model_parallel_all_reduce(z2) + y2, resid = self.norm[1](x2, resid) + + z3 = self.fp8_linear_layers[1](y2) + + x3 = tensor_model_parallel_all_reduce(z3) + y3, resid = self.norm[2](x3, resid) # use resid here + + z4 = self.fp8_linear_layers[2](y3) + x4 = tensor_model_parallel_all_reduce(z4) + y4, resid = self.norm[3](x4, resid) # use resid here + return y4 def ops_in_model_after(self): - ops_to_add = [ - torch.ops.vllm.reduce_scatter.default, + return [ torch.ops.vllm.all_gather.default, + torch.ops.vllm.reduce_scatter.default, + ] + + def ops_in_model_before(self): + return [ + torch.ops.vllm.all_reduce.default, ] - # The following is only added if fusion happens - if ( - self.vllm_config - and self.vllm_config.compilation_config.pass_config.enable_fusion - ): - ops_to_add.append(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default) - return ops_to_add def ops_in_model(self): - if ( - self.vllm_config - and self.vllm_config.compilation_config.pass_config.enable_fusion - ): - # If fusion happens, the fused op is the one - # we check for (de)functionalization + if self.vllm_config.compilation_config.pass_config.enable_fusion: return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default] - else: - # If no fusion, the original ops are checked + elif RMSNorm.enabled(): return [ torch.ops._C.fused_add_rms_norm.default, - # TODO functionalization pass does not handle this yet - # torch.ops._C.static_scaled_fp8_quant.default, ] + elif any( + layer.is_quant_fp8_enabled() for layer in self.fp8_linear_layers + ): + return [ + torch.ops._C.static_scaled_fp8_quant.default, + ] + else: + return [] @multi_gpu_test(num_gpus=2) -@pytest.mark.parametrize("test_model_cls", [TestModel, TestQuantModel]) +@pytest.mark.parametrize( + "test_model_cls, custom_ops", + [ + (TestAllReduceRMSNormModel, "+rms_norm"), + (TestAllReduceRMSNormModel, "-rms_norm"), + (TestAllReduceRMSNormStaticQuantFP8Model, "+rms_norm,+quant_fp8"), + (TestAllReduceRMSNormStaticQuantFP8Model, "+rms_norm,-quant_fp8"), + (TestAllReduceRMSNormStaticQuantFP8Model, "-rms_norm,+quant_fp8"), + (TestAllReduceRMSNormStaticQuantFP8Model, "-rms_norm,-quant_fp8"), + ], +) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("seq_len", [16]) @pytest.mark.parametrize("hidden_size", [16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("enable_fusion", [True, False]) +@pytest.mark.parametrize("dynamic", [False, True]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") def test_sequence_parallelism_pass( test_model_cls: type[torch.nn.Module], + custom_ops: str, batch_size: int, seq_len: int, hidden_size: int, dtype: torch.dtype, enable_fusion: bool, + dynamic: bool, ): num_processes = 2 @@ -219,11 +208,13 @@ def test_sequence_parallelism_pass( args=( num_processes, test_model_cls, + custom_ops, batch_size, seq_len, hidden_size, dtype, enable_fusion, + dynamic, ), nprocs=nprocs, ) @@ -235,11 +226,13 @@ def sequence_parallelism_pass_on_test_model( local_rank: int, world_size: int, test_model_cls: type[torch.nn.Module], + custom_ops: str, batch_size: int, seq_len: int, hidden_size: int, dtype: torch.dtype, enable_fusion: bool, + dynamic: bool, ): current_platform.seed_everything(0) @@ -263,12 +256,16 @@ def sequence_parallelism_pass_on_test_model( initialize_model_parallel(tensor_model_parallel_size=world_size) # configure vllm config for SequenceParallelismPass + custom_ops_list = custom_ops.split(",") if custom_ops else [] compilation_config = CompilationConfig( + splitting_ops=[], # avoid automatic rms_norm enablement + cudagraph_mode=CUDAGraphMode.NONE, # avoid piecewise warnings + custom_ops=custom_ops_list, pass_config=PassConfig( enable_sequence_parallelism=True, enable_fusion=enable_fusion, enable_noop=True, - ) + ), ) # NoOp needed for fusion device_config = DeviceConfig(device=torch.device("cuda")) @@ -288,7 +285,6 @@ def sequence_parallelism_pass_on_test_model( with set_current_vllm_config(vllm_config): noop_pass = NoOpEliminationPass(vllm_config) sequence_parallelism_pass = SequenceParallelismPass(vllm_config) - func_pass = FixFunctionalizationPass(vllm_config) cleanup_pass = PostCleanupPass(vllm_config) assert ( sequence_parallelism_pass.compilation_config.splitting_ops @@ -309,38 +305,29 @@ def sequence_parallelism_pass_on_test_model( passes_for_backend.append(cleanup_pass) - backend_no_func = TestBackend(*passes_for_backend) - backend_func = TestBackend(*passes_for_backend, func_pass) + backend = TestBackend(*passes_for_backend) - model = test_model_cls(hidden_size, hidden_size * 2) + model = test_model_cls(hidden_size) hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) - residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) - compiled_model_no_func = torch.compile(model, backend=backend_no_func) - compiled_model_no_func(hidden_states, residual) - compiled_model_func = torch.compile(model, backend=backend_func) - compiled_model_func(hidden_states, residual) + if dynamic: + torch._dynamo.mark_dynamic(hidden_states, 0) - assert sequence_parallelism_pass.matched_count == 1 + compiled_model = torch.compile(model, backend=backend) + compiled_model(hidden_states) + + assert sequence_parallelism_pass.matched_count == 4 # In pre-nodes, all reduce should be there, # reduce scatter and all gather should not - backend_no_func.check_before_ops(model.ops_in_model_before()) + for op in model.ops_in_model_before(): + assert backend.op_count(op, before=True) == 4 # In post-nodes, reduce scatter and all gather should be there, # all reduce should not - backend_no_func.check_after_ops(model.ops_in_model_after()) + for op in model.ops_in_model_after(): + assert backend.op_count(op, before=False) == 4 - # check if the functionalization pass is applied for op in model.ops_in_model(): - find_auto_fn(backend_no_func.graph_post_pass.nodes, op) - assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None - - # make sure the ops were all de-functionalized - found = dict() - for node in backend_func.graph_post_pass.nodes: - for op in model.ops_in_model(): - if is_func(node, op): - found[op] = True - assert all(found[op] for op in model.ops_in_model()) + find_auto_fn(backend.graph_post_pass.nodes, op) diff --git a/tests/compile/test_wrapper.py b/tests/compile/test_wrapper.py index da0afd9eaa49f..356cac7af258b 100644 --- a/tests/compile/test_wrapper.py +++ b/tests/compile/test_wrapper.py @@ -2,59 +2,134 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os + +import pytest import torch -from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher -from vllm.config import CompilationMode +from vllm.compilation.wrapper import TorchCompileWithNoGuardsWrapper +from vllm.config import ( + CompilationConfig, + CompilationMode, + VllmConfig, + set_current_vllm_config, +) class MyMod(torch.nn.Module): def forward(self, x: torch.Tensor, cache: torch.Tensor | None = None): - if cache is not None: - return x + cache - return x * 2 + if x.size()[0] >= 4: + return x * 2 + else: + return x * 100 -class MyWrapper(TorchCompileWrapperWithCustomDispatcher): +class MyWrapper(TorchCompileWithNoGuardsWrapper): def __init__(self, model): self.model = model - compiled_callable = torch.compile(self.forward, backend="eager") - super().__init__( - compiled_callable, compilation_mode=CompilationMode.DYNAMO_TRACE_ONCE + super().__init__() + + def forward(self, x: torch.Tensor): # type: ignore[override] + # this is the function to be compiled + return self.model(x) + + +@pytest.mark.parametrize("use_bytecode_hook", [True, False]) +def test_torch_compile_wrapper(use_bytecode_hook, monkeypatch): + """Test basic functionality of TorchCompileWithNoGuardsWrapper.""" + # Set the environment variable for this test + monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0") + + # Create a proper vLLM config instead of mocking + vllm_config = VllmConfig() + vllm_config.compilation_config = CompilationConfig() + vllm_config.compilation_config.mode = CompilationMode.DYNAMO_TRACE_ONCE + vllm_config.compilation_config.backend = "inductor" + + # Test DYNAMO_TRACE_ONCE + with set_current_vllm_config(vllm_config): + torch._dynamo.reset() + mod = MyMod() + wrapper = MyWrapper(mod) + + # First call should trigger compilation + x = torch.tensor([1, 2, 3, 4]) + torch._dynamo.mark_dynamic(x, 0) + + result1 = wrapper(x) + expected1 = torch.tensor([2, 4, 6, 8]) + assert torch.allclose(result1, expected1), ( + f"Expected {expected1}, got {result1}" ) - def forward(self, x: torch.Tensor, cache: torch.Tensor | None = None): - # this is the function to be compiled - return self.model(x, cache) + # Second call should use compiled code + x2 = torch.tensor([1, 2, 3]) + result2 = wrapper(x2) + expected2 = torch.tensor([2, 4, 6]) + assert torch.allclose(result2, expected2), ( + f"Expected {expected2}, got {result2}" + ) - def __call__(self, x: torch.Tensor, cache: torch.Tensor | None = None): - # let torch.compile compile twice - if len(self.compiled_codes) == 2: - dispatch_id = 0 if cache is None else 1 - with self.dispatch_to_code(dispatch_id): - return self.forward(x, cache) - else: - return self.compiled_callable(x, cache) + # without the wrapper result would be different. + result3 = mod(x2) + expected3 = torch.tensor([100, 200, 300]) + assert torch.allclose(result3, expected3), ( + f"Expected {result3}, got {expected3}" + ) -def test_torch_compile_wrapper(): - mod = MyMod() - wrappers = [] - for i in range(3): - torch._dynamo.reset() + # with STOCK_TORCH_COMPILE we do not remove guards. + vllm_config.compilation_config.mode = CompilationMode.STOCK_TORCH_COMPILE + torch._dynamo.reset() + with set_current_vllm_config(vllm_config): + mod = MyMod() wrapper = MyWrapper(mod) - wrappers.append(wrapper) - x = torch.tensor([1]) - wrapper(x, None) # profile run, compile - # create a cache tensor - cache = torch.tensor([2]) - wrapper(x, cache) # warm up with cache, recompile - # for new input, dispatch to the compiled code directly - new_x = torch.tensor([3]) - assert wrapper(new_x, None).item() == 6 # dispatch to the first compiled code - assert wrapper(new_x, cache).item() == 5 # dispatch to the second compiled code + # First call should trigger compilation + x = torch.tensor([1, 2, 3, 4]) + torch._dynamo.mark_dynamic(x, 0) - for wrapper in wrappers: - # make sure they have independent compiled codes - assert len(wrapper.compiled_codes) == 2 + result1 = wrapper(x) + expected1 = torch.tensor([2, 4, 6, 8]) + assert torch.allclose(result1, expected1), ( + f"Expected {expected1}, got {result1}" + ) + + # Second call should triger another compilation + x2 = torch.tensor([1, 2, 3]) + result2 = wrapper(x2) + expected2 = torch.tensor([100, 200, 300]) + assert torch.allclose(result2, expected2), ( + f"Expected {expected2}, got {result2}" + ) + + # NO_COMPILATION level not supported. + vllm_config.compilation_config.mode = None + torch._dynamo.reset() + with set_current_vllm_config(vllm_config): + torch._dynamo.reset() + mod = MyMod() + + try: + wrapper = MyWrapper(mod) + except Exception: + return + raise AssertionError("expected an exception to be raised") + + +if __name__ == "__main__": + # Run with both parameter values + + class MockMonkeypatch: + def setenv(self, name, value): + os.environ[name] = value + + mp = MockMonkeypatch() + + print("Testing with VLLM_USE_BYTECODE_HOOK=False") + test_torch_compile_wrapper(False, mp) + + print("Testing with VLLM_USE_BYTECODE_HOOK=True") + test_torch_compile_wrapper(True, mp) + + print("All tests passed!") diff --git a/tests/conftest.py b/tests/conftest.py index 5e127e4e939e6..b17081352edcf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1384,3 +1384,16 @@ def image_urls(request, local_asset_server) -> list[str]: """Indirect fixture: takes a list of names, returns list of full URLs.""" names: list[str] = request.param return [local_asset_server.url_for(name) for name in names] + + +@pytest.fixture +def disable_deepgemm_ue8m0(monkeypatch): + from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used + + with monkeypatch.context() as monkeypatch_ctx: + monkeypatch_ctx.setenv("VLLM_USE_DEEP_GEMM_E8M0", "0") + is_deep_gemm_e8m0_used.cache_clear() + yield + # Clear cache so the next time it is used it is processed with the + # default VLLM_USE_DEEP_GEMM_E8M0 setting. + is_deep_gemm_e8m0_used.cache_clear() diff --git a/tests/distributed/test_context_parallel.py b/tests/distributed/test_context_parallel.py index 3576efca591cf..b16fd0d06b145 100644 --- a/tests/distributed/test_context_parallel.py +++ b/tests/distributed/test_context_parallel.py @@ -39,6 +39,7 @@ class ParallelSetup(NamedTuple): class CPTestOptions(NamedTuple): multi_node_only: bool load_format: str | None = None + attn_backend: str | None = None @dataclass @@ -58,6 +59,7 @@ class CPTestSettings: multi_node_only: bool = False, runner: RunnerOption = "auto", load_format: str | None = None, + attn_backend: str | None = None, ): parallel_setups = [] for eager_mode_val in [False]: @@ -79,7 +81,9 @@ class CPTestSettings: distributed_backends=["mp"], runner=runner, test_options=CPTestOptions( - multi_node_only=multi_node_only, load_format=load_format + multi_node_only=multi_node_only, + load_format=load_format, + attn_backend=attn_backend, ), ) @@ -117,7 +121,7 @@ def _compare_cp_with_tp( chunked_prefill, ) = parallel_setup - multi_node_only, load_format = test_options + multi_node_only, load_format, attn_backend = test_options model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) model_info.check_transformers_version(on_fail="skip") @@ -177,6 +181,13 @@ def _compare_cp_with_tp( if hf_overrides: common_args.extend(["--hf-overrides", json.dumps(hf_overrides)]) + if not attn_backend: + cp_env = tp_env = {} + else: + cp_env = tp_env = { + "VLLM_ATTENTION_BACKEND": attn_backend, + } + cp_args = [ *common_args, "--tensor-parallel-size", @@ -205,6 +216,8 @@ def _compare_cp_with_tp( model_id, cp_args, tp_args, + cp_env, + tp_env, method=method, max_wait_seconds=720, ) diff --git a/tests/distributed/test_multiproc_executor.py b/tests/distributed/test_multiproc_executor.py new file mode 100644 index 0000000000000..e741a79bc4ed9 --- /dev/null +++ b/tests/distributed/test_multiproc_executor.py @@ -0,0 +1,437 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Integration tests for MultiprocExecutor at the executor level. +This test directly tests the executor without going through the LLM interface, +focusing on executor initialization, RPC calls, and distributed execution. +""" + +import multiprocessing +import os + +from tests.utils import multi_gpu_test +from vllm.config import VllmConfig +from vllm.engine.arg_utils import EngineArgs +from vllm.utils import get_open_port +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.executor.multiproc_executor import MultiprocExecutor + +MODEL = "facebook/opt-125m" + + +def create_vllm_config( + tensor_parallel_size: int = 1, + pipeline_parallel_size: int = 1, + max_model_len: int = 256, + gpu_memory_utilization: float = 0.3, + distributed_executor_backend: str = "mp", + nnodes: int = 1, + node_rank: int = 0, + master_port: int = 0, +) -> VllmConfig: + """Create a VllmConfig for testing using EngineArgs.""" + engine_args = EngineArgs( + model=MODEL, + tensor_parallel_size=tensor_parallel_size, + pipeline_parallel_size=pipeline_parallel_size, + max_model_len=max_model_len, + gpu_memory_utilization=gpu_memory_utilization, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True, + ) + vllm_config = engine_args.create_engine_config() + + # Override distributed node settings if needed + if nnodes > 1 or node_rank > 0: + vllm_config.parallel_config.nnodes = nnodes + vllm_config.parallel_config.node_rank = node_rank + vllm_config.parallel_config.master_port = master_port + if nnodes > 1: + vllm_config.parallel_config.disable_custom_all_reduce = True + + return vllm_config + + +def create_test_scheduler_output(num_requests: int = 1) -> SchedulerOutput: + """Create a minimal SchedulerOutput for testing.""" + # This is a simplified version - in practice you'd need proper + # SchedulerOutput construction based on the actual vLLM v1 API + return SchedulerOutput( + scheduled_new_reqs=[], + scheduled_resumed_reqs=[], + scheduled_running_reqs=[], + num_scheduled_tokens={}, + total_num_scheduled_tokens=0, + ) + + +def test_multiproc_executor_initialization(): + """Test that MultiprocExecutor can be initialized with proper config.""" + vllm_config = create_vllm_config( + tensor_parallel_size=1, + pipeline_parallel_size=1, + ) + + # Create executor - this should initialize workers + executor = MultiprocExecutor(vllm_config=vllm_config) + + # Verify executor properties + assert executor.world_size == 1, "World size should be 1 for single GPU" + assert executor.local_world_size == 1, "Local world size should be 1" + assert hasattr(executor, "workers"), "Executor should have workers" + assert len(executor.workers) == 1, "Should have 1 worker for single GPU" + + # Clean up + executor.shutdown() + + +@multi_gpu_test(num_gpus=2) +def test_multiproc_executor_initialization_tensor_parallel(): + """Test MultiprocExecutor initialization with tensor parallelism.""" + vllm_config = create_vllm_config( + tensor_parallel_size=2, + pipeline_parallel_size=1, + ) + + # Create executor + executor = MultiprocExecutor(vllm_config=vllm_config) + + # Verify executor properties + assert executor.world_size == 2, "World size should be 2 for TP=2" + assert executor.local_world_size == 2, "Local world size should be 2" + assert len(executor.workers) == 2, "Should have 2 workers for TP=2" + + # Verify output rank calculation + output_rank = executor._get_output_rank() + assert output_rank == 0, "Output rank should be 0 for TP=2, PP=1" + + # Clean up + executor.shutdown() + + +@multi_gpu_test(num_gpus=2) +def test_multiproc_executor_collective_rpc(): + """Test collective RPC calls to all workers.""" + vllm_config = create_vllm_config( + tensor_parallel_size=2, + pipeline_parallel_size=1, + ) + + # Create executor + executor = MultiprocExecutor(vllm_config=vllm_config) + + try: + # Test check_health RPC - should work without errors + executor.check_health() + + # Test that RPC works correctly + # Note: We're just testing that the RPC mechanism works, + # not testing actual model execution here + assert not executor.is_failed, "Executor should not be in failed state" + + finally: + # Clean up + executor.shutdown() + + +def test_multiproc_executor_failure_callback(): + """Test failure callback registration and invocation.""" + vllm_config = create_vllm_config( + tensor_parallel_size=1, + pipeline_parallel_size=1, + ) + + executor = MultiprocExecutor(vllm_config=vllm_config) + + try: + # Test callback registration + callback_invoked = [] + + def test_callback(): + callback_invoked.append(True) + + # Register callback + executor.register_failure_callback(test_callback) + + # Callback should not be invoked yet + assert len(callback_invoked) == 0, "Callback should not be invoked immediately" + + # Simulate failure + executor.is_failed = True + + # Register another callback - should be invoked immediately + executor.register_failure_callback(test_callback) + assert len(callback_invoked) == 1, ( + "Callback should be invoked when executor is failed" + ) + + finally: + # Clean up + executor.shutdown() + + +@multi_gpu_test(num_gpus=2) +def test_multiproc_executor_worker_monitor(): + """Test that worker monitor is set up correctly.""" + vllm_config = create_vllm_config( + tensor_parallel_size=2, + pipeline_parallel_size=1, + ) + + executor = MultiprocExecutor(vllm_config=vllm_config) + + try: + # Verify all worker processes are alive + for worker in executor.workers: + assert worker.proc.is_alive(), f"Worker rank {worker.rank} should be alive" + + # Verify executor is not in failed state + assert not executor.is_failed, "Executor should not be in failed state" + + finally: + # Clean up + executor.shutdown() + + # After shutdown, workers should be terminated + import time + + time.sleep(0.5) # Give processes time to terminate + for worker in executor.workers: + assert not worker.proc.is_alive(), ( + f"Worker rank {worker.rank} should terminate after shutdown" + ) + + +@multi_gpu_test(num_gpus=2) +def test_multiproc_executor_get_response_message_queues(): + """Test message queue retrieval for different ranks.""" + vllm_config = create_vllm_config( + tensor_parallel_size=2, + pipeline_parallel_size=1, + ) + + executor = MultiprocExecutor(vllm_config=vllm_config) + + try: + # Get all message queues + all_queues = executor.get_response_mqs() + assert len(all_queues) == 2, "Should have 2 message queues for 2 workers" + + # Get message queue for specific rank + rank0_queue = executor.get_response_mqs(unique_reply_rank=0) + assert len(rank0_queue) == 1, "Should have 1 message queue for rank 0" + + rank1_queue = executor.get_response_mqs(unique_reply_rank=1) + assert len(rank1_queue) == 1, "Should have 1 message queue for rank 1" + + finally: + # Clean up + executor.shutdown() + + +def test_multiproc_executor_shutdown_cleanup(): + """Test that shutdown properly cleans up resources.""" + vllm_config = create_vllm_config( + tensor_parallel_size=1, + pipeline_parallel_size=1, + ) + + executor = MultiprocExecutor(vllm_config=vllm_config) + + # Verify executor is set up + assert hasattr(executor, "workers"), "Executor should have workers" + assert len(executor.workers) > 0, "Should have at least one worker" + + # Shutdown + executor.shutdown() + + # Verify cleanup + import time + + time.sleep(0.5) # Give processes time to terminate + + for worker in executor.workers: + assert not worker.proc.is_alive(), "Worker processes should be terminated" + + # Verify shutdown event is set + assert executor.shutdown_event.is_set(), "Shutdown event should be set" + + # Multiple shutdowns should be safe (idempotent) + executor.shutdown() + executor.shutdown() + + +@multi_gpu_test(num_gpus=4) +def test_multiproc_executor_pipeline_parallel(): + """Test MultiprocExecutor with pipeline parallelism.""" + vllm_config = create_vllm_config( + tensor_parallel_size=2, + pipeline_parallel_size=2, + ) + + executor = MultiprocExecutor(vllm_config=vllm_config) + + try: + # Verify executor properties + assert executor.world_size == 4, "World size should be 4 for TP=2, PP=2" + assert len(executor.workers) == 4, "Should have 4 workers" + + # Verify output rank calculation + # For TP=2, PP=2: output should be from the last PP stage (ranks 2-3) + # Specifically rank 2 (first rank of last PP stage) + output_rank = executor._get_output_rank() + assert output_rank == 2, "Output rank should be 2 (first rank of last PP stage)" + + # Verify max_concurrent_batches for pipeline parallel + assert executor.max_concurrent_batches == 2, ( + "Max concurrent batches should equal PP size" + ) + + finally: + # Clean up + executor.shutdown() + + +def test_multiproc_executor_properties(): + """Test various executor properties and configurations.""" + vllm_config = create_vllm_config( + tensor_parallel_size=1, + pipeline_parallel_size=1, + ) + + executor = MultiprocExecutor(vllm_config=vllm_config) + + try: + # Test supports_pp property + assert MultiprocExecutor.supports_pp is True, ( + "MultiprocExecutor should support pipeline parallelism" + ) + + # Test world_size calculation + assert executor.world_size == ( + executor.parallel_config.tensor_parallel_size + * executor.parallel_config.pipeline_parallel_size + ), "World size should equal TP * PP" + + # Test local_world_size calculation + assert executor.local_world_size == ( + executor.parallel_config.world_size // executor.parallel_config.nnodes + ), "Local world size should be world_size / nnodes" + + finally: + # Clean up + executor.shutdown() + + +@multi_gpu_test(num_gpus=4) +def test_multiproc_executor_multi_node(): + """ + Test MultiprocExecutor with multi-node configuration. + This simulates 2 nodes with TP=4: + - Node 0 (rank 0): Uses GPUs 0,1 (CUDA_VISIBLE_DEVICES=0,1) with TP=2 + - Node 1 (rank 1): Uses GPUs 2,3 (CUDA_VISIBLE_DEVICES=2,3) with TP=2 + Total world_size = 4, nnodes = 2 + """ + port = get_open_port() + # symm_mem does not work for simulating multi instance in single node + os.environ["VLLM_ALLREDUCE_USE_SYMM_MEM"] = "0" + + def run_node(node_rank: int, result_queue: multiprocessing.Queue, port: int): + """Run a single node's executor.""" + executor = None + try: + # Set CUDA_VISIBLE_DEVICES for this node + if node_rank == 0: + os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" + else: + os.environ["CUDA_VISIBLE_DEVICES"] = "2,3" + + # Create config for this node + vllm_config = create_vllm_config( + tensor_parallel_size=4, # Total TP across all nodes + pipeline_parallel_size=1, + nnodes=2, # 2 nodes + node_rank=node_rank, + master_port=port, # same port + ) + + # Create executor for this node + executor = MultiprocExecutor(vllm_config=vllm_config) + + # Verify node-specific properties + assert executor.world_size == 4, ( + f"World size should be 4 on node {node_rank}" + ) + assert executor.local_world_size == 2, ( + f"Local world size should be 2 on node {node_rank}" + ) + assert len(executor.workers) == 2, ( + f"Should have 2 local workers on node {node_rank}" + ) + + # Verify worker ranks are correct for this node + expected_ranks = [node_rank * 2, node_rank * 2 + 1] + actual_ranks = sorted([w.rank for w in executor.workers]) + assert actual_ranks == expected_ranks, ( + f"Node {node_rank} should have workers " + f"with ranks {expected_ranks}, got {actual_ranks}" + ) + # Verify all workers are alive + for worker in executor.workers: + assert worker.proc.is_alive(), ( + f"Worker rank {worker.rank} should be alive on node {node_rank}" + ) + # executor.gen + # Put success result in queue BEFORE shutdown to avoid hanging + result_queue.put({"node": node_rank, "success": True}) + import time + + time.sleep(2) + executor.shutdown() + except Exception as e: + # Put failure result in queue + result_queue.put({"node": node_rank, "success": False, "error": str(e)}) + raise e + finally: + if executor is not None: + executor.shutdown() + + # Create a queue to collect results from both processes + result_queue: multiprocessing.Queue[dict[str, int | bool]] = multiprocessing.Queue() + + # Start both node processes + processes = [] + for node_rank in range(2): + p = multiprocessing.Process( + target=run_node, + args=(node_rank, result_queue, port), + name=f"Node{node_rank}", + ) + p.start() + processes.append(p) + + # Wait for both processes to complete + all_completed = True + for p in processes: + p.join(timeout=60) + if p.is_alive(): + p.terminate() + p.join(timeout=20) + if p.is_alive(): + p.kill() + p.join() + all_completed = False + + # Check results from both nodes + results: list[dict[str, int | bool]] = [] + while len(results) < 2: + try: + result = result_queue.get(timeout=1) + results.append(result) + except Exception: + pass + assert all_completed, "Not all processes completed successfully" + assert len(results) == 2, f"Expected 2 results, got {len(results)}" + assert results[0]["success"], f"Node 0 failed: {results[0]}" + assert results[1]["success"], f"Node 1 failed: {results[1]}" diff --git a/tests/distributed/test_sequence_parallel.py b/tests/distributed/test_sequence_parallel.py index 94b2b51211a64..f38c509775ed5 100644 --- a/tests/distributed/test_sequence_parallel.py +++ b/tests/distributed/test_sequence_parallel.py @@ -18,6 +18,7 @@ import pytest from vllm.config.compilation import CompilationMode from vllm.config.model import RunnerOption from vllm.logger import init_logger +from vllm.platforms import current_platform from vllm.utils.torch_utils import is_torch_equal_or_newer from ..models.registry import HF_EXAMPLE_MODELS @@ -161,6 +162,7 @@ def _compare_sp( test_options: SPTestOptions, num_gpus_available: int, use_inductor_graph_partition: bool, + enable_async_tp: bool, *, method: Literal["generate", "encode"], is_multimodal: bool, @@ -244,10 +246,10 @@ def _compare_sp( compilation_config = { "mode": CompilationMode.VLLM_COMPILE, - "custom_ops": ["+rms_norm"], "compile_sizes": [4, 8], "pass_config": { "enable_sequence_parallelism": True, + "enable_async_tp": enable_async_tp, "enable_fusion": enable_fusion, "enable_noop": True, }, @@ -307,6 +309,7 @@ SP_TEST_MODELS = [ ], ) @pytest.mark.parametrize("use_inductor_graph_partition", [True, False]) +@pytest.mark.parametrize("enable_async_tp", [False]) # TODO: enable async TP @create_new_process_for_each_test() def test_tp_sp_generation( model_id: str, @@ -316,10 +319,19 @@ def test_tp_sp_generation( test_options: SPTestOptions, num_gpus_available, use_inductor_graph_partition: bool, + enable_async_tp: bool, ): if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): pytest.skip("inductor graph partition is only available in PyTorch 2.9+") + # Skip FP8 SP-only test on sm89 (compute capability 8.9) + if ( + "fp8" in model_id.lower() + and current_platform.get_device_capability() < (9, 0) + and (not enable_async_tp) + ): + pytest.skip("FP8 reduction support begins with sm90 capable devices.") + _compare_sp( model_id, parallel_setup, @@ -328,6 +340,7 @@ def test_tp_sp_generation( test_options, num_gpus_available, use_inductor_graph_partition, + enable_async_tp=enable_async_tp, method="generate", is_multimodal=False, ) diff --git a/tests/entrypoints/openai/test_serving_tokens.py b/tests/entrypoints/openai/test_serving_tokens.py new file mode 100644 index 0000000000000..62d843e35b86f --- /dev/null +++ b/tests/entrypoints/openai/test_serving_tokens.py @@ -0,0 +1,262 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import httpx +import pytest +import pytest_asyncio +from transformers import AutoTokenizer + +from vllm.config import ModelConfig +from vllm.v1.engine.detokenizer import check_stop_strings + +from ...utils import RemoteOpenAIServer + +MODEL_NAME = "Qwen/Qwen3-0.6B" +GEN_ENDPOINT = "/inference/v1/generate" + + +def get_vocab_size(model_name): + config = ModelConfig( + model=model_name, + seed=0, + dtype="bfloat16", + ) + return config.get_vocab_size() + + +@pytest.fixture(scope="module") +def tokenizer(): + return AutoTokenizer.from_pretrained(MODEL_NAME) + + +@pytest.fixture(scope="module") +def messages(): + return [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "How many countries are in the EU?"}, + ] + + +@pytest.fixture(scope="module") +def server(request): + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "1024", + "--enforce-eager", + ] + + extra_args = getattr(request, "param", None) + if extra_args is not None: + args = args + ( + list(extra_args) + if isinstance(extra_args, (list, tuple)) + else [str(extra_args)] + ) + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server: RemoteOpenAIServer): + transport = httpx.AsyncHTTPTransport(uds=server.uds) if server.uds else None + headers = {"Authorization": f"Bearer {server.DUMMY_API_KEY}"} + async with httpx.AsyncClient( + transport=transport, + base_url=server.url_root, + timeout=600, + headers=headers, + ) as c: + yield c + + +@pytest.mark.asyncio +async def test_generate_endpoint(client): + payload = { + "model": MODEL_NAME, + "token_ids": [1, 2, 3], + "sampling_params": {"max_tokens": 5}, + "stream": False, + } + resp = await client.post(GEN_ENDPOINT, json=payload) + resp.raise_for_status() + data = resp.json() + assert "choices" in data + + +@pytest.mark.asyncio +async def test_same_response_as_chat_completions(client, tokenizer, messages): + token_ids = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + enable_thinking=False, # default with Qwen3 + ) + for ignore_eos in [True, False]: + payload = { + "model": MODEL_NAME, + "token_ids": token_ids, + "sampling_params": { + "max_tokens": 24, + "temperature": 0.0, + # NOTE coordinator will set this to skip detokenization + "detokenize": False, + "ignore_eos": ignore_eos, + }, + "stream": False, + } + generate_resp = await client.post(GEN_ENDPOINT, json=payload) + generate_data = generate_resp.json() + generate_res = tokenizer.decode( + generate_data["choices"][0]["token_ids"], skip_special_tokens=True + ) + + payload = { + "model": MODEL_NAME, + "messages": messages, + "max_tokens": 24, + "temperature": 0.0, + "stream": False, + "ignore_eos": ignore_eos, + "chat_template_kwargs": dict(enable_thinking=False), + } + completions_resp = await client.post("/v1/chat/completions", json=payload) + completions_data = completions_resp.json() + completions_res = completions_data["choices"][0]["message"]["content"] + + assert generate_res == completions_res + + +@pytest.mark.asyncio +async def test_stop_string_workflow(client, tokenizer, messages): + token_ids = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + enable_thinking=False, # default with Qwen3 + ) + payload = { + "model": MODEL_NAME, + "token_ids": token_ids, + "sampling_params": { + "max_tokens": 24, + "temperature": 0.0, + "detokenize": False, + # stop strings are only supported when detokenize is True. + "stop": ["27 member"], + }, + # TODO stream test is much more interesting + "stream": False, + } + with pytest.raises(httpx.HTTPStatusError): + generate_resp = await client.post(GEN_ENDPOINT, json=payload) + generate_resp.raise_for_status() + + payload["sampling_params"]["stop"] = None + generate_resp = await client.post( + GEN_ENDPOINT, json=payload, headers={"X-Request-Id": "42"} + ) + generate_data = generate_resp.json() + generate_res = tokenizer.decode( + generate_data["choices"][0]["token_ids"], skip_special_tokens=True + ) + + # NOTE This is under the responsibility of the coordinator + # stop_checker = StopChecker( + # max_model_len=1024, get_tokenizer_for_seq=lambda _: tokenizer + # ) + stop_str, truncate_to = check_stop_strings( + generate_res, len(generate_res), ["27 member"], False + ) + assert stop_str == "27 member" + # abort request that hit stop string (requires tokens-only mode) + # res = await client.post("/abort_requests", json={"request_ids": ["generate-tokens-42"]}) # noqa: E501 + # res.raise_for_status() + generate_res = generate_res[:truncate_to] + + # Get stop_str response from chat completions + payload = { + "model": MODEL_NAME, + "messages": messages, + "max_tokens": 24, + "temperature": 0.0, + "stream": False, + "stop": ["27 member"], + "chat_template_kwargs": dict(enable_thinking=False), + } + completions_resp = await client.post("/v1/chat/completions", json=payload) + completions_data = completions_resp.json() + completions_res = completions_data["choices"][0]["message"]["content"] + assert generate_res == completions_res + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "server", + [ + [ + "--enable-lora", + "--lora-modules", + "Alice=charent/self_cognition_Alice", + "Bob=charent/self_cognition_Bob", + "--max-lora-rank", + "64", + "--max-cpu-loras", + "2", + ] + ], + indirect=True, +) +async def test_generate_with_lora_adapter(client, tokenizer, messages): + # Verify adapters are listed + models_resp = await client.get("/v1/models") + models_resp.raise_for_status() + models = {m["id"] for m in models_resp.json().get("data", [])} + assert {"Alice", "Bob"}.issubset(models) + + # Generate using a LoRA adapter by specifying its name as the model + payload = { + "model": "Alice", + "token_ids": [1, 2, 3], + "sampling_params": {"max_tokens": 5}, + "stream": False, + } + resp = await client.post(GEN_ENDPOINT, json=payload) + resp.raise_for_status() + data = resp.json() + assert "choices" in data + + token_ids = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + enable_thinking=False, # default with Qwen3 + ) + payload = { + "model": "Alice", + "token_ids": token_ids, + "sampling_params": { + "max_tokens": 24, + "temperature": 0.0, + "detokenize": False, + }, + "stream": False, + } + generate_resp = await client.post(GEN_ENDPOINT, json=payload) + generate_data = generate_resp.json() + generate_res = tokenizer.decode( + generate_data["choices"][0]["token_ids"], skip_special_tokens=True + ) + + payload = { + "model": "Alice", + "messages": messages, + "max_tokens": 24, + "temperature": 0.0, + "stream": False, + "chat_template_kwargs": dict(enable_thinking=False), + } + completions_resp = await client.post("/v1/chat/completions", json=payload) + completions_data = completions_resp.json() + completions_res = completions_data["choices"][0]["message"]["content"] + + assert generate_res == completions_res diff --git a/tests/entrypoints/pooling/openai/test_classification.py b/tests/entrypoints/pooling/openai/test_classification.py index 671bb948780ae..25080d4189c2d 100644 --- a/tests/entrypoints/pooling/openai/test_classification.py +++ b/tests/entrypoints/pooling/openai/test_classification.py @@ -46,6 +46,16 @@ def test_single_input_classification(server: RemoteOpenAIServer, model_name: str assert hasattr(output.data[0], "probs") +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_add_special_tokens_false(server: RemoteOpenAIServer, model_name: str): + response = requests.post( + server.url_for("classify"), + json={"model": model_name, "input": "hello", "add_special_tokens": False}, + ) + response.raise_for_status() + ClassificationResponse.model_validate(response.json()) + + @pytest.mark.parametrize("model_name", [MODEL_NAME]) def test_multiple_inputs_classification(server: RemoteOpenAIServer, model_name: str): input_texts = [ diff --git a/tests/entrypoints/pooling/openai/test_vision_classification.py b/tests/entrypoints/pooling/openai/test_vision_classification.py new file mode 100644 index 0000000000000..f2616e057b175 --- /dev/null +++ b/tests/entrypoints/pooling/openai/test_vision_classification.py @@ -0,0 +1,95 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json + +import pytest +import requests + +from tests.utils import RemoteOpenAIServer +from vllm.entrypoints.openai.protocol import ClassificationResponse + +VLM_MODEL_NAME = "muziyongshixin/Qwen2.5-VL-7B-for-VideoCls" +MAXIMUM_VIDEOS = 1 +TEST_VIDEO_URL = "https://www.bogotobogo.com/python/OpenCV_Python/images/mean_shift_tracking/slow_traffic_small.mp4" + +HF_OVERRIDES = { + "text_config": { + "architectures": ["Qwen2_5_VLForSequenceClassification"], + }, +} + + +@pytest.fixture(scope="module") +def server_vlm_classify(): + args = [ + "--runner", + "pooling", + "--max-model-len", + "5000", + "--enforce-eager", + "--limit-mm-per-prompt", + json.dumps({"video": MAXIMUM_VIDEOS}), + ] + + with RemoteOpenAIServer( + VLM_MODEL_NAME, args, override_hf_configs=HF_OVERRIDES + ) as remote_server: + yield remote_server + + +@pytest.mark.parametrize("model_name", [VLM_MODEL_NAME]) +def test_classify_accepts_chat_text_only( + server_vlm_classify: RemoteOpenAIServer, model_name: str +) -> None: + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Please classify this text request."}, + ], + } + ] + + response = requests.post( + server_vlm_classify.url_for("classify"), + json={"model": model_name, "messages": messages}, + ) + response.raise_for_status() + + output = ClassificationResponse.model_validate(response.json()) + + assert output.object == "list" + assert output.model == model_name + assert len(output.data) == 1 + assert len(output.data[0].probs) == 2 + assert output.usage.prompt_tokens == 22 + + +@pytest.mark.parametrize("model_name", [VLM_MODEL_NAME]) +def test_classify_accepts_chat_video_url( + server_vlm_classify: RemoteOpenAIServer, model_name: str +) -> None: + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Please classify this video."}, + {"type": "video_url", "video_url": {"url": TEST_VIDEO_URL}}, + ], + } + ] + + response = requests.post( + server_vlm_classify.url_for("classify"), + json={"model": model_name, "messages": messages}, + ) + response.raise_for_status() + + output = ClassificationResponse.model_validate(response.json()) + + assert output.object == "list" + assert output.model == model_name + assert len(output.data) == 1 + assert len(output.data[0].probs) == 2 + assert output.usage.prompt_tokens == 4807 diff --git a/tests/evals/gsm8k/gsm8k_eval.py b/tests/evals/gsm8k/gsm8k_eval.py index c7799607912b6..0421f8bb18592 100644 --- a/tests/evals/gsm8k/gsm8k_eval.py +++ b/tests/evals/gsm8k/gsm8k_eval.py @@ -83,8 +83,12 @@ async def call_vllm_api( stop: list[str] | None = None, url: str | None = None, seed: int | None = None, -) -> str: - """Call vLLM's OpenAI-compatible completions endpoint.""" +) -> tuple[str, int]: + """Call vLLM's OpenAI-compatible completions endpoint. + + Returns: + Tuple of (response_text, completion_tokens) + """ data = { "prompt": prompt, "temperature": temperature, @@ -98,10 +102,12 @@ async def call_vllm_api( async with session.post(f"{url}/v1/completions", json=data) as response: response.raise_for_status() result = await response.json() - return result["choices"][0]["text"] + text = result["choices"][0]["text"] + completion_tokens = result.get("usage", {}).get("completion_tokens", 0) + return text, completion_tokens except Exception as e: print(f"Error calling vLLM API: {e}") - return "" + return "", 0 def evaluate_gsm8k( @@ -146,10 +152,11 @@ def evaluate_gsm8k( # Run evaluation async def run_async_evaluation(): states: list[str] = [""] * num_questions + output_tokens: list[int] = [0] * num_questions - async def get_answer(session: aiohttp.ClientSession, i: int) -> str: + async def get_answer(session: aiohttp.ClientSession, i: int) -> tuple[str, int]: prompt = few_shot_examples + questions[i] - answer = await call_vllm_api( + answer, tokens = await call_vllm_api( session=session, prompt=prompt, temperature=temperature, @@ -159,7 +166,8 @@ def evaluate_gsm8k( seed=seed, ) states[i] = answer - return answer + output_tokens[i] = tokens + return answer, tokens async with aiohttp.ClientSession( timeout=aiohttp.ClientTimeout(total=600) @@ -167,24 +175,28 @@ def evaluate_gsm8k( tasks = [get_answer(session, i) for i in range(num_questions)] await tqdm.gather(*tasks, desc="Evaluating") - return states + return states, output_tokens print(f"Running GSM8K evaluation: {num_questions} questions, {num_shots}-shot") tic = time.perf_counter() - states = asyncio.run(run_async_evaluation()) + states, output_tokens = asyncio.run(run_async_evaluation()) latency = time.perf_counter() - tic # Compute metrics preds = [get_answer_value(state) for state in states] accuracy = np.mean(np.array(preds) == np.array(labels)) invalid_rate = np.mean(np.array(preds) == INVALID) + total_output_tokens = sum(output_tokens) + tokens_per_second = total_output_tokens / latency if latency > 0 else 0.0 result = { "accuracy": accuracy, "invalid_rate": invalid_rate, "latency": latency, "questions_per_second": num_questions / latency, + "total_output_tokens": total_output_tokens, + "tokens_per_second": tokens_per_second, "num_questions": num_questions, "num_shots": num_shots, "max_tokens": max_tokens, @@ -236,6 +248,8 @@ def main() -> None: print(f"Invalid responses: {result['invalid_rate']:.3f}") print(f"Total latency: {result['latency']:.3f} s") print(f"Questions per second: {result['questions_per_second']:.3f}") + print(f"Total output tokens: {result['total_output_tokens']}") + print(f"Output tokens per second: {result['tokens_per_second']:.3f}") # Optional file saving if args.save_results: diff --git a/tests/kernels/attention/test_cascade_flash_attn.py b/tests/kernels/attention/test_cascade_flash_attn.py index 4295f852f95bb..20f573821b25f 100755 --- a/tests/kernels/attention/test_cascade_flash_attn.py +++ b/tests/kernels/attention/test_cascade_flash_attn.py @@ -170,6 +170,7 @@ def test_cascade( logits_soft_cap=soft_cap if soft_cap is not None else 0, block_table=block_tables, common_prefix_len=common_prefix_len, + max_num_splits=0, # no max fa_version=fa_version, ) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index 62704bbcbbc79..2285709fa7d60 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -40,8 +40,6 @@ NUM_EXPERTS = [8, 64] TOP_KS = [1, 2, 6] vllm_config = VllmConfig() -vllm_config.scheduler_config.max_num_seqs = 128 -vllm_config.scheduler_config.max_model_len = 8192 @dataclass diff --git a/tests/kernels/moe/test_block_fp8.py b/tests/kernels/moe/test_block_fp8.py index cd34617ee0fc4..88db4b3e537c2 100644 --- a/tests/kernels/moe/test_block_fp8.py +++ b/tests/kernels/moe/test_block_fp8.py @@ -33,8 +33,6 @@ if current_platform.get_device_capability() < (9, 0): pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True) vllm_config = VllmConfig() -vllm_config.scheduler_config.max_num_seqs = 128 -vllm_config.scheduler_config.max_model_len = 8192 # Test configurations DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32] diff --git a/tests/kernels/moe/test_block_int8.py b/tests/kernels/moe/test_block_int8.py index 3799e60f1294a..e35ca4caa9dbc 100644 --- a/tests/kernels/moe/test_block_int8.py +++ b/tests/kernels/moe/test_block_int8.py @@ -18,8 +18,6 @@ if current_platform.get_device_capability() < (7, 0): pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True) vllm_config = VllmConfig() -vllm_config.scheduler_config.max_num_seqs = 128 -vllm_config.scheduler_config.max_model_len = 8192 DTYPES = [torch.bfloat16] diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index 5512ccce47b05..c15837f145705 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -42,8 +42,6 @@ MNK_FACTORS = [ ] vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) -vllm_config.scheduler_config.max_num_seqs = 128 -vllm_config.scheduler_config.max_model_len = 8192 @dataclasses.dataclass diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py index 9d039b81690a1..455ecacef5ec3 100644 --- a/tests/kernels/moe/test_deepep_deepgemm_moe.py +++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py @@ -7,6 +7,7 @@ fp8 block-quantized case. """ import dataclasses +from contextlib import contextmanager import pytest import torch.distributed @@ -14,6 +15,7 @@ from torch.distributed import ProcessGroup from typing_extensions import ParamSpec from vllm.config import VllmConfig, set_current_vllm_config +from vllm.forward_context import set_forward_context from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, fp8_w8a8_moe_quant_config, @@ -21,7 +23,11 @@ from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.platforms import current_platform -from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported +from vllm.utils.deep_gemm import ( + get_mk_alignment_for_contiguous_layout, + is_deep_gemm_e8m0_used, + is_deep_gemm_supported, +) from vllm.utils.import_utils import has_deep_ep, has_deep_gemm from ...utils import multi_gpu_test @@ -57,6 +63,23 @@ requires_deep_gemm = pytest.mark.skipif( P = ParamSpec("P") +@contextmanager +def with_dp_metadata(M: int, world_size: int): + num_tokens_across_dp = torch.tensor([M] * world_size, device="cpu", dtype=torch.int) + + vllm_config = VllmConfig() + vllm_config.parallel_config.data_parallel_size = world_size + vllm_config.parallel_config.enable_expert_parallel = True + + with set_forward_context( + None, + vllm_config, + num_tokens=M, + num_tokens_across_dp=num_tokens_across_dp, + ): + yield + + def next_power_of_2(x): import math @@ -281,18 +304,21 @@ def deepep_deepgemm_moe_impl( quant_config=quant_config, ) - out = mk.forward( - hidden_states=test_tensors.rank_tokens, - w1=w1, - w2=w2, - topk_weights=test_tensors.topk_weights, - topk_ids=test_tensors.topk, - inplace=False, - activation="silu", - global_num_experts=num_experts, - expert_map=build_expert_map(), - apply_router_weight_on_input=False, - ) + with with_dp_metadata( + M=test_tensors.rank_tokens.size(0), world_size=pgi.world_size + ): + out = mk.forward( + hidden_states=test_tensors.rank_tokens, + w1=w1, + w2=w2, + topk_weights=test_tensors.topk_weights, + topk_ids=test_tensors.topk, + inplace=False, + activation="silu", + global_num_experts=num_experts, + expert_map=build_expert_map(), + apply_router_weight_on_input=False, + ) return out @@ -413,19 +439,16 @@ NUM_EXPERTS = [32] @multi_gpu_test(num_gpus=2) @requires_deep_ep @requires_deep_gemm -@pytest.mark.skipif( - is_deep_gemm_e8m0_used(), reason="Skipping test for Blackwell DeepGEMM" -) def test_ht_deepep_deepgemm_moe( mnk: tuple[int, int, int], num_experts: int, topk: int, world_dp_size: tuple[int, int], + disable_deepgemm_ue8m0, ): """ Tests for High-Throughput DeepEP + DeepGemm integration. """ - import deep_gemm m, n, k = mnk current_platform.seed_everything(7) @@ -433,7 +456,7 @@ def test_ht_deepep_deepgemm_moe( if topk > num_experts: pytest.skip(f"Skipping test: topk={topk} > E={num_experts}") - block_m = deep_gemm.get_m_alignment_for_contiguous_layout() + block_m = get_mk_alignment_for_contiguous_layout()[0] block_size = [block_m, block_m] world_size, dp_size = world_dp_size @@ -487,9 +510,6 @@ USE_FP8_DISPATCH = [False] @multi_gpu_test(num_gpus=2) @requires_deep_ep @requires_deep_gemm -@pytest.mark.skipif( - is_deep_gemm_e8m0_used(), reason="Skipping test for Blackwell DeepGEMM" -) def test_ll_deepep_deepgemm_moe( mnk: tuple[int, int, int], num_experts: int, @@ -497,10 +517,12 @@ def test_ll_deepep_deepgemm_moe( use_fp8_dispatch: bool, block_size: list[int], world_dp_size: tuple[int, int], + disable_deepgemm_ue8m0, ): """ Tests for Low-Latency DeepEP + DeepGemm integration. """ + assert not is_deep_gemm_e8m0_used() m, n, k = mnk current_platform.seed_everything(7) diff --git a/tests/kernels/moe/test_deepep_moe.py b/tests/kernels/moe/test_deepep_moe.py index b49319a7e6f54..d78b8250463a9 100644 --- a/tests/kernels/moe/test_deepep_moe.py +++ b/tests/kernels/moe/test_deepep_moe.py @@ -294,7 +294,7 @@ def torch_moe_impl( # blockwise quant and de-quant. assert not per_act_token_quant a = test_tensors.rank_tokens - aq, aq_scale = per_token_group_quant_fp8(a, 128) + aq, aq_scale = per_token_group_quant_fp8(a, 128, use_ue8m0=False) a = ( (aq.view(-1, 128).to(torch.float32) * aq_scale.view(-1, 1)) .view(a.shape) diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py index 707068b2bbdc2..218df4a2632c3 100644 --- a/tests/kernels/moe/test_flashinfer.py +++ b/tests/kernels/moe/test_flashinfer.py @@ -45,8 +45,6 @@ MNK_FACTORS = [ ] vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) -vllm_config.scheduler_config.max_num_seqs = 128 -vllm_config.scheduler_config.max_model_len = 8192 def quant_fp8_per_tensor_batches(a): @@ -79,10 +77,14 @@ class TestData: @staticmethod def make_moe_tensors_8bit( - m: int, k: int, n: int, e: int, reorder: bool + m: int, k: int, n: int, e: int, reorder: bool, activation: str = "silu" ) -> "TestData": + is_gated = activation != "relu2_no_mul" + hidden_states = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10 - w13 = torch.randn((e, 2 * n, k), device="cuda", dtype=torch.bfloat16) + w13 = torch.randn( + (e, (2 * n) if is_gated else n, k), device="cuda", dtype=torch.bfloat16 + ) w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16) # Scale to fp8 @@ -192,18 +194,22 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph( @pytest.mark.parametrize("m,n,k", MNK_FACTORS) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("activation", ["silu", "relu2_no_mul"]) def test_flashinfer_cutlass_moe_fp8_no_graph( m: int, n: int, k: int, e: int, topk: int, + activation: str, monkeypatch, ): current_platform.seed_everything(7) monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") with set_current_vllm_config(vllm_config): - td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=False) + td = TestData.make_moe_tensors_8bit( + m, k, n, e, reorder=False, activation=activation + ) score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) topk_weights, topk_ids, _ = FusedMoE.select_experts( @@ -235,7 +241,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph( topk_weights=topk_weights, topk_ids=topk_ids, inplace=False, - activation="silu", + activation=activation, global_num_experts=e, expert_map=None, apply_router_weight_on_input=True, @@ -255,7 +261,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph( td.layer, topk_weights, topk_ids, - activation="silu", + activation=activation, global_num_experts=e, expert_map=None, apply_router_weight_on_input=True, diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index c27cf2468ede5..0550c2d9e2125 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -81,8 +81,6 @@ FUSED_MOE_WN16_MNK_FACTORS = [ ] vllm_config = VllmConfig() -vllm_config.scheduler_config.max_num_seqs = 128 -vllm_config.scheduler_config.max_model_len = 8192 def run_moe_test( diff --git a/tests/kernels/moe/test_pplx_cutlass_moe.py b/tests/kernels/moe/test_pplx_cutlass_moe.py index a2de64974b353..dd4eb4da913bd 100644 --- a/tests/kernels/moe/test_pplx_cutlass_moe.py +++ b/tests/kernels/moe/test_pplx_cutlass_moe.py @@ -192,8 +192,6 @@ def pplx_cutlass_moe( vllm_config = VllmConfig() -vllm_config.scheduler_config.max_num_seqs = 128 -vllm_config.scheduler_config.max_model_len = 8192 def _pplx_moe( diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 0f0ed3326d159..f671b23d300ce 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -81,8 +81,6 @@ TOP_KS = [1, 2, 6] DTYPES = [torch.float8_e4m3fn, torch.bfloat16] vllm_config = VllmConfig() -vllm_config.scheduler_config.max_num_seqs = 128 -vllm_config.scheduler_config.max_model_len = 8192 def torch_prepare( diff --git a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py index 420dbbffaac08..d6b78dd2c2323 100644 --- a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py +++ b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py @@ -1,6 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import random + import pytest import torch @@ -8,27 +11,30 @@ from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( persistent_masked_m_silu_mul_quant, ) from vllm.platforms import current_platform -from vllm.utils.math_utils import cdiv +from vllm.utils.deep_gemm import DeepGemmQuantScaleFMT, has_deep_gemm +from vllm.utils.math_utils import cdiv, round_up fp8_dtype = torch.float8_e4m3fn CASES = [ (1, 1, 128, fp8_dtype), - (1, 4, 128, fp8_dtype), - (2, 4, 256, fp8_dtype), - (32, 64, 256, fp8_dtype), - (17, 31, 768, fp8_dtype), - (1, 1, 128 * 1, fp8_dtype), - (1, 1, 128 * 3, fp8_dtype), - (1, 1, 128 * 4, fp8_dtype), - (8, 16, 128 * 1, fp8_dtype), - (8, 16, 128 * 2, fp8_dtype), - (8, 16, 128 * 3, fp8_dtype), + (1, 4, 128 * 1, fp8_dtype), + (2, 4, 128 * 2, fp8_dtype), + (1, 4, 128 * 3, fp8_dtype), + (8, 16, 128 * 4, fp8_dtype), + (8, 16, 128 * 5, fp8_dtype), + (8, 16, 128 * 6, fp8_dtype), + (8, 16, 128 * 7, fp8_dtype), + (8, 16, 128 * 8, fp8_dtype), + (8, 16, 128 * 9, fp8_dtype), (8, 64, 7168, fp8_dtype), (8, 128, 128 * 33, fp8_dtype), + (1, 4, 128 * 10, fp8_dtype), (8, 128, 7168, fp8_dtype), (8, 512, 7168, fp8_dtype), (8, 1024, 7168, fp8_dtype), + (17, 31, 768, fp8_dtype), + (32, 64, 256, fp8_dtype), (256, 8, 7168, fp8_dtype), (256, 32, 7168, fp8_dtype), (256, 64, 7168, fp8_dtype), @@ -38,14 +44,159 @@ CASES = [ ] +def as_uint8(x) -> torch.Tensor: + return ( + torch.empty(x.shape, dtype=x.dtype, device=x.device).copy_(x).view(torch.uint8) + ) + + +def silu(x: torch.Tensor) -> torch.Tensor: + one_f32 = torch.tensor([1.0], device=x.device, dtype=torch.float32) + x_f32 = x.to(torch.float32) + act_f32 = x_f32 / (one_f32 + torch.exp(-x_f32)) + assert act_f32.dtype == torch.float32 + return act_f32.to(torch.bfloat16) + + +def do_quant(x: torch.Tensor, group_size: int, ceil_ue8m0: bool): + eps_bf16 = torch.tensor([1e-10], device=x.device, dtype=torch.bfloat16) + one_bf16 = torch.tensor([1.0], device=x.device, dtype=torch.bfloat16) + fp8_max_bf16 = torch.tensor( + [torch.finfo(fp8_dtype).max], device=x.device, dtype=torch.bfloat16 + ) + fp8_min_bf16 = torch.tensor( + [torch.finfo(fp8_dtype).min], device=x.device, dtype=torch.bfloat16 + ) + fp8_max_inv = one_bf16 / fp8_max_bf16 + assert fp8_max_inv.dtype == torch.bfloat16 + + assert x.size(-1) % group_size == 0 + num_groups = x.numel() // group_size + x_og_shape = x.shape + + x = x.to(torch.bfloat16) + x = x.view((-1, group_size)) + amax = x.abs().amax(dim=1).clamp(min=eps_bf16) + assert amax.dtype == torch.bfloat16 + s = amax * fp8_max_inv + + if ceil_ue8m0: + s = torch.exp2( + torch.ceil(torch.log2(s).to(torch.bfloat16)).to(torch.bfloat16) + ).to(torch.bfloat16) + + inv_s = one_bf16 / s + inv_s = inv_s.view((num_groups, 1)) + xq = torch.clamp(x * inv_s, min=fp8_min_bf16.item(), max=fp8_max_bf16.item()).to( + fp8_dtype + ) + + xq = xq.view(x_og_shape) + xs = s.view((-1, xq.size(-1) // group_size)) + return xq, xs + + +def silu_mul_quant( + gate: torch.Tensor, up: torch.Tensor, group_size: int, ceil_ue8m0: bool +) -> tuple[torch.Tensor, torch.Tensor]: + assert gate.size(-1) % group_size == 0 + assert up.size(-1) % group_size == 0 + + assert gate.dtype == torch.bfloat16 + assert up.dtype == torch.bfloat16 + + act_bf16 = silu(gate) + assert act_bf16.dtype == torch.bfloat16 + + # act & mul + a_m = act_bf16 * up + assert a_m.dtype == torch.bfloat16 + + q, s = do_quant(a_m, group_size, ceil_ue8m0) + return q, s + + +def pack_scales(x: torch.Tensor, tokens_per_expert: torch.Tensor) -> torch.Tensor: + """ + pack float32 scales into a int32 tensor + """ + assert x.dtype == torch.float32 + E, T, G = x.size() + + # Add i32_padding here so we can view it as a i32 tensor later on. + i32_padding = round_up(G, 4) - G + ref_s_i8 = torch.empty((E, T, G + i32_padding), dtype=torch.uint8, device="cuda") + for e in range(E): + nt = tokens_per_expert[e].item() + ref_s_i8[e, :nt, :G] = x[e, :nt].view(torch.int32) >> 23 + + ref_s_i32 = ref_s_i8.view(torch.int32) + + return ref_s_i32 + + +def ref_with_scale_fmt( + E: int, + T: int, + H: int, + group_size: int, + tokens_per_expert: torch.Tensor, + gate: torch.Tensor, + up: torch.Tensor, + scale_fmt: DeepGemmQuantScaleFMT, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + The precision types of the operations triggered by this function + match closely with the kernel implementation so we compare more + accurately. + """ + scale_dtype = ( + torch.int32 if scale_fmt == DeepGemmQuantScaleFMT.UE8M0 else torch.float32 + ) + ceil_ue8m0 = scale_fmt in [ + DeepGemmQuantScaleFMT.UE8M0, + DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0, + ] + + ref_q = torch.empty((E, T, H), dtype=fp8_dtype, device="cuda") + ref_s_f32 = torch.empty( + (E, T, cdiv(H, group_size)), dtype=torch.float32, device="cuda" + ) + + for e in range(E): + nt = tokens_per_expert[e].item() + if nt == 0: + continue + ref_q[e, :nt], ref_s_f32[e, :nt] = silu_mul_quant( + gate[e, :nt], up[e, :nt], group_size, ceil_ue8m0=ceil_ue8m0 + ) + + if scale_dtype == torch.float32: + return ref_q, ref_s_f32 + + assert scale_dtype == torch.int32 + return ref_q, pack_scales(ref_s_f32, tokens_per_expert) + + +def token_random(E, T, H2, tokens_per_expert): + """ + Initialize each token in a random range so we test a range of + scale values. + """ + y = torch.empty((E, T, H2), dtype=torch.bfloat16, device="cuda") + for e in range(E): + for t in range(tokens_per_expert[e].item()): + exp = random.choice(range(1, 20)) + y[e, t].uniform_(-(2**exp), 2**exp) + return y + + @pytest.mark.parametrize("E,T,H,fp8_type", CASES) @torch.inference_mode() -def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type): +def test_silu_mul_fp8_quant_deep_gemm(E: int, T: int, H: int, fp8_type: torch.dtype): group_size = 128 current_platform.seed_everything(42) - # Input tensor of shape (E, T, 2*H) - y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda") tokens_per_expert = torch.randint( low=0, high=T, @@ -54,71 +205,83 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type): device="cuda", ) + # Input tensor of shape (E, T, 2*H) + y = token_random(E, T, 2 * H, tokens_per_expert) + + gate = y[..., :H].to(torch.bfloat16) + up = y[..., H:].to(torch.bfloat16) + + scale_fmts = [ + DeepGemmQuantScaleFMT.FLOAT32, + DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0, + DeepGemmQuantScaleFMT.UE8M0, + ] + # Run the SiLU V2 kernel - # TODO (varun): use_e8m0 is set to false as the reference impl does - # not handle that case. - y_q, y_s = persistent_masked_m_silu_mul_quant( - y, tokens_per_expert, group_size=group_size, use_ue8m0=False - ) - - torch.cuda.synchronize() - fp8_info = torch.finfo(fp8_dtype) - fp8_max = fp8_info.max - fp8_min = fp8_info.min - eps = 1e-10 - - y1 = y[..., :H].float() - y2 = y[..., H:] - silu_x = y1 * torch.sigmoid(y1) - merged = silu_x * y2 - - for e in range(E): - nt = tokens_per_expert[e].item() - ref_s = torch.empty( - (T, cdiv(H, group_size)), dtype=torch.float32, device="cuda" - ) - ref_q = torch.empty((T, H), dtype=fp8_dtype, device="cuda") - - for t in range(nt): - data = merged[e, t].float() - ref_q_row = torch.empty_like(data) - - # process full groups - n_full_groups = H // group_size - if n_full_groups > 0: - data_grp = data[: n_full_groups * group_size].view( - n_full_groups, group_size - ) - amax = data_grp.abs().amax(dim=1).clamp(min=eps) - scale = amax / fp8_max - scaled = data[: n_full_groups * group_size] / scale.repeat_interleave( - group_size - ) - ref_q_row[: n_full_groups * group_size] = scaled.clamp( - fp8_min, fp8_max - ).to(fp8_dtype) - ref_s[t, :n_full_groups] = scale - - # process remainder group - rem = H % group_size - if rem > 0: - data_rem = data[-rem:] - amax = data_rem.abs().amax().clamp(min=eps) - scale = amax / fp8_max - scaled = data_rem / scale - ref_q_row[-rem:] = scaled.clamp(fp8_min, fp8_max).to(fp8_dtype) - ref_s[t, -1] = scale - - ref_q[t] = ref_q_row - - y_se = y_s[e].float() - y_qe = y_q[e].float() - - torch.testing.assert_close( - y_qe[:nt].to(torch.float32), - ref_q[:nt].to(torch.float32), - atol=2, - rtol=2e-1, + for scale_fmt in scale_fmts: + y_q, y_s = persistent_masked_m_silu_mul_quant( + y, + tokens_per_expert, + group_size=group_size, + quant_scale_fmt=scale_fmt, ) - torch.testing.assert_close(y_se[:nt], ref_s[:nt], atol=1e-4, rtol=1e-2) + ref_y_q, ref_y_s = ref_with_scale_fmt( + E, T, H, group_size, tokens_per_expert, gate, up, scale_fmt=scale_fmt + ) + + # deepgemm scales transform + dg_scales = None + if ( + has_deep_gemm() + and current_platform.has_device_capability(100) + and scale_fmt == DeepGemmQuantScaleFMT.UE8M0 + ): + from deep_gemm import transform_sf_into_required_layout + + _q, _s = ref_with_scale_fmt( + E, + T, + H, + group_size, + tokens_per_expert, + gate, + up, + scale_fmt=DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0, + ) + dg_scales = transform_sf_into_required_layout( + sf=_s, + mn=_q.size(1), + k=_q.size(2), + recipe=(1, 128, 128), + num_groups=_q.size(0), + is_sfa=True, + ) + + expected_scale_dtype = ( + torch.int32 if scale_fmt == DeepGemmQuantScaleFMT.UE8M0 else torch.float32 + ) + assert y_s.dtype == expected_scale_dtype + assert ref_y_s.dtype == expected_scale_dtype + + for e in range(E): + nt = tokens_per_expert[e].item() + + torch.testing.assert_close( + y_q[e, :nt].to(torch.float32), + ref_y_q[e, :nt].to(torch.float32), + ) + + if scale_fmt == DeepGemmQuantScaleFMT.UE8M0: + G = H // group_size + y_s_sliced = as_uint8(y_s[e]) + ref_s_sliced = as_uint8(ref_y_s[e]) + torch.testing.assert_close(y_s_sliced[:nt, :G], ref_s_sliced[:nt, :G]) + if dg_scales is not None: + dg_sliced = as_uint8(dg_scales[e]) + torch.testing.assert_close(y_s_sliced[:nt, :G], dg_sliced[:nt, :G]) + else: + torch.testing.assert_close( + y_s[e, :nt], + ref_y_s[e, :nt], + ) diff --git a/tests/kernels/moe/test_triton_moe_ptpc_fp8.py b/tests/kernels/moe/test_triton_moe_ptpc_fp8.py index 933cd9dbdeaa0..7a467e160b784 100644 --- a/tests/kernels/moe/test_triton_moe_ptpc_fp8.py +++ b/tests/kernels/moe/test_triton_moe_ptpc_fp8.py @@ -18,8 +18,6 @@ if current_platform.get_device_capability() < (9, 0): pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True) vllm_config = VllmConfig() -vllm_config.scheduler_config.max_num_seqs = 128 -vllm_config.scheduler_config.max_model_len = 8192 def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16): diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index 55f092e7ea694..e9973c1fcc15e 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -29,8 +29,6 @@ if current_platform.get_device_capability() < (9, 0): pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True) vllm_config = VllmConfig() -vllm_config.scheduler_config.max_num_seqs = 128 -vllm_config.scheduler_config.max_model_len = 8192 # Test configurations DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32] diff --git a/tests/kernels/quantization/test_block_int8.py b/tests/kernels/quantization/test_block_int8.py index dabc10a122f7a..310091b6a554d 100644 --- a/tests/kernels/quantization/test_block_int8.py +++ b/tests/kernels/quantization/test_block_int8.py @@ -18,8 +18,6 @@ if current_platform.get_device_capability() < (7, 0): pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True) vllm_config = VllmConfig() -vllm_config.scheduler_config.max_num_seqs = 128 -vllm_config.scheduler_config.max_model_len = 8192 DTYPES = [torch.half, torch.bfloat16] M = [1, 33, 64, 222] diff --git a/tests/model_executor/test_eagle_quantization.py b/tests/model_executor/test_eagle_quantization.py new file mode 100644 index 0000000000000..1ab75933ee31e --- /dev/null +++ b/tests/model_executor/test_eagle_quantization.py @@ -0,0 +1,169 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from unittest.mock import Mock, patch + +import pytest +import torch + +from vllm.config import LoadConfig, ModelConfig, SpeculativeConfig, VllmConfig +from vllm.model_executor.models.utils import get_draft_quant_config +from vllm.platforms import current_platform + +DEVICES = ( + [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] + if current_platform.is_cuda_alike() + else ["cpu"] +) + + +def test_get_draft_quant_config_with_draft_model(): + mock_draft_model_config = Mock(spec=ModelConfig) + mock_load_config = Mock(spec=LoadConfig) + mock_speculative_config = Mock(spec=SpeculativeConfig) + mock_speculative_config.draft_model_config = mock_draft_model_config + + mock_vllm_config = Mock(spec=VllmConfig) + mock_vllm_config.speculative_config = mock_speculative_config + mock_vllm_config.load_config = mock_load_config + + mock_quant_config = Mock() + with patch.object( + VllmConfig, "get_quantization_config", return_value=mock_quant_config + ): + result = get_draft_quant_config(mock_vllm_config) + + # Verify the function calls get_quantization_config with draft model config + VllmConfig.get_quantization_config.assert_called_once_with( + mock_draft_model_config, mock_load_config + ) + assert result == mock_quant_config + + +def test_get_draft_quant_config_without_draft_model(): + mock_speculative_config = Mock(spec=SpeculativeConfig) + mock_speculative_config.draft_model_config = None + + mock_vllm_config = Mock(spec=VllmConfig) + mock_vllm_config.speculative_config = mock_speculative_config + mock_vllm_config.load_config = Mock(spec=LoadConfig) + + result = get_draft_quant_config(mock_vllm_config) + + assert result is None + + +@torch.inference_mode() +@pytest.mark.parametrize("device", DEVICES) +def test_fc_layer_quant_config_usage(dist_init, device) -> None: + import torch + + from vllm.model_executor.layers.linear import ReplicatedLinear + + if current_platform.is_cuda_alike(): + torch.cuda.set_device(device) + + torch.set_default_device(device) + + input_size = 256 + output_size = 128 + + fc_no_quant = ReplicatedLinear( + input_size=input_size, + output_size=output_size, + bias=False, + params_dtype=torch.float16, + quant_config=None, + prefix="fc", + ) + + assert fc_no_quant.quant_config is None + assert fc_no_quant.input_size == input_size + assert fc_no_quant.output_size == output_size + + mock_quant_config = Mock() + fc_with_quant = ReplicatedLinear( + input_size=input_size, + output_size=output_size, + bias=False, + params_dtype=torch.float16, + quant_config=mock_quant_config, + prefix="fc", + ) + + assert fc_with_quant.quant_config == mock_quant_config + + # Check forward pass + x = torch.randn(2, input_size, dtype=torch.float16) + output, _ = fc_no_quant(x) + assert output.shape == (2, output_size) + + +def test_kv_cache_scale_name_handling(): + # Mock a quant config that supports cache scales + mock_quant_config = Mock() + mock_quant_config.get_cache_scale = Mock(return_value="layers.0.self_attn.kv_scale") + + # Condition check in load_weights + name = "layers.0.self_attn.k_proj.weight" + scale_name = mock_quant_config.get_cache_scale(name) + + # Check if get_cache_scale is called and returns expected value + mock_quant_config.get_cache_scale.assert_called_once_with(name) + assert scale_name == "layers.0.self_attn.kv_scale" + + +def test_kv_cache_scale_name_no_scale(): + # Mock a quant config that returns None for get_cache_scale + mock_quant_config = Mock() + mock_quant_config.get_cache_scale = Mock(return_value=None) + + name = "layers.0.mlp.gate_proj.weight" + scale_name = mock_quant_config.get_cache_scale(name) + + # Should return None for weights that don't have cache scales + assert scale_name is None + + +def test_maybe_remap_kv_scale_name(): + from vllm.model_executor.model_loader.weight_utils import maybe_remap_kv_scale_name + + params_dict = { + "layers.0.self_attn.kv_scale": Mock(), + "layers.1.self_attn.kv_scale": Mock(), + } + + name = "layers.0.self_attn.some_scale" + remapped = maybe_remap_kv_scale_name(name, params_dict) + + assert remapped in params_dict or remapped == name or remapped is None + + +def test_load_weights_kv_scale_handling(): + kv_scale_param = Mock() + kv_scale_param.weight_loader = Mock() + + params_dict = { + "layers.0.self_attn.kv_scale": kv_scale_param, + } + + mock_quant_config = Mock() + mock_quant_config.get_cache_scale = Mock(return_value="layers.0.self_attn.kv_scale") + + # Load_weights logic for KV cache scales + name = "layers.0.self_attn.k_proj.weight" + loaded_weight_tensor = torch.tensor([1.0, 2.0]) + + if mock_quant_config is not None: + scale_name = mock_quant_config.get_cache_scale(name) + if scale_name: + param = params_dict[scale_name] + assert param is kv_scale_param + weight_to_load = ( + loaded_weight_tensor + if loaded_weight_tensor.dim() == 0 + else loaded_weight_tensor[0] + ) + + assert scale_name == "layers.0.self_attn.kv_scale" + assert weight_to_load == loaded_weight_tensor[0] diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 681b380e6a155..37830093cd3c5 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -348,9 +348,14 @@ def test_fp32_cache_state( # Helper functions for the APC tests -def _get_vllm_runner_params(model, max_model_len, tensor_parallel_size=1): +def _get_vllm_runner_params( + model: str, + max_model_len: int, + tensor_parallel_size: int = 1, +): return { "model_name": model, + "enable_chunked_prefill": True, "enable_prefix_caching": False, "max_model_len": max_model_len, "tensor_parallel_size": tensor_parallel_size, diff --git a/tests/models/language/pooling/test_extract_hidden_states.py b/tests/models/language/pooling/test_extract_hidden_states.py index f8e3fa7d1560f..0d41b93233d5a 100644 --- a/tests/models/language/pooling/test_extract_hidden_states.py +++ b/tests/models/language/pooling/test_extract_hidden_states.py @@ -11,7 +11,7 @@ from vllm import TokensPrompt ["Qwen/Qwen3-0.6B"], ) @torch.inference_mode -def test_embed_models(hf_runner, vllm_runner, model: str): +def test_extract_hidden_states(hf_runner, vllm_runner, model: str): n_prompt_tokens = [55, 56, 57] token_prompts = [[1024 + i for i in range(n)] for n in n_prompt_tokens] @@ -21,7 +21,7 @@ def test_embed_models(hf_runner, vllm_runner, model: str): enforce_eager=True, runner="pooling", enable_chunked_prefill=False, - enable_prefix_caching=False, + enable_prefix_caching=True, ) as vllm_model: pooling_outputs = vllm_model.llm.encode( [TokensPrompt(prompt_token_ids=t) for t in token_prompts], @@ -30,4 +30,29 @@ def test_embed_models(hf_runner, vllm_runner, model: str): for n, output in zip(n_prompt_tokens, pooling_outputs): assert len(output.prompt_token_ids) == n + assert len(output.outputs.data) == n assert output.num_cached_tokens == 0 + + # test enable_prefix_caching plus all pooling + # we need to skip reading cache at this request by + # request.skip_reading_prefix_cache + pooling_outputs = vllm_model.llm.encode( + [TokensPrompt(prompt_token_ids=t) for t in token_prompts], + pooling_task="token_embed", + ) + + for n, output in zip(n_prompt_tokens, pooling_outputs): + assert len(output.prompt_token_ids) == n + assert len(output.outputs.data) == n + assert output.num_cached_tokens == 0 + + # skip_reading_prefix_cache can still write to cache + # to accelerate following requests + pooling_outputs = vllm_model.llm.encode( + [TokensPrompt(prompt_token_ids=t) for t in token_prompts], + pooling_task="embed", + ) + + for n, output in zip(n_prompt_tokens, pooling_outputs): + assert len(output.prompt_token_ids) == n + assert output.num_cached_tokens > 0 diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index 5504c417fda4c..95b64b380db0d 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -12,6 +12,7 @@ import pytest from packaging.version import Version from transformers import ( AutoModel, + AutoModelForCausalLM, AutoModelForImageTextToText, AutoModelForTextToWaveform, ) @@ -131,6 +132,7 @@ VLM_TEST_SETTINGS = { prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>", video_idx_to_prompt=lambda idx: "<|vision_start|><|video_pad|><|vision_end|>", + enforce_eager=False, max_model_len=4096, max_num_seqs=2, auto_cls=AutoModelForImageTextToText, @@ -160,6 +162,7 @@ VLM_TEST_SETTINGS = { VLMTestType.MULTI_IMAGE, VLMTestType.VIDEO, ), + enforce_eager=False, needs_video_metadata=True, prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>", # noqa: E501 @@ -689,6 +692,23 @@ VLM_TEST_SETTINGS = { patch_hf_runner=model_utils.ovis2_5_patch_hf_runner, hf_model_kwargs={"revision": "refs/pr/5"}, ), + "paddleocr_vl": VLMTestInfo( + models=["PaddlePaddle/PaddleOCR-VL"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:", + img_idx_to_prompt=lambda idx: ( + "<|IMAGE_START|><|IMAGE_PLACEHOLDER|><|IMAGE_END|>" + ), + multi_image_prompt=( + "Image-1: <|IMAGE_START|><|IMAGE_PLACEHOLDER|><|IMAGE_END|>\n" + "Image-2: <|IMAGE_START|><|IMAGE_PLACEHOLDER|><|IMAGE_END|>\n" + "Describe these two images separately." + ), + max_model_len=8192, + max_num_seqs=2, + auto_cls=AutoModelForCausalLM, + image_size_factors=[(), (0.25,)], + ), "phi3v": VLMTestInfo( models=["microsoft/Phi-3.5-vision-instruct"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), diff --git a/tests/models/multimodal/generation/test_qwen2_5_vl.py b/tests/models/multimodal/generation/test_qwen2_5_vl.py index 6b009075abfa7..3ba665710af46 100644 --- a/tests/models/multimodal/generation/test_qwen2_5_vl.py +++ b/tests/models/multimodal/generation/test_qwen2_5_vl.py @@ -34,6 +34,7 @@ VIDEO_PROMPTS = VIDEO_ASSETS.prompts( @pytest.mark.parametrize("num_frames", [16]) @pytest.mark.parametrize("dtype", [target_dtype]) @pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("use_bytecode_hook", [True, False]) def test_qwen2_5_vl_evs_functionality( vllm_runner, video_assets, @@ -42,10 +43,14 @@ def test_qwen2_5_vl_evs_functionality( num_frames: int, dtype: str, max_tokens: int, + use_bytecode_hook: bool, + monkeypatch, ) -> None: """Test EVS (Efficient Video Sampling) functionality with different pruning rates. """ + # Set the environment variable for this test + monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0") # Sample frames from video assets sampled_vids = [ @@ -86,6 +91,7 @@ def test_qwen2_5_vl_evs_functionality( @pytest.mark.parametrize("num_frames", [16]) @pytest.mark.parametrize("dtype", [target_dtype]) @pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("use_bytecode_hook", [True, False]) def test_qwen2_5_vl_evs_batched_videos( vllm_runner, video_assets, @@ -94,6 +100,8 @@ def test_qwen2_5_vl_evs_batched_videos( num_frames: int, dtype: str, max_tokens: int, + use_bytecode_hook: bool, + monkeypatch, ) -> None: """Test EVS functionality with batched videos. @@ -102,6 +110,8 @@ def test_qwen2_5_vl_evs_batched_videos( 2. Both pruning configurations work with multiple videos 3. The model doesn't crash when processing multiple videos simultaneously """ + # Set the environment variable for this test + monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0") # Sample frames from video assets sampled_vids = [ sample_frames_from_video(asset.np_ndarrays, num_frames) diff --git a/tests/models/registry.py b/tests/models/registry.py index 644d0619215fb..094f921e4305f 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -173,6 +173,10 @@ class _HfExamplesInfo: _TEXT_GENERATION_EXAMPLE_MODELS = { # [Decoder-only] + "AfmoeForCausalLM": _HfExamplesInfo( + "arcee-ai/Trinity-Nano", + is_available_online=False, + ), "ApertusForCausalLM": _HfExamplesInfo("swiss-ai/Apertus-8B-Instruct-2509"), "AquilaModel": _HfExamplesInfo("BAAI/AquilaChat-7B", trust_remote_code=True), "AquilaForCausalLM": _HfExamplesInfo("BAAI/AquilaChat2-7B", trust_remote_code=True), diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py index a18f5b6077636..ae5befd2c00b7 100644 --- a/tests/models/test_transformers.py +++ b/tests/models/test_transformers.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Test the functionality of the Transformers backend.""" +"""Test the functionality of the Transformers modeling backend.""" from typing import Any @@ -85,7 +85,7 @@ def test_models( required = Version("5.0.0.dev") if model == "allenai/OLMoE-1B-7B-0924" and installed < required: pytest.skip( - "MoE models with the Transformers backend require " + "MoE models with the Transformers modeling backend require " f"transformers>={required}, but got {installed}" ) diff --git a/tests/quantization/test_register_quantization_config.py b/tests/quantization/test_register_quantization_config.py index aeef4c2fd8a70..8da048703df93 100644 --- a/tests/quantization/test_register_quantization_config.py +++ b/tests/quantization/test_register_quantization_config.py @@ -7,6 +7,7 @@ See https://github.com/vllm-project/vllm/issues/11926 for more details. Run `pytest tests/quantization/test_register_quantization_config.py`. """ +import logging from typing import Any import pytest @@ -100,17 +101,22 @@ class CustomQuantConfig(QuantizationConfig): return None -def test_register_quantization_config(): +def test_register_quantization_config(caplog_vllm): """Test register custom quantization config.""" # The quantization method `custom_quant` should be registered. assert get_quantization_config("custom_quant") == CustomQuantConfig # The quantization method `custom_quant` is already exists, - # should raise an error. - with pytest.raises(ValueError): + # should raise a warning when re-registering it. + with caplog_vllm.at_level(logging.WARNING): register_quantization_config("custom_quant")(CustomQuantConfig) + assert any( + "The quantization method 'custom_quant' already exists" in message + for message in caplog_vllm.messages + ), "Expected a warning when re-registering custom_quant" + @pytest.mark.parametrize( argnames="model", diff --git a/tests/rocm/aiter/test_grouped_quant.py b/tests/rocm/aiter/test_grouped_quant.py new file mode 100644 index 0000000000000..c7f0f1eda3558 --- /dev/null +++ b/tests/rocm/aiter/test_grouped_quant.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# This is a test for the AITER group_fp8_quant op. +# It tests if the AITER op is +# 1. correctly defined the relationship between +# implementation and fake function +# 2. can be used with torch.compile +# 3. can be used with CUDA graphs +# This file will be skipped if AITER is not installed +# and the platform is not ROCm. + +import importlib.util + +import pytest +import torch + +# this import statement is needed to ensure the ops are registered +from vllm._aiter_ops import rocm_aiter_ops +from vllm.platforms import current_platform + +# Check if aiter package is installed +aiter_available = importlib.util.find_spec("aiter") is not None + +pytestmark = pytest.mark.skipif( + not (current_platform.is_rocm() and aiter_available), + reason="AITER ops are only available on ROCm with aiter package installed", +) + + +def test_rocm_aiter_group_fp8_quant_fake_implementation(): + """Test that the fake implementation is correctly + defined for torch.ops.vllm.rocm_aiter_group_fp8_quant.""" + # Create test tensors + M = 128 + N = 4096 + group_size = 128 + + input_tensor = torch.randn((M, N), dtype=torch.bfloat16, device="cuda") + + # Verify the op's fake implementation using torch.library.opcheck + # This checks that the fake function returns tensors with correct shapes and dtypes + torch.library.opcheck( + torch.ops.vllm.rocm_aiter_group_fp8_quant, + (input_tensor, group_size), + test_utils=("test_faketensor",), + ) + + +def test_rocm_aiter_group_fp8_quant_torch_compile_with_cudagraph(): + """Test that rocm_aiter_ops.group_fp8_quant + with group size 128 can be used with + torch.compile in cudagraph mode.""" + # Create test tensors + M = 128 + N = 4096 + group_size = 128 + + input_tensor = torch.randn((M, N), dtype=torch.bfloat16, device="cuda") + + # Define a function that uses the op + def group_fp8_quant_fn(x): + return rocm_aiter_ops.group_fp8_quant(x, group_size) + + # Compile with cudagraph mode + compiled_fn = torch.compile( + group_fp8_quant_fn, + fullgraph=True, + backend="inductor", + mode="reduce-overhead", + dynamic=False, + ) + + # Run eager mode + x_fp8_eager, scales_eager = group_fp8_quant_fn(input_tensor) + + # Run compiled version (first run will trigger compilation) + x_fp8_compiled, scales_compiled = compiled_fn(input_tensor) + + # Verify shapes match + assert x_fp8_compiled.shape == x_fp8_eager.shape + assert scales_compiled.shape == scales_eager.shape + + # Verify expected shapes + assert x_fp8_compiled.shape == (M, N) + expected_scale_cols = (N + group_size - 1) // group_size + assert scales_compiled.shape == (M, expected_scale_cols) + + # Verify results match + assert torch.allclose( + x_fp8_compiled.to(torch.float32), + x_fp8_eager.to(torch.float32), + rtol=1e-2, + atol=1e-2, + ) + assert torch.allclose(scales_compiled, scales_eager, rtol=1e-3, atol=1e-3) + + # Test with different input (reusing compiled graph) + input_tensor_2 = torch.randn((M, N), dtype=torch.bfloat16, device="cuda") + x_fp8_eager_2, scales_eager_2 = group_fp8_quant_fn(input_tensor_2) + x_fp8_compiled_2, scales_compiled_2 = compiled_fn(input_tensor_2) + + # Verify second run also produces correct results + assert torch.allclose( + x_fp8_compiled_2.to(torch.float32), + x_fp8_eager_2.to(torch.float32), + rtol=1e-2, + atol=1e-2, + ) + assert torch.allclose(scales_compiled_2, scales_eager_2, rtol=1e-3, atol=1e-3) + + +def test_rocm_aiter_group_fp8_quant_different_shapes(): + """Test rocm_aiter_ops.group_fp8_quant with different input shapes.""" + group_size = 128 + + test_shapes = [ + (64, 2048), + (256, 8192), + (32, 1024), + (512, 4096), + ] + + for M, N in test_shapes: + input_tensor = torch.randn((M, N), dtype=torch.bfloat16, device="cuda") + + x_fp8, scales = rocm_aiter_ops.group_fp8_quant(input_tensor, group_size) + + # Verify shapes + assert x_fp8.shape == (M, N) + expected_scale_cols = (N + group_size - 1) // group_size + assert scales.shape == (M, expected_scale_cols) + + # Verify dtypes + from aiter import dtypes + + assert x_fp8.dtype == dtypes.fp8 + assert scales.dtype == torch.float32 diff --git a/tests/test_envs.py b/tests/test_envs.py index 841d7945f9120..6a9835a68e7e2 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -36,7 +36,7 @@ def test_getattr_with_cache(monkeypatch: pytest.MonkeyPatch): # Enable envs cache and ignore ongoing environment changes enable_envs_cache() - # __getattr__ is not decorated with functools.cache + # __getattr__ is decorated with functools.cache assert hasattr(envs.__getattr__, "cache_info") start_hits = envs.__getattr__.cache_info().hits diff --git a/tests/test_inputs.py b/tests/test_inputs.py index 50a273016ab80..b1fb4e06a6906 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -86,34 +86,6 @@ def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs): assert zipped["mm_processor_kwargs"] == exp_kwargs -@pytest.mark.parametrize( - "model_id", - [ - "facebook/opt-125m", - ], -) -@pytest.mark.parametrize( - "prompt", - [ - { - "prompt": "", - "multi_modal_data": {"dummy": []}, - }, - { - "prompt_token_ids": [], - "multi_modal_data": {"dummy": []}, - }, - ], -) -def test_preprocessor_text_no_mm_inputs(model_id, prompt): - model_config = ModelConfig(model=model_id) - tokenizer = init_tokenizer_from_configs(model_config) - input_preprocessor = InputPreprocessor(model_config, tokenizer) - - with pytest.raises(ValueError, match="does not support multimodal inputs"): - input_preprocessor.preprocess(prompt) - - @pytest.mark.parametrize( "model_id", [ @@ -127,6 +99,13 @@ def test_preprocessor_text_no_mm_inputs(model_id, prompt): {"prompt_token_ids": []}, ], ) +@pytest.mark.skip( + reason=( + "Applying huggingface processor on text inputs results in " + "significant performance regression for multimodal models. " + "See https://github.com/vllm-project/vllm/issues/26320" + ) +) def test_preprocessor_always_mm_code_path(model_id, prompt): model_config = ModelConfig(model=model_id) tokenizer = init_tokenizer_from_configs(model_config) diff --git a/tests/tool_use/test_kimi_k2_tool_parser.py b/tests/tool_use/test_kimi_k2_tool_parser.py index c358589dbc292..33dabbc7e7b91 100644 --- a/tests/tool_use/test_kimi_k2_tool_parser.py +++ b/tests/tool_use/test_kimi_k2_tool_parser.py @@ -209,3 +209,596 @@ def test_streaming_no_tool_calls(kimi_k2_tool_parser): assert result is not None assert hasattr(result, "content") assert result.content == " without any tool calls." + + +def test_token_leak_between_section_and_tool_begin(kimi_k2_tool_parser): + """ + Test that text between <|tool_calls_section_begin|> and <|tool_call_begin|> + is suppressed and does not leak into reasoning_delta. + This is the main vulnerability being fixed. + """ + kimi_k2_tool_parser.reset_streaming_state() + + # Get token IDs for the markers + section_begin_token_id = kimi_k2_tool_parser.vocab.get( + "<|tool_calls_section_begin|>" + ) + tool_call_begin_token_id = kimi_k2_tool_parser.vocab.get("<|tool_call_begin|>") + + # Simulate streaming sequence: + # Delta 1: "I'll help you with that. " + result1 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="", + current_text="I'll help you with that. ", + delta_text="I'll help you with that. ", + previous_token_ids=[], + current_token_ids=[1, 2, 3], # Regular tokens + delta_token_ids=[1, 2, 3], + request=None, + ) + assert result1 is not None + assert result1.content == "I'll help you with that. " + + # Delta 2: "<|tool_calls_section_begin|>" + prev_ids = [1, 2, 3] + curr_ids = prev_ids + [section_begin_token_id] + result2 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="I'll help you with that. ", + current_text="I'll help you with that. <|tool_calls_section_begin|>", + delta_text="<|tool_calls_section_begin|>", + previous_token_ids=prev_ids, + current_token_ids=curr_ids, + delta_token_ids=[section_begin_token_id], + request=None, + ) + # Section marker should be stripped and suppressed + assert result2 is None or (result2.content is None or result2.content == "") + + # Delta 3: " spurious text or tokens " (THE LEAK SCENARIO) + prev_ids = curr_ids + curr_ids = curr_ids + [4, 5] + result3 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="I'll help you with that. <|tool_calls_section_begin|>", + current_text="I'll help you with that. <|tool_calls_section_begin|> spurious text ", + delta_text=" spurious text ", + previous_token_ids=prev_ids, + current_token_ids=curr_ids, + delta_token_ids=[4, 5], + request=None, + ) + # CRITICAL: This text should be suppressed, NOT returned as reasoning_delta + assert result3 is None or (result3.content is None or result3.content == "") + + # Delta 4: "<|tool_call_begin|>..." + prev_ids = curr_ids + curr_ids = curr_ids + [tool_call_begin_token_id] + _result4 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="I'll help you with that. <|tool_calls_section_begin|> spurious text ", + current_text="I'll help you with that. <|tool_calls_section_begin|> spurious text <|tool_call_begin|>", + delta_text="<|tool_call_begin|>", + previous_token_ids=prev_ids, + current_token_ids=curr_ids, + delta_token_ids=[tool_call_begin_token_id], + request=None, + ) + # Now we're in tool call mode, result depends on internal state + # The key is that the spurious text from Delta 3 was not leaked + + +def test_split_markers_across_deltas(kimi_k2_tool_parser): + """ + Test that markers split across delta chunks are correctly detected + via the rolling buffer mechanism. + """ + kimi_k2_tool_parser.reset_streaming_state() + + section_begin_token_id = kimi_k2_tool_parser.vocab.get( + "<|tool_calls_section_begin|>" + ) + + # Delta 1: "...reasoning<|tool_calls_sec" + _result1 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="Some reasoning", + current_text="Some reasoning<|tool_calls_sec", + delta_text="<|tool_calls_sec", + previous_token_ids=[1, 2], + current_token_ids=[1, 2, 3], # Partial token + delta_token_ids=[3], + request=None, + ) + # Partial token not recognized yet, might be buffered + # Should return as content or None (depends on implementation) + + # Delta 2: "tion_begin|> " (completes the marker) + _result2 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="Some reasoning<|tool_calls_sec", + current_text="Some reasoning<|tool_calls_section_begin|> ", + delta_text="tion_begin|> ", + previous_token_ids=[1, 2, 3], + current_token_ids=[1, 2, section_begin_token_id, 4], + delta_token_ids=[section_begin_token_id, 4], + request=None, + ) + # Now the complete marker should be detected via buffer + # The parser should enter tool section mode + assert kimi_k2_tool_parser.in_tool_section is True + + +def test_marker_variants(kimi_k2_tool_parser): + """Test that both singular and plural marker variants are recognized.""" + kimi_k2_tool_parser.reset_streaming_state() + + # Test singular variant: <|tool_call_section_begin|> (note: singular "call") + singular_token_id = kimi_k2_tool_parser.vocab.get("<|tool_call_section_begin|>") + + if singular_token_id is not None: # Only test if tokenizer supports it + _result = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="Reasoning ", + current_text="Reasoning <|tool_call_section_begin|>", + delta_text="<|tool_call_section_begin|>", + previous_token_ids=[1, 2], + current_token_ids=[1, 2, singular_token_id], + delta_token_ids=[singular_token_id], + request=None, + ) + # Should enter tool section mode with singular variant too + assert kimi_k2_tool_parser.in_tool_section is True + + +def test_reentry_to_reasoning_after_tool_section(kimi_k2_tool_parser): + """ + Test that after exiting a tool section with <|tool_calls_section_end|>, + subsequent text is correctly returned as reasoning content. + """ + kimi_k2_tool_parser.reset_streaming_state() + + section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>") + section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>") + + # Enter tool section + _result1 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="", + current_text="<|tool_calls_section_begin|>", + delta_text="<|tool_calls_section_begin|>", + previous_token_ids=[], + current_token_ids=[section_begin_id], + delta_token_ids=[section_begin_id], + request=None, + ) + assert kimi_k2_tool_parser.in_tool_section is True + + # Exit tool section + _result2 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="<|tool_calls_section_begin|>", + current_text="<|tool_calls_section_begin|><|tool_calls_section_end|>", + delta_text="<|tool_calls_section_end|>", + previous_token_ids=[section_begin_id], + current_token_ids=[section_begin_id, section_end_id], + delta_token_ids=[section_end_id], + request=None, + ) + assert kimi_k2_tool_parser.in_tool_section is False + + # Subsequent reasoning text should be returned normally + result3 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="<|tool_calls_section_begin|><|tool_calls_section_end|>", + current_text="<|tool_calls_section_begin|><|tool_calls_section_end|> More reasoning", + delta_text=" More reasoning", + previous_token_ids=[section_begin_id, section_end_id], + current_token_ids=[section_begin_id, section_end_id, 10, 11], + delta_token_ids=[10, 11], + request=None, + ) + assert result3 is not None + assert result3.content == " More reasoning" + + +def test_empty_tool_section(kimi_k2_tool_parser): + """Test an empty tool section (begin immediately followed by end).""" + kimi_k2_tool_parser.reset_streaming_state() + + section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>") + section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>") + + # Section begin + _result1 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="Reasoning ", + current_text="Reasoning <|tool_calls_section_begin|>", + delta_text="<|tool_calls_section_begin|>", + previous_token_ids=[1], + current_token_ids=[1, section_begin_id], + delta_token_ids=[section_begin_id], + request=None, + ) + + # Immediate section end + _result2 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="Reasoning <|tool_calls_section_begin|>", + current_text="Reasoning <|tool_calls_section_begin|><|tool_calls_section_end|>", + delta_text="<|tool_calls_section_end|>", + previous_token_ids=[1, section_begin_id], + current_token_ids=[1, section_begin_id, section_end_id], + delta_token_ids=[section_end_id], + request=None, + ) + # Should exit cleanly without errors + assert kimi_k2_tool_parser.in_tool_section is False + + +def test_malformed_tool_section_recovery(kimi_k2_tool_parser): + """ + Test that the parser recovers from a malformed tool section + that never closes properly. + """ + kimi_k2_tool_parser.reset_streaming_state() + + section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>") + + # Enter tool section + _result1 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="", + current_text="<|tool_calls_section_begin|>", + delta_text="<|tool_calls_section_begin|>", + previous_token_ids=[], + current_token_ids=[section_begin_id], + delta_token_ids=[section_begin_id], + request=None, + ) + assert kimi_k2_tool_parser.in_tool_section is True + + # Simulate a lot of text without proper tool calls or section end + # This should trigger the error recovery mechanism + large_text = "x" * 10000 # Exceeds max_section_chars + + result2 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="<|tool_calls_section_begin|>", + current_text="<|tool_calls_section_begin|>" + large_text, + delta_text=large_text, + previous_token_ids=[section_begin_id], + current_token_ids=[section_begin_id] + list(range(100, 100 + len(large_text))), + delta_token_ids=list(range(100, 100 + len(large_text))), + request=None, + ) + + # Parser should have force-exited the tool section + assert kimi_k2_tool_parser.in_tool_section is False + # And returned the content as reasoning + assert result2 is not None + assert result2.content == large_text + + +def test_state_reset(kimi_k2_tool_parser): + """Test that reset_streaming_state() properly clears all state.""" + # Put parser in a complex state + kimi_k2_tool_parser.in_tool_section = True + kimi_k2_tool_parser.token_buffer = "some buffer" + kimi_k2_tool_parser.current_tool_id = 5 + kimi_k2_tool_parser.prev_tool_call_arr = [{"id": "test"}] + kimi_k2_tool_parser.section_char_count = 1000 + + # Reset + kimi_k2_tool_parser.reset_streaming_state() + + # Verify all state is cleared + assert kimi_k2_tool_parser.in_tool_section is False + assert kimi_k2_tool_parser.token_buffer == "" + assert kimi_k2_tool_parser.current_tool_id == -1 + assert kimi_k2_tool_parser.prev_tool_call_arr == [] + assert kimi_k2_tool_parser.section_char_count == 0 + assert kimi_k2_tool_parser.current_tool_name_sent is False + assert kimi_k2_tool_parser.streamed_args_for_tool == [] + + +def test_section_begin_noise_tool_begin_same_chunk(kimi_k2_tool_parser): + """ + Test that begin→noise→tool_begin within the SAME chunk suppresses + the noise text correctly (not just across chunks). + """ + kimi_k2_tool_parser.reset_streaming_state() + + section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>") + tool_call_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_call_begin|>") + + # Single delta containing: section_begin + spurious text + tool_call_begin + combined_text = "<|tool_calls_section_begin|> noise text <|tool_call_begin|>" + + result = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="Reasoning ", + current_text="Reasoning " + combined_text, + delta_text=combined_text, + previous_token_ids=[1, 2], + current_token_ids=[1, 2, section_begin_id, 3, 4, tool_call_begin_id], + delta_token_ids=[section_begin_id, 3, 4, tool_call_begin_id], + request=None, + ) + + # The noise text should NOT leak into content + # Result should either be None/empty or start tool call parsing + if result is not None and result.content is not None: + # If content is returned, it should not contain the noise + assert "noise text" not in result.content + assert result.content == "" or result.content.strip() == "" + + +def test_stream_ends_without_section_end_marker(kimi_k2_tool_parser): + """ + Test that if the stream ends (EOF) without a proper section end marker, + the parser doesn't leak text, doesn't crash, and resets state cleanly. + """ + kimi_k2_tool_parser.reset_streaming_state() + + section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>") + + # Enter tool section + _result1 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="", + current_text="<|tool_calls_section_begin|>", + delta_text="<|tool_calls_section_begin|>", + previous_token_ids=[], + current_token_ids=[section_begin_id], + delta_token_ids=[section_begin_id], + request=None, + ) + assert kimi_k2_tool_parser.in_tool_section is True + + # Some content in tool section + result2 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="<|tool_calls_section_begin|>", + current_text="<|tool_calls_section_begin|> partial content", + delta_text=" partial content", + previous_token_ids=[section_begin_id], + current_token_ids=[section_begin_id, 10, 11], + delta_token_ids=[10, 11], + request=None, + ) + # Content should be suppressed + assert result2.content == "" or result2.content is None + + # Stream ends (EOF) - no more deltas, no section_end marker + # Simulate this by manually checking state and resetting + # (In real usage, the request handler would call reset_streaming_state) + assert kimi_k2_tool_parser.in_tool_section is True # Still in section + + # Reset state (as would happen between requests) + kimi_k2_tool_parser.reset_streaming_state() + + # Verify clean slate + assert kimi_k2_tool_parser.in_tool_section is False + assert kimi_k2_tool_parser.token_buffer == "" + + # Next request should work normally + result3 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="", + current_text="New reasoning", + delta_text="New reasoning", + previous_token_ids=[], + current_token_ids=[20, 21], + delta_token_ids=[20, 21], + request=None, + ) + assert result3 is not None + assert result3.content == "New reasoning" + + +def test_same_chunk_begin_and_end_markers(kimi_k2_tool_parser): + """ + CRITICAL TEST: Verify that when both section_begin and section_end + markers appear in the SAME chunk, the parser correctly: + 1. Enters the tool section + 2. Immediately exits the tool section + 3. Does NOT get stuck in in_tool_section=True state + + This tests the bug fix where elif was changed to if to handle + both state transitions in a single delta. + """ + kimi_k2_tool_parser.reset_streaming_state() + + section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>") + section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>") + + # Single chunk with both markers (e.g., empty tool section) + combined_delta = "<|tool_calls_section_begin|><|tool_calls_section_end|>" + + result = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="Some reasoning ", + current_text="Some reasoning " + combined_delta, + delta_text=combined_delta, + previous_token_ids=[1, 2], + current_token_ids=[1, 2, section_begin_id, section_end_id], + delta_token_ids=[section_begin_id, section_end_id], + request=None, + ) + + # CRITICAL: Parser should NOT be stuck in tool section + assert kimi_k2_tool_parser.in_tool_section is False, ( + "Parser stuck in tool section after processing both begin/end in same chunk. " + "This indicates the elif bug was not fixed." + ) + + # Result should be empty or contain only stripped content + assert result is not None + assert result.content == "" or result.content is None + + # Verify subsequent content streams correctly (not suppressed) + result2 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="Some reasoning " + combined_delta, + current_text="Some reasoning " + combined_delta + " More reasoning", + delta_text=" More reasoning", + previous_token_ids=[1, 2, section_begin_id, section_end_id], + current_token_ids=[1, 2, section_begin_id, section_end_id, 10, 11], + delta_token_ids=[10, 11], + request=None, + ) + + # This content should NOT be suppressed (we're out of tool section) + assert result2 is not None + assert result2.content == " More reasoning" + + +def test_same_chunk_begin_content_end_markers(kimi_k2_tool_parser): + """ + Test the same-chunk scenario with actual content between markers. + Example: <|tool_calls_section_begin|> text <|tool_calls_section_end|> + all arriving in one delta. The key is that the state machine correctly + transitions in and out within the same chunk. + """ + kimi_k2_tool_parser.reset_streaming_state() + + section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>") + section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>") + + # Chunk with begin, some whitespace/noise, and end all together + # This simulates a tool section that opens and closes in the same chunk + combined_delta = "<|tool_calls_section_begin|> <|tool_calls_section_end|>" + + _result = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="Reasoning ", + current_text="Reasoning " + combined_delta, + delta_text=combined_delta, + previous_token_ids=[1], + current_token_ids=[1, section_begin_id, 100, section_end_id], + delta_token_ids=[section_begin_id, 100, section_end_id], + request=None, + ) + + # Parser should exit cleanly (not stuck in tool section) + assert kimi_k2_tool_parser.in_tool_section is False + + # Verify the fix: next content should stream normally, not be suppressed + result2 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="Reasoning " + combined_delta, + current_text="Reasoning " + combined_delta + " Done", + delta_text=" Done", + previous_token_ids=[1, section_begin_id, 100, section_end_id], + current_token_ids=[1, section_begin_id, 100, section_end_id, 200], + delta_token_ids=[200], + request=None, + ) + + # Content after section should be returned (not suppressed) + assert result2 is not None + assert result2.content == " Done" + + +def test_tool_call_end_and_section_end_same_chunk(kimi_k2_tool_parser): + """ + CRITICAL TEST (P1): Verify that when both <|tool_call_end|> and + <|tool_calls_section_end|> appear in the SAME chunk, the parser: + 1. Processes the tool_call_end first (emits final arguments) + 2. THEN exits the section + 3. Does NOT drop the final tool call update + 4. Does NOT leak special tokens into reasoning + + This tests the deferred section exit fix. + """ + kimi_k2_tool_parser.reset_streaming_state() + + section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>") + section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>") + tool_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_call_begin|>") + tool_end_id = kimi_k2_tool_parser.vocab.get("<|tool_call_end|>") + + # Simulate a streaming sequence for a SHORT tool call (all in one chunk): + # 1. Reasoning text + result1 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="", + current_text="Let me help. ", + delta_text="Let me help. ", + previous_token_ids=[], + current_token_ids=[1, 2], + delta_token_ids=[1, 2], + request=None, + ) + assert result1 is not None + assert result1.content == "Let me help. " + + # 2. Section begin + _result2 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="Let me help. ", + current_text="Let me help. <|tool_calls_section_begin|>", + delta_text="<|tool_calls_section_begin|>", + previous_token_ids=[1, 2], + current_token_ids=[1, 2, section_begin_id], + delta_token_ids=[section_begin_id], + request=None, + ) + assert kimi_k2_tool_parser.in_tool_section is True + + # 3. Tool call begin + full content + tool_end + section_end ALL IN ONE CHUNK + # This is the critical scenario for short tool calls + combined = ( + '<|tool_call_begin|>get_weather:0 <|tool_call_argument_begin|> {"city": "Paris"} ' + "<|tool_call_end|><|tool_calls_section_end|>" + ) + + # Build up the previous text gradually to simulate realistic streaming + prev_text = "Let me help. <|tool_calls_section_begin|>" + curr_text = prev_text + combined + + result3 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text=prev_text, + current_text=curr_text, + delta_text=combined, + previous_token_ids=[1, 2, section_begin_id], + current_token_ids=[ + 1, + 2, + section_begin_id, + tool_begin_id, + 10, + 11, + 12, + tool_end_id, + section_end_id, + ], + delta_token_ids=[tool_begin_id, 10, 11, 12, tool_end_id, section_end_id], + request=None, + ) + + # CRITICAL: Parser should have exited section AFTER processing tool + assert kimi_k2_tool_parser.in_tool_section is False + + # Tool call should have been emitted (not dropped) + # The result might be the tool name or None depending on state, but + # importantly, it shouldn't be returning the literal tokens as content + + if result3 is not None and result3.content is not None: + # Verify no special tokens leaked into content + assert "<|tool_call_end|>" not in result3.content + assert "<|tool_calls_section_end|>" not in result3.content + + # 4. Verify subsequent content streams normally + result4 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text=curr_text, + current_text=curr_text + " Done", + delta_text=" Done", + previous_token_ids=[ + 1, + 2, + section_begin_id, + tool_begin_id, + 10, + 11, + 12, + tool_end_id, + section_end_id, + ], + current_token_ids=[ + 1, + 2, + section_begin_id, + tool_begin_id, + 10, + 11, + 12, + tool_end_id, + section_end_id, + 20, + ], + delta_token_ids=[20], + request=None, + ) + + # Content after tool section should stream normally + assert result4 is not None + assert result4.content == " Done" diff --git a/tests/v1/core/test_async_scheduler.py b/tests/v1/core/test_async_scheduler.py index e0645ed43015e..1d80ee9875913 100644 --- a/tests/v1/core/test_async_scheduler.py +++ b/tests/v1/core/test_async_scheduler.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections import deque +import numpy as np import pytest from vllm.v1.core.sched.output import SchedulerOutput @@ -21,7 +22,7 @@ def _make_model_runner_output( return ModelRunnerOutput( req_ids=req_ids, req_id_to_index={req_id: i for i, req_id in enumerate(req_ids)}, - sampled_token_ids=[[i] for i in range(len(req_ids))], + sampled_token_ids=[np.array([i]) for i in range(len(req_ids))], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index df6a5f109874d..24611a4aaa1b8 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -55,7 +55,7 @@ pytestmark = pytest.mark.cpu_test def _auto_init_hash_fn(request): hash_fn: Callable if "hash_fn" in request.fixturenames: - hash_fn = init_none_hash(request.getfixturevalue("hash_fn")) + hash_fn = request.getfixturevalue("hash_fn") else: hash_fn = sha256 init_none_hash(hash_fn) diff --git a/tests/v1/core/test_priority_scheduler_random.py b/tests/v1/core/test_priority_scheduler_random.py index b4805be802723..ba0b703302e38 100644 --- a/tests/v1/core/test_priority_scheduler_random.py +++ b/tests/v1/core/test_priority_scheduler_random.py @@ -3,6 +3,7 @@ import random import uuid +import numpy as np import pytest from vllm.config import VllmConfig @@ -99,8 +100,7 @@ def _mock_execute_model( random.randint(*num_output_tokens_range) for _ in range(len(request_ids)) ] sampled_token_ids = [ - [random.randint(0, 100) for _ in range(num_tokens)] - for num_tokens in num_output_tokens + np.random.randint(0, 100, size=num_tokens) for num_tokens in num_output_tokens ] return ModelRunnerOutput( @@ -196,6 +196,8 @@ def test_priority_scheduling_blast( num_blocks: int, ): random.seed(42) + np.random.seed(42) + seen_request_prompt_length = dict[str, int]() seen_request_ids = set[str]() seen_mm_hashes = set[str]() diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index d5b829e79b8f7..0570c0854c678 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -3,6 +3,7 @@ import dataclasses from unittest.mock import Mock +import numpy as np import pytest import torch @@ -31,11 +32,11 @@ from vllm.v1.kv_cache_interface import ( KVCacheConfig, KVCacheGroupSpec, ) -from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput +from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.structured_output import StructuredOutputManager -from .utils import EOS_TOKEN_ID, create_requests, create_scheduler +from .utils import EOS_TOKEN_ID, create_requests, create_scheduler, mock_kv pytestmark = pytest.mark.cpu_test @@ -169,7 +170,7 @@ def test_schedule_partial_requests(): req_id_to_index=req_to_index, # Only the first request has a sampled token id because # the rest requests are still being prefilled. - sampled_token_ids=[[0], [], []], + sampled_token_ids=[np.array([0]), np.array([]), np.array([])], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -216,7 +217,7 @@ def test_no_mm_input_chunking(): model_runner_output = ModelRunnerOutput( req_ids=[request.request_id for request in requests], req_id_to_index=req_to_index, - sampled_token_ids=[[] for _ in range(len(requests))], + sampled_token_ids=[np.array([]) for _ in range(len(requests))], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -276,7 +277,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): model_runner_output = ModelRunnerOutput( req_ids=[request.request_id for request in requests], req_id_to_index=req_to_index, - sampled_token_ids=[[] for _ in range(len(requests))], + sampled_token_ids=[np.array([]) for _ in range(len(requests))], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -300,7 +301,8 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): model_runner_output = ModelRunnerOutput( req_ids=[request.request_id for request in requests], req_id_to_index=req_to_index, - sampled_token_ids=[[0], [0]] + [[] for _ in range(len(requests) - 2)], + sampled_token_ids=[np.array([0]), np.array([0])] + + [np.array([]) for _ in range(len(requests) - 2)], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -347,8 +349,8 @@ def test_stop_via_update_from_output(): req_ids=[req.request_id for req in requests], req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, sampled_token_ids=[ - [EOS_TOKEN_ID], - [10, 11], + np.array([EOS_TOKEN_ID]), + np.array([10, 11]), ], # First request hits EOS, second continues logprobs=None, prompt_logprobs_dict={}, @@ -392,7 +394,10 @@ def test_stop_via_update_from_output(): model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, - sampled_token_ids=[[10, 42, 12], [13, 14]], # First request hits stop token + sampled_token_ids=[ + np.array([10, 42, 12]), + np.array([13, 14]), + ], # First request hits stop token logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -436,7 +441,10 @@ def test_stop_via_update_from_output(): model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, - sampled_token_ids=[[10, 11, 12], [13]], # First request exceeds max_tokens + sampled_token_ids=[ + np.array([10, 11, 12]), + np.array([13]), + ], # First request exceeds max_tokens logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -475,7 +483,7 @@ def test_stop_via_update_from_output(): model_output = ModelRunnerOutput( req_ids=[requests[0].request_id], req_id_to_index={requests[0].request_id: 0}, - sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]], + sampled_token_ids=[np.array([EOS_TOKEN_ID, 10, 11])], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -616,7 +624,7 @@ def test_schedule_concurrent_batches( model_runner_output = ModelRunnerOutput( req_ids=[requests[0].request_id], req_id_to_index={requests[0].request_id: 0}, - sampled_token_ids=[[0]], + sampled_token_ids=[np.array([0])], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -633,7 +641,7 @@ def test_schedule_concurrent_batches( model_runner_output = ModelRunnerOutput( req_ids=[requests[1].request_id], req_id_to_index={requests[1].request_id: 0}, - sampled_token_ids=[[0]], + sampled_token_ids=[np.array([0])], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -670,7 +678,7 @@ def test_preempt_during_execution(): model_runner_output0 = ModelRunnerOutput( req_ids=[requests[0].request_id], req_id_to_index={requests[0].request_id: 0}, - sampled_token_ids=[[0]], + sampled_token_ids=[np.array([0])], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -687,7 +695,7 @@ def test_preempt_during_execution(): model_runner_output1 = ModelRunnerOutput( req_ids=[requests[1].request_id], req_id_to_index={requests[1].request_id: 0}, - sampled_token_ids=[[42]], + sampled_token_ids=[np.array([42])], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -704,14 +712,18 @@ def test_preempt_during_execution(): @pytest.mark.parametrize( "spec_tokens,output_tokens,expected", [ - ([[1, 2, 3]], [[1, 2, 3, 4]], (1, 3, 3, [1, 1, 1])), # perfect match - ([[1, 2, 3]], [[1, 5]], (1, 3, 1, [1, 0, 0])), # early mismatch - ([[1, 2], [3]], [[1, 2, 5], [3, 4]], (2, 3, 3, [2, 1])), # multiple sequences - ([[1]], [[1, 2]], (1, 1, 1, [1])), # single token sequence - ([[]], [[5]], (0, 0, 0, [0])), # empty sequence + ([[1, 2, 3]], [np.array([1, 2, 3, 4])], (1, 3, 3, [1, 1, 1])), # perfect match + ([[1, 2, 3]], [np.array([1, 5])], (1, 3, 1, [1, 0, 0])), # early mismatch + ( + [[1, 2], [3]], + [np.array([1, 2, 5]), np.array([3, 4])], + (2, 3, 3, [2, 1]), + ), # multiple sequences + ([[1]], [np.array([1, 2])], (1, 1, 1, [1])), # single token sequence + ([[]], [np.array([5])], (0, 0, 0, [0])), # empty sequence ( [[1, 2, 3], [4, 5, 6]], - [[1, 2, 7], [4, 8]], + [np.array([1, 2, 7]), np.array([4, 8])], (2, 6, 3, [2, 1, 0]), ), # multiple mismatches ], @@ -745,7 +757,7 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): model_runner_output = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=req_to_index, - sampled_token_ids=[[0] for _ in range(len(requests))], + sampled_token_ids=[np.array([0]) for _ in range(len(requests))], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -888,27 +900,65 @@ def _step_until_done( all_finished = all_done -def test_kv_connector_basic(): +def _step_until_kv_transfer_finished(scheduler: Scheduler, req_ids: list[str]): + """Cycle requests through a KV transfer cyle.""" + + # Requests should first transition to WAITING_FOR_REMOTE_KVS + output = scheduler.schedule() + assert len(scheduler.waiting) == len(req_ids) + assert len(scheduler.running) == 0 + assert len(output.scheduled_new_reqs) == 0 + for req in scheduler.requests.values(): + assert req.status == RequestStatus.WAITING_FOR_REMOTE_KVS + + # No model execution yet + EMPTY_OUTPUT = ModelRunnerOutput( + req_ids=[], + req_id_to_index={}, + sampled_token_ids=[], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ) + scheduler.update_from_output(output, EMPTY_OUTPUT) + + # Simulate KV transfer completion using KVConnectorOutput.finished_recving + output = scheduler.schedule() + assert len(scheduler.waiting) == len(req_ids) + assert len(scheduler.running) == 0 + + MODEL_RUNNER_OUTPUT = ModelRunnerOutput( + req_ids=[], + req_id_to_index={}, + sampled_token_ids=[], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + kv_connector_output=KVConnectorOutput(finished_recving=req_ids), + ) + scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) + for req_id in req_ids: + assert req_id in scheduler.finished_recving_kv_req_ids + + +@pytest.mark.parametrize("is_async", [False, True]) +def test_kv_connector_basic(is_async: bool): """ Test whether Scheduler with KVConnector schedules tokens, allocates memory, and cleans up requests as expected under normal operation. """ # Setup Scheduler. + BLOCK_SIZE = 16 + NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2 scheduler = create_scheduler( enable_prefix_caching=True, - use_kv_connector=True, + use_kv_connector=mock_kv( + matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=is_async + ), + block_size=BLOCK_SIZE, ) NUM_TOTAL_BLOCKS = scheduler.kv_cache_manager.block_pool.get_num_free_blocks() - BLOCK_SIZE = scheduler.cache_config.block_size - - # Mock External Cache Hit. - NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2 - scheduler.connector.get_num_new_matched_tokens = Mock(name="method") - scheduler.connector.get_num_new_matched_tokens.return_value = ( - NUM_MATCHED_NEW_TOKENS, - False, - ) ###################################################### # FIRST SET OF REQUESTS - External Hit Only @@ -928,10 +978,13 @@ def test_kv_connector_basic(): req_ids.append(request.request_id) req_to_index[request.request_id] = i + if is_async: + _step_until_kv_transfer_finished(scheduler, req_ids) + MODEL_RUNNER_OUTPUT = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=req_to_index, - sampled_token_ids=[[1000]] * len(req_ids), + sampled_token_ids=[np.array([1000])] * len(req_ids), logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -978,10 +1031,13 @@ def test_kv_connector_basic(): req_ids.append(request.request_id) req_to_index[request.request_id] = i + if is_async: + _step_until_kv_transfer_finished(scheduler, req_ids) + MODEL_RUNNER_OUTPUT = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=req_to_index, - sampled_token_ids=[[1000]] * len(req_ids), + sampled_token_ids=[np.array([1000])] * len(req_ids), logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1020,17 +1076,10 @@ def test_external_prefix_cache_metrics(): """ # Setup Scheduler. + NUM_MATCHED_NEW_TOKENS = 4 scheduler = create_scheduler( enable_prefix_caching=False, - use_kv_connector=True, - ) - - # Mock connector to simulate a partial external cache hit - NUM_MATCHED_NEW_TOKENS = 4 - scheduler.connector.get_num_new_matched_tokens = Mock(name="method") - scheduler.connector.get_num_new_matched_tokens.return_value = ( - NUM_MATCHED_NEW_TOKENS, - False, + use_kv_connector=mock_kv(matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=False), ) # --- Prepare simple requests --- @@ -1051,7 +1100,7 @@ def test_external_prefix_cache_metrics(): MODEL_RUNNER_OUTPUT = ModelRunnerOutput( req_ids=[r.request_id for r in requests], req_id_to_index={r.request_id: i for i, r in enumerate(requests)}, - sampled_token_ids=[[1000]] * NUM_REQUESTS, + sampled_token_ids=[np.array([1000])] * NUM_REQUESTS, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1085,21 +1134,16 @@ def test_kv_connector_unable_to_allocate(use_ec_connector, ec_role): # Setup Scheduler With Mock External Cache Hit. BLOCK_SIZE = 4 NUM_BLOCKS = 10 + NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2 scheduler = create_scheduler( enable_prefix_caching=True, - use_kv_connector=True, + use_kv_connector=mock_kv(matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=False), block_size=BLOCK_SIZE, num_blocks=NUM_BLOCKS, # encoder connector should not affect test results use_ec_connector=use_ec_connector, ec_role=ec_role, ) - NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2 - scheduler.connector.get_num_new_matched_tokens = Mock(name="method") - scheduler.connector.get_num_new_matched_tokens.return_value = ( - NUM_MATCHED_NEW_TOKENS, - False, - ) # Create two requests. The second request will not be able to # allocate slots because it will not have enough blocks. @@ -1122,7 +1166,7 @@ def test_kv_connector_unable_to_allocate(use_ec_connector, ec_role): MODEL_RUNNER_OUTPUT = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=req_to_index, - sampled_token_ids=[[1000]] * len(req_ids), + sampled_token_ids=[np.array([1000])] * len(req_ids), logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1174,9 +1218,10 @@ def test_kv_connector_handles_preemption(use_ec_connector, ec_role): BLOCK_SIZE = 2 # NOTE: there is 1 null block, so this is 6 blocks. NUM_BLOCKS = 7 + NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE scheduler = create_scheduler( enable_prefix_caching=True, - use_kv_connector=True, + use_kv_connector=mock_kv(matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=False), block_size=BLOCK_SIZE, num_blocks=NUM_BLOCKS, # encoder connector should not affect test results @@ -1184,13 +1229,6 @@ def test_kv_connector_handles_preemption(use_ec_connector, ec_role): ec_role=ec_role, ) - NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE - scheduler.connector.get_num_new_matched_tokens = Mock(name="method") - scheduler.connector.get_num_new_matched_tokens.return_value = ( - NUM_MATCHED_NEW_TOKENS, - False, - ) - # Create two requests. # Both can be scheduled at first, but the second request # will be preempted and re-scheduled. @@ -1213,7 +1251,7 @@ def test_kv_connector_handles_preemption(use_ec_connector, ec_role): MODEL_RUNNER_OUTPUT = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=req_to_index, - sampled_token_ids=[[1000]] * len(req_ids), + sampled_token_ids=[np.array([1000])] * len(req_ids), logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1306,7 +1344,7 @@ def make_output(scheduler: Scheduler): return ModelRunnerOutput( req_ids=[req.request_id for req in scheduler.running], req_id_to_index={req.request_id: i for i, req in enumerate(scheduler.running)}, - sampled_token_ids=[[1000]] * len(scheduler.running), + sampled_token_ids=[np.array([1000])] * len(scheduler.running), logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1723,7 +1761,7 @@ def test_priority_scheduling_preemption(): req_id_to_index={ req.request_id: i for i, req in enumerate(low_priority_requests) }, - sampled_token_ids=[[100] for _ in low_priority_requests], + sampled_token_ids=[np.array([100]) for _ in low_priority_requests], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1792,7 +1830,7 @@ def test_priority_scheduling_no_preemption_when_space_available(): req_id_to_index={ req.request_id: i for i, req in enumerate(low_priority_requests) }, - sampled_token_ids=[[100] for _ in low_priority_requests], + sampled_token_ids=[np.array([100]) for _ in low_priority_requests], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -2038,7 +2076,7 @@ def test_priority_scheduling_heap_property(): model_output = ModelRunnerOutput( req_ids=[req.req_id], req_id_to_index={req.req_id: 0}, - sampled_token_ids=[[100]], + sampled_token_ids=[np.array([100])], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -2124,7 +2162,7 @@ def test_priority_scheduling_preemption_and_resumption_when_out_of_kv( model_output = ModelRunnerOutput( req_ids=[request_low.request_id], req_id_to_index={request_low.request_id: 0}, - sampled_token_ids=[[100]], + sampled_token_ids=[np.array([100])], # spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, @@ -2155,7 +2193,7 @@ def test_priority_scheduling_preemption_and_resumption_when_out_of_kv( model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, - sampled_token_ids=[[100] for _ in requests], + sampled_token_ids=[np.array([100]) for _ in requests], # spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, @@ -2181,7 +2219,7 @@ def test_priority_scheduling_preemption_and_resumption_when_out_of_kv( model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, - sampled_token_ids=[[], [100]], + sampled_token_ids=[np.array([]), np.array([100])], # spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, @@ -2230,6 +2268,8 @@ def test_chunked_prefill_disabled_for_encoder_decoder( scheduler_config = SchedulerConfig( enable_chunked_prefill=enable_chunked_prefill, is_encoder_decoder=is_encoder_decoder, + # Must <= max_num_batched_tokens if chunked prefill is disabled + max_model_len=SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS, ) # `is_encoder_decoder` should only be used during construction @@ -2254,7 +2294,6 @@ def _validate_chunked_prefill_settings_for_encoder_decoder( ) -> None: """Validate chunked prefill settings in the scheduler config for encoder-decoder models.""" - assert scheduler_config.chunked_prefill_enabled is expect_enabled assert scheduler_config.enable_chunked_prefill is expect_enabled if is_encoder_decoder: # Encoder-decoder models should automatically disable chunked multimodal @@ -2597,7 +2636,7 @@ def test_ec_connector_with_partial_cache_hit_multi_round(use_kv_connector): model_output = ModelRunnerOutput( req_ids=[request1.request_id], req_id_to_index={request1.request_id: 0}, - sampled_token_ids=[[100]], + sampled_token_ids=[np.array([100])], # spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, @@ -2803,7 +2842,7 @@ def test_ec_connector_unable_to_allocate(use_kv_connector): MODEL_RUNNER_OUTPUT = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=req_to_index, - sampled_token_ids=[[1000]] * len(req_ids), + sampled_token_ids=[np.array([1000])] * len(req_ids), logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -2916,7 +2955,7 @@ def test_priority_scheduling_ec_connector_preemption_and_resumption( model_output = ModelRunnerOutput( req_ids=[request_low.request_id], req_id_to_index={request_low.request_id: 0}, - sampled_token_ids=[[100]], + sampled_token_ids=[np.array([100])], # spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, @@ -2967,7 +3006,7 @@ def test_priority_scheduling_ec_connector_preemption_and_resumption( model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, - sampled_token_ids=[[100] for _ in requests], + sampled_token_ids=[np.array([100]) for _ in requests], # spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, @@ -3002,7 +3041,7 @@ def test_priority_scheduling_ec_connector_preemption_and_resumption( model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, - sampled_token_ids=[[100], [100, 200]], + sampled_token_ids=[np.array([100]), np.array([100, 200])], # spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, @@ -3188,7 +3227,7 @@ def test_ec_connector_allocate_encoder_tokens_with_external_load(use_kv_connecto model_output = ModelRunnerOutput( req_ids=[request1.request_id, request2.request_id], req_id_to_index={request1.request_id: 0, request2.request_id: 1}, - sampled_token_ids=[[100], [121]], + sampled_token_ids=[np.array([100]), np.array([121])], # spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, diff --git a/tests/v1/core/utils.py b/tests/v1/core/utils.py index 3692e633322e2..65511c17473b2 100644 --- a/tests/v1/core/utils.py +++ b/tests/v1/core/utils.py @@ -3,6 +3,7 @@ import torch +from tests.v1.kv_connector.unit.utils import MockKVConfig from vllm.config import ( CacheConfig, ECTransferConfig, @@ -33,6 +34,10 @@ from vllm.v1.structured_output import StructuredOutputManager EOS_TOKEN_ID = 50256 +def mock_kv(matched_tokens: int, is_async: bool): + return MockKVConfig(matched_tokens=matched_tokens, is_async=is_async) + + def create_scheduler( model: str = "facebook/opt-125m", max_num_seqs: int = 16, @@ -40,7 +45,7 @@ def create_scheduler( enable_prefix_caching: bool | None = None, long_prefill_token_threshold: int = 0, disable_chunked_mm_input: bool = False, - use_kv_connector: bool = False, + use_kv_connector: None | bool | MockKVConfig = None, num_blocks: int = 10000, block_size: int = 16, max_model_len: int | None = None, @@ -94,15 +99,22 @@ def create_scheduler( cache_dtype="auto", **kwargs_cache, ) - kv_transfer_config = ( - KVTransferConfig( + kv_transfer_config = None + if isinstance(use_kv_connector, MockKVConfig): + kv_transfer_config = KVTransferConfig( + kv_connector="MockKVConnector", + kv_role="kv_both", + kv_connector_extra_config={ + "matched_tokens": use_kv_connector.matched_tokens, + "is_async": use_kv_connector.is_async, + }, + ) + elif use_kv_connector: + kv_transfer_config = KVTransferConfig( kv_connector="SharedStorageConnector", kv_role="kv_both", kv_connector_extra_config={"shared_storage_path": "local_storage"}, ) - if use_kv_connector - else None - ) speculative_config: SpeculativeConfig | None = None if num_speculative_tokens is not None: diff --git a/tests/v1/determinism/conftest.py b/tests/v1/determinism/conftest.py new file mode 100644 index 0000000000000..3c2136e005849 --- /dev/null +++ b/tests/v1/determinism/conftest.py @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + + +@pytest.fixture(autouse=True) +def enable_batch_invariant_mode(monkeypatch: pytest.MonkeyPatch): + """Automatically enable batch invariant kernel overrides for all tests.""" + monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1") + yield diff --git a/tests/v1/generation/test_batch_invariance.py b/tests/v1/determinism/test_batch_invariance.py similarity index 92% rename from tests/v1/generation/test_batch_invariance.py rename to tests/v1/determinism/test_batch_invariance.py index 8fd038bca5d0f..f018ee551dbfe 100644 --- a/tests/v1/generation/test_batch_invariance.py +++ b/tests/v1/determinism/test_batch_invariance.py @@ -6,66 +6,9 @@ import random import pytest import torch +from utils import _extract_step_logprobs, _random_prompt, skip_unsupported from vllm import LLM, SamplingParams -from vllm.platforms import current_platform - -skip_unsupported = pytest.mark.skipif( - not (current_platform.is_cuda() and current_platform.has_device_capability(90)), - reason="Requires CUDA and >= Hopper (SM90)", -) - - -@pytest.fixture(autouse=True) -def enable_batch_invariant_mode(monkeypatch: pytest.MonkeyPatch): - """Automatically enable batch invariant kernel overrides for all tests.""" - monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1") - yield - - -def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str: - # Generate more realistic prompts that will actually produce varied tokens - # Use a mix of common English text patterns - - prompt_templates = [ - # Question-answer style - "Question: What is the capital of France?\nAnswer: The capital of France is", - "Q: How does photosynthesis work?\nA: Photosynthesis is the process by which", - "User: Can you explain quantum mechanics?\nAssistant: Quantum mechanics is", - # Story/narrative style - "Once upon a time in a distant galaxy, there lived", - "The old man walked slowly down the street, remembering", - "In the year 2157, humanity finally discovered", - # Technical/code style - "To implement a binary search tree in Python, first we need to", - "The algorithm works by iterating through the array and", - "Here's how to optimize database queries using indexing:", - # Factual/informative style - "The Renaissance was a period in European history that", - "Climate change is caused by several factors including", - "The human brain contains approximately 86 billion neurons which", - # Conversational style - "I've been thinking about getting a new laptop because", - "Yesterday I went to the store and bought", - "My favorite thing about summer is definitely", - ] - - # Pick a random template - base_prompt = random.choice(prompt_templates) - - if max_words < min_words: - max_words = min_words - target_words = random.randint(min_words, max_words) - - if target_words > 50: - # For longer prompts, repeat context - padding_text = ( - " This is an interesting topic that deserves more explanation. " - * (target_words // 50) - ) - base_prompt = base_prompt + padding_text - - return base_prompt @skip_unsupported @@ -204,22 +147,6 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle( llm_bsN.shutdown() -def _extract_step_logprobs(request_output): - if getattr(request_output, "outputs", None): - inner = request_output.outputs[0] - if hasattr(inner, "logprobs") and inner.logprobs is not None: - t = torch.tensor( - [ - inner.logprobs[i][tid].logprob - for i, tid in enumerate(inner.token_ids) - ], - dtype=torch.float32, - ) - return t, inner.token_ids - - return None, None - - @skip_unsupported @pytest.mark.parametrize( "backend", diff --git a/tests/v1/determinism/test_online_batch_invariance.py b/tests/v1/determinism/test_online_batch_invariance.py new file mode 100644 index 0000000000000..23f47863dd23f --- /dev/null +++ b/tests/v1/determinism/test_online_batch_invariance.py @@ -0,0 +1,161 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +HTTP-based batch invariance test: send requests to a running +vLLM server and compare BS=1 vs BS=N results (tokens and per-step logprobs). + +Environment variables: + - VLLM_TEST_MODEL: served model name (e.g., Qwen/Qwen3-1.7B / DeepSeek-R1) + - VLLM_TP_SIZE: tensor parallelism size (e.g., 4) + +""" + +import os +import random +import sys +from typing import Any + +import openai +from utils import _random_prompt, skip_unsupported + +from tests.utils import RemoteOpenAIServer + + +def _request_completion( + client: openai.OpenAI, + model: str, + prompt: Any, + sp: dict[str, Any], + max_retries: int = 3, + retry_backoff: float = 0.5, +) -> dict[str, Any] | None: + payload: dict[str, Any] = {"model": model, "prompt": prompt} + payload.update(sp) + + for attempt in range(max_retries + 1): + try: + completion = client.completions.create(**payload) + # Convert to plain dict so downstream logic can keep using + # dict-style access just like with raw HTTP JSON. + return completion.model_dump() + except Exception as e: # pragma: no cover + if attempt < max_retries: + import time as _t + + _t.sleep(retry_backoff * (2**attempt)) + continue + sys.stderr.write(f"Error: {e}\n") + return None + return None + + +def _extract_tokens_and_logprobs( + choice: dict[str, Any], +) -> tuple[list[Any], list[float] | None]: + tokens: list[Any] = [] + token_logprobs: list[float] | None = None + lp = choice.get("logprobs") + if lp and isinstance(lp, dict): + tokens = lp.get("token_ids") or lp.get("tokens") or [] + token_logprobs = lp.get("token_logprobs", None) + return tokens, token_logprobs + + +def _compare_bs1_vs_bsn_single_process( + prompts: list[str], + sp_kwargs: dict[str, Any], + client: openai.OpenAI, + model_name: str, +) -> None: + # BS=1 + bs1_tokens_per_prompt: list[list[Any]] = [] + bs1_logprobs_per_prompt: list[list[float] | None] = [] + for p in prompts: + resp = _request_completion(client, model_name, p, sp_kwargs) + if resp is None or not resp.get("choices"): + raise AssertionError("BS=1 empty/failed response") + choice = resp["choices"][0] + toks, lps = _extract_tokens_and_logprobs(choice) + if lps is None: + raise AssertionError( + "logprobs not returned; ensure server supports 'logprobs'" + ) + bs1_tokens_per_prompt.append(list(toks)) + bs1_logprobs_per_prompt.append(list(lps)) + + # BS=N + bsN_tokens_per_prompt: list[list[Any]] = [None] * len(prompts) # type: ignore[list-item] + bsN_logprobs_per_prompt: list[list[float] | None] = [None] * len(prompts) + resp = _request_completion(client, model_name, prompts, sp_kwargs) + if resp is None or not resp.get("choices"): + raise AssertionError("BS=N empty/failed batched response") + choices = resp.get("choices", []) + if len(choices) != len(prompts): + raise AssertionError( + f"BS=N choices length {len(choices)} != num prompts {len(prompts)}" + ) + for idx, choice in enumerate(choices): + toks, lps = _extract_tokens_and_logprobs(choice) + if lps is None: + raise AssertionError(f"BS=N missing logprobs for prompt {idx}") + bsN_tokens_per_prompt[idx] = list(toks) + bsN_logprobs_per_prompt[idx] = list(lps) + + # compare + for i, (tokens_bs1, tokens_bsN, logprobs_bs1, logprobs_bsN) in enumerate( + zip( + bs1_tokens_per_prompt, + bsN_tokens_per_prompt, + bs1_logprobs_per_prompt, + bsN_logprobs_per_prompt, + ) + ): + if tokens_bs1 != tokens_bsN: + raise AssertionError( + f"Prompt {i} (sampling): Different tokens sampled. " + f"BS=1 tokens: {tokens_bs1} BS=N tokens: {tokens_bsN}" + ) + if logprobs_bs1 is None or logprobs_bsN is None: + raise AssertionError(f"Prompt {i}: Missing logprobs in one of the runs") + if len(logprobs_bs1) != len(logprobs_bsN): + raise AssertionError( + f"Prompt {i}: Different number of steps: " + f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bsN)} (BS=N)." + ) + for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)): + if a != b: + diff = abs(a - b) + raise AssertionError( + f"Prompt {i} Step {t}: Bitwise mismatch " + f"(abs diff={diff:.6e}). " + f"BS=1 tokens: {tokens_bs1} BS=N tokens: {tokens_bsN}" + ) + + +@skip_unsupported +def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(): + random.seed(int(os.getenv("VLLM_TEST_SEED", "12345"))) + model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B") + prompts_all = [_random_prompt(10, 50) for _ in range(32)] + + sp_kwargs: dict[str, Any] = { + "temperature": 0.6, + "top_p": 1.0, + "max_tokens": 8, + "seed": 42, + "logprobs": 5, + } + + tp_size = os.getenv("VLLM_TP_SIZE", "1") + server_args: list[str] = [] + if tp_size: + server_args += ["-tp", tp_size] + + with RemoteOpenAIServer(model_name, server_args) as server: + client = server.get_client() + _compare_bs1_vs_bsn_single_process( + prompts=prompts_all, + sp_kwargs=sp_kwargs, + client=client, + model_name=model_name, + ) diff --git a/tests/v1/generation/test_rms_norm_batch_invariant.py b/tests/v1/determinism/test_rms_norm_batch_invariant.py similarity index 97% rename from tests/v1/generation/test_rms_norm_batch_invariant.py rename to tests/v1/determinism/test_rms_norm_batch_invariant.py index f79eba58d6ef2..390872519528c 100644 --- a/tests/v1/generation/test_rms_norm_batch_invariant.py +++ b/tests/v1/determinism/test_rms_norm_batch_invariant.py @@ -9,15 +9,10 @@ with the standard CUDA-based implementation to ensure numerical accuracy. import pytest import torch +from utils import skip_unsupported from vllm.model_executor.layers.batch_invariant import rms_norm as triton_rms_norm from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.platforms import current_platform - -skip_unsupported = pytest.mark.skipif( - not (current_platform.is_cuda() and current_platform.has_device_capability(90)), - reason="Requires CUDA and >= Hopper (SM90)", -) @skip_unsupported diff --git a/tests/v1/determinism/utils.py b/tests/v1/determinism/utils.py new file mode 100644 index 0000000000000..5141837faea04 --- /dev/null +++ b/tests/v1/determinism/utils.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import random + +import pytest +import torch + +from vllm.platforms import current_platform + +skip_unsupported = pytest.mark.skipif( + not (current_platform.is_cuda() and current_platform.has_device_capability(90)), + reason="Requires CUDA and >= Hopper (SM90)", +) + + +def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str: + # Generate more realistic prompts that will actually produce varied tokens + # Use a mix of common English text patterns + + prompt_templates = [ + # Question-answer style + "Question: What is the capital of France?\nAnswer: The capital of France is", + "Q: How does photosynthesis work?\nA: Photosynthesis is the process by which", + "User: Can you explain quantum mechanics?\nAssistant: Quantum mechanics is", + # Story/narrative style + "Once upon a time in a distant galaxy, there lived", + "The old man walked slowly down the street, remembering", + "In the year 2157, humanity finally discovered", + # Technical/code style + "To implement a binary search tree in Python, first we need to", + "The algorithm works by iterating through the array and", + "Here's how to optimize database queries using indexing:", + # Factual/informative style + "The Renaissance was a period in European history that", + "Climate change is caused by several factors including", + "The human brain contains approximately 86 billion neurons which", + # Conversational style + "I've been thinking about getting a new laptop because", + "Yesterday I went to the store and bought", + "My favorite thing about summer is definitely", + ] + + # Pick a random template + base_prompt = random.choice(prompt_templates) + + if max_words < min_words: + max_words = min_words + target_words = random.randint(min_words, max_words) + + if target_words > 50: + # For longer prompts, repeat context + padding_text = ( + " This is an interesting topic that deserves more explanation. " + * (target_words // 50) + ) + base_prompt = base_prompt + padding_text + + return base_prompt + + +def _extract_step_logprobs(request_output): + if getattr(request_output, "outputs", None): + inner = request_output.outputs[0] + if hasattr(inner, "logprobs") and inner.logprobs is not None: + t = torch.tensor( + [ + inner.logprobs[i][tid].logprob + for i, tid in enumerate(inner.token_ids) + ], + dtype=torch.float32, + ) + return t, inner.token_ids + + return None, None diff --git a/tests/v1/e2e/test_async_scheduling.py b/tests/v1/e2e/test_async_scheduling.py index 444afd5196dd8..f732b05f09f9d 100644 --- a/tests/v1/e2e/test_async_scheduling.py +++ b/tests/v1/e2e/test_async_scheduling.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from itertools import repeat from typing import Any import pytest @@ -8,126 +9,275 @@ import torch._dynamo.config as dynamo_config from vllm import SamplingParams from vllm.logprobs import Logprob from vllm.sampling_params import StructuredOutputsParams +from vllm.v1.metrics.reader import Metric from ...conftest import VllmRunner from ...models.utils import check_outputs_equal MODEL = "Qwen/Qwen3-0.6B" +MTP_MODEL = "meta-llama/Llama-3.2-1B-Instruct" -@dynamo_config.patch(cache_size_limit=16) -def test_preempt_and_async_scheduling_e2e( - sample_json_schema, monkeypatch: pytest.MonkeyPatch +first_prompt = ( + "The following numbers of the sequence " + + ", ".join(str(i) for i in range(10)) + + " are:" +) +example_prompts = [first_prompt, "In one word, the capital of France is "] + [ + f"Tell me about the number {i}: " for i in range(32) +] + +default_params = dict( + temperature=0.0, # greedy + max_tokens=23, + min_tokens=18, +) + + +def test_without_spec_decoding( + sample_json_schema, + monkeypatch: pytest.MonkeyPatch, ): """Test consistency of combos of async scheduling, preemption, - uni/multiproc executor, and various sampling parameters - including structured outputs.""" - - first_prompt = ( - "The following numbers of the sequence " - + ", ".join(str(i) for i in range(10)) - + " are:" - ) - example_prompts = [first_prompt, "In one word, the capital of France is "] + [ - f"Tell me about the number {i}: " for i in range(32) - ] - - sampling_param_tests: list[dict[str, Any]] = [ + uni/multiproc executor, prefill chunking.""" + struct_outputs = StructuredOutputsParams(json=sample_json_schema) + test_sampling_params: list[dict[str, Any]] = [ dict(), # dict(min_tokens=20), dict(presence_penalty=-1.0), dict(bad_words=["the", " the"]), dict(logprobs=2), dict(logprobs=2, presence_penalty=-1.0), - dict(structured_outputs=StructuredOutputsParams(json=sample_json_schema)), + dict(structured_outputs=struct_outputs), dict( - structured_outputs=StructuredOutputsParams(json=sample_json_schema), + structured_outputs=struct_outputs, logprobs=2, presence_penalty=-1.0, ), ] - default_params = dict( - temperature=0.0, # greedy - max_tokens=20, - ) + # test_preemption, executor, async_scheduling, + # spec_config, test_prefill_chunking + test_configs = [ + (False, "mp", False, None, False), + (True, "mp", False, None, True), + (False, "mp", True, None, False), + (False, "uni", True, None, False), + (True, "mp", True, None, False), + (True, "uni", True, None, False), + (False, "mp", True, None, True), + (True, "mp", True, None, True), + (True, "uni", True, None, True), + ] + + run_tests(monkeypatch, MODEL, test_configs, test_sampling_params) + + +def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch): + """Test consistency and acceptance rates with some different combos of + preemption, executor, async scheduling, prefill chunking, + spec decoding model length. + """ + + spec_config = { + "method": "eagle3", + "num_speculative_tokens": 2, + "model": "nm-testing/Llama3_2_1B_speculator.eagle3", + } + spec_config_short = spec_config | {"max_model_len": 50} + + # test_preemption, executor, async_scheduling, + # spec_config, test_prefill_chunking + test_configs = [ + (False, "mp", False, None, False), + (False, "mp", False, spec_config, False), + (True, "mp", False, spec_config, True), + (True, "uni", False, spec_config_short, True), + (False, "mp", True, spec_config, False), + (True, "mp", True, spec_config, False), + (False, "mp", True, spec_config_short, True), + (True, "uni", True, spec_config, False), + (True, "uni", True, spec_config_short, False), + (True, "mp", True, spec_config, True), + (True, "uni", True, spec_config_short, True), + ] + + run_tests(monkeypatch, MTP_MODEL, test_configs, [{}]) + + +@dynamo_config.patch(cache_size_limit=16) +def run_tests( + monkeypatch: pytest.MonkeyPatch, + model: str, + test_configs: list[tuple], + test_sampling_params: list[dict[str, Any]], +): + """Test consistency of combos of async scheduling, preemption, + uni/multiproc executor with spec decoding.""" with monkeypatch.context() as m: + # avoid precision errors m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION") # m.setenv("VLLM_BATCH_INVARIANT", "1") - - outputs: list[tuple[str, list]] = [] - for test_preemption in [False, True]: - for executor in ["mp", "uni"]: - for async_scheduling in [False, True]: - cache_arg: dict[str, Any] = ( - dict(num_gpu_blocks_override=32) - if test_preemption - else dict(gpu_memory_utilization=0.7) - ) - test_config = ( - f"executor={executor}, preemption={test_preemption}," - f" async_sched={async_scheduling}" - ) - print("-" * 80) - print(f"---- TESTING: {test_config}") - print("-" * 80) - with VllmRunner( - MODEL, - max_model_len=512, - enforce_eager=True, - async_scheduling=async_scheduling, - distributed_executor_backend=executor, - dtype="float32", # avoid precision errors - **cache_arg, - ) as vllm_model: - results = [] - for override_params in sampling_param_tests: - print(f"----------- RUNNING PARAMS: {override_params}") - results.append( - vllm_model.generate( - example_prompts, - sampling_params=SamplingParams( - **default_params, **override_params - ), - return_logprobs=True, - ) - ) - - if not outputs: - # First check that the different parameter configs - # actually result in different output. - for (other_test_outs, other_test_logprobs), params in zip( - results[1:], sampling_param_tests[1:] - ): - with pytest.raises(AssertionError): - check_outputs_equal( - outputs_0_lst=results[0][0], - outputs_1_lst=other_test_outs, - name_0=f"baseline params={params}", - name_1=f"other params={params}", - ) - assert _all_logprobs_match( - results[0][1], other_test_logprobs - ) - - outputs.append((test_config, results)) - - baseline_config, baseline_tests = outputs[0] - - for test_config, test_outputs in outputs[1:]: - for (base_outs, base_logprobs), (test_outs, test_logprobs), params in zip( - baseline_tests, test_outputs, sampling_param_tests - ): - check_outputs_equal( - outputs_0_lst=base_outs, - outputs_1_lst=test_outs, - name_0=f"baseline=[{baseline_config}], params={params}", - name_1=f"config=[{test_config}], params={params}", + outputs: list[tuple[str, list, list]] = [] + for n, ( + test_preemption, + executor, + async_scheduling, + spec_config, + test_prefill_chunking, + ) in enumerate(test_configs, 1): + test_str = f"{n}/{len(test_configs)}" + test_results = run_test( + model, + test_str, + test_sampling_params, + test_preemption, + executor, + async_scheduling, + spec_config, + test_prefill_chunking=test_prefill_chunking, ) - assert _all_logprobs_match(base_logprobs, test_logprobs) + outputs.append(test_results) - print(f"PASSED: config=[{test_config}], params={params}") + baseline_config, baseline_tests, _ = outputs[0] + _, _, baseline_acceptances = next( + (o for o in outputs if o[2] is not None), (None, None, None) + ) + + print(f"BASELINE: config=[{baseline_config}], accept_rates={baseline_acceptances}") + + failure = None + for test_config, test_outputs, test_acceptance_rates in outputs[1:]: + for (base_outs, base_logprobs), base_acceptance_rate, ( + test_outs, + test_logprobs, + ), test_acceptance_rate, params in zip( + baseline_tests, + baseline_acceptances or repeat(None), + test_outputs, + test_acceptance_rates or repeat(None), + test_sampling_params, + ): + try: + check_outputs_equal( + outputs_0_lst=base_outs, + outputs_1_lst=test_outs, + name_0=f"baseline=[{baseline_config}], params={params}", + name_1=f"config=[{test_config}], params={params}", + ) + assert _all_logprobs_match(base_logprobs, test_logprobs) + + if ( + base_acceptance_rate is not None + and test_acceptance_rate is not None + ): + if "spec_mml=None" in test_config: + assert ( + pytest.approx(test_acceptance_rate, rel=5e-2) + == base_acceptance_rate + ) + else: + # Currently the reported acceptance rate is expected to be + # lower when we sometimes skip drafting altogether. + assert test_acceptance_rate > 0.05 + print( + f"PASSED: config=[{test_config}], params={params}" + f" accept_rate={test_acceptance_rate}" + ) + except AssertionError as e: + print( + f"FAILED: config=[{test_config}], params={params}" + f" accept_rate={test_acceptance_rate}" + ) + if failure is None: + failure = e + + if failure is not None: + raise failure + + +def run_test( + model: str, + test_str: str, + sampling_param_tests: list[dict[str, Any]], + test_preemption: bool, + executor: str, + async_scheduling: bool, + spec_config: dict[str, Any] | None, + test_prefill_chunking: bool, +): + spec_decoding = spec_config is not None + cache_arg: dict[str, Any] = ( + # Force preemptions + dict(num_gpu_blocks_override=32) + if test_preemption + else dict(gpu_memory_utilization=0.9) + ) + spec_mml = (spec_config or {}).get("max_model_len") + test_config = ( + f"executor={executor}, preemption={test_preemption}, " + f"async_sched={async_scheduling}, " + f"chunk_prefill={test_prefill_chunking}, " + f"spec_decoding={spec_decoding}, spec_mml={spec_mml}" + ) + print("-" * 80) + print(f"---- TESTING {test_str}: {test_config}") + print("-" * 80) + with VllmRunner( + model, + max_model_len=512, + enable_chunked_prefill=test_prefill_chunking, + # Force prefill chunking + max_num_batched_tokens=48 if test_prefill_chunking else None, + # enforce_eager=True, + async_scheduling=async_scheduling, + distributed_executor_backend=executor, + dtype="float32", # avoid precision errors + speculative_config=spec_config, + disable_log_stats=False, + **cache_arg, + ) as vllm_model: + results = [] + acceptance_rates: list[float] | None = [] if spec_decoding else None + for override_params in sampling_param_tests: + metrics_before = vllm_model.llm.get_metrics() + print(f"----------- RUNNING PARAMS: {override_params}") + results.append( + vllm_model.generate( + example_prompts, + sampling_params=SamplingParams(**default_params, **override_params), + return_logprobs=True, + ) + ) + metrics_after = vllm_model.llm.get_metrics() + if acceptance_rates is not None: + acceptance_rate = _get_acceptance_rate(metrics_before, metrics_after) + acceptance_rates.append(acceptance_rate) + print(f"ACCEPTANCE RATE {acceptance_rate}") + + if test_preemption: + preemptions = _get_count( + metrics_before, metrics_after, "vllm:num_preemptions" + ) + assert preemptions > 0, "preemption test had no preemptions" + + if len(results) > 1: + # First check that the different parameter configs + # actually result in different output. + for (other_test_outs, other_test_logprobs), params in zip( + results[1:], sampling_param_tests[1:] + ): + with pytest.raises(AssertionError): + check_outputs_equal( + outputs_0_lst=results[0][0], + outputs_1_lst=other_test_outs, + name_0=f"baseline params={params}", + name_1=f"other params={params}", + ) + assert _all_logprobs_match(results[0][1], other_test_logprobs) + + return test_config, results, acceptance_rates def _all_logprobs_match(req_a, req_b) -> bool: @@ -149,3 +299,15 @@ def _logprobs_match(lps_a: dict[int, Logprob], lps_b: dict[int, Logprob]) -> boo and a.logprob == pytest.approx(b.logprob, rel=1e-3, abs=1e-6) for a, b in ((lps_a[x], lps_b[x]) for x in lps_a) ) + + +def _get_acceptance_rate(before: list[Metric], after: list[Metric]) -> float: + draft = _get_count(before, after, "vllm:spec_decode_num_draft_tokens") + accept = _get_count(before, after, "vllm:spec_decode_num_accepted_tokens") + return accept / draft if draft > 0 else 0.0 + + +def _get_count(before: list[Metric], after: list[Metric], name: str) -> int: + before_val = next(m.value for m in before if m.name == name) + after_val = next(m.value for m in after if m.name == name) + return after_val - before_val diff --git a/tests/v1/e2e/test_context_length.py b/tests/v1/e2e/test_context_length.py new file mode 100644 index 0000000000000..0ac40bec35fe2 --- /dev/null +++ b/tests/v1/e2e/test_context_length.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Tests for vLLM `vllm/v1/engine/processor.Processor._validate_model_input()` +handling of maximum context length for decoder models. + +This test ensures: +- A prompt that is one token shorter than the model's maximum context length + can be processed successfully when requesting one additional token. +- A prompt that reaches the model's maximum context length throws a + `ValueError` when requesting at least one additional token. +""" + +import pytest + +from tests.conftest import VllmRunner +from tests.utils import create_new_process_for_each_test + + +@create_new_process_for_each_test() +@pytest.mark.parametrize("model, max_model_len", [("JackFram/llama-160m", 2048)]) +@pytest.mark.parametrize( + "prompt_len, max_tokens", + [ + (2047, 1), # prompt_len = max_model_len - 1 -> allowed + (2048, 1), # prompt_len = max_model_len -> not allowed + ], +) +def test_decoder_max_context_length_validation( + model: str, + max_model_len: int, + vllm_runner: type[VllmRunner], + prompt_len: int, + max_tokens: int, +) -> None: + """Check vLLM decoder model input validation for edge cases where + the prompt length is (almost) equal to the max model length.""" + + prompt_ids = [[43] * prompt_len] + + with vllm_runner( + model_name=model, + tokenizer_name=model, + max_model_len=max_model_len, + max_num_seqs=1, + tensor_parallel_size=1, + ) as vllm_model: + if prompt_len + max_tokens <= max_model_len: + # Should succeed as constraints are met + vllm_model.generate_greedy(prompt_ids, max_tokens) + else: + # Should raise the ValueError defined in + # vllm/v1/engine/processor.Processor_validate_model_input() + expected_msg = ( + f"The decoder prompt (length {prompt_len}) plus the number of " + f"requested output tokens (at least 1) is longer than " + f"the maximum model length of {max_model_len}. " + "Make sure that `max_model_len` is no smaller than the number of " + "text tokens (prompt + requested output tokens)." + ) + with pytest.raises(ValueError) as excinfo: + vllm_model.generate_greedy(prompt_ids, max_tokens) + assert expected_msg in str(excinfo.value) diff --git a/tests/v1/e2e/test_kv_sharing_fast_prefill.py b/tests/v1/e2e/test_kv_sharing_fast_prefill.py index f2c6d1c1fd1a4..2778b0c5e5670 100644 --- a/tests/v1/e2e/test_kv_sharing_fast_prefill.py +++ b/tests/v1/e2e/test_kv_sharing_fast_prefill.py @@ -4,13 +4,11 @@ import random import pytest -import torch from vllm import LLM, SamplingParams from vllm.config import CompilationConfig, CompilationMode -from vllm.distributed import cleanup_dist_env_and_memory -from ...utils import fork_new_process_for_each_test +from ...utils import check_answers, fork_new_process_for_each_test, prep_prompts # global seed SEED = 42 @@ -45,28 +43,12 @@ def test_prompts(): return prompts -def cleanup(llm: LLM, compilation_config: CompilationConfig): - # hacky: below lines are required to free up memory for the next test - # when setting VLLM_ENABLE_V1_MULTIPROCESSING=0, del llm is not sufficient - # TODO(sarckk): when enforce_eager=False, memory is not freed: - # find out why and re-enable test for enforce_eager=False case - llm_engine = llm.llm_engine.engine_core.engine_core - model_runner = llm_engine.model_executor.driver_worker.worker.model_runner - del model_runner.model - del model_runner.kv_caches - del compilation_config.static_forward_context - compilation_config.static_forward_context = {} - - del llm - torch.cuda.empty_cache() - cleanup_dist_env_and_memory() - - @fork_new_process_for_each_test -@pytest.mark.parametrize("enforce_eager", [True]) -@pytest.mark.skip(reason="Disable until Gemma3n supports fast prefill") +@pytest.mark.parametrize("kv_sharing_fast_prefill", [False, True]) +@pytest.mark.parametrize("enforce_eager", [True, False]) def test_kv_sharing_fast_prefill( monkeypatch: pytest.MonkeyPatch, + kv_sharing_fast_prefill: bool, enforce_eager: bool, test_prompts: list[str], ): @@ -79,36 +61,25 @@ def test_kv_sharing_fast_prefill( if not enforce_eager else CompilationMode.NONE, ) + batch_size = 10 with monkeypatch.context() as m: # Make scheduling deterministic for reproducibility m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") - llm = LLM( - model="google/gemma-3n-E2B-it", - enforce_eager=enforce_eager, - compilation_config=compilation_config, - seed=SEED, - ) - ref_responses = llm.generate(test_prompts, sampling_params) - - cleanup(llm, compilation_config) + prompts, answer, indices = prep_prompts(batch_size) llm = LLM( model="google/gemma-3n-E2B-it", enforce_eager=enforce_eager, compilation_config=compilation_config, seed=SEED, - kv_sharing_fast_prefill=True, + kv_sharing_fast_prefill=kv_sharing_fast_prefill, + ) + responses = llm.generate(prompts, sampling_params) + check_answers( + indices, + answer, + [response.outputs[0].text for response in responses], + accept_rate=1.0, ) - optimized_responses = llm.generate(test_prompts, sampling_params) - - cleanup(llm, compilation_config) - - misses = 0 - - for ref_response, optimized_response in zip(ref_responses, optimized_responses): - if ref_response.outputs[0].text != optimized_response.outputs[0].text: - misses += 1 - - assert misses == 0 diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 4a6b84ae4817c..03396270a31cb 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -75,6 +75,14 @@ def model_name(): return "meta-llama/Llama-3.1-8B-Instruct" +@pytest.fixture(autouse=True) +def reset_torch_dynamo(): + """Reset torch dynamo cache before each test""" + yield + # Cleanup after test + torch._dynamo.reset() + + @pytest.mark.parametrize( "speculative_config", [ @@ -272,7 +280,7 @@ def test_speculators_model_integration( @pytest.mark.parametrize( - ["model_setup", "mm_enabled", "chunked_prefill_enabled"], + ["model_setup", "mm_enabled", "enable_chunked_prefill"], [ (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False, False), pytest.param( @@ -358,7 +366,7 @@ def test_eagle_correctness( sampling_config: SamplingParams, model_setup: tuple[str, str, str, int], mm_enabled: bool, - chunked_prefill_enabled: bool, + enable_chunked_prefill: bool, attn_backend: str, ): if attn_backend == "TREE_ATTN": @@ -396,9 +404,7 @@ def test_eagle_correctness( method, model_name, spec_model_name, tp_size = model_setup max_model_len = 2048 - max_num_batched_tokens = max_model_len - if chunked_prefill_enabled: - max_num_batched_tokens = 128 + max_num_batched_tokens = 128 if enable_chunked_prefill else max_model_len ref_llm = LLM( model=model_name, max_model_len=max_model_len, tensor_parallel_size=tp_size @@ -420,7 +426,7 @@ def test_eagle_correctness( }, max_model_len=max_model_len, max_num_batched_tokens=max_num_batched_tokens, - enable_chunked_prefill=chunked_prefill_enabled, + enable_chunked_prefill=enable_chunked_prefill, ) spec_outputs = spec_llm.chat(test_prompts, sampling_config) matches = 0 diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index 4e852dca95eb0..3ba8ab26f5522 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -571,7 +571,7 @@ def test_encoder_instance_zero_kv_cache( ) # Check 5: Verify chunked prefill is disabled - assert not vllm_config.scheduler_config.chunked_prefill_enabled, ( + assert not vllm_config.scheduler_config.enable_chunked_prefill, ( "Encoder instance should disable chunked prefill (no KV cache)" ) diff --git a/tests/v1/engine/test_output_processor.py b/tests/v1/engine/test_output_processor.py index d77a119ec60f8..8e1198b315bd1 100644 --- a/tests/v1/engine/test_output_processor.py +++ b/tests/v1/engine/test_output_processor.py @@ -49,10 +49,15 @@ def _ref_convert_id_to_token( @pytest.mark.parametrize( "request_output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY] ) +@pytest.mark.parametrize("stream_interval", [1, 5, 10]) def test_incremental_detokenization( - request_output_kind: RequestOutputKind, dummy_test_vectors + request_output_kind: RequestOutputKind, + stream_interval: int, + dummy_test_vectors, ): - output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False) + output_processor = OutputProcessor( + dummy_test_vectors.tokenizer, log_stats=False, stream_interval=stream_interval + ) engine_core = MockEngineCore(tokens_list=dummy_test_vectors.generation_tokens) # Make N requests. @@ -104,9 +109,18 @@ def test_incremental_detokenization( if request_id not in gen_strings: gen_strings[request_id] = new_text gen_tokens[request_id] = new_tokens + if request_output_kind == RequestOutputKind.DELTA: + assert len(new_tokens) == 1, f"{len(new_tokens)=}" else: gen_strings[request_id] += new_text gen_tokens[request_id].extend(new_tokens) + if ( + request_output_kind == RequestOutputKind.DELTA + and not request_output.finished + ): + assert len(new_tokens) >= stream_interval, ( + f"{len(new_tokens)=}, {stream_interval=}" + ) # Confirmed tracked values matches what we expected. for idx, (ref_gen_str, ref_gen_toks) in enumerate( diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index 4cd26e7b41d3a..a7d769c8542a9 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -677,9 +677,14 @@ def test_structured_output_with_reasoning_matrices( reasoning, content = run_reasoning_extraction(reasoner, [generated_text]) print(f"Prompt: {prompt!r}\nReasoning: {reasoning!r}\nContent: {content!r}") - assert content is not None and reasoning is not None - output_json = json.loads(content) - jsonschema.validate(instance=output_json, schema=reasoning_schema) + if "Qwen3" in model_name: + assert content is not None + + assert reasoning is not None + + if content is not None: + output_json = json.loads(content) + jsonschema.validate(instance=output_json, schema=reasoning_schema) @pytest.mark.skip_global_cleanup diff --git a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh index a9817313cf022..ebc8575e5b390 100755 --- a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh @@ -49,6 +49,8 @@ NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-1} # Default to 1 PREFILLER_TP_SIZE=${PREFILLER_TP_SIZE:-1} DECODER_TP_SIZE=${DECODER_TP_SIZE:-1} GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.2} +PREFILL_BLOCK_SIZE=${PREFILL_BLOCK_SIZE:-16} +DECODE_BLOCK_SIZE=${DECODE_BLOCK_SIZE:-16} # Find the git repository root directory GIT_ROOT=$(git rev-parse --show-toplevel) @@ -136,6 +138,7 @@ run_tests_for_model() { vllm serve $model_name \ --port $PORT \ --enforce-eager \ + --block-size ${PREFILL_BLOCK_SIZE} \ --gpu-memory-utilization $GPU_MEMORY_UTILIZATION \ --tensor-parallel-size $PREFILLER_TP_SIZE \ --kv-transfer-config '$KV_CONFIG'" @@ -177,6 +180,7 @@ run_tests_for_model() { vllm serve $model_name \ --port $PORT \ --enforce-eager \ + --block-size ${DECODE_BLOCK_SIZE} \ --gpu-memory-utilization $GPU_MEMORY_UTILIZATION \ --kv-transfer-config '$KV_CONFIG'" diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 8e421717fea30..b264e5108c16d 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -11,6 +11,7 @@ import uuid from collections import defaultdict from unittest.mock import patch +import numpy as np import pytest import ray import torch @@ -407,6 +408,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker): # `self.kv_cache_layout` is only forced to HND when vllm engine # is started. We mock HND here. kv_cache_layout="HND", + block_size=self.block_size, ), remote_tp_size=remote_tp_size, ) @@ -652,6 +654,7 @@ class TestNixlHandshake: block_lens=worker.block_len_per_layer, attn_backend_name=worker.backend_name, kv_cache_layout=mismatched_layout, + block_size=worker.block_size, ) with pytest.raises(RuntimeError): @@ -706,6 +709,7 @@ class TestNixlHandshake: block_lens=[i * 2 for i in worker.block_len_per_layer], attn_backend_name=worker.backend_name, kv_cache_layout="HND", + block_size=worker.block_size, ) # We don't check layout for homogeneous TP and MLA for now, as the @@ -823,7 +827,7 @@ def test_kv_connector_stats_aggregation(): output = ModelRunnerOutput( req_ids=[f"req_{i}"], req_id_to_index={f"req_{i}": 0}, - sampled_token_ids=[[123]], # dummy token + sampled_token_ids=[np.array([123])], # dummy token logprobs=None, prompt_logprobs_dict={}, pooler_output=[None], @@ -904,7 +908,7 @@ def test_multi_kv_connector_stats_aggregation(): output = ModelRunnerOutput( req_ids=[f"req_{i}"], req_id_to_index={f"req_{i}": 0}, - sampled_token_ids=[[123]], + sampled_token_ids=[np.array([123])], logprobs=None, prompt_logprobs_dict={}, pooler_output=[None], @@ -962,7 +966,7 @@ def test_scheduler_kv_connector_stats_aggregation(): model_output = ModelRunnerOutput( req_ids=["req_0"], req_id_to_index={"req_0": 0}, - sampled_token_ids=[[123]], + sampled_token_ids=[np.array([123])], logprobs=None, prompt_logprobs_dict={}, pooler_output=[None], diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index f0031643aa9d4..c248104d5b5ea 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -3,9 +3,11 @@ import tempfile from collections import defaultdict from collections.abc import Callable -from itertools import count +from dataclasses import dataclass +from itertools import chain, count from typing import Any +import numpy as np import torch from vllm import SamplingParams @@ -18,13 +20,18 @@ from vllm.config import ( VllmConfig, ) from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, + KVConnectorMetadata, + KVConnectorRole, +) from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa SharedStorageConnector, ) from vllm.utils.hashing import sha256 from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash -from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.core.sched.scheduler import Scheduler, SchedulerOutput from vllm.v1.kv_cache_interface import ( FullAttentionSpec, KVCacheConfig, @@ -222,7 +229,7 @@ def create_model_runner_output( # Make sampled tokens. sampled_token = EOS_TOKEN_ID if use_eos else token_id - sampled_token_ids = [[sampled_token] for _ in req_ids] + sampled_token_ids = [np.array([sampled_token]) for _ in req_ids] kv_connector_output = ( None @@ -307,6 +314,82 @@ class TestSharedStorageConnector(SharedStorageConnector): return attr +@dataclass(frozen=True) +class MockKVConfig: + matched_tokens: int = 0 + is_async: bool = False + + +class MockKVConnectorMetadata(KVConnectorMetadata): + def __init__(self): + # Scheduler tests check metadata.requests + self.requests: list = [] + + +class MockKVConnector(KVConnectorBase_V1): + """Mock KV connector for scheduler tests, supporting both sync and async mode.""" + + def __init__( + self, + vllm_config: VllmConfig, + role: KVConnectorRole, + kv_cache_config: KVCacheConfig | None = None, + ): + super().__init__(vllm_config, role, kv_cache_config) + extra_config = self._kv_transfer_config.kv_connector_extra_config + self.config = MockKVConfig( + matched_tokens=extra_config["matched_tokens"], + is_async=extra_config["is_async"], + ) + + def get_num_new_matched_tokens( + self, + request: Request, + num_computed_tokens: int, + ) -> tuple[int | None, bool]: + return (self.config.matched_tokens, self.config.is_async) + + def update_state_after_alloc( + self, + request: Request, + blocks: KVCacheBlocks, + num_external_tokens: int, + ): + pass + + def build_connector_meta( + self, scheduler_output: SchedulerOutput + ) -> KVConnectorMetadata: + metadata = MockKVConnectorMetadata() + cached_reqs = scheduler_output.scheduled_cached_reqs + for req_id in chain( + (req.req_id for req in scheduler_output.scheduled_new_reqs), + ( + req_id + for req_id in cached_reqs.req_ids + if req_id in cached_reqs.resumed_req_ids + ), + ): + metadata.requests.append({"req_id": req_id}) + return metadata + + def start_load_kv(self, kv_caches, finished_req_ids): + pass + + def wait_for_layer_load(self, layer_name): + pass + + def save_kv_layer(self, layer_name, kv_layer, attn_metadata, **kwargs): + pass + + def wait_for_save(self): + pass + + KVConnectorFactory.register_connector( "TestSharedStorageConnector", __name__, TestSharedStorageConnector.__name__ ) + +KVConnectorFactory.register_connector( + "MockKVConnector", __name__, MockKVConnector.__name__ +) diff --git a/tests/v1/sample/test_logprobs.py b/tests/v1/sample/test_logprobs.py index 354fff22dc2ac..42584938bc06f 100644 --- a/tests/v1/sample/test_logprobs.py +++ b/tests/v1/sample/test_logprobs.py @@ -47,6 +47,7 @@ def vllm_model(vllm_runner, request) -> Generator[VllmRunner, None, None]: max_num_batched_tokens=16, max_num_seqs=16, max_model_len=128, + enable_chunked_prefill=True, enforce_eager=True, # TODO: enable this once we support it for # prompt logprobs. diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 89d0ec769ac09..805b8c86b0804 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -3,6 +3,7 @@ from unittest import mock +import numpy as np import pytest import torch @@ -112,7 +113,9 @@ def test_prepare_next_token_ids(): sampled_token_ids_tensor = torch.tensor( sampled_token_ids, dtype=torch.int32, device=device ) - sampled_token_ids_cpu = [[i for i in seq if i != -1] for seq in sampled_token_ids] + sampled_token_ids_cpu = [ + np.array([i for i in seq if i != -1]) for seq in sampled_token_ids + ] expected_next_token_ids_cpu = [1, 4, 30, 40] expected_next_token_ids_tensor = torch.tensor( @@ -321,6 +324,7 @@ def test_prepare_inputs_padded(): @pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform()) @pytest.mark.parametrize("pp_size", [1, 2]) @pytest.mark.parametrize("use_distinct_embed_tokens", [True, False]) +@pytest.mark.parametrize("use_distinct_lm_head", [True, False]) @mock.patch("vllm.v1.spec_decode.eagle.get_pp_group") @mock.patch("vllm.v1.spec_decode.eagle.get_layers_from_vllm_config") @mock.patch("vllm.v1.spec_decode.eagle.get_model") @@ -332,6 +336,7 @@ def test_load_model( attn_backend, pp_size, use_distinct_embed_tokens, + use_distinct_lm_head, monkeypatch, ): monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend) @@ -347,12 +352,13 @@ def test_load_model( # Setup draft model mock mock_model = mock.MagicMock() + mock_model.model = mock.MagicMock() + mock_model.has_own_embed_tokens = use_distinct_embed_tokens if use_distinct_embed_tokens: - # Some models can have a different hidden size than the target model, - # so we test that their embed_tokens doesn't get overwritten - mock_model.model.embed_tokens.weight.shape = (131072, 2048) - else: - mock_model.model.embed_tokens.weight.shape = (131072, 4096) + mock_model.model.embed_tokens = mock.MagicMock() + mock_model.has_own_lm_head = use_distinct_lm_head + if use_distinct_lm_head: + mock_model.lm_head = mock.MagicMock() mock_get_model.return_value = mock_model @@ -388,15 +394,13 @@ def test_load_model( target_model = mock.create_autospec(_TargetModelStub, instance=True) target_model.model = mock.MagicMock() - target_model.model.embed_tokens.weight.shape = (131072, 4096) + target_model.lm_head = mock.MagicMock() + target_model.model.embed_tokens = mock.MagicMock() from vllm.model_executor.models import SupportsMultiModal assert not isinstance(target_model, SupportsMultiModal) - if method == "eagle": - target_model.lm_head = mock.MagicMock() - # Create proposer using the helper function proposer = _create_proposer(method, num_speculative_tokens=8) @@ -406,18 +410,18 @@ def test_load_model( # Verify common interactions mock_get_model.assert_called_once() - # Verify that EAGLE models gain the lm head from the target model - if method == "eagle": - assert proposer.model.lm_head == target_model.lm_head + # Verify that the lm head is set correctly + if use_distinct_lm_head: + assert proposer.model.lm_head is not target_model.lm_head + else: + assert proposer.model.lm_head is target_model.lm_head # Verify that the embed tokens are set correctly # If pp_size is > 1, the embed tokens should be distinct if pp_size > 1 or use_distinct_embed_tokens: - assert proposer.model.model.embed_tokens != target_model.model.embed_tokens + assert proposer.model.model.embed_tokens is not target_model.model.embed_tokens else: - # When pp_size is 1 and the draft and target models have - # embed_tokens of the same shape, they should be shared. - assert proposer.model.model.embed_tokens == target_model.model.embed_tokens + assert proposer.model.model.embed_tokens is target_model.model.embed_tokens @pytest.mark.parametrize("method", ["eagle", "eagle3"]) diff --git a/tests/v1/spec_decode/test_mtp.py b/tests/v1/spec_decode/test_mtp.py index 6d59b58e739eb..c5c0491abaf7c 100644 --- a/tests/v1/spec_decode/test_mtp.py +++ b/tests/v1/spec_decode/test_mtp.py @@ -67,6 +67,10 @@ def test_mtp_load_model_unified(mock_get_model, mock_get_layers, mock_get_pp_gro mock_model = mock.MagicMock() mock_model.model.embed_tokens.weight.shape = (131072, 4096) mock_get_model.return_value = mock_model + # MTP does not have its own embed_tokens or lm_head + # so it should share them with the target model + mock_model.has_own_embed_tokens = False + mock_model.has_own_lm_head = False target_attn_layers = {"target_attn_1": mock.MagicMock()} all_attn_layers = {**target_attn_layers, "draft_attn_1": mock.MagicMock()} diff --git a/tests/v1/spec_decode/test_ngram.py b/tests/v1/spec_decode/test_ngram.py index 692c39282c372..563bc1d957f41 100644 --- a/tests/v1/spec_decode/test_ngram.py +++ b/tests/v1/spec_decode/test_ngram.py @@ -77,7 +77,7 @@ def test_ngram_proposer(): # No match. token_ids_cpu = np.array([[1, 2, 3, 4, 5]]) result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose( - sampled_token_ids=[[0]], + sampled_token_ids=[np.array([0])], req_ids=["0"], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, @@ -88,7 +88,7 @@ def test_ngram_proposer(): # No match for 4-gram. token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]]) result = get_ngram_proposer(min_n=4, max_n=4, k=2).propose( - sampled_token_ids=[[0]], + sampled_token_ids=[np.array([0])], req_ids=["0"], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, @@ -99,7 +99,7 @@ def test_ngram_proposer(): # No match for 4-gram but match for 3-gram. token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]]) result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose( - sampled_token_ids=[[0]], + sampled_token_ids=[np.array([0])], req_ids=["0"], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, @@ -111,7 +111,7 @@ def test_ngram_proposer(): # In this case, the proposer should return the 4-gram match. token_ids_cpu = np.array([[2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]]) result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose( - sampled_token_ids=[[0]], + sampled_token_ids=[np.array([0])], req_ids=["0"], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, @@ -122,7 +122,7 @@ def test_ngram_proposer(): # Match for 2-gram and 3-gram, but not 4-gram. token_ids_cpu = np.array([[3, 4, 5, 2, 3, 4, 1, 2, 3, 4]]) result = get_ngram_proposer(min_n=2, max_n=4, k=2).propose( - sampled_token_ids=[[0]], + sampled_token_ids=[np.array([0])], req_ids=["0"], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, @@ -133,7 +133,7 @@ def test_ngram_proposer(): # Multiple 3-gram matched, but always pick the first one. token_ids_cpu = np.array([[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]]) result = get_ngram_proposer(min_n=3, max_n=3, k=2).propose( - sampled_token_ids=[[0]], + sampled_token_ids=[np.array([0])], req_ids=["0"], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, @@ -144,7 +144,7 @@ def test_ngram_proposer(): # check empty input token_ids_cpu = np.array([[]]) result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose( - sampled_token_ids=[[0]], + sampled_token_ids=[np.array([0])], req_ids=["0"], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, @@ -157,7 +157,7 @@ def test_ngram_proposer(): # second request has 3 tokens and no match. Padded with -1 for max len 5 token_ids_cpu = np.array([[1, 2, 3, 1, 2], [4, 5, 6, -1, -1]]) result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose( - sampled_token_ids=[[0], [1]], + sampled_token_ids=[np.array([0]), np.array([1])], req_ids=["0", "1"], num_tokens_no_spec=np.array([5, 3]), token_ids_cpu=token_ids_cpu, @@ -181,7 +181,7 @@ def test_ngram_proposer(): input_2[:3] = [4, 5, 6] token_ids_cpu = np.array([input_1, input_2]) result = ngram_proposer.propose( - sampled_token_ids=[[0], [1]], + sampled_token_ids=[np.array([0]), np.array([1])], req_ids=["0", "1"], num_tokens_no_spec=np.array([len(input_1), 3]), token_ids_cpu=token_ids_cpu, diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index b02d9a657407b..b95c8df3469b3 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -985,8 +985,10 @@ def test_hybrid_block_table_initialization(): req_index = 0 block_table.append_row(kvcache_manager_blocks, req_index) # Get expected kernel blocks from the implementation for verification. - expected_kernel_blocks = block_table._map_to_kernel_blocks( - np.array(kvcache_manager_blocks) + expected_kernel_blocks = block_table.map_to_kernel_blocks( + np.array(kvcache_manager_blocks), + block_table.blocks_per_kv_block, + block_table._kernel_block_arange, ) # Verify block table state assert block_table.num_blocks_per_row[req_index] == len(expected_kernel_blocks) diff --git a/tools/install_nixl_from_source_ubuntu.py b/tools/install_nixl_from_source_ubuntu.py index 4a20b6b7bb8fb..a786abba95ad9 100644 --- a/tools/install_nixl_from_source_ubuntu.py +++ b/tools/install_nixl_from_source_ubuntu.py @@ -175,6 +175,7 @@ def build_and_install_prerequisites(args): build_env["LD_LIBRARY_PATH"] = ( f"{ucx_lib_path}:{ucx_plugin_path}:{existing_ld_path}".strip(":") ) + build_env["LDFLAGS"] = "-Wl,-rpath,$ORIGIN" print(f"--> Using LD_LIBRARY_PATH: {build_env['LD_LIBRARY_PATH']}", flush=True) temp_wheel_dir = os.path.join(ROOT_DIR, "temp_wheelhouse") diff --git a/tools/vllm-tpu/build.sh b/tools/vllm-tpu/build.sh index fbc91e379df33..45ef8dfcb1db6 100755 --- a/tools/vllm-tpu/build.sh +++ b/tools/vllm-tpu/build.sh @@ -7,6 +7,13 @@ TOOLS_DIR=$(cd "$(dirname "$SCRIPT_PATH_PARAM")" && pwd) # Absolute path to the REPO_ROOT=$(cd "$TOOLS_DIR/../../" && pwd) # Absolute path to the repo root VLLM_DIR="$REPO_ROOT/" # Path to the vllm sources +CHANGE_FILE_LIST=( + "vllm/entrypoints/cli/main.py" + "vllm/entrypoints/cli/run_batch.py" + "vllm/utils/__init__.py" + "vllm/platforms/__init__.py" +) + # Ensure we are not running from within the vllm directory if SCRIPT_PATH_PARAM is relative like "." if [ "$TOOLS_DIR" = "$VLLM_DIR" ]; then echo "Error: This script should not be run from the vllm directory directly if using relative paths." @@ -30,6 +37,20 @@ if ! grep -q "name = \"vllm-tpu\"" "$PYPROJECT_FILE"; then echo "Patching pyproject.toml project name to vllm-tpu..." cp "$PYPROJECT_FILE" "${PYPROJECT_FILE}.bak" sed -i '0,/^name = "vllm"/s//name = "vllm-tpu"/' "$PYPROJECT_FILE" + + echo "Patching ${CHANGE_FILE_LIST[@]} vllm to vllm-tpu..." + # patching + # importlib.metadata.version('vllm') -> importlib.metadata.version('vllm-tpu') + # importlib.metadata.version("vllm") -> importlib.metadata.version("vllm-tpu") + # importlib.metadata.metadata('vllm') -> importlib.metadata.metadata('vllm-tpu') + # importlib.metadata.metadata("vllm") -> importlib.metadata.metadata("vllm-tpu") + # version('vllm') -> version('vllm-tpu') + # version("vllm") -> version("vllm-tpu") + sed -i \ + -e "s/importlib.metadata.version(\(['\"]\)vllm\1)/importlib.metadata.version(\1vllm-tpu\1)/" \ + -e "s/importlib.metadata.metadata(\(['\"]\)vllm\1)/importlib.metadata.metadata(\1vllm-tpu\1)/" \ + -e "s/version(\(['\"]\)vllm\1)/version(\1vllm-tpu\1)/" \ + "${CHANGE_FILE_LIST[@]}" PATCHED=true else PATCHED=false @@ -45,6 +66,13 @@ cleanup() { echo "Restoring original pyproject.toml..." cp "${PYPROJECT_FILE}.bak" "$PYPROJECT_FILE" rm -f "${PYPROJECT_FILE}.bak" + + echo "Restoring vllm code..." + sed -i \ + -e "s/importlib.metadata.version(\(['\"]\)vllm-tpu\1)/importlib.metadata.version(\1vllm\1)/" \ + -e "s/importlib.metadata.metadata(\(['\"]\)vllm-tpu\1)/importlib.metadata.metadata(\1vllm\1)/" \ + -e "s/version(\(['\"]\)vllm-tpu\1)/version(\1vllm\1)/" \ + "${CHANGE_FILE_LIST[@]}" fi } trap cleanup EXIT HUP INT QUIT PIPE TERM # Register cleanup function to run on script exit and various signals diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 5508e59bcd2f5..e53e4ae6e5296 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -30,7 +30,7 @@ def if_aiter_supported(func: Callable) -> Callable: @functools.wraps(func) def wrapper(*args, **kwargs): - # checks the platform, device arch and aiter library existance. + # checks the platform, device arch and aiter library existence. if current_platform.is_rocm() and IS_AITER_FOUND: from vllm.platforms.rocm import on_gfx9 @@ -43,6 +43,36 @@ def if_aiter_supported(func: Callable) -> Callable: return wrapper +def _rocm_aiter_group_fp8_quant_impl( + x: torch.Tensor, + group_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + assert x.shape[-1] % group_size == 0, "Input shape must be divisible by group size" + from aiter import QuantType, dtypes, get_hip_quant + + aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128) + return aiter_per1x128_quant(x.contiguous(), quant_dtype=dtypes.fp8) + + +def _rocm_aiter_group_fp8_quant_fake( + x: torch.Tensor, + group_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + from aiter import dtypes + + M, N = x.shape + x_fp8 = torch.empty((M, N), dtype=dtypes.fp8, device=x.device) + out_bs = torch.empty( + ( + M, + (N + group_size - 1) // group_size, + ), + dtype=torch.float32, + device=x.device, + ) + return x_fp8, out_bs + + def _rocm_aiter_fused_moe_impl( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -512,6 +542,14 @@ class rocm_aiter_ops: ) # register all the custom ops here + direct_register_custom_op( + op_name="rocm_aiter_group_fp8_quant", + op_func=_rocm_aiter_group_fp8_quant_impl, + mutates_args=[], + fake_impl=_rocm_aiter_group_fp8_quant_fake, + dispatch_key=current_platform.dispatch_key, + ) + direct_register_custom_op( op_name="rocm_aiter_asm_moe_tkw1", op_func=_rocm_aiter_asm_moe_tkw1_impl, @@ -887,14 +925,12 @@ class rocm_aiter_ops: return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype) @staticmethod - def per_1x128_fp8_quant( + def group_fp8_quant( input_2d: torch.Tensor, + group_size: int = 128, ) -> tuple[torch.Tensor, ...]: - """Only applies quantization method for fp8 data type only.""" - from aiter import QuantType, dtypes, get_hip_quant - - aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128) - return aiter_per1x128_quant(input_2d.contiguous(), quant_dtype=dtypes.fp8) + assert group_size == 128, "Group size must be 128" + return torch.ops.vllm.rocm_aiter_group_fp8_quant(input_2d, group_size) @staticmethod def is_triton_gemm_w8a8_tuned(n: int, k: int) -> bool: diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 7d70c01cefbb6..096266c9764e8 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1174,13 +1174,50 @@ def gptq_marlin_repack( return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits) -# gptq_marlin +if hasattr(torch.ops._C, "gptq_marlin_repack"): + + @register_fake("_C::gptq_marlin_repack") + def _gptq_marlin_repack_fake( + b_q_weight: torch.Tensor, + perm: torch.Tensor, + size_k: torch.SymInt, + size_n: torch.SymInt, + num_bits: int, + ) -> torch.Tensor: + pack_factor = 32 // num_bits + marlin_tile_size = 16 + return torch.empty( + (size_k // marlin_tile_size, size_n * marlin_tile_size // pack_factor), + dtype=b_q_weight.dtype, + device=b_q_weight.device, + ) + + +# awq_marlin def awq_marlin_repack( b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int ) -> torch.Tensor: return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) +if hasattr(torch.ops._C, "awq_marlin_repack"): + + @register_fake("_C::awq_marlin_repack") + def _awq_marlin_repack_fake( + b_q_weight: torch.Tensor, + size_k: torch.SymInt, + size_n: torch.SymInt, + num_bits: int, + ) -> torch.Tensor: + pack_factor = 32 // num_bits + marlin_tile_size = 16 + return torch.empty( + (size_k // marlin_tile_size, size_n * marlin_tile_size // pack_factor), + dtype=b_q_weight.dtype, + device=b_q_weight.device, + ) + + def gptq_marlin_moe_repack( b_q_weight: torch.Tensor, perm: torch.Tensor, diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 697beed918693..9275d70fd86a4 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -142,6 +142,17 @@ class AttentionBackend(ABC): def is_sparse(cls) -> bool: return False + @classmethod + def supports_attn_type(cls, attn_type: str) -> bool: + """Check if backend supports a given attention type. + + By default, only supports decoder attention. + Backends should override this to support other attention types. + """ + from vllm.attention import AttentionType + + return attn_type == AttentionType.DECODER + @classmethod def supports_compute_capability(cls, capability: "DeviceCapability") -> bool: return True @@ -171,6 +182,7 @@ class AttentionBackend(ABC): has_sink: bool, use_sparse: bool, device_capability: "DeviceCapability", + attn_type: str, ) -> list[str]: invalid_reasons = [] if not cls.supports_head_size(head_size): @@ -195,6 +207,8 @@ class AttentionBackend(ABC): invalid_reasons.append("non-sparse not supported") if not cls.supports_compute_capability(device_capability): invalid_reasons.append("compute capability not supported") + if not cls.supports_attn_type(attn_type): + invalid_reasons.append(f"attention type {attn_type} not supported") combination_reason = cls.supports_combination( head_size, dtype, diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 487bba76babf1..37f9a4b383ce9 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -291,6 +291,7 @@ class Attention(nn.Module, AttentionLayerBase): block_size, use_mla=False, has_sink=self.has_sink, + attn_type=attn_type, ) else: self.attn_backend = attn_backend diff --git a/vllm/attention/layers/chunked_local_attention.py b/vllm/attention/layers/chunked_local_attention.py index f144e8435b6cf..48fcc6fa736bb 100644 --- a/vllm/attention/layers/chunked_local_attention.py +++ b/vllm/attention/layers/chunked_local_attention.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools -from typing import ClassVar import torch @@ -12,11 +11,16 @@ from vllm.config.vllm import VllmConfig from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.v1.attention.backends.utils import ( AttentionCGSupport, + AttentionMetadataBuilder, CommonAttentionMetadata, make_local_attention_virtual_batches, subclass_attention_backend, ) -from vllm.v1.kv_cache_interface import ChunkedLocalAttentionSpec, KVCacheSpec +from vllm.v1.kv_cache_interface import ( + AttentionSpec, + ChunkedLocalAttentionSpec, + KVCacheSpec, +) from ..layer import Attention @@ -30,9 +34,18 @@ def create_chunked_local_attention_backend( prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_" underlying_builder = underlying_attn_backend.get_builder_cls() + assert issubclass(underlying_builder, AttentionMetadataBuilder) class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore - _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER + @classmethod + def get_cudagraph_support( + cls: type["AttentionMetadataBuilder"], + vllm_config: VllmConfig, + kv_cache_spec: AttentionSpec, + ) -> AttentionCGSupport: + # Explicit override in case the underlying builder specialized this getter. + # @override omitted only because of mypy limitation due to type variable. + return AttentionCGSupport.NEVER def build( self, diff --git a/vllm/attention/layers/encoder_only_attention.py b/vllm/attention/layers/encoder_only_attention.py index 4929bbf5efc73..5e99c99010034 100644 --- a/vllm/attention/layers/encoder_only_attention.py +++ b/vllm/attention/layers/encoder_only_attention.py @@ -74,7 +74,11 @@ class EncoderOnlyAttention(Attention): block_size = 16 underlying_attn_backend = get_attn_backend( - head_size, dtype, kv_cache_dtype, block_size + head_size, + dtype, + kv_cache_dtype, + block_size, + attn_type=AttentionType.ENCODER_ONLY, ) attn_backend = create_encoder_only_attention_backend(underlying_attn_backend) diff --git a/vllm/attention/ops/triton_reshape_and_cache_flash.py b/vllm/attention/ops/triton_reshape_and_cache_flash.py index bbcd560ad56e3..5d2ba154ae018 100644 --- a/vllm/attention/ops/triton_reshape_and_cache_flash.py +++ b/vllm/attention/ops/triton_reshape_and_cache_flash.py @@ -97,7 +97,6 @@ def triton_reshape_and_cache_flash( k_scale: torch.Tensor, # float32 v_scale: torch.Tensor, # float32 ): - num_tokens = key.shape[0] num_heads = key.shape[1] head_size = key.shape[2] block_size = key_cache.shape[1] @@ -155,7 +154,10 @@ def triton_reshape_and_cache_flash( # TODO(ngl): maybe replace with static launch grid to avoid overhead if # using cudagraphs - grid = lambda meta: (int(num_tokens), triton.cdiv(n, meta["TILE_SIZE"])) + grid = lambda meta: ( + slot_mapping.shape[0], + triton.cdiv(n, meta["TILE_SIZE"]), + ) reshape_and_cache_kernel_flash[grid]( key_ptr=key, diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 262cdf0e575b0..1a092db9ce378 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -76,6 +76,7 @@ def get_attn_backend( use_mla: bool = False, has_sink: bool = False, use_sparse: bool = False, + attn_type: str | None = None, ) -> type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" @@ -94,6 +95,7 @@ def get_attn_backend( use_mla=use_mla, has_sink=has_sink, use_sparse=use_sparse, + attn_type=attn_type, ) @@ -106,6 +108,7 @@ def _cached_get_attn_backend( use_mla: bool = False, has_sink: bool = False, use_sparse: bool = False, + attn_type: str | None = None, ) -> type[AttentionBackend]: # Check whether a particular choice of backend was # previously forced. @@ -159,6 +162,7 @@ def _cached_get_attn_backend( use_mla, has_sink, use_sparse, + attn_type, ) else: attention_cls = current_platform.get_attn_backend_cls( @@ -170,6 +174,7 @@ def _cached_get_attn_backend( use_mla, has_sink, use_sparse, + attn_type, ) if not attention_cls: raise ValueError( diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py index 0e9b0fbe2c028..dddb050ec180e 100644 --- a/vllm/benchmarks/serve.py +++ b/vllm/benchmarks/serve.py @@ -49,6 +49,7 @@ from vllm.benchmarks.lib.ready_checker import wait_for_endpoint from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.utils.gc_utils import freeze_gc_heap +from vllm.utils.network_utils import join_host_port MILLISECONDS_TO_SECONDS_CONVERSION = 1000 @@ -1333,8 +1334,9 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: api_url = f"{args.base_url}{args.endpoint}" base_url = f"{args.base_url}" else: - api_url = f"http://{args.host}:{args.port}{args.endpoint}" - base_url = f"http://{args.host}:{args.port}" + host_port = join_host_port(args.host, args.port) + api_url = f"http://{host_port}{args.endpoint}" + base_url = f"http://{host_port}" # Headers headers = None diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index be69075f94f09..60ef6eef21663 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -4,6 +4,7 @@ import ast import dataclasses import hashlib +import operator import os import pprint import time @@ -307,12 +308,24 @@ def split_graph( ) -> tuple[fx.GraphModule, list[SplitItem]]: # split graph by ops subgraph_id = 0 - node_to_subgraph_id = {} - split_op_graphs = [] + node_to_subgraph_id: dict[fx.Node, int] = {} + split_op_graphs: list[int] = [] for node in graph.graph.nodes: if node.op in ("output", "placeholder"): continue + # Check if this is a getitem operation on a node from an earlier subgraph. + # If so, assign it to the same subgraph as its input to avoid passing entire + # tuple as input to submodules, which is against standalone_compile and + # AoTAutograd input requirement. + if node.op == "call_function" and node.target == operator.getitem: + # Assign this getitem to the same subgraph as its input + input_node = node.args[0] + if input_node.op != "placeholder": + assert input_node in node_to_subgraph_id + node_to_subgraph_id[node] = node_to_subgraph_id[input_node] + continue + if should_split(node, splitting_ops): subgraph_id += 1 node_to_subgraph_id[node] = subgraph_id diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index b0cdb08884a3b..11cf0f85c1787 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -299,7 +299,7 @@ class InductorAdaptor(CompilerInterface): self.base_cache_dir = cache_dir[: -len(prefix)] if prefix else cache_dir if disable_cache: return - # redirect the cache directory to a sub-directory + # redirect the cache directory to a subdirectory # set flags so that Inductor and Triton store their cache # in the cache_dir, then users only need to copy the cache_dir # to another machine to reuse the cache. diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 0946fa69171b4..11a18c0e6bb78 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -17,7 +17,7 @@ from torch._dynamo.symbolic_convert import InliningInstructionTranslator import vllm.envs as envs from vllm.compilation.counter import compilation_counter -from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher +from vllm.compilation.wrapper import TorchCompileWithNoGuardsWrapper from vllm.config import ( CompilationMode, VllmConfig, @@ -159,7 +159,7 @@ def support_torch_compile( `mark_unbacked_dims` is a dictionary that maps argument names with a dynamic dim to be decorated with `mark_unbacked`. This is useful if we would like to - enforce that dynamo do not specialize on 0/1 values in the case of dummy input + enforce that dynamo does not specialize on 0/1 values in the case of dummy input such as for vision model compilation """ @@ -246,14 +246,14 @@ def _support_torch_compile( """ A decorator to add support for compiling the forward method of a class. """ - if TorchCompileWrapperWithCustomDispatcher in cls.__bases__: + if TorchCompileWithNoGuardsWrapper in cls.__bases__: # support decorating multiple times return cls # take care of method resolution order # make sure super().__init__ is called on the base class - # other than TorchCompileWrapperWithCustomDispatcher - cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher,) + # other than TorchCompileWithNoGuardsWrapper + cls.__bases__ = cls.__bases__ + (TorchCompileWithNoGuardsWrapper,) old_init = cls.__init__ @@ -290,12 +290,43 @@ def _support_torch_compile( return compilation_counter.num_models_seen += 1 - TorchCompileWrapperWithCustomDispatcher.__init__( - self, compilation_mode=vllm_config.compilation_config.mode - ) + self.compiled = False + TorchCompileWithNoGuardsWrapper.__init__(self) cls.__init__ = __init__ + def _mark_dynamic_inputs(mod, *args, **kwargs): + sig = inspect.signature(mod.__class__.forward) + bound_args = sig.bind(mod, *args, **kwargs) + bound_args.apply_defaults() + for k, dims in dynamic_arg_dims.items(): + arg = bound_args.arguments.get(k) + if arg is not None: + dims = [dims] if isinstance(dims, int) else dims + if isinstance(arg, torch.Tensor): + # In case dims is specified with negative indexing + dims = [arg.ndim + dim if dim < 0 else dim for dim in dims] + torch._dynamo.mark_dynamic(arg, dims) + elif isinstance(arg, IntermediateTensors): + for tensor in arg.tensors.values(): + # In case dims is specified with negative indexing + dims = [tensor.ndim + dim if dim < 0 else dim for dim in dims] + torch._dynamo.mark_dynamic(tensor, dims) + else: + raise ValueError( + "Unsupported dynamic dimensions" + f" {dims} for argument {k} with type {type(arg)}." + ) + if mark_unbacked_dims: + for k, dims in mark_unbacked_dims.items(): + arg = bound_args.arguments.get(k) + if arg is not None: + dims = [dims] if isinstance(dims, int) else dims + if isinstance(arg, torch.Tensor): + # In case dims is specified with negative indexing + dims = [arg.ndim + dim if dim < 0 else dim for dim in dims] + torch._dynamo.decorators.mark_unbacked(arg, dims) + def __call__(self, *args, **kwargs): # torch.compiler.is_compiling() means we are inside the compilation # e.g. TPU has the compilation logic in model runner, so we don't @@ -303,6 +334,7 @@ def _support_torch_compile( if self.do_not_compile or torch.compiler.is_compiling(): return self.forward(*args, **kwargs) + # if aot_compiled_fn is set, just call it. if getattr(self, "aot_compiled_fn", None) is not None: return self.aot_compiled_fn(self, *args, **kwargs) @@ -362,120 +394,84 @@ def _support_torch_compile( ) return self.aot_compiled_fn(self, *args, **kwargs) + if self.compiled: + assert not envs.VLLM_USE_AOT_COMPILE + return TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) + + # This is the path for the first compilation. + # the first compilation needs to have dynamic shapes marked - if len(self.compiled_codes) < 1: - sig = inspect.signature(self.__class__.forward) - bound_args = sig.bind(self, *args, **kwargs) - bound_args.apply_defaults() - for k, dims in dynamic_arg_dims.items(): - arg = bound_args.arguments.get(k) - if arg is not None: - dims = [dims] if isinstance(dims, int) else dims - if isinstance(arg, torch.Tensor): - # In case dims is specified with negative indexing - dims = [arg.ndim + dim if dim < 0 else dim for dim in dims] - torch._dynamo.mark_dynamic(arg, dims) - elif isinstance(arg, IntermediateTensors): - for tensor in arg.tensors.values(): - # In case dims is specified with negative indexing - dims = [ - tensor.ndim + dim if dim < 0 else dim for dim in dims - ] - torch._dynamo.mark_dynamic(tensor, dims) - else: - raise ValueError( - "Unsupported dynamic dimensions" - f" {dims} for argument {k} with type {type(arg)}." - ) - if mark_unbacked_dims: - for k, dims in mark_unbacked_dims.items(): - arg = bound_args.arguments.get(k) - if arg is not None: - dims = [dims] if isinstance(dims, int) else dims - if isinstance(arg, torch.Tensor): - # In case dims is specified with negative indexing - dims = [arg.ndim + dim if dim < 0 else dim for dim in dims] - torch._dynamo.decorators.mark_unbacked(arg, dims) - # here, it is the starting point of the `torch.compile` process - start_monitoring_torch_compile(self.vllm_config) - logger.debug("Start compiling function %s", self.original_code_object) + _mark_dynamic_inputs(self, *args, **kwargs) - # if we don't use custom dispatcher, we can directly call the - # compiled function and let torch.compile handle the dispatching, - # with the overhead of guard evaluation and recompilation. - if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher: - # it seems Dynamo reuse the compilation across instances, - # while we need to make sure the compiled code is not reused. - # we need to control all the compilation of the model. - torch._dynamo.eval_frame.remove_from_cache(self.original_code_object) + # here, it is the starting point of the `torch.compile` process + start_monitoring_torch_compile(self.vllm_config) + original_code_object = self.original_code_object() + logger.debug("Start compiling function %s", original_code_object) - # collect all relevant files traced by Dynamo, - # so that the compilation cache can trigger re-compilation - # properly when any of these files change. + # we do not want tp delete the original code object entries since + # we depend on them now to look up cached compiled functions. + # torch._dynamo.eval_frame.remove_from_cache(original_code_object) - # 1. the file containing the top-level forward function - self.vllm_config.compilation_config.traced_files.add( - self.original_code_object.co_filename - ) + # collect all relevant files traced by Dynamo, + # so that the compilation cache can trigger re-compilation + # properly when any of these files change. - # 2. every time Dynamo sees a function call, it will inline - # the function by calling InliningInstructionTranslator.inline_call_ - # we hijack this function to know all the functions called - # during Dynamo tracing, and their corresponding files - inline_call = InliningInstructionTranslator.inline_call_ + # 1. the file containing the top-level forward function + self.vllm_config.compilation_config.traced_files.add( + original_code_object.co_filename + ) - def patched_inline_call(self_): - code = self_.f_code - self.vllm_config.compilation_config.traced_files.add(code.co_filename) - return inline_call(self_) + # 2. every time Dynamo sees a function call, it will inline + # the function by calling InliningInstructionTranslator.inline_call_ + # we hijack this function to know all the functions called + # during Dynamo tracing, and their corresponding files + inline_call = InliningInstructionTranslator.inline_call_ - # Disable the C++ compilation of symbolic shape guards. C++-fication - # of symbolic shape guards can improve guard overhead. But, since - # vllm skip guards anyways, setting this flag to False can improve - # compile time. - dynamo_config_patches = {} - try: - _ = torch._dynamo.config.enable_cpp_symbolic_shape_guards - dynamo_config_patches["enable_cpp_symbolic_shape_guards"] = False - except AttributeError: - # Note: this config is not available in torch 2.6, we can skip - # if the config doesn't exist - logger.debug("enable_cpp_symbolic_shape_guards config not available") + def patched_inline_call(self_): + code = self_.f_code + self.vllm_config.compilation_config.traced_files.add(code.co_filename) + return inline_call(self_) - with ( - patch.object( - InliningInstructionTranslator, "inline_call_", patched_inline_call - ), - torch._dynamo.config.patch(**dynamo_config_patches), - maybe_use_cudagraph_partition_wrapper(self.vllm_config), - _torch27_patch_tensor_subclasses(), - ): - if envs.VLLM_USE_AOT_COMPILE: - self.aot_compiled_fn = self.aot_compile(*args, **kwargs) - output = self.aot_compiled_fn(self, *args, **kwargs) - assert aot_compilation_path is not None - assert cache_dir is not None - try: - os.makedirs(cache_dir, exist_ok=True) - self.aot_compiled_fn.save_compiled_function( - aot_compilation_path - ) - except Exception as e: - logger.warning( - "Cannot save aot compilation to path %s, error: %s", - aot_compilation_path, - str(e), - ) - else: - output = self.compiled_callable(*args, **kwargs) - return output + # Disable the C++ compilation of symbolic shape guards. C++-fication + # of symbolic shape guards can improve guard overhead. But, since + # vllm skip guards anyways, setting this flag to False can improve + # compile time. + dynamo_config_patches = {} + try: + _ = torch._dynamo.config.enable_cpp_symbolic_shape_guards + dynamo_config_patches["enable_cpp_symbolic_shape_guards"] = False + except AttributeError: + # Note: this config is not available in torch 2.6, we can skip + # if the config doesn't exist + logger.debug("enable_cpp_symbolic_shape_guards config not available") - # usually, capturing the model once is enough, and then we can - # dispatch to the compiled code directly, without going through - # the Dynamo guard mechanism. - with self.dispatch_to_code(0): - model_output = self.forward(*args, **kwargs) - return model_output + with ( + patch.object( + InliningInstructionTranslator, "inline_call_", patched_inline_call + ), + torch._dynamo.config.patch(**dynamo_config_patches), + maybe_use_cudagraph_partition_wrapper(self.vllm_config), + _torch27_patch_tensor_subclasses(), + ): + if envs.VLLM_USE_AOT_COMPILE: + self.aot_compiled_fn = self.aot_compile(*args, **kwargs) + output = self.aot_compiled_fn(self, *args, **kwargs) + assert aot_compilation_path is not None + assert cache_dir is not None + try: + os.makedirs(cache_dir, exist_ok=True) + self.aot_compiled_fn.save_compiled_function(aot_compilation_path) + except Exception as e: + logger.warning( + "Cannot save aot compilation to path %s, error: %s", + aot_compilation_path, + str(e), + ) + else: + output = TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) + + self.compiled = True + return output cls.__call__ = __call__ return cls @@ -487,7 +483,7 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig): Context manager to set/unset customized cudagraph partition wrappers. If we're using Inductor-based graph partitioning, we currently have the - whole `fx.Graph` before Inductor lowering and and the piecewise + whole `fx.Graph` before Inductor lowering and the piecewise splitting happens after all graph passes and fusions. Here, we add a custom hook for Inductor to wrap each partition with our static graph wrapper class to maintain more control over static graph diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 0c2210d72ce07..0e8bb2fc97351 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -18,6 +18,7 @@ if current_platform.is_cuda_alike(): from .fusion import RMSNormQuantFusionPass from .fusion_attn import AttnFusionPass from .qk_norm_rope_fusion import QKNormRoPEFusionPass + from .sequence_parallelism import SequenceParallelismPass if current_platform.is_cuda(): from .collective_fusion import AllReduceFusionPass, AsyncTPPass @@ -25,7 +26,6 @@ if current_platform.is_cuda(): from .fix_functionalization import FixFunctionalizationPass from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context from .noop_elimination import NoOpEliminationPass -from .sequence_parallelism import SequenceParallelismPass logger = init_logger(__name__) diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index 31624a8fdcc0f..bb4dcf12d865d 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools + import torch import torch._inductor.pattern_matcher as pm import torch.fx as fx @@ -10,98 +12,28 @@ from vllm.config import VllmConfig from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + kFp8StaticTensorSym, +) from vllm.platforms import current_platform from .inductor_pass import enable_fake_mode +from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm +from .noop_elimination import NoOpEliminationPass from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass logger = init_logger(__name__) -class _RMSNormAndQuantOpHelper: - """Base helper for RMSNorm and RMSNorm + Quantization functionalization.""" +def get_first_out_wrapper(fn): + @functools.wraps(fn) + def wrapper(*args): + return fn(*args)[0] - def __init__( - self, - epsilon: float, - dtype: torch.dtype, - device: str, - quant_op: torch._ops.OpOverload | None = None, - **kwargs, - ): - self.epsilon = epsilon - self.dtype = dtype - self.device = device - self.quant_op = quant_op - - def _functional_rmsnorm(self, result_buffer, input_tensor, weight_tensor): - return torch.ops.higher_order.auto_functionalized( - torch.ops._C.rms_norm.default, - result=result_buffer, - input=input_tensor, - weight=weight_tensor, - epsilon=self.epsilon, - ) - - def _functional_fused_add_rmsnorm( - self, input_tensor, residual_tensor, weight_tensor - ): - return torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm.default, - input=input_tensor, - residual=residual_tensor, - weight=weight_tensor, - epsilon=self.epsilon, - ) - - def _functional_rmsnorm_then_quant( - self, - rmsnorm_result_buffer, - quant_result_buffer, - input_tensor, - weight_tensor, - scale_tensor, - ): - if self.quant_op is None: - raise RuntimeError( - "_RMSNormAndQuantOpHelper was not initialized with a quant_op." - ) - rmsnorm_out_tuple = self._functional_rmsnorm( - rmsnorm_result_buffer, input_tensor, weight_tensor - ) - quant_out_tuple = torch.ops.higher_order.auto_functionalized( - self.quant_op, - result=quant_result_buffer, - input=rmsnorm_out_tuple[1], - scale=scale_tensor, - ) - return quant_out_tuple - - def _functional_fused_add_rmsnorm_then_quant( - self, - quant_result_buffer, - input_tensor, - residual_tensor, - weight_tensor, - scale_tensor, - ): - if self.quant_op is None: - raise RuntimeError( - "_RMSNormAndQuantOpHelper was not initialized with a quant_op." - ) - fused_add_rmsnorm_out_tuple = self._functional_fused_add_rmsnorm( - input_tensor, residual_tensor, weight_tensor - ) - quant_out_tuple = torch.ops.higher_order.auto_functionalized( - self.quant_op, - result=quant_result_buffer, - input=fused_add_rmsnorm_out_tuple[1], - scale=scale_tensor, - ) - return quant_out_tuple, fused_add_rmsnorm_out_tuple[2] + return wrapper -class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper): +class _SequenceParallelPatternHelper: """Helper for sequence parallelism patterns.""" def __init__( @@ -109,10 +41,10 @@ class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper): epsilon: float, dtype: torch.dtype, device: str, - quant_op: torch._ops.OpOverload | None = None, - **kwargs, ): - super().__init__(epsilon, dtype, device, quant_op=quant_op, **kwargs) + self.epsilon = epsilon + self.dtype = dtype + self.device = device self.tp_group = get_tp_group() self.tp_size = get_tensor_model_parallel_world_size() @@ -131,36 +63,34 @@ class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper): class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper): + def __init__(self, epsilon: float, dtype: torch.dtype, device: str): + super().__init__(epsilon, dtype, device) + self.rmsnorm_matcher = MatcherRMSNorm(epsilon) + def get_inputs(self): input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) - permute = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype) - return [input, permute, arg3_1] + return [input, arg3_1] def register(self, pm_pass: PatternMatcherPass): def pattern( input: torch.Tensor, - permute: torch.Tensor, arg3_1: torch.Tensor, ): all_reduce = self._all_reduce(input) - rmsnorm = self._functional_rmsnorm(permute, all_reduce, arg3_1) + rmsnorm = self.rmsnorm_matcher(all_reduce, arg3_1) - return rmsnorm[1], all_reduce + return rmsnorm, all_reduce def replacement( input: torch.Tensor, - permute: torch.Tensor, arg3_1: torch.Tensor, ): reduce_scatter = self._reduce_scatter(input) - rmsnorm_result = torch.empty_like(reduce_scatter) - rmsnorm = self._functional_rmsnorm(rmsnorm_result, reduce_scatter, arg3_1) - - all_gather = self._all_gather(rmsnorm[1]) - + rmsnorm = self.rmsnorm_matcher(reduce_scatter, arg3_1) + all_gather = self._all_gather(rmsnorm) return all_gather, reduce_scatter pm.register_replacement( @@ -169,6 +99,10 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper): class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper): + def __init__(self, epsilon: float, dtype: torch.dtype, device: str): + super().__init__(epsilon, dtype, device) + self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon) + def get_inputs(self): mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) @@ -188,67 +122,34 @@ class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper): rms_norm_weights: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: all_reduce = self._all_reduce(mm_1) - rmsnorm = self._functional_fused_add_rmsnorm( - all_reduce, residual, rms_norm_weights - ) - return rmsnorm[1], rmsnorm[2] + rmsnorm = self.rmsnorm_matcher(all_reduce, rms_norm_weights, residual) + return rmsnorm[0], rmsnorm[1] def replacement( residual: torch.Tensor, mm_1: torch.Tensor, rms_norm_weights: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: + # pattern matcher replaces from top-to-bottom, + # so residual is still the full size here. + # once the seqpar pattern with the previous rmsnorm is replaced reduce_scatter = self._reduce_scatter(mm_1) - rmsnorm = self._functional_fused_add_rmsnorm( - reduce_scatter, residual, rms_norm_weights - ) - all_gather = self._all_gather(rmsnorm[1]) - return all_gather, rmsnorm[2] + residual = residual[0 : reduce_scatter.size(0), ...] + rmsnorm = self.rmsnorm_matcher(reduce_scatter, rms_norm_weights, residual) + all_gather = self._all_gather(rmsnorm[0]) + # shape of residual changes but that's fine, + # next node is already slicing it, now becomes a noop + return all_gather, rmsnorm[1] pm.register_replacement( pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass ) - - -class LastAllReduceRMSNormPattern(_SequenceParallelPatternHelper): - def get_inputs(self): - mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) - - residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) - rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype) - - return [ - residual, - mm_1, - rms_norm_weights, - ] - - def register(self, pm_pass: PatternMatcherPass): - def pattern( - residual: torch.Tensor, - mm_1: torch.Tensor, - rms_norm_weights: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: - all_reduce = self._all_reduce(mm_1) - rmsnorm = self._functional_fused_add_rmsnorm( - all_reduce, residual, rms_norm_weights - ) - return rmsnorm[1] - - def replacement( - residual: torch.Tensor, - mm_1: torch.Tensor, - rms_norm_weights: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: - reduce_scatter = self._reduce_scatter(mm_1) - rmsnorm = self._functional_fused_add_rmsnorm( - reduce_scatter, residual, rms_norm_weights - ) - normalized = self._all_gather(rmsnorm[1]) - return normalized - pm.register_replacement( - pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + get_first_out_wrapper(pattern), + get_first_out_wrapper(replacement), + self.get_inputs(), + pm.fwd_only, + pm_pass, ) @@ -257,52 +158,41 @@ FP8_DTYPE = current_platform.fp8_dtype() class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): def __init__( - self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload + self, + epsilon: float, + dtype: torch.dtype, + device: str, ): - super().__init__(epsilon, dtype, device, quant_op=op) + super().__init__(epsilon, dtype, device) + self.rmsnorm_matcher = MatcherRMSNorm(epsilon) + self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym) def get_inputs(self): input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype) - rmsnorm_result = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) - quant_result = torch.empty([1, 8, 4], device=self.device, dtype=FP8_DTYPE) weight = torch.empty([4], device=self.device, dtype=self.dtype) scale = torch.tensor(1.0, device=self.device, dtype=torch.float32) - return [input, rmsnorm_result, quant_result, weight, scale] + return [input, weight, scale] def register(self, pm_pass: PatternMatcherPass): def pattern( input: torch.Tensor, - rmsnorm_result: torch.Tensor, - quant_result: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, ): all_reduce = self._all_reduce(input) - static_fp8 = self._functional_rmsnorm_then_quant( - rmsnorm_result, quant_result, all_reduce, weight, scale - ) - return static_fp8[1], all_reduce + rms = self.rmsnorm_matcher(all_reduce, weight) + quant, _ = self.quant_matcher(rms, scale) + return quant, all_reduce def replacement( input: torch.Tensor, - rmsnorm_result: torch.Tensor, - quant_result: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, ): reduce_scatter = self._reduce_scatter(input) - - rmsnorm_result = torch.empty_like( - reduce_scatter, dtype=rmsnorm_result.dtype - ) - quant_result = torch.empty_like( - rmsnorm_result, # Output of RMSNorm - dtype=quant_result.dtype, - ) - static_fp8 = self._functional_rmsnorm_then_quant( - rmsnorm_result, quant_result, reduce_scatter, weight, scale - ) - all_gather = self._all_gather(static_fp8[1]) + rms = self.rmsnorm_matcher(reduce_scatter, weight) + quant, _ = self.quant_matcher(rms, scale) + all_gather = self._all_gather(quant) return all_gather, reduce_scatter @@ -312,118 +202,64 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): - def __init__( - self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload - ): - super().__init__(epsilon, dtype, device, quant_op=op) + def __init__(self, epsilon: float, dtype: torch.dtype, device: str): + super().__init__(epsilon, dtype, device) + self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon) + self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym) def get_inputs(self): mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) - residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype) - result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE) scale = torch.empty([1, 1], device=self.device, dtype=torch.float32) - return [ - result, - residual, - mm_1, - rms_norm_weights, - scale, - ] + return [residual, mm_1, rms_norm_weights, scale] def register(self, pm_pass: PatternMatcherPass): def pattern( - result: torch.Tensor, residual: torch.Tensor, mm_1: torch.Tensor, rms_norm_weights: torch.Tensor, scale: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: all_reduce = self._all_reduce(mm_1) - static_fp8, rmsnorm_residual_out = ( - self._functional_fused_add_rmsnorm_then_quant( # noqa: E501 - result, all_reduce, residual, rms_norm_weights, scale - ) + rms, residual_out = self.rmsnorm_matcher( + all_reduce, rms_norm_weights, residual ) - return static_fp8[1], rmsnorm_residual_out + quant, _ = self.quant_matcher(rms, scale) + return quant, residual_out def replacement( - result: torch.Tensor, residual: torch.Tensor, mm_1: torch.Tensor, rms_norm_weights: torch.Tensor, scale: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: + # pattern matcher replaces from top-to-bottom, + # so residual is still the full size here. + # add a temporary slice which will become a noop + # once the seqpar pattern with the previous rmsnorm is replaced reduce_scatter = self._reduce_scatter(mm_1) - quant_result_buf = torch.empty_like(reduce_scatter, dtype=result.dtype) - static_fp8, rmsnorm_residual_out = ( - self._functional_fused_add_rmsnorm_then_quant( # noqa: E501 - quant_result_buf, reduce_scatter, residual, rms_norm_weights, scale - ) + residual = residual[0 : reduce_scatter.size(0), ...] + rms, residual_out = self.rmsnorm_matcher( + reduce_scatter, rms_norm_weights, residual ) - all_gather = self._all_gather(static_fp8[1]) - return all_gather, rmsnorm_residual_out + quant, _ = self.quant_matcher(rms, scale) + all_gather = self._all_gather(quant) + # shape of residual changes but that's fine, + # next node is already slicing it, now becomes a noop + return all_gather, residual_out pm.register_replacement( pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass ) - -class LastAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): - def __init__( - self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload - ): - super().__init__(epsilon, dtype, device, quant_op=op) - - def get_inputs(self): - mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) - - residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) - rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype) - result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE) - scale = torch.empty([1, 1], device=self.device, dtype=torch.float32) - - return [ - result, - residual, - mm_1, - rms_norm_weights, - scale, - ] - - def register(self, pm_pass: PatternMatcherPass): - def pattern( - result: torch.Tensor, - residual: torch.Tensor, - mm_1: torch.Tensor, - rms_norm_weights: torch.Tensor, - scale: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: - all_reduce = self._all_reduce(mm_1) - static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant( - result, all_reduce, residual, rms_norm_weights, scale - ) - return static_fp8[1] - - def replacement( - result: torch.Tensor, - residual: torch.Tensor, - mm_1: torch.Tensor, - rms_norm_weights: torch.Tensor, - scale: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: - reduce_scatter = self._reduce_scatter(mm_1) - quant_result_buf = torch.empty_like(reduce_scatter, dtype=result.dtype) - static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant( - quant_result_buf, reduce_scatter, residual, rms_norm_weights, scale - ) - normalized = self._all_gather(static_fp8[1]) - return normalized - pm.register_replacement( - pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + get_first_out_wrapper(pattern), + get_first_out_wrapper(replacement), + self.get_inputs(), + pm.fwd_only, + pm_pass, ) @@ -445,27 +281,45 @@ class SequenceParallelismPass(VllmPatternMatcherPass): GEMM + ReduceScatter and AllGather + GEMM fusions. These fusions can significantly reduce communication overhead and improve overall model performance. + + + This pass splits up the residual tensor across TP ranks and hence divides its size. + Because the pattern matcher starts at the end of the graph, the replacement + contains a slice that temporarily conforms the input residual to the correct size. + After all patterns have been matched, we use a NoOpEliminationPass to clean up + what have now become no-op slices. + + Note that an older version of the pass did not need this as it operated only on + custom rms_norm and fused_rms_norm_add custom ops which did not complain about + mismatched shapes during replacement. So this approach has the same assumption that + correctness is only maintained if all rms_norm operations are split across ranks. + + Correctness-wise, this is approach strictly better than before - before, + the graph was incorrect semantically and shape-wise during the pass. + With this approach there's only semantic incorrectness during the pass. + Both approaches restore a correct graph once all patterns are matched. """ @enable_fake_mode def __init__(self, config: VllmConfig): super().__init__(config) + # Used to cleanup redundant views created temporarily + # to circumvent residual shape change issues + self.noop_cleanup = NoOpEliminationPass(config) + self.noop_cleanup.pass_name = f"{self.pass_name}.{self.noop_cleanup.pass_name}" + self.patterns: PatternMatcherPass = PatternMatcherPass( pass_name="sequence_parallelism_pass" ) for epsilon in [1e-5, 1e-6]: # RMSNorm + Static FP8 quantization patterns - fp8_quant_op = torch.ops._C.static_scaled_fp8_quant.default FirstAllReduceRMSNormStaticFP8Pattern( - epsilon, self.model_dtype, self.device, fp8_quant_op + epsilon, self.model_dtype, self.device ).register(self.patterns) MiddleAllReduceRMSNormStaticFP8Pattern( - epsilon, self.model_dtype, self.device, fp8_quant_op - ).register(self.patterns) - LastAllReduceRMSNormStaticFP8Pattern( - epsilon, self.model_dtype, self.device, fp8_quant_op + epsilon, self.model_dtype, self.device ).register(self.patterns) # Normal RMSNorm patterns @@ -477,9 +331,6 @@ class SequenceParallelismPass(VllmPatternMatcherPass): epsilon, self.model_dtype, self.device ).register(self.patterns) - LastAllReduceRMSNormPattern( - epsilon, self.model_dtype, self.device - ).register(self.patterns) self.dump_patterns(config, self.patterns) def is_applicable(self, shape: int | None) -> bool: @@ -508,3 +359,5 @@ class SequenceParallelismPass(VllmPatternMatcherPass): def __call__(self, graph: fx.Graph): self.matched_count = self.patterns.apply(graph) logger.debug("Replaced %s patterns", self.matched_count) + # Clean up reshape nodes + self.noop_cleanup(graph) diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 4d26619bd128c..493e57f97f0f4 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -4,11 +4,11 @@ import os import sys from abc import abstractmethod -from collections.abc import Callable from contextlib import contextmanager from types import CodeType import torch +import torch._C._dynamo.guards import vllm.envs as envs from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config @@ -17,88 +17,153 @@ from vllm.logger import init_logger logger = init_logger(__name__) -class TorchCompileWrapperWithCustomDispatcher: +def _noop_add_global_state_guard(self, *args, **kwargs): + """No-op to skip the GLOBAL_STATE guard entirely""" + pass + + +def _noop_add_torch_function_mode_stack_guard(self, *args, **kwargs): + """No-op to skip the TORCH_FUNCTION_MODE_STACK guard entirely""" + pass + + +@contextmanager +def _compilation_context(): + """Context manager for compilation settings and patches. + + This manager: + 1. Sets higher dynamo cache limits for compilation. (Needed for + qwen2_5_vl see test_qwen2_5_vl_evs_functionality). + Generally a recompilation can happen whenever we use a new + backend instance in torch.compile. + 2. Patches out add_global_state_guard to skip GLOBAL_STATE guards + 3. Patches out add_torch_function_mode_stack_guard to skip + TORCH_FUNCTION_MODE_STACK guards. + 4. Restores everything when compilation completes """ - A wrapper class for torch.compile, with a custom dispatch logic. - Subclasses should: - 1. Implement the forward method - 2. Implement the dispatch logic in the __call__ method - It can use `self.compiled_codes` to access the compiled bytecode, - and `with self.dispatch_to_code(index):` to dispatch to - the compiled code. - 3. Implement the `__init__` method to determine how to call - `torch.compile` over the forward method. + # Save original values + original_global_state_guard = ( + torch._C._dynamo.guards.GuardManager.add_global_state_guard + ) + original_torch_function_mode_stack_guard = ( + torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard + ) + original_cache_size = torch._dynamo.config.cache_size_limit + original_accumulated_cache = torch._dynamo.config.accumulated_cache_size_limit + + try: + # Set higher cache limits for compilation + torch._dynamo.config.cache_size_limit = 2048 + torch._dynamo.config.accumulated_cache_size_limit = 8192 + + # Patch guard manager + torch._C._dynamo.guards.GuardManager.add_global_state_guard = ( + _noop_add_global_state_guard + ) + torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard = ( + _noop_add_torch_function_mode_stack_guard + ) + yield + finally: + # Restore original values + torch._C._dynamo.guards.GuardManager.add_global_state_guard = ( + original_global_state_guard + ) + torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard = ( + original_torch_function_mode_stack_guard + ) + torch._dynamo.config.cache_size_limit = original_cache_size + torch._dynamo.config.accumulated_cache_size_limit = original_accumulated_cache + + +class TorchCompileWithNoGuardsWrapper: + """ + A wrapper class for torch.compile, it ensures that all guards are dropped + when CompilationMode is not CompilationMode.STOCK_TORCH_COMPILE. + When guards are dropped, the first time __call__ is invoked, a single + compilation is triggered. Dynamo should never be traced again after that + since we drop all guards. """ - def __init__( - self, - compiled_callable: Callable | None = None, - compilation_mode: CompilationMode = CompilationMode.NONE, - ): + def __init__(self): + self.compiled = False + vllm_config = get_current_vllm_config() self.vllm_config = vllm_config - if compiled_callable is None: - # default compilation settings - # compiling the forward method + mode = vllm_config.compilation_config.mode + if mode is None: + raise RuntimeError("Compilation mode cannot be NO_COMPILATION") - backend = vllm_config.compilation_config.init_backend(vllm_config) - options = None - if isinstance(backend, str) and backend == "inductor": - options = ( - get_current_vllm_config().compilation_config.inductor_compile_config - ) - if envs.VLLM_USE_AOT_COMPILE: - options = options or {} - # This effectively drop all the guards. - # We need this because bytecode hook is not used any more to - # drop guards in the AOT compile mode. - options["guard_filter_fn"] = lambda guards: [False for _ in guards] - if hasattr(torch._dynamo.config, "enable_aot_compile"): - torch._dynamo.config.enable_aot_compile = True - else: - msg = "torch._dynamo.config.enable_aot_compile is not " - msg += "available. AOT compile is disabled and please " - msg += "upgrade PyTorch version to use AOT compile." - logger.warning(msg) + backend = vllm_config.compilation_config.init_backend(vllm_config) + options = {} - compiled_callable = torch.compile( - self.forward, fullgraph=True, backend=backend, options=options - ) + if isinstance(backend, str) and backend == "inductor": + options = vllm_config.compilation_config.inductor_compile_config - self.compiled_callable = compiled_callable - self.original_code_object = self.__class__.forward.__code__ - self.compiled_codes: list[CodeType] = [] - torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook) + if mode != CompilationMode.STOCK_TORCH_COMPILE: + # Drop all the guards. + options["guard_filter_fn"] = lambda x: [False for _ in x] - # read the env var to determine whether to use the custom dispatcher - # subclasses can use this to switch between the custom dispatcher - # and the default Dynamo guard mechanism. - self.use_custom_dispatcher: bool = ( - compilation_mode >= CompilationMode.DYNAMO_TRACE_ONCE + if envs.VLLM_USE_AOT_COMPILE: + if hasattr(torch._dynamo.config, "enable_aot_compile"): + torch._dynamo.config.enable_aot_compile = True + else: + msg = "torch._dynamo.config.enable_aot_compile is not " + msg += "available. AOT compile is disabled and please " + msg += "upgrade PyTorch version to use AOT compile." + logger.warning(msg) + + self._compiled_callable = torch.compile( + self.forward, + fullgraph=True, + dynamic=False, + backend=backend, + options=options, ) + if envs.VLLM_USE_BYTECODE_HOOK and mode != CompilationMode.STOCK_TORCH_COMPILE: + torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook) + self._compiled_bytecode = None + def aot_compile(self, *args, **kwargs): - if not hasattr(self.compiled_callable, "aot_compile"): + if not hasattr(self._compiled_callable, "aot_compile"): raise RuntimeError( "aot_compile is not supported by the current configuration. " + "Please make sure torch.compile is enabled with the latest " + f"version of PyTorch (current using torch: {torch.__version__})" ) - return self.compiled_callable.aot_compile((args, kwargs)) + return self._compiled_callable.aot_compile((args, kwargs)) def __call__(self, *args, **kwargs): - """Implement the dispatch logic here, beyond the torch.compile mode. - NOTE: this function can have additional arguments beyond the forward - method, for directly dispatching to the compiled code. - """ - return self.compiled_callable(*args, **kwargs) + if envs.VLLM_USE_BYTECODE_HOOK: + if ( + self.vllm_config.compilation_config.mode + == CompilationMode.STOCK_TORCH_COMPILE + ): + return self._compiled_callable(*args, **kwargs) + + if not self._compiled_bytecode: + # Make sure a compilation is triggered by clearing dynamo + # cache. + torch._dynamo.eval_frame.remove_from_cache(self.original_code_object()) + return self._compiled_callable(*args, **kwargs) + else: + with self._dispatch_to_compiled_code(): + return self.forward(*args, **kwargs) + else: + with _compilation_context(): + return self._compiled_callable(*args, **kwargs) @abstractmethod def forward(self, *args, **kwargs): ... + def original_code_object(self) -> CodeType: + """Return the original code object of the forward method.""" + return self.__class__.forward.__code__ + def bytecode_hook(self, old_code: CodeType, new_code: CodeType): """Hook to save the compiled bytecode for direct execution.""" - if old_code is not self.original_code_object: + if old_code is not self.original_code_object(): return # code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25 frame = sys._getframe() @@ -114,7 +179,7 @@ class TorchCompileWrapperWithCustomDispatcher: if frame.f_locals["self"] is not self: return - self.compiled_codes.append(new_code) + self._compiled_bytecode = new_code path = self.vllm_config.compile_debug_dump_path() if path: @@ -153,16 +218,21 @@ class TorchCompileWrapperWithCustomDispatcher: raise RuntimeError(msg) @contextmanager - def dispatch_to_code(self, index: int): - """Context manager to dispatch to the compiled code. + def _dispatch_to_compiled_code(self): + # noqa: E501 + """ + Context manager to dispatch to internally compiled code for torch<2.8. Why does this work? Because Dynamo guarantees that the compiled bytecode has exactly the same arguments, cell variables, and free variables as the original code. Therefore we can directly switch the code object in the function and call it. - See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 - for more details. - """ - self.__class__.forward.__code__ = self.compiled_codes[index] - yield - self.__class__.forward.__code__ = self.original_code_object + See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details. + """ # noqa: E501 line too long + original = self.original_code_object() + assert self._compiled_bytecode is not None + self.__class__.forward.__code__ = self._compiled_bytecode + try: + yield + finally: + self.__class__.forward.__code__ = original diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index b0d1bc2bab306..088d0b1af757a 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -18,6 +18,7 @@ from vllm.config.utils import config from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils.import_utils import resolve_obj_by_qualname +from vllm.utils.math_utils import round_up from vllm.utils.torch_utils import is_torch_equal_or_newer if TYPE_CHECKING: @@ -320,9 +321,10 @@ class CompilationConfig: If None, defaults to attention ops for piecewise cudagraphs. If empty list [], no ops are excluded (suitable for full cudagraphs).""" - compile_mm_encoder: bool = True + compile_mm_encoder: bool = False """Whether or not to compile the multimodal encoder. - Currently, this only works for `Qwen2_5_vl`.""" + Currently, this only works for `Qwen2_5_vl` on selected platforms. + Disabled by default until more models are supported/tested to work.""" # Inductor capture use_inductor: bool | None = None @@ -772,19 +774,8 @@ class CompilationConfig: if self.cudagraph_capture_sizes: assert self.cudagraph_capture_sizes[-1] == self.max_cudagraph_capture_size - # pre-compute the mapping from batch size to padded graph size - self.bs_to_padded_graph_size = [ - 0 for i in range(self.max_cudagraph_capture_size + 1) - ] - for end, start in zip( - self.cudagraph_capture_sizes + [self.max_cudagraph_capture_size + 1], - [0] + self.cudagraph_capture_sizes, - ): - for bs in range(start, end): - if bs == start: - self.bs_to_padded_graph_size[bs] = start - else: - self.bs_to_padded_graph_size[bs] = end + # May get recomputed in the model runner if adjustment is needed for spec-decode + self.compute_bs_to_padded_graph_size() def set_splitting_ops_for_v1(self): # NOTE: this function needs to be called only when mode is @@ -921,3 +912,64 @@ class CompilationConfig: enable_str, op, ) + + def adjust_cudagraph_sizes_for_spec_decode( + self, uniform_decode_query_len: int, tensor_parallel_size: int + ): + multiple_of = uniform_decode_query_len + if tensor_parallel_size > 1: + multiple_of = max(uniform_decode_query_len, tensor_parallel_size) + if ( + multiple_of % uniform_decode_query_len != 0 + or multiple_of % tensor_parallel_size != 0 + ): + raise ValueError( + f"Can't determine cudagraph shapes that are both a " + f"multiple of {uniform_decode_query_len} " + f"(num_speculative_tokens + 1) required by spec-decode " + f"and {tensor_parallel_size} (tensor_parallel_size) " + f"required by sequence parallelism please adjust " + f"num_speculative_tokens or disable sequence parallelism" + ) + + if not self.cudagraph_capture_sizes or multiple_of <= 1: + return + + assert self.max_cudagraph_capture_size is not None + rounded_sizes = sorted( + set( + round_up(size, multiple_of) + for size in self.cudagraph_capture_sizes + if round_up(size, multiple_of) <= self.max_cudagraph_capture_size + ) + ) + + if len(rounded_sizes) == 0: + logger.warning( + "No valid cudagraph sizes after rounding to multiple of " + " num_speculative_tokens + 1 (%d); please adjust num_speculative_tokens" + " or max_cudagraph_capture_size (or cudagraph_capture_sizes)", + multiple_of, + ) + return + + self.max_cudagraph_capture_size = rounded_sizes[-1] + self.cudagraph_capture_sizes = rounded_sizes + + # Recompute after adjusting the cudagraph sizes + self.compute_bs_to_padded_graph_size() + + def compute_bs_to_padded_graph_size(self): + # pre-compute the mapping from batch size to padded graph size + self.bs_to_padded_graph_size = [ + 0 for i in range(self.max_cudagraph_capture_size + 1) + ] + for end, start in zip( + self.cudagraph_capture_sizes + [self.max_cudagraph_capture_size + 1], + [0] + self.cudagraph_capture_sizes, + ): + for bs in range(start, end): + if bs == start: + self.bs_to_padded_graph_size[bs] = start + else: + self.bs_to_padded_graph_size[bs] = end diff --git a/vllm/config/model.py b/vllm/config/model.py index c47b619118ff2..b3a28af6de389 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -732,7 +732,7 @@ class ModelConfig: return self def _get_transformers_backend_cls(self) -> str: - """Determine which Transformers backend class will be used if + """Determine which Transformers modeling backend class will be used if `model_impl` is set to `transformers` or `auto`.""" cls = "Transformers" # If 'hf_config != hf_text_config' it's a nested config, i.e. multimodal @@ -746,8 +746,8 @@ class ModelConfig: # User specified value take precedence if self.runner != "auto": runner = self.runner - # Only consider Transformers backend pooling classes if we're wrapping an - # architecture that defaults to pooling. Otherwise, we return the LM class + # Only consider Transformers modeling backend pooling classes if we're wrapping + # an architecture that defaults to pooling. Otherwise, we return the LM class # and use adapters. if runner == "pooling" and task in {"embed", "classify"}: if task == "embed": @@ -759,7 +759,7 @@ class ModelConfig: return cls def using_transformers_backend(self) -> bool: - """Check if the model is using the Transformers backend class.""" + """Check if the model is using the Transformers modeling backend class.""" used_cls = self._model_info.architecture transformers_backend_cls = self._get_transformers_backend_cls() return used_cls == transformers_backend_cls @@ -1183,6 +1183,14 @@ class ModelConfig: f"but got {decode_context_parallel_size}" ) + num_q_per_kv = total_num_attention_heads // total_num_kv_heads + assert num_q_per_kv % decode_context_parallel_size == 0, ( + f"Total number of q per kv attn heads ({num_q_per_kv})" + " must be divisible by dcp world size when enable " + "decode context parallel for GQA " + f"({parallel_config.decode_context_parallel_size})." + ) + def get_sliding_window(self) -> int | None: """Get the sliding window size from the HF text config if present.""" return getattr(self.hf_text_config, "sliding_window", None) @@ -1342,7 +1350,8 @@ class ModelConfig: # Ernie VL's remote code uses list[int]... # The values are always the same so we just take the first one. return num_experts[0] - return num_experts + # Coerce to 0 if explicitly set to None + return num_experts or 0 def get_layers_start_end_indices( self, parallel_config: ParallelConfig diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 61bcd15e06a84..9a6326d62e82e 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -210,6 +210,18 @@ class ParallelConfig: class is dynamically inherited by the worker class. This is used to inject new attributes and methods to the worker class for use in collective_rpc calls.""" + master_addr: str = "127.0.0.1" + """distributed master address for multi-node distributed + inference when distributed_executor_backend is mp.""" + master_port: int = 29501 + """distributed master port for multi-node distributed + inference when distributed_executor_backend is mp.""" + node_rank: int = 0 + """distributed node rank for multi-node distributed + inference when distributed_executor_backend is mp.""" + nnodes: int = 1 + """num of nodes for multi-node distributed + inference when distributed_executor_backend is mp.""" world_size: int = Field(init=False) """world_size is TPxPP, it affects the number of workers we create.""" @@ -387,6 +399,23 @@ class ParallelConfig: and self.data_parallel_size > 1 ) + @property + def node_rank_within_dp(self) -> int: + return self.node_rank % self.nnodes_within_dp + + @property + def nnodes_within_dp(self) -> int: + if self.nnodes == 1: + return 1 + data_parallel_node_size = ( + self.data_parallel_size // self.data_parallel_size_local + ) + return self.nnodes // data_parallel_node_size + + @property + def local_world_size(self) -> int: + return self.world_size // self.nnodes_within_dp + @staticmethod def has_unfinished_dp(dp_group: ProcessGroup, has_unfinished: bool) -> bool: tensor = torch.tensor([has_unfinished], dtype=torch.int32, device="cpu") @@ -528,6 +557,8 @@ class ParallelConfig: ray_found = ray_utils.ray_is_available() if current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD: backend = "uni" + elif current_platform.is_cuda() and self.nnodes > 1: + backend = "mp" elif ( current_platform.is_cuda() and cuda_device_count_stateless() < self.world_size @@ -565,6 +596,10 @@ class ParallelConfig: "max_parallel_loading_workers is currently " "not supported and will be ignored." ) + if self.distributed_executor_backend != "mp" and self.nnodes > 1: + raise ValueError( + "nnodes > 1 can only be set when distributed exectuor backend is mp." + ) @property def use_ray(self) -> bool: @@ -607,6 +642,11 @@ class ParallelConfig: "Disabled the custom all-reduce kernel because it is not " "supported on current platform." ) + if self.nnodes > 1: + self.disable_custom_all_reduce = True + logger.debug( + "Disabled the custom all-reduce since we are running on multi-node." + ) if self.ray_workers_use_nsight and not self.use_ray: raise ValueError( "Unable to use nsight profiling unless workers run with Ray." diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index 47aa343527b39..8194295ffedb6 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -4,19 +4,14 @@ import hashlib from collections.abc import Callable from dataclasses import InitVar -from typing import TYPE_CHECKING, Any, Literal, cast +from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast -from pydantic import Field, field_validator, model_validator +from pydantic import Field, field_validator from pydantic.dataclasses import dataclass -from typing_extensions import Self +from typing_extensions import Self, deprecated from vllm.config.utils import config from vllm.logger import init_logger -from vllm.utils import ( - DEFAULT_MAX_NUM_BATCHED_TOKENS, - MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, - POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, -) from vllm.utils.import_utils import resolve_obj_by_qualname if TYPE_CHECKING: @@ -33,25 +28,25 @@ SchedulerPolicy = Literal["fcfs", "priority"] class SchedulerConfig: """Scheduler configuration.""" + DEFAULT_MAX_NUM_BATCHED_TOKENS: ClassVar[int] = 2048 + DEFAULT_MAX_NUM_SEQS: ClassVar[int] = 128 + runner_type: RunnerType = "generate" """The runner type to launch for the model.""" - max_num_batched_tokens: int = Field(default=None, ge=1) + max_num_batched_tokens: int = Field(default=DEFAULT_MAX_NUM_BATCHED_TOKENS, ge=1) """Maximum number of tokens to be processed in a single iteration. - This config has no static default. If left unspecified by the user, it will - be set in `EngineArgs.create_engine_config` based on the usage context.""" + The default value here is mainly for convenience when testing. + In real usage, this should be set in `EngineArgs.create_engine_config`. + """ - max_num_seqs: int = Field(default=None, ge=1) + max_num_seqs: int = Field(default=DEFAULT_MAX_NUM_SEQS, ge=1) """Maximum number of sequences to be processed in a single iteration. - This config has no static default. If left unspecified by the user, it will - be set in `EngineArgs.create_engine_config` based on the usage context.""" - - max_model_len: int = Field(default=None, ge=1) - """Maximum length of a sequence (including prompt and generated text). This - is primarily set in `ModelConfig` and that value should be manually - duplicated here.""" + The default value here is mainly for convenience when testing. + In real usage, this should be set in `EngineArgs.create_engine_config`. + """ max_num_partial_prefills: int = Field(default=1, ge=1) """For chunked prefill, the maximum number of sequences that can be @@ -76,13 +71,23 @@ class SchedulerConfig: NOTE: This will be replaced by speculative config in the future; it is present to enable correctness tests until then.""" - enable_chunked_prefill: bool = Field(default=None) + enable_chunked_prefill: bool = True """If True, prefill requests can be chunked based - on the remaining max_num_batched_tokens.""" + on the remaining `max_num_batched_tokens`. + + The default value here is mainly for convenience when testing. + In real usage, this should be set in `EngineArgs.create_engine_config`. + """ is_multimodal_model: bool = False """True if the model is multimodal.""" + max_model_len: InitVar[int] = 8192 + """Maximum length of a sequence (including prompt and generated text). + + Note: This is stored in the ModelConfig, and is used only here to + provide fallbacks and validate other attributes.""" + is_encoder_decoder: InitVar[bool] = False """True if the model is an encoder-decoder model. @@ -111,9 +116,6 @@ class SchedulerConfig: - "priority" means requests are handled based on given priority (lower value means earlier handling) and time of arrival deciding any ties).""" - chunked_prefill_enabled: bool = Field(init=False) - """True if chunked prefill is enabled.""" - disable_chunked_mm_input: bool = False """If set to true and chunked prefill is enabled, we do not want to partially schedule a multimodal item. Only used in V1 @@ -142,6 +144,12 @@ class SchedulerConfig: speculative decoding and pipeline parallelism. """ + stream_interval: int = Field(default=1, ge=1) + """The interval (or buffer size) for streaming in terms of token length. + A smaller value (1) makes streaming smoother by sending each token immediately, + while a larger value (e.g., 10) reduces host overhead and may increase throughput + by batching multiple tokens before sending.""" + def get_scheduler_cls(self) -> type["SchedulerInterface"]: if self.scheduler_cls is None: if self.async_scheduling: @@ -182,15 +190,7 @@ class SchedulerConfig: hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str - @field_validator( - "max_num_batched_tokens", - "max_num_seqs", - "max_model_len", - "enable_chunked_prefill", - "scheduler_cls", - "async_scheduling", - mode="wrap", - ) + @field_validator("scheduler_cls", "async_scheduling", mode="wrap") @classmethod def _skip_none_validation(cls, value: Any, handler: Callable) -> Any: """Skip validation if the value is `None` when initialisation is delayed.""" @@ -198,17 +198,10 @@ class SchedulerConfig: return value return handler(value) - def __post_init__(self, is_encoder_decoder: bool) -> None: - if self.max_model_len is None: - self.max_model_len = 8192 - - if self.max_num_seqs is None: - self.max_num_seqs = 128 - + def __post_init__(self, max_model_len: int, is_encoder_decoder: bool) -> None: if is_encoder_decoder: # Chunked prefill should be disabled for encoder-decoder models. self.disable_chunked_mm_input = True - self.chunked_prefill_enabled = False self.enable_chunked_prefill = False self.long_prefill_token_threshold = 0 logger.info( @@ -216,37 +209,6 @@ class SchedulerConfig: " prefix caching; disabling both." ) - if self.max_num_batched_tokens is None: - if self.enable_chunked_prefill: - self.max_num_batched_tokens = DEFAULT_MAX_NUM_BATCHED_TOKENS - else: - # If max_model_len is too short, use - # DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value - # for higher throughput. - self.max_num_batched_tokens = max( - self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS - ) - - if self.runner_type == "pooling": - # Choose specific value for higher throughput - self.max_num_batched_tokens = max( - self.max_num_batched_tokens, - POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, - ) - if self.is_multimodal_model: - # The value needs to be at least the number of multimodal tokens - self.max_num_batched_tokens = max( - self.max_num_batched_tokens, - MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, - ) - - # When using default settings, - # Ensure max_num_batched_tokens does not exceed model limit. - # Some models (e.g., Whisper) have embeddings tied to max length. - self.max_num_batched_tokens = min( - self.max_num_seqs * self.max_model_len, self.max_num_batched_tokens - ) - self.max_num_encoder_input_tokens = self.max_num_batched_tokens self.encoder_cache_size = self.max_num_batched_tokens @@ -256,10 +218,9 @@ class SchedulerConfig: self.max_num_batched_tokens, ) - self.chunked_prefill_enabled = self.enable_chunked_prefill if self.max_num_partial_prefills > 1: if self.long_prefill_token_threshold == 0: - self.long_prefill_token_threshold = int(self.max_model_len * 0.04) + self.long_prefill_token_threshold = int(max_model_len * 0.04) logger.info( "Concurrent partial prefills enabled with " @@ -270,15 +231,29 @@ class SchedulerConfig: self.long_prefill_token_threshold, ) - @model_validator(mode="after") - def _verify_args(self) -> Self: + self.verify_max_model_len(max_model_len) + + @property + @deprecated( + "`SchedulerConfig.chunked_prefill_enabled` has been renamed to " + "`SchedulerConfig.enable_chunked_prefill`. " + "The old name will be removed in v0.12." + ) + def chunked_prefill_enabled(self) -> bool: + return self.enable_chunked_prefill + + @chunked_prefill_enabled.setter + def chunked_prefill_enabled(self, value: bool): + self.enable_chunked_prefill = value + + def verify_max_model_len(self, max_model_len: int) -> Self: if ( - self.max_num_batched_tokens < self.max_model_len - and not self.chunked_prefill_enabled + self.max_num_batched_tokens < max_model_len + and not self.enable_chunked_prefill ): raise ValueError( f"max_num_batched_tokens ({self.max_num_batched_tokens}) is " - f"smaller than max_model_len ({self.max_model_len}). " + f"smaller than max_model_len ({max_model_len}). " "This effectively limits the maximum sequence length to " "max_num_batched_tokens and makes vLLM reject longer " "sequences. Please increase max_num_batched_tokens or " @@ -292,26 +267,26 @@ class SchedulerConfig: f"({self.max_num_seqs})." ) - if self.max_num_batched_tokens > self.max_num_seqs * self.max_model_len: + if self.max_num_batched_tokens > self.max_num_seqs * max_model_len: logger.warning( "max_num_batched_tokens (%d) exceeds max_num_seqs " "* max_model_len (%d). This may lead to unexpected behavior.", self.max_num_batched_tokens, - self.max_num_seqs * self.max_model_len, + self.max_num_seqs * max_model_len, ) if self.max_num_partial_prefills > 1: - if not self.chunked_prefill_enabled: + if not self.enable_chunked_prefill: raise ValueError( "Chunked prefill must be enabled to set " "max_num_partial_prefills > 1." ) - if self.long_prefill_token_threshold > self.max_model_len: + if self.long_prefill_token_threshold > max_model_len: raise ValueError( "long_prefill_token_threshold " f"({self.long_prefill_token_threshold}) cannot be greater " - f"than the max_model_len ({self.max_model_len})." + f"than the max_model_len ({max_model_len})." ) if self.max_long_partial_prefills > self.max_num_partial_prefills: diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 31cdeabe501d2..13a8632413d91 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -3,7 +3,7 @@ import ast import hashlib -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, get_args from pydantic import Field, SkipValidation, model_validator from pydantic.dataclasses import dataclass @@ -29,31 +29,25 @@ else: logger = init_logger(__name__) -SpeculativeMethod = Literal[ - "ngram", - "eagle", - "eagle3", - "medusa", - "mlp_speculator", - "draft_model", - "deepseek_mtp", - "ernie_mtp", - "qwen3_next_mtp", - "mimo_mtp", - "longcat_flash_mtp", - "pangu_ultra_moe_mtp", - "mtp", - "suffix", -] -MTP_MODEL_TYPES = ( +MTPModelTypes = Literal[ "deepseek_mtp", "mimo_mtp", "glm4_moe_mtp", "ernie_mtp", "qwen3_next_mtp", "longcat_flash_mtp", + "mtp", "pangu_ultra_moe_mtp", -) +] +EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes] +SpeculativeMethod = Literal[ + "ngram", + "medusa", + "mlp_speculator", + "draft_model", + "suffix", + EagleModelTypes, +] @config @@ -244,7 +238,7 @@ class SpeculativeConfig: # can not be detected, it will be considered as the "draft_model" by # default. - if self.method in MTP_MODEL_TYPES: + if self.method in get_args(MTPModelTypes) and self.method != "mtp": logger.warning( "method `%s` is deprecated and replaced with mtp.", self.method ) @@ -361,7 +355,9 @@ class SpeculativeConfig: self.method = "medusa" elif self.draft_model_config.hf_config.model_type == "mlp_speculator": self.method = "mlp_speculator" - elif self.draft_model_config.hf_config.model_type in MTP_MODEL_TYPES: + elif self.draft_model_config.hf_config.model_type in get_args( + MTPModelTypes + ): self.method = "mtp" if self.num_speculative_tokens > 1: logger.warning( diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index f581267f73f7d..672b004c4aa56 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -14,13 +14,14 @@ from dataclasses import replace from datetime import datetime from functools import lru_cache from pathlib import Path -from typing import TYPE_CHECKING, Any, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar, get_args import torch from pydantic import ConfigDict, Field, model_validator from pydantic.dataclasses import dataclass import vllm.envs as envs +from vllm.config.speculative import EagleModelTypes from vllm.logger import enable_trace_function_call, init_logger from vllm.transformers_utils.runai_utils import is_runai_obj_uri from vllm.utils import random_uuid @@ -374,10 +375,22 @@ class VllmConfig: "Async scheduling is not yet compatible with " "pipeline_parallel_size > 1." ) + # Currently, async scheduling only support eagle speculative + # decoding. if self.speculative_config is not None: - raise ValueError( - "Async scheduling is not yet compatible with speculative decoding." - ) + if self.speculative_config.method not in get_args(EagleModelTypes): + raise ValueError( + "Currently, async scheduling is only supported " + "with EAGLE/MTP kind of speculative decoding" + ) + if self.speculative_config.disable_padded_drafter_batch: + raise ValueError( + "async scheduling for EAGLE/MTP kind of speculative " + "decoding is enabled, but disable_padded_drafter_batch=True " + "disable_padded_drafter_batch=True is not supported for " + "this situation now. please set " + "disable_padded_drafter_batch=Fasle" + ) if not executor_supports_async_sched: raise ValueError( "Currently, async scheduling only supports `mp`, `uni`, or " @@ -411,7 +424,7 @@ class VllmConfig: if ( self.model_config is not None - and self.scheduler_config.chunked_prefill_enabled + and self.scheduler_config.enable_chunked_prefill and self.model_config.dtype == torch.float32 and current_platform.get_device_capability() == (7, 5) ): @@ -445,8 +458,6 @@ class VllmConfig: # and requires it to be enabled. if self.compilation_config.pass_config.enable_async_tp: self.compilation_config.pass_config.enable_sequence_parallelism = True - if self.compilation_config.pass_config.enable_sequence_parallelism: - self.compilation_config.custom_ops.append("+rms_norm") if current_platform.support_static_graph_mode(): # if cudagraph_mode is not explicitly set by users, set default @@ -483,21 +494,6 @@ class VllmConfig: "Overriding cudagraph_mode to PIECEWISE." ) self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE - elif ( - current_platform.is_cuda() - and current_platform.is_device_capability(100) - and self.model_config.max_model_len > 131072 - and not self.model_config.use_mla - ): - # Refer to vllm/utils/flashinfer.py::use_trtllm_attention() - logger.warning_once( - "NVIDIA Blackwell TRTLLM attention cannot support " - "max_model_len >= 131072 (found " - f"{self.model_config.max_model_len}), causing dynamic " - "dispatching that breaks full cudagraphs. " - "Overriding cudagraph_mode to PIECEWISE." - ) - self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE # disable cudagraph when enforce eager execution if self.model_config is not None and self.model_config.enforce_eager: @@ -584,7 +580,7 @@ class VllmConfig: ): for reason in disable_chunked_prefill_reasons: logger.info(reason) - self.scheduler_config.chunked_prefill_enabled = False + self.scheduler_config.enable_chunked_prefill = False self.scheduler_config.long_prefill_token_threshold = 0 if self.cache_config is not None: @@ -635,6 +631,32 @@ class VllmConfig: if self.compilation_config.mode == CompilationMode.VLLM_COMPILE: self.compilation_config.set_splitting_ops_for_v1() + if self.compilation_config.pass_config.enable_sequence_parallelism: + # With pipeline parallelism or dynamo partitioning, + # native rms norm tracing errors due to incorrect residual shape. + # Use custom rms norm to unblock. In the future, + # the pass will operate on higher-level IR to avoid the issue. + # TODO: https://github.com/vllm-project/vllm/issues/27894 + is_fullgraph = ( + self.compilation_config.use_inductor_graph_partition + or len(self.compilation_config.splitting_ops) == 0 + ) + if self.parallel_config.pipeline_parallel_size > 1 or not is_fullgraph: + if "-rms_norm" not in self.compilation_config.custom_ops: + self.compilation_config.custom_ops.append("+rms_norm") + else: + regime = ( + "Dynamo partition" + if not is_fullgraph + else "pipeline parallelism" + ) + logger.warning_once( + "Sequence parallelism not supported with" + "native rms_norm when using %s, " + "this will likely lead to an error.", + regime, + ) + # final check of cudagraph mode after all possible updates if current_platform.is_cuda_alike(): if ( @@ -929,7 +951,6 @@ class VllmConfig: model_config = self.model_config max_model_len = model_config.get_and_verify_max_len(max_model_len) self.model_config.max_model_len = max_model_len - self.scheduler_config.max_model_len = max_model_len def try_verify_and_update_config(self): if self.model_config is None: @@ -1026,7 +1047,7 @@ class VllmConfig: f"seed={self.model_config.seed}, " f"served_model_name={self.model_config.served_model_name}, " f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, " - f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa + f"enable_chunked_prefill={self.scheduler_config.enable_chunked_prefill}, " # noqa f"pooler_config={self.model_config.pooler_config!r}, " f"compilation_config={self.compilation_config!r}" ) diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index 5046cac2e90a7..052df19e34d72 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -8,7 +8,7 @@ from dataclasses import dataclass, field from multiprocessing import shared_memory from pickle import PickleBuffer from threading import Event -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from unittest.mock import patch import torch @@ -602,13 +602,87 @@ class MessageQueue: return obj return self.dequeue() + @staticmethod + def create_from_process_group_single_reader( + pg: ProcessGroup, + max_chunk_bytes, + max_chunks, + reader_rank: int = 0, + blocking: bool = False, + ) -> tuple["MessageQueue", list[Handle]]: + """ + Creates a MessageQueue for a process group with a single reader. + + This method is designed for scenarios where only one process (the reader) + will consume messages, and all other processes are writers. It sets up + the shared memory buffer and communication handles accordingly, and + gathers the handles from all processes to the reader. + + Args: + pg (ProcessGroup): The torch distributed process group. + max_chunk_bytes (int): Maximum size in bytes for each chunk in the buffer. + max_chunks (int): Maximum number of chunks in the buffer. + reader_rank (int, optional): The global rank that will act as the reader. + Defaults to 0. + blocking (bool, optional): If True, blocks until all processes are ready. + Defaults to False. + + Returns: + tuple[MessageQueue, list[Handle]]: + The MessageQueue instance for the calling process, + and a list of handles (only non-empty for the reader process). + """ + local_size = torch.cuda.device_count() + rank = dist.get_rank() + same_node = rank // local_size == reader_rank // local_size + buffer_io = MessageQueue( + n_reader=1, + n_local_reader=1 if same_node else 0, + max_chunk_bytes=max_chunk_bytes, + max_chunks=max_chunks, + ) + handle = buffer_io.export_handle() + handles = [None] * dist.get_world_size(pg) if rank == reader_rank else None + dist.gather_object(handle, handles, dst=reader_rank, group=pg) + if blocking: + buffer_io.wait_until_ready() + return buffer_io, cast(list[Handle], handles or []) + @staticmethod def create_from_process_group( pg: ProcessGroup | StatelessProcessGroup, max_chunk_bytes, max_chunks, - writer_rank=0, + writer_rank: int = 0, + external_writer_handle=None, + blocking: bool = True, ) -> "MessageQueue": + """ + Creates a MessageQueue for a distributed process group with one writer and + multiple readers. + + This method is designed for scenarios where one process (the writer) sends + messages, and all other processes (the readers) receive messages. It sets up + the shared memory buffer and socket communication handles accordingly, and + broadcasts the handle from the writer to all readers. + + Args: + pg (ProcessGroup | StatelessProcessGroup): The torch distributed process + group. + max_chunk_bytes (int): Maximum size in bytes for each chunk in the buffer. + max_chunks (int): Maximum number of chunks in the buffer. + writer_rank (int, optional): The global rank that will act as the writer. + Defaults to 0. + external_writer_handle (Handle, optional): Used when there is a handle + from an external Message Queue. If provided, use this handle to init + PG writer message queue instead of creating a new one. Defaults to None. + blocking (bool, optional): If True, blocks until all processes are ready. + Defaults to True. + + Returns: + MessageQueue: The MessageQueue instance for the calling process. + + """ if isinstance(pg, ProcessGroup): group_rank = dist.get_rank(pg) group_world_size = dist.get_world_size(pg) @@ -617,23 +691,26 @@ class MessageQueue: group_rank = pg.rank group_world_size = pg.world_size global_ranks = list(range(pg.world_size)) - from vllm.distributed.parallel_state import in_the_same_node_as status = in_the_same_node_as(pg, source_rank=writer_rank) - same_node_ranks = [i for i, s in enumerate(status) if s] - n_reader = group_world_size - 1 - n_local_reader = len(same_node_ranks) - 1 - local_reader_ranks = [i for i in same_node_ranks if i != writer_rank] - buffer_io: MessageQueue if group_rank == writer_rank: - buffer_io = MessageQueue( - n_reader=n_reader, - n_local_reader=n_local_reader, - local_reader_ranks=local_reader_ranks, - max_chunk_bytes=max_chunk_bytes, - max_chunks=max_chunks, - ) + if external_writer_handle is not None: + buffer_io = MessageQueue.create_from_handle( + external_writer_handle, group_rank + ) + else: + same_node_ranks = [i for i, s in enumerate(status) if s] + n_reader = group_world_size - 1 + n_local_reader = len(same_node_ranks) - 1 + local_reader_ranks = [i for i in same_node_ranks if i != writer_rank] + buffer_io = MessageQueue( + n_reader=n_reader, + n_local_reader=n_local_reader, + local_reader_ranks=local_reader_ranks, + max_chunk_bytes=max_chunk_bytes, + max_chunks=max_chunks, + ) handle = buffer_io.export_handle() if isinstance(pg, ProcessGroup): dist.broadcast_object_list( @@ -651,5 +728,6 @@ class MessageQueue: else: handle = pg.broadcast_obj(None, writer_rank) buffer_io = MessageQueue.create_from_handle(handle, group_rank) - buffer_io.wait_until_ready() + if blocking: + buffer_io.wait_until_ready() return buffer_io diff --git a/vllm/distributed/device_communicators/shm_object_storage.py b/vllm/distributed/device_communicators/shm_object_storage.py index 2ec33afb87839..4af2caa16b0d6 100644 --- a/vllm/distributed/device_communicators/shm_object_storage.py +++ b/vllm/distributed/device_communicators/shm_object_storage.py @@ -342,8 +342,8 @@ class MsgpackSerde(ObjectSerde): from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder self.encoder = MsgpackEncoder() - self.tensor_decoder = MsgpackDecoder(torch.Tensor) - self.mm_decoder = MsgpackDecoder(MultiModalKwargsItem) + self.tensor_decoder = MsgpackDecoder(torch.Tensor, share_mem=False) + self.mm_decoder = MsgpackDecoder(MultiModalKwargsItem, share_mem=False) self._mm_kwargs_item_cls = MultiModalKwargsItem def serialize(self, value: Any) -> tuple[bytes | list[bytes], int, bytes, int]: @@ -368,7 +368,7 @@ class MsgpackSerde(ObjectSerde): # pickle.loads do not read past the end of a pickled object # within a large buffer, so we can skip storing the metadata size type_name, nbytes, len_arr = pickle.loads(data_view) - serialized_data = bytearray(data_view[-nbytes:]) + serialized_data = data_view[-nbytes:] if type_name == torch.Tensor.__name__: obj = [] diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 42433c717cf26..a70c98b637131 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -49,6 +49,7 @@ from vllm.platforms import current_platform from vllm.utils.network_utils import make_zmq_path, make_zmq_socket from vllm.v1.attention.backends.utils import get_kv_cache_layout from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.worker.block_table import BlockTable if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata @@ -107,11 +108,14 @@ class NixlAgentMetadata(KVConnectorHandshakeMetadata): block_lens: list[int] attn_backend_name: str kv_cache_layout: str + block_size: int @dataclass class ReqMeta: local_block_ids: list[int] + # To be used when logical block size does not match the kernel block size + local_physical_block_ids: list[int] remote_block_ids: list[int] remote_host: str remote_port: int @@ -139,6 +143,7 @@ class NixlConnectorMetadata(KVConnectorMetadata): assert load_remote_cache ^ save_to_host _req = ReqMeta( local_block_ids=local_block_ids, + local_physical_block_ids=local_block_ids, remote_block_ids=kv_transfer_params["remote_block_ids"], remote_engine_id=kv_transfer_params["remote_engine_id"], remote_host=kv_transfer_params["remote_host"], @@ -705,6 +710,9 @@ class NixlConnectorWorker: self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first ) + block_size: int + remote_block_size: dict[EngineId, int] + def tp_ratio( self, remote_tp_size: int, @@ -721,6 +729,19 @@ class NixlConnectorWorker: ) return self.tp_size // remote_tp_size + def block_size_ratio( + self, + remote_block_size: int, + ) -> float: + """ + Calculate the block size ratio between local and remote TP. + """ + assert self.block_size % remote_block_size == 0, ( + f"Local block size {self.block_size} is not divisible " + f"by remote block size {remote_block_size} or vice versa." + ) + return self.block_size // remote_block_size + def tp_ratio_from_engine_id( self, remote_engine_id: EngineId, @@ -728,6 +749,13 @@ class NixlConnectorWorker: remote_tp_size = self.remote_tp_size[remote_engine_id] return self.tp_ratio(remote_tp_size) + def block_size_ratio_from_engine_id( + self, + remote_engine_id: EngineId, + ) -> float: + remote_block_size = self.remote_block_size[remote_engine_id] + return self.block_size_ratio(remote_block_size) + def is_kv_replicated(self, engine_id: EngineId) -> bool: """ Whether the KV cache is replicated across TP workers due to the @@ -862,6 +890,7 @@ class NixlConnectorWorker: # nixl_prepped_dlist_handle. self.src_xfer_side_handle: int = 0 + self.src_xfer_side_handles: dict[int, int] = {} # Map of engine_id -> nixl_prepped_dlist_handle (int)]. self.dst_xfer_side_handles: dict[EngineId, int] = {} @@ -921,6 +950,7 @@ class NixlConnectorWorker: logger.debug("Detected kv cache layout %s", self.kv_cache_layout) self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size} + self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size} # With heterogeneous TP, P must wait for all assigned D TP workers to # finish reading before safely freeing the blocks. self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int) @@ -932,9 +962,12 @@ class NixlConnectorWorker: remote_tp_size=self._tp_size, # shared state is_mla=self.use_mla, total_num_kv_heads=self.model_config.get_total_num_kv_heads(), + block_size=self.block_size, + remote_block_size=self._block_size, attn_backend=backend, ) self._use_pallas = self.kv_topo._use_pallas + self._physical_blocks_per_logical_kv_block = 1 def _nixl_handshake( self, @@ -982,9 +1015,13 @@ class NixlConnectorWorker: ) # Register Remote agent. + assert metadata.block_size <= self.block_size, ( + "nP > nD is not supported yet." + ) remote_agent_name = self.add_remote_agent( metadata, p_remote_rank, remote_tp_size ) + setup_agent_time = time.perf_counter() logger.debug( "NIXL handshake: add agent took: %s", @@ -1133,6 +1170,22 @@ class NixlConnectorWorker: if base_addr in seen_base_addresses: continue + # TODO (NickLucche): Get kernel_block_size in a cleaner way + # NHD default "view" for non-MLA cache + kernel_block_size = cache.shape[-2] if self.use_mla else cache.shape[-3] + + if self.block_size != kernel_block_size: + logger.info_once( + "User-specified logical block size (%s) does not match" + " physical kernel block size (%s). Using the latter. ", + self.block_size, + kernel_block_size, + ) + self._physical_blocks_per_logical_kv_block = ( + self.block_size // kernel_block_size + ) + self.block_size = kernel_block_size + seen_base_addresses.append(base_addr) curr_tensor_size_bytes = cache.numel() * cache.element_size() @@ -1196,43 +1249,10 @@ class NixlConnectorWorker: self.num_regions *= 2 # Register local/src descr for NIXL xfer. - blocks_data = [] - for i, base_addr in enumerate(seen_base_addresses): - kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) - # NOTE With heter-TP, more blocks are prepared than what are - # needed as self.num_blocks >= nixl_agent_meta.num_blocks. We - # could create fewer, but then _get_block_descs_ids needs to - # select agent_meta.num_blocks instead of self.num_blocks for - # local descr, and that makes handling regular flow less clean. - for block_id in range(self.num_blocks): - block_offset = block_id * self.block_len_per_layer[i] - addr = base_addr + block_offset - # (addr, len, device id) - blocks_data.append((addr, kv_block_len, self.device_id)) + self.seen_base_addresses = seen_base_addresses + self.src_xfer_side_handle = self.register_local_xfer_handler(self.block_size) - if self.kv_topo.is_kv_layout_blocks_first: - # Separate and interleave K/V regions to maintain the same - # descs ordering. This is needed for selecting contiguous heads - # when split across TP ranks. - for block_id in range(self.num_blocks): - block_offset = block_id * self.block_len_per_layer[i] - addr = base_addr + block_offset - # Register addresses for V cache (K registered first). - v_addr = addr + kv_block_len - blocks_data.append((v_addr, kv_block_len, self.device_id)) - logger.debug( - "Created %s blocks for src engine %s and rank %s on device id %s", - len(blocks_data), - self.engine_id, - self.tp_rank, - self.device_id, - ) - - descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) - # NIXL_INIT_AGENT to be used for preparations of local descs. - self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( - "NIXL_INIT_AGENT", descs - ) + self.src_xfer_side_handles[self.block_size] = self.src_xfer_side_handle # TODO(mgoin): Hybrid memory allocator is currently disabled for # models with local attention (Llama 4). Can remove this once enabled. @@ -1268,8 +1288,62 @@ class NixlConnectorWorker: kv_cache_layout=self.kv_cache_layout if not self.use_host_buffer else self.host_buffer_kv_cache_layout, + block_size=self.block_size, ) + def register_local_xfer_handler( + self, + block_size: int, + ) -> int: + """ + Function used for register local xfer handler with local block_size or + Remote block_size. + + When local block_size is same as remote block_size, we use local block_size + to register local_xfer_handler during init. + + When remote block size is less than local block size, we need to use + register another local_xfer_handler using remote block len to ensure + data copy correctness. + """ + block_size_ratio = self.block_size // block_size + blocks_data = [] + for i, base_addr in enumerate(self.seen_base_addresses): + # The new block_len is using prefill block_len; + # and num_blocks is multiple with N + kv_block_len = ( + self.get_backend_aware_kv_block_len(layer_idx=i) // block_size_ratio + ) + block_len_per_layer = self.block_len_per_layer[i] // block_size_ratio + num_blocks = self.num_blocks * block_size_ratio + for block_id in range(num_blocks): + block_offset = block_id * block_len_per_layer + addr = base_addr + block_offset + # (addr, len, device id) + blocks_data.append((addr, kv_block_len, self.device_id)) + + if self.kv_topo.is_kv_layout_blocks_first: + # Separate and interleave K/V regions to maintain the same + # descs ordering. This is needed for selecting contiguous heads + # when split across TP ranks. + for block_id in range(num_blocks): + block_offset = block_id * block_len_per_layer + addr = base_addr + block_offset + # Register addresses for V cache (K registered first). + v_addr = addr + kv_block_len + blocks_data.append((v_addr, kv_block_len, self.device_id)) + logger.debug( + "Created %s blocks for src engine %s and rank %s on device id %s", + len(blocks_data), + self.engine_id, + self.tp_rank, + self.device_id, + ) + + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) + # NIXL_INIT_AGENT to be used for preparations of local descs. + return self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs) + def add_remote_agent( self, nixl_agent_meta: NixlAgentMetadata, @@ -1328,6 +1402,8 @@ class NixlConnectorWorker: ### Register remote agent metadata if engine_id not in self._tp_size: self._tp_size[engine_id] = remote_tp_size + if engine_id not in self._block_size: + self._block_size[engine_id] = nixl_agent_meta.block_size remote_agent_name = self.nixl_wrapper.add_remote_agent( nixl_agent_meta.agent_metadata @@ -1338,6 +1414,13 @@ class NixlConnectorWorker: # Create dst descs and xfer side handles. TP workers have same #blocks # so we only register once per engine_id. + # Example: + # block_size_ratio > 1: + # remote: | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12| + # local origin:| 0| 1| 8| 12| + # local mapped:| 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|13|14|15| + block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(engine_id) + if engine_id not in self.dst_num_blocks: self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks @@ -1360,8 +1443,14 @@ class NixlConnectorWorker: # Register all remote blocks, but only the corresponding kv heads. for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) + remote_kv_block_len = kv_block_len // block_size_ratio + if block_size_ratio > 1: + # using remote kv_block_len as transfer unit + kv_block_len = remote_kv_block_len rank_offset = ( - self.tp_rank % tp_ratio * kv_block_len if not replicates_kv_cache else 0 + self.tp_rank % tp_ratio * remote_kv_block_len + if not replicates_kv_cache + else 0 ) for block_id in range(nixl_agent_meta.num_blocks): block_offset = block_id * nixl_agent_meta.block_lens[i] @@ -1396,6 +1485,13 @@ class NixlConnectorWorker: remote_agent_name, descs ) + if block_size_ratio > 1: + # when prefill with smaller block_size, we need to init a + # new handler with same block_len to match + self.src_xfer_side_handles[nixl_agent_meta.block_size] = ( + self.register_local_xfer_handler(nixl_agent_meta.block_size) + ) + return remote_agent_name def _validate_remote_agent_handshake( @@ -1412,6 +1508,9 @@ class NixlConnectorWorker: assert nixl_agent_meta.attn_backend_name == self.backend_name tp_ratio = self.kv_topo.tp_ratio_from_engine_id(remote_engine_id) + block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id( + remote_engine_id + ) assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP" assert not self._use_pallas or tp_ratio == 1, ( "TPU (pallas_v1) DOES NOT support heterogeneous TP yet." @@ -1442,33 +1541,26 @@ class NixlConnectorWorker: remote_block_len = nixl_agent_meta.block_lens[0] if self.use_mla or self.kv_topo.is_kv_replicated(remote_engine_id): # With replicated KV cache, only the number of blocks can differ. - assert self.block_len_per_layer == nixl_agent_meta.block_lens, ( - "KV cache sizes must match between P and D when replicated" - ) - remote_block_size = remote_block_len // (self.slot_size_per_layer[0]) + for i in range(len(self.block_len_per_layer)): + assert ( + self.block_len_per_layer[i] // block_size_ratio + == nixl_agent_meta.block_lens[i] + ), "KV cache sizes must match between P and D when replicated" else: # When MLA is not used, this is a list of the same block length for block_len in nixl_agent_meta.block_lens: assert block_len == remote_block_len, ( "All remote layers must have the same block size" ) - remote_block_size = remote_block_len // ( - self.slot_size_per_layer[0] * tp_ratio - ) - if self.kv_topo.is_kv_layout_blocks_first: - # With flashinfer, KV are sent in the same message. - remote_block_size //= 2 - assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, ( + assert ( + remote_block_len + == (self.block_len_per_layer[0] * tp_ratio) // block_size_ratio + ), ( "Remote P worker KV layer cache must be of shape [2, N, " "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." ) - assert self.block_size == remote_block_size, ( - "Remote P worker with different page/block size is not supported " - f"{self.block_size=}, {remote_block_size=}" - ) - # TP workers have same #blocks. assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks @@ -1479,7 +1571,7 @@ class NixlConnectorWorker: assert self.use_host_buffer assert self.copy_blocks is not None - local_block_ids = meta.local_block_ids + local_block_ids = meta.local_physical_block_ids self.copy_blocks( self.host_xfer_buffers, self.device_kv_caches, @@ -1492,7 +1584,7 @@ class NixlConnectorWorker: "synced recved kv of request[%s] to device kv buffer," "local_block_ids: %s. ", req_id, - ",".join(map(str, meta.local_block_ids)), + ",".join(map(str, local_block_ids)), ) def save_kv_to_host(self, metadata: NixlConnectorMetadata): @@ -1501,19 +1593,22 @@ class NixlConnectorWorker: assert self.copy_blocks is not None for req_id, meta in metadata.reqs_to_save.items(): + meta.local_physical_block_ids = self._logical_to_kernel_block_ids( + meta.local_block_ids + ) if logger.isEnabledFor(logging.DEBUG): logger.debug( "save_load_kv for request[%s] to host xfer buffer." "local_block_ids: %s. ", req_id, - ",".join(map(str, meta.local_block_ids)), + ",".join(map(str, meta.local_physical_block_ids)), ) # blocking self.copy_blocks( self.device_kv_caches, self.host_xfer_buffers, - meta.local_block_ids, - meta.local_block_ids, + meta.local_physical_block_ids, + meta.local_physical_block_ids, "d2h", ) @@ -1552,6 +1647,56 @@ class NixlConnectorWorker: ) cache.index_copy_(0, indices, permuted_blocks) + def blocksize_post_process(self, block_ids_per_ratio: dict[float, list[list[int]]]): + def _process_local_gt_remote(blocks_to_update, block_size_ratio): + n_kv_heads, block_size, head_size = blocks_to_update.shape[1:] + remote_block_size = block_size // block_size_ratio + n_blocks = block_size_ratio + # actual permute is to convert + # for local blocksize > remote blocksize + # ex: local blocksize = 16 tokens, remote blocksize = 4 tokens + # local block[0] = remote block[0, 1, 2, 3] + # remote is |h0-b0|h1-b0|h2-b0|h3-b0|h0-b1|h1-b1|h2-b1|h3-b1|... + # local is |h0-b0..................|h1-b0..................|... + # permute is to: + # 1. view => view remote as n_blocks * remote_shape(H,remoteN,D) + # 2. permute => (H, nblocks, remoteN, D) + # 3. flatten => (H, localN, D) + permuted_blocks = ( + blocks_to_update.reshape( + -1, n_blocks, n_kv_heads, remote_block_size, head_size + ) + .permute(0, 2, 1, 3, 4) + .flatten(2, 3) + ) + return permuted_blocks + + if len(self.device_kv_caches) == 0: + return + split_k_and_v = not ( + self.use_mla or self._use_pallas or self.kv_topo.is_kv_layout_blocks_first + ) + sample_cache = list(self.device_kv_caches.values())[0][0] + for block_size_ratio, block_ids_list in block_ids_per_ratio.items(): + assert block_size_ratio > 1, "Only nP < nD supported currently." + block_ids_list = [[item for sublist in block_ids_list for item in sublist]] + + for block_ids in block_ids_list: + indices = torch.tensor(block_ids, device=sample_cache.device) + + for _, cache_or_caches in self.device_kv_caches.items(): + cache_list = cache_or_caches if split_k_and_v else [cache_or_caches] + for cache in cache_list: + blocks_to_update = cache.index_select(0, indices) + # because kv_cache is always using original layout NHD as + # virtual shape while stride can be either HND / NHD at + # initialization. + # we need to firstly get physical view of the tensor + permuted_blocks = _process_local_gt_remote( + blocks_to_update.permute(0, 2, 1, 3), block_size_ratio + ).permute(0, 2, 1, 3) + cache.index_copy_(0, indices, permuted_blocks) + def get_finished(self) -> tuple[set[str], set[str]]: """ Get requests that are done sending or recving on this specific worker. @@ -1575,6 +1720,7 @@ class NixlConnectorWorker: ) block_ids_to_permute = [] + block_ids_for_blocksize_post_process = defaultdict(list) for req_id in done_recving: # clean up metadata for completed requests meta = self._recving_metadata.pop(req_id, None) @@ -1582,7 +1728,21 @@ class NixlConnectorWorker: if self.use_host_buffer: self.sync_recved_kv_to_device(req_id, meta) if self.enable_permute_local_kv: - block_ids_to_permute += meta.local_block_ids + block_ids_to_permute += meta.local_physical_block_ids + + # post processing for heteroblocksize + block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id( + meta.remote_engine_id + ) + if ( + not self.use_mla + and block_size_ratio > 1 + and self.kv_cache_layout == "HND" + ): + block_ids_for_blocksize_post_process[block_size_ratio].append( + meta.local_block_ids + ) + self.blocksize_post_process(block_ids_for_blocksize_post_process) if len(block_ids_to_permute) > 0: self.permute_device_kv(block_ids_to_permute) @@ -1669,7 +1829,7 @@ class NixlConnectorWorker: req_id, xfer_state, ) - # mark all blocks for this request as invalid + # mark all (logical)blocks for this request as invalid if meta := self._recving_metadata.pop(req_id, None): self._invalid_block_ids.update(meta.local_block_ids) self._recving_metadata.pop(req_id, None) @@ -1686,13 +1846,19 @@ class NixlConnectorWorker: We check for these trnxs to complete in each step(). """ for req_id, meta in metadata.reqs_to_recv.items(): + meta.local_physical_block_ids = self._logical_to_kernel_block_ids( + meta.local_block_ids + ) + meta.remote_block_ids = self._logical_to_kernel_block_ids( + meta.remote_block_ids + ) remote_engine_id = meta.remote_engine_id logger.debug( "start_load_kv for request %s from remote engine %s. " "Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id, remote_engine_id, - len(meta.local_block_ids), + len(meta.local_physical_block_ids), len(meta.remote_block_ids), ) # always store metadata for failure recovery @@ -1740,7 +1906,7 @@ class NixlConnectorWorker: self._read_blocks( request_id=req_id, dst_engine_id=meta.remote_engine_id, - local_block_ids=meta.local_block_ids, + local_block_ids=meta.local_physical_block_ids, remote_block_ids=meta.remote_block_ids, ) @@ -1751,6 +1917,24 @@ class NixlConnectorWorker: dst_engine_id: str, request_id: str, ): + block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id) + if block_size_ratio > 1: + local_block_ids = self.get_mapped_blocks( + np.asarray(local_block_ids), block_size_ratio + ) + if len(local_block_ids) > len(remote_block_ids): + # NOTE: + # get_mapped_blocks will always expand block_ids for n times. + # ex: + # prefill block_ids with block_size as 4: + # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + # Local decode block_ids with block_size as 16: [1, 2, 3] + # expland ecode block_ids with get_mapped_blocks from [1, 2, 3] to + # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] + # Then we clip local to align with prefill + # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] to + # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + local_block_ids = local_block_ids[: len(remote_block_ids)] # NOTE(rob): having the staging blocks be on the READER side is # not going to work well (since we will have to call rearrange tensors). # after we detect the txn is complete (which means we cannot make the @@ -1793,7 +1977,10 @@ class NixlConnectorWorker: remote_block_ids = remote_block_ids[-num_local_blocks:] # Get side handles. - local_xfer_side_handle = self.src_xfer_side_handle + remote_block_size = self.kv_topo.remote_block_size[dst_engine_id] + local_xfer_side_handle = self.src_xfer_side_handles.get( + remote_block_size, self.src_xfer_side_handle + ) remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id] # NOTE (nicolo) With homogeneous TP, each TP worker loads KV from @@ -1803,13 +1990,17 @@ class NixlConnectorWorker: # Get descs ids. local_block_descs_ids: np.ndarray remote_block_descs_ids: np.ndarray + if not self.block_window_per_layer: # Default case: assume global attention remote_block_descs_ids = self._get_block_descs_ids( - dst_engine_id, remote_block_ids + dst_engine_id, + remote_block_ids, ) local_block_descs_ids = self._get_block_descs_ids( - self.engine_id, local_block_ids + self.engine_id, + local_block_ids, + block_size_ratio=block_size_ratio, ) else: # TODO(mgoin): remove this once we have hybrid memory allocator @@ -1830,10 +2021,15 @@ class NixlConnectorWorker: # Get descs ids for the layer. layer_local_desc_ids = self._get_block_descs_ids( - self.engine_id, layer_local_block_ids, layer_idx + dst_engine_id, + layer_local_block_ids, + layer_idx, ) layer_remote_desc_ids = self._get_block_descs_ids( - dst_engine_id, layer_remote_block_ids, layer_idx + self.engine_id, + layer_remote_block_ids, + layer_idx, + block_size_ratio=block_size_ratio, ) local_descs_list.append(layer_local_desc_ids) @@ -1867,7 +2063,7 @@ class NixlConnectorWorker: "Marking blocks as invalid.", request_id, ) - # mark all blocks for this request as invalid + # mark all (logical) blocks for this request as invalid if meta := self._recving_metadata.get(request_id): self._invalid_block_ids.update(meta.local_block_ids) self.xfer_stats.record_failed_transfer() @@ -1875,8 +2071,31 @@ class NixlConnectorWorker: self.nixl_wrapper.release_xfer_handle(handle) self._failed_recv_reqs.add(request_id) + def get_mapped_blocks(self, block_ids, block_size_ratio): + """ + Calculates the new set of block IDs by mapping every element + in the (potentially sparse) input array. + Example: block_ids=[0, 2], block_size_ratio=2 + get_mapped_blocks 0 1 [2 3] 4 5 + # remote is |h0-b0|h1-b0||h0-b1|h1-b1||h0-b1|h1-b1|| + # local is |h0-b0......||h1-b0......||h2-b0........ + local_block_ids 0 [1] 2 + """ + if block_ids.size == 0: + return np.array([], dtype=np.int64) + + start_ids = block_ids * block_size_ratio + offsets = np.arange(block_size_ratio) + mapped_2d = start_ids[:, None] + offsets[None, :] + + return mapped_2d.flatten().astype(np.int64) + def _get_block_descs_ids( - self, engine_id: str, block_ids: list[int], layer_idx: int | None = None + self, + engine_id: str, + block_ids: list[int], + layer_idx: int | None = None, + block_size_ratio: float | None = None, ) -> np.ndarray: """ Get the descs ids for a set of block ids. @@ -1899,6 +2118,8 @@ class NixlConnectorWorker: region_ids = np.arange(layer_idx, layer_idx + 1) num_blocks = self.dst_num_blocks[engine_id] + if block_size_ratio is not None: + num_blocks = int(num_blocks * block_size_ratio) # Compute the desc ids for each block. region_ids = region_ids[:, None] @@ -1906,6 +2127,23 @@ class NixlConnectorWorker: descs_ids = region_ids * num_blocks + block_ids return descs_ids.flatten() + def _logical_to_kernel_block_ids(self, block_ids: list[int]) -> list[int]: + """ + Convert logical block ids to kernel physical block ids. + This is required when the logical block size (the one set by the user) + does not match the one required by the attn backend. + """ + if self._physical_blocks_per_logical_kv_block == 1: + # Noop when physical and logical block sizes are the same + return block_ids + block_ids_np = np.array(block_ids) + block_arange = np.arange(0, self._physical_blocks_per_logical_kv_block).reshape( + 1, -1 + ) + return BlockTable.map_to_kernel_blocks( + block_ids_np, self._physical_blocks_per_logical_kv_block, block_arange + ).tolist() + def get_backend_aware_kv_block_len(self, layer_idx: int): """ Get the block length for one K/V element (K and V have the same size). diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index c78e6a32733c1..852c4c644433f 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -385,6 +385,33 @@ class GroupCoordinator: torch.ops._C, "init_shm_manager" ) + def create_mq_broadcaster( + self, writer_rank=0, external_writer_handle=None, blocking=True + ): + from vllm.distributed.device_communicators.shm_broadcast import MessageQueue + + return MessageQueue.create_from_process_group( + self.cpu_group, + 1 << 22, + 6, + writer_rank=writer_rank, + external_writer_handle=external_writer_handle, + blocking=blocking, + ) + + def create_single_reader_mq_broadcasters( + self, reader_rank_in_group=0, blocking=False + ): + from vllm.distributed.device_communicators.shm_broadcast import MessageQueue + + return MessageQueue.create_from_process_group_single_reader( + self.cpu_group, + 1 << 22, + 6, + reader_rank=self.ranks[reader_rank_in_group], + blocking=blocking, + ) + @property def first_rank(self): """Return the global rank of the first process in the group""" @@ -997,6 +1024,7 @@ class GroupCoordinator: _WORLD: GroupCoordinator | None = None +_INNER_DP_WORLD: GroupCoordinator | None = None _NODE_COUNT: int | None = None @@ -1005,6 +1033,11 @@ def get_world_group() -> GroupCoordinator: return _WORLD +def get_inner_dp_world_group() -> GroupCoordinator: + assert _INNER_DP_WORLD is not None, "inner dp world group is not initialized" + return _INNER_DP_WORLD + + def init_world_group( ranks: list[int], local_rank: int, backend: str ) -> GroupCoordinator: @@ -1023,12 +1056,13 @@ def init_model_parallel_group( backend: str, use_message_queue_broadcaster: bool = False, group_name: str | None = None, + use_device_communicator: bool = True, ) -> GroupCoordinator: return GroupCoordinator( group_ranks=group_ranks, local_rank=local_rank, torch_distributed_backend=backend, - use_device_communicator=True, + use_device_communicator=use_device_communicator, use_message_queue_broadcaster=use_message_queue_broadcaster, group_name=group_name, ) @@ -1143,7 +1177,14 @@ def init_distributed_environment( from vllm.config import get_current_vllm_config config = get_current_vllm_config() - if ( + if config is not None and config.parallel_config.nnodes > 1: + parallel_config = config.parallel_config + ip = parallel_config.master_addr + rank = parallel_config.data_parallel_rank * world_size + rank + world_size = parallel_config.world_size_across_dp + port = parallel_config.master_port + distributed_init_method = get_distributed_init_method(ip, port) + elif ( config is not None and config.parallel_config.data_parallel_size > 1 and config.parallel_config.distributed_executor_backend != "external_launcher" @@ -1164,6 +1205,14 @@ def init_distributed_environment( distributed_init_method, ) if not torch.distributed.is_initialized(): + logger.info( + "world_size=%d rank=%d local_rank=%d distributed_init_method=%s backend=%s", + world_size, + rank, + local_rank, + distributed_init_method, + backend, + ) assert distributed_init_method is not None, ( "distributed_init_method must be provided when initializing " "distributed environment" @@ -1192,16 +1241,36 @@ def init_distributed_environment( # local rank not set, this usually happens in single-node # setting, where we can use rank as local rank local_rank = envs.LOCAL_RANK if distributed_init_method == "env://" else rank - global _WORLD, _NODE_COUNT + global _WORLD, _NODE_COUNT, _INNER_DP_WORLD if _WORLD is None: ranks = list(range(torch.distributed.get_world_size())) _WORLD = init_world_group(ranks, local_rank, backend) - _NODE_COUNT = _node_count(_WORLD.cpu_group) + if config.parallel_config.nnodes > 1: + _NODE_COUNT = config.parallel_config.nnodes + else: + _NODE_COUNT = _node_count(_WORLD.cpu_group) logger.debug("Detected %d nodes in the distributed environment", _NODE_COUNT) else: assert _WORLD.world_size == torch.distributed.get_world_size(), ( "world group already initialized with a different world size" ) + if config.parallel_config.nnodes_within_dp > 1: + if parallel_config.data_parallel_size > 1: + world_size_inner_dp = parallel_config.world_size + group_ranks = [ + [dp_rank * world_size_inner_dp + i for i in range(world_size_inner_dp)] + for dp_rank in range(parallel_config.data_parallel_size) + ] + _INNER_DP_WORLD = init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + use_message_queue_broadcaster=True, + group_name="inner_dp_world", + use_device_communicator=False, + ) + else: + _INNER_DP_WORLD = _WORLD def initialize_model_parallel( diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 13c7704f5bf3d..d011dfdbfbb2e 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -384,6 +384,10 @@ class EngineArgs: ) = ParallelConfig.distributed_executor_backend # number of P/D disaggregation (or other disaggregation) workers pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size + master_addr: str = ParallelConfig.master_addr + master_port: int = ParallelConfig.master_port + nnodes: int = ParallelConfig.nnodes + node_rank: int = ParallelConfig.node_rank tensor_parallel_size: int = ParallelConfig.tensor_parallel_size decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size dcp_kv_cache_interleave_size: int = ParallelConfig.dcp_kv_cache_interleave_size @@ -394,6 +398,7 @@ class EngineArgs: data_parallel_address: str | None = None data_parallel_rpc_port: int | None = None data_parallel_hybrid_lb: bool = False + data_parallel_external_lb: bool = False data_parallel_backend: str = ParallelConfig.data_parallel_backend enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel all2all_backend: str | None = ParallelConfig.all2all_backend @@ -428,11 +433,11 @@ class EngineArgs: cpu_offload_gb: float = CacheConfig.cpu_offload_gb gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization kv_cache_memory_bytes: int | None = CacheConfig.kv_cache_memory_bytes - max_num_batched_tokens: int | None = SchedulerConfig.max_num_batched_tokens + max_num_batched_tokens: int | None = None max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills long_prefill_token_threshold: int = SchedulerConfig.long_prefill_token_threshold - max_num_seqs: int | None = SchedulerConfig.max_num_seqs + max_num_seqs: int | None = None max_logprobs: int = ModelConfig.max_logprobs logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode disable_log_stats: bool = False @@ -485,7 +490,7 @@ class EngineArgs: model_loader_extra_config: dict = get_field(LoadConfig, "model_loader_extra_config") ignore_patterns: str | list[str] = get_field(LoadConfig, "ignore_patterns") - enable_chunked_prefill: bool | None = SchedulerConfig.enable_chunked_prefill + enable_chunked_prefill: bool | None = None disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input disable_hybrid_kv_cache_manager: bool = ( @@ -558,12 +563,15 @@ class EngineArgs: async_scheduling: bool | None = SchedulerConfig.async_scheduling + stream_interval: int = SchedulerConfig.stream_interval + kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill kv_offloading_size: float | None = CacheConfig.kv_offloading_size kv_offloading_backend: KVOffloadingBackend | None = ( CacheConfig.kv_offloading_backend ) + tokens_only: bool = False def __post_init__(self): # support `EngineArgs(compilation_config={...})` @@ -746,6 +754,10 @@ class EngineArgs: "-pp", **parallel_kwargs["pipeline_parallel_size"], ) + parallel_group.add_argument("--master-addr", **parallel_kwargs["master_addr"]) + parallel_group.add_argument("--master-port", **parallel_kwargs["master_port"]) + parallel_group.add_argument("--nnodes", "-n", **parallel_kwargs["nnodes"]) + parallel_group.add_argument("--node-rank", "-r", **parallel_kwargs["node_rank"]) parallel_group.add_argument( "--tensor-parallel-size", "-tp", **parallel_kwargs["tensor_parallel_size"] ) @@ -800,7 +812,14 @@ class EngineArgs: help='Backend for data parallel, either "mp" or "ray".', ) parallel_group.add_argument( - "--data-parallel-hybrid-lb", **parallel_kwargs["data_parallel_hybrid_lb"] + "--data-parallel-hybrid-lb", + "-dph", + **parallel_kwargs["data_parallel_hybrid_lb"], + ) + parallel_group.add_argument( + "--data-parallel-external-lb", + "-dpe", + **parallel_kwargs["data_parallel_external_lb"], ) parallel_group.add_argument( "--enable-expert-parallel", **parallel_kwargs["enable_expert_parallel"] @@ -1067,6 +1086,9 @@ class EngineArgs: scheduler_group.add_argument( "--async-scheduling", **scheduler_kwargs["async_scheduling"] ) + scheduler_group.add_argument( + "--stream-interval", **scheduler_kwargs["stream_interval"] + ) # Compilation arguments compilation_kwargs = get_kwargs(CompilationConfig) @@ -1422,12 +1444,56 @@ class EngineArgs: assert not headless or not self.data_parallel_hybrid_lb, ( "data_parallel_hybrid_lb is not applicable in headless mode" ) - - data_parallel_external_lb = self.data_parallel_rank is not None + assert not (self.data_parallel_hybrid_lb and self.data_parallel_external_lb), ( + "data_parallel_hybrid_lb and data_parallel_external_lb cannot both be True." + ) + assert self.data_parallel_backend == "mp" or self.nnodes == 1, ( + "nnodes > 1 is only supported with data_parallel_backend=mp" + ) + inferred_data_parallel_rank = 0 + if self.nnodes > 1: + world_size = ( + self.data_parallel_size + * self.pipeline_parallel_size + * self.tensor_parallel_size + ) + world_size_within_dp = ( + self.pipeline_parallel_size * self.tensor_parallel_size + ) + local_world_size = world_size // self.nnodes + assert world_size % self.nnodes == 0, ( + f"world_size={world_size} must be divisible by nnodes={self.nnodes}." + ) + assert self.node_rank < self.nnodes, ( + f"node_rank={self.node_rank} must be less than nnodes={self.nnodes}." + ) + inferred_data_parallel_rank = ( + self.node_rank * local_world_size + ) // world_size_within_dp + if self.data_parallel_size > 1 and self.data_parallel_external_lb: + self.data_parallel_rank = inferred_data_parallel_rank + logger.info( + "Inferred data_parallel_rank %d from node_rank %d for external lb", + self.data_parallel_rank, + self.node_rank, + ) + elif self.data_parallel_size_local is None: + # Infer data parallel size local for internal dplb: + self.data_parallel_size_local = max( + local_world_size // world_size_within_dp, 1 + ) + data_parallel_external_lb = ( + self.data_parallel_external_lb or self.data_parallel_rank is not None + ) # Local DP rank = 1, use pure-external LB. if data_parallel_external_lb: + assert self.data_parallel_rank is not None, ( + "data_parallel_rank or node_rank must be spefified if " + "data_parallel_external_lb is enable." + ) assert self.data_parallel_size_local in (1, None), ( - "data_parallel_size_local must be 1 when data_parallel_rank is set" + "data_parallel_size_local must be 1 or None when data_parallel_rank " + "is set" ) data_parallel_size_local = 1 # Use full external lb if we have local_size of 1. @@ -1441,6 +1507,11 @@ class EngineArgs: if self.data_parallel_hybrid_lb and data_parallel_size_local == 1: # Use full external lb if we have local_size of 1. + logger.warning( + "data_parallel_hybrid_lb is not eligible when " + "data_parallel_size_local = 1, autoswitch to " + "data_parallel_external_lb." + ) data_parallel_external_lb = True self.data_parallel_hybrid_lb = False @@ -1448,7 +1519,15 @@ class EngineArgs: # Disable hybrid LB mode if set for a single node self.data_parallel_hybrid_lb = False - self.data_parallel_rank = self.data_parallel_start_rank or 0 + self.data_parallel_rank = ( + self.data_parallel_start_rank or inferred_data_parallel_rank + ) + if self.nnodes > 1: + logger.info( + "Inferred data_parallel_rank %d from node_rank %d", + self.data_parallel_rank, + self.node_rank, + ) else: assert not self.data_parallel_hybrid_lb, ( "data_parallel_size_local must be set to use data_parallel_hybrid_lb." @@ -1478,7 +1557,9 @@ class EngineArgs: "data_parallel_backend can only be ray or mp, got %s", self.data_parallel_backend, ) - data_parallel_address = ParallelConfig.data_parallel_master_ip + data_parallel_address = ( + self.master_addr or ParallelConfig.data_parallel_master_ip + ) else: data_parallel_address = self.data_parallel_address @@ -1490,6 +1571,10 @@ class EngineArgs: else ParallelConfig.data_parallel_rpc_port ) + if self.tokens_only and not model_config.skip_tokenizer_init: + model_config.skip_tokenizer_init = True + logger.info("Skipping tokenizer initialization for tokens-only mode.") + # Forward the deprecated CLI args to the EPLB config. if self.num_redundant_experts is not None: self.eplb_config.num_redundant_experts = self.num_redundant_experts @@ -1507,6 +1592,10 @@ class EngineArgs: data_parallel_rank=self.data_parallel_rank or 0, data_parallel_external_lb=data_parallel_external_lb, data_parallel_size_local=data_parallel_size_local, + master_addr=self.master_addr, + master_port=self.master_port, + nnodes=self.nnodes, + node_rank=self.node_rank, data_parallel_master_ip=data_parallel_address, data_parallel_rpc_port=data_parallel_rpc_port, data_parallel_backend=self.data_parallel_backend, @@ -1562,6 +1651,7 @@ class EngineArgs: long_prefill_token_threshold=self.long_prefill_token_threshold, disable_hybrid_kv_cache_manager=self.disable_hybrid_kv_cache_manager, async_scheduling=self.async_scheduling, + stream_interval=self.stream_interval, ) if not model_config.is_multimodal_model and self.default_mm_loras: @@ -1631,40 +1721,39 @@ class EngineArgs: ) observability_config = ObservabilityConfig( - show_hidden_metrics_for_version=(self.show_hidden_metrics_for_version), + show_hidden_metrics_for_version=self.show_hidden_metrics_for_version, otlp_traces_endpoint=self.otlp_traces_endpoint, collect_detailed_traces=self.collect_detailed_traces, ) # Compilation config overrides + compilation_config = copy.deepcopy(self.compilation_config) if self.cuda_graph_sizes is not None: logger.warning( "--cuda-graph-sizes is deprecated and will be removed in v0.13.0 or " "v1.0.0, whichever is soonest. Please use --cudagraph-capture-sizes " "instead." ) - if self.compilation_config.cudagraph_capture_sizes is not None: + if compilation_config.cudagraph_capture_sizes is not None: raise ValueError( "cuda_graph_sizes and compilation_config." "cudagraph_capture_sizes are mutually exclusive" ) - self.compilation_config.cudagraph_capture_sizes = self.cuda_graph_sizes + compilation_config.cudagraph_capture_sizes = self.cuda_graph_sizes if self.cudagraph_capture_sizes is not None: - if self.compilation_config.cudagraph_capture_sizes is not None: + if compilation_config.cudagraph_capture_sizes is not None: raise ValueError( "cudagraph_capture_sizes and compilation_config." "cudagraph_capture_sizes are mutually exclusive" ) - self.compilation_config.cudagraph_capture_sizes = ( - self.cudagraph_capture_sizes - ) + compilation_config.cudagraph_capture_sizes = self.cudagraph_capture_sizes if self.max_cudagraph_capture_size is not None: - if self.compilation_config.max_cudagraph_capture_size is not None: + if compilation_config.max_cudagraph_capture_size is not None: raise ValueError( "max_cudagraph_capture_size and compilation_config." "max_cudagraph_capture_size are mutually exclusive" ) - self.compilation_config.max_cudagraph_capture_size = ( + compilation_config.max_cudagraph_capture_size = ( self.max_cudagraph_capture_size ) @@ -1679,7 +1768,7 @@ class EngineArgs: load_config=load_config, structured_outputs_config=self.structured_outputs_config, observability_config=observability_config, - compilation_config=self.compilation_config, + compilation_config=compilation_config, kv_transfer_config=self.kv_transfer_config, kv_events_config=self.kv_events_config, ec_transfer_config=self.ec_transfer_config, @@ -1733,41 +1822,41 @@ class EngineArgs: ) _raise_unsupported_error(feature_name=name) - def _set_default_args( - self, usage_context: UsageContext, model_config: ModelConfig - ) -> None: - """Set Default Arguments for V1 Engine.""" - - # V1 uses chunked prefills and prefix caching by default - # for non-pooling tasks. - # For pooling tasks the default is False + @classmethod + def get_chunked_prefill_prefix_caching_defaults( + cls, + model_config: ModelConfig, + ) -> tuple[bool, bool]: if model_config.runner_type != "pooling": - self.enable_chunked_prefill = True + default_chunked_prefill = True - if self.enable_prefix_caching is None: - # Disable prefix caching default for hybrid models - # since the feature is still experimental. - if model_config.is_hybrid: - self.enable_prefix_caching = False - else: - self.enable_prefix_caching = True + # Disable prefix caching default for hybrid models + # since the feature is still experimental. + default_prefix_caching = not model_config.is_hybrid else: + assert model_config.pooler_config is not None + pooling_type = model_config.pooler_config.pooling_type - is_causal = getattr(model_config.hf_config, "is_causal", True) incremental_prefill_supported = ( pooling_type is not None and pooling_type.lower() == "last" - and bool(is_causal) + and getattr(model_config.hf_config, "is_causal", True) ) - action = "Enabling" if incremental_prefill_supported else "Disabling" + default_chunked_prefill = incremental_prefill_supported + default_prefix_caching = incremental_prefill_supported - if self.enable_chunked_prefill is None: - self.enable_chunked_prefill = incremental_prefill_supported - logger.info("(%s) chunked prefill by default", action) - if self.enable_prefix_caching is None: - self.enable_prefix_caching = incremental_prefill_supported - logger.info("(%s) prefix caching by default", action) + return default_chunked_prefill, default_prefix_caching + + @classmethod + def get_batch_defaults( + cls, + world_size: int, + ) -> tuple[dict[UsageContext | None, int], dict[UsageContext | None, int]]: + from vllm.usage.usage_lib import UsageContext + + default_max_num_batched_tokens: dict[UsageContext | None, int] + default_max_num_seqs: dict[UsageContext | None, int] # When no user override, set the default values based on the usage # context. @@ -1788,8 +1877,6 @@ class EngineArgs: # NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces # throughput, see PR #17885 for more details. # So here we do an extra device name check to prevent such regression. - from vllm.usage.usage_lib import UsageContext - if device_memory >= 70 * GiB_bytes and "a100" not in device_name: # For GPUs like H100 and MI300x, use larger default values. default_max_num_batched_tokens = { @@ -1813,22 +1900,26 @@ class EngineArgs: # tpu specific default values. if current_platform.is_tpu(): - default_max_num_batched_tokens_tpu = { - UsageContext.LLM_CLASS: { - "V6E": 2048, - "V5E": 1024, - "V5P": 512, - }, - UsageContext.OPENAI_API_SERVER: { - "V6E": 1024, - "V5E": 512, - "V5P": 256, - }, - } + chip_name = current_platform.get_device_name() + + if chip_name == "V6E": + default_max_num_batched_tokens = { + UsageContext.LLM_CLASS: 2048, + UsageContext.OPENAI_API_SERVER: 1024, + } + elif chip_name == "V5E": + default_max_num_batched_tokens = { + UsageContext.LLM_CLASS: 1024, + UsageContext.OPENAI_API_SERVER: 512, + } + elif chip_name == "V5P": + default_max_num_batched_tokens = { + UsageContext.LLM_CLASS: 512, + UsageContext.OPENAI_API_SERVER: 256, + } # cpu specific default values. if current_platform.is_cpu(): - world_size = self.pipeline_parallel_size * self.tensor_parallel_size default_max_num_batched_tokens = { UsageContext.LLM_CLASS: 4096 * world_size, UsageContext.OPENAI_API_SERVER: 2048 * world_size, @@ -1838,44 +1929,104 @@ class EngineArgs: UsageContext.OPENAI_API_SERVER: 128 * world_size, } - use_context_value = usage_context.value if usage_context else None - if ( - self.max_num_batched_tokens is None - and usage_context in default_max_num_batched_tokens + return default_max_num_batched_tokens, default_max_num_seqs + + def _set_default_args( + self, usage_context: UsageContext, model_config: ModelConfig + ) -> None: + """Set Default Arguments for V1 Engine.""" + ( + default_chunked_prefill, + default_prefix_caching, + ) = self.get_chunked_prefill_prefix_caching_defaults(model_config) + + if self.enable_chunked_prefill is None: + self.enable_chunked_prefill = default_chunked_prefill + + logger.debug( + "%s chunked prefill by default", + "Enabling" if default_chunked_prefill else "Disabling", + ) + elif ( + model_config.runner_type == "pooling" + and self.enable_chunked_prefill + and not default_chunked_prefill ): - if current_platform.is_tpu(): - chip_name = current_platform.get_device_name() - if chip_name in default_max_num_batched_tokens_tpu[usage_context]: - self.max_num_batched_tokens = default_max_num_batched_tokens_tpu[ - usage_context - ][chip_name] - else: - self.max_num_batched_tokens = default_max_num_batched_tokens[ - usage_context - ] - else: - if not self.enable_chunked_prefill: - self.max_num_batched_tokens = model_config.max_model_len - else: - self.max_num_batched_tokens = default_max_num_batched_tokens[ - usage_context - ] + logger.warning( + "This model does not officially support chunked prefill. " + "Enabling this manually may cause the engine to crash " + "or produce incorrect outputs.", + ) + + if self.enable_prefix_caching is None: + self.enable_prefix_caching = default_prefix_caching + logger.debug( - "Setting max_num_batched_tokens to %d for %s usage context.", + "%s prefix caching by default", + "Enabling" if default_prefix_caching else "Disabling", + ) + elif ( + model_config.runner_type == "pooling" + and self.enable_prefix_caching + and not default_prefix_caching + ): + logger.warning( + "This model does not officially support prefix caching. " + "Enabling this manually may cause the engine to crash " + "or produce incorrect outputs.", + ) + + world_size = self.pipeline_parallel_size * self.tensor_parallel_size + ( + default_max_num_batched_tokens, + default_max_num_seqs, + ) = self.get_batch_defaults(world_size) + + orig_max_num_batched_tokens = self.max_num_batched_tokens + orig_max_num_seqs = self.max_num_seqs + + if self.max_num_batched_tokens is None: + self.max_num_batched_tokens = default_max_num_batched_tokens.get( + usage_context, + SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS, + ) + + if self.max_num_seqs is None: + self.max_num_seqs = default_max_num_seqs.get( + usage_context, + SchedulerConfig.DEFAULT_MAX_NUM_SEQS, + ) + + if orig_max_num_batched_tokens is None: + if not self.enable_chunked_prefill: + # If max_model_len is too short, use the default for higher throughput. + self.max_num_batched_tokens = max( + model_config.max_model_len, + self.max_num_batched_tokens, + ) + + # When using default settings, + # Ensure max_num_batched_tokens does not exceed model limit. + # Some models (e.g., Whisper) have embeddings tied to max length. + self.max_num_batched_tokens = min( + self.max_num_seqs * model_config.max_model_len, self.max_num_batched_tokens, - use_context_value, - ) - - if self.max_num_seqs is None and usage_context in default_max_num_seqs: - self.max_num_seqs = min( - default_max_num_seqs[usage_context], - self.max_num_batched_tokens or sys.maxsize, ) logger.debug( - "Setting max_num_seqs to %d for %s usage context.", + "Defaulting max_num_batched_tokens to %d for %s usage context.", + self.max_num_batched_tokens, + usage_context.value if usage_context else None, + ) + + if orig_max_num_seqs is None: + assert self.max_num_batched_tokens is not None # For type checking + self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens) + + logger.debug( + "Defaulting max_num_seqs to %d for %s usage context.", self.max_num_seqs, - use_context_value, + usage_context.value if usage_context else None, ) diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 24fcd9fe1cab9..462d2c4e50e73 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -125,7 +125,7 @@ class EngineClient(ABC): ... @abstractmethod - async def reset_prefix_cache(self, device: Device | None = None) -> None: + async def reset_prefix_cache(self) -> None: """Reset the prefix cache""" ... diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index d7d6419d643b0..3b722c2d92770 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -240,6 +240,9 @@ class CustomChatCompletionMessageParam(TypedDict, total=False): tool_calls: Iterable[ChatCompletionMessageToolCallParam] | None """The tool calls generated by the model, such as function calls.""" + reasoning: str | None + """The reasoning content for interleaved thinking.""" + ChatCompletionMessageParam: TypeAlias = ( OpenAIChatCompletionMessageParam @@ -265,6 +268,12 @@ class ConversationMessage(TypedDict, total=False): tool_calls: Iterable[ChatCompletionMessageToolCallParam] | None """The tool calls generated by the model, such as function calls.""" + reasoning: str | None + """The reasoning content for interleaved thinking.""" + + reasoning_content: str | None + """Deprecated: The reasoning content for interleaved thinking.""" + # Passed in by user ChatTemplateContentFormatOption = Literal["auto", "string", "openai"] @@ -1374,7 +1383,7 @@ def _parse_chat_message_content( ) -> list[ConversationMessage]: role = message["role"] content = message.get("content") - + reasoning = message.get("reasoning") or message.get("reasoning_content") if content is None: content = [] elif isinstance(content, str): @@ -1396,6 +1405,12 @@ def _parse_chat_message_content( # follow the OpenAI spec. if "tool_calls" in parsed_msg and parsed_msg["tool_calls"] is not None: result_msg["tool_calls"] = list(parsed_msg["tool_calls"]) + # Include reasoning if present for interleaved thinking. + if reasoning is not None: + result_msg["reasoning"] = cast(str, reasoning) + result_msg["reasoning_content"] = cast( + str, reasoning + ) # keep compatibility elif role == "tool": parsed_msg = _ToolParser(message) if "tool_call_id" in parsed_msg: diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 2678658dd1262..96608f360e17b 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -24,6 +24,7 @@ from vllm.utils.system_utils import decorate_logs, set_process_title from vllm.v1.engine.core import EngineCoreProc from vllm.v1.engine.utils import CoreEngineProcManager, launch_core_engines from vllm.v1.executor import Executor +from vllm.v1.executor.multiproc_executor import MultiprocExecutor from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus from vllm.v1.utils import APIServerProcessManager, wait_for_completion_or_failure @@ -97,18 +98,40 @@ def run_headless(args: argparse.Namespace): if local_engine_count <= 0: raise ValueError("data_parallel_size_local must be > 0 in headless mode") - host = parallel_config.data_parallel_master_ip - port = engine_args.data_parallel_rpc_port # add to config too - handshake_address = get_tcp_uri(host, port) + shutdown_requested = False # Catch SIGTERM and SIGINT to allow graceful shutdown. def signal_handler(signum, frame): + nonlocal shutdown_requested logger.debug("Received %d signal.", signum) - raise SystemExit + if not shutdown_requested: + shutdown_requested = True + raise SystemExit signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGINT, signal_handler) + if parallel_config.node_rank_within_dp > 0: + from vllm.version import __version__ as VLLM_VERSION + + # Run headless workers (for multi-node PP/TP). + host = parallel_config.master_addr + head_node_address = f"{host}:{parallel_config.master_port}" + logger.info( + "Launching vLLM (v%s) headless multiproc executor, " + "with head node address %s for torch.distributed process group.", + VLLM_VERSION, + head_node_address, + ) + + executor = MultiprocExecutor(vllm_config, monitor_workers=False) + executor.start_worker_monitor(inline=True) + return + + host = parallel_config.data_parallel_master_ip + port = parallel_config.data_parallel_rpc_port + handshake_address = get_tcp_uri(host, port) + logger.info( "Launching %d data parallel engine(s) in headless mode, " "with head node address %s.", diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 62717a7eacdf0..b0786bd355aa6 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -32,7 +32,6 @@ from vllm.config.model import ( TokenizerMode, ) from vllm.engine.arg_utils import EngineArgs -from vllm.engine.protocol import Device from vllm.entrypoints.chat_utils import ( ChatCompletionMessageParam, ChatTemplateContentFormatOption, @@ -1499,8 +1498,8 @@ class LLM: def stop_profile(self) -> None: self.llm_engine.stop_profile() - def reset_prefix_cache(self, device: Device | None = None) -> None: - self.llm_engine.reset_prefix_cache(device) + def reset_prefix_cache(self) -> None: + self.llm_engine.reset_prefix_cache() def sleep(self, level: int = 1): """ diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index fbb2d32a229da..3cf66fcd27e2a 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -39,7 +39,7 @@ from typing_extensions import assert_never import vllm.envs as envs from vllm.config import VllmConfig from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.engine.protocol import Device, EngineClient +from vllm.engine.protocol import EngineClient from vllm.entrypoints.anthropic.protocol import ( AnthropicError, AnthropicErrorResponse, @@ -65,6 +65,8 @@ from vllm.entrypoints.openai.protocol import ( EmbeddingResponse, ErrorInfo, ErrorResponse, + GenerateRequest, + GenerateResponse, IOProcessorResponse, PoolingBytesResponse, PoolingRequest, @@ -96,6 +98,7 @@ from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses from vllm.entrypoints.openai.serving_score import ServingScores from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization +from vllm.entrypoints.openai.serving_tokens import ServingTokens from vllm.entrypoints.openai.serving_transcription import ( OpenAIServingTranscription, OpenAIServingTranslation, @@ -357,6 +360,10 @@ def engine_client(request: Request) -> EngineClient: return request.app.state.engine_client +def generate_tokens(request: Request) -> ServingTokens | None: + return request.app.state.serving_tokens + + @router.get("/health", response_class=Response) async def health(raw_request: Request) -> Response: """Health check.""" @@ -1062,12 +1069,8 @@ if envs.VLLM_SERVER_DEV_MODE: Reset the prefix cache. Note that we currently do not check if the prefix cache is successfully reset in the API server. """ - device = None - device_str = raw_request.query_params.get("device") - if device_str is not None: - device = Device[device_str.upper()] - logger.info("Resetting prefix cache with specific %s...", str(device)) - await engine_client(raw_request).reset_prefix_cache(device) + logger.info("Resetting prefix cache...") + await engine_client(raw_request).reset_prefix_cache() return Response(status_code=200) @router.post("/reset_mm_cache") @@ -1228,6 +1231,41 @@ INVOCATION_VALIDATORS = [ ] +@router.post( + "/inference/v1/generate", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.OK.value: {"content": {"text/event-stream": {}}}, + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) +@with_cancellation +@load_aware_call +async def generate(request: GenerateRequest, raw_request: Request): + handler = generate_tokens(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support generate tokens API" + ) + try: + generator = await handler.serve_tokens(request, raw_request) + except Exception as e: + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) + ) from e + if isinstance(generator, ErrorResponse): + return JSONResponse( + content=generator.model_dump(), status_code=generator.error.code + ) + + elif isinstance(generator, GenerateResponse): + return JSONResponse(content=generator.model_dump()) + + return StreamingResponse(content=generator, media_type="text/event-stream") + + if envs.VLLM_TORCH_PROFILER_DIR: logger.warning_once( "Torch Profiler is enabled in the API server. This should ONLY be " @@ -1629,6 +1667,31 @@ def build_app(args: Namespace) -> FastAPI: ) app = sagemaker_standards.bootstrap(app) + # Optional endpoints + if args.tokens_only: + + @app.post("/abort_requests") + async def abort_requests(raw_request: Request): + """ + Abort one or more requests. To be used in a + Disaggregated Everything setup. + """ + try: + body = await raw_request.json() + except json.JSONDecodeError as e: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail=f"JSON decode error: {e}", + ) from e + request_ids = body.get("request_ids") + if request_ids is None: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail="Missing 'request_ids' in request body", + ) + # Abort requests in background + asyncio.create_task(engine_client(raw_request).abort(request_ids)) + return Response(status_code=200) return app @@ -1784,6 +1847,9 @@ async def init_app_state( engine_client, state.openai_serving_models, request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + trust_request_chat_template=args.trust_request_chat_template, log_error_stack=args.log_error_stack, ) if "classify" in supported_tasks @@ -1848,6 +1914,20 @@ async def init_app_state( if "generate" in supported_tasks else None ) + state.serving_tokens = ( + ServingTokens( + engine_client, + state.openai_serving_models, + request_logger=request_logger, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + log_error_stack=args.log_error_stack, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, + enable_log_outputs=args.enable_log_outputs, + force_no_detokenize=args.tokens_only, + ) + if "generate" in supported_tasks + else None + ) state.enable_server_load_tracking = args.enable_server_load_tracking state.server_load_metrics = 0 diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 476587c178237..946362ce2ef0a 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -189,6 +189,11 @@ class FrontendArgs: Helps mitigate header abuse. Default: 256.""" log_error_stack: bool = envs.VLLM_SERVER_DEV_MODE """If set to True, log the stack trace of error responses""" + tokens_only: bool = False + """ + If set to True, only enable the Tokens In<>Out endpoint. + This is intended for use in a Disaggregated Everything setup. + """ @staticmethod def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 69e757d4764d2..65bd15ba387b9 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -2000,10 +2000,10 @@ class ScoreResponse(OpenAIBaseModel): usage: UsageInfo -class ClassificationRequest(OpenAIBaseModel): +class ClassificationCompletionRequest(OpenAIBaseModel): model: str | None = None input: list[str] | str - truncate_prompt_tokens: int | None = None + truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None user: str | None = None # --8<-- [start:classification-extra-params] @@ -2015,7 +2015,21 @@ class ClassificationRequest(OpenAIBaseModel): "if the served model does not use priority scheduling." ), ) - + add_special_tokens: bool = Field( + default=True, + description=( + "If true (the default), special tokens (e.g. BOS) will be added to " + "the prompt." + ), + ) + request_id: str = Field( + default_factory=lambda: f"{random_uuid()}", + description=( + "The request_id related to this request. If the caller does " + "not set it, a random_uuid will be generated. This id is used " + "through out the inference process and return in response." + ), + ) softmax: bool | None = Field( default=None, description="softmax will be deprecated, please use use_activation instead.", @@ -2040,6 +2054,102 @@ class ClassificationRequest(OpenAIBaseModel): ) +class ClassificationChatRequest(OpenAIBaseModel): + model: str | None = None + messages: list[ChatCompletionMessageParam] + truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None + user: str | None = None + + # --8<-- [start:chat-classification-extra-params] + add_generation_prompt: bool = Field( + default=False, + description=( + "If true, the generation prompt will be added to the chat template. " + "This is a parameter used by chat template in tokenizer config of the " + "model." + ), + ) + + add_special_tokens: bool = Field( + default=False, + description=( + "If true, special tokens (e.g. BOS) will be added to the prompt " + "on top of what is added by the chat template. " + "For most models, the chat template takes care of adding the " + "special tokens so this should be set to false (as is the " + "default)." + ), + ) + + chat_template: str | None = Field( + default=None, + description=( + "A Jinja template to use for this conversion. " + "As of transformers v4.44, default chat template is no longer " + "allowed, so you must provide a chat template if the tokenizer " + "does not define one." + ), + ) + + chat_template_kwargs: dict[str, Any] | None = Field( + default=None, + description=( + "Additional keyword args to pass to the template renderer. " + "Will be accessible by the chat template." + ), + ) + + mm_processor_kwargs: dict[str, Any] | None = Field( + default=None, + description=("Additional kwargs to pass to the HF processor."), + ) + + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling." + ), + ) + + request_id: str = Field( + default_factory=lambda: f"{random_uuid()}", + description=( + "The request_id related to this request. If the caller does " + "not set it, a random_uuid will be generated. This id is used " + "through out the inference process and return in response." + ), + ) + softmax: bool | None = Field( + default=None, + description="softmax will be deprecated, please use use_activation instead.", + ) + + activation: bool | None = Field( + default=None, + description="activation will be deprecated, please use use_activation instead.", + ) + + use_activation: bool | None = Field( + default=None, + description="Whether to use activation for classification outputs. " + "Default is True.", + ) + # --8<-- [end:chat-classification-extra-params] + + def to_pooling_params(self): + return PoolingParams( + truncate_prompt_tokens=self.truncate_prompt_tokens, + use_activation=get_use_activation(self), + ) + + +ClassificationRequest: TypeAlias = ( + ClassificationCompletionRequest | ClassificationChatRequest +) + + class ClassificationData(OpenAIBaseModel): index: int label: str | None @@ -3110,3 +3220,80 @@ class TranslationResponseVerbose(OpenAIBaseModel): words: list[TranslationWord] | None = None """Extracted words and their corresponding timestamps.""" + + +####### Tokens IN <> Tokens OUT ####### +class GenerateRequest(BaseModel): + request_id: str = Field( + default_factory=lambda: f"{random_uuid()}", + description=( + "The request_id related to this request. If the caller does " + "not set it, a random_uuid will be generated. This id is used " + "through out the inference process and return in response." + ), + ) + token_ids: list[int] + """The token ids to generate text from.""" + + # features: MultiModalFeatureSpec + # TODO (NickLucche): implement once Renderer work is completed + features: str | None = None + """The processed MM inputs for the model.""" + + sampling_params: SamplingParams + """The sampling parameters for the model.""" + + model: str | None = None + + stream: bool | None = False + stream_options: StreamOptions | None = None + cache_salt: str | None = Field( + default=None, + description=( + "If specified, the prefix cache will be salted with the provided " + "string to prevent an attacker to guess prompts in multi-user " + "environments. The salt should be random, protected from " + "access by 3rd parties, and long enough to be " + "unpredictable (e.g., 43 characters base64-encoded, corresponding " + "to 256 bit)." + ), + ) + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling." + ), + ) + kv_transfer_params: dict[str, Any] | None = Field( + default=None, + description="KVTransfer parameters used for disaggregated serving.", + ) + + +class GenerateResponseChoice(BaseModel): + index: int + logprobs: ChatCompletionLogProbs | None = None + # per OpenAI spec this is the default + finish_reason: str | None = "stop" + token_ids: list[int] | None = None + + +class GenerateResponse(BaseModel): + request_id: str = Field( + default_factory=lambda: f"{random_uuid()}", + description=( + "The request_id related to this request. If the caller does " + "not set it, a random_uuid will be generated. This id is used " + "through out the inference process and return in response." + ), + ) + choices: list[GenerateResponseChoice] + + prompt_logprobs: list[dict[int, Logprob] | None] | None = None + + kv_transfer_params: dict[str, Any] | None = Field( + default=None, + description="KVTransfer parameters used for disaggregated serving.", + ) diff --git a/vllm/entrypoints/openai/serving_classification.py b/vllm/entrypoints/openai/serving_classification.py index 45bbe732a680f..167ee152fece3 100644 --- a/vllm/entrypoints/openai/serving_classification.py +++ b/vllm/entrypoints/openai/serving_classification.py @@ -4,13 +4,17 @@ from http import HTTPStatus from typing import cast +import jinja2 import numpy as np from fastapi import Request -from typing_extensions import override from vllm.engine.protocol import EngineClient +from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ClassificationChatRequest, + ClassificationCompletionRequest, ClassificationData, ClassificationRequest, ClassificationResponse, @@ -32,7 +36,10 @@ logger = init_logger(__name__) class ClassificationMixin(OpenAIServing): - @override + chat_template: str | None + chat_template_content_format: ChatTemplateContentFormatOption + trust_request_chat_template: bool + async def _preprocess( self, ctx: ServeContext, @@ -42,31 +49,79 @@ class ClassificationMixin(OpenAIServing): and prepare model-specific inputs. """ ctx = cast(ClassificationServeContext, ctx) - if isinstance(ctx.request.input, str) and not ctx.request.input: - return self.create_error_response( - "Input cannot be empty for classification", - status_code=HTTPStatus.BAD_REQUEST, - ) - - if isinstance(ctx.request.input, list) and len(ctx.request.input) == 0: - return None - try: ctx.tokenizer = await self.engine_client.get_tokenizer() - renderer = self._get_renderer(ctx.tokenizer) - ctx.engine_prompts = await renderer.render_prompt( - prompt_or_prompts=ctx.request.input, - config=self._build_render_config(ctx.request), - ) + request_obj = ctx.request + + if isinstance(request_obj, ClassificationChatRequest): + chat_request = request_obj + messages = chat_request.messages + trust_request_chat_template = getattr( + self, + "trust_request_chat_template", + False, + ) + ret = self._validate_chat_template( + request_chat_template=chat_request.chat_template, + chat_template_kwargs=chat_request.chat_template_kwargs, + trust_request_chat_template=trust_request_chat_template, + ) + if ret: + return ret + + ( + _, + _, + engine_prompts, + ) = await self._preprocess_chat( + cast(ChatCompletionRequest, chat_request), + ctx.tokenizer, + messages, + chat_template=( + chat_request.chat_template + or getattr(self, "chat_template", None) + ), + chat_template_content_format=cast( + ChatTemplateContentFormatOption, + getattr(self, "chat_template_content_format", "auto"), + ), + add_generation_prompt=False, + continue_final_message=False, + add_special_tokens=chat_request.add_special_tokens, + ) + ctx.engine_prompts = engine_prompts + + elif isinstance(request_obj, ClassificationCompletionRequest): + completion_request = request_obj + input_data = completion_request.input + if input_data in (None, ""): + return self.create_error_response( + "Input or messages must be provided", + status_code=HTTPStatus.BAD_REQUEST, + ) + if isinstance(input_data, list) and not input_data: + ctx.engine_prompts = [] + return None + + renderer = self._get_renderer(ctx.tokenizer) + prompt_input = cast(str | list[str], input_data) + ctx.engine_prompts = await renderer.render_prompt( + prompt_or_prompts=prompt_input, + config=self._build_render_config(completion_request), + ) + else: + return self.create_error_response( + "Invalid classification request type", + status_code=HTTPStatus.BAD_REQUEST, + ) return None - except (ValueError, TypeError) as e: + except (ValueError, TypeError, jinja2.TemplateError) as e: logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(str(e)) - @override def _build_response( self, ctx: ServeContext, @@ -118,6 +173,7 @@ class ClassificationMixin(OpenAIServing): return RenderConfig( max_length=self.max_model_len, truncate_prompt_tokens=request.truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, ) @@ -130,6 +186,9 @@ class ServingClassification(ClassificationMixin): models: OpenAIServingModels, *, request_logger: RequestLogger | None, + chat_template: str | None = None, + chat_template_content_format: ChatTemplateContentFormatOption = "auto", + trust_request_chat_template: bool = False, log_error_stack: bool = False, ) -> None: super().__init__( @@ -139,6 +198,10 @@ class ServingClassification(ClassificationMixin): log_error_stack=log_error_stack, ) + self.chat_template = chat_template + self.chat_template_content_format = chat_template_content_format + self.trust_request_chat_template = trust_request_chat_template + async def create_classify( self, request: ClassificationRequest, @@ -156,7 +219,6 @@ class ServingClassification(ClassificationMixin): return await super().handle(ctx) # type: ignore - @override def _create_pooling_params( self, ctx: ClassificationServeContext, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 1456727a3cdd6..c50b0c4a23e17 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -43,6 +43,8 @@ from vllm.entrypoints.openai.protocol import ( ChatCompletionNamedToolChoiceParam, ChatCompletionRequest, ChatCompletionResponse, + ClassificationChatRequest, + ClassificationCompletionRequest, ClassificationRequest, ClassificationResponse, CompletionRequest, @@ -56,6 +58,8 @@ from vllm.entrypoints.openai.protocol import ( ErrorResponse, FunctionCall, FunctionDefinition, + GenerateRequest, + GenerateResponse, IOProcessorRequest, PoolingResponse, RerankRequest, @@ -114,13 +118,16 @@ CompletionLikeRequest: TypeAlias = ( | DetokenizeRequest | EmbeddingCompletionRequest | RerankRequest - | ClassificationRequest + | ClassificationCompletionRequest | ScoreRequest | TokenizeCompletionRequest ) ChatLikeRequest: TypeAlias = ( - ChatCompletionRequest | EmbeddingChatRequest | TokenizeChatRequest + ChatCompletionRequest + | EmbeddingChatRequest + | TokenizeChatRequest + | ClassificationChatRequest ) SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest AnyRequest: TypeAlias = ( @@ -129,6 +136,7 @@ AnyRequest: TypeAlias = ( | SpeechToTextRequest | ResponsesRequest | IOProcessorRequest + | GenerateRequest ) AnyResponse: TypeAlias = ( @@ -140,6 +148,7 @@ AnyResponse: TypeAlias = ( | PoolingResponse | ClassificationResponse | ScoreResponse + | GenerateResponse ) @@ -814,7 +823,11 @@ class OpenAIServing: if not hasattr(request, "messages"): return message_types - for message in request.messages: + messages = request.messages + if messages is None or isinstance(messages, (str, bytes)): + return message_types + + for message in messages: if ( isinstance(message, dict) and "content" in message @@ -907,7 +920,8 @@ class OpenAIServing: EmbeddingCompletionRequest, ScoreRequest, RerankRequest, - ClassificationRequest, + ClassificationCompletionRequest, + ClassificationChatRequest, ), ): # Note: input length can be up to the entire model context length @@ -915,7 +929,8 @@ class OpenAIServing: if token_num > self.max_model_len: operations: dict[type[AnyRequest], str] = { ScoreRequest: "score", - ClassificationRequest: "classification", + ClassificationCompletionRequest: "classification", + ClassificationChatRequest: "classification", } operation = operations.get(type(request), "embedding generation") raise ValueError( diff --git a/vllm/entrypoints/openai/serving_pooling.py b/vllm/entrypoints/openai/serving_pooling.py index 0eade272111f1..ee4c5c8bacaae 100644 --- a/vllm/entrypoints/openai/serving_pooling.py +++ b/vllm/entrypoints/openai/serving_pooling.py @@ -4,7 +4,7 @@ import asyncio import json import time -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Sequence from typing import Final, cast import jinja2 @@ -122,6 +122,10 @@ class OpenAIServingPooling(OpenAIServing): engine_prompts = await self.io_processor.pre_process_async( prompt=validated_prompt, request_id=request_id ) + if not isinstance(engine_prompts, Sequence) or isinstance( + engine_prompts, (str, bytes, bytearray) + ): + engine_prompts = [engine_prompts] elif isinstance(request, PoolingChatRequest): error_check_ret = self._validate_chat_template( diff --git a/vllm/entrypoints/openai/serving_tokens.py b/vllm/entrypoints/openai/serving_tokens.py new file mode 100644 index 0000000000000..69a526b9b70d2 --- /dev/null +++ b/vllm/entrypoints/openai/serving_tokens.py @@ -0,0 +1,269 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio +import time +from collections.abc import AsyncGenerator +from collections.abc import Sequence as GenericSequence + +from fastapi import Request + +# yapf: disable +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import ( + ChatCompletionLogProb, + ChatCompletionLogProbs, + ChatCompletionLogProbsContent, + ErrorResponse, + GenerateRequest, + GenerateResponse, + GenerateResponseChoice, + PromptTokenUsageInfo, + RequestResponseMetadata, + UsageInfo, +) +from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.inputs.data import TokensPrompt as EngineTokensPrompt +from vllm.logger import init_logger +from vllm.logprobs import Logprob +from vllm.outputs import RequestOutput +from vllm.sampling_params import SamplingParams +from vllm.utils.collection_utils import as_list + +logger = init_logger(__name__) + + +class ServingTokens(OpenAIServing): + """Provides Tokens IN <> Tokens OUT functionality to vLLM API.""" + + def __init__( + self, + engine_client: EngineClient, + models: OpenAIServingModels, + *, + request_logger: RequestLogger | None, + force_no_detokenize: bool = False, + return_tokens_as_token_ids: bool = False, + log_error_stack: bool = False, + enable_prompt_tokens_details: bool = False, + enable_log_outputs: bool = False, + ): + super().__init__(engine_client=engine_client, + models=models, + request_logger=request_logger, + return_tokens_as_token_ids=return_tokens_as_token_ids, + log_error_stack=log_error_stack) + self.enable_prompt_tokens_details = enable_prompt_tokens_details + self.enable_log_outputs = enable_log_outputs + self.force_no_detokenize = force_no_detokenize + if force_no_detokenize: + logger.info("Tokens-only mode is enabled, skipping detokenization " + "step for incoming requests.") + + async def serve_tokens( + self, + request: GenerateRequest, + raw_request: Request | None = None + ) -> GenerateResponse | ErrorResponse: + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + logger.error("Error with model %s", error_check_ret) + return error_check_ret + + # If the engine is dead, raise the engine's DEAD_ERROR. + # This is required for the streaming case, where we return a + # success status before we actually start generating text :). + if self.engine_client.errored: + raise self.engine_client.dead_error + + lora_request = None + lora_request = self._maybe_get_adapters(request, + supports_default_mm_loras=True) + + model_name = self.models.model_name(lora_request) + + request_id = "generate-tokens-" \ + f"{self._base_request_id(raw_request, request.request_id)}" + + request_metadata = RequestResponseMetadata(request_id=request_id) + if raw_request: + raw_request.state.request_metadata = request_metadata + + # TODO(NickLucche): Change to EngineCoreRequest once Renderer work is + # completed + engine_prompt = EngineTokensPrompt(prompt_token_ids=request.token_ids) + if request.features is not None: + engine_prompt["multi_modal_data"] = None + + if hasattr(request, "cache_salt") and request.cache_salt is not None: + engine_prompt["cache_salt"] = request.cache_salt + + # Schedule the request and get the result generator. + result_generator: AsyncGenerator[RequestOutput, None] | None = None + try: + sampling_params = request.sampling_params + if self.force_no_detokenize: + sampling_params.detokenize = False + + self._log_inputs(request_id, + request.token_ids, + params=sampling_params, + lora_request=lora_request) + + trace_headers = (None if raw_request is None else await + self._get_trace_headers(raw_request.headers)) + + result_generator = self.engine_client.generate( + engine_prompt, + sampling_params, + request_id, + lora_request=lora_request, + trace_headers=trace_headers, + priority=request.priority, + ) + + except ValueError as e: + return self.create_error_response(str(e)) + + # TODO(NickLucche): Implement streaming response + + try: + assert result_generator is not None + return await self.serve_tokens_full_generator( + request, result_generator, request_id, model_name, + request_metadata) + except ValueError as e: + return self.create_error_response(str(e)) + + async def serve_tokens_full_generator( + self, + request: GenerateRequest, + result_generator: AsyncGenerator[RequestOutput, None], + request_id: str, + model_name: str, + request_metadata: RequestResponseMetadata, + ) -> ErrorResponse | GenerateResponse: + + created_time = int(time.time()) + final_res: RequestOutput | None = None + sampling_params: SamplingParams = request.sampling_params + + try: + async for res in result_generator: + final_res = res + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + return self.create_error_response(str(e)) + + assert final_res is not None + + choices: list[GenerateResponseChoice] = [] + num_generated_tokens = 0 + for output in final_res.outputs: + token_ids = output.token_ids + out_logprobs = output.logprobs + + # This is top_logprobs in completions API + if sampling_params.logprobs: + assert out_logprobs is not None, "Did not output logprobs" + logprobs = self._create_tokens_logprobs( + token_ids=token_ids, + top_logprobs=out_logprobs, + num_output_top_logprobs=sampling_params.logprobs, + ) + else: + logprobs = None + + choice_data = GenerateResponseChoice( + index=output.index, + logprobs=logprobs, + finish_reason=output.finish_reason + if output.finish_reason else "stop", + token_ids=as_list(output.token_ids)) + + choices.append(choice_data) + num_generated_tokens += len(output.token_ids) + + assert final_res.prompt_token_ids is not None + num_prompt_tokens = len(final_res.prompt_token_ids) + if final_res.encoder_prompt_token_ids is not None: + num_prompt_tokens += len(final_res.encoder_prompt_token_ids) + + usage = UsageInfo(prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + + num_generated_tokens) + if self.enable_prompt_tokens_details and final_res.num_cached_tokens: + # This info is not available at the /coordinator level + usage.prompt_tokens_details = PromptTokenUsageInfo( + cached_tokens=final_res.num_cached_tokens) + + request_metadata.final_usage_info = usage + + response = GenerateResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, + prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs), + kv_transfer_params=final_res.kv_transfer_params, + ) + + # Log complete response if output logging is enabled + if self.enable_log_outputs and self.request_logger: + for choice in choices: + # Get the corresponding output token IDs + output_token_ids = None + if choice.index < len(final_res.outputs): + output_token_ids = final_res.outputs[ + choice.index].token_ids + + if output_token_ids: + # Log token_ids only. + self.request_logger.log_outputs( + request_id=request_id, + outputs="", + output_token_ids=output_token_ids, + finish_reason=choice.finish_reason, + is_streaming=False, + delta=False, + ) + + return response + + def _create_tokens_logprobs( + self, + token_ids: GenericSequence[int], + top_logprobs: GenericSequence[dict[int, Logprob] | None], + num_output_top_logprobs: int | None = None, + ) -> ChatCompletionLogProbs: + """Create OpenAI-style logprobs.""" + logprobs_content: list[ChatCompletionLogProbsContent] = [] + + for i, token_id in enumerate(token_ids): + token = f"token_id:{token_id}" + step_top_logprobs = top_logprobs[i] + if step_top_logprobs is None or step_top_logprobs.get( + token_id) is None: + logprobs_content.append( + ChatCompletionLogProbsContent(token=token, )) + else: + step_token = step_top_logprobs[token_id] + + logprobs_content.append( + ChatCompletionLogProbsContent( + token=token, + logprob=max(step_token.logprob, -9999.0), + top_logprobs=[ + ChatCompletionLogProb( + token=token, + logprob=max(p[1].logprob, -9999.0), + ) for i, p in enumerate(step_top_logprobs.items()) + if num_output_top_logprobs + and i < num_output_top_logprobs + ])) + + return ChatCompletionLogProbs(content=logprobs_content) diff --git a/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py index 0453db58361a9..a84c9e4547168 100644 --- a/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py @@ -34,8 +34,27 @@ class KimiK2ToolParser(ToolParser): str ] = [] # map what has been streamed for each tool so far to a list + # Section-level state management to prevent token leakage + self.in_tool_section: bool = False + self.token_buffer: str = "" + # Buffer size: empirical worst-case for longest marker (~30 chars) * 2 + # + safety margin for unicode + partial overlap. Prevents unbounded growth. + self.buffer_max_size: int = 1024 + self.section_char_count: int = 0 # Track characters processed in tool section + self.max_section_chars: int = 8192 # Force exit if section exceeds this + self._buffer_overflow_logged: bool = False # Log overflow once per session + + # Support both singular and plural variants self.tool_calls_start_token: str = "<|tool_calls_section_begin|>" self.tool_calls_end_token: str = "<|tool_calls_section_end|>" + self.tool_calls_start_token_variants: list[str] = [ + "<|tool_calls_section_begin|>", + "<|tool_call_section_begin|>", # singular variant + ] + self.tool_calls_end_token_variants: list[str] = [ + "<|tool_calls_section_end|>", + "<|tool_call_section_end|>", # singular variant + ] self.tool_call_start_token: str = "<|tool_call_begin|>" self.tool_call_end_token: str = "<|tool_call_end|>" @@ -58,6 +77,18 @@ class KimiK2ToolParser(ToolParser): self.tool_calls_start_token_id = self.vocab.get(self.tool_calls_start_token) self.tool_calls_end_token_id = self.vocab.get(self.tool_calls_end_token) + # Get token IDs for all variants + self.tool_calls_start_token_ids: list[int] = [ + tid + for variant in self.tool_calls_start_token_variants + if (tid := self.vocab.get(variant)) is not None + ] + self.tool_calls_end_token_ids: list[int] = [ + tid + for variant in self.tool_calls_end_token_variants + if (tid := self.vocab.get(variant)) is not None + ] + self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token) self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) @@ -70,6 +101,51 @@ class KimiK2ToolParser(ToolParser): "tokens in the tokenizer!" ) + def _check_and_strip_markers(self, text: str) -> tuple[str, bool, bool]: + """ + Check for section begin/end markers in text and strip them. + Returns: (cleaned_text, found_section_begin, found_section_end) + """ + found_begin = False + found_end = False + cleaned = text + + # Check for section begin markers (any variant) + for variant in self.tool_calls_start_token_variants: + if variant in cleaned: + cleaned = cleaned.replace(variant, "") + found_begin = True + + # Check for section end markers (any variant) + for variant in self.tool_calls_end_token_variants: + if variant in cleaned: + cleaned = cleaned.replace(variant, "") + found_end = True + + return cleaned, found_begin, found_end + + def _reset_section_state(self) -> None: + """Reset state when exiting tool section.""" + self.in_tool_section = False + self.token_buffer = "" + self.section_char_count = 0 + + def reset_streaming_state(self) -> None: + """ + Reset all streaming state. Call this between requests to prevent + state leakage when parser instance is reused. + """ + # Reset section state + self._reset_section_state() + + # Reset parent class state + self.current_tool_name_sent = False + self.prev_tool_call_arr = [] + self.current_tool_id = -1 + self.streamed_args_for_tool = [] + + logger.debug("Streaming state reset") + def extract_tool_calls( self, model_output: str, @@ -131,13 +207,94 @@ class KimiK2ToolParser(ToolParser): ) -> DeltaMessage | None: logger.debug("delta_text: %s", delta_text) logger.debug("delta_token_ids: %s", delta_token_ids) - # check to see if we should be streaming a tool call - is there a - if self.tool_calls_start_token_id not in current_token_ids: - logger.debug("No tool call tokens found!") - return DeltaMessage(content=delta_text) - delta_text = delta_text.replace(self.tool_calls_start_token, "").replace( - self.tool_calls_end_token, "" + + # Flag to defer section exit until after tool parsing completes + deferred_section_exit = False + + # Add delta to buffer for split marker detection + self.token_buffer += delta_text + + # Enforce buffer size limit to prevent memory issues + if len(self.token_buffer) > self.buffer_max_size: + if not self._buffer_overflow_logged: + logger.warning( + "Token buffer exceeded max size (%d bytes), flushing excess. " + "This may indicate very long markers or unusual tokenization.", + self.buffer_max_size, + ) + self._buffer_overflow_logged = True + # Keep only the most recent content that might contain partial markers + self.token_buffer = self.token_buffer[-self.buffer_max_size // 2 :] + + # Check buffer for section markers (handles split tokens) + buffered_text, found_section_begin, found_section_end = ( + self._check_and_strip_markers(self.token_buffer) ) + + # Track section state transitions + if found_section_begin and not self.in_tool_section: + logger.debug("Entering tool section") + self.in_tool_section = True + self.token_buffer = buffered_text # Use cleaned buffer + self.section_char_count = 0 # Reset counter for new section + if found_section_end and self.in_tool_section: + logger.debug("Detected section end marker") + # CRITICAL: Don't exit early if tool_call_end is in this chunk. + # Tool parser must emit final arguments/close first to avoid dropping + # the final tool update and leaking tokens into reasoning channel. + has_tool_end = self.tool_call_end_token_id in delta_token_ids + if has_tool_end: + # Defer exit until after tool parsing completes + deferred_section_exit = True + logger.debug("Deferring section exit: tool_call_end in same chunk") + self.token_buffer = buffered_text + else: + # No tool call ending, safe to exit immediately + logger.debug("Exiting tool section") + remaining = buffered_text + self._reset_section_state() + # Return remaining text as reasoning content if non-empty + if remaining.strip(): + return DeltaMessage(content=remaining) + # Return empty delta to maintain function contract + # (always returns DeltaMessage) + return DeltaMessage(content="") + else: + self.token_buffer = buffered_text + + # Check if any variant of section start token is in current_token_ids + has_section_token = any( + tid in current_token_ids for tid in self.tool_calls_start_token_ids + ) + + # Early return: if no section token detected yet, return as reasoning content + if not has_section_token and not self.in_tool_section: + logger.debug("No tool call tokens found!") + # Don't clear buffer - it needs to accumulate partial markers across deltas + # Buffer overflow is already protected by lines 215-224 + return DeltaMessage(content=delta_text) + + # Strip section markers from delta_text for subsequent processing + # NOTE: This preprocessing happens BEFORE the regex-based tool call + # parsing (from PR #24847) to ensure markers are removed cleanly + # before pattern matching. No double-stripping occurs because + # section markers and tool call markers are distinct. + delta_text, _, _ = self._check_and_strip_markers(delta_text) + + # Error recovery: If in tool section for too long, force exit + if self.in_tool_section: + self.section_char_count += len(delta_text) + if self.section_char_count > self.max_section_chars: + logger.warning( + "Tool section exceeded max length (%d chars), forcing exit. " + "This may indicate malformed model output.", + self.max_section_chars, + ) + self._reset_section_state() + # Deferred exit already handled by forced exit above + # Return remaining content as reasoning (or empty delta if no content) + return DeltaMessage(content=delta_text if delta_text.strip() else "") + try: # figure out where we are in the parsing by counting tool call # start & end tags @@ -158,6 +315,16 @@ class KimiK2ToolParser(ToolParser): and prev_tool_end_count == cur_tool_end_count and self.tool_call_end_token not in delta_text ): + # CRITICAL FIX: Suppress content if in tool section but + # no tool calls started + if self.in_tool_section and cur_tool_start_count == 0: + logger.debug( + "In tool section but no tool calls started yet. " + "Suppressing: %s", + delta_text, + ) + # Return empty delta to maintain iterator contract + return DeltaMessage(content="") logger.debug("Generating text content! skipping tool parsing.") return DeltaMessage(content=delta_text) @@ -209,6 +376,9 @@ class KimiK2ToolParser(ToolParser): ): if self.prev_tool_call_arr is None or len(self.prev_tool_call_arr) == 0: logger.debug("attempting to close tool call, but no tool call") + # Handle deferred section exit before returning + if deferred_section_exit and self.in_tool_section: + self._reset_section_state() return None diff = self.prev_tool_call_arr[self.current_tool_id].get("arguments") if diff: @@ -218,6 +388,9 @@ class KimiK2ToolParser(ToolParser): else diff ) if '"}' not in delta_text: + # Handle deferred section exit before returning + if deferred_section_exit and self.in_tool_section: + self._reset_section_state() return None end_loc = delta_text.rindex('"}') diff = delta_text[:end_loc] + '"}' @@ -227,6 +400,10 @@ class KimiK2ToolParser(ToolParser): diff, ) self.streamed_args_for_tool[self.current_tool_id] += diff + # Handle deferred section exit before returning + if deferred_section_exit and self.in_tool_section: + logger.debug("Completing deferred section exit") + self._reset_section_state() return DeltaMessage( tool_calls=[ DeltaToolCall( @@ -240,9 +417,19 @@ class KimiK2ToolParser(ToolParser): # case -- otherwise we're just generating text else: + # Check if we're in tool section - if so, suppress + if self.in_tool_section: + logger.debug("In tool section, suppressing text generation") + # Handle deferred section exit before returning + if deferred_section_exit: + self._reset_section_state() + return DeltaMessage(content="") text = delta_text.replace(self.tool_call_start_token, "") text = text.replace(self.tool_call_end_token, "") delta = DeltaMessage(tool_calls=[], content=text) + # Handle deferred section exit before returning + if deferred_section_exit and self.in_tool_section: + self._reset_section_state() return delta current_tool_call = dict() @@ -390,6 +577,11 @@ class KimiK2ToolParser(ToolParser): else: self.prev_tool_call_arr.append(current_tool_call) + # Handle deferred section exit after tool parsing completes + if deferred_section_exit and self.in_tool_section: + logger.debug("Completing deferred section exit") + self._reset_section_state() + return delta except Exception: diff --git a/vllm/envs.py b/vllm/envs.py index 0530938c32f9e..6bf05803e14ef 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -92,6 +92,7 @@ if TYPE_CHECKING: VLLM_TORCH_PROFILER_RECORD_SHAPES: bool = False VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY: bool = False VLLM_USE_AOT_COMPILE: bool = False + VLLM_USE_BYTECODE_HOOK: bool = False VLLM_FORCE_AOT_LOAD: bool = False VLLM_TORCH_PROFILER_WITH_STACK: bool = True VLLM_TORCH_PROFILER_WITH_FLOPS: bool = False @@ -422,7 +423,7 @@ def get_vllm_port() -> int | None: raise ValueError(f"VLLM_PORT '{port}' must be a valid integer") from err -# The begin-* and end* here are used by the documentation generator +# The start-* and end* here are used by the documentation generator # to extract the used env vars. # --8<-- [start:env-vars-definition] @@ -556,6 +557,11 @@ environment_variables: dict[str, Callable[[], Any]] = { # compilation is done in warmup phase and the compilation will be # reused in subsequent calls. "VLLM_USE_AOT_COMPILE": use_aot_compile, + # Feature flag to enable/disable bytecode in + # TorchCompileWithNoGuardsWrapper. + "VLLM_USE_BYTECODE_HOOK": lambda: bool( + int(os.environ.get("VLLM_USE_BYTECODE_HOOK", "1")) + ), # Force vllm to always load AOT compiled models from disk. Failure # to load will result in a hard error when this is enabled. # Will be ignored when VLLM_USE_AOT_COMPILE is disabled. diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 44bc2a4cda311..25fb7181a8f29 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -221,6 +221,10 @@ def get_forward_context() -> ForwardContext: return _forward_context +def is_forward_context_available() -> bool: + return _forward_context is not None + + def create_forward_context( attn_metadata: Any, vllm_config: VllmConfig, diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 80d5322a34c3a..839c13868a16c 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -348,18 +348,15 @@ class InputPreprocessor: ) inputs: TokenInputs | MultiModalInputs - if self.model_config.is_multimodal_model: + if multi_modal_data := parsed_content.get("multi_modal_data"): inputs = self._process_multimodal( prompt_token_ids, - parsed_content.get("multi_modal_data") or {}, + multi_modal_data, parsed_content.get("mm_processor_kwargs") or {}, tokenization_kwargs=tokenization_kwargs, mm_uuids=mm_uuids, ) else: - if parsed_content.get("multi_modal_data"): - raise ValueError("This model does not support multimodal inputs") - inputs = token_inputs(prompt_token_ids) if cache_salt := parsed_content.get("cache_salt"): @@ -377,18 +374,15 @@ class InputPreprocessor: prompt_text = parsed_content["prompt"] inputs: TokenInputs | MultiModalInputs - if self.model_config.is_multimodal_model: + if multi_modal_data := parsed_content.get("multi_modal_data"): inputs = self._process_multimodal( prompt_text, - parsed_content.get("multi_modal_data") or {}, + multi_modal_data, parsed_content.get("mm_processor_kwargs") or {}, tokenization_kwargs=tokenization_kwargs, mm_uuids=mm_uuids, ) else: - if parsed_content.get("multi_modal_data"): - raise ValueError("This model does not support multimodal inputs") - prompt_token_ids = self._tokenize_prompt( prompt_text, tokenization_kwargs=tokenization_kwargs, diff --git a/vllm/lora/layers/base_linear.py b/vllm/lora/layers/base_linear.py index d619a0edc1241..3db4165e20176 100644 --- a/vllm/lora/layers/base_linear.py +++ b/vllm/lora/layers/base_linear.py @@ -121,7 +121,7 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x, bias) - # In transformers backend, x and output have extra batch dimension like + # In Transformers modeling backend, x and output have extra batch dimension like # (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim), # therefore we need to flatten the batch dimensions. if x.ndim == 3 and output.ndim == 3: diff --git a/vllm/model_executor/layers/conv.py b/vllm/model_executor/layers/conv.py new file mode 100644 index 0000000000000..e6f2d2990c241 --- /dev/null +++ b/vllm/model_executor/layers/conv.py @@ -0,0 +1,236 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Conv Layer Class.""" + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from vllm.model_executor.custom_op import CustomOp +from vllm.utils.torch_utils import is_torch_equal + + +class ConvLayerBase(CustomOp): + """Conv layer base class.""" + + num_dim: int + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int | tuple[int, ...], + stride: int | tuple[int, ...] = 1, + padding: int | tuple[int, ...] = 0, + dilation: int | tuple[int, ...] = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + *, + params_dtype: torch.dtype | None = None, + ) -> None: + super().__init__() + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + + kernel_size = ( + (kernel_size,) * self.num_dim + if isinstance(kernel_size, int) + else kernel_size + ) + stride = (stride,) * self.num_dim if isinstance(stride, int) else stride + padding = (padding,) * self.num_dim if isinstance(padding, int) else padding + dilation = (dilation,) * self.num_dim if isinstance(dilation, int) else dilation + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.padding_mode = padding_mode + + self.enable_linear = ( + (self.kernel_size == self.stride) + and not any(self.padding) + and self.groups == 1 + ) + self.input_size = in_channels * math.prod(self.kernel_size) + + self.weight = nn.Parameter( + torch.empty( + out_channels, + in_channels // groups, + *kernel_size, + dtype=params_dtype, + ), + ) + + if bias: + self.bias = nn.Parameter(torch.empty(self.out_channels, dtype=params_dtype)) + else: + self.register_parameter("bias", None) + + def extra_repr(self) -> str: + s = f"in_channels={self.in_channels}, " + s += f"out_channels={self.out_channels}, " + s += f"kernel_size={self.kernel_size}, " + s += f"stride={self.stride}, " + s += f"padding={self.padding}, " + s += f"bias={self.bias is not None}" + return s + + +@CustomOp.register("conv2d") +class Conv2dLayer(ConvLayerBase): + """Conv layer with Conv2d.""" + + num_dim = 2 + + def _forward_mulmat(self, x: torch.Tensor) -> torch.Tensor: + assert x.dim() == 4 + B, C, H, W = x.shape + K1, K2 = self.kernel_size + H, W = H // K1, W // K2 + x = x.unfold(2, K1, K1).unfold(3, K2, K2) + x = x.permute(0, 2, 3, 1, 4, 5).reshape(-1, self.input_size) + x = F.linear( + x, + self.weight.view(self.out_channels, self.input_size), + self.bias, + ) + x = x.view(B, H, W, self.out_channels).permute(0, 3, 1, 2) + return x + + def _forward_conv(self, x: torch.Tensor) -> torch.Tensor: + assert x.dim() == 4 + x = F.conv2d( + x, + self.weight, + self.bias, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=self.groups, + ) + return x + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + """Expected input shape: (batch_size, in_channels, height, width)""" + assert x.dim() == 4 + if self.enable_linear: + return self._forward_mulmat(x) + else: + return self._forward_conv(x) + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + # By default, we use CUDNN's convolution ops with optimization. + return self._forward_conv(x) + + +class CausalConv2dLayer(Conv2dLayer): + """ + A causal version of nn.Conv2d where each location in the 2D matrix would + have no access to locations on its right or down + All arguments are the same as nn.Conv2d except padding which should be + set as None + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + padding: int = 0, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + *, + params_dtype: torch.dtype | None = None, + ) -> None: + if padding is not None: + raise ValueError( + "Argument padding should be set to None for CausalConv2dLayer." + ) + self._left_padding: int = kernel_size - 1 + self._right_padding: int = stride - 1 + padding = 0 + + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode, + params_dtype=params_dtype, + ) + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + x = F.pad(x, pad=(self._left_padding, self._right_padding, 0, 0)) + x = super().forward(x) + return x + + +@CustomOp.register("conv3d") +class Conv3dLayer(ConvLayerBase): + """Conv layer with Conv3d.""" + + num_dim = 3 + + def _forward_mulmat(self, x: torch.Tensor) -> torch.Tensor: + assert x.dim() == 5 + B, C, T, H, W = x.shape + K1, K2, K3 = self.kernel_size + T, H, W = T // K1, H // K2, W // K3 + x = x.unfold(2, K1, K1).unfold(3, K2, K2).unfold(4, K3, K3) + x = x.permute(0, 2, 3, 4, 1, 5, 6, 7).reshape(-1, self.input_size) + x = F.linear( + x, + self.weight.view(self.out_channels, self.input_size), + self.bias, + ) + x = x.view(B, T, H, W, self.out_channels).permute(0, 4, 1, 2, 3) + return x + + def _forward_conv(self, x: torch.Tensor) -> torch.Tensor: + assert x.dim() == 5 + x = F.conv3d( + x, + self.weight, + self.bias, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=self.groups, + ) + return x + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + """Expected input shape: (batch_size, in_channels, time, height, width)""" + if self.enable_linear: + return self._forward_mulmat(x) + else: + return self._forward_conv(x) + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + # PyTorch2.9.0 disabled CUDNN's Conv3D, which caused a + # significant performance regression. + # See: https://github.com/vllm-project/vllm/issues/27406 + # and https://github.com/pytorch/pytorch/issues/166122 + # By default, we use CUDNN's convolution ops with optimization. + if self.enable_linear and is_torch_equal("2.9.0"): + return self._forward_mulmat(x) + return self._forward_conv(x) diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 869082f8231d1..53362277dae8a 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -1,9 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.forward_context import get_forward_context, is_forward_context_available from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( @@ -13,14 +15,33 @@ from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils.deep_gemm import ( + DeepGemmQuantScaleFMT, fp8_m_grouped_gemm_nt_masked, get_mk_alignment_for_contiguous_layout, is_deep_gemm_e8m0_used, ) +from vllm.utils.math_utils import cdiv, round_up logger = init_logger(__name__) +def scales_shape_stride_dtype( + E: int, T: int, G: int, quant_scale_fmt: DeepGemmQuantScaleFMT +) -> tuple[tuple[int, ...], tuple[int, ...], torch.dtype]: + shape = (E, T, G) + strides = (T * G, 1, T) + if quant_scale_fmt in [ + DeepGemmQuantScaleFMT.FLOAT32, + DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0, + ]: + return shape, strides, torch.float32 + + assert quant_scale_fmt == DeepGemmQuantScaleFMT.UE8M0 + shape = (E, T, cdiv(G, 4)) + strides = (T * cdiv(G, 4), 1, T) + return shape, strides, torch.int32 + + @triton.jit def _silu_mul_fp8_quant_deep_gemm( # Pointers ------------------------------------------------------------ @@ -49,7 +70,7 @@ def _silu_mul_fp8_quant_deep_gemm( eps: tl.constexpr, fp8_min: tl.constexpr, fp8_max: tl.constexpr, - use_ue8m0: tl.constexpr, + ceil_ue8m0: tl.constexpr, # Meta --------------------------------------------------------------- BLOCK: tl.constexpr, NUM_STAGES: tl.constexpr, @@ -86,7 +107,7 @@ def _silu_mul_fp8_quant_deep_gemm( y = gate * up y_s = tl.maximum(tl.max(tl.abs(y)), eps) / fp8_max - if use_ue8m0: + if ceil_ue8m0: y_s = tl.exp2(tl.ceil(tl.log2(y_s))) y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) @@ -100,7 +121,7 @@ def persistent_masked_m_silu_mul_quant( tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert num_parallel_tokens=16, group_size: int = 128, - use_ue8m0: bool | None = None, + quant_scale_fmt: DeepGemmQuantScaleFMT = DeepGemmQuantScaleFMT.FLOAT32, ) -> tuple[torch.Tensor, torch.Tensor]: """Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales y has shape (E, T, 2*H). The first half of the last dimension is @@ -137,7 +158,13 @@ def persistent_masked_m_silu_mul_quant( Returns `(y_q, y_s)` where * `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H] - * `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T) + * `y_s` depends on quant_scale_fmt, + - quant_scale_fmt == FLOAT32, + `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T) + - quant_scale_fmt == E8M0, + `y_s`: Int32 tensor, shape (E, T, H // group_size // 4), strides (T*G, 1, T) + - quant_scale_fmt == E8M0_FLOAT32_SPARSE + `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T) Let NUM_WARPS be the number of warps in a single thread block and `GROUP_SIZE = 128` be the size of the quantization group. """ @@ -155,17 +182,18 @@ def persistent_masked_m_silu_mul_quant( fp8_dtype = torch.float8_e4m3fn y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device) - stride_ys_e = T * G - stride_ys_t = 1 - stride_ys_g = T + ys_shape, ys_strides, ys_dtype = scales_shape_stride_dtype(E, T, G, quant_scale_fmt) y_s = torch.empty_strided( - (E, T, G), - (stride_ys_e, stride_ys_t, stride_ys_g), - dtype=torch.float32, + ys_shape, + ys_strides, + dtype=ys_dtype, device=y.device, ) - use_ue8m0 = use_ue8m0 if use_ue8m0 is not None else is_deep_gemm_e8m0_used() + ceil_ue8m0 = quant_scale_fmt in [ + DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0, + DeepGemmQuantScaleFMT.UE8M0, + ] cuda_arch = current_platform.get_device_capability( device_id=y.device.index @@ -173,7 +201,7 @@ def persistent_masked_m_silu_mul_quant( if cuda_arch >= 80: torch.ops._C.persistent_masked_m_silu_mul_quant( - y, tokens_per_expert, y_q, y_s, use_ue8m0 + y, tokens_per_expert, y_q, y_s, ceil_ue8m0 ) else: stride_cnt_e = tokens_per_expert.stride()[0] @@ -189,6 +217,10 @@ def persistent_masked_m_silu_mul_quant( fp8_max = f_info.max fp8_min = f_info.min eps: float = 1e-10 + assert y_s.dtype == torch.float32, ( + "_silu_mul_fp8_quant_deep_gemm does" + "not support {y_s.dtype} scales. Only torch.float32 supported." + ) _silu_mul_fp8_quant_deep_gemm[grid]( y, y_q, @@ -202,14 +234,14 @@ def persistent_masked_m_silu_mul_quant( stride_yq_e, stride_yq_t, stride_yq_h, - stride_ys_e, - stride_ys_t, - stride_ys_g, + ys_strides[0], + ys_strides[1], + ys_strides[2], stride_cnt_e, eps, fp8_min, fp8_max, - is_deep_gemm_e8m0_used(), + ceil_ue8m0, BLOCK=group_size, NUM_STAGES=4, num_warps=1, @@ -255,7 +287,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): """ DeepGemm supports packed ue8m0 activation scales format in devices == sm100 """ - return current_platform.is_device_capability(100) + return is_deep_gemm_e8m0_used() and current_platform.is_device_capability(100) def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: # Let PrepareAndFinalize::finalize() decide the impl. @@ -282,6 +314,33 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): output = (num_experts, max_num_tokens * num_dispatchers, K) return (workspace13, workspace2, output) + def estimate_expected_m( + self, global_num_experts: int, max_tokens_per_expert: int, topk: int + ) -> int: + dp_meta = ( + get_forward_context().dp_metadata + if is_forward_context_available() + else None + ) + if dp_meta is None: + logger.warning_once( + "DPMetadata unavailable. Defaulting expected_m to " + f"{max_tokens_per_expert}.", + scope="local", + ) + return max_tokens_per_expert + + total_num_tokens = dp_meta.num_tokens_across_dp_cpu.sum().item() + total_num_tokens_replicated = total_num_tokens * topk + + # Assume even load balancing + assert global_num_experts != 0 + estimate = round_up(int(total_num_tokens_replicated // global_num_experts), 16) + # clamp estimate + estimate = max(estimate, 16) + estimate = min(max_tokens_per_expert, estimate) + return estimate + def apply( self, output: torch.Tensor, @@ -317,10 +376,12 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N)) - # (from deepgemm docs) : A value hint (which is a value on CPU) - # for the M expectation of each batch, correctly setting this value - # may lead to better performance. - expected_m = max_num_tokens + expected_m = self.estimate_expected_m( + global_num_experts=global_num_experts, + max_tokens_per_expert=max_num_tokens, + topk=topk_ids.size(-1), + ) + fp8_m_grouped_gemm_nt_masked( (a1q, a1q_scale), (w1, self.w1_scale), @@ -329,10 +390,17 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): expected_m, ) + quant_scale_fmt = DeepGemmQuantScaleFMT.from_oracle() a2q, a2q_scale = persistent_masked_m_silu_mul_quant( - workspace1, expert_num_tokens + workspace1, + expert_num_tokens, + quant_scale_fmt=quant_scale_fmt, ) fp8_m_grouped_gemm_nt_masked( - (a2q, a2q_scale), (w2, self.w2_scale), output, expert_num_tokens, expected_m + (a2q, a2q_scale), + (w2, self.w2_scale), + output, + expert_num_tokens, + expected_m, ) diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json new file mode 100644 index 0000000000000..555d173644522 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 85ce77fb1f7f7..f864634c66176 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -57,6 +57,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): tp_rank: int = 0, tp_size: int = 1, use_dp: bool = False, + use_deepseek_fp8_block_scale: bool = False, ): super().__init__(quant_config) assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn, None), ( @@ -69,6 +70,10 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): self.tp_size = tp_size self.out_dtype = out_dtype self.use_dp = use_dp + # Enables DeepSeek-style FP8 block-scale path: + # - pass per-block weight scales to the kernel + # - skip input activation quantization (kernel applies scaling) + self.use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale @property def activation_formats( @@ -143,11 +148,22 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool | None, ): - assert activation == "silu", ( - "Only activation silu is supported in FlashInferExperts" + from flashinfer.fused_moe.core import ActivationType + + activation_str_to_value_map = { + "silu": ActivationType.Swiglu, # This is the default + "relu2_no_mul": ActivationType.Relu2, + } + assert activation in activation_str_to_value_map, ( + f"{activation=} missing from {activation_str_to_value_map.keys()=}" ) - if self.quant_dtype == torch.float8_e4m3fn: + # Select quantization metadata based on FP8 format/path + if ( + self.quant_dtype == torch.float8_e4m3fn + and not self.use_deepseek_fp8_block_scale + ): + # FP8 per-tensor path: use global alphas/scales; do not pass input_sf quant_scales = [ self.g1_alphas, self.a2_gscale, @@ -176,6 +192,15 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): # FlashInfer API requires weight to be long for nvfp4 fc1_expert_weights = w1.view(torch.long) fc2_expert_weights = w2.view(torch.long) + elif self.use_deepseek_fp8_block_scale: + # FP8 block-scale path: provide block-scale weights, omit a1q_scale + quant_scales = [ + self.w1_scale, + self.w2_scale, + ] + a1q_scale = None + fc1_expert_weights = w1 + fc2_expert_weights = w2 else: quant_scales = None a1q_scale = None @@ -196,6 +221,9 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ep_size=self.ep_size, ep_rank=self.ep_rank, output=output, + activation_type=activation_str_to_value_map[activation], + # Informs FlashInfer to use the block-scale decoding path when True + use_deepseek_fp8_block_scale=self.use_deepseek_fp8_block_scale, ) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index bc9aab5208d9a..762890867e605 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -28,11 +28,15 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): self, use_dp: bool, num_dispatchers: int = 1, + use_deepseek_fp8_block_scale: bool = False, ): super().__init__() self.num_dispatchers_ = num_dispatchers self.use_dp = use_dp self.local_tokens = None + # Toggle for DeepSeek-style FP8 block-scale path where activations are + # not quantized here and weight block scales are consumed by the kernel. + self.use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale @property def activation_format(self) -> mk.FusedMoEActivationFormat: @@ -73,8 +77,9 @@ class FlashInferAllToAllMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFina self, use_dp: bool, num_dispatchers: int = 1, + use_deepseek_fp8_block_scale: bool = False, ): - super().__init__(use_dp, num_dispatchers) + super().__init__(use_dp, num_dispatchers, use_deepseek_fp8_block_scale) self.alltoall_info = None # Initialize all2all_manager only for DP case @@ -97,15 +102,19 @@ class FlashInferAllToAllMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFina ) if not self.use_dp: - # Non-DP case: standard quantization - a1q, a1q_scale = moe_kernel_quantize_input( - a1, - quant_config.a1_gscale, - quant_config.quant_dtype, - quant_config.per_act_token_quant, - quant_config.block_shape, - is_fp4_scale_swizzled=not self.use_dp, - ) + # Non-DP case: quantize activations unless using block-scale path + if not self.use_deepseek_fp8_block_scale: + a1q, a1q_scale = moe_kernel_quantize_input( + a1, + quant_config.a1_gscale, + quant_config.quant_dtype, + quant_config.per_act_token_quant, + quant_config.block_shape, + is_fp4_scale_swizzled=not self.use_dp, + ) + else: + a1q = a1 + a1q_scale = None else: # DP case: use FlashInfer AllToAll global_num_tokens_cpu = get_local_sizes() @@ -122,6 +131,7 @@ class FlashInferAllToAllMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFina top_k, num_experts, quant_config, + use_deepseek_fp8_block_scale=self.use_deepseek_fp8_block_scale, ) ) @@ -154,8 +164,9 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin self, use_dp: bool, num_dispatchers: int = 1, + use_deepseek_fp8_block_scale: bool = False, ): - super().__init__(use_dp, num_dispatchers) + super().__init__(use_dp, num_dispatchers, use_deepseek_fp8_block_scale) def prepare( self, @@ -173,22 +184,42 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin if not self.use_dp and quant_config.quant_dtype == "nvfp4": return a1, None, None, topk_ids, topk_weights - a1q, a1q_scale = moe_kernel_quantize_input( - a1, - quant_config.a1_gscale, - quant_config.quant_dtype, - quant_config.per_act_token_quant, - quant_config.block_shape, - is_fp4_scale_swizzled=not self.use_dp, - ) + if not self.use_deepseek_fp8_block_scale: + a1q, a1q_scale = moe_kernel_quantize_input( + a1, + quant_config.a1_gscale, + quant_config.quant_dtype, + quant_config.per_act_token_quant, + quant_config.block_shape, + is_fp4_scale_swizzled=not self.use_dp, + ) + else: + # Block-scale path: pass activations through, omit per-token scales + a1q = a1 + a1q_scale = None if self.use_dp: - topk_weights, topk_ids, a1q, a1q_scale = get_dp_group().all_gatherv( - [topk_weights, topk_ids, a1q, a1q_scale], - dim=0, - sizes=get_local_sizes(), - ) - if quant_config.quant_dtype == "nvfp4": + # Build gather list conditionally - omit a1q_scale if None + # (block-scale path) + gather_list = [topk_weights, topk_ids, a1q] + if a1q_scale is not None: + gather_list.append(a1q_scale) + gathered = get_dp_group().all_gatherv( + gather_list, + dim=0, + sizes=get_local_sizes(), + ) + topk_weights, topk_ids, a1q, a1q_scale = gathered + else: + gathered = get_dp_group().all_gatherv( + gather_list, + dim=0, + sizes=get_local_sizes(), + ) + topk_weights, topk_ids, a1q = gathered + a1q_scale = None + + if quant_config.quant_dtype == "nvfp4" and a1q_scale is not None: a1q_scale = nvfp4_block_scale_interleave(a1q_scale) return a1q, a1q_scale, None, topk_ids, topk_weights @@ -221,6 +252,7 @@ def flashinfer_alltoall_dispatch( top_k: int, num_experts: int, quant_config: FusedMoEQuantConfig, + use_deepseek_fp8_block_scale: bool = False, ): from flashinfer.comm.trtllm_alltoall import MnnvlMoe @@ -250,30 +282,42 @@ def flashinfer_alltoall_dispatch( ) topk_weights = topk_weights.view(dtype=orig_topk_weights_dtype) - x, x_sf = moe_kernel_quantize_input( - x, - gs, - quant_config.quant_dtype, - quant_config.per_act_token_quant, - quant_config.block_shape, - is_fp4_scale_swizzled=False, # delay swizzle to after comm - ) - x = MnnvlMoe.mnnvl_moe_alltoallv( - x, - alltoall_info, - all2all_manager.workspace_tensor, - ep_rank, - ep_size, - ) + if not use_deepseek_fp8_block_scale: + x, x_sf = moe_kernel_quantize_input( + x, + gs, + quant_config.quant_dtype, + quant_config.per_act_token_quant, + quant_config.block_shape, + is_fp4_scale_swizzled=False, # delay swizzle to after comm + ) + x = MnnvlMoe.mnnvl_moe_alltoallv( + x, + alltoall_info, + all2all_manager.workspace_tensor, + ep_rank, + ep_size, + ) - x_sf = MnnvlMoe.mnnvl_moe_alltoallv( - x_sf, - alltoall_info, - all2all_manager.workspace_tensor, - ep_rank, - ep_size, - ) - x_sf = nvfp4_block_scale_interleave(x_sf) + x_sf = MnnvlMoe.mnnvl_moe_alltoallv( + x_sf, + alltoall_info, + all2all_manager.workspace_tensor, + ep_rank, + ep_size, + ) + if quant_config.quant_dtype == "nvfp4": + x_sf = nvfp4_block_scale_interleave(x_sf) + else: + # Block-scale path: pass activations through without quantization + x_sf = None + x = MnnvlMoe.mnnvl_moe_alltoallv( + x, + alltoall_info, + all2all_manager.workspace_tensor, + ep_rank, + ep_size, + ) return alltoall_info, topk_ids, topk_weights, x, x_sf @@ -304,6 +348,7 @@ def create_flashinfer_prepare_finalize( use_dp: bool, use_nvfp4: bool = False, enable_alltoallv: bool = False, + use_deepseek_fp8_block_scale: bool = False, ) -> FlashInferCutlassMoEPrepareAndFinalize: """Factory function to create the appropriate FlashInfer implementation.""" if use_nvfp4: @@ -311,5 +356,7 @@ def create_flashinfer_prepare_finalize( return FlashInferAllToAllMoEPrepareAndFinalize(use_dp) else: return FlashInferAllGatherMoEPrepareAndFinalize(use_dp) - # Fp8 only supports AllGather - return FlashInferAllGatherMoEPrepareAndFinalize(use_dp) + # FP8 path currently supported via AllGather; optionally enable block-scale + return FlashInferAllGatherMoEPrepareAndFinalize( + use_dp=use_dp, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale + ) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index aed8245cbd830..023132acfed3f 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1749,14 +1749,16 @@ class FusedMoE(CustomOp): with sp_ctx: if do_naive_dispatch_combine: - hidden_states, router_logits = get_ep_group().dispatch( + hidden_states_combined, router_logits = get_ep_group().dispatch( hidden_states, router_logits, self.is_sequence_parallel ) # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, - x=hidden_states, + x=hidden_states_combined + if do_naive_dispatch_combine + else hidden_states, router_logits=router_logits, top_k=self.top_k, renormalize=self.renormalize, diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index a3142f37053f9..093affe51f503 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -1060,7 +1060,7 @@ class FusedMoEModularKernel(torch.nn.Module): global_num_experts=global_num_experts, expert_map=expert_map, a1q_scale=_slice_scales(a1q_scale, s, e), - a2_scale=_slice_scales(self.fused_experts.a2_scale, e, e), + a2_scale=_slice_scales(self.fused_experts.a2_scale, s, e), workspace13=workspace13, workspace2=workspace2, expert_tokens_meta=c_expert_tokens_meta, diff --git a/vllm/model_executor/layers/kda.py b/vllm/model_executor/layers/kda.py index 26458f2e3c4da..2e7500bac7188 100644 --- a/vllm/model_executor/layers/kda.py +++ b/vllm/model_executor/layers/kda.py @@ -44,7 +44,6 @@ def kda_attention( k_proj_states: torch.Tensor, v_proj_states: torch.Tensor, g1: torch.Tensor, - g2: torch.Tensor, beta: torch.Tensor, core_attn_out: torch.Tensor, layer_name: str, @@ -56,7 +55,6 @@ def kda_attention( k_proj_states=k_proj_states, v_proj_states=v_proj_states, g1=g1, - g2=g2, beta=beta, core_attn_out=core_attn_out, ) @@ -67,7 +65,6 @@ def kda_attention_fake( k_proj_states: torch.Tensor, v_proj_states: torch.Tensor, g1: torch.Tensor, - g2: torch.Tensor, beta: torch.Tensor, core_attn_out: torch.Tensor, layer_name: str, @@ -284,7 +281,6 @@ class KimiDeltaAttention(nn.Module, MambaBase): k, v, g1, - g2, beta, core_attn_out, self.prefix, @@ -299,7 +295,6 @@ class KimiDeltaAttention(nn.Module, MambaBase): k_proj_states: torch.Tensor, v_proj_states: torch.Tensor, g1: torch.Tensor, - g2: torch.Tensor, beta: torch.Tensor, core_attn_out: torch.Tensor, ) -> None: @@ -316,8 +311,15 @@ class KimiDeltaAttention(nn.Module, MambaBase): has_initial_state = attn_metadata.has_initial_state non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 + num_actual_tokens = attn_metadata.num_actual_tokens constant_caches = self.kv_cache[forward_context.virtual_engine] + q_proj_states = q_proj_states[:num_actual_tokens] + k_proj_states = k_proj_states[:num_actual_tokens] + v_proj_states = v_proj_states[:num_actual_tokens] + g1 = g1[:num_actual_tokens] + beta = beta[:num_actual_tokens] + (conv_state_q, conv_state_k, conv_state_v, recurrent_state) = constant_caches # deal with strides conv_state_q = conv_state_q.transpose(-1, -2) @@ -372,7 +374,7 @@ class KimiDeltaAttention(nn.Module, MambaBase): ).transpose(0, 1) else: decode_conv_indices = non_spec_state_indices_tensor[ - : attn_metadata.num_decodes + : attn_metadata.num_actual_tokens ] q = causal_conv1d_update( q_proj_states, @@ -438,8 +440,9 @@ class KimiDeltaAttention(nn.Module, MambaBase): beta=beta, initial_state=recurrent_state, use_qk_l2norm_in_kernel=True, - cu_seqlens=non_spec_query_start_loc, + cu_seqlens=non_spec_query_start_loc[: attn_metadata.num_decodes + 1], ssm_state_indices=non_spec_state_indices_tensor, ) - assert core_attn_out_non_spec.shape == core_attn_out.shape - core_attn_out[:] = core_attn_out_non_spec + core_attn_out[0, :num_actual_tokens] = core_attn_out_non_spec[ + 0, :num_actual_tokens + ] diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index e5a5c9dd6f712..661c884627b00 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -245,7 +245,7 @@ def _chunk_scan_fwd_kernel( ) if not HAS_INITSTATES and (seq_idx != seq_idx_prev): prev_states = tl.zeros( - (BLOCK_SIZE_DSTATE, BLOCK_SIZE_K), dtype=C_ptr.dtype.element_ty + (BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=C_ptr.dtype.element_ty ) else: prev_states = tl.load( diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index b92fb8d266b73..bb42b10f87186 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -3,8 +3,11 @@ from typing import Literal, get_args +from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +logger = init_logger(__name__) + QuantizationMethods = Literal[ "awq", "deepspeedfp", @@ -70,15 +73,20 @@ def register_quantization_config(quantization: str): def _wrapper(quant_config_cls): if quantization in QUANTIZATION_METHODS: - raise ValueError( - f"The quantization method `{quantization}` is already exists." + logger.warning( + "The quantization method '%s' already exists and will be " + "overwritten by the quantization config %s.", + quantization, + quant_config_cls, ) + else: + QUANTIZATION_METHODS.append(quantization) + if not issubclass(quant_config_cls, QuantizationConfig): raise ValueError( "The quantization config must be a subclass of `QuantizationConfig`." ) _CUSTOMIZED_METHOD_TO_QUANT_CONFIG[quantization] = quant_config_cls - QUANTIZATION_METHODS.append(quantization) return quant_config_cls return _wrapper diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 225ed9499fd4d..e3a7b2a8dffe9 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -3,6 +3,7 @@ from collections.abc import Callable from enum import Enum +from functools import partial from typing import TYPE_CHECKING, Any, Optional import torch @@ -126,10 +127,13 @@ def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend: Select the primary FP8 MoE backend Note: Shape-specific fallbacks may still occur at runtime. """ - # prefer FlashInfer backends when available and enabled on supported GPUs + # Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100. if ( current_platform.is_cuda() - and current_platform.is_device_capability(100) + and ( + current_platform.is_device_capability(100) + or current_platform.is_device_capability(90) + ) and envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe() ): @@ -138,14 +142,14 @@ def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend: logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100") return Fp8MoeBackend.FLASHINFER_TRTLLM else: - if block_quant: + if block_quant and current_platform.is_device_capability(100): raise ValueError( "FlashInfer FP8 MoE throughput backend does not " "support block quantization. Please use " "VLLM_FLASHINFER_MOE_BACKEND=latency " "instead." ) - logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM100") + logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM90/SM100") return Fp8MoeBackend.FLASHINFER_CUTLASS # weight-only path for older GPUs without native FP8 @@ -644,6 +648,16 @@ class Fp8MoEMethod(FusedMoEMethodBase): self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS + if self.block_quant: + assert self.weight_block_size == [128, 128], ( + f"Only support weight_block_size == [128, 128], " + f"got {self.weight_block_size}" + ) + self.flashinfer_moe_fn = partial( + flashinfer_cutlass_moe_fp8, + moe=self.moe, + use_deepseek_fp8_block_scale=self.block_quant, + ) self.allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM self.allow_cutlass_block_scaled_grouped_gemm = ( @@ -1015,8 +1029,15 @@ class Fp8MoEMethod(FusedMoEMethodBase): ): return None elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: + if self.block_quant: + assert self.weight_block_size == [128, 128], ( + f"Only support weight_block_size == [128, 128], " + f"got {self.weight_block_size}" + ) + # Wire block-scale flag through prepare/finalize when using CUTLASS prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( - self.moe + self.moe, + use_deepseek_fp8_block_scale=self.block_quant, ) logger.debug_once("%s", prepare_finalize.__class__.__name__) return prepare_finalize @@ -1065,9 +1086,11 @@ class Fp8MoEMethod(FusedMoEMethodBase): ) elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: + # Select GEMM experts with block-scale when weights are block-quantized experts = select_cutlass_fp8_gemm_impl( self.moe, self.moe_quant_config, + use_deepseek_fp8_block_scale=self.block_quant, ) logger.debug_once("Using %s", experts.__class__.__name__) return experts @@ -1254,16 +1277,17 @@ class Fp8MoEMethod(FusedMoEMethodBase): workspace=layer.workspace, ) elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: - assert not self.block_quant - assert not renormalize and custom_routing_function is not None assert activation == "silu", ( f"Expected 'silu' activation but got {activation}" ) - assert scoring_func == "sigmoid", ( - f"Expected 'sigmoid' scoring func but got {scoring_func}" - ) - - result = flashinfer_cutlass_moe_fp8( + if not self.block_quant: + assert not renormalize and custom_routing_function is not None + assert scoring_func == "sigmoid", ( + f"Expected 'sigmoid' scoring func but got {scoring_func}" + ) + # Delegate to CUTLASS FlashInfer path; function already bound with + # use_deepseek_fp8_block_scale for block-quant when applicable + result = self.flashinfer_moe_fn( x, layer, topk_weights, diff --git a/vllm/model_executor/layers/quantization/ipex_quant.py b/vllm/model_executor/layers/quantization/ipex_quant.py index e0234191c62bf..5ca9167faec80 100644 --- a/vllm/model_executor/layers/quantization/ipex_quant.py +++ b/vllm/model_executor/layers/quantization/ipex_quant.py @@ -52,6 +52,7 @@ class IPEXConfig(QuantizationConfig): modules_to_not_convert: list[str] | None = None, desc_act: bool | None = None, lm_head_quantized: bool | None = None, + is_sym: bool | None = None, ) -> None: super().__init__() self.method = method @@ -60,6 +61,7 @@ class IPEXConfig(QuantizationConfig): self.modules_to_not_convert = modules_to_not_convert or [] self.desc_act = desc_act self.lm_head_quantized = lm_head_quantized + self.is_sym = is_sym self.pack_factor = 32 // self.weight_bits if self.weight_bits not in [4]: @@ -108,15 +110,25 @@ class IPEXConfig(QuantizationConfig): modules_to_not_convert = cls.get_from_keys_or( config, ["modules_to_not_convert"], None ) + is_sym = not cls.get_from_keys_or(config, ["zero_point"], default=False) return cls( - method, weight_bits, group_size, modules_to_not_convert, False, False + method, + weight_bits, + group_size, + modules_to_not_convert, + False, + False, + is_sym, ) # otherwise for gptq weight_bits = cls.get_from_keys(config, ["bits"]) group_size = cls.get_from_keys(config, ["group_size"]) lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) desc_act = cls.get_from_keys_or(config, ["desc_act"], default=False) - return cls(method, weight_bits, group_size, [], desc_act, lm_head_quantized) + is_sym = cls.get_from_keys_or(config, ["sym"], default=True) + return cls( + method, weight_bits, group_size, [], desc_act, lm_head_quantized, is_sym + ) @classmethod def override_quantization_method( @@ -180,6 +192,7 @@ class IPEXGPTQLinearMethod(GPTQLinearMethod): # The float activation will be quantized (dynamic, per-token) to INT8. act_quant_mode = ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK + assert isinstance(self.quant_config, IPEXConfig) qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping( weight_dtype=weight_dtype, lowp_mode=lowp_mode, @@ -200,6 +213,7 @@ class IPEXGPTQLinearMethod(GPTQLinearMethod): bias=bias, group_size=self.quant_config.group_size, quant_method=IPEXConfig.IPEX_QUANT_METHOD_MAP["gptq"], + weight_qscheme="sym" if self.quant_config.is_sym else "asym", ) ) @@ -250,6 +264,7 @@ class IPEXAWQLinearMethod(AWQLinearMethod): # The float activation will be quantized (dynamic, per-token) to INT8. act_quant_mode = ipex.quantization.WoqActQuantMode.PER_BATCH + assert isinstance(self.quant_config, IPEXConfig) qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping( weight_dtype=weight_dtype, lowp_mode=lowp_mode, @@ -269,6 +284,7 @@ class IPEXAWQLinearMethod(AWQLinearMethod): bias=bias, group_size=self.quant_config.group_size, quant_method=IPEXConfig.IPEX_QUANT_METHOD_MAP["awq"], # type: ignore + weight_qscheme="sym" if self.quant_config.is_sym else "asym", ) ) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 79bcb61dc5060..9309f4f150e45 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -15,6 +15,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, + RoutingMethodType, fp8_w8a8_moe_quant_config, nvfp4_moe_quant_config, ) @@ -353,12 +354,18 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): self.cutlass_fp8_supported = cutlass_fp8_supported() self.flashinfer_moe_backend: FlashinferMoeBackend | None = None - if ( - envs.VLLM_USE_FLASHINFER_MOE_FP8 - and has_flashinfer_moe() - and self.moe.is_act_and_mul - ): + if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe(): self.flashinfer_moe_backend = get_flashinfer_moe_backend() + if ( + self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM + and not self.moe.is_act_and_mul + ): + logger.info_once( + "Non-gated MoE is not supported for min-latency mode," + "falling back to high-throughput mode" + ) + self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS + logger.info_once( f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels" ) @@ -556,10 +563,11 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ) if self.flashinfer_moe_backend is not None: - layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) - register_moe_scaling_factors(layer) + if self.moe.is_act_and_mul: + layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight) + register_moe_scaling_factors(layer) def get_fused_moe_quant_config( self, layer: torch.nn.Module @@ -569,13 +577,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): return fp8_w8a8_moe_quant_config( w1_scale=layer.w13_weight_scale, - g1_alphas=(layer.w13_weight_scale * layer.w13_input_scale).squeeze(), + g1_alphas=layer.output1_scales_gate_scalar.squeeze(), w2_scale=layer.w2_weight_scale, - g2_alphas=(layer.w2_weight_scale * layer.w2_input_scale).squeeze(), + g2_alphas=layer.output2_scales_scalar.squeeze(), a1_scale=layer.w13_input_scale, a1_gscale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, - a2_gscale=1.0 / layer.w2_input_scale, + a2_gscale=layer.w2_input_scale_inv, per_act_token_quant=False, ) @@ -641,9 +649,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ) if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: - assert not renormalize - assert activation == "silu", ( - f"Expected 'silu' activation but got {activation}" + assert activation in ("silu", "relu2_no_mul"), ( + "Expected activation to be in ('silu', 'relu2_no_mul')," + f"but got {activation}" ) return flashinfer_cutlass_moe_fp8( x, @@ -1649,16 +1657,19 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): use_llama4_routing = ( custom_routing_function is Llama4MoE.custom_routing_function ) - routing_method_type = flashinfer.RoutingMethodType.DeepSeekV3 + routing_method_type = layer.routing_method_type if use_llama4_routing: - routing_method_type = flashinfer.RoutingMethodType.Llama4 + routing_method_type = RoutingMethodType.Llama4 + router_logits = ( + router_logits.to(torch.float32) + if routing_method_type == RoutingMethodType.DeepSeekV3 + else router_logits + ) routing_bias = e_score_correction_bias if routing_bias is not None: routing_bias = routing_bias.to(torch.bfloat16) out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe( - routing_logits=router_logits - if use_llama4_routing - else router_logits.to(torch.float32), + routing_logits=router_logits, routing_bias=routing_bias, hidden_states=hidden_states_fp4, hidden_states_scale=hidden_states_scale_linear_fp4.view( @@ -1682,8 +1693,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): output2_scale_scalar=layer.g2_alphas.data, num_experts=global_num_experts, top_k=top_k, - n_group=num_expert_group if num_expert_group is not None else 0, - topk_group=topk_group if topk_group is not None else 0, + n_group=num_expert_group, + topk_group=topk_group, intermediate_size=layer.intermediate_size_per_partition, local_expert_offset=layer.ep_rank * layer.local_num_experts, local_num_experts=layer.local_num_experts, diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 0f69a18a1f3fd..b95d1a6b3a1f5 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -755,8 +755,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): self.w13_weight = w13_weight self.w2_weight = w2_weight - layer.w13_weight = w13_weight - layer.w2_weight = w2_weight + layer.w13_weight = Parameter(w13_weight.storage.data, requires_grad=False) + layer.w2_weight = Parameter(w2_weight.storage.data, requires_grad=False) else: raise ValueError(f"Unsupported backend: {self.mxfp4_backend}") diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index 095a66ef10f9a..1bb698faf46df 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import fnmatch -from typing import Any, Optional, cast +from typing import TYPE_CHECKING, Any, Optional, cast import torch @@ -34,6 +34,9 @@ from vllm.model_executor.layers.quantization.quark.utils import ( ) from vllm.platforms import current_platform +if TYPE_CHECKING: + from vllm.model_executor.models.utils import WeightsMapper + __all__ = ["QuarkLinearMethod"] logger = init_logger(__name__) @@ -54,6 +57,7 @@ class QuarkConfig(QuantizationConfig): self.kv_cache_group = kv_cache_group self.kv_cache_config = kv_cache_config self.pack_method = pack_method + self.ignore: list[str] = cast(list[str], self.quant_config.get("exclude", [])) def get_linear_method(self) -> "QuarkLinearMethod": return QuarkLinearMethod(self) @@ -74,9 +78,8 @@ class QuarkConfig(QuantizationConfig): from vllm.attention.layer import Attention # Avoid circular import # Check if the layer is skipped for quantization. - exclude_layers = cast(list[str], self.quant_config.get("exclude")) if should_ignore_layer( - prefix, ignore=exclude_layers, fused_mapping=self.packed_modules_mapping + prefix, ignore=self.ignore, fused_mapping=self.packed_modules_mapping ): return UnquantizedLinearMethod() if isinstance(layer, LinearBase): @@ -90,6 +93,9 @@ class QuarkConfig(QuantizationConfig): return QuarkMoEMethod.get_moe_method(self, module=layer, layer_name=prefix) return None + def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): + self.ignore = hf_to_vllm_mapper.apply_list(self.ignore) + @classmethod def from_config(cls, config: dict[str, Any]) -> "QuarkConfig": export_config = config.get("export") diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index e49d374f154d8..f22e17945d1f6 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -17,6 +17,7 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 create_flashinfer_prepare_finalize, ) +from vllm.platforms import current_platform logger = init_logger(__name__) @@ -190,17 +191,22 @@ def register_moe_scaling_factors(layer: torch.nn.Module) -> None: def build_flashinfer_fp8_cutlass_moe_prepare_finalize( - moe: FusedMoEConfig | None, + moe: FusedMoEConfig | None, use_deepseek_fp8_block_scale: bool = False ) -> mk.FusedMoEPrepareAndFinalize: """Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel""" use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False - return create_flashinfer_prepare_finalize(use_dp) + # Propagate block-scale flag so prepare/finalize can skip act quantization + # and inform the kernel to consume per-block weight scales. + return create_flashinfer_prepare_finalize( + use_dp, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale + ) def select_cutlass_fp8_gemm_impl( moe: FusedMoEConfig | None, quant_config: FusedMoEQuantConfig, out_dtype: torch.dtype | None = None, + use_deepseek_fp8_block_scale: bool = False, ) -> mk.FusedMoEPermuteExpertsUnpermute: """Return a GEMM *experts* implementation for fused-MoE layers""" @@ -212,12 +218,14 @@ def select_cutlass_fp8_gemm_impl( ep_size=moe.moe_parallel_config.ep_size, tp_rank=moe.moe_parallel_config.tp_rank, tp_size=moe.moe_parallel_config.tp_size, + use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale, ) assert out_dtype is not None, "If moe config is None, out_dtype must be passed" return FlashInferExperts( out_dtype=out_dtype, quant_config=quant_config, + use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale, ) @@ -231,14 +239,22 @@ def flashinfer_cutlass_moe_fp8( global_num_experts: int = -1, expert_map: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, + use_deepseek_fp8_block_scale: bool = False, + moe: FusedMoEConfig | None = None, ) -> torch.Tensor: quant_config = layer.quant_method.get_fused_moe_quant_config(layer) assert quant_config is not None + # Construct modular kernel with block-scale support when requested. fused_experts = mk.FusedMoEModularKernel( - build_flashinfer_fp8_cutlass_moe_prepare_finalize(moe=None), + build_flashinfer_fp8_cutlass_moe_prepare_finalize( + moe=moe, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale + ), select_cutlass_fp8_gemm_impl( - moe=None, quant_config=quant_config, out_dtype=hidden_states.dtype + moe=moe, + quant_config=quant_config, + out_dtype=hidden_states.dtype, + use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale, ), ) @@ -258,7 +274,10 @@ def flashinfer_cutlass_moe_fp8( def get_flashinfer_moe_backend() -> FlashinferMoeBackend: flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND - if flashinfer_moe_backend == "throughput": + # Prefer CUTLASS on SM90 to cover both SM90/SM100 generations + if flashinfer_moe_backend == "throughput" or current_platform.is_device_capability( + 90 + ): return FlashinferMoeBackend.CUTLASS elif flashinfer_moe_backend == "latency": return FlashinferMoeBackend.TENSORRT_LLM @@ -272,5 +291,8 @@ def get_flashinfer_moe_backend() -> FlashinferMoeBackend: def is_flashinfer_supporting_global_sf(backend: FlashinferMoeBackend | None) -> bool: # TODO(shuw@nvidia): Update when new backends are added. - backends_supporting_global_sf = (FlashinferMoeBackend.CUTLASS,) + backends_supporting_global_sf = ( + FlashinferMoeBackend.CUTLASS, + FlashinferMoeBackend.TENSORRT_LLM, + ) return backend in backends_supporting_global_sf diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 541c6c631053d..ae63b4a767268 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -342,7 +342,7 @@ class W8A8BlockFp8LinearOp: ) # MI300 uses tuned AITER ASM/C++ kernel else: - q_input, input_scale = rocm_aiter_ops.per_1x128_fp8_quant(input_2d) + q_input, input_scale = rocm_aiter_ops.group_fp8_quant(input_2d) return gemm_a8w8_blockscale_op( q_input, diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index f742090df71fd..a9cc49451a1d3 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -429,7 +429,7 @@ def load_weights_using_from_2_way_softmax( if text_config.tie_word_embeddings: # embed_tokens is the assumed name for input embeddings. If the model does not # have this attribute, we fallback to get_input_embeddings(), which is used by - # the Transformers backend. + # the Transformers modeling backend. embed_tokens = ( model.model.embed_tokens if hasattr(model.model, "embed_tokens") @@ -487,7 +487,7 @@ def load_weights_no_post_processing(model, weights: Iterable[tuple[str, torch.Te if text_config.tie_word_embeddings: # embed_tokens is the assumed name for input embeddings. If the model does not # have this attribute, we fallback to get_input_embeddings(), which is used by - # the Transformers backend. + # the Transformers modeling backend. embed_tokens = ( model.model.embed_tokens if hasattr(model.model, "embed_tokens") diff --git a/vllm/model_executor/models/afmoe.py b/vllm/model_executor/models/afmoe.py new file mode 100644 index 0000000000000..6f654f47495f7 --- /dev/null +++ b/vllm/model_executor/models/afmoe.py @@ -0,0 +1,711 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Inference-only AfMoE model compatible with HuggingFace weights.""" + +import typing +from collections.abc import Callable, Iterable +from itertools import islice +from typing import Any + +import torch +from torch import nn + +from vllm.attention import Attention, AttentionType +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config +from vllm.distributed import ( + get_ep_group, + get_pp_group, + get_tensor_model_parallel_world_size, +) +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP +from vllm.model_executor.models.llama import LlamaMLP as AfmoeMLP +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + PPMissingLayer, + WeightsMapper, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) +from vllm.sequence import IntermediateTensors + +logger = init_logger(__name__) + + +class AfmoeMoE(nn.Module): + def __init__( + self, + config, # AfmoeConfig + quant_config: QuantizationConfig | None = None, + prefix: str = "", + enable_eplb: bool = False, + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.route_scale = config.route_scale + self.score_func = config.score_func + self.route_norm = config.route_norm + + self.ep_group = get_ep_group().device_group + self.ep_rank = self.ep_group.rank() + self.ep_size = self.ep_group.size() + self.n_routed_experts: int = config.num_experts + self.n_shared_experts: int = config.num_shared_experts + + if config.hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now." + ) + + # Router gate + self.gate = nn.Linear( + config.hidden_size, + config.num_experts, + bias=False, + dtype=torch.float32, + ) + self.expert_bias = nn.Parameter( + torch.empty(config.num_experts, dtype=torch.float32) + ) + + # Load balancing settings + vllm_config = get_current_vllm_config() + eplb_config = vllm_config.parallel_config.eplb_config + self.enable_eplb = enable_eplb + + self.n_redundant_experts = eplb_config.num_redundant_experts + self.n_logical_experts = self.n_routed_experts + self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + + self.physical_expert_start = self.ep_rank * self.n_local_physical_experts + self.physical_expert_end = ( + self.physical_expert_start + self.n_local_physical_experts + ) + + self.shared_experts = None + # Shared experts + if config.num_shared_experts > 0: + intermediate_size = config.moe_intermediate_size * config.num_shared_experts + self.shared_experts = AfmoeMLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + prefix=f"{prefix}.shared_experts", + ) + + # Routed experts using SharedFusedMoE + self.experts = SharedFusedMoE( + shared_experts=self.shared_experts, + num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=self.route_norm if self.score_func == "sigmoid" else False, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func=self.score_func, + routed_scaling_factor=self.route_scale, + e_score_correction_bias=self.expert_bias, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + + router_logits = self.gate(hidden_states.to(dtype=torch.float32)) + + fused_moe_out = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) + + if self.shared_experts is not None: + shared_output, final_hidden_states = fused_moe_out + final_hidden_states = final_hidden_states + shared_output + else: + final_hidden_states = fused_moe_out + if self.tp_size > 1: + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states + ) + + return final_hidden_states.view(num_tokens, hidden_dim) + + +class AfmoeAttention(nn.Module): + def __init__( + self, + config, # AfmoeConfig + layer_idx: int, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: dict[str, Any] | None = None, + max_position_embeddings: int = 131072, + head_dim: int | None = None, + rms_norm_eps: float = 1e-05, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim or (hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + # Check if this is a local attention layer + self.is_local_attention = config.layer_types[layer_idx] == "sliding_attention" + self.sliding_window = config.sliding_window if self.is_local_attention else None + + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + # Gating projection + self.gate_proj = ColumnParallelLinear( + hidden_size, + self.total_num_heads * self.head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_proj", + ) + + # Q/K normalization + self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + + # Only create rotary embeddings for local attention + if self.is_local_attention: + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=True, + ) + else: + self.rotary_emb = None + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + per_layer_sliding_window=self.sliding_window, + prefix=f"{prefix}.attn", + attn_type=attn_type, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + gate, _ = self.gate_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + # Apply Q/K normalization + q = self.q_norm(q.reshape(-1, self.num_heads, self.head_dim)).reshape(q.shape) + k = self.k_norm(k.reshape(-1, self.num_kv_heads, self.head_dim)).reshape( + k.shape + ) + + # Apply rotary embeddings only for local attention + if self.is_local_attention and self.rotary_emb is not None: + q, k = self.rotary_emb(positions, q, k) + + attn_output = self.attn(q, k, v) + + # Apply gating + attn_output = attn_output * torch.sigmoid(gate) + output, _ = self.o_proj(attn_output) + return output + + +class AfmoeDecoderLayer(nn.Module): + def __init__( + self, + config, # AfmoeConfig + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + enable_eplb: bool = False, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None + ): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings + ) + max_position_embeddings = getattr(config, "max_position_embeddings", 131072) + + # DecoderLayers are created with `make_layers` which passes the prefix + # with the layer's index. + self.layer_idx = extract_layer_index(prefix) + + self.self_attn = AfmoeAttention( + config=config, + layer_idx=self.layer_idx, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + head_dim=config.head_dim, + rms_norm_eps=config.rms_norm_eps, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + # MoE or dense FFN + self.moe_enabled = self.layer_idx >= config.num_dense_layers + if self.moe_enabled: + self.mlp = AfmoeMoE( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + enable_eplb=enable_eplb, + ) + else: + self.mlp = AfmoeMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.pre_mlp_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_mlp_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + hidden_states = self.post_attention_layernorm(hidden_states) # attn norm b + + # Fully Connected + hidden_states, residual = self.pre_mlp_layernorm( # ffn norm a + hidden_states, residual + ) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_mlp_layernorm(hidden_states) # ffn norm b + + return hidden_states, residual + + +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + } +) +class AfmoeModel(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + enable_eplb = vllm_config.parallel_config.enable_eplb + self.config = config + + self.vocab_size = config.vocab_size + self.mup_enabled = config.mup_enabled + + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, config.hidden_size, prefix=f"{prefix}.embed_tokens" + ) + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: AfmoeDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + enable_eplb=enable_eplb, + ), + prefix=f"{prefix}.layers", + ) + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_input_ids(input_ids) + + # Apply muP input scaling if enabled + if self.mup_enabled: + hidden_states = hidden_states * (self.config.hidden_size**0.5) + + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in islice(self.layers, self.start_layer, self.end_layer): + hidden_states, residual = layer(positions, hidden_states, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, device: torch.device + ) -> IntermediateTensors: + return IntermediateTensors( + { + "hidden_states": torch.zeros( + (batch_size, self.config.hidden_size), dtype=dtype, device=device + ), + "residual": torch.zeros( + (batch_size, self.config.hidden_size), dtype=dtype, device=device + ), + } + ) + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + return SharedFusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts, + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() + + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if (weight_name not in name) or ("self_attn.gate_proj" in name): + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if ("mlp.experts." in name) and name not in params_dict: + continue + + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + is_expert_weight = False + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + + # Anyway, this is an expert weight and should not be + # attempted to load as other weights later + is_expert_weight = True + + # Do not modify `name` since the loop may continue here + # Instead, create a new variable + name_mapped = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name_mapped, self): + continue + + param = params_dict[name_mapped] + # We should ask the weight loader to return success or not + # here since otherwise we may skip experts with other + # available replicas. + weight_loader = typing.cast( + Callable[..., bool], param.weight_loader + ) + success = weight_loader( + param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) + if success: + name = name_mapped + break + else: + if is_expert_weight: + # We've checked that this is an expert weight + # However it's not mapped locally to this rank + # So we simply skip it + continue + + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params + + +class AfmoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_suffix={ + ".router.gate.weight": ".gate.weight", + }, + ) + + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = AfmoeModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead( + config.vocab_size, config.hidden_size, quant_config=quant_config + ) + else: + self.lm_head = PPMissingLayer() + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + self.expert_weights = [] + + # Set MoE hyperparameters + self.num_moe_layers = config.num_hidden_layers - config.num_dense_layers + self.num_expert_groups = config.n_group + + self.moe_layers: list[SharedFusedMoE] = [] + example_moe = None + for layer in self.model.layers: + if isinstance(layer, PPMissingLayer): + continue + + assert isinstance(layer, AfmoeDecoderLayer) + if layer.moe_enabled: + example_moe = layer.mlp + self.moe_layers.append(layer.mlp.experts) + + if example_moe is None and self.num_moe_layers > 0: + raise RuntimeError("No AfmoeMoE layer found in model.layers.") + + if example_moe is not None: + self.num_logical_experts = example_moe.n_logical_experts + self.num_physical_experts = example_moe.n_physical_experts + self.num_local_physical_experts = example_moe.n_local_physical_experts + self.num_routed_experts = example_moe.n_routed_experts + self.num_shared_experts = example_moe.n_shared_experts + self.num_redundant_experts = example_moe.n_redundant_experts + + def set_eplb_state( + self, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + for layer_idx, layer in enumerate(self.moe_layers): + # Register the expert weights. + self.expert_weights.append(layer.get_expert_weights()) + layer.set_eplb_state( + moe_layer_idx=layer_idx, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None: + logits = self.logits_processor(self.lm_head, hidden_states) + return logits + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() diff --git a/vllm/model_executor/models/bailing_moe.py b/vllm/model_executor/models/bailing_moe.py index a878134022565..024425bb24406 100644 --- a/vllm/model_executor/models/bailing_moe.py +++ b/vllm/model_executor/models/bailing_moe.py @@ -39,7 +39,6 @@ from vllm.distributed import ( get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce, ) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import SharedFusedMoE @@ -330,7 +329,9 @@ class BailingMoE(nn.Module): final_hidden_states = final_hidden_states + shared_output if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states + ) return final_hidden_states.view(num_tokens, hidden_size) @@ -598,7 +599,7 @@ class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): config.vocab_size, config.hidden_size, quant_config=quant_config, - prefix=f"{prefix}.lm_head", + prefix=maybe_prefix(prefix, "lm_head"), ) self.logits_processor = LogitsProcessor(config.vocab_size) else: diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 50f476dfd185b..5d611deb942d1 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -20,6 +20,7 @@ from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.linear import ( ColumnParallelLinear, QKVParallelLinear, @@ -315,7 +316,7 @@ class CLIPVisionEmbeddings(nn.Module): self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) - self.patch_embedding = nn.Conv2d( + self.patch_embedding = Conv2dLayer( in_channels=config.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, diff --git a/vllm/model_executor/models/deepseek_eagle.py b/vllm/model_executor/models/deepseek_eagle.py index 9e834a73f8e5e..3fb04c3b70dd1 100644 --- a/vllm/model_executor/models/deepseek_eagle.py +++ b/vllm/model_executor/models/deepseek_eagle.py @@ -26,7 +26,7 @@ from vllm.model_executor.models.deepseek_v2 import ( ) from vllm.utils import init_logger -from .utils import AutoWeightsLoader, maybe_prefix +from .utils import AutoWeightsLoader, maybe_prefix, process_eagle_weight logger = init_logger(__name__) @@ -250,6 +250,7 @@ class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM): name, loaded_weight = inputs if "lm_head" not in name: name = "model." + name + process_eagle_weight(self, name) return name, loaded_weight loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/deepseek_ocr.py b/vllm/model_executor/models/deepseek_ocr.py index c89caab93a1ee..8179f916ff417 100644 --- a/vllm/model_executor/models/deepseek_ocr.py +++ b/vllm/model_executor/models/deepseek_ocr.py @@ -161,7 +161,7 @@ class NGramPerReqLogitsProcessor(AdapterLogitsProcessor): ) def is_argmax_invariant(self) -> bool: - return True + return False def new_req_logits_processor( self, diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 115818d903a6d..e8ee9951d6119 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -85,7 +85,7 @@ from vllm.v1.attention.backends.mla.indexer import ( ) from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec -from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP +from .interfaces import MixtureOfExperts, SupportsEagle, SupportsLoRA, SupportsPP from .utils import ( PPMissingLayer, is_pp_missing_parameter, @@ -1311,7 +1311,7 @@ class DeepseekV2MixtureOfExperts(MixtureOfExperts): class DeepseekV2ForCausalLM( - nn.Module, SupportsPP, DeepseekV2MixtureOfExperts, SupportsLoRA + nn.Module, SupportsPP, DeepseekV2MixtureOfExperts, SupportsLoRA, SupportsEagle ): packed_modules_mapping = { "gate_up_proj": ["gate_proj", "up_proj"], diff --git a/vllm/model_executor/models/dots_ocr.py b/vllm/model_executor/models/dots_ocr.py index 25e5588961a63..f46caaa095c6a 100644 --- a/vllm/model_executor/models/dots_ocr.py +++ b/vllm/model_executor/models/dots_ocr.py @@ -39,8 +39,8 @@ from vllm.model_executor.models.interfaces import ( ) from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM -from vllm.model_executor.models.qwen2_5_vl import Qwen2_5_VisionAttention from vllm.model_executor.models.qwen2_vl import ( + Qwen2VisionAttention, Qwen2VLDummyInputsBuilder, Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo, @@ -328,7 +328,7 @@ class DotsVisionAttention(nn.Module): # [S, C] -> [S, B=1, C] x = hidden_states.unsqueeze(1) x, _ = self.qkv(x) - q, k, v = Qwen2_5_VisionAttention.split_qkv(self, x) + q, k, v = Qwen2VisionAttention.split_qkv(self, x) bs = q.shape[1] # [S,B,H,D] -> [B,S,H,D] q = q.permute(1, 0, 2, 3).contiguous() @@ -780,6 +780,10 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA architectures=["Qwen2ForCausalLM"], ) + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors + ) + def _parse_and_validate_image_input( self, **kwargs: object ) -> DotsOCRImageInputs | None: diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 02fb7ef31dc94..8e2bbe8f7990c 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Any, Literal +from typing import Annotated, Any, Literal, TypeAlias import torch from torch import nn @@ -20,7 +20,12 @@ from vllm.multimodal.inputs import ( MultiModalFieldConfig, MultiModalKwargsItems, ) -from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems +from vllm.multimodal.parse import ( + ImageEmbeddingItems, + ImageProcessorItems, + ImageSize, + MultiModalDataItems, +) from vllm.multimodal.processing import ( BaseMultiModalProcessor, BaseProcessingInfo, @@ -71,7 +76,15 @@ class Gemma3ImagePixelInputs(TensorSchema): num_patches: Annotated[torch.Tensor, TensorShape("bn")] -Gemma3ImageInputs = Gemma3ImagePixelInputs +class Gemma3ImageEmbeddingInputs(TensorSchema): + type: Literal["image_embeds"] = "image_embeds" + image_embeds: Annotated[ + torch.Tensor, + TensorShape("ni", "nf", "hs"), + ] + + +Gemma3ImageInputs: TypeAlias = Gemma3ImagePixelInputs | Gemma3ImageEmbeddingInputs class Gemma3ProcessingInfo(BaseProcessingInfo): @@ -178,8 +191,9 @@ class Gemma3ProcessingInfo(BaseProcessingInfo): def get_image_repl( self, *, - image_width: int, - image_height: int, + image_width: int | None, + image_height: int | None, + num_crops: int | None = None, processor: Gemma3Processor | None, ) -> PromptUpdateDetails[str]: if processor is None: @@ -187,11 +201,13 @@ class Gemma3ProcessingInfo(BaseProcessingInfo): boi_token = processor.boi_token - num_crops = self.get_num_crops( - image_width=image_width, - image_height=image_height, - processor=processor, - ) + if num_crops is None: + assert image_width is not None and image_height is not None + num_crops = self.get_num_crops( + image_width=image_width, + image_height=image_height, + processor=processor, + ) if num_crops == 0: image_text = boi_token @@ -321,6 +337,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): return dict( pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches), num_patches=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), ) def _get_prompt_updates( @@ -333,7 +350,19 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): image_token = hf_processor.boi_token def get_replacement_gemma3(item_idx: int): - images = mm_items.get_items("image", ImageProcessorItems) + images = mm_items.get_items( + "image", (ImageEmbeddingItems, ImageProcessorItems) + ) + + if isinstance(images, ImageEmbeddingItems): + # For image embedding inputs, only support no crops cases + # since it's not supported in hf processor anyway + return self.info.get_image_repl( + image_width=None, + image_height=None, + num_crops=0, + processor=hf_processor, + ) image_size = images.get_image_size(item_idx) return self.info.get_image_repl( @@ -557,17 +586,19 @@ class Gemma3ForConditionalGeneration( pixel_values = kwargs.pop("pixel_values", None) num_patches = kwargs.pop("num_patches", None) image_embeds = kwargs.pop("image_embeds", None) - assert image_embeds is None, "Gemma3 does not support image_embeds." - if pixel_values is None: - return None - image_size = self.config.vision_config.image_size - - return Gemma3ImagePixelInputs( - pixel_values=pixel_values, - num_patches=num_patches, - resolve_bindings={"h": image_size, "w": image_size}, - ) + if pixel_values is not None: + image_size = self.config.vision_config.image_size + return Gemma3ImagePixelInputs( + pixel_values=pixel_values, + num_patches=num_patches, + resolve_bindings={"h": image_size, "w": image_size}, + ) + elif image_embeds is not None: + return Gemma3ImageEmbeddingInputs( + image_embeds=image_embeds, + type="image_embeds", + ) def _image_pixels_to_features( self, @@ -579,7 +610,9 @@ class Gemma3ForConditionalGeneration( def _process_image_input( self, image_input: Gemma3ImageInputs, - ) -> list[torch.Tensor]: + ) -> torch.Tensor | list[torch.Tensor]: + if image_input["type"] == "image_embeds": + return image_input["image_embeds"] assert self.vision_tower is not None pixel_values = image_input["pixel_values"] diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index b2d4fe0c0139b..6953b805653b4 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -56,12 +56,12 @@ from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state from vllm.distributed import utils as dist_utils from vllm.logger import init_logger +from vllm.model_executor.layers.conv import Conv3dLayer from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, - ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig @@ -103,7 +103,6 @@ from .utils import ( maybe_prefix, ) from .vision import ( - conv3d_to_linear_weight, get_vit_attn_backend, run_dp_sharded_mrope_vision_model, ) @@ -486,15 +485,18 @@ class Glm4vVisionPatchEmbed(nn.Module): self.hidden_size = hidden_size kernel_size = (temporal_patch_size, patch_size, patch_size) - self.proj = ReplicatedLinear( - in_channels * math.prod(kernel_size), + self.proj = Conv3dLayer( + in_channels, hidden_size, + kernel_size=kernel_size, + stride=kernel_size, bias=True, - return_bias=False, ) def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.proj(x) + L, C = x.shape + x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) + x = self.proj(x).view(L, self.hidden_size) return x @@ -893,9 +895,6 @@ class Glm4vVisionTransformer(nn.Module): loaded_params: set[str] = set() for name, loaded_weight in weights: - if name.endswith("patch_embed.proj.weight"): - loaded_weight = conv3d_to_linear_weight(loaded_weight) - for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/models/glm4_moe_mtp.py b/vllm/model_executor/models/glm4_moe_mtp.py index 110ed0a646334..e34ae6c85a4f8 100644 --- a/vllm/model_executor/models/glm4_moe_mtp.py +++ b/vllm/model_executor/models/glm4_moe_mtp.py @@ -256,13 +256,12 @@ class Glm4MoeMTP(nn.Module, SupportsPP, Glm4MixtureOfExperts): params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() - spec_layer = self.model.mtp_start_layer_idx for name, loaded_weight in weights: if name == "lm_head.weight": - name = f"model.layers.{spec_layer}.shard_head.head.weight" + spec_layer = self.model.mtp_start_layer_idx + name = f"model.layers.{spec_layer}.shared_head.head.weight" elif name == "model.embed_tokens.weight": - # This name is same with local model, rewriting is not needed. - pass + spec_layer = self.model.mtp_start_layer_idx else: spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) if spec_layer is None: diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index 692ef605fe175..7df3b087ccb88 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -494,8 +494,8 @@ class GptOssModel(nn.Module): def _load_weights_other( self, - ep_rank_start: int, ep_rank_end: int, + ep_rank_start: int, heads_per_rank: int, head_start: int, weights: Iterable[tuple[str, torch.Tensor]], @@ -641,8 +641,8 @@ class GptOssModel(nn.Module): ) else: return self._load_weights_other( - ep_rank_end, ep_rank_start, + ep_rank_end, heads_per_rank, head_start, weights, diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 929bfaaee5cbb..dc4caf2f02f9d 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -932,13 +932,73 @@ def supports_transcription( @runtime_checkable -class SupportsEagle3(Protocol): +class SupportsEagleBase(Protocol): + """Base interface for models that support EAGLE-based speculative decoding.""" + + has_own_lm_head: bool = False + """ + A flag that indicates this model has trained its own lm_head. + """ + + has_own_embed_tokens: bool = False + """ + A flag that indicates this model has trained its own input embeddings. + """ + + +@overload +def supports_any_eagle(model: type[object]) -> TypeIs[type[SupportsEagleBase]]: ... + + +@overload +def supports_any_eagle(model: object) -> TypeIs[SupportsEagleBase]: ... + + +def supports_any_eagle( + model: type[object] | object, +) -> TypeIs[type[SupportsEagleBase]] | TypeIs[SupportsEagleBase]: + """Check if model supports any EAGLE variant (1, 2, or 3).""" + return supports_eagle(model) or supports_eagle3(model) + + +@runtime_checkable +class SupportsEagle(SupportsEagleBase, Protocol): """The interface required for models that support - EAGLE3 speculative decoding.""" + EAGLE-1 and EAGLE-2 speculative decoding.""" + + supports_eagle: ClassVar[Literal[True]] = True + """ + A flag that indicates this model supports EAGLE-1 and EAGLE-2 + speculative decoding. + + Note: + There is no need to redefine this flag if this class is in the + MRO of your model class. + """ + + +@overload +def supports_eagle(model: type[object]) -> TypeIs[type[SupportsEagle]]: ... + + +@overload +def supports_eagle(model: object) -> TypeIs[SupportsEagle]: ... + + +def supports_eagle( + model: type[object] | object, +) -> TypeIs[type[SupportsEagle]] | TypeIs[SupportsEagle]: + return isinstance(model, SupportsEagle) + + +@runtime_checkable +class SupportsEagle3(SupportsEagleBase, Protocol): + """The interface required for models that support + EAGLE-3 speculative decoding.""" supports_eagle3: ClassVar[Literal[True]] = True """ - A flag that indicates this model supports EAGLE3 + A flag that indicates this model supports EAGLE-3 speculative decoding. Note: @@ -949,7 +1009,7 @@ class SupportsEagle3(Protocol): def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: """ Set which layers should output auxiliary - hidden states for EAGLE3. + hidden states for EAGLE-3. Args: layers: Tuple of layer indices that should output auxiliary @@ -960,7 +1020,7 @@ class SupportsEagle3(Protocol): def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: """ Get the layer indices that should output auxiliary hidden states - for EAGLE3. + for EAGLE-3. Returns: Tuple of layer indices for auxiliary hidden state outputs. diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index c49a1ea817f91..0a3f37c30ab5f 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -58,7 +58,7 @@ from vllm.model_executor.model_loader.weight_utils import ( ) from vllm.sequence import IntermediateTensors -from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP +from .interfaces import SupportsEagle, SupportsEagle3, SupportsLoRA, SupportsPP from .utils import ( AutoWeightsLoader, PPMissingLayer, @@ -529,7 +529,9 @@ class LlamaModel(nn.Module): return loaded_params -class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): +class LlamaForCausalLM( + nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3 +): packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"], diff --git a/vllm/model_executor/models/llama4_eagle.py b/vllm/model_executor/models/llama4_eagle.py index e8716d652415e..660c8f1bb5226 100644 --- a/vllm/model_executor/models/llama4_eagle.py +++ b/vllm/model_executor/models/llama4_eagle.py @@ -35,7 +35,7 @@ from vllm.model_executor.models.llama4 import Llama4DecoderLayer, Llama4ForCausa from vllm.model_executor.models.utils import extract_layer_index from .interfaces import SupportsMultiModal -from .utils import AutoWeightsLoader, maybe_prefix +from .utils import AutoWeightsLoader, maybe_prefix, process_eagle_weight logger = init_logger(__name__) @@ -212,6 +212,7 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM): name, weight = self.permute_qk_weight_for_rotary(name, loaded_weight) if "lm_head" not in name: name = "model." + name + process_eagle_weight(self, name) return name, weight loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index ab2a9f6f06dbe..90ab5c50361b6 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -11,13 +11,22 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.distributed.parallel_state import get_pp_group from vllm.logger import init_logger +from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM -from .utils import AutoWeightsLoader, maybe_prefix +from .utils import ( + AutoWeightsLoader, + get_draft_quant_config, + maybe_prefix, + process_eagle_weight, +) logger = init_logger(__name__) @@ -40,14 +49,7 @@ class LlamaDecoderLayer(LlamaDecoderLayer): def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None: """Use drafter's quantization config instead of verifier's.""" - draft_model_config = vllm_config.speculative_config.draft_model_config - draft_load_config = vllm_config.load_config - - return ( - VllmConfig.get_quantization_config(draft_model_config, draft_load_config) - if draft_model_config - else None - ) + return get_draft_quant_config(vllm_config) @support_torch_compile @@ -63,6 +65,9 @@ class LlamaModel(nn.Module): self.config = vllm_config.speculative_config.draft_model_config.hf_config self.vocab_size = self.config.vocab_size + # Get drafter's quantization config + self.quant_config = get_draft_quant_config(vllm_config) + self.embed_tokens = VocabParallelEmbedding( self.config.vocab_size, self.config.hidden_size, @@ -80,8 +85,14 @@ class LlamaModel(nn.Module): for i in range(self.config.num_hidden_layers) ] ) - self.fc = torch.nn.Linear( - self.config.hidden_size * 2, self.config.hidden_size, bias=False + self.fc = ReplicatedLinear( + input_size=self.config.hidden_size * 2, + output_size=self.config.hidden_size, + bias=False, + params_dtype=vllm_config.model_config.dtype, + quant_config=self.quant_config, + prefix=maybe_prefix(prefix, "fc"), + return_bias=False, ) def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: @@ -117,6 +128,24 @@ class LlamaModel(nn.Module): params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: + # Handle kv cache quantization scales + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + # Remapping the name FP8 kv-scale + if "scale" in name: + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue @@ -179,6 +208,7 @@ class EagleLlamaForCausalLM(LlamaForCausalLM): name, loaded_weight = inputs if "lm_head" not in name: name = "model." + name + process_eagle_weight(self, name) return name, loaded_weight loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 6edc9519dfbbf..75c671311b491 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -11,19 +11,27 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig, get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import QKVParallelLinear +from vllm.model_executor.layers.linear import QKVParallelLinear, ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import NestedTensors -from .utils import AutoWeightsLoader, maybe_prefix +from .utils import ( + AutoWeightsLoader, + get_draft_quant_config, + maybe_prefix, + process_eagle_weight, +) logger = init_logger(__name__) @@ -66,14 +74,7 @@ class LlamaDecoderLayer(LlamaDecoderLayer): def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None: """Use drafter's quantization config instead of verifier's.""" - draft_model_config = vllm_config.speculative_config.draft_model_config - draft_load_config = vllm_config.load_config - - return ( - VllmConfig.get_quantization_config(draft_model_config, draft_load_config) - if draft_model_config - else None - ) + return get_draft_quant_config(vllm_config) def _norm_before_residual( self, hidden_states: torch.Tensor @@ -140,6 +141,9 @@ class LlamaModel(nn.Module): self.config = vllm_config.speculative_config.draft_model_config.hf_config self.vocab_size = self.config.vocab_size + # Get drafter's quantization config + self.quant_config = get_draft_quant_config(vllm_config) + current_vllm_config = get_current_vllm_config() self.embed_tokens = VocabParallelEmbedding( @@ -160,13 +164,19 @@ class LlamaModel(nn.Module): ] ) if hasattr(self.config, "target_hidden_size"): - self.fc = torch.nn.Linear( - self.config.target_hidden_size * 3, self.config.hidden_size, bias=False - ) + fc_input_size = self.config.target_hidden_size * 3 else: - self.fc = torch.nn.Linear( - self.config.hidden_size * 3, self.config.hidden_size, bias=False - ) + fc_input_size = self.config.hidden_size * 3 + self.fc = ReplicatedLinear( + input_size=fc_input_size, + output_size=self.config.hidden_size, + bias=False, + params_dtype=vllm_config.model_config.dtype, + quant_config=self.quant_config, + prefix=maybe_prefix(prefix, "fc"), + return_bias=False, + ) + self.norm = RMSNorm( self.config.hidden_size, eps=self.config.rms_norm_eps, @@ -211,6 +221,24 @@ class LlamaModel(nn.Module): for name, loaded_weight in weights: if "midlayer." in name: name = name.replace("midlayer.", "layers.0.") + # Handle kv cache quantization scales + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + # Remapping the name FP8 kv-scale + if "scale" in name: + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue @@ -324,6 +352,7 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): if "embed_tokens" in name: includes_embed_tokens = True model_weights[name] = loaded_weight + process_eagle_weight(self, name) skip_substrs = [] if not includes_draft_id_mapping: diff --git a/vllm/model_executor/models/minicpm_eagle.py b/vllm/model_executor/models/minicpm_eagle.py index 0ca31913485db..d0cdb70aa8574 100644 --- a/vllm/model_executor/models/minicpm_eagle.py +++ b/vllm/model_executor/models/minicpm_eagle.py @@ -43,7 +43,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA, SupportsPP +from .interfaces import SupportsEagle, SupportsLoRA, SupportsPP from .minicpm import MiniCPMAttention as EagleMiniCPMAttention from .minicpm import MiniCPMMLP as EagleMiniCPMMLP from .minicpm import MiniCPMMoE as EagleMiniCPMMoE @@ -52,6 +52,7 @@ from .utils import ( is_pp_missing_parameter, make_empty_intermediate_tensors_factory, maybe_prefix, + process_eagle_weight, ) @@ -289,7 +290,7 @@ class EagleMiniCPMModel(nn.Module): return loaded_params -class EagleMiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): +class EagleMiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -376,8 +377,13 @@ class EagleMiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + def transform(inputs): + name, loaded_weight = inputs + process_eagle_weight(self, name) + return name, loaded_weight + loader = AutoWeightsLoader( self, skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) - return loader.load_weights(weights) + return loader.load_weights(map(transform, weights)) diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index 14e741f322582..e25a104d822a7 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -35,6 +35,7 @@ from vllm.attention.layer import MultiHeadAttention from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import ( ColumnParallelLinear, QKVParallelLinear, @@ -45,6 +46,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.model_loader.utils import initialize_model from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalDataDict, @@ -68,11 +70,15 @@ from .interfaces import ( MixtureOfExperts, MultiModalEmbeddings, SupportsEagle3, + SupportsLoRA, SupportsMultiModal, SupportsPP, ) from .llama4 import Llama4ForCausalLM -from .utils import AutoWeightsLoader, maybe_prefix +from .utils import ( + AutoWeightsLoader, + maybe_prefix, +) from .vision import run_dp_sharded_vision_model @@ -724,7 +730,12 @@ class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]): dummy_inputs=Mllama4DummyInputsBuilder, ) class Llama4ForConditionalGeneration( - nn.Module, SupportsMultiModal, SupportsPP, MixtureOfExperts, SupportsEagle3 + nn.Module, + SupportsMultiModal, + SupportsPP, + MixtureOfExperts, + SupportsEagle3, + SupportsLoRA, ): merge_by_field_config = True @@ -1067,6 +1078,17 @@ class Llama4ForConditionalGeneration( return updated_params + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.text_config.num_local_experts, + num_redundant_experts=self.num_redundant_experts, + ) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -1113,3 +1135,13 @@ class Llama4ForConditionalGeneration( ) return updated_params + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="multi_modal_projector.", + tower_model="vision_model.", + ) diff --git a/vllm/model_executor/models/paddleocr_vl.py b/vllm/model_executor/models/paddleocr_vl.py index 183f458658aa3..3ef6470070d18 100644 --- a/vllm/model_executor/models/paddleocr_vl.py +++ b/vllm/model_executor/models/paddleocr_vl.py @@ -232,8 +232,7 @@ class PaddleOCRVLProcessingInfo(BaseProcessingInfo): # Find factors of max_num_tokens close to its square root # to create a dummy image with a reasonable aspect ratio. h_patches = int(math.sqrt(max_num_tokens)) - while max_num_tokens % h_patches != 0: - h_patches -= 1 + max_num_tokens -= max_num_tokens % h_patches w_patches = max_num_tokens // h_patches return ImageSize(height=h_patches * factor, width=w_patches * factor) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 23591480b160e..897dd7ef29f12 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -26,7 +26,6 @@ # limitations under the License. """Inference-only Qwen2.5-VL model compatible with HuggingFace weights.""" -import math from collections.abc import Callable, Iterable, Mapping, Sequence from functools import lru_cache, partial from typing import Annotated, Any, Literal, TypeAlias @@ -56,12 +55,12 @@ from vllm.distributed import utils as dist_utils from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.activation import get_act_and_mul_fn +from vllm.model_executor.layers.conv import Conv3dLayer from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, - ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig @@ -110,7 +109,6 @@ from .utils import ( maybe_prefix, ) from .vision import ( - conv3d_to_linear_weight, get_vit_attn_backend, run_dp_sharded_mrope_vision_model, ) @@ -361,23 +359,6 @@ class Qwen2_5_VisionAttention(nn.Module): AttentionBackendEnum.ROCM_AITER_FA, } - def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: - # [s, b, 3 * head * head_dim] - seq_len, bs, _ = qkv.shape - - # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim] - q, k, v = qkv.chunk(3, dim=2) - - # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim] - new_shape = ( - seq_len, - bs, - self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head, - ) - q, k, v = (x.view(*new_shape) for x in (q, k, v)) - return q, k, v - def forward( self, x: torch.Tensor, @@ -388,17 +369,32 @@ class Qwen2_5_VisionAttention(nn.Module): ) -> torch.Tensor: # [s, b, c] --> [s, b, head * 3 * head_dim] x, _ = self.qkv(x) + seq_len, batch_size, _ = x.shape - # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim] - q, k, v = self.split_qkv(x) - batch_size = q.shape[1] + qkv = einops.rearrange( + x, + "s b (three head head_dim) -> b s three head head_dim", + three=3, + head=self.num_attention_heads_per_partition, + ) - q, k, v = (einops.rearrange(x, "s b ... -> b s ...") for x in (q, k, v)) if rotary_pos_emb is not None: - # [2 * b, s, heads, head_dim] - qk_concat = torch.cat([q, k], dim=0) - qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) - q, k = torch.chunk(qk_rotated, 2, dim=0) + qk, v = qkv[:, :, :2], qkv[:, :, 2] + + qk_reshaped = einops.rearrange( + qk, "b s two head head_dim -> (two b) s head head_dim", two=2 + ) + qk_rotated = apply_rotary_pos_emb_vision(qk_reshaped, rotary_pos_emb) + qk_rotated = qk_rotated.view( + 2, + batch_size, + seq_len, + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + q, k = qk_rotated.unbind(dim=0) + else: + q, k, v = qkv.unbind(dim=2) if self.is_flash_attn_backend: context_layer = vit_flash_attn_wrapper( @@ -525,15 +521,18 @@ class Qwen2_5_VisionPatchEmbed(nn.Module): self.hidden_size = hidden_size kernel_size = (temporal_patch_size, patch_size, patch_size) - self.proj = ReplicatedLinear( - in_channels * math.prod(kernel_size), + self.proj = Conv3dLayer( + in_channels, hidden_size, + kernel_size=kernel_size, + stride=kernel_size, bias=False, - return_bias=False, ) def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.proj(x) + L, C = x.shape + x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) + x = self.proj(x).view(L, self.hidden_size) return x @@ -957,9 +956,6 @@ class Qwen2_5_VisionTransformer(nn.Module): loaded_params: set[str] = set() for name, loaded_weight in weights: - if name.endswith("patch_embed.proj.weight"): - loaded_weight = conv3d_to_linear_weight(loaded_weight) - for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 13b54bbe17488..5d21e249fc4cc 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -25,7 +25,6 @@ # limitations under the License. """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" -import math from collections.abc import Callable, Iterable, Mapping, Sequence from functools import partial from typing import Annotated, Any, Literal, TypeAlias @@ -54,9 +53,9 @@ from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils from vllm.logger import init_logger from vllm.model_executor.layers.activation import QuickGELU +from vllm.model_executor.layers.conv import Conv3dLayer from vllm.model_executor.layers.linear import ( ColumnParallelLinear, - ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig @@ -107,7 +106,6 @@ from .utils import ( maybe_prefix, ) from .vision import ( - conv3d_to_linear_weight, get_vit_attn_backend, run_dp_sharded_mrope_vision_model, ) @@ -566,15 +564,18 @@ class Qwen2VisionPatchEmbed(nn.Module): self.embed_dim = embed_dim kernel_size = (temporal_patch_size, patch_size, patch_size) - self.proj = ReplicatedLinear( - in_channels * math.prod(kernel_size), + self.proj = Conv3dLayer( + in_channels, embed_dim, + kernel_size=kernel_size, + stride=kernel_size, bias=False, - return_bias=False, ) def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.proj(x) + L, C = x.shape + x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) + x = self.proj(x).view(L, self.embed_dim) return x @@ -844,9 +845,6 @@ class Qwen2VisionTransformer(nn.Module): loaded_params: set[str] = set() for name, loaded_weight in weights: - if name.endswith("patch_embed.proj.weight"): - loaded_weight = conv3d_to_linear_weight(loaded_weight) - for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py index 5df2372a842cf..40b80ce2387c8 100755 --- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -22,7 +22,6 @@ # limitations under the License. """Inference-only Qwen3-Omni-Moe model (thinker part).""" -import math from collections.abc import Callable, Iterable, Mapping, Sequence from functools import partial from typing import Any @@ -54,9 +53,9 @@ from vllm.config import VllmConfig from vllm.distributed import get_pp_group from vllm.logger import init_logger from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY +from vllm.model_executor.layers.conv import Conv3dLayer from vllm.model_executor.layers.linear import ( ColumnParallelLinear, - ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -102,7 +101,6 @@ from .utils import ( maybe_prefix, ) from .vision import ( - conv3d_to_linear_weight, get_llm_pos_ids_for_vision, get_vit_attn_backend, ) @@ -138,16 +136,18 @@ class Qwen3_VisionPatchEmbed(nn.Module): self.hidden_size = hidden_size kernel_size = (temporal_patch_size, patch_size, patch_size) - self.proj = ReplicatedLinear( - in_channels * math.prod(kernel_size), + self.proj = Conv3dLayer( + in_channels, hidden_size, + kernel_size=kernel_size, + stride=kernel_size, bias=True, - return_bias=False, ) def forward(self, x: torch.Tensor) -> torch.Tensor: L, C = x.shape - x = self.proj(x) + x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) + x = self.proj(x).view(L, self.hidden_size) return x @@ -566,9 +566,6 @@ class Qwen3Omni_VisionTransformer(nn.Module): loaded_params: set[str] = set() for name, loaded_weight in weights: - if name.endswith("patch_embed.proj.weight"): - loaded_weight = conv3d_to_linear_weight(loaded_weight) - for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 5f5bde1dd72d3..7f0c9372991d1 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -24,9 +24,8 @@ # limitations under the License. """Inference-only Qwen3VL model compatible with HuggingFace weights.""" -import math -from collections.abc import Callable, Iterable, Mapping, Sequence -from functools import partial +from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence +from functools import lru_cache, partial from itertools import islice from typing import Any @@ -57,9 +56,9 @@ from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions from vllm.distributed import get_pp_group from vllm.logger import init_logger from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY +from vllm.model_executor.layers.conv import Conv3dLayer from vllm.model_executor.layers.linear import ( ColumnParallelLinear, - ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -114,7 +113,6 @@ from .utils import ( maybe_prefix, ) from .vision import ( - conv3d_to_linear_weight, get_vit_attn_backend, run_dp_sharded_mrope_vision_model, ) @@ -139,15 +137,18 @@ class Qwen3_VisionPatchEmbed(nn.Module): self.hidden_size = hidden_size kernel_size = (temporal_patch_size, patch_size, patch_size) - self.proj = ReplicatedLinear( - in_channels * math.prod(kernel_size), + self.proj = Conv3dLayer( + in_channels, hidden_size, + kernel_size=kernel_size, + stride=kernel_size, bias=True, - return_bias=False, ) def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.proj(x) + L, C = x.shape + x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) + x = self.proj(x).view(L, self.hidden_size) return x @@ -415,30 +416,41 @@ class Qwen3_VisionTransformer(nn.Module): def device(self) -> torch.device: return self.patch_embed.proj.weight.device - def rot_pos_emb(self, grid_thw: list[list[int]]): - pos_ids = [] - max_grid_size = max(max(h, w) for _, h, w in grid_thw) - for t, h, w in grid_thw: - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - hpos_ids = hpos_ids.permute(0, 2, 1, 3) - hpos_ids = hpos_ids.flatten() + @staticmethod + @lru_cache(maxsize=1024) + def rot_pos_ids(h: int, w: int, spatial_merge_size: int) -> torch.Tensor: + hpos_ids = np.broadcast_to(np.arange(h).reshape(h, 1), (h, w)) + h_div = h // spatial_merge_size + w_div = w // spatial_merge_size + hpos_ids = hpos_ids.reshape( + h_div, + spatial_merge_size, + w_div, + spatial_merge_size, + ) + hpos_ids = hpos_ids.transpose(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - wpos_ids = wpos_ids.permute(0, 2, 1, 3) - wpos_ids = wpos_ids.flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + wpos_ids = np.broadcast_to(np.arange(w).reshape(1, w), (h, w)) + wpos_ids = wpos_ids.reshape( + h_div, + spatial_merge_size, + w_div, + spatial_merge_size, + ) + wpos_ids = wpos_ids.transpose(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + + return torch.from_numpy(np.stack([hpos_ids, wpos_ids], axis=-1)) + + def rot_pos_emb(self, grid_thw: list[list[int]]): + max_grid_size = max(max(h, w) for _, h, w in grid_thw) + pos_ids = [ + self.rot_pos_ids(h, w, self.spatial_merge_size) + if t == 1 + else self.rot_pos_ids(h, w, self.spatial_merge_size).repeat(t, 1) + for t, h, w in grid_thw + ] pos_ids = torch.cat(pos_ids, dim=0) rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) @@ -579,9 +591,6 @@ class Qwen3_VisionTransformer(nn.Module): loaded_params: set[str] = set() for name, loaded_weight in weights: - if name.endswith("patch_embed.proj.weight"): - loaded_weight = conv3d_to_linear_weight(loaded_weight) - for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue @@ -1140,7 +1149,9 @@ class Qwen3LLMForCausalLM(Qwen3ForCausalLM): self.config = config self.quant_config = quant_config - self.model = Qwen3LLMModel(vllm_config=vllm_config, prefix=prefix) + self.model = Qwen3LLMModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) if get_pp_group().is_last_rank: if config.tie_word_embeddings: @@ -1412,72 +1423,47 @@ class Qwen3VLForConditionalGeneration( ) return mm_input_by_modality + def iter_mm_grid_hw( + self, input_tokens: list[int], mm_features: list[MultiModalFeatureSpec] + ) -> Iterator[tuple[int, int, int]]: + video_token_id = self.config.video_token_id + spatial_merge_size = self.config.vision_config.spatial_merge_size + for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset): + offset = mm_feature.mm_position.offset + if mm_feature.modality == "image": + t, h, w = mm_feature.data["image_grid_thw"].data.tolist() + assert t == 1, f"Image must have 1 frame, got {t}" + yield offset, h // spatial_merge_size, w // spatial_merge_size + elif mm_feature.modality == "video": + t, h, w = mm_feature.data["video_grid_thw"].data.tolist() + llm_grid_h = h // spatial_merge_size + llm_grid_w = w // spatial_merge_size + for _ in range(t): + offset = input_tokens.index(video_token_id, offset) + yield offset, llm_grid_h, llm_grid_w + offset += llm_grid_h * llm_grid_w + else: + raise ValueError(f"Unsupported modality: {mm_feature.modality}") + def get_mrope_input_positions( self, input_tokens: list[int], mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: - kwargs = MultiModalFeatureSpec.gather_kwargs( - mm_features, - {"image_grid_thw", "video_grid_thw"}, - ) - image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])] - video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])] - - video_grid_thw = [[1, h, w] for t, h, w in video_grid_thw for _ in range(t)] - - hf_config = self.config - image_token_id = hf_config.image_token_id - video_token_id = hf_config.video_token_id - vision_start_token_id = hf_config.vision_start_token_id - spatial_merge_size = hf_config.vision_config.spatial_merge_size - - input_tokens_array = np.array(input_tokens) - vision_start_mask = input_tokens_array == vision_start_token_id - vision_tokens = input_tokens_array[vision_start_mask.nonzero()[0] + 1] - image_nums = np.count_nonzero(vision_tokens == image_token_id) - video_nums = np.count_nonzero(vision_tokens == video_token_id) - llm_pos_ids_list: list = [] - + llm_pos_ids_list = [] st = 0 - remain_images, remain_videos = image_nums, video_nums - - image_index, video_index = 0, 0 - for _ in range(image_nums + video_nums): - if image_token_id in input_tokens and remain_images > 0: - ed_image = input_tokens.index(image_token_id, st) - else: - ed_image = len(input_tokens) + 1 - if video_token_id in input_tokens and remain_videos > 0: - ed_video = input_tokens.index(video_token_id, st) - else: - ed_video = len(input_tokens) + 1 - if ed_image < ed_video: - t, h, w = image_grid_thw[image_index] - image_index += 1 - remain_images -= 1 - ed = ed_image - else: - t, h, w = video_grid_thw[video_index] - video_index += 1 - remain_videos -= 1 - ed = ed_video - - llm_grid_t, llm_grid_h, llm_grid_w = ( - t, - h // spatial_merge_size, - w // spatial_merge_size, - ) - text_len = ed - st - + for offset, llm_grid_h, llm_grid_w in self.iter_mm_grid_hw( + input_tokens, mm_features + ): + text_len = offset - st st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 llm_pos_ids_list.append( np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx ) - grid_indices = np.indices((llm_grid_t, llm_grid_h, llm_grid_w)) - llm_pos_ids_list.append(grid_indices.reshape(3, -1) + text_len + st_idx) - st = ed + llm_grid_t * llm_grid_h * llm_grid_w + grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1) + llm_pos_ids_list.append(grid_indices + text_len + st_idx) + st = offset + llm_grid_h * llm_grid_w if st < len(input_tokens): st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 4af8fa01f562b..6e9790de49bfa 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -56,6 +56,7 @@ logger = init_logger(__name__) _TEXT_GENERATION_MODELS = { # [Decoder-only] + "AfmoeForCausalLM": ("afmoe", "AfmoeForCausalLM"), "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"), "AquilaModel": ("llama", "LlamaForCausalLM"), "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2 diff --git a/vllm/model_executor/models/transformers/__init__.py b/vllm/model_executor/models/transformers/__init__.py index 365b5eb08893d..93cd8ff507669 100644 --- a/vllm/model_executor/models/transformers/__init__.py +++ b/vllm/model_executor/models/transformers/__init__.py @@ -120,8 +120,8 @@ def __getattr__(name: str): """Handle imports of non-existent classes with a helpful error message.""" if name not in globals(): raise AttributeError( - "The Transformers backend does not currently have a class to handle " - f"the requested model type: {name}. Please open an issue at " + "The Transformers modeling backend does not currently have a class to " + f"handle the requested model type: {name}. Please open an issue at " "https://github.com/vllm-project/vllm/issues/new" ) return globals()[name] diff --git a/vllm/model_executor/models/transformers/base.py b/vllm/model_executor/models/transformers/base.py index 63096e57f8eee..f4ba4758bcc46 100644 --- a/vllm/model_executor/models/transformers/base.py +++ b/vllm/model_executor/models/transformers/base.py @@ -14,7 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Transformers backend base class.""" +"""Transformers modeling backend base class.""" from collections.abc import Iterable from typing import TYPE_CHECKING @@ -118,7 +118,7 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP): def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""): super().__init__() - logger.info("Using Transformers backend.") + logger.info("Using Transformers modeling backend.") self.config = vllm_config.model_config.hf_config self.text_config = self.config.get_text_config() @@ -147,7 +147,8 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP): # Check for unsupported quantization methods. if quant_method_name == "mxfp4": raise NotImplementedError( - "Transformers backend does not support MXFP4 quantization yet." + "Transformers modeling backend does " + "not support MXFP4 quantization yet." ) # Skip loading extra bias for GPTQ models. if "gptq" in quant_method_name: @@ -458,6 +459,6 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP): required = Version(min_version) if installed < required: raise ImportError( - f"Transformers backend requires transformers>={required} " + f"Transformers modeling backend requires transformers>={required} " f"for {feature}, but got {installed}" ) diff --git a/vllm/model_executor/models/transformers/causal.py b/vllm/model_executor/models/transformers/causal.py index 42fd11117c737..b2865ed0c7ff5 100644 --- a/vllm/model_executor/models/transformers/causal.py +++ b/vllm/model_executor/models/transformers/causal.py @@ -14,7 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Transformers backend mixin for causal language models.""" +"""Transformers modeling backend mixin for causal language models.""" from typing import TYPE_CHECKING diff --git a/vllm/model_executor/models/transformers/legacy.py b/vllm/model_executor/models/transformers/legacy.py index a453870a2687f..aca630be56154 100644 --- a/vllm/model_executor/models/transformers/legacy.py +++ b/vllm/model_executor/models/transformers/legacy.py @@ -14,7 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Transformers backend mixin for legacy models.""" +"""Transformers modeling backend mixin for legacy models.""" from typing import TYPE_CHECKING diff --git a/vllm/model_executor/models/transformers/moe.py b/vllm/model_executor/models/transformers/moe.py index 8e39eb0b9902c..4973014c3d4ed 100644 --- a/vllm/model_executor/models/transformers/moe.py +++ b/vllm/model_executor/models/transformers/moe.py @@ -14,7 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Transformers backend mixin for Mixture of Experts (MoE) models.""" +"""Transformers modeling backend mixin for Mixture of Experts (MoE) models.""" from typing import TYPE_CHECKING, Any @@ -39,7 +39,7 @@ if TYPE_CHECKING: @CustomOp.register("transformers_fused_moe") class TransformersFusedMoE(FusedMoE): - """Custom FusedMoE for the Transformers backend.""" + """Custom FusedMoE for the Transformers modeling backend.""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/vllm/model_executor/models/transformers/multimodal.py b/vllm/model_executor/models/transformers/multimodal.py index 9b0463f41fa87..ccf6053719871 100644 --- a/vllm/model_executor/models/transformers/multimodal.py +++ b/vllm/model_executor/models/transformers/multimodal.py @@ -14,7 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Transformers backend mixin for multi-modal models.""" +"""Transformers modeling backend mixin for multi-modal models.""" from collections.abc import Mapping from typing import TYPE_CHECKING @@ -310,9 +310,9 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE): return model_output def get_language_model(self) -> torch.nn.Module: - """Transformers backend multimodal classes do not contain a separate vLLM - language model class. Therefore, in order to return a language model vLLM class, - we use a wrapper to give `self` the same interface as a text model.""" + """Transformers modeling backend multimodal classes do not contain a separate + vLLM language model class. Therefore, in order to return a language model vLLM + class, we use a wrapper to give `self` the same interface as a text model.""" # Exclude self and object bases = self.__class__.mro()[1:-1] @@ -385,7 +385,9 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE): for k, v in kwargs.items() if k not in {"image_grid_thw", "video_grid_thw"} ): - raise NotImplementedError("Transformers backend only supports images.") + raise NotImplementedError( + "Transformers modeling backend only supports images." + ) image_grid_thw = kwargs.get("image_grid_thw", []) video_grid_thw = kwargs.get("video_grid_thw", []) diff --git a/vllm/model_executor/models/transformers/pooling.py b/vllm/model_executor/models/transformers/pooling.py index 8117bbac013ea..4c2a74bccb6a9 100644 --- a/vllm/model_executor/models/transformers/pooling.py +++ b/vllm/model_executor/models/transformers/pooling.py @@ -14,7 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Transformers backend mixins for pooling models.""" +"""Transformers modeling backend mixins for pooling models.""" from typing import TYPE_CHECKING diff --git a/vllm/model_executor/models/transformers/utils.py b/vllm/model_executor/models/transformers/utils.py index 267a6e06e6bbf..517eb54d53ac6 100644 --- a/vllm/model_executor/models/transformers/utils.py +++ b/vllm/model_executor/models/transformers/utils.py @@ -14,7 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Transformers backend utilities.""" +"""Transformers modeling backend utilities.""" from contextlib import contextmanager from pathlib import Path diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index f14b79f2886c4..ca5af358e2eed 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -18,7 +18,11 @@ from vllm.distributed import ( get_tensor_model_parallel_world_size, ) from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import supports_any_eagle from vllm.multimodal import NestedTensors from vllm.sequence import IntermediateTensors from vllm.utils.math_utils import cdiv @@ -117,9 +121,10 @@ class AutoWeightsLoader: environment variable `VLLM_LOGGING_LEVEL=DEBUG`. """ - # Models trained using early version ColossalAI - # may include these tensors in checkpoint. Skip them. + # Models trained using early version ColossalAI or quantized by + # GPTQModel may include these tensors in checkpoint. Skip them. ROTARY_EMBEDS_UNUSED_WEIGHTS = [ + "rotary_pos_emb.inv_freq", "rotary_emb.inv_freq", "rotary_emb.cos_cached", "rotary_emb.sin_cached", @@ -713,6 +718,30 @@ def maybe_prefix(prefix: str, name: str) -> str: return name if not prefix else f"{prefix}.{name}" +def get_draft_quant_config( + vllm_config: VllmConfig, +) -> QuantizationConfig | None: + """Get quantization config for Draft models. + + Draft models should use their own quantization config instead of the verifier/target + model's config. This helper retrieves the draft model's quantization config. + + Args: + vllm_config: The vLLM configuration object. + + Returns: + The draft model's config if available, None otherwise. + """ + draft_model_config = vllm_config.speculative_config.draft_model_config + draft_load_config = vllm_config.load_config + + return ( + VllmConfig.get_quantization_config(draft_model_config, draft_load_config) + if draft_model_config + else None + ) + + def extract_layer_index(layer_name: str, num_attn_module: int = 1) -> int: """ Extract the layer index from the module name. @@ -824,3 +853,25 @@ direct_register_custom_op( fake_impl=sequence_parallel_chunk_impl_fake, tags=(torch.Tag.needs_fixed_stride_order,), ) + + +def process_eagle_weight( + model: nn.Module, + name: str, +) -> None: + """ + Update EAGLE model flags based on loaded weight name. + This should be called during weight loading to detect if a model + has its own lm_head or embed_tokens weight. + Args: + model: The model instance (must support EAGLE) + name: The name of the weight to process + """ + if not supports_any_eagle(model): + return + + # To prevent overriding with target model's layers + if "lm_head" in name: + model.has_own_lm_head = True + if "embed_tokens" in name: + model.has_own_embed_tokens = True diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index 0e814e5c86ad4..e5d70eb7bc2fc 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -550,19 +550,3 @@ def get_llm_pos_ids_for_vision( llm_pos_ids_list.append(_llm_pos_ids + start_idx) llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1) return llm_pos_ids - - -# Due to a performance regression with Conv3D in PyTorch2.9, we reshape -# Conv3D weights to Linear weights for better performance. -# See: https://github.com/vllm-project/vllm/issues/27406 -# and https://github.com/pytorch/pytorch/issues/166122 -# FIXME(Isotr0py): Revert the PR introduces this workaround -# (https://github.com/vllm-project/vllm/pull/27418), -# once the performance issue is resolved in PyTorch. -def conv3d_to_linear_weight(conv3d_weight: torch.Tensor) -> torch.Tensor: - """ - Reshape Conv3D weight to Linear weight. Only work when kernel_size==stride. - """ - out_channels, in_channels, kt, kh, kw = conv3d_weight.shape - linear_weight = conv3d_weight.reshape(out_channels, in_channels * kt * kh * kw) - return linear_weight diff --git a/vllm/multimodal/parse.py b/vllm/multimodal/parse.py index 2fa3f6ebcc114..810f29072a0fe 100644 --- a/vllm/multimodal/parse.py +++ b/vllm/multimodal/parse.py @@ -359,8 +359,9 @@ class MultiModalDataParser: ) self.video_needs_metadata = video_needs_metadata - def _is_embeddings( - self, data: object + @classmethod + def is_embeddings( + cls, data: object ) -> TypeGuard[torch.Tensor | list[torch.Tensor]]: if isinstance(data, torch.Tensor): return data.ndim == 3 @@ -420,7 +421,7 @@ class MultiModalDataParser: ): return None - if self._is_embeddings(data): + if self.is_embeddings(data): return AudioEmbeddingItems(data) data_items: list[AudioItem] @@ -458,7 +459,7 @@ class MultiModalDataParser: if self._is_empty(data): return None - if self._is_embeddings(data): + if self.is_embeddings(data): return ImageEmbeddingItems(data) if ( @@ -484,7 +485,7 @@ class MultiModalDataParser: if self._is_empty(data): return None - if self._is_embeddings(data): + if self.is_embeddings(data): return VideoEmbeddingItems(data) data_items: list[VideoItem] diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 8b3b8d4cb44fc..ed655912d3964 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -15,7 +15,6 @@ import torch from vllm import envs from vllm.logger import init_logger -from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS from .interface import CpuArchEnum, Platform, PlatformEnum @@ -134,6 +133,7 @@ class CpuPlatform(Platform): use_mla: bool, has_sink: bool, use_sparse: bool, + attn_type: str | None = None, ) -> str: from vllm.attention.backends.registry import AttentionBackendEnum @@ -192,7 +192,7 @@ class CpuPlatform(Platform): scheduler_config = vllm_config.scheduler_config if ( - scheduler_config.chunked_prefill_enabled + scheduler_config.enable_chunked_prefill or cache_config.enable_prefix_caching ) and cache_config.cache_dtype != "auto": raise RuntimeError( @@ -338,10 +338,9 @@ class CpuPlatform(Platform): "prefill and prefix caching to be disabled." ) vllm_config.scheduler_config.enable_chunked_prefill = False - vllm_config.scheduler_config.chunked_prefill_enabled = False vllm_config.scheduler_config.max_num_batched_tokens = max( - vllm_config.scheduler_config.max_model_len, - DEFAULT_MAX_NUM_BATCHED_TOKENS, + vllm_config.model_config.max_model_len, + vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS, ) @classmethod diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index ebcc290a64cd7..2e4dd8bb808b4 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -298,6 +298,7 @@ class CudaPlatformBase(Platform): has_sink, use_sparse, device_capability, + attn_type, ) -> tuple[ list[tuple["AttentionBackendEnum", int]], dict["AttentionBackendEnum", list[str]], @@ -318,6 +319,7 @@ class CudaPlatformBase(Platform): has_sink, use_sparse, device_capability, + attn_type, ) except ImportError: invalid_reasons_i = ["ImportError"] @@ -339,7 +341,13 @@ class CudaPlatformBase(Platform): use_mla: bool, has_sink: bool, use_sparse: bool, + attn_type: str | None = None, ) -> str: + from vllm.attention import AttentionType + + if attn_type is None: + attn_type = AttentionType.DECODER + device_capability = cls.get_device_capability() assert device_capability is not None @@ -356,6 +364,7 @@ class CudaPlatformBase(Platform): has_sink, use_sparse, device_capability, + attn_type, ) except ImportError: invalid_reasons = ["ImportError"] @@ -379,6 +388,7 @@ class CudaPlatformBase(Platform): has_sink, use_sparse, device_capability, + attn_type, ) reasons_str = ( "{" diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 12c377384270e..0471c20429b1d 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -222,6 +222,7 @@ class Platform: use_mla: bool, has_sink: bool, use_sparse: bool, + attn_type: str | None = None, ) -> str: """Get the attention backend class of a device.""" return "" diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index d20dc9e6b0674..788f9d69c357a 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -216,6 +216,7 @@ class RocmPlatform(Platform): use_mla, has_sink, use_sparse, + attn_type: str | None = None, ) -> str: from vllm._aiter_ops import rocm_aiter_ops from vllm.attention.backends.registry import AttentionBackendEnum diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 4773fef6829d1..944344a229578 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -9,21 +9,25 @@ from tpu_info import device from vllm.inputs import ProcessorInputs, PromptType from vllm.logger import init_logger -from vllm.sampling_params import SamplingParams, SamplingType -from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS from .interface import Platform, PlatformEnum if TYPE_CHECKING: + from typing import TypeAlias + from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import VllmConfig from vllm.config.cache import BlockSize from vllm.pooling_params import PoolingParams + from vllm.sampling_params import SamplingParams + + ParamsType: TypeAlias = SamplingParams | PoolingParams else: BlockSize = None VllmConfig = None PoolingParams = None AttentionBackendEnum = None + ParamsType = None logger = init_logger(__name__) @@ -61,6 +65,7 @@ class TpuPlatform(Platform): use_mla: bool, has_sink, use_sparse, + attn_type: str | None = None, ) -> str: from vllm.attention.backends.registry import AttentionBackendEnum @@ -185,10 +190,9 @@ class TpuPlatform(Platform): "prefill and prefix caching to be disabled." ) vllm_config.scheduler_config.enable_chunked_prefill = False - vllm_config.scheduler_config.chunked_prefill_enabled = False vllm_config.scheduler_config.max_num_batched_tokens = max( - vllm_config.scheduler_config.max_model_len, - DEFAULT_MAX_NUM_BATCHED_TOKENS, + vllm_config.model_config.max_model_len, + vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS, ) @classmethod @@ -204,10 +208,12 @@ class TpuPlatform(Platform): def validate_request( cls, prompt: PromptType, - params: SamplingParams | PoolingParams, + params: ParamsType, processed_inputs: ProcessorInputs, ) -> None: """Raises if this request is unsupported on this platform""" + from vllm.sampling_params import SamplingParams, SamplingType + if ( isinstance(params, SamplingParams) and params.sampling_type == SamplingType.RANDOM_SEED diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index c629325f76a32..65516827a16da 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -9,7 +9,6 @@ import torch import vllm.envs as envs from vllm.logger import init_logger -from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS from .interface import DeviceCapability, Platform, PlatformEnum @@ -51,6 +50,7 @@ class XPUPlatform(Platform): use_mla: bool, has_sink: bool, use_sparse, + attn_type: str | None = None, ) -> str: from vllm.v1.attention.backends.utils import set_kv_cache_layout @@ -184,10 +184,9 @@ class XPUPlatform(Platform): "prefill and prefix caching to be disabled." ) vllm_config.scheduler_config.enable_chunked_prefill = False - vllm_config.scheduler_config.chunked_prefill_enabled = False vllm_config.scheduler_config.max_num_batched_tokens = max( - vllm_config.scheduler_config.max_model_len, - DEFAULT_MAX_NUM_BATCHED_TOKENS, + vllm_config.model_config.max_model_len, + vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS, ) @classmethod diff --git a/vllm/plugins/lora_resolvers/README.md b/vllm/plugins/lora_resolvers/README.md deleted file mode 100644 index 48f27dddea07e..0000000000000 --- a/vllm/plugins/lora_resolvers/README.md +++ /dev/null @@ -1,16 +0,0 @@ -# LoRA Resolver Plugins - -This directory contains vLLM general plugins for dynamically discovering and loading LoRA adapters -via the LoRAResolver plugin framework. - -Note that `VLLM_ALLOW_RUNTIME_LORA_UPDATING` must be set to true to allow LoRA resolver plugins -to work, and `VLLM_PLUGINS` must be set to include the desired resolver plugins. - -## lora_filesystem_resolver - -This LoRA Resolver is installed with vLLM by default. -To use, set `VLLM_PLUGIN_LORA_CACHE_DIR` to a local directory. When vLLM receives a request -for a LoRA adapter `foobar` it doesn't currently recognize, it will look in that local directory -for a subdirectory `foobar` containing a LoRA adapter. If such an adapter exists, it will -load that adapter, and then service the request as normal. That adapter will then be available -for future requests as normal. diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index 72a8320cc1bf8..5c3dfa8ac9cbc 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -57,6 +57,7 @@ class PoolingParams( ## Internal use only task: PoolingTask | None = None requires_token_ids: bool = False + skip_reading_prefix_cache: bool = None extra_kwargs: dict[str, Any] | None = None output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY @@ -93,6 +94,8 @@ class PoolingParams( # plugin task uses io_processor.parse_request to verify inputs, # skipping PoolingParams verify if self.task == "plugin": + if self.skip_reading_prefix_cache is None: + self.skip_reading_prefix_cache = True return # NOTE: Task validation needs to done against the model instance, @@ -122,6 +125,15 @@ class PoolingParams( if getattr(self, k, None) is None: setattr(self, k, getattr(pooler_config, k)) + if self.skip_reading_prefix_cache is None: + # If prefix caching is enabled, + # the output of all pooling may less than n_prompt_tokens, + # we need to skip reading cache at this request. + if self.task in ["token_embed", "token_classify"]: + self.skip_reading_prefix_cache = True + else: + self.skip_reading_prefix_cache = False + self._verify_step_pooling(pooler_config, valid_parameters) def _verify_step_pooling( diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 4b2a3bc4dbaa6..901d661634527 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -15,6 +15,7 @@ from pydantic.dataclasses import dataclass from vllm.logger import init_logger from vllm.logits_process import LogitsProcessor from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.v1.serial_utils import PydanticMsgspecMixin logger = init_logger(__name__) @@ -122,6 +123,7 @@ class RequestOutputKind(Enum): class SamplingParams( + PydanticMsgspecMixin, msgspec.Struct, omit_defaults=True, # type: ignore[call-arg] # required for @cached_property. @@ -252,6 +254,8 @@ class SamplingParams( generated token can complete the sequence.""" _bad_words_token_ids: list[list[int]] | None = None + skip_reading_prefix_cache: bool = None + @staticmethod def from_optional( n: int | None = 1, @@ -412,6 +416,12 @@ class SamplingParams( self.structured_outputs = self.guided_decoding self.guided_decoding = None + if self.skip_reading_prefix_cache is None: + # If prefix caching is enabled, + # the output of prompt logprobs may less than n_prompt_tokens, + # we need to skip reading cache at this request. + self.skip_reading_prefix_cache = self.prompt_logprobs is not None + def _verify_args(self) -> None: if not isinstance(self.n, int): raise ValueError(f"n must be an int, but is of type {type(self.n)}") diff --git a/vllm/sequence.py b/vllm/sequence.py index 6bcc94ad5c625..6d20ca9aac225 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -60,12 +60,17 @@ class IntermediateTensors: tensors: dict[str, torch.Tensor] kv_connector_output: KVConnectorOutput | None - def __init__(self, tensors): + def __init__( + self, + tensors: dict[str, torch.Tensor], + kv_connector_output: KVConnectorOutput | None = None, + ) -> None: # manually define this function, so that # Dynamo knows `IntermediateTensors()` comes from this file. # Otherwise, dataclass will generate this function by evaluating # a string, and we will lose the information about the source file. self.tensors = tensors + self.kv_connector_output = kv_connector_output def __getitem__(self, key: str | slice): if isinstance(key, str): diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 14cae2b168e19..49250e071eab2 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -77,6 +77,7 @@ class LazyConfigDict(dict): _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict( + afmoe="AfmoeConfig", chatglm="ChatGLMConfig", deepseek_vl_v2="DeepseekVLV2Config", deepseek_v32=DeepseekV3Config, @@ -472,8 +473,7 @@ def is_interleaved(config: PretrainedConfig) -> bool: """ text_config = config.get_text_config() if layer_types := getattr(text_config, "layer_types", None): - interleaved_types = {"full_attention", "sliding_attention"} - return interleaved_types.issubset(layer_types) + return len(set(layer_types)) > 1 return False diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index ac612b255143c..dcae05a15fec3 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -7,6 +7,7 @@ Model configs may be defined in this directory for the following reasons: - There is a need to override the existing config to support vLLM. """ +from vllm.transformers_utils.configs.afmoe import AfmoeConfig from vllm.transformers_utils.configs.chatglm import ChatGLMConfig from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekVLV2Config from vllm.transformers_utils.configs.dotsocr import DotsOCRConfig @@ -40,6 +41,7 @@ from vllm.transformers_utils.configs.step3_vl import ( from vllm.transformers_utils.configs.ultravox import UltravoxConfig __all__ = [ + "AfmoeConfig", "ChatGLMConfig", "DeepseekVLV2Config", "DotsOCRConfig", diff --git a/vllm/transformers_utils/configs/afmoe.py b/vllm/transformers_utils/configs/afmoe.py new file mode 100644 index 0000000000000..9b634fd037a33 --- /dev/null +++ b/vllm/transformers_utils/configs/afmoe.py @@ -0,0 +1,84 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from transformers.configuration_utils import PretrainedConfig + + +class AfmoeConfig(PretrainedConfig): + model_type = "afmoe" + + def __init__( + self, + vocab_size: int = 200_192, + hidden_size: int = 2048, + intermediate_size: int = 6144, + moe_intermediate_size: int = 1408, + num_hidden_layers: int = 32, + num_dense_layers: int = 1, + num_attention_heads: int = 16, + num_key_value_heads: int | None = None, + head_dim: int = 128, + hidden_act: str = "silu", + max_position_embeddings: int = 131072, + initializer_range: float = 0.02, + rms_norm_eps: float = 1e-5, + use_cache: bool = True, + tie_word_embeddings: bool = False, + rope_theta: float = 10000.0, + rope_scaling: dict | None = None, + num_experts: int = 64, + num_experts_per_tok: int = 6, + num_shared_experts: int = 2, + num_expert_groups: int = 1, + num_limited_groups: int = 1, + score_func: str = "sigmoid", + route_norm: bool = True, + route_scale: float = 1.0, + global_attn_every_n_layers: int = 4, + sliding_window: int = 2048, + layer_types: list[str] | None = None, + attention_dropout: float = 0.0, + mup_enabled: bool = False, + n_group: int = 1, + topk_group: int = 1, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_dense_layers = num_dense_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads or num_attention_heads + self.head_dim = head_dim + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + + self.moe_intermediate_size = moe_intermediate_size + self.num_experts = num_experts + self.num_experts_per_tok = num_experts_per_tok + self.num_shared_experts = num_shared_experts + self.num_expert_groups = num_expert_groups + self.num_limited_groups = num_limited_groups + self.score_func = score_func + self.route_norm = route_norm + self.route_scale = route_scale + + self.global_attn_every_n_layers = global_attn_every_n_layers + self.sliding_window = sliding_window + self.layer_types = layer_types + self.attention_dropout = attention_dropout + + self.mup_enabled = mup_enabled + self.n_group = n_group + self.topk_group = topk_group + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + +__all__ = ["AfmoeConfig"] diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 9b0045279a67e..3ef44e7703204 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -1,11 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import inspect import uuid import warnings -from functools import wraps -from typing import Any, TypeVar +from typing import Any import torch @@ -41,12 +39,6 @@ def __dir__() -> list[str]: logger = init_logger(__name__) -# This value is chosen to have a balance between ITL and TTFT. Note it is -# not optimized for throughput. -DEFAULT_MAX_NUM_BATCHED_TOKENS = 2048 -POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 -MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120 - # Constants related to forcing the attention backend selection # String name of register which may be set in order to @@ -62,56 +54,10 @@ STR_FLASH_ATTN_VAL: str = "FLASH_ATTN" STR_INVALID_VAL: str = "INVALID" -T = TypeVar("T") - - def random_uuid() -> str: return str(uuid.uuid4().hex) -def warn_for_unimplemented_methods(cls: type[T]) -> type[T]: - """ - A replacement for `abc.ABC`. - When we use `abc.ABC`, subclasses will fail to instantiate - if they do not implement all abstract methods. - Here, we only require `raise NotImplementedError` in the - base class, and log a warning if the method is not implemented - in the subclass. - """ - - original_init = cls.__init__ - - def find_unimplemented_methods(self: object): - unimplemented_methods = [] - for attr_name in dir(self): - # bypass inner method - if attr_name.startswith("_"): - continue - - try: - attr = getattr(self, attr_name) - # get the func of callable method - if callable(attr): - attr_func = attr.__func__ - except AttributeError: - continue - src = inspect.getsource(attr_func) - if "NotImplementedError" in src: - unimplemented_methods.append(attr_name) - if unimplemented_methods: - method_names = ",".join(unimplemented_methods) - msg = f"Methods {method_names} not implemented in {self}" - logger.debug(msg) - - @wraps(original_init) - def wrapped_init(self, *args, **kwargs) -> None: - original_init(self, *args, **kwargs) - find_unimplemented_methods(self) - - type.__setattr__(cls, "__init__", wrapped_init) - return cls - - def length_from_prompt_token_ids_or_embeds( prompt_token_ids: list[int] | None, prompt_embeds: torch.Tensor | None, diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 4c15baf7a8f93..b5ab37534dd78 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -9,6 +9,7 @@ import functools import importlib import os from collections.abc import Callable +from enum import Enum from typing import Any, NoReturn import torch @@ -20,6 +21,28 @@ from vllm.utils.import_utils import has_deep_gemm from vllm.utils.math_utils import cdiv +class DeepGemmQuantScaleFMT(Enum): + # Float32 scales in Float32 tensor + FLOAT32 = 0 + # Compute float32 scales and ceil the scales to UE8M0. + # Keep the scales in Float32 tensor. + FLOAT32_CEIL_UE8M0 = 1 + # Compute float32 scales and ceil the scales to UE8M0. + # Pack the scales into a int32 tensor where each int32 + # element contains 4 scale values. + UE8M0 = 2 + + @staticmethod + def from_oracle() -> "DeepGemmQuantScaleFMT": + if not is_deep_gemm_e8m0_used(): + return DeepGemmQuantScaleFMT.FLOAT32 + return ( + DeepGemmQuantScaleFMT.UE8M0 + if current_platform.is_device_capability(100) + else DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0 + ) + + @functools.cache def is_deep_gemm_supported() -> bool: """Return `True` if DeepGEMM is supported on the current platform. diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 62af39513d651..1209d64901bf5 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -259,6 +259,7 @@ def use_trtllm_attention( num_kv_heads: int, num_tokens: int, max_seq_len: int, + dcp_world_size: int, kv_cache_dtype: str, q_dtype: torch.dtype, is_prefill: bool, @@ -272,6 +273,14 @@ def use_trtllm_attention( if force_use_trtllm is not None and not force_use_trtllm: return False + # Decode context parallel is not supported + if dcp_world_size > 1: + logger.warning_once( + "Trtllm does not support returning LSE and as a result " + "does not support DCP, reverting to FlashInfer" + ) + return False + # The platform is not supported if not supports_trtllm_attention(): if force_use_trtllm: @@ -310,14 +319,12 @@ def use_trtllm_attention( # Environment variable not set - use auto-detection if is_prefill: # Prefill auto-detection - use_trtllm = max_seq_len <= 131072 and kv_cache_dtype == "auto" + use_trtllm = kv_cache_dtype == "auto" if use_trtllm: logger.warning_once("Using TRTLLM prefill attention (auto-detected).") else: # Decode auto-detection - use_trtllm = ( - num_tokens <= 256 and max_seq_len <= 131072 and kv_cache_dtype == "auto" - ) + use_trtllm = num_tokens <= 256 and kv_cache_dtype == "auto" if use_trtllm: logger.warning_once("Using TRTLLM decode attention (auto-detected).") return use_trtllm diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 674398e19c4ce..f1254352c0585 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -48,6 +48,17 @@ class CPUAttentionBackend(AttentionBackend): def get_name() -> str: return "CPU_ATTN" + @classmethod + def supports_attn_type(cls, attn_type: str) -> bool: + """CPU attention supports decoder and encoder-only attention.""" + from vllm.attention import AttentionType + + return attn_type in ( + AttentionType.DECODER, + AttentionType.ENCODER, + AttentionType.ENCODER_ONLY, + ) + @staticmethod def get_impl_cls() -> type["CPUAttentionBackendImpl"]: return CPUAttentionBackendImpl diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index d9bd52d8f9800..a5d4435000d4d 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -66,6 +66,18 @@ class FlashAttentionBackend(AttentionBackend): def get_name() -> str: return "FLASH_ATTN" + @classmethod + def supports_attn_type(cls, attn_type: str) -> bool: + """FlashAttention supports all attention types.""" + from vllm.attention import AttentionType + + return attn_type in ( + AttentionType.DECODER, + AttentionType.ENCODER, + AttentionType.ENCODER_ONLY, + AttentionType.ENCODER_DECODER, + ) + @staticmethod def get_impl_cls() -> type["FlashAttentionImpl"]: return FlashAttentionImpl @@ -118,6 +130,12 @@ class FlashAttentionBackend(AttentionBackend): return flash_attn_supports_fp8() return kv_cache_dtype in ["auto"] + @classmethod + def supports_sink(cls) -> bool: + if not is_flash_attn_varlen_func_available(): + return False + return flash_attn_supports_sinks() + @classmethod def supports_compute_capability(cls, capability: DeviceCapability) -> bool: return capability >= DeviceCapability(8, 0) @@ -686,6 +704,7 @@ class FlashAttentionImpl(AttentionImpl): logits_soft_cap=self.logits_soft_cap, block_table=attn_metadata.block_table, common_prefix_len=attn_metadata.common_prefix_len, + max_num_splits=attn_metadata.max_num_splits, fa_version=self.vllm_flash_attn_version, prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata, suffix_scheduler_metadata=attn_metadata.scheduler_metadata, @@ -932,6 +951,7 @@ def cascade_attention( logits_soft_cap: float, block_table: torch.Tensor, common_prefix_len: int, + max_num_splits: int, fa_version: int, prefix_scheduler_metadata: torch.Tensor | None = None, suffix_scheduler_metadata: torch.Tensor | None = None, @@ -976,7 +996,7 @@ def cascade_attention( # s_aux is incorporated into prefix_lse inside the GPU kernel, # enabling its effect during the final attention merge. s_aux=s_aux, - num_splits=1 if vllm_is_batch_invariant() else 0, + num_splits=1 if vllm_is_batch_invariant() else max_num_splits, ) descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2]) @@ -1001,7 +1021,7 @@ def cascade_attention( q_descale=q_descale.expand(descale_shape) if q_descale is not None else None, k_descale=k_descale.expand(descale_shape) if k_descale is not None else None, v_descale=v_descale.expand(descale_shape) if v_descale is not None else None, - num_splits=1 if vllm_is_batch_invariant() else 0, + num_splits=1 if vllm_is_batch_invariant() else max_num_splits, ) # Merge prefix and suffix outputs, and store the result in output. diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 0b650e2e0d33b..4da1637d96eb6 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -10,6 +10,7 @@ import torch from flashinfer import ( BatchDecodeWithPagedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper, + BatchPrefillWithRaggedKVCacheWrapper, MultiLevelCascadeAttentionWrapper, ) from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache @@ -24,8 +25,11 @@ from vllm.attention.backends.abstract import ( AttentionType, MultipleOf, ) +from vllm.attention.ops.common import cp_lse_ag_out_rs +from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.config import CUDAGraphMode, VllmConfig from vllm.config.cache import CacheDType +from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, @@ -50,6 +54,7 @@ from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, KVCacheLayoutType, + get_dcp_local_seq_lens, get_kv_cache_layout, get_per_layer_parameters, infer_global_hyperparameters, @@ -160,6 +165,113 @@ def trtllm_prefill_attn_kvfp8_dequant( return mock_kv_cache, mock_block_table +class BatchDCPPrefillWrapper: + def __init__( + self, + workspace_buffer: torch.Tensor | None = None, + ): + self._context = BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, get_kv_cache_layout() + ) + self._new_tokens = BatchPrefillWithRaggedKVCacheWrapper( + workspace_buffer, get_kv_cache_layout() + ) + + def plan( + self, + qo_indptr_cpu: torch.Tensor, + paged_kv_indptr_cpu: torch.Tensor, + paged_kv_indices: torch.Tensor, + paged_kv_last_page_len_cpu: torch.Tensor, + prefill_start: int, + page_size: int, + num_qo_heads: int, + dcp_world_size: int, + num_kv_heads: int, + head_dim: int, + sm_scale: float, + window_left: int, + logits_soft_cap: float | None, + q_data_type: torch.dtype, + kv_cache_dtype: torch.dtype, + prefill_fixed_split_size: int, + disable_split_kv: bool, + ): + """Plan the prefill operation with given parameters.""" + self._context.plan( + qo_indptr_cpu, + paged_kv_indptr_cpu, + paged_kv_indices, + paged_kv_last_page_len_cpu[prefill_start:], + num_qo_heads * dcp_world_size, + num_kv_heads, + head_dim, + page_size, + causal=False, # This is context run + sm_scale=sm_scale, + window_left=window_left, + logits_soft_cap=logits_soft_cap, + q_data_type=q_data_type, + kv_data_type=kv_cache_dtype, + fixed_split_size=prefill_fixed_split_size, + disable_split_kv=disable_split_kv, + ) + self._new_tokens.plan( + qo_indptr=qo_indptr_cpu, + kv_indptr=qo_indptr_cpu, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim_qk=head_dim, + head_dim_vo=head_dim, + causal=True, # This is newtokens run + sm_scale=sm_scale, + window_left=window_left, + logits_soft_cap=logits_soft_cap, + q_data_type=q_data_type, + ) + + def run( + self, + layer: torch.nn.Module, + prefill_query: torch.Tensor, + kv_cache_permute: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + ): + prefill_query_across_dcp = get_dcp_group().all_gather( + prefill_query.contiguous(), dim=1 + ) + output_context_tmp, lse_context_tmp = self._context.run( + prefill_query_across_dcp, + kv_cache_permute, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + return_lse=True, + ) + output_context, lse_context = cp_lse_ag_out_rs( + output_context_tmp, lse_context_tmp, get_dcp_group(), return_lse=True + ) + lse_context = lse_context.transpose(0, 1).contiguous() + + output_query, lse_query = self._new_tokens.run( + prefill_query, + key, + value, + return_lse=True, + ) + lse_query = lse_query.transpose(0, 1).contiguous() + + merge_attn_states( + out, + output_context, + lse_context, + output_query, + lse_query, + ) + return out + + class FlashInferBackend(AttentionBackend): accept_output_buffer: bool = True supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] @@ -281,7 +393,9 @@ class FlashInferMetadata: # For cascade attention (CPU for planning). use_cascade: bool - prefill_wrapper: BatchPrefillWithPagedKVCacheWrapper | None = None + prefill_wrapper: ( + BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper | None + ) = None decode_wrapper: BatchDecodeWithPagedKVCacheWrapper | None = None cascade_wrapper: MultiLevelCascadeAttentionWrapper | None = None @@ -303,7 +417,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self.cache_config = vllm_config.cache_config self.model_config = vllm_config.model_config self._workspace_buffer = None - self._prefill_wrapper = None # Wrapper for prefill/append + self._prefill_wrapper: ( + BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper | None + ) = None # Wrapper for prefill/append self._decode_wrapper = None # Wrapper for decode (general shape) if vllm_is_batch_invariant(): @@ -341,9 +457,23 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self.compilation_config.max_cudagraph_capture_size, ) - self.num_qo_heads = self.model_config.get_num_attention_heads( - self.vllm_config.parallel_config + try: + self.dcp_world_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group + self.dcp_kv_cache_interleave_size = ( + vllm_config.parallel_config.dcp_kv_cache_interleave_size + ) + except AssertionError: + # DCP might not be initialized in testing + self.dcp_world_size = 1 + self.dcp_rank = 0 + self.dcp_kv_cache_interleave_size = 1 + + self.num_qo_heads = ( + self.model_config.get_num_attention_heads(self.vllm_config.parallel_config) + * self.dcp_world_size ) + self.num_kv_heads = self.kv_cache_spec.num_kv_heads self.head_dim = self.kv_cache_spec.head_size self.page_size = self.kv_cache_spec.block_size @@ -455,11 +585,19 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ) return self._workspace_buffer - def _get_prefill_wrapper(self): + def _get_prefill_wrapper( + self, + ) -> BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper: if self._prefill_wrapper is None: - self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( - self._get_workspace_buffer(), get_kv_cache_layout() - ) + if self.dcp_world_size > 1: + self._prefill_wrapper = BatchDCPPrefillWrapper( + workspace_buffer=self._get_workspace_buffer(), + ) + else: + self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( + self._get_workspace_buffer(), get_kv_cache_layout() + ) + assert self._prefill_wrapper is not None return self._prefill_wrapper def _get_decode_wrapper(self, batch_size: int, use_cudagraph: bool = False): @@ -526,9 +664,29 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): max_seq_len = common_attn_metadata.max_seq_len seq_lens = common_attn_metadata.seq_lens seq_lens_cpu = common_attn_metadata.seq_lens_cpu - seq_lens_np = seq_lens_cpu.numpy() block_table_tensor = common_attn_metadata.block_table_tensor + qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu + if self.dcp_world_size > 1: + if num_prefills > 0: + qo_indptr_prefill_cpu = ( + qo_indptr_cpu[num_decodes:] - qo_indptr_cpu[num_decodes] + ) + query_lens_prefill_cpu = ( + qo_indptr_prefill_cpu[1:] - qo_indptr_prefill_cpu[:-1] + ) + seq_lens_cpu[num_decodes:] = ( + seq_lens_cpu[num_decodes:] - query_lens_prefill_cpu + ) + + seq_lens_cpu = get_dcp_local_seq_lens( + seq_lens_cpu, + self.dcp_world_size, + self.dcp_rank, + self.dcp_kv_cache_interleave_size, + ) + + seq_lens_np = seq_lens_cpu.numpy() num_blocks_np = (seq_lens_np + (page_size - 1)) // page_size use_cascade = common_prefix_len > 0 @@ -589,7 +747,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): # write self.paged_kv_last_page_len_cpu inplace paged_kv_last_page_len_np = seq_lens_np % page_size self.paged_kv_last_page_len_np[:num_reqs] = np.where( - paged_kv_last_page_len_np == 0, + (paged_kv_last_page_len_np == 0) & (seq_lens_np != 0), page_size, paged_kv_last_page_len_np, ) @@ -600,13 +758,16 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self.num_kv_heads, num_prefill_tokens, max_seq_len, + self.dcp_world_size, self.cache_dtype, self.q_data_type, is_prefill=True, has_sinks=self.has_sinks, has_spec=uses_spec_reorder, ) - decode_use_trtllm = self.use_trtllm_decode_attention + decode_use_trtllm = ( + self.use_trtllm_decode_attention and self.dcp_world_size <= 1 + ) if not (prefill_use_trtllm and decode_use_trtllm): if self.has_sinks: @@ -651,7 +812,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): use_cascade=use_cascade, ) - qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[: 1 + num_reqs] paged_kv_last_page_len_cpu = self.paged_kv_last_page_len_cpu[:num_reqs] @@ -703,24 +863,52 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): attn_metadata.max_q_len_prefill = int(query_lens_prefill.max().item()) if not attn_metadata.prefill_use_trtllm: - attn_metadata.prefill_wrapper.plan( - qo_indptr_cpu, - paged_kv_indptr_cpu, - paged_kv_indices, - paged_kv_last_page_len_cpu[prefill_start:], - self.num_qo_heads, - self.num_kv_heads, - self.head_dim, - self.page_size, - causal=True, - sm_scale=self.sm_scale, - window_left=self.window_left, - logits_soft_cap=self.logits_soft_cap, - q_data_type=self.q_data_type, - kv_data_type=self.kv_cache_dtype, - fixed_split_size=self.prefill_fixed_split_size, - disable_split_kv=self.disable_split_kv, - ) + if self.dcp_world_size > 1: + assert isinstance( + attn_metadata.prefill_wrapper, BatchDCPPrefillWrapper + ) + attn_metadata.prefill_wrapper.plan( + qo_indptr_cpu=qo_indptr_cpu, + paged_kv_indptr_cpu=paged_kv_indptr_cpu, + paged_kv_indices=paged_kv_indices, + paged_kv_last_page_len_cpu=paged_kv_last_page_len_cpu, + prefill_start=prefill_start, + page_size=self.page_size, + num_qo_heads=self.num_qo_heads, + dcp_world_size=self.dcp_world_size, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + q_data_type=self.q_data_type, + kv_cache_dtype=self.kv_cache_dtype, + prefill_fixed_split_size=self.prefill_fixed_split_size, + disable_split_kv=self.disable_split_kv, + ) + else: + assert isinstance( + attn_metadata.prefill_wrapper, + BatchPrefillWithPagedKVCacheWrapper, + ) + attn_metadata.prefill_wrapper.plan( + qo_indptr_cpu, + paged_kv_indptr_cpu, + paged_kv_indices, + paged_kv_last_page_len_cpu[prefill_start:], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + causal=True, + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + q_data_type=self.q_data_type, + kv_data_type=self.kv_cache_dtype, + fixed_split_size=self.prefill_fixed_split_size, + disable_split_kv=self.disable_split_kv, + ) else: attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to( self.device, non_blocking=True @@ -770,7 +958,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): paged_kv_indices, self.paged_kv_last_page_len_cpu[:num_input_tokens], seq_lens_cpu[:num_input_tokens], - self.num_qo_heads, + self.num_qo_heads * self.dcp_world_size, self.num_kv_heads, self.head_dim, self.page_size, @@ -797,6 +985,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): class FlashInferImpl(AttentionImpl): + can_return_lse_for_decode: bool = True + def __init__( self, num_heads: int, @@ -989,6 +1179,8 @@ class FlashInferImpl(AttentionImpl): # Inputs and outputs may be padded for CUDA graphs query = query[:num_actual_tokens] + key = key[:num_actual_tokens] + value = value[:num_actual_tokens] output_padded = output output = output[:num_actual_tokens] @@ -1015,17 +1207,46 @@ class FlashInferImpl(AttentionImpl): assert prefill_wrapper is not None if not attn_metadata.prefill_use_trtllm: - assert prefill_wrapper._causal - assert prefill_wrapper._window_left == self.window_left - assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0) - assert prefill_wrapper._sm_scale == self.scale - prefill_wrapper.run( - prefill_query, - kv_cache_permute, - k_scale=layer._k_scale_float, - v_scale=layer._v_scale_float, - out=output[num_decode_tokens:], - ) + if self.dcp_world_size > 1: + assert isinstance(prefill_wrapper, BatchDCPPrefillWrapper) + assert prefill_wrapper._context._window_left == self.window_left + assert prefill_wrapper._context._logits_soft_cap == ( + self.logits_soft_cap or 0.0 + ) + assert prefill_wrapper._context._sm_scale == self.scale + assert not prefill_wrapper._context._causal + assert prefill_wrapper._new_tokens._window_left == self.window_left + assert prefill_wrapper._new_tokens._logits_soft_cap == ( + self.logits_soft_cap or 0.0 + ) + assert prefill_wrapper._new_tokens._sm_scale == self.scale + assert prefill_wrapper._new_tokens._causal + + prefill_wrapper.run( + layer, + prefill_query, + kv_cache_permute, + key[num_decode_tokens:], + value[num_decode_tokens:], + out=output[num_decode_tokens:], + ) + else: + assert isinstance( + prefill_wrapper, BatchPrefillWithPagedKVCacheWrapper + ) + assert prefill_wrapper._window_left == self.window_left + assert prefill_wrapper._logits_soft_cap == ( + self.logits_soft_cap or 0.0 + ) + assert prefill_wrapper._sm_scale == self.scale + assert prefill_wrapper._causal + prefill_wrapper.run( + prefill_query, + kv_cache_permute, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + out=output[num_decode_tokens:], + ) else: # prefill_query may be non-contiguous prefill_query = prefill_query.contiguous() @@ -1101,13 +1322,37 @@ class FlashInferImpl(AttentionImpl): assert decode_wrapper._window_left == self.window_left assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0) assert decode_wrapper._sm_scale == self.scale - decode_wrapper.run( - decode_query, - kv_cache_permute, - k_scale=layer._k_scale_float, - v_scale=layer._v_scale_float, - out=output[:num_decode_tokens], - ) + + if self.dcp_world_size > 1: + decode_query = get_dcp_group().all_gather( + decode_query.contiguous(), dim=-2 + ) + output_tmp = torch.empty_like(decode_query) + lse = torch.empty( + (decode_query.size(0), decode_query.size(1)), + dtype=torch.float32, + device=decode_query.device, + ) + decode_wrapper.run( + decode_query, + kv_cache_permute, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + out=output_tmp, + lse=lse, + return_lse=True, + ) + output[:num_decode_tokens] = cp_lse_ag_out_rs( + output_tmp, lse, get_dcp_group() + ) + else: + decode_wrapper.run( + decode_query, + kv_cache_permute, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + out=output[:num_decode_tokens], + ) else: # decode_query may be non-contiguous decode_query = decode_query.contiguous() diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index e53cd0d8af4f2..7768827d26dc3 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -84,6 +84,13 @@ class FlexAttentionBackend(AttentionBackend): def get_name() -> str: return "FLEX_ATTENTION" + @classmethod + def supports_attn_type(cls, attn_type: str) -> bool: + """FlexAttention supports both decoder and encoder-only attention.""" + from vllm.attention import AttentionType + + return attn_type in (AttentionType.DECODER, AttentionType.ENCODER_ONLY) + @staticmethod def get_impl_cls() -> type["FlexAttentionImpl"]: return FlexAttentionImpl diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 467c01cd9d069..2ccdd1f143ce8 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -337,6 +337,7 @@ class MLACommonPrefillMetadata: local_context_lens_allranks: list[list[int]] | None = None padded_local_cu_seq_lens: torch.Tensor | None = None cu_seq_lens_lst: list[list[int]] | None = None + chunk_size: int | None = None block_table: torch.Tensor query_start_loc: torch.Tensor @@ -902,6 +903,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): device, non_blocking=True ), cu_seq_lens_lst=cu_seq_lens_cpu.tolist(), + chunk_size=padded_local_max_context_chunk_across_ranks, ) else: chunked_context_metadata = chunked_context_metadata_cls( @@ -986,6 +988,8 @@ def reorg_kvcache( local_context_lens_allranks: list[list[int]], sum_seq_len: int, max_seq_len: int, + chunk_size: int, + chunk_idx: int, toks: int, ) -> tuple[torch.Tensor, torch.Tensor]: """ @@ -1001,6 +1005,9 @@ def reorg_kvcache( local_context_lens_allranks: local context lengths on each CP rank. sum_seq_len: the sum of cp_chunk_seq_lens_lst. max_seq_len: the max value of cp_chunk_seq_lens_lst. + chunk_size: the local padded max context chunk from + chunked_context_metadata building. + chunk_idx: chunk idx of chunked_prefill. toks: the number of tokens for local gather cache. """ kv_c_segments = [] @@ -1012,20 +1019,31 @@ def reorg_kvcache( ): cur_seq_len = 0 for rank, local_context_len in enumerate(local_context_lens): - if local_context_len != 0: + # Note(qcs): We split the context into multiple chunks, + # depending on the size of the workspace. + # local_context in dcp0: |-----------------| + # local_context in dcp1: |--------------| + # n*padded_local_chunk: |-----|-----|-----| + # local_chunk_len in dcp1: |-----|-----|--| + # so we need update the last chunk length in dcp1. + local_chunk_len = min( + max(0, local_context_len - chunk_idx * chunk_size), + padded_local_chunk_seq_len, + ) + if local_chunk_len != 0: kv_c_segment = allgatered_kv_c_normed[ rank * toks + src_token_idx : rank * toks + src_token_idx - + local_context_len + + local_chunk_len ] k_pe_segment = allgatered_k_pe[ rank * toks + src_token_idx : rank * toks + src_token_idx - + local_context_len + + local_chunk_len ] kv_c_segments.append(kv_c_segment) k_pe_segments.append(k_pe_segment) - cur_seq_len += local_context_len + cur_seq_len += local_chunk_len max_seq_len_check = max(max_seq_len_check, cur_seq_len) src_token_idx += padded_local_chunk_seq_len reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0) @@ -1676,6 +1694,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): assert prefill_metadata.chunked_context.local_context_lens_allranks is not None assert prefill_metadata.chunked_context.padded_local_cu_seq_lens is not None assert prefill_metadata.chunked_context.cu_seq_lens_lst is not None + assert prefill_metadata.chunked_context.chunk_size is not None output = None iters = len(prefill_metadata.chunked_context.seq_tot) @@ -1725,6 +1744,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): local_context_lens_allranks=prefill_metadata.chunked_context.local_context_lens_allranks, sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i][-1], max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i], + chunk_size=prefill_metadata.chunked_context.chunk_size, + chunk_idx=i, toks=toks, ) diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 5fe9c69d35007..bb8d914d15719 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -40,14 +40,14 @@ logger = init_logger(__name__) """ NOTE: FlashMLA Sparse uses an fp8 cache with the following format -In the "FP8 with scale" format, each token's KV cache is 656 Bytes, +In the "FP8 with scale" format, each token's KV cache is 656 Bytes, structured as: -- **First 512 bytes:** The "quantized NoPE" part, containing 512 +- **First 512 bytes:** The "quantized NoPE" part, containing 512 `float8_e4m3` values. -- **Next 16 bytes:** Scale factors, containing 4 `float32` values. - The first `float32` is the scale for the first 128 `float8_e4m3` values, +- **Next 16 bytes:** Scale factors, containing 4 `float32` values. + The first `float32` is the scale for the first 128 `float8_e4m3` values, the second for the next 128, and so on. -- **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This +- **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This part is not quantized for accuracy. """ diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index c7f925817a6a8..ea611848b0e81 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -18,6 +18,7 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils.math_utils import cdiv +from vllm.utils.platform_utils import get_cu_count from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, @@ -38,7 +39,7 @@ if current_platform.is_rocm(): return min(65536 // x.element_size(), triton.next_power_of_2(head_dim)) def num_programs(total_tokens): - return min(total_tokens, current_platform.get_cu_count()) + return min(total_tokens, get_cu_count()) @triton.jit def cp_mha_gather_cache_kernel( @@ -728,7 +729,7 @@ class AiterFlashAttentionImpl(AttentionImpl): cu_seqlens_k=attn_metadata.prefill_metadata.query_start_loc, max_seqlen_q=attn_metadata.prefill_metadata.max_query_len, max_seqlen_k=attn_metadata.prefill_metadata.max_seq_len, - min_seqlen_q=attn_metadata.prefill_metadata.min_query_len, + min_seqlen_q=1, dropout_p=0.0, softmax_scale=self.scale, causal=True, @@ -758,7 +759,7 @@ class AiterFlashAttentionImpl(AttentionImpl): cu_seqlens_q=attn_metadata.extend_metadata.query_start_loc, max_seqlen_q=attn_metadata.extend_metadata.max_query_len, max_seqlen_k=attn_metadata.extend_metadata.max_seq_len, - min_seqlen_q=attn_metadata.extend_metadata.min_query_len, + min_seqlen_q=1, block_table=attn_metadata.block_table[ num_decodes : num_decodes + num_extends ], diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index fd37a665cf05f..578153cda7863 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -965,12 +965,6 @@ def reshape_attn_output_for_spec_decode(attn_output: torch.Tensor) -> torch.Tens return attn_output.view(total_tokens, attn_output.shape[2], attn_output.shape[3]) -KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [ - ("logits_indices_padded", torch.Tensor | None, None), - ("num_logits_indices", int, 0), -] - - def subclass_attention_metadata( name_prefix: str, metadata_cls: Any, @@ -986,8 +980,8 @@ def subclass_attention_metadata( @runtime_checkable class KVSharingFastPrefillMetadata(Protocol): - logits_indices_padded: torch.Tensor - num_logits_indices: int + logits_indices_padded: torch.Tensor | None = None + num_logits_indices: int | None = None def create_fast_prefill_custom_backend( @@ -1019,11 +1013,6 @@ def create_fast_prefill_custom_backend( for _field in fields(metadata.__class__): setattr(self, _field.name, getattr(metadata, _field.name)) - # Set additional fields that will be used in model code - assert ( - common_attn_metadata.logits_indices_padded is not None - and common_attn_metadata.num_logits_indices is not None - ) self.logits_indices_padded = ( common_attn_metadata.logits_indices_padded ) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 63a1ff06e4049..7f405fc248ac2 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -185,12 +185,11 @@ class KVCacheManager: - A list of blocks that are computed for the request. - The number of computed tokens. """ - # Prefix caching is disabled or - # When the request requires prompt logprobs, we skip prefix caching. - if not self.enable_caching or ( - request.sampling_params is not None - and request.sampling_params.prompt_logprobs is not None - ): + # We skip finding the prefix cache hit when prefix caching is + # disabled or the request is marked as skipping kv cache read + # (which happens when the request requires prompt logprobs + # or calls a pooling model with all pooling). + if not self.enable_caching or request.skip_reading_prefix_cache: return self.empty_kv_cache_blocks, 0 # NOTE: When all tokens hit the cache, we must recompute the last token diff --git a/vllm/v1/core/sched/async_scheduler.py b/vllm/v1/core/sched/async_scheduler.py index 0ad994c360b01..3214f65a09728 100644 --- a/vllm/v1/core/sched/async_scheduler.py +++ b/vllm/v1/core/sched/async_scheduler.py @@ -16,18 +16,25 @@ class AsyncScheduler(Scheduler): ) -> None: super()._update_after_schedule(scheduler_output) pending_structured_output_tokens = False + spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens for req_id in scheduler_output.num_scheduled_tokens: request = self.requests[req_id] pending_structured_output_tokens |= ( request.use_structured_output and request.num_output_placeholders > 0 ) + cur_num_spec_tokens = len(spec_decode_tokens.get(req_id, ())) if ( request.num_computed_tokens - == request.num_tokens + request.num_output_placeholders + == request.num_tokens + + request.num_output_placeholders + + cur_num_spec_tokens ): - # The request will generate a new token in this scheduling step. - # TODO(woosuk): Support speculative decoding. - request.num_output_placeholders += 1 + # The request will generate a new token plus num_spec_tokens + # in this scheduling step. + request.num_output_placeholders += 1 + cur_num_spec_tokens + # Add placeholders for the new tokens in spec_token_ids. + # Wwe will update the actual spec token ids in the worker process. + request.spec_token_ids = [-1] * self.num_spec_tokens scheduler_output.pending_structured_output_tokens = ( pending_structured_output_tokens diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 4fcc7955df195..4323141c435b7 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -83,7 +83,7 @@ class Scheduler(SchedulerInterface): # Scheduling constraints. self.max_num_running_reqs = self.scheduler_config.max_num_seqs self.max_num_scheduled_tokens = self.scheduler_config.max_num_batched_tokens - self.max_model_len = self.scheduler_config.max_model_len + self.max_model_len = vllm_config.model_config.max_model_len self.enable_kv_cache_events = ( self.kv_events_config is not None and self.kv_events_config.enable_kv_cache_events @@ -348,7 +348,10 @@ class Scheduler(SchedulerInterface): # Speculative decode related. if request.spec_token_ids: num_scheduled_spec_tokens = ( - num_new_tokens + request.num_computed_tokens - request.num_tokens + num_new_tokens + + request.num_computed_tokens + - request.num_tokens + - request.num_output_placeholders ) if num_scheduled_spec_tokens > 0: # Trim spec_token_ids list to num_scheduled_spec_tokens. @@ -469,9 +472,9 @@ class Scheduler(SchedulerInterface): num_computed_tokens = ( num_new_local_computed_tokens + num_external_computed_tokens ) - # KVTransfer: WAITING reqs have num_computed_tokens > 0 - # after async KV recvs are completed. else: + # KVTransfer: WAITING reqs have num_computed_tokens > 0 + # after async KV recvs are completed. new_computed_blocks = self.kv_cache_manager.empty_kv_cache_blocks num_new_local_computed_tokens = 0 num_computed_tokens = request.num_computed_tokens @@ -480,12 +483,12 @@ class Scheduler(SchedulerInterface): external_load_encoder_input = [] new_encoder_compute_budget = encoder_compute_budget - # KVTransfer: loading remote KV, do not allocate for new work. if load_kv_async: + # KVTransfer: loading remote KV, do not allocate for new work. assert num_external_computed_tokens > 0 num_new_tokens = 0 - # Number of tokens to be scheduled. else: + # Number of tokens to be scheduled. # We use `request.num_tokens` instead of # `request.num_prompt_tokens` to consider the resumed # requests, which have output tokens. @@ -497,7 +500,7 @@ class Scheduler(SchedulerInterface): # chunked prefill has to be enabled explicitly to allow # pooling requests to be chunked if ( - not self.scheduler_config.chunked_prefill_enabled + not self.scheduler_config.enable_chunked_prefill and num_new_tokens > token_budget ): self.waiting.pop_request() @@ -778,9 +781,7 @@ class Scheduler(SchedulerInterface): assert not scheduled_in_prev_step resumed_req_ids.add(req_id) if not scheduled_in_prev_step: - all_token_ids[req_id] = req.all_token_ids[ - : req.num_computed_tokens + num_tokens - ] + all_token_ids[req_id] = req.all_token_ids.copy() new_block_ids.append( req_to_new_blocks[req_id].get_block_ids(allow_none=True) ) @@ -1010,8 +1011,8 @@ class Scheduler(SchedulerInterface): continue req_index = model_runner_output.req_id_to_index[req_id] - generated_token_ids = ( - sampled_token_ids[req_index] if sampled_token_ids else [] + generated_token_ids: list[int] = ( + sampled_token_ids[req_index].tolist() if sampled_token_ids else [] ) scheduled_spec_token_ids = ( @@ -1026,7 +1027,12 @@ class Scheduler(SchedulerInterface): # tokens and rejections. If some tokens are rejected, # num_computed_tokens is decreased by the number of rejected # tokens. - request.num_computed_tokens -= num_rejected + if request.num_computed_tokens > 0: + request.num_computed_tokens -= num_rejected + # If async scheduling, num_output_placeholders also includes + # the scheduled spec tokens count and so is similarly adjusted. + if request.num_output_placeholders > 0: + request.num_output_placeholders -= num_rejected spec_decoding_stats = self.make_spec_decoding_stats( spec_decoding_stats, num_draft_tokens=num_draft_tokens, diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 058a4bcaecb58..3f621d77c0241 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -15,6 +15,7 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import LogprobsLists, LogprobsTensors +from vllm.v1.serial_utils import UtilityResult # These are possible values of RequestOutput.finish_reason, # so form part of the external API. @@ -131,13 +132,6 @@ class EngineCoreOutput( return self.finish_reason is not None -class UtilityResult: - """Wrapper for special handling when serializing/deserializing.""" - - def __init__(self, r: Any = None): - self.result = r - - class UtilityOutput( msgspec.Struct, array_like=True, # type: ignore[call-arg] diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index aee21fb3fffe7..c160c7cbcab4a 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -14,7 +14,7 @@ import torch import vllm.envs as envs from vllm.config import VllmConfig from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.engine.protocol import Device, EngineClient +from vllm.engine.protocol import EngineClient from vllm.entrypoints.utils import _validate_truncation_size from vllm.inputs import PromptType from vllm.logger import init_logger @@ -120,8 +120,9 @@ class AsyncLLM(EngineClient): ) # OutputProcessor (converts EngineCoreOutputs --> RequestOutput). + stream_interval = self.vllm_config.scheduler_config.stream_interval self.output_processor = OutputProcessor( - self.tokenizer, log_stats=self.log_stats + self.tokenizer, log_stats=self.log_stats, stream_interval=stream_interval ) endpoint = self.observability_config.otlp_traces_endpoint if endpoint is not None: @@ -671,9 +672,7 @@ class AsyncLLM(EngineClient): self.processor.clear_mm_cache() await self.engine_core.reset_mm_cache_async() - async def reset_prefix_cache(self, device: Device | None = None) -> None: - if device == Device.CPU: - raise ValueError("Not supported on CPU.") + async def reset_prefix_cache(self) -> None: await self.engine_core.reset_prefix_cache_async() async def sleep(self, level: int = 1) -> None: diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index ffb5232e770d1..97286c6e2e5e4 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -63,7 +63,6 @@ from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder from vllm.v1.structured_output import StructuredOutputManager -from vllm.v1.utils import record_function_or_nullcontext from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -124,7 +123,7 @@ class EngineCore: # Encoder models without KV cache don't support # chunked prefill. But do SSM models? logger.info("Disabling chunked prefill for model without KVCache") - vllm_config.scheduler_config.chunked_prefill_enabled = False + vllm_config.scheduler_config.enable_chunked_prefill = False scheduler_block_size = ( vllm_config.cache_config.block_size @@ -181,11 +180,13 @@ class EngineCore: logger.info("Batch queue is enabled with size %d", self.batch_queue_size) self.batch_queue = deque(maxlen=self.batch_queue_size) + self.ec_producer = ( + vllm_config.ec_transfer_config is not None + and vllm_config.ec_transfer_config.is_ec_producer + ) + self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None - if ( - self.vllm_config.cache_config.enable_prefix_caching - or kv_connector is not None - ): + if vllm_config.cache_config.enable_prefix_caching or kv_connector is not None: caching_hash_fn = get_hash_fn_by_name( vllm_config.cache_config.prefix_caching_hash_algo ) @@ -198,6 +199,7 @@ class EngineCore: self.step_fn = ( self.step if self.batch_queue is None else self.step_with_batch_queue ) + self.async_scheduling = vllm_config.scheduler_config.async_scheduling # Mark the startup heap as static so that it's ignored by GC. # Reduces pause times of oldest generation collections. @@ -245,7 +247,7 @@ class EngineCore: elapsed = time.time() - start logger.info_once( - ("init engine (profile, create kv cache, warmup model) took %.2f seconds"), + "init engine (profile, create kv cache, warmup model) took %.2f seconds", elapsed, scope="local", ) @@ -311,6 +313,16 @@ class EngineCore: ) raise err + def _log_err_callback(self, scheduler_output: SchedulerOutput): + """Log error details of a future that's not expected to return a result.""" + + def callback(f, sched_output=scheduler_output): + with self.log_error_detail(sched_output): + result = f.result() + assert result is None + + return callback + def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]: """Schedule, execute, and make output. @@ -322,26 +334,25 @@ class EngineCore: # or finished and not yet removed from the batch. if not self.scheduler.has_requests(): return {}, False - with record_function_or_nullcontext("core step: schedule"): - scheduler_output = self.scheduler.schedule() + scheduler_output = self.scheduler.schedule() + future = self.model_executor.execute_model(scheduler_output, non_block=True) + grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output) + with self.log_error_detail(scheduler_output): + model_output = future.result() + if model_output is None: + model_output = self.model_executor.sample_tokens(grammar_output) - with record_function_or_nullcontext("core step: execute_model"): - future = self.model_executor.execute_model(scheduler_output, non_block=True) - grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output) - with self.log_error_detail(scheduler_output): - model_output = future.result() - if model_output is None: - model_output = self.model_executor.sample_tokens(grammar_output) - - with record_function_or_nullcontext("core step: update_from_output"): - engine_core_outputs = self.scheduler.update_from_output( - scheduler_output, model_output - ) + engine_core_outputs = self.scheduler.update_from_output( + scheduler_output, model_output + ) return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0 def post_step(self, model_executed: bool) -> None: - if self.use_spec_decode and model_executed: + # When using async scheduling we can't get draft token ids in advance, + # so we update draft token ids in the worker process and don't + # need to update draft token ids here. + if not self.async_scheduling and self.use_spec_decode and model_executed: # Take the draft token ids. draft_token_ids = self.model_executor.take_draft_token_ids() if draft_token_ids is not None: @@ -374,52 +385,34 @@ class EngineCore: model_executed = False deferred_scheduler_output = None if self.scheduler.has_requests(): - with record_function_or_nullcontext("core step_with_batch_queue: schedule"): - scheduler_output = self.scheduler.schedule() - with record_function_or_nullcontext( - "core step_with_batch_queue: execute_model" - ): - exec_future = self.model_executor.execute_model( - scheduler_output, non_block=True - ) - model_executed = scheduler_output.total_num_scheduled_tokens > 0 + scheduler_output = self.scheduler.schedule() + exec_future = self.model_executor.execute_model( + scheduler_output, non_block=True + ) + if not self.ec_producer: + model_executed = scheduler_output.total_num_scheduled_tokens > 0 - if scheduler_output.pending_structured_output_tokens: - with record_function_or_nullcontext( - "core step_with_batch_queue: pending_structured_output_tokens" - ): - # We need to defer sampling until we have processed the model output - # from the prior step. - deferred_scheduler_output = scheduler_output - # Block-wait for execute to return - # (continues running async on the GPU). - with self.log_error_detail(scheduler_output): - exec_result = exec_future.result() - assert exec_result is None + if not model_executed: + # No sampling required (no requests scheduled). + future = cast(Future[ModelRunnerOutput], exec_future) else: - with record_function_or_nullcontext( - "core step_with_batch_queue: get_grammar_bitmask" - ): - # We aren't waiting for any tokens, get any grammar - # output immediately. + exec_future.add_done_callback(self._log_err_callback(scheduler_output)) + + if not scheduler_output.pending_structured_output_tokens: + # We aren't waiting for any tokens, get any grammar output + # and sample immediately. grammar_output = self.scheduler.get_grammar_bitmask( scheduler_output ) - # Block-wait for execute to return (continues running async on the GPU). - with self.log_error_detail(scheduler_output): - exec_result = exec_future.result() - - if exec_result is None: - with record_function_or_nullcontext( - "core step_with_batch_queue: sample_tokens" - ): - # Call sample tokens. - future = self.model_executor.sample_tokens( - grammar_output, non_block=True - ) + future = self.model_executor.sample_tokens( + grammar_output, non_block=True + ) else: - # No sampling required (e.g. all requests finished). - future = cast(Future[ModelRunnerOutput], exec_future) + # We need to defer sampling until we have processed the model output + # from the prior step. + deferred_scheduler_output = scheduler_output + + if not deferred_scheduler_output: # Add this step's future to the queue. batch_queue.appendleft((future, scheduler_output)) if ( @@ -436,34 +429,27 @@ class EngineCore: # only be called when the scheduler contains requests or the queue # is non-empty. return None, False - with record_function_or_nullcontext("core step_with_batch_queue: model_output"): - # Block until the next result is available. - future, scheduler_output = batch_queue.pop() - with self.log_error_detail(scheduler_output): - model_output = future.result() - with record_function_or_nullcontext( - "core step_with_batch_queue: update_from_output" - ): - engine_core_outputs = self.scheduler.update_from_output( - scheduler_output, model_output - ) + + # Block until the next result is available. + future, scheduler_output = batch_queue.pop() + with self.log_error_detail(scheduler_output): + model_output = future.result() + + engine_core_outputs = self.scheduler.update_from_output( + scheduler_output, model_output + ) # NOTE(nick): We can either handle the deferred tasks here or save # in a field and do it immediately once step_with_batch_queue is # re-called. The latter slightly favors TTFT over TPOT/throughput. if deferred_scheduler_output: - with record_function_or_nullcontext( - "core step_with_batch_queue: deferred_scheduler_output" - ): - # We now have the tokens needed to compute the bitmask for the - # deferred request. Get the bitmask and call sample tokens. - grammar_output = self.scheduler.get_grammar_bitmask( - deferred_scheduler_output - ) - future = self.model_executor.sample_tokens( - grammar_output, non_block=True - ) - batch_queue.appendleft((future, deferred_scheduler_output)) + # We now have the tokens needed to compute the bitmask for the + # deferred request. Get the bitmask and call sample tokens. + grammar_output = self.scheduler.get_grammar_bitmask( + deferred_scheduler_output + ) + future = self.model_executor.sample_tokens(grammar_output, non_block=True) + batch_queue.appendleft((future, deferred_scheduler_output)) return engine_core_outputs, model_executed diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 6224af5700b7b..e403cea87788b 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -14,7 +14,6 @@ from vllm.config import ParallelConfig, VllmConfig from vllm.distributed import stateless_destroy_torch_distributed_process_group from vllm.distributed.parallel_state import get_dp_group from vllm.engine.arg_utils import EngineArgs -from vllm.engine.protocol import Device from vllm.inputs import PromptType from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -96,8 +95,9 @@ class LLMEngine: ) # OutputProcessor (convert EngineCoreOutputs --> RequestOutput). + stream_interval = self.vllm_config.scheduler_config.stream_interval self.output_processor = OutputProcessor( - self.tokenizer, log_stats=self.log_stats + self.tokenizer, log_stats=self.log_stats, stream_interval=stream_interval ) endpoint = self.observability_config.otlp_traces_endpoint if endpoint is not None: @@ -320,7 +320,7 @@ class LLMEngine: self.processor.clear_mm_cache() self.engine_core.reset_mm_cache() - def reset_prefix_cache(self, device: Device | None = None): + def reset_prefix_cache(self): self.engine_core.reset_prefix_cache() def sleep(self, level: int = 1): diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index d8d03f19d4663..bdbbfe2595f81 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -104,6 +104,7 @@ class RequestState: arrival_time: float, queue: RequestOutputCollector | None, log_stats: bool, + stream_interval: int, top_p: float | None = None, n: int | None = None, temperature: float | None = None, @@ -131,6 +132,10 @@ class RequestState: self.stats = RequestStateStats(arrival_time=arrival_time) if log_stats else None + # Stream Interval + self.stream_interval = stream_interval + self.sent_tokens_offset = 0 # Offset of sent tokens + @classmethod def from_new_request( cls, @@ -141,6 +146,7 @@ class RequestState: request_index: int, queue: RequestOutputCollector | None, log_stats: bool, + stream_interval: int, ) -> "RequestState": if sampling_params := request.sampling_params: if not sampling_params.detokenize: @@ -188,6 +194,7 @@ class RequestState: arrival_time=request.arrival_time, queue=queue, log_stats=log_stats, + stream_interval=stream_interval, ) def make_request_output( @@ -205,6 +212,29 @@ class RequestState: # Only the final output is required in FINAL_ONLY mode. return None + if self.stream_interval > 1: + assert self.detokenizer is not None + + # Send output request only when + # 1. It has finished, or + # 2. It is the first token, or + # 3. It has reached the stream interval number of tokens + if not ( + finished + or self.sent_tokens_offset == 0 + or len(self.detokenizer.output_token_ids) - self.sent_tokens_offset + >= self.stream_interval + ): + return None + + if self.output_kind == RequestOutputKind.DELTA: + # Send tokens from the offset in DELTA mode, otherwise all + # tokens are sent. + new_token_ids = self.detokenizer.output_token_ids[ + self.sent_tokens_offset : + ] + self.sent_tokens_offset = len(self.detokenizer.output_token_ids) + request_id = self.request_id if pooling_output is not None: return self._new_request_output( @@ -310,9 +340,12 @@ class RequestState: class OutputProcessor: """Process EngineCoreOutputs into RequestOutputs.""" - def __init__(self, tokenizer: AnyTokenizer, log_stats: bool): + def __init__( + self, tokenizer: AnyTokenizer, log_stats: bool, stream_interval: int = 1 + ): self.log_stats = log_stats self.tokenizer = tokenizer + self.stream_interval = stream_interval self.request_states: dict[str, RequestState] = {} self.parent_requests: dict[str, ParentRequest] = {} self.lora_states = LoRARequestStates(log_stats) @@ -385,6 +418,7 @@ class OutputProcessor: request_index=request_index, queue=queue, log_stats=self.log_stats, + stream_interval=self.stream_interval, ) self.request_states[request_id] = req_state if parent_req: diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index f2d992403e1a8..4cb911d8e22b7 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -14,6 +14,7 @@ from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal.cache import processor_cache_from_config from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict +from vllm.multimodal.parse import MultiModalDataParser from vllm.multimodal.processing import EncDecMultiModalProcessor from vllm.multimodal.utils import argsort_mm_positions from vllm.pooling_params import PoolingParams @@ -149,6 +150,23 @@ class Processor: raise ValueError( "vLLM V1 does not support per request user provided logits processors." ) + # Async scheduling + spec decode currently incompatible with some + # sampling parameters. + if ( + self.vllm_config.speculative_config is not None + and self.vllm_config.scheduler_config.async_scheduling + and ( + params.frequency_penalty != 0.0 + or params.presence_penalty != 0.0 + or params.repetition_penalty != 1.0 + or params.bad_words_token_ids + or params.structured_outputs + ) + ): + raise ValueError( + "async scheduling with spec decoding doesn't yet support " + "penalties, bad words or structured outputs in sampling parameters." + ) def _validate_params( self, @@ -340,7 +358,12 @@ class Processor: mm_uuids: dict[str, list[str | None] | str] = {} for modality, data in mm_data.items(): - n = len(data) if isinstance(data, list) else 1 + # Hash each item for embedding inputs. + n = ( + len(data) + if isinstance(data, list) or MultiModalDataParser.is_embeddings(data) + else 1 + ) mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)] return mm_uuids @@ -575,6 +598,22 @@ class Processor: # check that chunked prefill does not truncate them # max_batch_len = self.scheduler_config.max_num_batched_tokens + if ( + prompt_len == max_prompt_len + and prompt_type == "decoder" + and not model_config.is_multimodal_model + and self.model_config.runner_type != "pooling" + ): + suggestion = ( + "Make sure that `max_model_len` is no smaller than the " + "number of text tokens (prompt + requested output tokens)." + ) + raise ValueError( + f"The {prompt_type} prompt (length {prompt_len}) plus the number of " + f"requested output tokens (at least 1) is longer than the maximum " + f"model length of {max_prompt_len}. {suggestion}" + ) + def stat_mm_cache(self) -> MultiModalCacheStats | None: return self.input_preprocessor.stat_mm_cache() diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index e74519b21aa6e..d65cad7af03d6 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -183,15 +183,19 @@ def set_device_control_env_var( for engine subprocess. """ world_size = vllm_config.parallel_config.world_size + local_world_size = vllm_config.parallel_config.local_world_size evar = current_platform.device_control_env_var - value = get_device_indices(evar, local_dp_rank, world_size) + value = get_device_indices(evar, local_dp_rank, world_size, local_world_size) with patch.dict(os.environ, values=((evar, value),)): yield def get_device_indices( - device_control_env_var: str, local_dp_rank: int, world_size: int + device_control_env_var: str, + local_dp_rank: int, + world_size: int, + local_world_size: int | None = None, ): """ Returns a comma-separated string of device indices for the specified @@ -200,10 +204,15 @@ def get_device_indices( For example, if world_size=2 and local_dp_rank=1, and there are 4 devices, this will select devices 2 and 3 for local_dp_rank=1. """ + if local_world_size is None: + local_world_size = world_size try: value = ",".join( str(current_platform.device_id_to_physical_device_id(i)) - for i in range(local_dp_rank * world_size, (local_dp_rank + 1) * world_size) + for i in range( + local_dp_rank * world_size, + local_dp_rank * world_size + local_world_size, + ) ) except IndexError as e: raise Exception( diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 1e249161c6886..ad2ece50f9815 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -10,7 +10,7 @@ import time import traceback import weakref from collections import deque -from collections.abc import Callable +from collections.abc import Callable, Sequence from concurrent.futures import Future, InvalidStateError from contextlib import suppress from dataclasses import dataclass @@ -31,8 +31,10 @@ from vllm.distributed import destroy_distributed_environment, destroy_model_para from vllm.distributed.device_communicators.shm_broadcast import Handle, MessageQueue from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator from vllm.distributed.parallel_state import ( + get_dcp_group, get_dp_group, get_ep_group, + get_inner_dp_world_group, get_pp_group, get_tp_group, ) @@ -89,6 +91,10 @@ class FutureWrapper(Future): class MultiprocExecutor(Executor): supports_pp: bool = True + def __init__(self, vllm_config: VllmConfig, monitor_workers: bool = True): + self.monitor_workers = monitor_workers + super().__init__(vllm_config) + def _init_executor(self) -> None: # Call self.shutdown at exit to clean up # and ensure workers will be terminated. @@ -98,6 +104,12 @@ class MultiprocExecutor(Executor): self.failure_callback: FailureCallback | None = None self.world_size = self.parallel_config.world_size + assert self.world_size % self.parallel_config.nnodes_within_dp == 0, ( + f"global world_size ({self.parallel_config.world_size}) must be " + f"divisible by nnodes_within_dp " + f"({self.parallel_config.nnodes_within_dp}). " + ) + self.local_world_size = self.parallel_config.local_world_size tensor_parallel_size = self.parallel_config.tensor_parallel_size pp_parallel_size = self.parallel_config.pipeline_parallel_size assert self.world_size == tensor_parallel_size * pp_parallel_size, ( @@ -115,27 +127,37 @@ class MultiprocExecutor(Executor): distributed_init_method = get_distributed_init_method( get_loopback_ip(), get_open_port() ) - + self.rpc_broadcast_mq: MessageQueue | None = None + scheduler_output_handle: Handle | None = None # Initialize worker and set up message queues for SchedulerOutputs # and ModelRunnerOutputs - max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024 - self.rpc_broadcast_mq = MessageQueue( - self.world_size, self.world_size, max_chunk_bytes=max_chunk_bytes - ) - scheduler_output_handle = self.rpc_broadcast_mq.export_handle() - + if self.parallel_config.node_rank_within_dp == 0: + # For leader node within each dp rank, + # each dp will have its own leader multiproc executor. + max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024 + self.rpc_broadcast_mq = MessageQueue( + self.world_size, + self.local_world_size, + max_chunk_bytes=max_chunk_bytes, + connect_ip=self.parallel_config.master_addr, + ) + scheduler_output_handle = self.rpc_broadcast_mq.export_handle() # Create workers context = get_mp_context() shared_worker_lock = context.Lock() unready_workers: list[UnreadyWorkerProcHandle] = [] success = False try: - for rank in range(self.world_size): + global_start_rank = ( + self.local_world_size * self.parallel_config.node_rank_within_dp + ) + for local_rank in range(self.local_world_size): + global_rank = global_start_rank + local_rank unready_workers.append( WorkerProc.make_worker_process( vllm_config=self.vllm_config, - local_rank=rank, - rank=rank, + local_rank=local_rank, + rank=global_rank, distributed_init_method=distributed_init_method, input_shm_handle=scheduler_output_handle, shared_worker_lock=shared_worker_lock, @@ -144,15 +166,38 @@ class MultiprocExecutor(Executor): # Workers must be created before wait_for_ready to avoid # deadlock, since worker.init_device() does a device sync. + + # Wait for all local workers to be ready. self.workers = WorkerProc.wait_for_ready(unready_workers) + # Start background thread to monitor worker health if not in headless mode. + if self.monitor_workers: + self.start_worker_monitor() + + self.response_mqs = [] + # Only leader node have remote response mqs + if self.parallel_config.node_rank_within_dp == 0: + for rank in range(self.world_size): + if rank < self.local_world_size: + local_message_queue = self.workers[rank].worker_response_mq + assert local_message_queue is not None + self.response_mqs.append(local_message_queue) + else: + remote_message_queue = self.workers[0].peer_worker_response_mqs[ + rank + ] + assert remote_message_queue is not None + self.response_mqs.append(remote_message_queue) + # Ensure message queues are ready. Will deadlock if re-ordered # Must be kept consistent with the WorkerProc. - self.rpc_broadcast_mq.wait_until_ready() - for w in self.workers: - w.worker_response_mq.wait_until_ready() - self.start_worker_monitor() + # Wait for all input mqs to be ready. + if self.rpc_broadcast_mq is not None: + self.rpc_broadcast_mq.wait_until_ready() + # Wait for all remote response mqs to be ready. + for response_mq in self.response_mqs: + response_mq.wait_until_ready() success = True finally: if not success: @@ -167,7 +212,7 @@ class MultiprocExecutor(Executor): self.output_rank = self._get_output_rank() - def start_worker_monitor(self): + def start_worker_monitor(self, inline=False) -> None: workers = self.workers self_ref = weakref.ref(self) @@ -191,9 +236,13 @@ class MultiprocExecutor(Executor): _self.failure_callback = None callback() - Thread( - target=monitor_workers, daemon=True, name="MultiprocWorkerMonitor" - ).start() + if not inline: + Thread( + target=monitor_workers, daemon=True, name="MultiprocWorkerMonitor" + ).start() + return + + monitor_workers() def register_failure_callback(self, callback: FailureCallback): if self.is_failed: @@ -246,7 +295,9 @@ class MultiprocExecutor(Executor): ) -> Any | list[Any] | Future[Any | list[Any]]: """Returns single result if unique_reply_rank and/or kv_output_aggregator is provided, otherwise list.""" - + assert self.rpc_broadcast_mq is not None, ( + "collective_rpc should not be called on follower node" + ) if self.is_failed: raise RuntimeError("Executor failed.") @@ -268,20 +319,20 @@ class MultiprocExecutor(Executor): send_method = cloudpickle.dumps(method, protocol=pickle.HIGHEST_PROTOCOL) self.rpc_broadcast_mq.enqueue((send_method, args, kwargs, output_rank)) - workers = ( - (self.workers[output_rank],) if output_rank is not None else self.workers - ) + response_mqs: Sequence[MessageQueue] = self.response_mqs + if output_rank is not None: + response_mqs = (response_mqs[output_rank],) shutdown_event = self.shutdown_event def get_response(): responses = [] - for w in workers: + for mq in response_mqs: dequeue_timeout = ( None if deadline is None else (deadline - time.monotonic()) ) try: - status, result = w.worker_response_mq.dequeue( + status, result = mq.dequeue( timeout=dequeue_timeout, cancel=shutdown_event ) except TimeoutError as e: @@ -390,17 +441,26 @@ class UnreadyWorkerProcHandle: class WorkerProcHandle: proc: BaseProcess rank: int - worker_response_mq: MessageQueue # The worker process writes to this MQ + # The worker process writes to this MQ in single-node mode + worker_response_mq: MessageQueue | None + # This is only non empty on driver node, + # the peer worker process i writes to MQ + # `peer_worker_response_mqs[i]` + peer_worker_response_mqs: list[MessageQueue | None] death_writer: Connection | None = None @classmethod def from_unready_handle( - cls, unready_handle: UnreadyWorkerProcHandle, worker_response_mq: MessageQueue + cls, + unready_handle: UnreadyWorkerProcHandle, + worker_response_mq: MessageQueue | None, + peer_worker_response_mqs: list[MessageQueue | None], ) -> "WorkerProcHandle": return cls( proc=unready_handle.proc, rank=unready_handle.rank, worker_response_mq=worker_response_mq, + peer_worker_response_mqs=peer_worker_response_mqs, death_writer=unready_handle.death_writer, ) @@ -410,6 +470,38 @@ class WorkerProc: READY_STR = "READY" + def _init_message_queues( + self, input_shm_handle: Handle, vllm_config: VllmConfig + ) -> None: + if vllm_config.parallel_config.nnodes_within_dp == 1: + # Initialize MessageQueue for receiving SchedulerOutput + self.rpc_broadcast_mq = MessageQueue.create_from_handle( + input_shm_handle, self.worker.rank + ) + + # Initializes a message queue for sending the model output + self.worker_response_mq: MessageQueue = MessageQueue(1, 1) + self.peer_response_handles = [] + else: + # Initialize remote MessageQueue for receiving SchedulerOutput across nodes + self.rpc_broadcast_mq = get_inner_dp_world_group().create_mq_broadcaster( + external_writer_handle=input_shm_handle, + # Since there is external_writer_handle from executor proc, + # where the ready signal from actual writer is sent out of the + # create_mq_broadcaster method and after this setup, we make it + # non blocking. The handshake will be triggered when + # worker.rpc_broadcast_mq.wait_until_ready() is called + blocking=False, + ) + # Initializes remote message queue for sending the model output to the + # driver worker, exposing peer_response_handles for driver worker + # that include handles for all ranks + self.worker_response_mq, self.peer_response_handles = ( + get_inner_dp_world_group().create_single_reader_mq_broadcasters( + reader_rank_in_group=0 + ) + ) + def __init__( self, vllm_config: VllmConfig, @@ -420,13 +512,15 @@ class WorkerProc: shared_worker_lock: LockType, ): self.rank = rank - wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank) + wrapper = WorkerWrapperBase( + vllm_config=vllm_config, rpc_rank=local_rank, global_rank=rank + ) # TODO: move `init_worker` to executor level as a collective rpc call all_kwargs: list[dict] = [ {} for _ in range(vllm_config.parallel_config.world_size) ] is_driver_worker = rank % vllm_config.parallel_config.tensor_parallel_size == 0 - all_kwargs[rank] = { + all_kwargs[local_rank] = { "vllm_config": vllm_config, "local_rank": local_rank, "rank": rank, @@ -437,14 +531,6 @@ class WorkerProc: wrapper.init_worker(all_kwargs) self.worker = wrapper - # Initialize MessageQueue for receiving SchedulerOutput - self.rpc_broadcast_mq = MessageQueue.create_from_handle( - input_shm_handle, self.worker.rank - ) - - # Initializes a message queue for sending the model output - self.worker_response_mq = MessageQueue(1, 1) - scheduler_config = vllm_config.scheduler_config self.use_async_scheduling = scheduler_config.async_scheduling if self.use_async_scheduling: @@ -465,6 +551,7 @@ class WorkerProc: ) # Load model + self._init_message_queues(input_shm_handle, vllm_config) self.worker.load_model() # Enable environment variable cache (e.g. assume no more @@ -511,6 +598,27 @@ class WorkerProc: # death_reader in child will get EOFError return UnreadyWorkerProcHandle(proc, rank, reader, death_writer) + @staticmethod + def wait_for_response_handle_ready( + handles: dict[str, Any], proc_handle: UnreadyWorkerProcHandle + ) -> WorkerProcHandle: + response_handle = handles["handle"] + worker_response_mq: MessageQueue | None = None + if len(response_handle.local_reader_ranks) > 0: + worker_response_mq = MessageQueue.create_from_handle(response_handle, 0) + peer_response_handles = handles["peer_response_handles"] + peer_worker_response_mqs = [ + MessageQueue.create_from_handle(handle, -1) + if handle.remote_subscribe_addr is not None + else None + for handle in peer_response_handles + ] + return WorkerProcHandle.from_unready_handle( + proc_handle, + worker_response_mq, + peer_worker_response_mqs=peer_worker_response_mqs, + ) + @staticmethod def wait_for_ready( unready_proc_handles: list[UnreadyWorkerProcHandle], @@ -536,16 +644,10 @@ class WorkerProc: if response["status"] != "READY": raise e - # Extract the message queue handle. - worker_response_mq = MessageQueue.create_from_handle( - response["handle"], 0 + idx = unready_proc_handle.rank % len(ready_proc_handles) + ready_proc_handles[idx] = WorkerProc.wait_for_response_handle_ready( + response, unready_proc_handle ) - ready_proc_handles[unready_proc_handle.rank] = ( - WorkerProcHandle.from_unready_handle( - unready_proc_handle, worker_response_mq - ) - ) - except EOFError: e.__suppress_context__ = True raise e from None @@ -617,12 +719,14 @@ class WorkerProc: { "status": WorkerProc.READY_STR, "handle": worker.worker_response_mq.export_handle(), + "peer_response_handles": worker.peer_response_handles, } ) # Ensure message queues are ready. Will deadlock if re-ordered. # Must be kept consistent with the Executor - worker.rpc_broadcast_mq.wait_until_ready() + if worker.rpc_broadcast_mq is not None: + worker.rpc_broadcast_mq.wait_until_ready() worker.worker_response_mq.wait_until_ready() ready_writer.close() ready_writer = None @@ -726,6 +830,8 @@ class WorkerProc: pp_rank = get_pp_group().rank_in_group tp_size = get_tp_group().world_size tp_rank = get_tp_group().rank_in_group + dcp_size = get_dcp_group().world_size + dcp_rank = get_dcp_group().rank_in_group process_name = "Worker" if dp_size > 1: process_name += f"_DP{dp_rank}" @@ -733,6 +839,8 @@ class WorkerProc: process_name += f"_PP{pp_rank}" if tp_size > 1: process_name += f"_TP{tp_rank}" + if dcp_size > 1: + process_name += f"_DCP{dcp_rank}" if enable_ep: ep_rank = get_ep_group().rank_in_group process_name += f"_EP{ep_rank}" diff --git a/vllm/v1/executor/ray_executor.py b/vllm/v1/executor/ray_executor.py index 119e4c0818316..55db7445c9c74 100644 --- a/vllm/v1/executor/ray_executor.py +++ b/vllm/v1/executor/ray_executor.py @@ -99,6 +99,11 @@ class RayDistributedExecutor(Executor): # KV connector setup self.has_connector = self.vllm_config.kv_transfer_config is not None + self.ec_producer = ( + self.vllm_config.ec_transfer_config is not None + and self.vllm_config.ec_transfer_config.is_ec_producer + ) + self.scheduler_output: SchedulerOutput | None = None @property @@ -395,6 +400,12 @@ class RayDistributedExecutor(Executor): "State error: sample_tokens() must be called " "after execute_model() returns None." ) + + if self.ec_producer or not scheduler_output.total_num_scheduled_tokens: + # Model will not execute, call model runner immediately. + return self._execute_dag(scheduler_output, None, non_block) + + # Model will execute, defer to sample_tokens() call. self.scheduler_output = scheduler_output return COMPLETED_NONE_FUTURE if non_block else None @@ -417,10 +428,18 @@ class RayDistributedExecutor(Executor): """ scheduler_output = self.scheduler_output if scheduler_output is None: - return None # noqa + return COMPLETED_NONE_FUTURE if non_block else None # noqa self.scheduler_output = None + return self._execute_dag(scheduler_output, grammar_output, non_block) + + def _execute_dag( + self, + scheduler_output: SchedulerOutput, + grammar_output: "GrammarOutput | None", + non_block: bool = False, + ) -> ModelRunnerOutput | Future[ModelRunnerOutput]: # Build the compiled DAG for the first time. if self.forward_dag is None: # type: ignore self.forward_dag = self._compiled_ray_dag(enable_asyncio=False) diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 1a175e9e110bd..cb36e7973650e 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -118,12 +118,14 @@ class LoggingStatLogger(StatLoggerBase): self.num_prompt_tokens: int = 0 self.num_generation_tokens: int = 0 self.num_corrupted_reqs: int = 0 + self.num_preemptions: int = 0 def _track_iteration_stats(self, iteration_stats: IterationStats): # Save tracked stats for token counters. self.num_prompt_tokens += iteration_stats.num_prompt_tokens self.num_generation_tokens += iteration_stats.num_generation_tokens self.num_corrupted_reqs += iteration_stats.num_corrupted_reqs + self.num_preemptions += iteration_stats.num_preempted_reqs def _get_throughput(self, tracked_stats: int, now: float) -> float: # Compute summary metrics for tracked stats @@ -196,18 +198,31 @@ class LoggingStatLogger(StatLoggerBase): "Avg generation throughput: %.1f tokens/s", "Running: %d reqs", "Waiting: %d reqs", - "GPU KV cache usage: %.1f%%", - "Prefix cache hit rate: %.1f%%", ] log_args = [ self.last_prompt_throughput, self.last_generation_throughput, self.last_scheduler_stats.num_running_reqs, self.last_scheduler_stats.num_waiting_reqs, - self.last_scheduler_stats.kv_cache_usage * 100, - self.prefix_caching_metrics.hit_rate * 100, ] + if self.num_preemptions > 0: + log_parts.append("Preemptions: %d") + log_args.append(self.num_preemptions) + + log_parts.extend( + [ + "GPU KV cache usage: %.1f%%", + "Prefix cache hit rate: %.1f%%", + ] + ) + log_args.extend( + [ + self.last_scheduler_stats.kv_cache_usage * 100, + self.prefix_caching_metrics.hit_rate * 100, + ] + ) + if envs.VLLM_COMPUTE_NANS_IN_LOGITS: log_parts.append("Corrupted: %d reqs") log_args.append(self.num_corrupted_reqs) @@ -479,6 +494,7 @@ class PrometheusStatLogger(AggregateStatLoggerBase): gauge_kv_cache_usage = self._gauge_cls( name="vllm:kv_cache_usage_perc", documentation="KV-cache usage. 1 means 100 percent usage.", + multiprocess_mode="mostrecent", labelnames=labelnames, ) self.gauge_kv_cache_usage = make_per_engine( diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index e32d5bb608b1d..c0b2835c3124c 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -158,7 +158,7 @@ class ModelRunnerOutput: # num_generated_tokens is the number of tokens # generated in the current step. It can be different for # each request due to speculative/jump decoding. - sampled_token_ids: list[list[int]] + sampled_token_ids: list[np.ndarray] # [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1] @@ -220,7 +220,7 @@ def make_empty_encoder_model_runner_output( req_id_to_index: dict[str, int] = {rid: idx for idx, rid in enumerate(req_ids)} # No tokens generated yet ⇒ one empty list per request - sampled_token_ids: list[list[int]] = [[0] for _ in req_ids] + sampled_token_ids: list[list[int]] = [np.array([0]) for _ in req_ids] # Pooler outputs are not available yet ⇒ use None placeholders pooler_output: list[torch.Tensor | None] = [None for _ in req_ids] diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 7a5f1183ed48e..3d92906fbf4b1 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -127,6 +127,8 @@ class Request: self.get_hash_new_full_blocks = partial(block_hasher, self) self.block_hashes = self.get_hash_new_full_blocks() + self.skip_reading_prefix_cache = self.get_skip_reading_prefix_cache() + @classmethod def from_engine_core_request( cls, @@ -180,6 +182,19 @@ class Request: def num_output_tokens(self) -> int: return len(self._output_token_ids) + def get_skip_reading_prefix_cache(self) -> bool: + if ( + self.sampling_params is not None + and self.sampling_params.skip_reading_prefix_cache is not None + ): + return self.sampling_params.skip_reading_prefix_cache + elif ( + self.pooling_params is not None + and self.pooling_params.skip_reading_prefix_cache is not None + ): + return self.pooling_params.skip_reading_prefix_cache + return False + def is_finished(self) -> bool: return RequestStatus.is_finished(self.status) diff --git a/vllm/v1/sample/logits_processor/__init__.py b/vllm/v1/sample/logits_processor/__init__.py index 5992c4066c9cb..8b174af4c7794 100644 --- a/vllm/v1/sample/logits_processor/__init__.py +++ b/vllm/v1/sample/logits_processor/__init__.py @@ -41,7 +41,7 @@ STR_POOLING_REJECTS_LOGITSPROCS = ( # Error message when the user tries to initialize vLLM with a speculative # decoding enabled and custom logitsproces STR_SPEC_DEC_REJECTS_LOGITSPROCS = ( - "Custom logits processors are not supportedwhen speculative decoding is enabled." + "Custom logits processors are not supported when speculative decoding is enabled." ) LOGITSPROCS_GROUP = "vllm.logits_processors" diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 926305d25f56b..f31a0cddda9ae 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -3,6 +3,7 @@ from dataclasses import replace +import numpy as np import torch import torch.nn as nn @@ -204,7 +205,7 @@ class RejectionSampler(nn.Module): def parse_output( output_token_ids: torch.Tensor, vocab_size: int, - ) -> list[list[int]]: + ) -> list[np.ndarray]: """Parse the output of the rejection sampler. Args: output_token_ids: The sampled token IDs in shape @@ -220,10 +221,7 @@ class RejectionSampler(nn.Module): valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & ( output_token_ids_np < vocab_size ) - outputs = [ - row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np) - ] - return outputs + return [row[valid_mask[i]] for i, row in enumerate(output_token_ids_np)] def apply_logits_processors( self, diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 102357ca7c642..0a6806390451d 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -8,7 +8,7 @@ from collections.abc import Callable, Sequence from functools import partial from inspect import isclass from types import FunctionType -from typing import Any, TypeAlias +from typing import Any, TypeAlias, get_type_hints import cloudpickle import msgspec @@ -16,6 +16,8 @@ import numpy as np import torch import zmq from msgspec import msgpack +from pydantic import GetCoreSchemaHandler +from pydantic_core import core_schema from vllm import envs from vllm.logger import init_logger @@ -31,7 +33,7 @@ from vllm.multimodal.inputs import ( MultiModalSharedField, NestedTensors, ) -from vllm.v1.engine import UtilityResult +from vllm.utils.platform_utils import is_pin_memory_available from vllm.v1.utils import tensor_data logger = init_logger(__name__) @@ -103,6 +105,13 @@ def _decode_type_info_recursive( return convert_fn(type_info, data) +class UtilityResult: + """Wrapper for special handling when serializing/deserializing.""" + + def __init__(self, r: Any = None): + self.result = r + + class MsgpackEncoder: """Encoder with custom torch tensor and numpy array serialization. @@ -282,7 +291,9 @@ class MsgpackDecoder: not thread-safe when encoding tensors / numpy arrays. """ - def __init__(self, t: Any | None = None): + def __init__(self, t: Any | None = None, share_mem: bool = True): + self.share_mem = share_mem + self.pin_tensors = is_pin_memory_available() args = () if t is None else (t,) self.decoder = msgpack.Decoder( *args, ext_hook=self.ext_hook, dec_hook=self.dec_hook @@ -347,21 +358,30 @@ class MsgpackDecoder: # zero-copy decode. We assume the ndarray will not be kept around, # as it now locks the whole received message buffer in memory. buffer = self.aux_buffers[data] if isinstance(data, int) else data - return np.frombuffer(buffer, dtype=dtype).reshape(shape) + arr = np.frombuffer(buffer, dtype=dtype) + if not self.share_mem: + arr = arr.copy() + return arr.reshape(shape) def _decode_tensor(self, arr: Any) -> torch.Tensor: dtype, shape, data = arr - # Copy from inline representation, to decouple the memory storage - # of the message from the original buffer. And also make Torch - # not complain about a readonly memoryview. - buffer = self.aux_buffers[data] if isinstance(data, int) else bytearray(data) + is_aux = isinstance(data, int) + buffer = self.aux_buffers[data] if is_aux else data + buffer = buffer if isinstance(buffer, memoryview) else memoryview(buffer) torch_dtype = getattr(torch, dtype) assert isinstance(torch_dtype, torch.dtype) - if not buffer: # torch.frombuffer doesn't like empty buffers + if not buffer.nbytes: # torch.frombuffer doesn't like empty buffers assert 0 in shape return torch.empty(shape, dtype=torch_dtype) # Create uint8 array arr = torch.frombuffer(buffer, dtype=torch.uint8) + # Clone ensures tensor is backed by pytorch-owned memory for safe + # future async CPU->GPU transfer. + # Pin larger tensors for more efficient CPU->GPU transfer. + if not is_aux: + arr = arr.clone() + elif not self.share_mem: + arr = arr.pin_memory() if self.pin_tensors else arr.clone() # Convert back to proper shape & type return arr.view(torch_dtype).view(shape) @@ -457,3 +477,56 @@ def run_method( else: func = partial(method, obj) # type: ignore return func(*args, **kwargs) + + +class PydanticMsgspecMixin: + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + """ + Make msgspec.Struct compatible with Pydantic, respecting defaults. + Handle JSON=>msgspec.Struct. Used when exposing msgspec.Struct to the + API as input or in `/docs`. Note this is cached by Pydantic and not + called on every validation. + """ + msgspec_fields = {f.name: f for f in msgspec.structs.fields(source_type)} + type_hints = get_type_hints(source_type) + + # Build the Pydantic typed_dict_field for each msgspec field + fields = {} + for name, hint in type_hints.items(): + msgspec_field = msgspec_fields[name] + + # typed_dict_field using the handler to get the schema + field_schema = handler(hint) + + # Add default value to the schema. + if msgspec_field.default_factory is not msgspec.NODEFAULT: + wrapped_schema = core_schema.with_default_schema( + schema=field_schema, + default_factory=msgspec_field.default_factory, + ) + fields[name] = core_schema.typed_dict_field(wrapped_schema) + elif msgspec_field.default is not msgspec.NODEFAULT: + wrapped_schema = core_schema.with_default_schema( + schema=field_schema, + default=msgspec_field.default, + ) + fields[name] = core_schema.typed_dict_field(wrapped_schema) + else: + # No default, so Pydantic will treat it as required + fields[name] = core_schema.typed_dict_field(field_schema) + return core_schema.no_info_after_validator_function( + cls._validate_msgspec, + core_schema.typed_dict_schema(fields), + ) + + @classmethod + def _validate_msgspec(cls, value: Any) -> Any: + """Validate and convert input to msgspec.Struct instance.""" + if isinstance(value, cls): + return value + if isinstance(value, dict): + return cls(**value) + return msgspec.convert(value, type=cls) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index beef5203e0394..5bf2503c3027d 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -397,10 +397,13 @@ class EagleProposer: positions += 1 exceeds_max_model_len = positions >= self.max_model_len clamped_positions = torch.where(exceeds_max_model_len, 0, positions) - + # For data integrity when async scheduling, we shouldn't use in place + # operations in case they are modified in next step's `prepare_input` + # of main model. # Increment the sequence lengths. common_attn_metadata.seq_lens += 1 - common_attn_metadata.seq_lens_cpu += 1 + # This is an out-of-place operation to avoid modifying the original tensor. + common_attn_metadata.seq_lens_cpu = common_attn_metadata.seq_lens_cpu + 1 # For the requests that exceed the max model length, we set the # sequence length to 1 to minimize their overheads in attention. @@ -484,7 +487,7 @@ class EagleProposer: def prepare_next_token_ids_cpu( self, - sampled_token_ids: list[list[int]], + sampled_token_ids: list[np.ndarray], requests: dict[str, CachedRequestState], gpu_input_batch: InputBatch, num_scheduled_tokens: dict[str, int], @@ -499,7 +502,7 @@ class EagleProposer: req_ids = gpu_input_batch.req_ids next_token_ids: list[int] = [] for i, token_ids in enumerate(sampled_token_ids): - if token_ids: + if token_ids.shape[0] > 0: # Common case. next_token_id = token_ids[-1] else: @@ -510,10 +513,9 @@ class EagleProposer: seq_len = req_state.num_computed_tokens + num_scheduled_tokens[req_id] next_token_id = req_state.get_token_id(seq_len) next_token_ids.append(next_token_id) - next_token_ids = torch.tensor( + return torch.tensor( next_token_ids, dtype=torch.int32, device=self.input_ids.device ) - return next_token_ids def prepare_next_token_ids_padded( self, @@ -992,6 +994,7 @@ class EagleProposer: target_language_model = target_model.get_language_model() else: target_language_model = target_model + # share embed_tokens with the target model if needed if get_pp_group().world_size == 1: if hasattr(target_language_model.model, "embed_tokens"): @@ -1003,52 +1006,92 @@ class EagleProposer: "Target model does not have 'embed_tokens' or 'embedding' attribute" ) - # Check if shapes match and we found the embedding - eagle_shape = self.model.model.embed_tokens.weight.shape - target_shape = target_embed_tokens.weight.shape - if eagle_shape == target_shape: - logger.info( - "Assuming the EAGLE head shares the same vocab embedding" - " with the target model." - ) - del self.model.model.embed_tokens - self.model.model.embed_tokens = target_embed_tokens + share_embeddings = False + if hasattr(self.model, "has_own_embed_tokens"): + # EAGLE model + if not self.model.has_own_embed_tokens: + share_embeddings = True + logger.info( + "Detected EAGLE model without its own embed_tokens in the" + " checkpoint. Sharing target model embedding weights with the" + " draft model." + ) + elif ( + isinstance(target_embed_tokens.weight, torch.Tensor) + and isinstance(self.model.model.embed_tokens.weight, torch.Tensor) + and torch.equal( + target_embed_tokens.weight, self.model.model.embed_tokens.weight + ) + ): + share_embeddings = True + logger.info( + "Detected EAGLE model with embed_tokens identical to the target" + " model. Sharing target model embedding weights with the draft" + " model." + ) + else: + logger.info( + "Detected EAGLE model with distinct embed_tokens weights. " + "Keeping separate embedding weights from the target model." + ) else: + # MTP model + share_embeddings = True logger.info( - "The EAGLE head's vocab embedding will be loaded separately" - " from the target model." + "Detected MTP model. " + "Sharing target model embedding weights with the draft model." ) + + if share_embeddings: + if hasattr(self.model.model, "embed_tokens"): + del self.model.model.embed_tokens + self.model.model.embed_tokens = target_embed_tokens else: logger.info( - "The EAGLE head's vocab embedding will be loaded separately" + "The draft model's vocab embedding will be loaded separately" " from the target model." ) # share lm_head with the target model if needed - # some model definition do not define lm_head explicitly - # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM - if self.vllm_config.speculative_config.method != "eagle3": - if hasattr(target_language_model, "lm_head"): - logger.info("Loading EAGLE LM head weights from the target model.") - self.model.lm_head = target_language_model.lm_head - else: - if ( - hasattr(self.model, "lm_head") - and hasattr(target_language_model, "lm_head") - and self.model.lm_head.weight.shape - == target_language_model.lm_head.weight.shape - ): + share_lm_head = False + if hasattr(self.model, "has_own_lm_head"): + # EAGLE model + if not self.model.has_own_lm_head: + share_lm_head = True logger.info( - "Assuming the EAGLE head shares the same lm_head" - " with the target model." + "Detected EAGLE model without its own lm_head in the checkpoint. " + "Sharing target model lm_head weights with the draft model." + ) + elif ( + hasattr(target_language_model, "lm_head") + and isinstance(target_language_model.lm_head.weight, torch.Tensor) + and isinstance(self.model.lm_head.weight, torch.Tensor) + and torch.equal( + target_language_model.lm_head.weight, self.model.lm_head.weight + ) + ): + share_lm_head = True + logger.info( + "Detected EAGLE model with lm_head identical to the target model. " + "Sharing target model lm_head weights with the draft model." ) - del self.model.lm_head - self.model.lm_head = target_language_model.lm_head else: logger.info( - "The EAGLE head's lm_head will be loaded separately" - " from the target model." + "Detected EAGLE model with distinct lm_head weights. " + "Keeping separate lm_head weights from the target model." ) + else: + # MTP model + share_lm_head = True + logger.info( + "Detected MTP model. " + "Sharing target model lm_head weights with the draft model." + ) + + if share_lm_head and hasattr(target_language_model, "lm_head"): + if hasattr(self.model, "lm_head"): + del self.model.lm_head + self.model.lm_head = target_language_model.lm_head @torch.inference_mode() def dummy_run( diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index e2f83cb24aa90..378937dba9882 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -54,7 +54,7 @@ class NgramProposer: # Trigger Numba JIT compilation for N-gram proposer. # This usually takes less than 1 second. self.propose( - [[]] * 1024, + [np.array([])] * 1024, [""] * 1024, np.zeros(1024, dtype=np.int32), np.zeros((1024, self.max_model_len), dtype=np.int32), @@ -131,7 +131,7 @@ class NgramProposer: def propose( self, - sampled_token_ids: list[list[int]], + sampled_token_ids: list[np.ndarray], req_ids: list[str], num_tokens_no_spec: np.ndarray, token_ids_cpu: np.ndarray, @@ -140,7 +140,7 @@ class NgramProposer: # find which requests need ngram proposals valid_ngram_requests = [] for i, sampled_ids in enumerate(sampled_token_ids): - num_sampled_ids = len(sampled_ids) + num_sampled_ids = sampled_ids.shape[0] if not num_sampled_ids: # Skip speculative decoding. continue diff --git a/vllm/v1/spec_decode/suffix_decoding.py b/vllm/v1/spec_decode/suffix_decoding.py index 049e335db3254..d76e0ffe778d4 100644 --- a/vllm/v1/spec_decode/suffix_decoding.py +++ b/vllm/v1/spec_decode/suffix_decoding.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import numpy as np + from vllm.config import VllmConfig from vllm.v1.worker.gpu_input_batch import InputBatch @@ -32,16 +34,16 @@ class SuffixDecodingProposer: def propose( self, input_batch: InputBatch, - sampled_token_ids: list[list[int]], + sampled_token_ids: list[np.ndarray], ) -> list[list[int]]: """ Propose speculative tokens for each request in the input batch. Suffix Decoding will speculate a dynamic number of tokens for each request every decoding step, so each entry in the returned list may have different lengths. """ - draft_token_ids: list[list[int]] = [] + draft_token_ids: list[np.ndarray] = [] for i, sampled_ids in enumerate(sampled_token_ids): - if not sampled_ids: + if sampled_ids.shape[0] == 0: # Skip speculative decoding for partial prefills. draft_token_ids.append([]) continue @@ -70,7 +72,7 @@ class SuffixDecodingProposer: self.suffix_cache.start_request(req_id, prompt_token_ids) # Append the newly sampled ids to the suffix cache for this request. - self.suffix_cache.add_active_response(req_id, sampled_ids) + self.suffix_cache.add_active_response(req_id, sampled_ids.tolist()) # Suffix decoding only uses the most recent tokens up to max_tree_depth, so # we extract the pattern from the end of the input. diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index a401f6d74cdd5..29099d1e9b17e 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -97,6 +97,9 @@ class ConstantList(Generic[T], Sequence): def __repr__(self): return f"ConstantList({self._x})" + def copy(self) -> list[T]: + return self._x.copy() + class CpuGpuBuffer: """Buffer to easily copy tensors between CPU and GPU.""" diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index c28bf542f85c5..9f6c19e464308 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -98,7 +98,9 @@ class BlockTable: return if self.use_hybrid_blocks: - block_ids = self._map_to_kernel_blocks(np.array(block_ids)) + block_ids = self.map_to_kernel_blocks( + np.array(block_ids), self.blocks_per_kv_block, self._kernel_block_arange + ) num_blocks = len(block_ids) start = self.num_blocks_per_row[row_idx] @@ -188,7 +190,12 @@ class BlockTable: self.block_table.gpu.fill_(0) self.block_table.cpu.fill_(0) - def _map_to_kernel_blocks(self, kv_manager_block_ids: np.ndarray) -> np.ndarray: + @staticmethod + def map_to_kernel_blocks( + kv_manager_block_ids: np.ndarray, + blocks_per_kv_block: int, + kernel_block_arange: np.ndarray, + ) -> np.ndarray: """Convert kv_manager_block_id IDs to kernel block IDs. Example: @@ -203,12 +210,12 @@ class BlockTable: # kv_manager_block_id 1 → kernel block id [2, 3] # kv_manager_block_id 2 → kernel block id [4, 5] """ - if not self.use_hybrid_blocks: + if blocks_per_kv_block == 1: return kv_manager_block_ids kernel_block_ids = ( - kv_manager_block_ids.reshape(-1, 1) * self.blocks_per_kv_block - + self._kernel_block_arange + kv_manager_block_ids.reshape(-1, 1) * blocks_per_kv_block + + kernel_block_arange ) return kernel_block_ids.reshape(-1) diff --git a/vllm/v1/worker/cpu_model_runner.py b/vllm/v1/worker/cpu_model_runner.py index ceb1cf64b5889..40f011fed1ada 100644 --- a/vllm/v1/worker/cpu_model_runner.py +++ b/vllm/v1/worker/cpu_model_runner.py @@ -80,9 +80,6 @@ class CPUModelRunner(GPUModelRunner): def _sync_device(self) -> None: pass - def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]: - return sampled_token_ids.tolist() - def get_dp_padding(self, num_tokens: int) -> tuple[int, torch.Tensor | None]: # Note: For CPU backend, dp padding is not required for now. return 0, None diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 393181f543d2e..7cf6afa3fc371 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -46,6 +46,9 @@ class CachedRequestState: lora_request: LoRARequest | None = None prompt_embeds: torch.Tensor | None = None + # Used when both async_scheduling and spec_decode are enabled. + prev_num_draft_len: int = 0 + def __post_init__(self): self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds( self.prompt_token_ids, self.prompt_embeds diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c9c64137ca04b..0102ca4739ad5 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -7,7 +7,7 @@ import time from collections import defaultdict from collections.abc import Iterator from contextlib import contextmanager -from copy import deepcopy +from copy import copy, deepcopy from functools import reduce from itertools import product from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast @@ -179,6 +179,7 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput): logprobs_tensors: torch.Tensor | None, invalid_req_indices: list[int], async_output_copy_stream: torch.cuda.Stream, + vocab_size: int, ): self._model_runner_output = model_runner_output self._invalid_req_indices = invalid_req_indices @@ -189,6 +190,7 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput): # Keep a reference to the device tensor to avoid it being # deallocated until we finish copying it to the host. self._sampled_token_ids = sampled_token_ids + self.vocab_size = vocab_size self._logprobs_tensors = logprobs_tensors # Initiate the copy on a separate stream, but do not synchronize it. @@ -215,10 +217,18 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput): # Release the device tensors once the copy has completed. del self._logprobs_tensors del self._sampled_token_ids - - valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist() + max_gen_len = self.sampled_token_ids_cpu.shape[-1] + if max_gen_len == 1: + valid_sampled_token_ids: list[np.ndarray] = [ + row for row in self.sampled_token_ids_cpu.numpy() + ] + else: + valid_sampled_token_ids = RejectionSampler.parse_output( + self.sampled_token_ids_cpu, + self.vocab_size, + ) for i in self._invalid_req_indices: - valid_sampled_token_ids[i].clear() + valid_sampled_token_ids[i] = np.array([]) output = self._model_runner_output output.sampled_token_ids = valid_sampled_token_ids @@ -240,7 +250,6 @@ class ExecuteModelState(NamedTuple): hidden_states: torch.Tensor sample_hidden_states: torch.Tensor aux_hidden_states: list[torch.Tensor] | None - kv_connector_output: KVConnectorOutput | None ec_connector_output: ECConnectorOutput | None @@ -375,6 +384,10 @@ class GPUModelRunner( ) self.rejection_sampler = RejectionSampler(self.sampler) + self.num_spec_tokens = 0 + if self.speculative_config: + self.num_spec_tokens = self.speculative_config.num_speculative_tokens + # Request states. self.requests: dict[str, CachedRequestState] = {} self.comm_stream = torch.cuda.Stream() @@ -511,11 +524,7 @@ class GPUModelRunner( self.max_num_tokens, dtype=torch.int32, device=self.device ) - self.uniform_decode_query_len = ( - 1 - if not self.speculative_config - else 1 + self.speculative_config.num_speculative_tokens - ) + self.uniform_decode_query_len = 1 + self.num_spec_tokens # Cudagraph dispatcher for runtime cudagraph dispatching. self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config) @@ -547,8 +556,23 @@ class GPUModelRunner( pin_memory=self.pin_memory, ) + # Pre-allocated tensor for copying valid sampled token counts to CPU, + # with dedicated stream for overlapping and event for coordination. + self.valid_sampled_token_count_event: torch.cuda.Event | None = None + self.valid_sampled_token_count_copy_stream: torch.cuda.Stream | None = None + if self.use_async_scheduling and self.num_spec_tokens: + self.valid_sampled_token_count_event = torch.cuda.Event() + self.valid_sampled_token_count_copy_stream = torch.cuda.Stream() + self.valid_sampled_token_count_cpu = torch.empty( + self.max_num_reqs, + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory, + ) + # Ephemeral state transferred between execute_model() and sample_tokens(). self.execute_model_state: ExecuteModelState | None = None + self.kv_connector_output: KVConnectorOutput | None = None def reset_mm_cache(self) -> None: if self.mm_budget: @@ -628,16 +652,6 @@ class GPUModelRunner( return if self.reorder_batch_threshold is not None: - # NOTE(lucas): currently no backend supports the custom masking - # required for DCP with q_len > 1, so we assert here. Remove this - # assert once the custom mask is support is added to FA3. - if ( - self.dcp_world_size > 1 - and envs.VLLM_ATTENTION_BACKEND != "FLASH_ATTN_MLA" - ): - assert self.reorder_batch_threshold == 1, ( - "DCP not support reorder_batch_threshold > 1 now." - ) reorder_batch_to_split_decodes_and_prefills( self.input_batch, scheduler_output, @@ -744,17 +758,45 @@ class GPUModelRunner( # Update the states of the running/resumed requests. is_last_rank = get_pp_group().is_last_rank req_data = scheduler_output.scheduled_cached_reqs + + # Wait until valid_sampled_tokens_count is copied to cpu, + # then use it to update actual num_computed_tokens of each request. + valid_sampled_token_count = self._get_valid_sampled_token_count() + for i, req_id in enumerate(req_data.req_ids): req_state = self.requests[req_id] num_computed_tokens = req_data.num_computed_tokens[i] new_block_ids = req_data.new_block_ids[i] resumed_from_preemption = req_id in req_data.resumed_req_ids num_output_tokens = req_data.num_output_tokens[i] + req_index = self.input_batch.req_id_to_index.get(req_id) + + # prev_num_draft_len is used in async scheduling mode with + # spec decode. it indicates if need to update num_computed_tokens + # of the request. for example: + # fist step: num_computed_tokens = 0, spec_tokens = [], + # prev_num_draft_len = 0. + # second step: num_computed_tokens = 100(prompt lenth), + # spec_tokens = [a,b], prev_num_draft_len = 0. + # third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d], + # prev_num_draft_len = 2. + # num_computed_tokens in first step and second step does't contain + # the spec tokens length, but in third step it contains the + # spec tokens length. we only need to update num_computed_tokens + # when prev_num_draft_len > 0. + if req_state.prev_num_draft_len: + if req_index is None: + req_state.prev_num_draft_len = 0 + else: + assert self.input_batch.prev_req_id_to_index is not None + prev_req_index = self.input_batch.prev_req_id_to_index[req_id] + num_accepted = valid_sampled_token_count[prev_req_index] - 1 + num_rejected = req_state.prev_num_draft_len - num_accepted + num_computed_tokens -= num_rejected + req_state.output_token_ids.extend([-1] * num_accepted) # Update the cached states. - req_state.num_computed_tokens = num_computed_tokens - req_index = self.input_batch.req_id_to_index.get(req_id) if not is_last_rank: # When using PP, the scheduler sends the sampled tokens back, @@ -831,8 +873,11 @@ class GPUModelRunner( spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( req_id, [] ) - if spec_token_ids: - num_spec_tokens = len(spec_token_ids) + num_spec_tokens = len(spec_token_ids) + # For async scheduling, token_ids_cpu assigned from + # spec_token_ids are placeholders and will be overwritten in + # _prepare_input_ids. + if num_spec_tokens: start_index = self.input_batch.num_tokens_no_spec[req_index] end_token_index = start_index + num_spec_tokens self.input_batch.token_ids_cpu[ @@ -848,6 +893,15 @@ class GPUModelRunner( # even when speculative decoding is enabled. self.input_batch.spec_token_ids[req_index] = spec_token_ids + # there are no draft tokens with async scheduling, + # we clear the spec_decoding info in scheduler_output and + # use normal sampling but rejection_sampling. + if self.use_async_scheduling: + req_state.prev_num_draft_len = num_spec_tokens + if num_spec_tokens and self._draft_token_ids is None: + scheduler_output.total_num_scheduled_tokens -= num_spec_tokens + scheduler_output.num_scheduled_tokens[req_id] -= num_spec_tokens + scheduler_output.scheduled_spec_decode_tokens.pop(req_id, None) # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. for request in reqs_to_add: @@ -967,7 +1021,10 @@ class GPUModelRunner( return cu_num_tokens, arange def _prepare_input_ids( - self, total_num_scheduled_tokens: int, cu_num_tokens: np.ndarray + self, + scheduler_output: "SchedulerOutput", + total_num_scheduled_tokens: int, + cu_num_tokens: np.ndarray, ) -> None: """Prepare the input IDs for the current batch. @@ -988,21 +1045,43 @@ class GPUModelRunner( # on the GPU from prev_sampled_token_ids. prev_req_id_to_index = self.input_batch.prev_req_id_to_index assert prev_req_id_to_index is not None - flattened_indices = [] - prev_common_req_indices = [] + sample_flattened_indices: list[int] = [] + spec_flattened_indices: list[int] = [] + prev_common_req_indices: list[int] = [] + prev_draft_token_indices: list[int] = [] indices_match = True max_flattened_index = -1 + total_num_spec_tokens = 0 + scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens + for req_id, cur_index in self.input_batch.req_id_to_index.items(): if (prev_index := prev_req_id_to_index.get(req_id)) is not None: prev_common_req_indices.append(prev_index) # We need to compute the flattened input_ids index of the # last token in each common request. + draft_len = len(scheduled_spec_tokens.get(req_id, ())) + total_num_spec_tokens += draft_len flattened_index = cu_num_tokens[cur_index].item() - 1 - flattened_indices.append(flattened_index) + # example: cu_num_tokens = [2, 5, 8], draft_tokens = [1, 2, 2] + # sample_flattened_indices = [0, 2, 5] + # spec_flattened_indices = [1, 3, 4, 6, 7] + sample_flattened_indices.append(flattened_index - draft_len) + spec_flattened_indices.extend( + range(flattened_index - draft_len + 1, flattened_index + 1) + ) + start = prev_index * self.num_spec_tokens + # prev_draft_token_indices is used to find which draft_tokens_id + # should be copied to input_ids + # example: prev draft_tokens_id [[1,2], [3,4], [5, 6]] + # flatten draft_tokens_id [1,2,3,4,5,6] + # draft_len of each request [1, 2, 1] + # then prev_draft_token_indices is [0, 2, 3, 4] + prev_draft_token_indices.extend(range(start, start + draft_len)) indices_match &= prev_index == flattened_index max_flattened_index = max(max_flattened_index, flattened_index) - num_commmon_tokens = len(flattened_indices) - if num_commmon_tokens < total_num_scheduled_tokens: + num_commmon_tokens = len(sample_flattened_indices) + total_without_spec = total_num_scheduled_tokens - total_num_spec_tokens + if num_commmon_tokens < total_without_spec: # If not all requests are decodes from the last iteration, # We need to copy the input_ids_cpu to the GPU first. self.input_ids.copy_to_gpu(total_num_scheduled_tokens) @@ -1026,20 +1105,43 @@ class GPUModelRunner( self.is_token_ids.gpu[:num_commmon_tokens] = True return # Upload the index tensors asynchronously so the scatter can be non-blocking. - input_ids_index_tensor = torch.tensor( - flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory + sampled_tokens_index_tensor = torch.tensor( + sample_flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory ).to(self.device, non_blocking=True) prev_common_req_indices_tensor = torch.tensor( prev_common_req_indices, dtype=torch.int64, pin_memory=self.pin_memory ).to(self.device, non_blocking=True) self.input_ids.gpu.scatter_( dim=0, - index=input_ids_index_tensor, + index=sampled_tokens_index_tensor, src=self.input_batch.prev_sampled_token_ids[ prev_common_req_indices_tensor, 0 ], ) + # Scatter the draft tokens after the sampled tokens are scattered. + if self._draft_token_ids is None or not spec_flattened_indices: + return + + assert isinstance(self._draft_token_ids, torch.Tensor) + draft_tokens_index_tensor = torch.tensor( + spec_flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory + ).to(self.device, non_blocking=True) + prev_draft_token_indices_tensor = torch.tensor( + prev_draft_token_indices, dtype=torch.int64, pin_memory=self.pin_memory + ).to(self.device, non_blocking=True) + + # because input_ids dtype is torch.int32, + # so convert draft_token_ids to torch.int32 here. + draft_token_ids = self._draft_token_ids.to(dtype=torch.int32) + self._draft_token_ids = None + + self.input_ids.gpu.scatter_( + dim=0, + index=draft_tokens_index_tensor, + src=draft_token_ids.flatten()[prev_draft_token_indices_tensor], + ) + def _get_encoder_seq_lens( self, scheduled_encoder_inputs: dict[str, list[int]], @@ -1226,7 +1328,11 @@ class GPUModelRunner( self.discard_request_indices.copy_to_gpu(self.num_discarded_requests) # Copy the tensors to the GPU. - self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens) + self._prepare_input_ids( + scheduler_output, + total_num_scheduled_tokens, + cu_num_tokens, + ) if self.uses_mrope: # Only relevant for models using M-RoPE (e.g, Qwen2-VL) @@ -1314,7 +1420,7 @@ class GPUModelRunner( :return: tuple[attn_metadata, spec_decode_common_attn_metadata] """ logits_indices_padded = None - num_logits_indices = 0 + num_logits_indices = None if logits_indices is not None: num_logits_indices = logits_indices.size(0) if self.cache_config.kv_sharing_fast_prefill: @@ -2031,7 +2137,7 @@ class GPUModelRunner( supported_tasks = list(model.pooler.get_supported_tasks()) - if self.scheduler_config.chunked_prefill_enabled: + if self.scheduler_config.enable_chunked_prefill: if "token_embed" in supported_tasks: supported_tasks.remove("token_embed") if "token_classify" in supported_tasks: @@ -2339,7 +2445,7 @@ class GPUModelRunner( ) -> tuple[ dict[str, int], LogprobsLists | None, - list[list[int]], + list[np.ndarray], dict[str, LogprobsTensors | None], list[str], dict[str, int], @@ -2365,6 +2471,7 @@ class GPUModelRunner( num_sampled_tokens = sampler_output.sampled_token_ids.shape[0] sampled_token_ids = sampler_output.sampled_token_ids invalid_req_indices = [] + valid_sampled_token_ids: list[np.ndarray] if not self.use_async_scheduling: # Get the valid generated tokens. max_gen_len = sampled_token_ids.shape[-1] @@ -2379,17 +2486,19 @@ class GPUModelRunner( ) # Mask out the sampled tokens that should not be sampled. for i in discard_sampled_tokens_req_indices: - valid_sampled_token_ids[int(i)].clear() + valid_sampled_token_ids[int(i)] = np.array([]) else: valid_sampled_token_ids = [] invalid_req_indices = discard_sampled_tokens_req_indices.tolist() invalid_req_indices_set = set(invalid_req_indices) - assert sampled_token_ids.shape[-1] == 1 # Cache the sampled tokens on the GPU and avoid CPU sync. # These will be copied into input_ids in the next step # when preparing inputs. - self.input_batch.prev_sampled_token_ids = sampled_token_ids + # With spec decoding, this is done in propose_draft_token_ids(). + if self.input_batch.prev_sampled_token_ids is None: + assert sampled_token_ids.shape[-1] == 1 + self.input_batch.prev_sampled_token_ids = sampled_token_ids self.input_batch.prev_req_id_to_index = { req_id: i for i, req_id in enumerate(self.input_batch.req_ids) @@ -2407,19 +2516,24 @@ class GPUModelRunner( [0] if spec_decode_metadata and logprobs_tensors else None ) for req_idx in range(num_sampled_tokens): + sampled_ids: np.ndarray | None if self.use_async_scheduling: - sampled_ids = [-1] if req_idx not in invalid_req_indices_set else None + sampled_ids = ( + np.array([-1]) if req_idx not in invalid_req_indices_set else None + ) else: sampled_ids = valid_sampled_token_ids[req_idx] - num_sampled_ids: int = len(sampled_ids) if sampled_ids else 0 + num_sampled_ids: int = ( + sampled_ids.shape[0] if sampled_ids is not None else 0 + ) if cu_num_accepted_tokens is not None: cu_num_accepted_tokens.append( cu_num_accepted_tokens[-1] + num_sampled_ids ) - if not sampled_ids: + if sampled_ids is None or num_sampled_ids == 0: continue start_idx = self.input_batch.num_tokens_no_spec[req_idx] @@ -2519,6 +2633,21 @@ class GPUModelRunner( "State error: sample_tokens() must be called " "after execute_model() returns None." ) + + # self._draft_token_ids is None when `input_fits_in_drafter=False` + # and there is no draft tokens scheduled. so it need to update the + # spec_decoding info in scheduler_output with async_scheduling. + # use deepcopy to avoid the modification has influence on the + # scheduler_output in engine core process. + # TODO(Ronald1995): deepcopy is expensive when there is a large + # number of requests, optimize it later. + if ( + self.use_async_scheduling + and self.num_spec_tokens + and self._draft_token_ids is None + ): + scheduler_output = deepcopy(scheduler_output) + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens with record_function_or_nullcontext("gpu_model_runner: preprocess"): with self.synchronize_input_prep(): @@ -2534,6 +2663,18 @@ class GPUModelRunner( return make_empty_encoder_model_runner_output(scheduler_output) if not num_scheduled_tokens: + if ( + self.parallel_config.distributed_executor_backend + == "external_launcher" + and self.parallel_config.data_parallel_size > 1 + ): + # this is a corner case when both external launcher + # and DP are enabled, num_scheduled_tokens could be + # 0, and has_unfinished_requests in the outer loop + # returns True. before returning early here we call + # dummy run to ensure coordinate_batch_across_dp + # is called into to avoid out of sync issues. + self._dummy_run(1) if not has_kv_transfer_group(): # Return empty ModelRunnerOutput if no work to do. return EMPTY_MODEL_RUNNER_OUTPUT @@ -2590,28 +2731,28 @@ class GPUModelRunner( ) ) - dp_rank = self.parallel_config.data_parallel_rank - if ubatch_slices: - assert num_tokens_across_dp is not None - num_input_tokens = int(num_tokens_across_dp[dp_rank].item()) - self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens) - elif num_tokens_across_dp is not None: - num_input_tokens = int(num_tokens_across_dp[dp_rank].item()) - else: - num_input_tokens = self._get_num_input_tokens( - scheduler_output.total_num_scheduled_tokens - ) + dp_rank = self.parallel_config.data_parallel_rank + if ubatch_slices: + assert num_tokens_across_dp is not None + num_input_tokens = int(num_tokens_across_dp[dp_rank].item()) + self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens) + elif num_tokens_across_dp is not None: + num_input_tokens = int(num_tokens_across_dp[dp_rank].item()) + else: + num_input_tokens = self._get_num_input_tokens( + scheduler_output.total_num_scheduled_tokens + ) - ( - input_ids, - inputs_embeds, - positions, - intermediate_tensors, - model_kwargs, - ec_connector_output, - ) = self._preprocess( - scheduler_output, num_input_tokens, intermediate_tensors - ) + ( + input_ids, + inputs_embeds, + positions, + intermediate_tensors, + model_kwargs, + ec_connector_output, + ) = self._preprocess( + scheduler_output, num_input_tokens, intermediate_tensors + ) uniform_decode = ( max_num_scheduled_tokens == self.uniform_decode_query_len @@ -2674,6 +2815,7 @@ class GPUModelRunner( # Return the intermediate tensors. assert isinstance(hidden_states, IntermediateTensors) hidden_states.kv_connector_output = kv_connector_output + self.kv_connector_output = kv_connector_output return hidden_states if self.is_pooling_model: @@ -2724,18 +2866,31 @@ class GPUModelRunner( hidden_states, sample_hidden_states, aux_hidden_states, - kv_connector_output, ec_connector_output, ) + self.kv_connector_output = kv_connector_output return None @torch.inference_mode def sample_tokens( self, grammar_output: "GrammarOutput | None" ) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors: + kv_connector_output = self.kv_connector_output + self.kv_connector_output = None + if self.execute_model_state is None: # Nothing to do (PP non-final rank case), output isn't used. - return None # noqa + if not kv_connector_output: + return None # noqa + + # In case of PP with kv transfer, we need to pass through the + # kv_connector_output + if kv_connector_output.is_empty(): + return EMPTY_MODEL_RUNNER_OUTPUT + + output = copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.kv_connector_output = kv_connector_output + return output # Unpack ephemeral state. ( @@ -2746,7 +2901,6 @@ class GPUModelRunner( hidden_states, sample_hidden_states, aux_hidden_states, - kv_connector_output, ec_connector_output, ) = self.execute_model_state # Clear ephemeral state. @@ -2761,7 +2915,11 @@ class GPUModelRunner( with record_function_or_nullcontext("gpu_model_runner: sample"): sampler_output = self._sample(logits, spec_decode_metadata) - def propose_draft_token_ids(sampled_token_ids): + self.input_batch.prev_sampled_token_ids = None + + def propose_draft_token_ids( + sampled_token_ids: torch.Tensor | list[np.ndarray], + ) -> None: assert spec_decode_common_attn_metadata is not None with record_function_or_nullcontext("gpu_model_runner: draft"): self._draft_token_ids = self.propose_draft_token_ids( @@ -2792,14 +2950,29 @@ class GPUModelRunner( self.speculative_config.draft_model_config.max_model_len ) input_fits_in_drafter = spec_decode_common_attn_metadata and ( - spec_decode_common_attn_metadata.max_seq_len - + self.speculative_config.num_speculative_tokens + spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens <= effective_drafter_max_model_len ) - if use_padded_batch_for_eagle and input_fits_in_drafter: - # EAGLE speculative decoding can use the GPU sampled tokens - # as inputs, and does not need to wait for bookkeeping to finish. - propose_draft_token_ids(sampler_output.sampled_token_ids) + if use_padded_batch_for_eagle: + sampled_token_ids = sampler_output.sampled_token_ids + if input_fits_in_drafter: + # EAGLE speculative decoding can use the GPU sampled tokens + # as inputs, and does not need to wait for bookkeeping to finish. + propose_draft_token_ids(sampled_token_ids) + elif self.valid_sampled_token_count_event is not None: + next_token_ids, valid_sampled_tokens_count = ( + self.drafter.prepare_next_token_ids_padded( + spec_decode_common_attn_metadata, + sampled_token_ids, + self.requests, + self.input_batch, + self.discard_request_indices.gpu, + self.num_discarded_requests, + ) + ) + self._copy_valid_sampled_token_count( + next_token_ids, valid_sampled_tokens_count + ) with record_function_or_nullcontext("gpu_model_runner: bookkeep"): ( @@ -2856,12 +3029,13 @@ class GPUModelRunner( logprobs_tensors=sampler_output.logprobs_tensors, invalid_req_indices=invalid_req_indices, async_output_copy_stream=self.async_output_copy_stream, + vocab_size=self.input_batch.vocab_size, ) with record_function_or_nullcontext( "gpu_model_runner: set_async_sampled_token_ids" ): # Save ref of sampled_token_ids CPU tensor if the batch contains - # any requests with sampling params that that require output ids. + # any requests with sampling params that require output ids. self.input_batch.set_async_sampled_token_ids( async_output.sampled_token_ids_cpu, async_output.async_copy_ready_event, @@ -2880,17 +3054,48 @@ class GPUModelRunner( self._draft_token_ids = None return DraftTokenIds(req_ids, draft_token_ids) + def _copy_valid_sampled_token_count( + self, next_token_ids: torch.Tensor, valid_sampled_tokens_count: torch.Tensor + ) -> None: + if self.valid_sampled_token_count_event is None: + return + + default_stream = torch.cuda.current_stream() + # Initialize a new stream to overlap the copy operation with + # prepare_input of draft model. + with torch.cuda.stream(self.valid_sampled_token_count_copy_stream): + self.valid_sampled_token_count_copy_stream.wait_stream(default_stream) # type: ignore + counts = valid_sampled_tokens_count + counts_cpu = self.valid_sampled_token_count_cpu + counts_cpu[: counts.shape[0]].copy_(counts, non_blocking=True) + self.valid_sampled_token_count_event.record() + + self.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze(1) + + def _get_valid_sampled_token_count(self) -> list[int]: + # Wait until valid_sampled_tokens_count is copied to cpu, + prev_sampled_token_ids = self.input_batch.prev_sampled_token_ids + if ( + self.valid_sampled_token_count_event is None + or prev_sampled_token_ids is None + ): + return [] + + counts_cpu = self.valid_sampled_token_count_cpu + self.valid_sampled_token_count_event.synchronize() + return counts_cpu[: prev_sampled_token_ids.shape[0]].tolist() + def propose_draft_token_ids( self, scheduler_output: "SchedulerOutput", - sampled_token_ids: torch.Tensor | list[list[int]], + sampled_token_ids: torch.Tensor | list[np.ndarray], sampling_metadata: SamplingMetadata, hidden_states: torch.Tensor, sample_hidden_states: torch.Tensor, aux_hidden_states: list[torch.Tensor] | None, spec_decode_metadata: SpecDecodeMetadata | None, common_attn_metadata: CommonAttentionMetadata, - ) -> list[list[int]] | torch.Tensor: + ) -> torch.Tensor | list[list[int]]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if self.speculative_config.method == "ngram": assert isinstance(sampled_token_ids, list) @@ -2922,7 +3127,7 @@ class GPUModelRunner( for num_draft, tokens in zip( spec_decode_metadata.num_draft_tokens, sampled_token_ids ): - indices.append(offset + len(tokens) - 1) + indices.append(offset + tokens.shape[0] - 1) offset += num_draft + 1 indices = torch.tensor(indices, device=self.device) hidden_states = sample_hidden_states[indices] @@ -2967,6 +3172,9 @@ class GPUModelRunner( self.num_discarded_requests, ) ) + self._copy_valid_sampled_token_count( + next_token_ids, valid_sampled_tokens_count + ) if spec_decode_metadata is None: token_indices_to_sample = None @@ -3532,7 +3740,7 @@ class GPUModelRunner( # TODO(luka) better system for describing dummy batches seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1] else: - seq_lens = max_query_len + seq_lens = max_query_len # type: ignore[assignment] self.seq_lens.np[:num_reqs] = seq_lens self.seq_lens.np[num_reqs:] = 0 self.seq_lens.copy_to_gpu() @@ -3825,7 +4033,7 @@ class GPUModelRunner( supported_pooling_tasks = self.get_supported_pooling_tasks() if not supported_pooling_tasks: - if self.scheduler_config.chunked_prefill_enabled: + if self.scheduler_config.enable_chunked_prefill: raise RuntimeError( f"Model {self.model_config.model} does not support " "any pooling tasks with chunked prefill enabled. " @@ -4332,6 +4540,22 @@ class GPUModelRunner( "and make sure compilation mode is VLLM_COMPILE" ) + # if we have dedicated decode cudagraphs, and spec-decode is enabled, + # we need to adjust the cudagraph sizes to be a multiple of the uniform + # decode query length to avoid: https://github.com/vllm-project/vllm/issues/28207 + # temp-fix: https://github.com/vllm-project/vllm/issues/28207#issuecomment-3504004536 + # Will be removed in the near future when we have seperate cudagraph capture + # sizes for decode and mixed prefill-decode. + if ( + cudagraph_mode.decode_mode() == CUDAGraphMode.FULL + and cudagraph_mode.separate_routine() + and self.uniform_decode_query_len > 1 + ): + self.compilation_config.adjust_cudagraph_sizes_for_spec_decode( + self.uniform_decode_query_len, self.parallel_config.tensor_parallel_size + ) + self.cudagraph_batch_sizes = self.compilation_config.cudagraph_capture_sizes + # Trigger cudagraph dispatching keys initialization after # resolved cudagraph mode. self.cudagraph_dispatcher.initialize_cudagraph_keys( @@ -4469,11 +4693,7 @@ class GPUModelRunner( logitsprocs=self.input_batch.logitsprocs, logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids, is_pooling_model=self.is_pooling_model, - num_speculative_tokens=( - self.vllm_config.speculative_config.num_speculative_tokens - if self.vllm_config.speculative_config - else 0 - ), + num_speculative_tokens=self.num_spec_tokens, ) def _allocate_kv_cache_tensors( @@ -4862,7 +5082,7 @@ class GPUModelRunner( return kv_cache_spec - def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]: + def _to_list(self, sampled_token_ids: torch.Tensor) -> list[np.ndarray]: # This is a short term mitigation for issue mentioned in # https://github.com/vllm-project/vllm/issues/22754. # `tolist` would trigger a cuda wise stream sync, which @@ -4875,4 +5095,4 @@ class GPUModelRunner( pinned.copy_(sampled_token_ids, non_blocking=True) self.transfer_event.record() self.transfer_event.synchronize() - return pinned.tolist() + return [row for row in pinned.numpy()] diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 2b9d8bb2f25e6..315f01b68499a 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """A GPU worker class.""" -import copy import gc import os from contextlib import AbstractContextManager, nullcontext @@ -45,7 +44,6 @@ from vllm.v1.core.sched.output import GrammarOutput from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ( - EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput, @@ -189,6 +187,7 @@ class Worker(WorkerBase): and self.parallel_config.distributed_executor_backend not in ["ray", "external_launcher"] and self.vllm_config.parallel_config.data_parallel_backend != "ray" + and self.vllm_config.parallel_config.nnodes_within_dp == 1 ): # Use local DP rank if available, otherwise use global DP rank. dp_local_rank = self.parallel_config.data_parallel_rank_local @@ -205,7 +204,14 @@ class Worker(WorkerBase): assert self.local_rank < torch.cuda.device_count(), ( f"DP adjusted local rank {self.local_rank} is out of bounds. " ) - + visible_device_count = ( + torch.cuda.device_count() if torch.cuda.is_available() else 0 + ) + assert self.parallel_config.local_world_size <= visible_device_count, ( + f"local_world_size ({self.parallel_config.local_world_size}) must be " + f"less than or equal to the number of visible devices " + f"({visible_device_count})." + ) self.device = torch.device(f"cuda:{self.local_rank}") current_platform.set_device(self.device) @@ -573,18 +579,7 @@ class Worker(WorkerBase): all_gather_tensors=all_gather_tensors, ) - kv_connector_output = output.kv_connector_output - if not kv_connector_output: - return None - - # In case of PP with kv transfer, we need to pass through the - # kv_connector_output - if kv_connector_output.is_empty(): - return EMPTY_MODEL_RUNNER_OUTPUT - - output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) - output.kv_connector_output = kv_connector_output - return output + return None def take_draft_token_ids(self) -> DraftTokenIds | None: return self.model_runner.take_draft_token_ids() @@ -596,14 +591,19 @@ class Worker(WorkerBase): self.profiler.start() else: self.profiler.stop() - # only print profiler results on rank 0 - if ( - isinstance(self.profiler, torch.profiler.profile) - and self.local_rank == 0 - ): - print( - self.profiler.key_averages().table(sort_by="self_cuda_time_total") - ) + if isinstance(self.profiler, torch.profiler.profile): + rank = self.local_rank + profiler_dir = envs.VLLM_TORCH_PROFILER_DIR + profiler_out_file = f"{profiler_dir}/profiler_out_{rank}.txt" + sort_key = "self_cuda_time_total" + table = self.profiler.key_averages().table(sort_by=sort_key) + + with open(profiler_out_file, "w") as f: + print(table, file=f) + + # only print profiler results on rank 0 + if rank == 0: + print(table) def execute_dummy_batch(self) -> None: self.model_runner._dummy_run(1, uniform_decode=True) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 0f90578671db5..e9eb7cad38f88 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -21,7 +21,7 @@ from vllm.attention import Attention from vllm.attention.backends.abstract import AttentionType from vllm.attention.layer import MLAAttention from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention -from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher +from vllm.compilation.wrapper import TorchCompileWithNoGuardsWrapper from vllm.config import ( ParallelConfig, VllmConfig, @@ -1254,13 +1254,15 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): max_gen_len = selected_token_ids.shape[-1] if max_gen_len == 1: - valid_sampled_token_ids = selected_token_ids.tolist() + valid_sampled_token_ids: list[np.ndarray] = [ + row for row in selected_token_ids.numpy() + ] # Mask out the sampled tokens that should not be sampled. # TODO: Keep in sync with gpu_model_runner.py, in particular # the "else" case here for i in discard_sampled_tokens_req_indices: - valid_sampled_token_ids[i].clear() + valid_sampled_token_ids[i] = np.array([]) # Append sampled tokens for i, req_state, seq_len in request_seq_lens: @@ -1273,7 +1275,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): valid_mask = selected_token_ids != INVALID_TOKEN_ID gen_lens = valid_mask.sum(dim=1).tolist() valid_sampled_token_ids = [ - seq.tolist() for seq in selected_token_ids[valid_mask].split(gen_lens) + seq.numpy() for seq in selected_token_ids[valid_mask].split(gen_lens) ] self.input_batch.num_tokens[:num_reqs] += gen_lens for i, req_state, seq_len in request_seq_lens: @@ -1895,12 +1897,14 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): compiled_model = self.model.get_language_model().model else: compiled_model = self.model.model - if isinstance(compiled_model, TorchCompileWrapperWithCustomDispatcher): + if isinstance(compiled_model, TorchCompileWithNoGuardsWrapper): logger.info("Clear dynamo cache and cached dynamo bytecode.") torch._dynamo.eval_frame.remove_from_cache( - compiled_model.original_code_object + compiled_model.original_code_object() ) - compiled_model.compiled_codes.clear() + # Reset the wrapper to re-initialize. + compiled_model.compiled = False + TorchCompileWithNoGuardsWrapper.__init__(compiled_model) @torch.compile(backend="openxla", fullgraph=True, dynamic=False) def select_hidden_states(self, hidden_states, indices_do_sample): diff --git a/vllm/v1/worker/worker_base.py b/vllm/v1/worker/worker_base.py index 30ea0ab77bd9e..16f321c080779 100644 --- a/vllm/v1/worker/worker_base.py +++ b/vllm/v1/worker/worker_base.py @@ -13,7 +13,6 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.cache import worker_receiver_cache_from_config -from vllm.utils import warn_for_unimplemented_methods from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.system_utils import update_environment_variables from vllm.v1.kv_cache_interface import KVCacheSpec @@ -33,7 +32,6 @@ logger = init_logger(__name__) _R = TypeVar("_R") -@warn_for_unimplemented_methods class WorkerBase: """Worker interface that allows vLLM to cleanly separate implementations for different hardware. Also abstracts control plane communication, e.g., to @@ -182,6 +180,7 @@ class WorkerWrapperBase: self, vllm_config: VllmConfig, rpc_rank: int = 0, + global_rank: int | None = None, ) -> None: """ Initialize the worker wrapper with the given vllm_config and rpc_rank. @@ -194,6 +193,7 @@ class WorkerWrapperBase: group. """ self.rpc_rank = rpc_rank + self.global_rank = self.rpc_rank if global_rank is None else global_rank self.worker: WorkerBase | None = None # do not store this `vllm_config`, `init_worker` will set the final @@ -314,7 +314,7 @@ class WorkerWrapperBase: assert self.worker is not None def initialize_from_config(self, kv_cache_configs: list[Any]) -> None: - kv_cache_config = kv_cache_configs[self.rpc_rank] + kv_cache_config = kv_cache_configs[self.global_rank] with set_current_vllm_config(self.vllm_config): self.worker.initialize_from_config(kv_cache_config) # type: ignore diff --git a/vllm/v1/worker/xpu_worker.py b/vllm/v1/worker/xpu_worker.py index 31fa3f3bd6acc..26c6f8d06bdcd 100644 --- a/vllm/v1/worker/xpu_worker.py +++ b/vllm/v1/worker/xpu_worker.py @@ -159,12 +159,10 @@ class XPUWorker(Worker): else: raise RuntimeError(f"Not support device type: {self.device_config.device}") - ENV_CCL_ZE_IPC_EXCHANGE = os.getenv("CCL_ZE_IPC_EXCHANGE", "pidfd") ENV_CCL_ATL_TRANSPORT = os.getenv("CCL_ATL_TRANSPORT", "ofi") ENV_LOCAL_WORLD_SIZE = os.getenv( "LOCAL_WORLD_SIZE", str(self.parallel_config.world_size) ) - os.environ["CCL_ZE_IPC_EXCHANGE"] = ENV_CCL_ZE_IPC_EXCHANGE os.environ["CCL_ATL_TRANSPORT"] = ENV_CCL_ATL_TRANSPORT os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE os.environ["LOCAL_RANK"] = str(self.local_rank)