diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml index 597dfbf99022d..a1de41652c9a6 100644 --- a/.buildkite/release-pipeline.yaml +++ b/.buildkite/release-pipeline.yaml @@ -149,3 +149,25 @@ steps: - "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version)" env: DOCKER_BUILDKIT: "1" + + - label: "Build and publish nightly multi-arch image to DockerHub" + depends_on: + - create-multi-arch-manifest + if: build.env("NIGHTLY") == "1" + agents: + queue: cpu_queue_postmerge + commands: + - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" + - "docker pull public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT" + - "docker tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT vllm/vllm-openai:nightly" + - "docker tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT vllm/vllm-openai:nightly-$BUILDKITE_COMMIT" + - "docker push vllm/vllm-openai:nightly" + - "docker push vllm/vllm-openai:nightly-$BUILDKITE_COMMIT" + # Clean up old nightly builds (keep only last 14) + - "bash .buildkite/scripts/cleanup-nightly-builds.sh" + plugins: + - docker-login#v3.0.0: + username: vllmbot + password-env: DOCKERHUB_TOKEN + env: + DOCKER_BUILDKIT: "1" diff --git a/.buildkite/scripts/cleanup-nightly-builds.sh b/.buildkite/scripts/cleanup-nightly-builds.sh new file mode 100755 index 0000000000000..1a82f7d085233 --- /dev/null +++ b/.buildkite/scripts/cleanup-nightly-builds.sh @@ -0,0 +1,97 @@ +#!/bin/bash + +set -ex + +# Clean up old nightly builds from DockerHub, keeping only the last 14 builds +# This script uses DockerHub API to list and delete old tags with "nightly-" prefix + +# DockerHub API endpoint for vllm/vllm-openai repository +REPO_API_URL="https://hub.docker.com/v2/repositories/vllm/vllm-openai/tags" + +# Get DockerHub token from environment +if [ -z "$DOCKERHUB_TOKEN" ]; then + echo "Error: DOCKERHUB_TOKEN environment variable is not set" + exit 1 +fi + +# Function to get all tags from DockerHub +get_all_tags() { + local page=1 + local all_tags="" + + while true; do + local response=$(curl -s -H "Authorization: Bearer $DOCKERHUB_TOKEN" \ + "$REPO_API_URL?page=$page&page_size=100") + + # Get both last_updated timestamp and tag name, separated by | + local tags=$(echo "$response" | jq -r '.results[] | select(.name | startswith("nightly-")) | "\(.last_updated)|\(.name)"') + + if [ -z "$tags" ]; then + break + fi + + all_tags="$all_tags$tags"$'\n' + page=$((page + 1)) + done + + # Sort by timestamp (newest first) and extract just the tag names + echo "$all_tags" | sort -r | cut -d'|' -f2 +} + +delete_tag() { + local tag_name="$1" + echo "Deleting tag: $tag_name" + + local delete_url="https://hub.docker.com/v2/repositories/vllm/vllm-openai/tags/$tag_name" + local response=$(curl -s -X DELETE -H "Authorization: Bearer $DOCKERHUB_TOKEN" "$delete_url") + + if echo "$response" | jq -e '.detail' > /dev/null 2>&1; then + echo "Warning: Failed to delete tag $tag_name: $(echo "$response" | jq -r '.detail')" + else + echo "Successfully deleted tag: $tag_name" + fi +} + +# Get all nightly- prefixed tags, sorted by last_updated timestamp (newest first) +echo "Fetching all tags from DockerHub..." +all_tags=$(get_all_tags) + +if [ -z "$all_tags" ]; then + echo "No tags found to clean up" + exit 0 +fi + +# Count total tags +total_tags=$(echo "$all_tags" | wc -l) +echo "Found $total_tags tags" + +# Keep only the last 14 builds (including the current one) +tags_to_keep=14 +tags_to_delete=$((total_tags - tags_to_keep)) + +if [ $tags_to_delete -le 0 ]; then + echo "No tags need to be deleted (only $total_tags tags found, keeping $tags_to_keep)" + exit 0 +fi + +echo "Will delete $tags_to_delete old tags, keeping the newest $tags_to_keep" + +# Get tags to delete (skip the first $tags_to_keep tags) +tags_to_delete_list=$(echo "$all_tags" | tail -n +$((tags_to_keep + 1))) + +if [ -z "$tags_to_delete_list" ]; then + echo "No tags to delete" + exit 0 +fi + +# Delete old tags +echo "Deleting old tags..." +while IFS= read -r tag; do + if [ -n "$tag" ]; then + delete_tag "$tag" + # Add a small delay to avoid rate limiting + sleep 1 + fi +done <<< "$tags_to_delete_list" + +echo "Cleanup completed successfully" diff --git a/.buildkite/scripts/hardware_ci/run-xpu-test.sh b/.buildkite/scripts/hardware_ci/run-xpu-test.sh index efcd10acf0b93..8c9b00990e995 100644 --- a/.buildkite/scripts/hardware_ci/run-xpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-xpu-test.sh @@ -30,6 +30,7 @@ docker run \ bash -c ' set -e echo $ZE_AFFINITY_MASK + pip install tblib==3.1.0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 -O3 -O.cudagraph_mode=NONE python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index b0d4c4456d339..adb5c862eecd9 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -54,6 +54,7 @@ steps: - tests/utils_ - tests/worker - tests/standalone_tests/lazy_imports.py + - tests/transformers_utils commands: - python3 standalone_tests/lazy_imports.py - pytest -v -s mq_llm_engine # MQLLMEngine @@ -63,6 +64,7 @@ steps: - pytest -v -s multimodal - pytest -v -s utils_ # Utils - pytest -v -s worker # Worker + - pytest -v -s transformers_utils # transformers_utils - label: Python-only Installation Test # 10min timeout_in_minutes: 20 @@ -102,7 +104,18 @@ steps: commands: - pytest -v -s core -- label: Entrypoints Test (LLM) # 30min +- label: Entrypoints Unit Tests # 5min + timeout_in_minutes: 10 + working_dir: "/vllm-workspace/tests" + fast_check: true + source_file_dependencies: + - vllm/entrypoints + - tests/entrypoints/ + commands: + - pytest -v -s entrypoints/openai/tool_parsers + - pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/openai --ignore=entrypoints/offline_mode --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling + +- label: Entrypoints Integration Test (LLM) # 30min timeout_in_minutes: 40 mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" @@ -119,7 +132,7 @@ steps: - pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process - VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests -- label: Entrypoints Test (API Server) # 100min +- label: Entrypoints Integration Test (API Server) # 100min timeout_in_minutes: 130 mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" @@ -132,9 +145,22 @@ steps: commands: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/openai/test_collective_rpc.py # PYTHONPATH is needed to import custom Worker extension - - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_collective_rpc.py + - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_collective_rpc.py --ignore=entrypoints/openai/tool_parsers/ - pytest -v -s entrypoints/test_chat_utils.py +- label: Entrypoints Integration Test (Pooling) + timeout_in_minutes: 50 + mirror_hardwares: [amdexperimental] + working_dir: "/vllm-workspace/tests" + fast_check: true + torch_nightly: true + source_file_dependencies: + - vllm/ + - tests/entrypoints/pooling + commands: + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pytest -v -s entrypoints/pooling + - label: Distributed Tests (4 GPUs) # 35min timeout_in_minutes: 50 mirror_hardwares: [amdexperimental] @@ -205,7 +231,7 @@ steps: source_file_dependencies: - vllm/ - tests/metrics - - tests/tracing + - tests/v1/tracing commands: - pytest -v -s metrics - "pip install \ @@ -310,7 +336,6 @@ steps: - python3 offline_inference/vision_language_pooling.py --seed 0 - python3 offline_inference/vision_language_multi_image.py --seed 0 - VLLM_USE_V1=0 python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - - python3 offline_inference/encoder_decoder.py - python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0 - python3 offline_inference/basic/classify.py - python3 offline_inference/basic/embed.py @@ -379,11 +404,7 @@ steps: - tests/compile commands: - pytest -v -s compile/test_basic_correctness.py - # these tests need to be separated, cannot combine - - pytest -v -s compile/piecewise/test_simple.py - - pytest -v -s compile/piecewise/test_toy_llama.py - - pytest -v -s compile/piecewise/test_full_cudagraph.py - - pytest -v -s compile/piecewise/test_multiple_graphs.py + - pytest -v -s compile/piecewise/ - label: PyTorch Fullgraph Test # 20min timeout_in_minutes: 30 @@ -501,6 +522,10 @@ steps: commands: # temporary install here since we need nightly, will move to requirements/test.in # after torchao 0.12 release, and pin a working version of torchao nightly here + + # since torchao nightly is only compatible with torch nightly currently + # https://github.com/pytorch/ao/issues/2919, we'll have to skip new torchao tests for now + # we can only upgrade after this is resolved - pip install --pre torchao==0.13.0.dev20250814 --index-url https://download.pytorch.org/whl/nightly/cu128 - VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization @@ -546,36 +571,85 @@ steps: ##### models test ##### -- label: Basic Models Test # 57min - timeout_in_minutes: 75 +- label: Basic Models Tests (Initialization) + timeout_in_minutes: 45 mirror_hardwares: [amdexperimental] torch_nightly: true source_file_dependencies: - vllm/ - - tests/models + - tests/models/test_initialization.py commands: - - pytest -v -s models/test_transformers.py - - pytest -v -s models/test_registry.py - - pytest -v -s models/test_utils.py - - pytest -v -s models/test_vision.py - - pytest -v -s models/test_initialization.py + # Run a subset of model initialization tests + - pytest -v -s models/test_initialization.py::test_can_initialize_small_subset -- label: Language Models Test (Standard) # 35min +- label: Basic Models Tests (Extra Initialization) %N timeout_in_minutes: 45 mirror_hardwares: [amdexperimental] torch_nightly: true source_file_dependencies: + - vllm/model_executor/models/ + - tests/models/test_initialization.py + commands: + # Only when vLLM model source is modified - test initialization of a large + # subset of supported models (the complement of the small subset in the above + # test.) Also run if model initialization test file is modified + - pytest -v -s models/test_initialization.py \ + -k 'not test_can_initialize_small_subset' \ + --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT \ + --shard-id=$$BUILDKITE_PARALLEL_JOB + parallelism: 2 + +- label: Basic Models Tests (Other) + timeout_in_minutes: 45 + mirror_hardwares: [amdexperimental] + torch_nightly: true + source_file_dependencies: + - vllm/ + - tests/models/test_transformers.py + - tests/models/test_registry.py + - tests/models/test_utils.py + - tests/models/test_vision.py + commands: + - pytest -v -s models/test_transformers.py \ + models/test_registry.py \ + models/test_utils.py \ + models/test_vision.py + +- label: Language Models Tests (Standard) + timeout_in_minutes: 25 + mirror_hardwares: [amdexperimental] + torch_nightly: true + source_file_dependencies: - vllm/ - tests/models/language commands: + # Test standard language models, excluding a subset of slow tests - pip freeze | grep -E 'torch' - - pytest -v -s models/language -m core_model + - pytest -v -s models/language -m 'core_model and (not slow_test)' -- label: Language Models Test (Hybrid) # 35 min +- label: Language Models Tests (Extra Standard) %N timeout_in_minutes: 45 mirror_hardwares: [amdexperimental] torch_nightly: true source_file_dependencies: + - vllm/model_executor/models/ + - tests/models/language/pooling/test_embedding.py + - tests/models/language/generation/test_common.py + - tests/models/language/pooling/test_classification.py + commands: + # Shard slow subset of standard language models tests. Only run when model + # source is modified, or when specified test files are modified + - pip freeze | grep -E 'torch' + - pytest -v -s models/language -m 'core_model and slow_test' \ + --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT \ + --shard-id=$$BUILDKITE_PARALLEL_JOB + parallelism: 2 + +- label: Language Models Tests (Hybrid) %N + timeout_in_minutes: 75 + mirror_hardwares: [amdexperimental] + torch_nightly: true + source_file_dependencies: - vllm/ - tests/models/language/generation commands: @@ -583,7 +657,12 @@ steps: # Note: also needed to run plamo2 model in vLLM - uv pip install --system --no-build-isolation 'git+https://github.com/state-spaces/mamba@v2.2.5' - uv pip install --system --no-build-isolation 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.2' - - pytest -v -s models/language/generation -m hybrid_model + # Shard hybrid language model tests + - pytest -v -s models/language/generation \ + -m hybrid_model \ + --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT \ + --shard-id=$$BUILDKITE_PARALLEL_JOB + parallelism: 2 - label: Language Models Test (Extended Generation) # 80min timeout_in_minutes: 110 @@ -597,6 +676,16 @@ steps: - pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8' - pytest -v -s models/language/generation -m '(not core_model) and (not hybrid_model)' +- label: Language Models Test (PPL) + timeout_in_minutes: 110 + mirror_hardwares: [amdexperimental] + optional: true + source_file_dependencies: + - vllm/ + - tests/models/language/generation_ppl_test + commands: + - pytest -v -s models/language/generation_ppl_test + - label: Language Models Test (Extended Pooling) # 36min timeout_in_minutes: 50 mirror_hardwares: [amdexperimental] @@ -607,6 +696,16 @@ steps: commands: - pytest -v -s models/language/pooling -m 'not core_model' +- label: Language Models Test (MTEB) + timeout_in_minutes: 110 + mirror_hardwares: [amdexperimental] + optional: true + source_file_dependencies: + - vllm/ + - tests/models/language/pooling_mteb_test + commands: + - pytest -v -s models/language/pooling_mteb_test + - label: Multi-Modal Processor Test # 44min timeout_in_minutes: 60 source_file_dependencies: @@ -627,7 +726,7 @@ steps: - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pip freeze | grep -E 'torch' - pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing - - cd .. && pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work + - cd .. && VLLM_WORKER_MULTIPROC_METHOD=spawn pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work - label: Multi-Modal Models Test (Extended) 1 mirror_hardwares: [amdexperimental] @@ -713,7 +812,8 @@ steps: # num_heads2 broken by https://github.com/flashinfer-ai/flashinfer/issues/1353 - pytest -v -s tests/kernels/attention/test_flashinfer.py -k 'not num_heads2' - pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_attention.py - - pytest -v -s tests/kernels/test_cutlass_mla_decode.py + - pytest -v -s tests/kernels/attention/test_cutlass_mla_decode.py + - pytest -v -s tests/kernels/attention/test_flashinfer_mla_decode.py # Quantization - pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8' - pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py @@ -743,6 +843,8 @@ steps: commands: - pytest -v -s distributed/test_comm_ops.py - pytest -v -s distributed/test_shm_broadcast.py + - pytest -v -s distributed/test_shm_buffer.py + - pytest -v -s distributed/test_shm_storage.py - label: 2 Node Tests (4 GPUs in total) # 16min timeout_in_minutes: 30 @@ -801,7 +903,8 @@ steps: # Avoid importing model tests that cause CUDA reinitialization error - pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)' - pytest models/language -v -s -m 'distributed(num_gpus=2)' - - pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' + - pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' --ignore models/multimodal/generation/test_whisper.py + - VLLM_WORKER_MULTIPROC_METHOD=spawn pytest models/multimodal/generation/test_whisper.py -v -s -m 'distributed(num_gpus=2)' # test sequence parallel - pytest -v -s distributed/test_sequence_parallel.py # this test fails consistently. @@ -827,7 +930,7 @@ steps: # begin io_processor plugins test, all the code in between uses the prithvi_io_processor plugin - pip install -e ./plugins/prithvi_io_processor_plugin - pytest -v -s plugins_tests/test_io_processor_plugins.py - - pip uninstall prithvi_io_processor_plugin -y + - pip uninstall prithvi_io_processor_plugin -y # end io_processor plugins test # other tests continue here: - pytest -v -s plugins_tests/test_scheduler_plugins.py @@ -875,7 +978,7 @@ steps: timeout_in_minutes: 45 mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" - num_gpus: 2 + num_gpus: 2 optional: true source_file_dependencies: - vllm/ diff --git a/.github/.bc-linter.yml b/.github/.bc-linter.yml new file mode 100644 index 0000000000000..443dfa45af22c --- /dev/null +++ b/.github/.bc-linter.yml @@ -0,0 +1,24 @@ +# doc: https://github.com/pytorch/test-infra/blob/main/tools/stronghold/docs/bc_linter_config.md +version: 1 +paths: +# We temporarily disable globally, and will only enable with `annotations.include` +# include: +# - "vllm/v1/attetion/*.py" +# - "vllm/v1/core/*.py" +exclude: + - "**/*.py" + +scan: + functions: true # check free functions and methods + classes: true # check classes/dataclasses + public_only: true # ignore names starting with "_" at any level + +annotations: + include: # decorators that force‑include a symbol + - name: "bc_linter_include" # matched by simple name or dotted suffix + propagate_to_members: false # for classes, include methods/inner classes + exclude: # decorators that force‑exclude a symbol + - name: "bc_linter_skip" # matched by simple name or dotted suffix + propagate_to_members: true # for classes, exclude methods/inner classes + +excluded_violations: [] # e.g. ["ParameterRenamed", "FieldTypeChanged"] diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 682b27ac8986e..846b68054c0a1 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -8,17 +8,18 @@ /vllm/executor/executor_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @22quinn /vllm/worker/worker_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @22quinn /vllm/worker/worker.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill -/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill +/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @NickLucche /vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth @yewentao256 /vllm/model_executor/layers/mamba @tdoublep /vllm/model_executor/model_loader @22quinn -/vllm/multimodal @DarkLight1337 @ywang96 +/vllm/multimodal @DarkLight1337 @ywang96 @NickLucche /vllm/v1/sample @22quinn @houseroad /vllm/vllm_flash_attn @LucasWilkinson /vllm/lora @jeejeelee -/vllm/reasoning @aarnphm -/vllm/entrypoints @aarnphm +/vllm/reasoning @aarnphm @chaunceyjiang +/vllm/entrypoints @aarnphm @chaunceyjiang /vllm/compilation @zou3519 @youkaichao @ProExpertProg +/vllm/distributed/kv_transfer @NickLucche CMakeLists.txt @tlrmchlsmth @LucasWilkinson # Any change to the VllmConfig changes can have a large user-facing impact, @@ -28,8 +29,10 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson # vLLM V1 /vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat /vllm/v1/structured_output @mgoin @russellb @aarnphm @benchislett -/vllm/v1/spec_decode @benchislett +/vllm/v1/spec_decode @benchislett @luccafong /vllm/v1/attention/backends/triton_attn.py @tdoublep +/vllm/v1/core @heheda12345 +/vllm/v1/kv_cache_interface.py @heheda12345 # Test ownership /.buildkite/lm-eval-harness @mgoin @simon-mo @@ -37,18 +40,20 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson /tests/distributed/test_multi_node_assignment.py @youkaichao /tests/distributed/test_pipeline_parallel.py @youkaichao /tests/distributed/test_same_node.py @youkaichao -/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @simon-mo @aarnphm +/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @simon-mo @aarnphm @NickLucche /tests/kernels @tlrmchlsmth @WoosukKwon @yewentao256 /tests/models @DarkLight1337 @ywang96 -/tests/multimodal @DarkLight1337 @ywang96 +/tests/multimodal @DarkLight1337 @ywang96 @NickLucche /tests/prefix_caching @comaniac @KuntaiDu /tests/quantization @mgoin @robertgshaw2-redhat @yewentao256 /tests/test_inputs.py @DarkLight1337 @ywang96 /tests/v1/entrypoints/llm/test_struct_output_generate.py @mgoin @russellb @aarnphm /tests/v1/structured_output @mgoin @russellb @aarnphm +/tests/v1/core @heheda12345 /tests/weight_loading @mgoin @youkaichao @yewentao256 /tests/lora @jeejeelee /tests/models/language/generation/test_hybrid.py @tdoublep +/tests/v1/kv_connector/nixl_integration @NickLucche # Docs /docs @hmellor @@ -70,6 +75,9 @@ mkdocs.yaml @hmellor /vllm/attention/backends/dual_chunk_flash_attn.py @sighingnow /vllm/model_executor/models/qwen* @sighingnow +# MTP-specific files +/vllm/model_executor/models/deepseek_mtp.py @luccafong + # Mistral-specific files /vllm/model_executor/models/mistral*.py @patrickvonplaten /vllm/model_executor/models/mixtral*.py @patrickvonplaten @@ -88,3 +96,9 @@ mkdocs.yaml @hmellor /vllm/v1/attention/backends/mla/rocm*.py @gshtras /vllm/attention/ops/rocm*.py @gshtras /vllm/model_executor/layers/fused_moe/rocm*.py @gshtras + +# TPU +/vllm/v1/worker/tpu* @NickLucche +/vllm/platforms/tpu.py @NickLucche +/vllm/v1/sample/tpu @NickLucche +/vllm/tests/v1/tpu @NickLucche \ No newline at end of file diff --git a/.github/mergify.yml b/.github/mergify.yml index 495d207d44260..f2dd2e06214ae 100644 --- a/.github/mergify.yml +++ b/.github/mergify.yml @@ -124,9 +124,16 @@ pull_request_rules: - or: - files~=^examples/.*gpt[-_]?oss.*\.py - files~=^tests/.*gpt[-_]?oss.*\.py + - files~=^tests/entrypoints/openai/test_response_api_with_harmony.py + - files~=^tests/entrypoints/test_context.py - files~=^vllm/model_executor/models/.*gpt[-_]?oss.*\.py - files~=^vllm/model_executor/layers/.*gpt[-_]?oss.*\.py + - files~=^vllm/entrypoints/harmony_utils.py + - files~=^vllm/entrypoints/tool_server.py + - files~=^vllm/entrypoints/tool.py + - files~=^vllm/entrypoints/context.py - title~=(?i)gpt[-_]?oss + - title~=(?i)harmony actions: label: add: @@ -273,6 +280,20 @@ pull_request_rules: users: - "sangstar" +- name: assign reviewer for modelopt changes + conditions: + - or: + - files~=^vllm/model_executor/layers/quantization/modelopt\.py$ + - files~=^vllm/model_executor/layers/quantization/__init__\.py$ + - files~=^tests/models/quantization/test_modelopt\.py$ + - files~=^tests/quantization/test_modelopt\.py$ + - files~=^tests/models/quantization/test_nvfp4\.py$ + - files~=^docs/features/quantization/modelopt\.md$ + actions: + assign: + users: + - "Edwardf0t1" + - name: remove 'needs-rebase' label when conflict is resolved conditions: - -conflict diff --git a/.github/workflows/add_label_automerge.yml b/.github/workflows/add_label_automerge.yml index 315042fbf5cf4..d8bbedef3174b 100644 --- a/.github/workflows/add_label_automerge.yml +++ b/.github/workflows/add_label_automerge.yml @@ -10,7 +10,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Add label - uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 with: script: | github.rest.issues.addLabels({ diff --git a/.github/workflows/bc-lint.yml b/.github/workflows/bc-lint.yml new file mode 100644 index 0000000000000..823695a921321 --- /dev/null +++ b/.github/workflows/bc-lint.yml @@ -0,0 +1,29 @@ +name: BC Lint + +on: + pull_request: + types: + - opened + - synchronize + - reopened + - labeled + - unlabeled + +jobs: + bc_lint: + if: github.repository_owner == 'vllm-project' + runs-on: ubuntu-latest + steps: + - name: Run BC Lint Action + uses: pytorch/test-infra/.github/actions/bc-lint@main + with: + repo: ${{ github.event.pull_request.head.repo.full_name }} + base_sha: ${{ github.event.pull_request.base.sha }} + head_sha: ${{ github.event.pull_request.head.sha }} + suppression: ${{ contains(github.event.pull_request.labels.*.name, 'suppress-bc-linter') }} + docs_link: 'https://github.com/pytorch/test-infra/wiki/BC-Linter' + config_dir: .github + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }} + cancel-in-progress: true diff --git a/.github/workflows/cleanup_pr_body.yml b/.github/workflows/cleanup_pr_body.yml index d5c6b8d43a6ef..c3e132a536a42 100644 --- a/.github/workflows/cleanup_pr_body.yml +++ b/.github/workflows/cleanup_pr_body.yml @@ -16,7 +16,7 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 with: python-version: '3.12' diff --git a/.github/workflows/issue_autolabel.yml b/.github/workflows/issue_autolabel.yml index e0ab3872d8fa3..c2b17abe811cd 100644 --- a/.github/workflows/issue_autolabel.yml +++ b/.github/workflows/issue_autolabel.yml @@ -13,7 +13,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Label issues based on keywords - uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 with: script: | // Configuration: Add new labels and keywords here diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 195579f206a2f..e21d13b8161f3 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -17,7 +17,7 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + - uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 with: python-version: "3.12" - run: echo "::add-matcher::.github/workflows/matchers/actionlint.json" diff --git a/.github/workflows/reminder_comment.yml b/.github/workflows/reminder_comment.yml index 1ee605dc7bb0d..8884359fa0ce4 100644 --- a/.github/workflows/reminder_comment.yml +++ b/.github/workflows/reminder_comment.yml @@ -9,7 +9,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Remind to run full CI on PR - uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 with: script: | try { diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 656f3d3fa7bc4..82844810a633a 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -13,7 +13,7 @@ jobs: actions: write runs-on: ubuntu-latest steps: - - uses: actions/stale@5bef64f19d7facfb25b37b414482c7164d639639 # v9.1.0 + - uses: actions/stale@3a9db7e6a41a89f618792c92c0e97cc736e1b13f # v10.0.0 with: # Increasing this value ensures that changes to this workflow # propagate to all issues and PRs in days rather than months diff --git a/.gitignore b/.gitignore index 465935d488f84..b1df673e83ca8 100644 --- a/.gitignore +++ b/.gitignore @@ -4,7 +4,7 @@ # vllm-flash-attn built from source vllm/vllm_flash_attn/* -# triton jit +# triton jit .triton # Byte-compiled / optimized / DLL files @@ -177,6 +177,14 @@ cython_debug/ # VSCode .vscode/ +# Claude +CLAUDE.md +.claude/ + +# Codex +AGENTS.md +.codex/ + # DS Store .DS_Store @@ -209,4 +217,4 @@ shellcheck*/ csrc/moe/marlin_moe_wna16/kernel_* # Ignore ep_kernels_workspace folder -ep_kernels_workspace/ \ No newline at end of file +ep_kernels_workspace/ diff --git a/.yapfignore b/.yapfignore index 2d6dcf8380cac..38158259032a6 100644 --- a/.yapfignore +++ b/.yapfignore @@ -1 +1,2 @@ collect_env.py +vllm/model_executor/layers/fla/ops/*.py diff --git a/README.md b/README.md index 4e03df758c261..0c6e5aa6b31d2 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,9 @@ Easy, fast, and cheap LLM serving for everyone | Documentation | Blog | Paper | Twitter/X | User Forum | Developer Slack |

+--- +Join us at the [PyTorch Conference, October 22-23](https://events.linuxfoundation.org/pytorch-conference/) and [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco for our latest updates on vLLM and to meet the vLLM team! Register now for the largest vLLM community events of the year! + --- *Latest News* 🔥 @@ -78,7 +81,7 @@ vLLM is flexible and easy to use with: - Tensor, pipeline, data and expert parallelism support for distributed inference - Streaming outputs - OpenAI-compatible API server -- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, TPU, and AWS Neuron +- Support for NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, and TPU. Additionally, support for diverse hardware plugins such as Intel Gaudi, IBM Spyre and Huawei Ascend. - Prefix caching support - Multi-LoRA support diff --git a/benchmarks/README.md b/benchmarks/README.md index 98b3600d13635..ee172642033de 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -95,6 +95,24 @@ become available. ✅ lmms-lab/LLaVA-OneVision-Data, Aeala/ShareGPT_Vicuna_unfiltered + + HuggingFace-MTBench + ✅ + ✅ + philschmid/mt-bench + + + HuggingFace-Blazedit + ✅ + ✅ + vdaita/edit_5k_char, vdaita/edit_10k_char + + + Spec Bench + ✅ + ✅ + wget https://raw.githubusercontent.com/hemingkx/Spec-Bench/refs/heads/main/data/spec_bench/question.jsonl + Custom ✅ @@ -239,6 +257,43 @@ vllm bench serve \ --num-prompts 2048 ``` +### Spec Bench Benchmark with Speculative Decoding + +``` bash +VLLM_USE_V1=1 vllm serve meta-llama/Meta-Llama-3-8B-Instruct \ + --speculative-config $'{"method": "ngram", + "num_speculative_tokens": 5, "prompt_lookup_max": 5, + "prompt_lookup_min": 2}' +``` + +[SpecBench dataset](https://github.com/hemingkx/Spec-Bench) + +Run all categories: + +``` bash +# Download the dataset using: +# wget https://raw.githubusercontent.com/hemingkx/Spec-Bench/refs/heads/main/data/spec_bench/question.jsonl + +vllm bench serve \ + --model meta-llama/Meta-Llama-3-8B-Instruct \ + --dataset-name spec_bench \ + --dataset-path "/data/spec_bench/question.jsonl" \ + --num-prompts -1 +``` + +Available categories include `[writing, roleplay, reasoning, math, coding, extraction, stem, humanities, translation, summarization, qa, math_reasoning, rag]`. + +Run only a specific category like "summarization": + +``` bash +vllm bench serve \ + --model meta-llama/Meta-Llama-3-8B-Instruct \ + --dataset-name spec_bench \ + --dataset-path "/data/spec_bench/question.jsonl" \ + --num-prompts -1 + --spec-bench-category "summarization" +``` + ### Other HuggingFaceDataset Examples ```bash @@ -295,6 +350,18 @@ vllm bench serve \ --num-prompts 80 ``` +`vdaita/edit_5k_char` or `vdaita/edit_10k_char`: + +``` bash +vllm bench serve \ + --model Qwen/QwQ-32B \ + --dataset-name hf \ + --dataset-path vdaita/edit_5k_char \ + --num-prompts 90 \ + --blazedit-min-distance 0.01 \ + --blazedit-max-distance 0.99 +``` + ### Running With Sampling Parameters When using OpenAI-compatible backends such as `vllm`, optional sampling @@ -694,7 +761,7 @@ python -m vllm.entrypoints.openai.api_server \ Send requests with images: ```bash -python benchmarks/benchmark_serving.py \ +vllm bench serve \ --backend openai-chat \ --model Qwen/Qwen2.5-VL-7B-Instruct \ --dataset-name sharegpt \ @@ -721,7 +788,7 @@ python -m vllm.entrypoints.openai.api_server \ Send requests with videos: ```bash -python benchmarks/benchmark_serving.py \ +vllm bench serve \ --backend openai-chat \ --model Qwen/Qwen2.5-VL-7B-Instruct \ --dataset-name sharegpt \ diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index d8b960edaa468..a7892f3f71243 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -1,191 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Benchmark the latency of processing a single batch of requests.""" - -import argparse -import dataclasses -import json -import os -import time -from typing import Any, Optional - -import numpy as np -from tqdm import tqdm -from typing_extensions import deprecated - -import vllm.envs as envs -from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json -from vllm import LLM, SamplingParams -from vllm.engine.arg_utils import EngineArgs -from vllm.inputs import PromptType -from vllm.sampling_params import BeamSearchParams -from vllm.utils import FlexibleArgumentParser - - -def save_to_pytorch_benchmark_format( - args: argparse.Namespace, results: dict[str, Any] -) -> None: - pt_records = convert_to_pytorch_benchmark_format( - args=args, - metrics={"latency": results["latencies"]}, - extra_info={k: results[k] for k in ["avg_latency", "percentiles"]}, - ) - if pt_records: - pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json" - write_to_json(pt_file, pt_records) - - -@deprecated( - "benchmark_latency.py is deprecated and will be removed in a " - "future version. Please use 'vllm bench latency' instead.", -) -def main(args: argparse.Namespace): - print(args) - - engine_args = EngineArgs.from_cli_args(args) - - # NOTE(woosuk): If the request cannot be processed in a single batch, - # the engine will automatically process the request in multiple batches. - llm = LLM(**dataclasses.asdict(engine_args)) - assert llm.llm_engine.model_config.max_model_len >= ( - args.input_len + args.output_len - ), ( - "Please ensure that max_model_len is greater than" - " the sum of input_len and output_len." - ) - - sampling_params = SamplingParams( - n=args.n, - temperature=1.0, - top_p=1.0, - ignore_eos=True, - max_tokens=args.output_len, - detokenize=not args.disable_detokenize, - ) - print(sampling_params) - dummy_prompt_token_ids = np.random.randint( - 10000, size=(args.batch_size, args.input_len) - ) - dummy_prompts: list[PromptType] = [ - {"prompt_token_ids": batch} for batch in dummy_prompt_token_ids.tolist() - ] - - def llm_generate(): - if not args.use_beam_search: - llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False) - else: - llm.beam_search( - dummy_prompts, - BeamSearchParams( - beam_width=args.n, - max_tokens=args.output_len, - ignore_eos=True, - ), - ) - - def run_to_completion(profile_dir: Optional[str] = None): - if profile_dir: - llm.start_profile() - llm_generate() - llm.stop_profile() - else: - start_time = time.perf_counter() - llm_generate() - end_time = time.perf_counter() - latency = end_time - start_time - return latency - - print("Warming up...") - for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"): - run_to_completion(profile_dir=None) - - if args.profile: - profile_dir = envs.VLLM_TORCH_PROFILER_DIR - print(f"Profiling (results will be saved to '{profile_dir}')...") - run_to_completion(profile_dir=profile_dir) - return - - # Benchmark. - latencies = [] - for _ in tqdm(range(args.num_iters), desc="Profiling iterations"): - latencies.append(run_to_completion(profile_dir=None)) - latencies = np.array(latencies) - percentages = [10, 25, 50, 75, 90, 99] - percentiles = np.percentile(latencies, percentages) - print(f"Avg latency: {np.mean(latencies)} seconds") - for percentage, percentile in zip(percentages, percentiles): - print(f"{percentage}% percentile latency: {percentile} seconds") - - # Output JSON results if specified - if args.output_json: - results = { - "avg_latency": np.mean(latencies), - "latencies": latencies.tolist(), - "percentiles": dict(zip(percentages, percentiles.tolist())), - } - with open(args.output_json, "w") as f: - json.dump(results, f, indent=4) - save_to_pytorch_benchmark_format(args, results) - - -def create_argument_parser(): - parser = FlexibleArgumentParser( - description="Benchmark the latency of processing a single batch of " - "requests till completion." - ) - parser.add_argument("--input-len", type=int, default=32) - parser.add_argument("--output-len", type=int, default=128) - parser.add_argument("--batch-size", type=int, default=8) - parser.add_argument( - "--n", - type=int, - default=1, - help="Number of generated sequences per prompt.", - ) - parser.add_argument("--use-beam-search", action="store_true") - parser.add_argument( - "--num-iters-warmup", - type=int, - default=10, - help="Number of iterations to run for warmup.", - ) - parser.add_argument( - "--num-iters", type=int, default=30, help="Number of iterations to run." - ) - parser.add_argument( - "--profile", - action="store_true", - help="profile the generation process of a single batch", - ) - parser.add_argument( - "--output-json", - type=str, - default=None, - help="Path to save the latency results in JSON format.", - ) - parser.add_argument( - "--disable-detokenize", - action="store_true", - help=( - "Do not detokenize responses (i.e. do not include " - "detokenization time in the latency measurement)" - ), - ) - - parser = EngineArgs.add_cli_args(parser) - # V1 enables prefix caching by default which skews the latency - # numbers. We need to disable prefix caching by default. - parser.set_defaults(enable_prefix_caching=False) - - return parser - +import sys if __name__ == "__main__": - parser = create_argument_parser() - args = parser.parse_args() - if args.profile and not envs.VLLM_TORCH_PROFILER_DIR: - raise OSError( - "The environment variable 'VLLM_TORCH_PROFILER_DIR' is not set. " - "Please set it to a valid path to use torch profiler." - ) - main(args) + print("""DEPRECATED: This script has been moved to the vLLM CLI. + +Please use the following command instead: + vllm bench latency + +For help with the new command, run: + vllm bench latency --help + +Alternatively, you can run the new command directly with: + python -m vllm.entrypoints.cli.main bench latency --help +""") + sys.exit(1) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 934df05efac17..76cf51498020b 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -1,1324 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -r"""Benchmark online serving throughput. - -On the server side, run one of the following commands: - vLLM OpenAI API server - vllm serve \ - --swap-space 16 - -On the client side, run: - python benchmarks/benchmark_serving.py \ - --backend \ - --model \ - --dataset-name sharegpt \ - --dataset-path \ - --request-rate \ # By default is inf - --num-prompts # By default is 1000 - - when using tgi backend, add - --endpoint /generate_stream - to the end of the command above. -""" - -import argparse -import asyncio -import gc -import json -import os -import random -import time -import warnings -from collections.abc import Iterable -from dataclasses import dataclass -from datetime import datetime -from typing import Any, Literal, Optional - -import numpy as np -from tqdm.asyncio import tqdm -from transformers import PreTrainedTokenizerBase -from typing_extensions import deprecated - -from backend_request_func import ( - ASYNC_REQUEST_FUNCS, - OPENAI_COMPATIBLE_BACKENDS, - RequestFuncInput, - RequestFuncOutput, -) - -try: - from vllm.transformers_utils.tokenizer import get_tokenizer -except ImportError: - from backend_request_func import get_tokenizer - -try: - from vllm.utils import FlexibleArgumentParser -except ImportError: - from argparse import ArgumentParser as FlexibleArgumentParser - -from benchmark_dataset import ( - AIMODataset, - ASRDataset, - BurstGPTDataset, - ConversationDataset, - CustomDataset, - HuggingFaceDataset, - InstructCoderDataset, - MTBenchDataset, - NextEditPredictionDataset, - RandomDataset, - SampleRequest, - ShareGPTDataset, - SonnetDataset, - VisionArenaDataset, -) -from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json -from vllm.benchmarks.serve import get_request - -MILLISECONDS_TO_SECONDS_CONVERSION = 1000 - - -@dataclass -class BenchmarkMetrics: - completed: int - total_input: int - total_output: int - request_throughput: float - request_goodput: float - output_throughput: float - total_token_throughput: float - mean_ttft_ms: float - median_ttft_ms: float - std_ttft_ms: float - percentiles_ttft_ms: list[tuple[float, float]] - mean_tpot_ms: float - median_tpot_ms: float - std_tpot_ms: float - percentiles_tpot_ms: list[tuple[float, float]] - mean_itl_ms: float - median_itl_ms: float - std_itl_ms: float - percentiles_itl_ms: list[tuple[float, float]] - # E2EL stands for end-to-end latency per request. - # It is the time taken on the client side from sending - # a request to receiving a complete response. - mean_e2el_ms: float - median_e2el_ms: float - std_e2el_ms: float - percentiles_e2el_ms: list[tuple[float, float]] - - -def calculate_metrics( - input_requests: list[SampleRequest], - outputs: list[RequestFuncOutput], - dur_s: float, - tokenizer: PreTrainedTokenizerBase, - selected_percentile_metrics: list[str], - selected_percentiles: list[float], - goodput_config_dict: dict[str, float], -) -> tuple[BenchmarkMetrics, list[int]]: - actual_output_lens: list[int] = [] - total_input = 0 - completed = 0 - good_completed = 0 - itls: list[float] = [] - tpots: list[float] = [] - all_tpots: list[float] = [] - ttfts: list[float] = [] - e2els: list[float] = [] - for i in range(len(outputs)): - if outputs[i].success: - output_len = outputs[i].output_tokens - - if not output_len: - # We use the tokenizer to count the number of output tokens - # for some serving backends instead of looking at - # len(outputs[i].itl) since multiple output tokens may be - # bundled together - # Note : this may inflate the output token count slightly - output_len = len( - tokenizer( - outputs[i].generated_text, add_special_tokens=False - ).input_ids - ) - actual_output_lens.append(output_len) - total_input += input_requests[i].prompt_len - tpot = 0 - if output_len > 1: - latency_minus_ttft = outputs[i].latency - outputs[i].ttft - tpot = latency_minus_ttft / (output_len - 1) - tpots.append(tpot) - # Note: if output_len <= 1, we regard tpot as 0 for goodput - all_tpots.append(tpot) - itls += outputs[i].itl - ttfts.append(outputs[i].ttft) - e2els.append(outputs[i].latency) - completed += 1 - else: - actual_output_lens.append(0) - - if goodput_config_dict: - valid_metrics = [] - slo_values = [] - - if "ttft" in goodput_config_dict: - valid_metrics.append(ttfts) - slo_values.append( - goodput_config_dict["ttft"] / MILLISECONDS_TO_SECONDS_CONVERSION - ) - if "tpot" in goodput_config_dict: - valid_metrics.append(all_tpots) - slo_values.append( - goodput_config_dict["tpot"] / MILLISECONDS_TO_SECONDS_CONVERSION - ) - if "e2el" in goodput_config_dict: - valid_metrics.append(e2els) - slo_values.append( - goodput_config_dict["e2el"] / MILLISECONDS_TO_SECONDS_CONVERSION - ) - - for req_metric in zip(*valid_metrics): - is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)]) - if is_good_req: - good_completed += 1 - - if completed == 0: - warnings.warn( - "All requests failed. This is likely due to a misconfiguration " - "on the benchmark arguments.", - stacklevel=2, - ) - metrics = BenchmarkMetrics( - completed=completed, - total_input=total_input, - total_output=sum(actual_output_lens), - request_throughput=completed / dur_s, - request_goodput=good_completed / dur_s, - output_throughput=sum(actual_output_lens) / dur_s, - total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s, - mean_ttft_ms=np.mean(ttfts or 0) - * 1000, # ttfts is empty if streaming is not supported by backend - std_ttft_ms=np.std(ttfts or 0) * 1000, - median_ttft_ms=np.median(ttfts or 0) * 1000, - percentiles_ttft_ms=[ - (p, np.percentile(ttfts or 0, p) * 1000) for p in selected_percentiles - ], - mean_tpot_ms=np.mean(tpots or 0) * 1000, - std_tpot_ms=np.std(tpots or 0) * 1000, - median_tpot_ms=np.median(tpots or 0) * 1000, - percentiles_tpot_ms=[ - (p, np.percentile(tpots or 0, p) * 1000) for p in selected_percentiles - ], - mean_itl_ms=np.mean(itls or 0) * 1000, - std_itl_ms=np.std(itls or 0) * 1000, - median_itl_ms=np.median(itls or 0) * 1000, - percentiles_itl_ms=[ - (p, np.percentile(itls or 0, p) * 1000) for p in selected_percentiles - ], - mean_e2el_ms=np.mean(e2els or 0) * 1000, - std_e2el_ms=np.std(e2els or 0) * 1000, - median_e2el_ms=np.median(e2els or 0) * 1000, - percentiles_e2el_ms=[ - (p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles - ], - ) - - return metrics, actual_output_lens - - -async def benchmark( - backend: str, - api_url: str, - base_url: str, - model_id: str, - model_name: str, - tokenizer: PreTrainedTokenizerBase, - input_requests: list[SampleRequest], - logprobs: Optional[int], - request_rate: float, - burstiness: float, - disable_tqdm: bool, - profile: bool, - selected_percentile_metrics: list[str], - selected_percentiles: list[float], - ignore_eos: bool, - goodput_config_dict: dict[str, float], - max_concurrency: Optional[int], - lora_modules: Optional[Iterable[str]], - extra_body: Optional[dict], - ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None, - ramp_up_start_rps: Optional[int] = None, - ramp_up_end_rps: Optional[int] = None, -): - if backend in ASYNC_REQUEST_FUNCS: - request_func = ASYNC_REQUEST_FUNCS[backend] - else: - raise ValueError(f"Unknown backend: {backend}") - - print("Starting initial single prompt test run...") - test_prompt, test_prompt_len, test_output_len, test_mm_content = ( - input_requests[0].prompt, - input_requests[0].prompt_len, - input_requests[0].expected_output_len, - input_requests[0].multi_modal_data, - ) - - assert ( - test_mm_content is None - or isinstance(test_mm_content, dict) - or ( - isinstance(test_mm_content, list) - and all(isinstance(item, dict) for item in test_mm_content) - ) - ), "multi_modal_data must be a dict or list[dict]" - test_input = RequestFuncInput( - model=model_id, - model_name=model_name, - prompt=test_prompt, - api_url=api_url, - prompt_len=test_prompt_len, - output_len=test_output_len, - logprobs=logprobs, - multi_modal_content=test_mm_content, - ignore_eos=ignore_eos, - extra_body=extra_body, - ) - - test_output = await request_func(request_func_input=test_input) - if not test_output.success: - raise ValueError( - "Initial test run failed - Please make sure benchmark arguments " - f"are correctly specified. Error: {test_output.error}" - ) - else: - print("Initial test run completed. Starting main benchmark run...") - - if lora_modules: - # For each input request, choose a LoRA module at random. - lora_modules = iter( - [random.choice(lora_modules) for _ in range(len(input_requests))] - ) - - if profile: - print("Starting profiler...") - profile_input = RequestFuncInput( - model=model_id, - model_name=model_name, - prompt=test_prompt, - api_url=base_url + "/start_profile", - prompt_len=test_prompt_len, - output_len=test_output_len, - logprobs=logprobs, - multi_modal_content=test_mm_content, - ignore_eos=ignore_eos, - extra_body=extra_body, - ) - profile_output = await request_func(request_func_input=profile_input) - if profile_output.success: - print("Profiler started") - - distribution = "Poisson process" if burstiness == 1.0 else "Gamma distribution" - - if ramp_up_strategy is not None: - print( - f"Traffic ramp-up strategy: {ramp_up_strategy}. Will increase " - f"RPS from {ramp_up_start_rps} to {ramp_up_end_rps} RPS over " - "the duration of the benchmark." - ) - else: - print(f"Traffic request rate: {request_rate} RPS.") - - print(f"Burstiness factor: {burstiness} ({distribution})") - print(f"Maximum request concurrency: {max_concurrency}") - - pbar = None if disable_tqdm else tqdm(total=len(input_requests)) - - # This can be used once the minimum Python version is 3.10 or higher, - # and it will simplify the code in limited_request_func. - # semaphore = (asyncio.Semaphore(max_concurrency) - # if max_concurrency else contextlib.nullcontext()) - semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None - - async def limited_request_func(request_func_input, pbar): - if semaphore is None: - return await request_func(request_func_input=request_func_input, pbar=pbar) - async with semaphore: - return await request_func(request_func_input=request_func_input, pbar=pbar) - - benchmark_start_time = time.perf_counter() - tasks: list[asyncio.Task] = [] - - rps_change_events = [] - last_int_rps = -1 - if ramp_up_strategy is not None and ramp_up_start_rps is not None: - last_int_rps = ramp_up_start_rps - rps_change_events.append( - { - "rps": last_int_rps, - "timestamp": datetime.now().isoformat(), - } - ) - - async for request, current_request_rate in get_request( - input_requests, - request_rate, - burstiness, - ramp_up_strategy, - ramp_up_start_rps, - ramp_up_end_rps, - ): - if ramp_up_strategy is not None: - current_int_rps = int(current_request_rate) - if current_int_rps > last_int_rps: - timestamp = datetime.now().isoformat() - for rps_val in range(last_int_rps + 1, current_int_rps + 1): - rps_change_events.append({"rps": rps_val, "timestamp": timestamp}) - last_int_rps = current_int_rps - - prompt, prompt_len, output_len, mm_content, request_id = ( - request.prompt, - request.prompt_len, - request.expected_output_len, - request.multi_modal_data, - request.request_id, - ) - req_model_id, req_model_name = model_id, model_name - if lora_modules: - req_lora_module = next(lora_modules) - req_model_id, req_model_name = req_lora_module, req_lora_module - - request_func_input = RequestFuncInput( - model=req_model_id, - model_name=req_model_name, - prompt=prompt, - api_url=api_url, - prompt_len=prompt_len, - output_len=output_len, - logprobs=logprobs, - multi_modal_content=mm_content, - ignore_eos=ignore_eos, - extra_body=extra_body, - request_id=request_id, - ) - task = limited_request_func(request_func_input=request_func_input, pbar=pbar) - tasks.append(asyncio.create_task(task)) - outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) - - if pbar is not None: - pbar.close() - - benchmark_duration = time.perf_counter() - benchmark_start_time - - metrics, actual_output_lens = calculate_metrics( - input_requests=input_requests, - outputs=outputs, - dur_s=benchmark_duration, - tokenizer=tokenizer, - selected_percentile_metrics=selected_percentile_metrics, - selected_percentiles=selected_percentiles, - goodput_config_dict=goodput_config_dict, - ) - - print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="=")) - print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) - if max_concurrency is not None: - print("{:<40} {:<10}".format("Maximum request concurrency:", max_concurrency)) - if request_rate != float("inf"): - print("{:<40} {:<10.2f}".format("Request rate configured (RPS):", request_rate)) - print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration)) - print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) - print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output)) - print( - "{:<40} {:<10.2f}".format( - "Request throughput (req/s):", metrics.request_throughput - ) - ) - if goodput_config_dict: - print( - "{:<40} {:<10.2f}".format( - "Request goodput (req/s):", metrics.request_goodput - ) - ) - print( - "{:<40} {:<10.2f}".format( - "Output token throughput (tok/s):", metrics.output_throughput - ) - ) - print( - "{:<40} {:<10.2f}".format( - "Total Token throughput (tok/s):", metrics.total_token_throughput - ) - ) - - result = { - "duration": benchmark_duration, - "completed": metrics.completed, - "total_input_tokens": metrics.total_input, - "total_output_tokens": metrics.total_output, - "request_throughput": metrics.request_throughput, - "request_goodput": metrics.request_goodput if goodput_config_dict else None, - "output_throughput": metrics.output_throughput, - "total_token_throughput": metrics.total_token_throughput, - "input_lens": [output.prompt_len for output in outputs], - "output_lens": actual_output_lens, - "ttfts": [output.ttft for output in outputs], - "itls": [output.itl for output in outputs], - "generated_texts": [output.generated_text for output in outputs], - "errors": [output.error for output in outputs], - } - - if rps_change_events: - result["rps_change_events"] = rps_change_events - - def process_one_metric( - # E.g., "ttft" - metric_attribute_name: str, - # E.g., "TTFT" - metric_name: str, - # E.g., "Time to First Token" - metric_header: str, - ): - # This function prints and adds statistics of the specified - # metric. - if metric_attribute_name not in selected_percentile_metrics: - return - print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-")) - print( - "{:<40} {:<10.2f}".format( - f"Mean {metric_name} (ms):", - getattr(metrics, f"mean_{metric_attribute_name}_ms"), - ) - ) - print( - "{:<40} {:<10.2f}".format( - f"Median {metric_name} (ms):", - getattr(metrics, f"median_{metric_attribute_name}_ms"), - ) - ) - result[f"mean_{metric_attribute_name}_ms"] = getattr( - metrics, f"mean_{metric_attribute_name}_ms" - ) - result[f"median_{metric_attribute_name}_ms"] = getattr( - metrics, f"median_{metric_attribute_name}_ms" - ) - result[f"std_{metric_attribute_name}_ms"] = getattr( - metrics, f"std_{metric_attribute_name}_ms" - ) - for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}_ms"): - p_word = str(int(p)) if int(p) == p else str(p) - print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value)) - result[f"p{p_word}_{metric_attribute_name}_ms"] = value - - process_one_metric("ttft", "TTFT", "Time to First Token") - process_one_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)") - process_one_metric("itl", "ITL", "Inter-token Latency") - process_one_metric("e2el", "E2EL", "End-to-end Latency") - - print("=" * 50) - - if profile: - print("Stopping profiler...") - profile_input = RequestFuncInput( - model=model_id, - prompt=test_prompt, - api_url=base_url + "/stop_profile", - prompt_len=test_prompt_len, - output_len=test_output_len, - logprobs=logprobs, - ) - profile_output = await request_func(request_func_input=profile_input) - if profile_output.success: - print("Profiler stopped") - - return result - - -def check_goodput_args(args): - # Check and parse goodput arguments - goodput_config_dict = {} - VALID_NAMES = ["ttft", "tpot", "e2el"] - if args.goodput: - goodput_config_dict = parse_goodput(args.goodput) - for slo_name, slo_val in goodput_config_dict.items(): - if slo_name not in VALID_NAMES: - raise ValueError( - f"Invalid metric name found, {slo_name}: {slo_val}. " - "The service level objective name should be one of " - f"{str(VALID_NAMES)}. " - ) - if slo_val < 0: - raise ValueError( - f"Invalid value found, {slo_name}: {slo_val}. " - "The service level objective value should be " - "non-negative." - ) - return goodput_config_dict - - -def parse_goodput(slo_pairs): - goodput_config_dict = {} - try: - for slo_pair in slo_pairs: - slo_name, slo_val = slo_pair.split(":") - goodput_config_dict[slo_name] = float(slo_val) - except ValueError as err: - raise argparse.ArgumentTypeError( - "Invalid format found for service level objectives. " - 'Specify service level objectives for goodput as "KEY:VALUE" ' - "pairs, where the key is a metric name, and the value is a " - "number in milliseconds." - ) from err - return goodput_config_dict - - -def save_to_pytorch_benchmark_format( - args: argparse.Namespace, results: dict[str, Any], file_name: str -) -> None: - metrics = [ - "median_ttft_ms", - "mean_ttft_ms", - "std_ttft_ms", - "p99_ttft_ms", - "mean_tpot_ms", - "median_tpot_ms", - "std_tpot_ms", - "p99_tpot_ms", - "median_itl_ms", - "mean_itl_ms", - "std_itl_ms", - "p99_itl_ms", - ] - # These raw data might be useful, but they are rather big. They can be added - # later if needed - ignored_metrics = ["ttfts", "itls", "generated_texts", "errors"] - pt_records = convert_to_pytorch_benchmark_format( - args=args, - metrics={k: [results[k]] for k in metrics}, - extra_info={ - k: results[k] - for k in results - if k not in metrics and k not in ignored_metrics - }, - ) - if pt_records: - # Don't use json suffix here as we don't want CI to pick it up - pt_file = f"{os.path.splitext(file_name)[0]}.pytorch.json" - write_to_json(pt_file, pt_records) - - -@deprecated( - "benchmark_serving.py is deprecated and will be removed in a future " - "version. Please use 'vllm bench serve' instead.", -) -def main(args: argparse.Namespace): - print(args) - random.seed(args.seed) - np.random.seed(args.seed) - - backend = args.backend - model_id = args.model - model_name = args.served_model_name - tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model - tokenizer_mode = args.tokenizer_mode - - # Validate ramp-up arguments - if args.ramp_up_strategy is not None: - if args.request_rate != float("inf"): - raise ValueError( - "When using ramp-up, do not specify --request-rate. " - "The request rate will be controlled by ramp-up parameters. " - "Please remove the --request-rate argument." - ) - if args.ramp_up_start_rps is None or args.ramp_up_end_rps is None: - raise ValueError( - "When using --ramp-up-strategy, both --ramp-up-start-rps and " - "--ramp-up-end-rps must be specified" - ) - if args.ramp_up_start_rps < 0 or args.ramp_up_end_rps < 0: - raise ValueError("Ramp-up start and end RPS must be non-negative") - if args.ramp_up_start_rps > args.ramp_up_end_rps: - raise ValueError("Ramp-up start RPS must be less than end RPS") - if args.ramp_up_strategy == "exponential" and args.ramp_up_start_rps == 0: - raise ValueError("For exponential ramp-up, the start RPS cannot be 0.") - - if args.base_url is not None: - api_url = f"{args.base_url}{args.endpoint}" - base_url = f"{args.base_url}" - else: - api_url = f"http://{args.host}:{args.port}{args.endpoint}" - base_url = f"http://{args.host}:{args.port}" - - tokenizer = get_tokenizer( - tokenizer_id, - tokenizer_mode=tokenizer_mode, - trust_remote_code=args.trust_remote_code, - ) - - if args.dataset_name is None: - raise ValueError( - "Please specify '--dataset-name' and the corresponding " - "'--dataset-path' if required." - ) - - if args.dataset_name == "custom": - dataset = CustomDataset(dataset_path=args.dataset_path) - input_requests = dataset.sample( - num_requests=args.num_prompts, - tokenizer=tokenizer, - output_len=args.custom_output_len, - skip_chat_template=args.custom_skip_chat_template, - request_id_prefix=args.request_id_prefix, - ) - - elif args.dataset_name == "sonnet": - dataset = SonnetDataset(dataset_path=args.dataset_path) - # For the "sonnet" dataset, formatting depends on the backend. - if args.backend == "openai-chat": - input_requests = dataset.sample( - num_requests=args.num_prompts, - input_len=args.sonnet_input_len, - output_len=args.sonnet_output_len, - prefix_len=args.sonnet_prefix_len, - tokenizer=tokenizer, - return_prompt_formatted=False, - request_id_prefix=args.request_id_prefix, - ) - else: - assert tokenizer.chat_template or tokenizer.default_chat_template, ( - "Tokenizer/model must have chat template for sonnet dataset." - ) - input_requests = dataset.sample( - num_requests=args.num_prompts, - input_len=args.sonnet_input_len, - output_len=args.sonnet_output_len, - prefix_len=args.sonnet_prefix_len, - tokenizer=tokenizer, - return_prompt_formatted=True, - request_id_prefix=args.request_id_prefix, - ) - - elif args.dataset_name == "hf": - # all following datasets are implemented from the - # HuggingFaceDataset base class - if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS: - dataset_class = VisionArenaDataset - args.hf_split = "train" - args.hf_subset = None - elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS: - dataset_class = InstructCoderDataset - args.hf_split = "train" - elif args.dataset_path in MTBenchDataset.SUPPORTED_DATASET_PATHS: - dataset_class = MTBenchDataset - args.hf_split = "train" - elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS: - dataset_class = ConversationDataset - elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS: - dataset_class = AIMODataset - args.hf_split = "train" - elif args.dataset_path in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS: # noqa: E501 - dataset_class = NextEditPredictionDataset - args.hf_split = "train" - elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS: - dataset_class = ASRDataset - args.hf_split = "train" - else: - supported_datasets = set( - [ - dataset_name - for cls in HuggingFaceDataset.__subclasses__() - for dataset_name in cls.SUPPORTED_DATASET_PATHS - ] - ) - raise ValueError( - f"Unsupported dataset path: {args.dataset_path}. " - "Huggingface dataset only supports dataset_path" - f" from one of following: {supported_datasets}. " - "Please consider contributing if you would " - "like to add support for additional dataset formats." - ) - - if dataset_class.IS_MULTIMODAL and backend not in [ - "openai-chat", - "openai-audio", - ]: - # multi-modal benchmark is only available on OpenAI Chat backend. - raise ValueError( - "Multi-modal content is only supported on 'openai-chat' and " - "'openai-audio' backend." - ) - input_requests = dataset_class( - dataset_path=args.dataset_path, - dataset_subset=args.hf_subset, - dataset_split=args.hf_split, - random_seed=args.seed, - no_stream=args.no_stream, - ).sample( - num_requests=args.num_prompts, - tokenizer=tokenizer, - output_len=args.hf_output_len, - request_id_prefix=args.request_id_prefix, - ) - - else: - # For datasets that follow a similar structure, use a mapping. - dataset_mapping = { - "sharegpt": lambda: ShareGPTDataset( - random_seed=args.seed, dataset_path=args.dataset_path - ).sample( - tokenizer=tokenizer, - num_requests=args.num_prompts, - output_len=args.sharegpt_output_len, - request_id_prefix=args.request_id_prefix, - ), - "burstgpt": lambda: BurstGPTDataset( - random_seed=args.seed, dataset_path=args.dataset_path - ).sample( - tokenizer=tokenizer, - num_requests=args.num_prompts, - request_id_prefix=args.request_id_prefix, - ), - "random": lambda: RandomDataset(dataset_path=args.dataset_path).sample( - tokenizer=tokenizer, - num_requests=args.num_prompts, - prefix_len=args.random_prefix_len, - input_len=args.random_input_len, - output_len=args.random_output_len, - range_ratio=args.random_range_ratio, - request_id_prefix=args.request_id_prefix, - ), - } - - try: - input_requests = dataset_mapping[args.dataset_name]() - except KeyError as err: - raise ValueError(f"Unknown dataset: {args.dataset_name}") from err - goodput_config_dict = check_goodput_args(args) - - # Collect the sampling parameters. - sampling_params = { - k: v - for k, v in { - "top_p": args.top_p, - "top_k": args.top_k, - "min_p": args.min_p, - "temperature": args.temperature, - }.items() - if v is not None - } - - # Sampling parameters are only supported by openai-compatible backend. - if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS: - raise ValueError( - "Sampling parameters are only supported by openai-compatible backends." - ) - - if "temperature" not in sampling_params: - sampling_params["temperature"] = 0.0 # Default to greedy decoding. - - if args.backend == "llama.cpp": - # Disable prompt caching in llama.cpp backend - sampling_params["cache_prompt"] = False - - # Avoid GC processing "static" data - reduce pause times. - gc.collect() - gc.freeze() - - benchmark_result = asyncio.run( - benchmark( - backend=backend, - api_url=api_url, - base_url=base_url, - model_id=model_id, - model_name=model_name, - tokenizer=tokenizer, - input_requests=input_requests, - logprobs=args.logprobs, - request_rate=args.request_rate, - burstiness=args.burstiness, - disable_tqdm=args.disable_tqdm, - profile=args.profile, - selected_percentile_metrics=args.percentile_metrics.split(","), - selected_percentiles=[float(p) for p in args.metric_percentiles.split(",")], - ignore_eos=args.ignore_eos, - goodput_config_dict=goodput_config_dict, - max_concurrency=args.max_concurrency, - lora_modules=args.lora_modules, - extra_body=sampling_params, - ramp_up_strategy=args.ramp_up_strategy, - ramp_up_start_rps=args.ramp_up_start_rps, - ramp_up_end_rps=args.ramp_up_end_rps, - ) - ) - - # Save config and results to json - if args.save_result or args.append_result: - result_json: dict[str, Any] = {} - - # Setup - current_dt = datetime.now().strftime("%Y%m%d-%H%M%S") - result_json["date"] = current_dt - result_json["backend"] = backend - result_json["model_id"] = model_id - result_json["tokenizer_id"] = tokenizer_id - result_json["num_prompts"] = args.num_prompts - - # Metadata - if args.metadata: - for item in args.metadata: - if "=" in item: - kvstring = item.split("=") - result_json[kvstring[0].strip()] = kvstring[1].strip() - else: - raise ValueError( - "Invalid metadata format. Please use KEY=VALUE format." - ) - # Traffic - result_json["request_rate"] = ( - args.request_rate if args.request_rate < float("inf") else "inf" - ) - result_json["burstiness"] = args.burstiness - result_json["max_concurrency"] = args.max_concurrency - - if args.ramp_up_strategy is not None: - result_json["ramp_up_strategy"] = args.ramp_up_strategy - result_json["ramp_up_start_rps"] = args.ramp_up_start_rps - result_json["ramp_up_end_rps"] = args.ramp_up_end_rps - - # Merge with benchmark result - result_json = {**result_json, **benchmark_result} - - if not args.save_detailed: - # Remove fields with too many data points - for field in [ - "input_lens", - "output_lens", - "ttfts", - "itls", - "generated_texts", - "errors", - ]: - if field in result_json: - del result_json[field] - if field in benchmark_result: - del benchmark_result[field] - - # Save to file - base_model_id = model_id.split("/")[-1] - max_concurrency_str = ( - f"-concurrency{args.max_concurrency}" - if args.max_concurrency is not None - else "" - ) - if args.ramp_up_strategy is not None: - file_name = f"{backend}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa - else: - file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa - if args.result_filename: - file_name = args.result_filename - if args.result_dir: - os.makedirs(args.result_dir, exist_ok=True) - file_name = os.path.join(args.result_dir, file_name) - with open( - file_name, mode="a+" if args.append_result else "w", encoding="utf-8" - ) as outfile: - # Append a newline. - if args.append_result and outfile.tell() != 0: - outfile.write("\n") - json.dump(result_json, outfile) - save_to_pytorch_benchmark_format(args, result_json, file_name) - - -def create_argument_parser(): - parser = FlexibleArgumentParser( - description="Benchmark the online serving throughput." - ) - parser.add_argument( - "--backend", - type=str, - default="vllm", - choices=list(ASYNC_REQUEST_FUNCS.keys()), - ) - parser.add_argument( - "--base-url", - type=str, - default=None, - help="Server or API base url if not using http host and port.", - ) - # Use 127.0.0.1 here instead of localhost to force the use of ipv4 - parser.add_argument("--host", type=str, default="127.0.0.1") - parser.add_argument("--port", type=int, default=8000) - parser.add_argument( - "--endpoint", - type=str, - default="/v1/completions", - help="API endpoint.", - ) - parser.add_argument( - "--dataset-name", - type=str, - default="sharegpt", - choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "custom"], - help="Name of the dataset to benchmark on.", - ) - parser.add_argument( - "--dataset-path", - type=str, - default=None, - help="Path to the sharegpt/sonnet dataset. " - "Or the huggingface dataset ID if using HF dataset.", - ) - parser.add_argument( - "--no-stream", - action="store_true", - help="Do not load the dataset in streaming mode.", - ) - parser.add_argument( - "--max-concurrency", - type=int, - default=None, - help="Maximum number of concurrent requests. This can be used " - "to help simulate an environment where a higher level component " - "is enforcing a maximum number of concurrent requests. While the " - "--request-rate argument controls the rate at which requests are " - "initiated, this argument will control how many are actually allowed " - "to execute at a time. This means that when used in combination, the " - "actual request rate may be lower than specified with --request-rate, " - "if the server is not processing requests fast enough to keep up.", - ) - - parser.add_argument( - "--model", - type=str, - required=True, - help="Name of the model.", - ) - parser.add_argument( - "--tokenizer", - type=str, - help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 - ) - parser.add_argument("--use-beam-search", action="store_true") - parser.add_argument( - "--num-prompts", - type=int, - default=1000, - help="Number of prompts to process.", - ) - parser.add_argument( - "--logprobs", - type=int, - default=None, - help=( - "Number of logprobs-per-token to compute & return as part of " - "the request. If unspecified, then either (1) if beam search " - "is disabled, no logprobs are computed & a single dummy " - "logprob is returned for each token; or (2) if beam search " - "is enabled 1 logprob per token is computed" - ), - ) - parser.add_argument( - "--request-rate", - type=float, - default=float("inf"), - help="Number of requests per second. If this is inf, " - "then all the requests are sent at time 0. " - "Otherwise, we use Poisson process or gamma distribution " - "to synthesize the request arrival times.", - ) - parser.add_argument( - "--burstiness", - type=float, - default=1.0, - help="Burstiness factor of the request generation. " - "Only take effect when request_rate is not inf. " - "Default value is 1, which follows Poisson process. " - "Otherwise, the request intervals follow a gamma distribution. " - "A lower burstiness value (0 < burstiness < 1) results in more " - "bursty requests. A higher burstiness value (burstiness > 1) " - "results in a more uniform arrival of requests.", - ) - parser.add_argument("--seed", type=int, default=0) - parser.add_argument( - "--trust-remote-code", - action="store_true", - help="Trust remote code from huggingface", - ) - parser.add_argument( - "--disable-tqdm", - action="store_true", - help="Specify to disable tqdm progress bar.", - ) - parser.add_argument( - "--profile", - action="store_true", - help="Use Torch Profiler. The endpoint must be launched with " - "VLLM_TORCH_PROFILER_DIR to enable profiler.", - ) - parser.add_argument( - "--save-result", - action="store_true", - help="Specify to save benchmark results to a json file", - ) - parser.add_argument( - "--save-detailed", - action="store_true", - help="When saving the results, whether to include per request " - "information such as response, error, ttfs, tpots, etc.", - ) - parser.add_argument( - "--append-result", - action="store_true", - help="Append the benchmark result to the existing json file.", - ) - parser.add_argument( - "--metadata", - metavar="KEY=VALUE", - nargs="*", - help="Key-value pairs (e.g, --metadata version=0.3.3 tp=1) " - "for metadata of this run to be saved in the result JSON file " - "for record keeping purposes.", - ) - parser.add_argument( - "--result-dir", - type=str, - default=None, - help="Specify directory to save benchmark json results." - "If not specified, results are saved in the current directory.", - ) - parser.add_argument( - "--result-filename", - type=str, - default=None, - help="Specify the filename to save benchmark json results." - "If not specified, results will be saved in " - "{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" - " format.", - ) - parser.add_argument( - "--ignore-eos", - action="store_true", - help="Set ignore_eos flag when sending the benchmark request." - "Warning: ignore_eos is not supported in deepspeed_mii and tgi.", - ) - parser.add_argument( - "--percentile-metrics", - type=str, - default="ttft,tpot,itl", - help="Comma-separated list of selected metrics to report percentiles. " - "This argument specifies the metrics to report percentiles. " - 'Allowed metric names are "ttft", "tpot", "itl", "e2el". ' - 'Default value is "ttft,tpot,itl".', - ) - parser.add_argument( - "--metric-percentiles", - type=str, - default="99", - help="Comma-separated list of percentiles for selected metrics. " - 'To report 25-th, 50-th, and 75-th percentiles, use "25,50,75". ' - 'Default value is "99". ' - 'Use "--percentile-metrics" to select metrics.', - ) - parser.add_argument( - "--goodput", - nargs="+", - required=False, - help='Specify service level objectives for goodput as "KEY:VALUE" ' - "pairs, where the key is a metric name, and the value is in " - 'milliseconds. Multiple "KEY:VALUE" pairs can be provided, ' - "separated by spaces. Allowed request level metric names are " - '"ttft", "tpot", "e2el". For more context on the definition of ' - "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " - "and the blog: https://hao-ai-lab.github.io/blogs/distserve", - ) - parser.add_argument( - "--request-id-prefix", - type=str, - required=False, - default="benchmark-serving", - help="Specify the prefix of request id.", - ) - - # group for dataset specific arguments - custom_group = parser.add_argument_group("custom dataset options") - custom_group.add_argument( - "--custom-output-len", - type=int, - default=256, - help="Number of output tokens per request, used only for custom dataset.", - ) - custom_group.add_argument( - "--custom-skip-chat-template", - action="store_true", - help="Skip applying chat template to prompt, used only for custom dataset.", - ) - - sonnet_group = parser.add_argument_group("sonnet dataset options") - sonnet_group.add_argument( - "--sonnet-input-len", - type=int, - default=550, - help="Number of input tokens per request, used only for sonnet dataset.", - ) - sonnet_group.add_argument( - "--sonnet-output-len", - type=int, - default=150, - help="Number of output tokens per request, used only for sonnet dataset.", - ) - sonnet_group.add_argument( - "--sonnet-prefix-len", - type=int, - default=200, - help="Number of prefix tokens per request, used only for sonnet dataset.", - ) - - sharegpt_group = parser.add_argument_group("sharegpt dataset options") - sharegpt_group.add_argument( - "--sharegpt-output-len", - type=int, - default=None, - help="Output length for each request. Overrides the output length " - "from the ShareGPT dataset.", - ) - - random_group = parser.add_argument_group("random dataset options") - random_group.add_argument( - "--random-input-len", - type=int, - default=1024, - help="Number of input tokens per request, used only for random sampling.", - ) - random_group.add_argument( - "--random-output-len", - type=int, - default=128, - help="Number of output tokens per request, used only for random sampling.", - ) - random_group.add_argument( - "--random-range-ratio", - type=float, - default=0.0, - help="Range ratio for sampling input/output length, " - "used only for random sampling. Must be in the range [0, 1) to define " - "a symmetric sampling range" - "[length * (1 - range_ratio), length * (1 + range_ratio)].", - ) - random_group.add_argument( - "--random-prefix-len", - type=int, - default=0, - help=( - "Number of fixed prefix tokens before the random context " - "in a request. " - "The total input length is the sum of `random-prefix-len` and " - "a random " - "context length sampled from [input_len * (1 - range_ratio), " - "input_len * (1 + range_ratio)]." - ), - ) - - hf_group = parser.add_argument_group("hf dataset options") - hf_group.add_argument( - "--hf-subset", type=str, default=None, help="Subset of the HF dataset." - ) - hf_group.add_argument( - "--hf-split", type=str, default=None, help="Split of the HF dataset." - ) - hf_group.add_argument( - "--hf-output-len", - type=int, - default=None, - help="Output length for each request. Overrides the output lengths " - "from the sampled HF dataset.", - ) - - sampling_group = parser.add_argument_group("sampling parameters") - sampling_group.add_argument( - "--top-p", - type=float, - default=None, - help="Top-p sampling parameter. Only has effect on openai-compatible backends.", - ) - sampling_group.add_argument( - "--top-k", - type=int, - default=None, - help="Top-k sampling parameter. Only has effect on openai-compatible backends.", - ) - sampling_group.add_argument( - "--min-p", - type=float, - default=None, - help="Min-p sampling parameter. Only has effect on openai-compatible backends.", - ) - sampling_group.add_argument( - "--temperature", - type=float, - default=None, - help="Temperature sampling parameter. Only has effect on " - "openai-compatible backends. If not specified, default to greedy " - "decoding (i.e. temperature==0.0).", - ) - - parser.add_argument( - "--tokenizer-mode", - type=str, - default="auto", - choices=["auto", "slow", "mistral", "custom"], - help='The tokenizer mode.\n\n* "auto" will use the ' - 'fast tokenizer if available.\n* "slow" will ' - "always use the slow tokenizer. \n* " - '"mistral" will always use the `mistral_common` tokenizer. \n*' - '"custom" will use --tokenizer to select the preregistered tokenizer.', - ) - - parser.add_argument( - "--served-model-name", - type=str, - default=None, - help="The model name used in the API. " - "If not specified, the model name will be the " - "same as the ``--model`` argument. ", - ) - - parser.add_argument( - "--lora-modules", - nargs="+", - default=None, - help="A subset of LoRA module names passed in when " - "launching the server. For each request, the " - "script chooses a LoRA module at random.", - ) - - parser.add_argument( - "--ramp-up-strategy", - type=str, - default=None, - choices=["linear", "exponential"], - help="The ramp-up strategy. This would be used to " - "ramp up the request rate from initial RPS to final " - "RPS rate (specified by --ramp-up-start-rps and --ramp-up-end-rps). " - "over the duration of the benchmark.", - ) - parser.add_argument( - "--ramp-up-start-rps", - type=int, - default=None, - help="The starting request rate for ramp-up (RPS). " - "Needs to be specified when --ramp-up-strategy is used.", - ) - parser.add_argument( - "--ramp-up-end-rps", - type=int, - default=None, - help="The ending request rate for ramp-up (RPS). " - "Needs to be specified when --ramp-up-strategy is used.", - ) - - return parser - +import sys if __name__ == "__main__": - parser = create_argument_parser() - args = parser.parse_args() - main(args) + print("""DEPRECATED: This script has been moved to the vLLM CLI. + +Please use the following command instead: + vllm bench serve + +For help with the new command, run: + vllm bench serve --help + +Alternatively, you can run the new command directly with: + python -m vllm.entrypoints.cli.main bench serve --help +""") + sys.exit(1) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 34a525f00d910..b6dc0918fd4d1 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -1,741 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Benchmark offline inference throughput.""" - -import argparse -import dataclasses -import json -import os -import random -import time -import warnings -from typing import Any, Optional, Union - -import torch -import uvloop -from tqdm import tqdm -from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase -from typing_extensions import deprecated - -from benchmark_dataset import ( - AIMODataset, - BurstGPTDataset, - ConversationDataset, - InstructCoderDataset, - RandomDataset, - SampleRequest, - ShareGPTDataset, - SonnetDataset, - VisionArenaDataset, -) -from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json -from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs -from vllm.entrypoints.openai.api_server import ( - build_async_engine_client_from_engine_args, -) -from vllm.inputs import TextPrompt, TokensPrompt -from vllm.lora.request import LoRARequest -from vllm.outputs import RequestOutput -from vllm.sampling_params import BeamSearchParams -from vllm.utils import FlexibleArgumentParser, merge_async_iterators - - -def run_vllm( - requests: list[SampleRequest], - n: int, - engine_args: EngineArgs, - disable_detokenize: bool = False, -) -> tuple[float, Optional[list[RequestOutput]]]: - from vllm import LLM, SamplingParams - - llm = LLM(**dataclasses.asdict(engine_args)) - assert all( - llm.llm_engine.model_config.max_model_len - >= (request.prompt_len + request.expected_output_len) - for request in requests - ), ( - "Please ensure that max_model_len is greater than the sum of" - " prompt_len and expected_output_len for all requests." - ) - # Add the requests to the engine. - prompts: list[Union[TextPrompt, TokensPrompt]] = [] - sampling_params: list[SamplingParams] = [] - for request in requests: - prompts.append( - TokensPrompt( - prompt_token_ids=request.prompt["prompt_token_ids"], - multi_modal_data=request.multi_modal_data, - ) - if "prompt_token_ids" in request.prompt - else TextPrompt( - prompt=request.prompt, multi_modal_data=request.multi_modal_data - ) - ) - sampling_params.append( - SamplingParams( - n=n, - temperature=1.0, - top_p=1.0, - ignore_eos=True, - max_tokens=request.expected_output_len, - detokenize=not disable_detokenize, - ) - ) - lora_requests: Optional[list[LoRARequest]] = None - if engine_args.enable_lora: - lora_requests = [request.lora_request for request in requests] - - use_beam_search = False - - outputs = None - if not use_beam_search: - start = time.perf_counter() - outputs = llm.generate( - prompts, sampling_params, lora_request=lora_requests, use_tqdm=True - ) - end = time.perf_counter() - else: - assert lora_requests is None, "BeamSearch API does not support LoRA" - # output_len should be the same for all requests. - output_len = requests[0].expected_output_len - for request in requests: - assert request.expected_output_len == output_len - start = time.perf_counter() - llm.beam_search( - prompts, - BeamSearchParams( - beam_width=n, - max_tokens=output_len, - ignore_eos=True, - ), - ) - end = time.perf_counter() - return end - start, outputs - - -def run_vllm_chat( - requests: list[SampleRequest], - n: int, - engine_args: EngineArgs, - disable_detokenize: bool = False, -) -> tuple[float, list[RequestOutput]]: - """ - Run vLLM chat benchmark. This function is recommended ONLY for benchmarking - multimodal models as it properly handles multimodal inputs and chat - formatting. For non-multimodal models, use run_vllm() instead. - """ - from vllm import LLM, SamplingParams - - llm = LLM(**dataclasses.asdict(engine_args)) - - assert all( - llm.llm_engine.model_config.max_model_len - >= (request.prompt_len + request.expected_output_len) - for request in requests - ), ( - "Please ensure that max_model_len is greater than the sum of " - "prompt_len and expected_output_len for all requests." - ) - - prompts = [] - sampling_params: list[SamplingParams] = [] - for request in requests: - prompts.append(request.prompt) - sampling_params.append( - SamplingParams( - n=n, - temperature=1.0, - top_p=1.0, - ignore_eos=True, - max_tokens=request.expected_output_len, - detokenize=not disable_detokenize, - ) - ) - start = time.perf_counter() - outputs = llm.chat(prompts, sampling_params, use_tqdm=True) - end = time.perf_counter() - return end - start, outputs - - -async def run_vllm_async( - requests: list[SampleRequest], - n: int, - engine_args: AsyncEngineArgs, - disable_frontend_multiprocessing: bool = False, - disable_detokenize: bool = False, -) -> float: - from vllm import SamplingParams - - async with build_async_engine_client_from_engine_args( - engine_args, - disable_frontend_multiprocessing=disable_frontend_multiprocessing, - ) as llm: - model_config = await llm.get_model_config() - assert all( - model_config.max_model_len - >= (request.prompt_len + request.expected_output_len) - for request in requests - ), ( - "Please ensure that max_model_len is greater than the sum of" - " prompt_len and expected_output_len for all requests." - ) - - # Add the requests to the engine. - prompts: list[Union[TextPrompt, TokensPrompt]] = [] - sampling_params: list[SamplingParams] = [] - lora_requests: list[Optional[LoRARequest]] = [] - for request in requests: - prompts.append( - TokensPrompt( - prompt_token_ids=request.prompt["prompt_token_ids"], - multi_modal_data=request.multi_modal_data, - ) - if "prompt_token_ids" in request.prompt - else TextPrompt( - prompt=request.prompt, multi_modal_data=request.multi_modal_data - ) - ) - sampling_params.append( - SamplingParams( - n=n, - temperature=1.0, - top_p=1.0, - ignore_eos=True, - max_tokens=request.expected_output_len, - detokenize=not disable_detokenize, - ) - ) - lora_requests.append(request.lora_request) - - generators = [] - start = time.perf_counter() - for i, (prompt, sp, lr) in enumerate( - zip(prompts, sampling_params, lora_requests) - ): - generator = llm.generate(prompt, sp, lora_request=lr, request_id=f"test{i}") - generators.append(generator) - all_gens = merge_async_iterators(*generators) - async for i, res in all_gens: - pass - end = time.perf_counter() - return end - start - - -def run_hf( - requests: list[SampleRequest], - model: str, - tokenizer: PreTrainedTokenizerBase, - n: int, - max_batch_size: int, - trust_remote_code: bool, - disable_detokenize: bool = False, -) -> float: - llm = AutoModelForCausalLM.from_pretrained( - model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code - ) - if llm.config.model_type == "llama": - # To enable padding in the HF backend. - tokenizer.pad_token = tokenizer.eos_token - llm = llm.cuda() - - pbar = tqdm(total=len(requests)) - start = time.perf_counter() - batch: list[str] = [] - max_prompt_len = 0 - max_output_len = 0 - for i in range(len(requests)): - prompt = requests[i].prompt - prompt_len = requests[i].prompt_len - output_len = requests[i].expected_output_len - # Add the prompt to the batch. - batch.append(prompt) - max_prompt_len = max(max_prompt_len, prompt_len) - max_output_len = max(max_output_len, output_len) - if len(batch) < max_batch_size and i != len(requests) - 1: - # Check if we can add more requests to the batch. - next_prompt_len = requests[i + 1].prompt_len - next_output_len = requests[i + 1].expected_output_len - if ( - max(max_prompt_len, next_prompt_len) - + max(max_output_len, next_output_len) - ) <= 2048: - # We can add more requests to the batch. - continue - - # Generate the sequences. - input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids - llm_outputs = llm.generate( - input_ids=input_ids.cuda(), - do_sample=True, - num_return_sequences=n, - temperature=1.0, - top_p=1.0, - use_cache=True, - max_new_tokens=max_output_len, - ) - if not disable_detokenize: - # Include the decoding time. - tokenizer.batch_decode(llm_outputs, skip_special_tokens=True) - pbar.update(len(batch)) - - # Clear the batch. - batch = [] - max_prompt_len = 0 - max_output_len = 0 - end = time.perf_counter() - return end - start - - -def run_mii( - requests: list[SampleRequest], - model: str, - tensor_parallel_size: int, - output_len: int, -) -> float: - from mii import client, serve - - llm = serve(model, tensor_parallel=tensor_parallel_size) - prompts = [request.prompt for request in requests] - - start = time.perf_counter() - llm.generate(prompts, max_new_tokens=output_len) - end = time.perf_counter() - client = client(model) - client.terminate_server() - return end - start - - -def save_to_pytorch_benchmark_format( - args: argparse.Namespace, results: dict[str, Any] -) -> None: - pt_records = convert_to_pytorch_benchmark_format( - args=args, - metrics={ - "requests_per_second": [results["requests_per_second"]], - "tokens_per_second": [results["tokens_per_second"]], - }, - extra_info={ - k: results[k] for k in ["elapsed_time", "num_requests", "total_num_tokens"] - }, - ) - if pt_records: - # Don't use json suffix here as we don't want CI to pick it up - pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json" - write_to_json(pt_file, pt_records) - - -def get_requests(args, tokenizer): - # Common parameters for all dataset types. - common_kwargs = { - "dataset_path": args.dataset_path, - "random_seed": args.seed, - } - sample_kwargs = { - "tokenizer": tokenizer, - "lora_path": args.lora_path, - "max_loras": args.max_loras, - "num_requests": args.num_prompts, - "input_len": args.input_len, - "output_len": args.output_len, - } - - if args.dataset_path is None or args.dataset_name == "random": - sample_kwargs["range_ratio"] = args.random_range_ratio - sample_kwargs["prefix_len"] = args.prefix_len - dataset_cls = RandomDataset - elif args.dataset_name == "sharegpt": - dataset_cls = ShareGPTDataset - if args.backend == "vllm-chat": - sample_kwargs["enable_multimodal_chat"] = True - elif args.dataset_name == "sonnet": - assert tokenizer.chat_template or tokenizer.default_chat_template, ( - "Tokenizer/model must have chat template for sonnet dataset." - ) - dataset_cls = SonnetDataset - sample_kwargs["prefix_len"] = args.prefix_len - sample_kwargs["return_prompt_formatted"] = True - elif args.dataset_name == "burstgpt": - dataset_cls = BurstGPTDataset - elif args.dataset_name == "hf": - common_kwargs["no_stream"] = args.no_stream - if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS: - dataset_cls = VisionArenaDataset - common_kwargs["dataset_subset"] = None - common_kwargs["dataset_split"] = "train" - sample_kwargs["enable_multimodal_chat"] = True - elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS: - dataset_cls = InstructCoderDataset - common_kwargs["dataset_split"] = "train" - elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS: - dataset_cls = ConversationDataset - common_kwargs["dataset_subset"] = args.hf_subset - common_kwargs["dataset_split"] = args.hf_split - sample_kwargs["enable_multimodal_chat"] = True - elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS: - dataset_cls = AIMODataset - common_kwargs["dataset_subset"] = None - common_kwargs["dataset_split"] = "train" - else: - raise ValueError(f"Unknown dataset name: {args.dataset_name}") - # Remove None values - sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None} - return dataset_cls(**common_kwargs).sample(**sample_kwargs) - - -@deprecated( - "benchmark_throughput.py is deprecated and will be removed in a " - "future version. Please use 'vllm bench throughput' instead.", -) -def main(args: argparse.Namespace): - if args.seed is None: - args.seed = 0 - print(args) - random.seed(args.seed) - # Sample the requests. - tokenizer = AutoTokenizer.from_pretrained( - args.tokenizer, trust_remote_code=args.trust_remote_code - ) - requests = get_requests(args, tokenizer) - is_multi_modal = any(request.multi_modal_data is not None for request in requests) - request_outputs: Optional[list[RequestOutput]] = None - if args.backend == "vllm": - if args.async_engine: - elapsed_time = uvloop.run( - run_vllm_async( - requests, - args.n, - AsyncEngineArgs.from_cli_args(args), - args.disable_frontend_multiprocessing, - args.disable_detokenize, - ) - ) - else: - elapsed_time, request_outputs = run_vllm( - requests, - args.n, - EngineArgs.from_cli_args(args), - args.disable_detokenize, - ) - elif args.backend == "hf": - assert args.tensor_parallel_size == 1 - elapsed_time = run_hf( - requests, - args.model, - tokenizer, - args.n, - args.hf_max_batch_size, - args.trust_remote_code, - args.disable_detokenize, - ) - elif args.backend == "mii": - elapsed_time = run_mii( - requests, args.model, args.tensor_parallel_size, args.output_len - ) - elif args.backend == "vllm-chat": - elapsed_time, request_outputs = run_vllm_chat( - requests, args.n, EngineArgs.from_cli_args(args), args.disable_detokenize - ) - else: - raise ValueError(f"Unknown backend: {args.backend}") - - if request_outputs: - # Note: with the vllm and vllm-chat backends, - # we have request_outputs, which we use to count tokens. - total_prompt_tokens = 0 - total_output_tokens = 0 - for ro in request_outputs: - if not isinstance(ro, RequestOutput): - continue - total_prompt_tokens += ( - len(ro.prompt_token_ids) if ro.prompt_token_ids else 0 - ) - total_output_tokens += sum(len(o.token_ids) for o in ro.outputs if o) - total_num_tokens = total_prompt_tokens + total_output_tokens - else: - total_num_tokens = sum(r.prompt_len + r.expected_output_len for r in requests) - total_output_tokens = sum(r.expected_output_len for r in requests) - total_prompt_tokens = total_num_tokens - total_output_tokens - - if is_multi_modal and args.backend != "vllm-chat": - print( - "\033[91mWARNING\033[0m: Multi-modal request with " - f"{args.backend} backend detected. The " - "following metrics are not accurate because image tokens are not" - " counted. See vllm-project/vllm/issues/9778 for details." - ) - # TODO(vllm-project/vllm/issues/9778): Count multi-modal token length. - # vllm-chat backend counts the image tokens now - - print( - f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " - f"{total_num_tokens / elapsed_time:.2f} total tokens/s, " - f"{total_output_tokens / elapsed_time:.2f} output tokens/s" - ) - print(f"Total num prompt tokens: {total_prompt_tokens}") - print(f"Total num output tokens: {total_output_tokens}") - - # Output JSON results if specified - if args.output_json: - results = { - "elapsed_time": elapsed_time, - "num_requests": len(requests), - "total_num_tokens": total_num_tokens, - "requests_per_second": len(requests) / elapsed_time, - "tokens_per_second": total_num_tokens / elapsed_time, - } - with open(args.output_json, "w") as f: - json.dump(results, f, indent=4) - save_to_pytorch_benchmark_format(args, results) - - -def validate_args(args): - """ - Validate command-line arguments. - """ - - # === Deprecation and Defaulting === - if args.dataset is not None: - warnings.warn( - "The '--dataset' argument will be deprecated in the next release. " - "Please use '--dataset-name' and '--dataset-path' instead.", - stacklevel=2, - ) - args.dataset_path = args.dataset - - if not getattr(args, "tokenizer", None): - args.tokenizer = args.model - - # === Backend Validation === - valid_backends = {"vllm", "hf", "mii", "vllm-chat"} - if args.backend not in valid_backends: - raise ValueError(f"Unsupported backend: {args.backend}") - - # === Dataset Configuration === - if not args.dataset and not args.dataset_path: - print("When dataset path is not set, it will default to random dataset") - args.dataset_name = "random" - if args.input_len is None: - raise ValueError("input_len must be provided for a random dataset") - - # === Dataset Name Specific Checks === - # --hf-subset and --hf-split: only used - # when dataset_name is 'hf' - if args.dataset_name != "hf" and ( - getattr(args, "hf_subset", None) is not None - or getattr(args, "hf_split", None) is not None - ): - warnings.warn( - "--hf-subset and --hf-split will be ignored \ - since --dataset-name is not 'hf'.", - stacklevel=2, - ) - elif args.dataset_name == "hf": - if args.dataset_path in ( - VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys() - | ConversationDataset.SUPPORTED_DATASET_PATHS - ): - assert args.backend == "vllm-chat", ( - f"{args.dataset_path} needs to use vllm-chat as the backend." - ) # noqa: E501 - elif args.dataset_path in ( - InstructCoderDataset.SUPPORTED_DATASET_PATHS - | AIMODataset.SUPPORTED_DATASET_PATHS - ): - assert args.backend == "vllm", ( - f"{args.dataset_path} needs to use vllm as the backend." - ) # noqa: E501 - else: - raise ValueError(f"{args.dataset_path} is not supported by hf dataset.") - - # --random-range-ratio: only used when dataset_name is 'random' - if args.dataset_name != "random" and args.random_range_ratio is not None: - warnings.warn( - "--random-range-ratio will be ignored since \ - --dataset-name is not 'random'.", - stacklevel=2, - ) - - # --prefix-len: only used when dataset_name is 'random', 'sonnet', or not - # set. - if ( - args.dataset_name not in {"random", "sonnet", None} - and args.prefix_len is not None - ): - warnings.warn( - "--prefix-len will be ignored since --dataset-name\ - is not 'random', 'sonnet', or not set.", - stacklevel=2, - ) - - # === LoRA Settings === - if getattr(args, "enable_lora", False) and args.backend != "vllm": - raise ValueError("LoRA benchmarking is only supported for vLLM backend") - if getattr(args, "enable_lora", False) and args.lora_path is None: - raise ValueError("LoRA path must be provided when enable_lora is True") - - # === Backend-specific Validations === - if args.backend == "hf" and args.hf_max_batch_size is None: - raise ValueError("HF max batch size is required for HF backend") - if args.backend != "hf" and args.hf_max_batch_size is not None: - raise ValueError("HF max batch size is only for HF backend.") - - if ( - args.backend in {"hf", "mii"} - and getattr(args, "quantization", None) is not None - ): - raise ValueError("Quantization is only for vLLM backend.") - - if args.backend == "mii" and args.dtype != "auto": - raise ValueError("dtype must be auto for MII backend.") - if args.backend == "mii" and args.n != 1: - raise ValueError("n must be 1 for MII backend.") - if args.backend == "mii" and args.tokenizer != args.model: - raise ValueError("Tokenizer must be the same as the model for MII backend.") - - # --data-parallel is not supported currently. - # https://github.com/vllm-project/vllm/issues/16222 - if args.data_parallel_size > 1: - raise ValueError( - "Data parallel is not supported in offline benchmark, " - "please use benchmark serving instead" - ) - - -def create_argument_parser(): - parser = FlexibleArgumentParser(description="Benchmark the throughput.") - parser.add_argument( - "--backend", - type=str, - choices=["vllm", "hf", "mii", "vllm-chat"], - default="vllm", - ) - parser.add_argument( - "--dataset-name", - type=str, - choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"], - help="Name of the dataset to benchmark on.", - default="sharegpt", - ) - parser.add_argument( - "--no-stream", - action="store_true", - help="Do not load the dataset in streaming mode.", - ) - parser.add_argument( - "--dataset", - type=str, - default=None, - help="Path to the ShareGPT dataset, will be deprecated in\ - the next release. The dataset is expected to " - "be a json in form of list[dict[..., conversations: " - "list[dict[..., value: ]]]]", - ) - parser.add_argument( - "--dataset-path", type=str, default=None, help="Path to the dataset" - ) - parser.add_argument( - "--input-len", - type=int, - default=None, - help="Input prompt length for each request", - ) - parser.add_argument( - "--output-len", - type=int, - default=None, - help="Output length for each request. Overrides the " - "output length from the dataset.", - ) - parser.add_argument( - "--n", type=int, default=1, help="Number of generated sequences per prompt." - ) - parser.add_argument( - "--num-prompts", type=int, default=1000, help="Number of prompts to process." - ) - parser.add_argument( - "--hf-max-batch-size", - type=int, - default=None, - help="Maximum batch size for HF backend.", - ) - parser.add_argument( - "--output-json", - type=str, - default=None, - help="Path to save the throughput results in JSON format.", - ) - parser.add_argument( - "--async-engine", - action="store_true", - default=False, - help="Use vLLM async engine rather than LLM class.", - ) - parser.add_argument( - "--disable-frontend-multiprocessing", - action="store_true", - default=False, - help="Disable decoupled async engine frontend.", - ) - parser.add_argument( - "--disable-detokenize", - action="store_true", - help=( - "Do not detokenize the response (i.e. do not include " - "detokenization time in the measurement)" - ), - ) - # LoRA - parser.add_argument( - "--lora-path", - type=str, - default=None, - help="Path to the LoRA adapters to use. This can be an absolute path, " - "a relative path, or a Hugging Face model identifier.", - ) - parser.add_argument( - "--prefix-len", - type=int, - default=None, - help=f"Number of prefix tokens to be used in RandomDataset " - "and SonnetDataset. For RandomDataset, the total input " - "length is the sum of prefix-len (default: " - f"{RandomDataset.DEFAULT_PREFIX_LEN}) and a random context length " - "sampled from [input_len * (1 - range_ratio), " - "input_len * (1 + range_ratio)]. For SonnetDataset, " - f"prefix_len (default: {SonnetDataset.DEFAULT_PREFIX_LEN}) " - "controls how much of the input is fixed lines versus " - "random lines, but the total input length remains approximately " - "input_len tokens.", - ) - # random dataset - parser.add_argument( - "--random-range-ratio", - type=float, - default=None, - help=f"Range ratio (default : {RandomDataset.DEFAULT_RANGE_RATIO}) " - "for sampling input/output length, " - "used only for RandomDataset. Must be in the range [0, 1) to " - "define a symmetric sampling range " - "[length * (1 - range_ratio), length * (1 + range_ratio)].", - ) - - # hf dataset - parser.add_argument( - "--hf-subset", type=str, default=None, help="Subset of the HF dataset." - ) - parser.add_argument( - "--hf-split", type=str, default=None, help="Split of the HF dataset." - ) - - parser = AsyncEngineArgs.add_cli_args(parser) - - return parser - +import sys if __name__ == "__main__": - parser = create_argument_parser() - args = parser.parse_args() - if args.tokenizer is None: - args.tokenizer = args.model - validate_args(args) - main(args) + print("""DEPRECATED: This script has been moved to the vLLM CLI. + +Please use the following command instead: + vllm bench throughput + +For help with the new command, run: + vllm bench throughput --help + +Alternatively, you can run the new command directly with: + python -m vllm.entrypoints.cli.main bench throughput --help +""") + sys.exit(1) diff --git a/benchmarks/kernels/bench_block_fp8_gemm.py b/benchmarks/kernels/bench_block_fp8_gemm.py index 9663503e9baa0..f1e504499eaf6 100644 --- a/benchmarks/kernels/bench_block_fp8_gemm.py +++ b/benchmarks/kernels/bench_block_fp8_gemm.py @@ -4,7 +4,10 @@ import torch from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - w8a8_block_fp8_matmul, + apply_w8a8_block_fp8_linear, +) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + CUTLASS_BLOCK_FP8_SUPPORTED, ) from vllm.platforms import current_platform from vllm.triton_utils import triton as vllm_triton @@ -29,7 +32,7 @@ DEEPSEEK_V3_SHAPES = [ ] -def build_w8a8_block_fp8_runner(M, N, K, block_size, device): +def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass): """Build runner function for w8a8 block fp8 matmul.""" factor_for_scale = 1e-2 @@ -37,37 +40,54 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device): fp8_max, fp8_min = fp8_info.max, fp8_info.min # Create random FP8 tensors - A_fp32 = (torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max - A = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + A_ref = (torch.rand(M, K, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max - B_fp32 = (torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max - B = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + B_ref = (torch.rand(N, K, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max + B = B_ref.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) # Create scales block_n, block_k = block_size[0], block_size[1] n_tiles = (N + block_n - 1) // block_n k_tiles = (K + block_k - 1) // block_k - As = torch.rand(M, k_tiles, dtype=torch.float32, device=device) * factor_for_scale Bs = ( torch.rand(n_tiles, k_tiles, dtype=torch.float32, device=device) * factor_for_scale ) + # SM90 CUTLASS requires row-major format for scales + if use_cutlass and current_platform.is_device_capability(90): + Bs = Bs.T.contiguous() + def run(): - return w8a8_block_fp8_matmul(A, B, As, Bs, block_size, torch.bfloat16) + if use_cutlass: + return apply_w8a8_block_fp8_linear( + A_ref, B, block_size, Bs, cutlass_block_fp8_supported=True + ) + else: + return apply_w8a8_block_fp8_linear( + A_ref, B, block_size, Bs, cutlass_block_fp8_supported=False + ) return run +# Determine available providers +available_providers = ["torch-bf16", "w8a8-block-fp8-triton"] +plot_title = "BF16 vs W8A8 Block FP8 GEMMs" + +if CUTLASS_BLOCK_FP8_SUPPORTED: + available_providers.append("w8a8-block-fp8-cutlass") + + @vllm_triton.testing.perf_report( vllm_triton.testing.Benchmark( x_names=["batch_size"], x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384], x_log=False, line_arg="provider", - line_vals=["torch-bf16", "w8a8-block-fp8"], - line_names=["torch-bf16", "w8a8-block-fp8"], + line_vals=available_providers, + line_names=available_providers, ylabel="TFLOP/s (larger is better)", plot_name="BF16 vs W8A8 Block FP8 GEMMs", args={}, @@ -85,11 +105,22 @@ def benchmark_tflops(batch_size, provider, N, K, block_size=(128, 128)): ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph( lambda: torch.nn.functional.linear(a, b), quantiles=quantiles ) - else: # w8a8-block-fp8 - run_w8a8 = build_w8a8_block_fp8_runner(M, N, K, block_size, device) - ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph( - lambda: run_w8a8(), quantiles=quantiles + elif provider == "w8a8-block-fp8-triton": + run_w8a8_triton = build_w8a8_block_fp8_runner( + M, N, K, block_size, device, use_cutlass=False ) + ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph( + lambda: run_w8a8_triton(), quantiles=quantiles + ) + elif provider == "w8a8-block-fp8-cutlass": + run_w8a8_cutlass = build_w8a8_block_fp8_runner( + M, N, K, block_size, device, use_cutlass=True + ) + ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph( + lambda: run_w8a8_cutlass(), quantiles=quantiles + ) + else: + raise ValueError(f"Unknown provider: {provider}") to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3) return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms) diff --git a/benchmarks/kernels/benchmark_device_communicators.py b/benchmarks/kernels/benchmark_device_communicators.py new file mode 100644 index 0000000000000..a61c17edc1e28 --- /dev/null +++ b/benchmarks/kernels/benchmark_device_communicators.py @@ -0,0 +1,486 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Benchmark script for device communicators: +CustomAllreduce (oneshot, twoshot), PyNcclCommunicator, +and SymmMemCommunicator (multimem, two-shot). + +Usage: + torchrun --nproc_per_node= benchmark_device_communicators.py [options] + +Example: + torchrun --nproc_per_node=2 benchmark_device_communicators.py + --sequence-lengths 512 1024 2048 --num-warmup 10 --num-trials 100 +""" + +import json +import os +import time +from contextlib import nullcontext +from typing import Callable, Optional + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce +from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator +from vllm.distributed.device_communicators.symm_mem import SymmMemCommunicator +from vllm.logger import init_logger +from vllm.utils import FlexibleArgumentParser + +logger = init_logger(__name__) + +# Default sequence lengths to benchmark +DEFAULT_SEQUENCE_LENGTHS = [128, 512, 1024, 2048, 4096, 8192] + +# Fixed hidden size and dtype for all benchmarks +HIDDEN_SIZE = 8192 +BENCHMARK_DTYPE = torch.bfloat16 + +# CUDA graph settings +CUDA_GRAPH_CAPTURE_CYCLES = 10 + + +class CommunicatorBenchmark: + """Benchmark class for testing device communicators.""" + + def __init__( + self, + rank: int, + world_size: int, + device: torch.device, + cpu_group: ProcessGroup, + sequence_lengths: list[int], + ): + self.rank = rank + self.world_size = world_size + self.device = device + self.cpu_group = cpu_group + + # Calculate max_size_override based on largest sequence length + max_seq_len = max(sequence_lengths) + max_tensor_elements = max_seq_len * HIDDEN_SIZE + self.max_size_override = max_tensor_elements * BENCHMARK_DTYPE.itemsize + 1 + + # Initialize communicators + self.custom_allreduce = None + self.pynccl_comm = None + self.symm_mem_comm = None + self.symm_mem_comm_multimem = None + self.symm_mem_comm_two_shot = None + + self._init_communicators() + + def _init_communicators(self): + """Initialize all available communicators.""" + try: + self.custom_allreduce = CustomAllreduce( + group=self.cpu_group, + device=self.device, + max_size=self.max_size_override, + ) + if not self.custom_allreduce.disabled: + logger.info("Rank %s: CustomAllreduce initialized", self.rank) + else: + logger.info("Rank %s: CustomAllreduce disabled", self.rank) + except Exception as e: + logger.warning( + "Rank %s: Failed to initialize CustomAllreduce: %s", self.rank, e + ) + self.custom_allreduce = None + + try: + self.pynccl_comm = PyNcclCommunicator( + group=self.cpu_group, device=self.device + ) + if not self.pynccl_comm.disabled: + logger.info("Rank %s: PyNcclCommunicator initialized", self.rank) + else: + logger.info("Rank %s: PyNcclCommunicator disabled", self.rank) + self.pynccl_comm = None + except Exception as e: + logger.warning( + "Rank %s: Failed to initialize PyNcclCommunicator: %s", self.rank, e + ) + self.pynccl_comm = None + + # Initialize variants for SymmMemCommunicator + try: + self.symm_mem_comm_multimem = SymmMemCommunicator( + group=self.cpu_group, + device=self.device, + force_multimem=True, + max_size_override=self.max_size_override, + ) + if not self.symm_mem_comm_multimem.disabled: + logger.info( + "Rank %s: SymmMemCommunicator (multimem) initialized", self.rank + ) + else: + self.symm_mem_comm_multimem = None + except Exception as e: + logger.warning( + "Rank %s: Failed to initialize SymmMemCommunicator (multimem): %s", + self.rank, + e, + ) + self.symm_mem_comm_multimem = None + + try: + self.symm_mem_comm_two_shot = SymmMemCommunicator( + group=self.cpu_group, + device=self.device, + force_multimem=False, + max_size_override=self.max_size_override, + ) + if not self.symm_mem_comm_two_shot.disabled: + logger.info( + "Rank %s: SymmMemCommunicator (two_shot) initialized", self.rank + ) + else: + self.symm_mem_comm_two_shot = None + except Exception as e: + logger.warning( + "Rank %s: Failed to initialize SymmMemCommunicator (two_shot): %s", + self.rank, + e, + ) + self.symm_mem_comm_two_shot = None + + def benchmark_allreduce( + self, sequence_length: int, num_warmup: int, num_trials: int + ) -> dict[str, float]: + """Benchmark allreduce operations for all available communicators.""" + + results = {} + + # Define communicators with their benchmark functions + communicators = [] + + if self.custom_allreduce is not None: + comm = self.custom_allreduce + # CustomAllreduce one-shot + communicators.append( + ( + "ca_1stage", + lambda t, c=comm: c.custom_all_reduce(t), + lambda t, c=comm: c.should_custom_ar(t), + comm.capture(), + "1stage", # env variable value + ) + ) + # CustomAllreduce two-shot + communicators.append( + ( + "ca_2stage", + lambda t, c=comm: c.custom_all_reduce(t), + lambda t, c=comm: c.should_custom_ar(t), + comm.capture(), + "2stage", # env variable value + ) + ) + + if self.pynccl_comm is not None: + comm = self.pynccl_comm + communicators.append( + ( + "pynccl", + lambda t, c=comm: c.all_reduce(t), + lambda t: True, # Always available if initialized + nullcontext(), + None, # no env variable needed + ) + ) + + if self.symm_mem_comm_multimem is not None: + comm = self.symm_mem_comm_multimem + communicators.append( + ( + "symm_mem_multimem", + lambda t, c=comm: c.all_reduce(t), + lambda t, c=comm: c.should_use_symm_mem(t), + nullcontext(), + None, # no env variable needed + ) + ) + + if self.symm_mem_comm_two_shot is not None: + comm = self.symm_mem_comm_two_shot + communicators.append( + ( + "symm_mem_two_shot", + lambda t, c=comm: c.all_reduce(t), + lambda t, c=comm: c.should_use_symm_mem(t), + nullcontext(), + None, # no env variable needed + ) + ) + + # Benchmark each communicator + for name, allreduce_fn, should_use_fn, context, env_var in communicators: + # Set environment variable if needed + if env_var is not None: + os.environ["VLLM_CUSTOM_ALLREDUCE_ALGO"] = env_var + else: + # Clear the environment variable to avoid interference + os.environ.pop("VLLM_CUSTOM_ALLREDUCE_ALGO", None) + + latency = self.benchmark_allreduce_single( + sequence_length, + allreduce_fn, + should_use_fn, + context, + num_warmup, + num_trials, + ) + if latency is not None: + results[name] = latency + + return results + + def benchmark_allreduce_single( + self, + sequence_length: int, + allreduce_fn: Callable[[torch.Tensor], Optional[torch.Tensor]], + should_use_fn: Callable[[torch.Tensor], bool], + context, + num_warmup: int, + num_trials: int, + ) -> Optional[float]: + """Benchmark method with CUDA graph optimization.""" + try: + # Create test tensor (2D: sequence_length x hidden_size) + tensor = torch.randn( + sequence_length, HIDDEN_SIZE, dtype=BENCHMARK_DTYPE, device=self.device + ) + if not should_use_fn(tensor): + return None + + torch.cuda.synchronize() + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + graph_input = tensor.clone() + + # Warmup before capture + for _ in range(3): + allreduce_fn(graph_input) + + # Capture the graph using context manager + with context: + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + for _ in range(CUDA_GRAPH_CAPTURE_CYCLES): + allreduce_fn(graph_input) + + torch.cuda.synchronize() + for _ in range(num_warmup): + graph.replay() + torch.cuda.synchronize() + + torch.cuda.synchronize() + start_time = time.perf_counter() + + for _ in range(num_trials): + graph.replay() + torch.cuda.synchronize() + + end_time = time.perf_counter() + + # Convert to ms and divide by CUDA_GRAPH_CAPTURE_CYCLES + return ( + (end_time - start_time) / num_trials / CUDA_GRAPH_CAPTURE_CYCLES * 1000 + ) + + except Exception as e: + logger.error("CUDA graph benchmark failed: %s", e) + raise RuntimeError( + f"CUDA graph benchmark failed for communicator: {e}" + ) from e + + +def _calculate_speedup_info(comm_results: dict[str, float]) -> str: + """Calculate speedup information for a single tensor size.""" + if not comm_results: + return "N/A" + + # Find the fastest communicator + fastest_comm = min(comm_results.keys(), key=lambda k: comm_results[k]) + fastest_time = comm_results[fastest_comm] + + # Calculate speedup vs PyNccl if available + if "pynccl" in comm_results: + pynccl_time = comm_results["pynccl"] + speedup = pynccl_time / fastest_time + return f"{fastest_comm} ({speedup:.2f}x)" + else: + return f"{fastest_comm} (N/A)" + + +def print_results( + results: dict[str, dict[str, float]], sequence_lengths: list[int], world_size: int +): + """Print benchmark results in a formatted table.""" + + print(f"\n{'=' * 130}") + print("Device Communicator Benchmark Results") + print( + f"World Size: {world_size}, Data Type: {BENCHMARK_DTYPE}, " + f"Hidden Size: {HIDDEN_SIZE}" + ) + print(f"{'=' * 130}") + + # Get all communicator names + all_comms = set() + for size_results in results.values(): + all_comms.update(size_results.keys()) + + all_comms = sorted(list(all_comms)) + + # Print header + header = f"{'Tensor Shape':<20}{'Tensor Size':<15}" + for comm in all_comms: + header += f"{comm:<20}" + header += f"{'Best (Speedup vs PyNccl)':<30}" + print(header) + print("-" * len(header)) + + # Print results for each sequence length + for seq_len in sequence_lengths: + if seq_len in results: + # Calculate tensor size in elements and bytes + tensor_elements = seq_len * HIDDEN_SIZE + tensor_bytes = tensor_elements * BENCHMARK_DTYPE.itemsize + + # Format tensor size (MB) + tensor_size_mb = tensor_bytes / (1024 * 1024) + tensor_size_str = f"{tensor_size_mb:.2f} MB" + + # Format tensor shape + tensor_shape = f"({seq_len}, {HIDDEN_SIZE})" + + row = f"{tensor_shape:<20}{tensor_size_str:<15}" + for comm in all_comms: + if comm in results[seq_len]: + row += f"{results[seq_len][comm]:<20.3f}" + else: + row += f"{'N/A':<20}" + + # Calculate speedup information + speedup_info = _calculate_speedup_info(results[seq_len]) + row += f"{speedup_info:<30}" + + print(row) + + print(f"{'=' * 130}") + print("All times are in milliseconds (ms) per allreduce operation") + print("Speedup column shows: fastest_algorithm (speedup_vs_pynccl)") + + +def main(): + parser = FlexibleArgumentParser(description="Benchmark device communicators") + + parser.add_argument( + "--sequence-lengths", + type=int, + nargs="+", + default=DEFAULT_SEQUENCE_LENGTHS, + help="Sequence lengths to benchmark (tensor shape: seq_len x hidden_size)", + ) + + parser.add_argument( + "--num-warmup", type=int, default=5, help="Number of warmup iterations" + ) + + parser.add_argument( + "--num-trials", type=int, default=50, help="Number of benchmark trials" + ) + + parser.add_argument("--output-json", type=str, help="Output results to JSON file") + + args = parser.parse_args() + + # Initialize distributed + if not dist.is_initialized(): + dist.init_process_group(backend="gloo") + rank = dist.get_rank() + world_size = dist.get_world_size() + + # Set device + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + + # Get CPU process group + cpu_group = dist.new_group(backend="gloo") + + # Disable USE_SYMM_MEM to avoid affecting the max_sizes + # in symm_mem and custom_all_reduce for benchmark + os.environ["VLLM_ALLREDUCE_USE_SYMM_MEM"] = "0" + + # Initialize benchmark + benchmark = CommunicatorBenchmark( + rank, world_size, device, cpu_group, args.sequence_lengths + ) + + # Run benchmarks + all_results = {} + + for seq_len in args.sequence_lengths: + if rank == 0: + logger.info( + "Benchmarking sequence length: %s (tensor shape: %s x %s)", + seq_len, + seq_len, + HIDDEN_SIZE, + ) + + results = benchmark.benchmark_allreduce( + sequence_length=seq_len, + num_warmup=args.num_warmup, + num_trials=args.num_trials, + ) + + all_results[seq_len] = results + + # Synchronize between ranks + dist.barrier() + + # Print results (only rank 0) + if rank == 0: + print_results(all_results, args.sequence_lengths, world_size) + + # Save to JSON if requested + if args.output_json: + # Add speedup information to results + enhanced_results = {} + for seq_len, comm_results in all_results.items(): + enhanced_results[seq_len] = { + "timings": comm_results, + "speedup_info": _calculate_speedup_info(comm_results), + } + + output_data = { + "world_size": world_size, + "dtype": str(BENCHMARK_DTYPE), + "hidden_size": HIDDEN_SIZE, + "sequence_lengths": args.sequence_lengths, + "num_warmup": args.num_warmup, + "num_trials": args.num_trials, + "cuda_graph_capture_cycles": CUDA_GRAPH_CAPTURE_CYCLES, + "results": enhanced_results, + } + + with open(args.output_json, "w") as f: + json.dump(output_data, f, indent=2) + + logger.info("Results saved to %s", args.output_json) + + # Cleanup + if cpu_group != dist.group.WORLD: + dist.destroy_process_group(cpu_group) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 6259aa0dd6290..94f3f1ae11f27 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -594,7 +594,11 @@ def main(args: argparse.Namespace): E = config.n_routed_experts topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size - elif config.architectures[0] in ("Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"): + elif config.architectures[0] in ( + "Qwen2MoeForCausalLM", + "Qwen3MoeForCausalLM", + "Qwen3NextForCausalLM", + ): E = config.num_experts topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size diff --git a/benchmarks/kernels/benchmark_polynorm.py b/benchmarks/kernels/benchmark_polynorm.py new file mode 100644 index 0000000000000..9ac8f5e6594e4 --- /dev/null +++ b/benchmarks/kernels/benchmark_polynorm.py @@ -0,0 +1,155 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import itertools + +import torch + +from vllm import _custom_ops as vllm_ops +from vllm.triton_utils import triton + + +def polynorm_naive( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float = 1e-6, +): + orig_shape = x.shape + x = x.view(-1, x.shape[-1]) + + def norm(x, eps: float): + return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps) + + x = x.float() + return ( + ( + weight[0] * norm(x**3, eps) + + weight[1] * norm(x**2, eps) + + weight[2] * norm(x, eps) + + bias + ) + .to(weight.dtype) + .view(orig_shape) + ) + + +def polynorm_vllm( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float = 1e-6, +): + orig_shape = x.shape + x = x.view(-1, x.shape[-1]) + + out = torch.empty_like(x) + vllm_ops.poly_norm(out, x, weight, bias, eps) + output = out + + output = output.view(orig_shape) + return output + + +def calculate_diff(batch_size, seq_len, hidden_dim): + dtype = torch.bfloat16 + x = torch.randn(batch_size, seq_len, hidden_dim, dtype=dtype, device="cuda") + weight = torch.ones(3, dtype=dtype, device="cuda") + bias = torch.ones(1, dtype=dtype, device="cuda") + + output_naive = polynorm_naive(x, weight, bias) + output_vllm = polynorm_vllm(x, weight, bias) + + if torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2): + print("✅ All implementations match") + else: + print("❌ Implementations differ") + + +batch_size_range = [2**i for i in range(0, 7, 2)] +seq_length_range = [2**i for i in range(6, 11, 1)] +dim_range = [2048, 4096] +configs = list(itertools.product(dim_range, batch_size_range, seq_length_range)) + + +def get_benchmark(): + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["dim", "batch_size", "seq_len"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["naive", "vllm"], + line_names=["Naive", "vLLM"], + styles=[("blue", "-"), ("red", "-")], + ylabel="us", + plot_name="polynorm-perf", + args={}, + ) + ) + def benchmark(dim, batch_size, seq_len, provider): + dtype = torch.bfloat16 + hidden_dim = dim * 4 + + x = torch.randn(batch_size, seq_len, hidden_dim, dtype=dtype, device="cuda") + weight = torch.ones(3, dtype=dtype, device="cuda") + bias = torch.ones(1, dtype=dtype, device="cuda") + + quantiles = [0.5, 0.2, 0.8] + + if provider == "naive": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: polynorm_naive(x, weight, bias), + quantiles=quantiles, + ) + else: + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: polynorm_vllm(x, weight, bias), + quantiles=quantiles, + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + return benchmark + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--batch-size", + type=int, + default=4, + help="Batch size", + ) + parser.add_argument( + "--seq-len", + type=int, + default=128, + help="Sequence length", + ) + parser.add_argument( + "--hidden-dim", + type=int, + default=8192, + help="Intermediate size of MLP", + ) + parser.add_argument( + "--save-path", + type=str, + default="./configs/polnorm/", + help="Path to save polnorm benchmark results", + ) + + args = parser.parse_args() + + # Run correctness test + calculate_diff( + batch_size=args.batch_size, + seq_len=args.seq_len, + hidden_dim=args.hidden_dim, + ) + + benchmark = get_benchmark() + # Run performance benchmark + benchmark.run(print_data=True, save_path=args.save_path) diff --git a/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py b/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py index 0650cbf3cc18e..c7a4066b39d70 100644 --- a/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py +++ b/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py @@ -1,77 +1,675 @@ -#!/usr/bin/env python3 # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import time +from collections.abc import Callable +import matplotlib.pyplot as plt +import numpy as np import torch from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( - silu_mul_fp8_quant_deep_gemm, + silu_mul_fp8_quant_deep_gemm_cuda, ) from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used -def benchmark(E, T, H, G=128, runs=50): - current_platform.seed_everything(42) - y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda") - tokens_per_expert = torch.randint( - T // 2, T, size=(E,), dtype=torch.int32, device="cuda" +@triton.jit +def _silu_mul_fp8_quant_deep_gemm( + # Pointers ------------------------------------------------------------ + input_ptr, # 16-bit activations (E, T, 2*H) + y_q_ptr, # fp8 quantized activations (E, T, H) + y_s_ptr, # 16-bit scales (E, T, G) + counts_ptr, # int32 num tokens per expert (E) + # Sizes --------------------------------------------------------------- + H: tl.constexpr, # hidden dimension (per output) + GROUP_SIZE: tl.constexpr, # elements per group (usually 128) + # Strides for input (elements) --------------------------------------- + stride_i_e, + stride_i_t, + stride_i_h, + # Strides for y_q (elements) ----------------------------------------- + stride_yq_e, + stride_yq_t, + stride_yq_h, + # Strides for y_s (elements) ----------------------------------------- + stride_ys_e, + stride_ys_t, + stride_ys_g, + # Stride for counts (elements) + stride_counts_e, + # Numeric params ------------------------------------------------------ + eps: tl.constexpr, + fp8_min: tl.constexpr, + fp8_max: tl.constexpr, + use_ue8m0: tl.constexpr, + # Meta --------------------------------------------------------------- + BLOCK: tl.constexpr, + NUM_STAGES: tl.constexpr, +): + G = H // GROUP_SIZE + + # map program id -> (e, g) + pid = tl.program_id(0) + e = pid // G + g = pid % G + + e = e.to(tl.int64) + g = g.to(tl.int64) + + # number of valid tokens for this expert + n_tokens = tl.load(counts_ptr + e * stride_counts_e).to(tl.int64) + + cols = tl.arange(0, BLOCK).to(tl.int64) + mask = cols < BLOCK + + base_input_offset = e * stride_i_e + g * GROUP_SIZE * stride_i_h + base_gate_offset = base_input_offset + cols * stride_i_h + base_up_offset = base_input_offset + H * stride_i_h + cols * stride_i_h + base_yq_offset = e * stride_yq_e + g * GROUP_SIZE * stride_yq_h + cols * stride_yq_h + base_ys_offset = e * stride_ys_e + g * stride_ys_g + + for t in tl.range(0, n_tokens, num_stages=NUM_STAGES): + gate = tl.load( + input_ptr + base_gate_offset + t * stride_i_t, mask=mask, other=0.0 + ).to(tl.float32) + up = tl.load(input_ptr + base_up_offset + t * stride_i_t, mask=mask, other=0.0) + + gate = gate * (1.0 / (1.0 + tl.exp(-gate))) + y = gate * up + + y_s = tl.maximum(tl.max(tl.abs(y)), eps) / fp8_max + if use_ue8m0: + y_s = tl.exp2(tl.ceil(tl.log2(y_s))) + + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + base_yq_offset + t * stride_yq_t, y_q, mask=mask) + tl.store(y_s_ptr + base_ys_offset + t * stride_ys_t, y_s) + + +def silu_mul_fp8_quant_deep_gemm_triton( + y: torch.Tensor, # (E, T, 2*H) + tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert + num_parallel_tokens, + group_size: int = 128, + eps: float = 1e-10, +) -> tuple[torch.Tensor, torch.Tensor]: + """Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales + + y has shape (E, T, 2*H). The first half of the last dimension is + silu-activated, multiplied by the second half, then quantized into FP8. + + Returns `(y_q, y_s)` where + * `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H] + * `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T) + """ + assert y.ndim == 3, "y must be (E, T, 2*H)" + E, T, H2 = y.shape + assert H2 % 2 == 0, "last dim of y must be even (2*H)" + H = H2 // 2 + G = (H + group_size - 1) // group_size + assert H % group_size == 0, "H must be divisible by group_size" + assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E, ( + "tokens_per_expert must be shape (E,)" + ) + tokens_per_expert = tokens_per_expert.to(device=y.device, dtype=torch.int32) + + # allocate outputs + fp8_dtype = torch.float8_e4m3fn + y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device) + + # strides (elements) + stride_i_e, stride_i_t, stride_i_h = y.stride() + stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride() + + # desired scale strides (elements): (T*G, 1, T) + stride_ys_e = T * G + stride_ys_t = 1 + stride_ys_g = T + y_s = torch.empty_strided( + (E, T, G), + (stride_ys_e, stride_ys_t, stride_ys_g), + dtype=torch.float32, + device=y.device, ) + stride_cnt_e = tokens_per_expert.stride()[0] + + # Static grid over experts and H-groups. + # A loop inside the kernel handles the token dim + grid = (E * G,) + + f_info = torch.finfo(fp8_dtype) + fp8_max = f_info.max + fp8_min = f_info.min + + _silu_mul_fp8_quant_deep_gemm[grid]( + y, + y_q, + y_s, + tokens_per_expert, + H, + group_size, + stride_i_e, + stride_i_t, + stride_i_h, + stride_yq_e, + stride_yq_t, + stride_yq_h, + stride_ys_e, + stride_ys_t, + stride_ys_g, + stride_cnt_e, + eps, + fp8_min, + fp8_max, + is_deep_gemm_e8m0_used(), + BLOCK=group_size, + NUM_STAGES=4, + num_warps=1, + ) + + return y_q, y_s + + +# Parse generation strategies +strategies = ["uniform", "max_t", "first_t"] + + +def benchmark( + kernel: Callable, + E: int, + T: int, + H: int, + total_tokens: int, + num_parallel_tokens: int = 64, + G: int = 128, + runs: int = 200, + num_warmups: int = 20, + gen_strategy: str = "default", + iterations_per_run: int = 20, +): + def generate_data(seed_offset=0): + """Generate input data with given seed offset""" + current_platform.seed_everything(42 + seed_offset) + y = torch.rand((E, T, 2 * H), dtype=torch.bfloat16, device="cuda").contiguous() + + if gen_strategy == "uniform": + r = torch.rand(size=(E,), device="cuda") + r /= r.sum() + r *= total_tokens + tokens_per_expert = r.int() + tokens_per_expert = torch.minimum( + tokens_per_expert, + torch.ones((E,), device=r.device, dtype=torch.int) * T, + ) + elif gen_strategy == "max_t": + tokens_per_expert = torch.empty(size=(E,), dtype=torch.int32, device="cuda") + tokens_per_expert.fill_(total_tokens / E) + elif gen_strategy == "first_t": + tokens_per_expert = torch.zeros(size=(E,), dtype=torch.int32, device="cuda") + tokens_per_expert[0] = min(T, total_tokens) + else: + raise ValueError(f"Unknown generation strategy: {gen_strategy}") + return y, tokens_per_expert + + dataset_count = 4 + # Pre-generate different input matrices for each iteration to avoid cache effects + data_sets = [generate_data(i) for i in range(dataset_count)] + # Warmup - for _ in range(10): - silu_mul_fp8_quant_deep_gemm(y, tokens_per_expert, group_size=G) - torch.cuda.synchronize() + y, tokens_per_expert = data_sets[0] + for _ in range(num_warmups): + kernel( + y, tokens_per_expert, num_parallel_tokens=num_parallel_tokens, group_size=G + ) + torch.cuda.synchronize() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) # Benchmark - torch.cuda.synchronize() - start = time.perf_counter() + latencies: list[float] = [] for _ in range(runs): - silu_mul_fp8_quant_deep_gemm(y, tokens_per_expert, group_size=G) - torch.cuda.synchronize() + torch.cuda.synchronize() - avg_time = (time.perf_counter() - start) / runs * 1000 + start_event.record() + for i in range(iterations_per_run): + y, tokens_per_expert = data_sets[i % dataset_count] + kernel( + y, + tokens_per_expert, + num_parallel_tokens=num_parallel_tokens, + group_size=G, + ) + end_event.record() + end_event.synchronize() - # Calculate actual work done (only count valid tokens) + total_time_ms = start_event.elapsed_time(end_event) + per_iter_time_ms = total_time_ms / iterations_per_run + latencies.append(per_iter_time_ms) + + # Use median instead of average for better outlier handling + median_time_ms = np.median(latencies) + median_time_s = median_time_ms / 1000 + + # Calculate actual work done (using first dataset for consistency) + _, tokens_per_expert = data_sets[0] actual_tokens = tokens_per_expert.sum().item() actual_elements = actual_tokens * H # GFLOPS: operations per element = exp + 3 muls + 1 div + quantization ops ≈ 8 ops ops_per_element = 8 total_ops = actual_elements * ops_per_element - gflops = total_ops / (avg_time / 1000) / 1e9 + gflops = total_ops / median_time_s / 1e9 # Memory bandwidth: bfloat16 inputs (2 bytes), fp8 output (1 byte), scales (4 bytes) input_bytes = actual_tokens * 2 * H * 2 # 2*H bfloat16 inputs output_bytes = actual_tokens * H * 1 # H fp8 outputs scale_bytes = actual_tokens * (H // G) * 4 # scales in float32 total_bytes = input_bytes + output_bytes + scale_bytes - memory_bw = total_bytes / (avg_time / 1000) / 1e9 + memory_bw = total_bytes / median_time_s / 1e9 - return avg_time, gflops, memory_bw + HOPPER_BANDWIDTH_TBPS = 3.35 + return ( + median_time_ms, + gflops, + memory_bw, + (memory_bw / (HOPPER_BANDWIDTH_TBPS * 1024)) * 100, + ) +def create_comparison_plot( + ratio, cuda_times, baseline_times, config_labels, strategy_name, id +): + """Create a comparison plot for a specific generation strategy""" + fig, ax = plt.subplots(1, 1, figsize=(16, 6)) + + # Configure x-axis positions + x = np.arange(len(config_labels)) + width = 0.35 + + # Execution Time plot (lower is better) + ax.bar( + x - width / 2, cuda_times, width, label="CUDA Kernel", alpha=0.8, color="blue" + ) + ax.bar( + x + width / 2, + baseline_times, + width, + label="Baseline", + alpha=0.8, + color="orange", + ) + + # Add speedup labels over each bar pair + for i in range(len(x)): + speedup = ratio[i] + max_height = max(cuda_times[i], baseline_times[i]) + ax.text( + x[i], + max_height + max_height * 0.02, + f"{speedup:.2f}x", + ha="center", + va="bottom", + fontweight="bold", + fontsize=9, + ) + + ax.set_xlabel("Configuration") + ax.set_ylabel("% Utilization") + ax.set_title( + f"Memory Bandwidth Utilization (%) - {strategy_name}\n(Higher is Better)" + ) + ax.set_xticks(x) + ax.set_xticklabels(config_labels, rotation=45, ha="right") + ax.legend() + ax.grid(True, alpha=0.3) + + plt.tight_layout() + return fig, ax + + +def create_combined_plot(all_results): + """Create a combined plot with all strategies in one PNG""" + num_strategies = len(all_results) + fig, axes = plt.subplots(num_strategies, 1, figsize=(20, 6 * num_strategies)) + + if num_strategies == 1: + axes = [axes] + + for idx, ( + strategy_name, + ratio, + cuda_times, + baseline_times, + config_labels, + ) in enumerate(all_results): + ax = axes[idx] + + # Configure x-axis positions + x = np.arange(len(config_labels)) + width = 0.35 + + # Execution Time plot (lower is better) + ax.bar( + x - width / 2, + cuda_times, + width, + label="CUDA Kernel", + alpha=0.8, + color="blue", + ) + ax.bar( + x + width / 2, + baseline_times, + width, + label="Baseline", + alpha=0.8, + color="orange", + ) + + # Add speedup labels over each bar pair + for i in range(len(x)): + speedup = ratio[i] + max_height = max(cuda_times[i], baseline_times[i]) + ax.text( + x[i], + max_height + max_height * 0.02, + f"{speedup:.2f}x", + ha="center", + va="bottom", + fontweight="bold", + fontsize=9, + ) + + ax.set_xlabel("Configuration") + ax.set_ylabel("% Utilization") + ax.set_title( + f"Memory Bandwidth Utilization (%) - {strategy_name}\n(Higher is Better)" + ) + ax.set_xticks(x) + ax.set_xticklabels(config_labels, rotation=45, ha="right") + ax.legend() + ax.grid(True, alpha=0.3) + + plt.tight_layout() + filename = "../../silu_bench/silu_benchmark_combined.png" + plt.savefig(filename, dpi=300, bbox_inches="tight") + plt.show() + + return filename + + +outer_dim = 7168 configs = [ - (8, 32, 1024), - (16, 64, 2048), - (32, 128, 4096), # DeepSeekV3 Configs - (256, 16, 7168), - (256, 32, 7168), - (256, 64, 7168), - (256, 128, 7168), - (256, 256, 7168), - (256, 512, 7168), + (8, 1024, 7168), + # DeepSeekV3 Configs + (32, 1024, 7168), + # DeepSeekV3 Configs (256, 1024, 7168), ] -print(f"GPU: {torch.cuda.get_device_name()}") -print(f"{'Config':<20} {'Time(ms)':<10} {'GFLOPS':<10} {'GB/s':<10}") -print("-" * 50) +runs = 100 +num_warmups = 20 -for E, T, H in configs: - try: - time_ms, gflops, gbps = benchmark(E, T, H) - print(f"E={E:3d},T={T:4d},H={H:4d} {time_ms:8.3f} {gflops:8.1f} {gbps:8.1f}") - except Exception: - print(f"E={E:3d},T={T:4d},H={H:4d} FAILED") +strategy_descriptions = { + "uniform": "Uniform Random", + "max_t": "Even Assignment", + "first_t": "experts[0] = T, experts[1:] = 0", +} + +print(f"GPU: {torch.cuda.get_device_name()}") +print(f"Testing strategies: {', '.join(strategies)}") +print(f"Configurations: {len(configs)} configs") + +all_results = [] + +# Run benchmarks for each strategy +for id, strategy in enumerate(strategies): + print(f"\n{'=' * 60}") + print(f"Testing strategy: {strategy_descriptions[strategy]}") + print(f"{'=' * 60}") + + # Collect benchmark data for both algorithms + config_labels = [] + config_x_axis = [] + all_cuda_results = [] + all_baseline_results = [] + all_ratios = [] + + for E, T, H in configs: + total_tokens_config = [8 * E, 16 * E, 32 * E, 64 * E, 128 * E, 256 * E] + config_x_axis.append(total_tokens_config) + + cuda_results = [] + baseline_results = [] + ratios = [] + + for total_tokens in total_tokens_config: + config_label = f"E={E},T={T},H={H},TT={total_tokens}" + config_labels.append(config_label) + + # CUDA kernel results + time_ms_cuda, gflops, gbps, perc = benchmark( + silu_mul_fp8_quant_deep_gemm_cuda, + E, + T, + H, + total_tokens, + runs=runs, + num_warmups=num_warmups, + gen_strategy=strategy, + ) + cuda_results.append((time_ms_cuda, gflops, gbps, perc)) + + # Baseline results + time_ms_triton, gflops, gbps, perc = benchmark( + silu_mul_fp8_quant_deep_gemm_triton, + E, + T, + H, + total_tokens, + runs=runs, + num_warmups=num_warmups, + gen_strategy=strategy, + ) + baseline_results.append((time_ms_triton, gflops, gbps, perc)) + ratios.append(time_ms_triton / time_ms_cuda) + + print(f"Completed: {config_label}") + all_cuda_results.append(cuda_results) + all_baseline_results.append(baseline_results) + all_ratios.append(ratios) + + # Store results for combined plotting + all_results.append( + ( + strategy_descriptions[strategy], + all_ratios, + all_cuda_results, + all_baseline_results, + config_labels, + config_x_axis, + ) + ) + + # Print summary table for this strategy + print(f"\nSummary Table - {strategy_descriptions[strategy]}:") + print(f"{'Config':<20} {'CUDA Time(ms)':<12} {'Base Time(ms)':<12} {'Speedup':<8}") + print("-" * 60) + + for i, (E, T, H) in enumerate(configs): + speedup = baseline_results[i][0] / cuda_results[i][0] + config_label = f"E={E:3d},T={T:4d},H={H:4d}" + print( + f"{config_label:<20} {cuda_results[i][0]:8.5f} " + f"{baseline_results[i][0]:8.5f} {speedup:6.2f}x" + ) + + +def create_total_tokens_plot(all_results): + num_strategies = len(all_results) + num_configs = len(configs) + + # Create side-by-side subplots: 2 columns for speedup and bandwidth percentage + fig, axs = plt.subplots( + num_strategies, num_configs * 2, figsize=(28, 6 * num_strategies) + ) + + # Add main title to the entire figure + fig.suptitle( + "Performance Analysis: Speedup vs Bandwidth Utilization (Triton & CUDA)", + fontsize=16, + fontweight="bold", + y=0.98, + ) + + # Handle single strategy case + if num_strategies == 1: + axs = axs.reshape(1, -1) + + # Handle single config case + if num_configs == 1: + axs = axs.reshape(-1, 2) + + for strategy_idx, result in enumerate(all_results): + ( + strategy_name, + all_ratios, + all_cuda_results, + all_baseline_results, + config_labels, + config_x_axis, + ) = result + + for config_idx in range(num_configs): + # Speedup plot (left column) + ax_speedup = axs[strategy_idx, config_idx * 2] + # Bandwidth plot (right column) + ax_bandwidth = axs[strategy_idx, config_idx * 2 + 1] + + E, T, H = configs[config_idx] + ratios = all_ratios[config_idx] + total_tokens_values = config_x_axis[config_idx] + + # Extract CUDA and Triton bandwidth percentages + cuda_bandwidth_percentages = [ + result[3] for result in all_cuda_results[config_idx] + ] + triton_bandwidth_percentages = [ + result[3] for result in all_baseline_results[config_idx] + ] + + # Plot speedup ratios vs total tokens (left plot) + ax_speedup.plot( + total_tokens_values, ratios, "bo-", linewidth=3, markersize=8 + ) + ax_speedup.set_title( + f"{strategy_name}\nSpeedup (CUDA/Triton)\nE={E}, T={T}, H={H}", + fontsize=12, + fontweight="bold", + ) + ax_speedup.set_xlabel("Total Tokens", fontweight="bold", fontsize=11) + ax_speedup.set_ylabel("Speedup Ratio", fontweight="bold", fontsize=11) + ax_speedup.grid(True, alpha=0.3) + + ax_bandwidth.plot( + total_tokens_values, + cuda_bandwidth_percentages, + "ro-", + linewidth=3, + markersize=8, + label="CUDA", + ) + ax_bandwidth.plot( + total_tokens_values, + triton_bandwidth_percentages, + "go-", + linewidth=3, + markersize=8, + label="Triton", + ) + ax_bandwidth.set_title( + f"{strategy_name}\nBandwidth Utilization (Hopper)\nE={E}, T={T}, H={H}", + fontsize=12, + fontweight="bold", + ) + ax_bandwidth.set_xlabel("Total Tokens", fontweight="bold", fontsize=11) + ax_bandwidth.set_ylabel( + "% of Peak Bandwidth", fontweight="bold", fontsize=11 + ) + ax_bandwidth.legend(prop={"weight": "bold"}) + ax_bandwidth.grid(True, alpha=0.3) + + # Format x-axis labels for both plots + for ax in [ax_speedup, ax_bandwidth]: + ax.set_xticks(total_tokens_values) + ax.set_xticklabels( + [ + f"{tt // 1000}K" if tt >= 1000 else str(tt) + for tt in total_tokens_values + ], + fontweight="bold", + ) + # Make tick labels bold + for label in ax.get_xticklabels() + ax.get_yticklabels(): + label.set_fontweight("bold") + + # Add value labels on speedup points + for x, y in zip(total_tokens_values, ratios): + ax_speedup.annotate( + f"{y:.2f}x", + (x, y), + textcoords="offset points", + xytext=(0, 12), + ha="center", + fontsize=10, + fontweight="bold", + bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.7), + ) + + # Add value labels on CUDA bandwidth points + for x, y in zip(total_tokens_values, cuda_bandwidth_percentages): + ax_bandwidth.annotate( + f"{y:.1f}%", + (x, y), + textcoords="offset points", + xytext=(0, 12), + ha="center", + fontsize=9, + fontweight="bold", + bbox=dict(boxstyle="round,pad=0.2", facecolor="red", alpha=0.3), + ) + + # Add value labels on Triton bandwidth points + for x, y in zip(total_tokens_values, triton_bandwidth_percentages): + ax_bandwidth.annotate( + f"{y:.1f}%", + (x, y), + textcoords="offset points", + xytext=(0, -15), + ha="center", + fontsize=9, + fontweight="bold", + bbox=dict(boxstyle="round,pad=0.2", facecolor="green", alpha=0.3), + ) + + plt.tight_layout() + plt.subplots_adjust(top=0.93) # Make room for main title + filename = "silu_benchmark_total_tokens.png" + plt.savefig(filename, dpi=300, bbox_inches="tight") + plt.show() + + return filename + + +# Create combined plot with all strategies +combined_plot_filename = create_total_tokens_plot(all_results) + +print(f"\n{'=' * 60}") +print("Benchmark Complete!") +print(f"Generated combined plot: {combined_plot_filename}") +print(f"{'=' * 60}") diff --git a/benchmarks/kernels/benchmark_trtllm_decode_attention.py b/benchmarks/kernels/benchmark_trtllm_decode_attention.py index 603ce5ecf0d2c..6ddab46214577 100644 --- a/benchmarks/kernels/benchmark_trtllm_decode_attention.py +++ b/benchmarks/kernels/benchmark_trtllm_decode_attention.py @@ -259,6 +259,7 @@ if __name__ == "__main__": # (q_quant_dtype, kv_quant_dtype, o_quant_dtype) (None, None, None), (None, FP8_DTYPE, None), + (FP8_DTYPE, FP8_DTYPE, None), (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), (FP8_DTYPE, FP8_DTYPE, FP4_DTYPE), ] diff --git a/benchmarks/kernels/benchmark_trtllm_prefill_attention.py b/benchmarks/kernels/benchmark_trtllm_prefill_attention.py index 40903c6c3444f..131df74c7de1b 100644 --- a/benchmarks/kernels/benchmark_trtllm_prefill_attention.py +++ b/benchmarks/kernels/benchmark_trtllm_prefill_attention.py @@ -274,6 +274,7 @@ if __name__ == "__main__": quant_dtypes = [ # (q_quant_dtype, kv_quant_dtype, o_quant_dtype) (None, None, None), + (FP8_DTYPE, FP8_DTYPE, None), (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), (FP8_DTYPE, FP8_DTYPE, FP4_DTYPE), ] diff --git a/benchmarks/kernels/benchmark_w8a8_block_fp8.py b/benchmarks/kernels/benchmark_w8a8_block_fp8.py index 98bde9d83c82d..df2b713e46dc4 100644 --- a/benchmarks/kernels/benchmark_w8a8_block_fp8.py +++ b/benchmarks/kernels/benchmark_w8a8_block_fp8.py @@ -56,7 +56,7 @@ def w8a8_block_matmul( Bs: The per-block quantization scale for `B`. block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128]. - output_dytpe: The dtype of the returned tensor. + output_dtype: The dtype of the returned tensor. Returns: torch.Tensor: The result of matmul. diff --git a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu index 820bf81dd1a02..d1874515cc8fd 100644 --- a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu +++ b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu @@ -36,12 +36,14 @@ limitations under the License. #if !defined(CUDA_VERSION) || CUDA_VERSION < 12040 void sm100_cutlass_mla_decode( torch::Tensor const& out, + torch::Tensor const& lse, torch::Tensor const& q_nope, torch::Tensor const& q_pe, torch::Tensor const& kv_c_and_k_pe_cache, torch::Tensor const& seq_lens, torch::Tensor const& page_table, torch::Tensor const& workspace, + double sm_scale, int64_t num_kv_splits) { TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_decode"); } @@ -99,6 +101,7 @@ struct MlaSm100 { template typename T::Fmha::Arguments args_from_options( at::Tensor const& out, + at::Tensor const& lse, at::Tensor const& q_nope, at::Tensor const& q_pe, at::Tensor const& kv_c_and_k_pe_cache, @@ -162,7 +165,10 @@ typename T::Fmha::Arguments args_from_options( stride_PT, page_count_total, page_size}, - {static_cast(out.data_ptr()), stride_O, static_cast(nullptr), stride_LSE}, + {static_cast(out.data_ptr()), + stride_O, + static_cast(lse.defined() ? lse.data_ptr() : nullptr), + stride_LSE}, hw_info, // TODO(trevor-m): Change split_kv back to -1 when // https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will @@ -181,6 +187,7 @@ typename T::Fmha::Arguments args_from_options( template void runMla( at::Tensor const& out, + at::Tensor const& lse, at::Tensor const& q_nope, at::Tensor const& q_pe, at::Tensor const& kv_c_and_k_pe_cache, @@ -192,7 +199,7 @@ void runMla( cudaStream_t stream) { using MlaSm100Type = MlaSm100; typename MlaSm100Type::Fmha fmha; - auto arguments = args_from_options(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, sm_scale, num_kv_splits); + auto arguments = args_from_options(out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, sm_scale, num_kv_splits); CUTLASS_CHECK(fmha.can_implement(arguments)); @@ -214,6 +221,7 @@ void runMla( void sm100_cutlass_mla_decode( torch::Tensor const& out, + torch::Tensor const& lse, torch::Tensor const& q_nope, torch::Tensor const& q_pe, torch::Tensor const& kv_c_and_k_pe_cache, @@ -234,13 +242,13 @@ void sm100_cutlass_mla_decode( DISPATCH_BOOL(num_kv_splits <= 1, NotManualSplitKV, [&] { if (in_dtype == at::ScalarType::Half) { runMla>( - out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); + out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); } else if (in_dtype == at::ScalarType::BFloat16) { runMla>( - out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); + out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); } else if (in_dtype == at::ScalarType::Float8_e4m3fn) { runMla>( - out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); + out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); } else { TORCH_CHECK(false, "Unsupported input data type of MLA"); } diff --git a/csrc/cpu/cpu_types_vxe.hpp b/csrc/cpu/cpu_types_vxe.hpp index ab8cbbbf4ec4f..51bca37e699b9 100644 --- a/csrc/cpu/cpu_types_vxe.hpp +++ b/csrc/cpu/cpu_types_vxe.hpp @@ -12,7 +12,7 @@ namespace vec_op { #define vec_sub(a, b) ((a) - (b)) #define vec_mul(a, b) ((a) * (b)) #define vec_div(a, b) ((a) / (b)) -#define vec_sr(a, b) ((a) >> (b)) // Vector Shift Right Algebaic +#define vec_sr(a, b) ((a) >> (b)) // Vector Shift Right Algebraic #define vec_sl(a, b) ((a) << (b)) // Vector Shift Left // FIXME: FP16 is not fully supported in Torch-CPU diff --git a/csrc/cpu/sgl-kernels/moe.cpp b/csrc/cpu/sgl-kernels/moe.cpp index beeccff783ea0..94b24c2f13a06 100644 --- a/csrc/cpu/sgl-kernels/moe.cpp +++ b/csrc/cpu/sgl-kernels/moe.cpp @@ -215,7 +215,7 @@ int moe_align_block_size( offsets[mb + 1] = sorted_id_size(sorted_ids + mb * BLOCK_M); } }); - // TODO: do we need to vecterize this ? + // TODO: do we need to vectorize this ? for (int mb = 0; mb < num_token_blocks; ++mb) { offsets[mb + 1] += offsets[mb]; } diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh index 44709b4597765..58926f6429dd3 100644 --- a/csrc/custom_all_reduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -15,6 +15,8 @@ typedef __hip_bfloat16 nv_bfloat16; #include #include #include +#include +#include namespace vllm { #define CUDACHECK(cmd) \ @@ -555,22 +557,47 @@ class CustomAllreduce { size /= d; auto bytes = size * sizeof(typename packed_t::P); int blocks = std::min(block_limit, (size + threads - 1) / threads); + + // Check environment variable once + const char* env_algo = std::getenv("VLLM_CUSTOM_ALLREDUCE_ALGO"); + bool force_1stage = false; + bool force_2stage = false; + if (env_algo != nullptr) { + if (std::strcmp(env_algo, "1stage") == 0 || + std::strcmp(env_algo, "oneshot") == 0) { + force_1stage = true; + } else if (std::strcmp(env_algo, "2stage") == 0 || + std::strcmp(env_algo, "twoshot") == 0) { + force_2stage = true; + } else { + throw std::runtime_error( + "Invalid VLLM_CUSTOM_ALLREDUCE_ALGO: " + std::string(env_algo) + + ". Valid values: 1stage, oneshot, 2stage, twoshot"); + } + } + #define KL(ngpus, name) \ name<<>>(ptrs, sg_, self_sg_, output, \ rank_, size); -#define REDUCE_CASE(ngpus) \ - case ngpus: { \ - if (world_size_ == 2) { \ - KL(ngpus, cross_device_reduce_1stage); \ - } else if (fully_connected_) { \ - if ((world_size_ <= 4 && bytes < 512 * 1024) || \ - (world_size_ <= 8 && bytes < 256 * 1024)) { \ - KL(ngpus, cross_device_reduce_1stage); \ - } else { \ - KL(ngpus, cross_device_reduce_2stage); \ - } \ - } \ - break; \ +#define REDUCE_CASE(ngpus) \ + case ngpus: { \ + if (force_1stage) { \ + KL(ngpus, cross_device_reduce_1stage); \ + } else if (force_2stage) { \ + KL(ngpus, cross_device_reduce_2stage); \ + } else { \ + if (world_size_ == 2) { \ + KL(ngpus, cross_device_reduce_1stage); \ + } else if (fully_connected_) { \ + if ((world_size_ <= 4 && bytes < 512 * 1024) || \ + (world_size_ <= 8 && bytes < 256 * 1024)) { \ + KL(ngpus, cross_device_reduce_1stage); \ + } else { \ + KL(ngpus, cross_device_reduce_2stage); \ + } \ + } \ + } \ + break; \ } switch (world_size_) { diff --git a/csrc/cutlass_extensions/gemm/collective/collective_builder.hpp b/csrc/cutlass_extensions/gemm/collective/collective_builder.hpp deleted file mode 100644 index ec75c29e54f4d..0000000000000 --- a/csrc/cutlass_extensions/gemm/collective/collective_builder.hpp +++ /dev/null @@ -1,123 +0,0 @@ -// Modified from: cutlass/gemm/collective/builders/sm90_gmma_builder.inl -// clang-format off -#pragma once - -#include "cutlass/gemm/collective/builders/sm90_gmma_builder.inl" - -#include "cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp" - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm::collective { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA_TMA_WS_SS (BlockScaled Builders) -template < - class ElementA, - class GmemLayoutATag, - int AlignmentA, - class ElementB, - class GmemLayoutBTag, - int AlignmentB, - class ElementAccumulator, - class TileShape_MNK, - class ClusterShape_MNK, - class StageCountType, - int ScaleGranularityM -> -struct CollectiveBuilder< - arch::Sm90, - arch::OpClassTensorOp, - ElementA, - GmemLayoutATag, - AlignmentA, - ElementB, - GmemLayoutBTag, - AlignmentB, - ElementAccumulator, - TileShape_MNK, - ClusterShape_MNK, - StageCountType, - KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum, - cute::enable_if_t< - not detail::is_use_rmem_A()> -> { - using KernelScheduleType = KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum; - - static_assert(is_static::value); - static_assert(is_static::value); -#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED - static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); -#endif - static_assert(detail::is_aligned(), - "Should meet TMA alignment requirement\n"); - - static constexpr bool IsArrayOfPointersGemm = (cute::is_any_of_v); - static constexpr bool IsFP8Input = detail::is_input_fp8(); - static_assert((!IsFP8Input || !IsArrayOfPointersGemm), - "KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum is only compatible with FP8 Blocked Scaled version right now."); - - // For fp32 types, map to tf32 MMA value type - using ElementAMma = cute::conditional_t, tfloat32_t, ElementA>; - using ElementBMma = cute::conditional_t, tfloat32_t, ElementB>; - - static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); - static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); - - static constexpr bool IsCooperative = cute::is_any_of_v>; - using AtomLayoutMNK = cute::conditional_t>, Layout>>; - - using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< - ElementAMma, ElementBMma, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{})); - - using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); - using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); - - using SmemLayoutAtomA = decltype(detail::ss_smem_selector< - GmmaMajorA, ElementAMma, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - using SmemLayoutAtomB = decltype(detail::ss_smem_selector< - GmmaMajorB, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - - static constexpr size_t TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0; - static constexpr int KernelSmemCarveout = static_cast(TensorMapStorage); - - static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); - using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8; - - using SmemCopyAtomA = void; - using SmemCopyAtomB = void; - - using CollectiveOp = CollectiveMma< - DispatchPolicy, - TileShape_MNK, - ElementA, - TagToStrideA_t, - ElementB, - TagToStrideB_t, - TiledMma, - GmemTiledCopyA, - SmemLayoutAtomA, - SmemCopyAtomA, - cute::identity, - GmemTiledCopyB, - SmemLayoutAtomB, - SmemCopyAtomB, - cute::identity - >; -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::collective - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/cutlass_extensions/gemm/collective/fp8_accumulation.hpp b/csrc/cutlass_extensions/gemm/collective/fp8_accumulation.hpp deleted file mode 100644 index 13b90e998625e..0000000000000 --- a/csrc/cutlass_extensions/gemm/collective/fp8_accumulation.hpp +++ /dev/null @@ -1,183 +0,0 @@ -// clang-format off -// adapted from: https://github.com/soundOfDestiny/cutlass/blob/a4208aa6958864923505cade9c63eb2a6daf16e5/include/cutlass/gemm/collective/fp8_accumulation.hpp - -/*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include "cute/algorithm/clear.hpp" -#include "cute/tensor.hpp" - -////////////////////////////////////////////////////////////////////////////// -///////////////////////////////////FP8 Accumulation/////////////////////////// -////////////////////////////////////////////////////////////////////////////// -/// This class provides API to promote (add) or scale (multiply_add) the results -/// from the tensor core accumulators to the main accumulators when the number -/// of MMAs reaches the max number of MMA interval specified by user, after that -/// the tensor core accumulators are zeroed. -////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm::collective { - -template < - class EngineAccum, - class LayoutAccum> -struct GmmaFP8AccumulationWithScale { - using TensorAccum = cute::Tensor; - using ElementAccumulator = typename EngineAccum::value_type; - - static_assert(is_static::value, "Accumulator Layout should be static"); - static_assert(is_rmem::value , "Accumulator tensor must be rmem resident."); - -private: - TensorAccum& accum_; - TensorAccum accum_temp_; - - uint32_t accum_promotion_interval_; // defines the max num of executed MMAs after which accum should be promoted. - uint32_t mma_count_per_mainloop_iteration_; // num of MMAs per k_tile of mainloop - uint32_t mma_count_; // current executed MMAs - uint32_t reset_accum_flag_; // accum needs to be zeroed or not. - - // promote or `add` the partial accumulators to main accumulator (FADD). - CUTLASS_DEVICE - void promote_core() { - warpgroup_wait<0>(); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(accum_); ++i) { - accum_(i) += accum_temp_(i); - } - } - - // `multiply` scale the partial accumulators and `add` to main accumulator (FFMA). - template < - class EngineScale, - class LayoutScale> - CUTLASS_DEVICE - void scale_core(const cute::Tensor &scale) { - using TensorScale = cute::Tensor; - - static_assert(is_static::value, "Scale Layout should be static"); - static_assert(is_rmem::value , "Scale tensor must be rmem resident."); - - static_assert(LayoutAccum{}.shape() == LayoutScale{}.shape(), "Accumulator and scale must have same shape."); - - warpgroup_wait<0>(); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(accum_); ++i) { - accum_(i) += accum_temp_(i) * scale(i); - } - } - -public: - CUTLASS_DEVICE - GmmaFP8AccumulationWithScale( - TensorAccum &accum, - uint32_t accum_promotion_interval, - uint32_t mma_count_per_mainloop_iteration) - : accum_(accum), - accum_promotion_interval_(accum_promotion_interval), - mma_count_per_mainloop_iteration_(mma_count_per_mainloop_iteration), - mma_count_(0), - reset_accum_flag_(0) - { - accum_temp_ = cute::make_fragment_like(accum); - } - - // - // Methods (Common) - // - - CUTLASS_DEVICE - TensorAccum& operator()() { - return accum_temp_; - } - - /// prepare the MMA accumulators when initialization or zeroing is required. - CUTLASS_DEVICE - bool prepare_if_needed() { - return reset_accum_flag_; - } - - // - // Methods (for FADD version) - // - - /// promote (add) the results from the MMA accumulators to main accumulator if needed. - CUTLASS_DEVICE - void promote_if_needed() { - mma_count_ += mma_count_per_mainloop_iteration_; - reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0); - if (reset_accum_flag_) { - promote_core(); - mma_count_ = 0; - } - } - - /// promote (add) the residue results from the MMA accumulators to main accumulator if needed. - CUTLASS_DEVICE - void promote_residue_if_needed() { - if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) { - promote_core(); - } - } - - // - // Methods (for FFMA version) - // - - /// scale (multiply_add) the results from the MMA accumulators to main accumulator if needed. - template < - class EngineScale, - class LayoutScale> - CUTLASS_DEVICE - void scale_if_needed(const cute::Tensor &scale) { - mma_count_ += mma_count_per_mainloop_iteration_; - reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0); - if (reset_accum_flag_) { - scale_core(scale); - mma_count_ = 0; - } - } - - /// scale (multiply_add) the residue results from the MMA accumulators to main accumulator if needed. - template < - class EngineScale, - class LayoutScale> - CUTLASS_DEVICE - void scale_residue_if_needed(const cute::Tensor &scale) { - if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) { - scale_core(scale); - } - } -}; - -} // namespace cutlass::gemm::collective diff --git a/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp b/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp deleted file mode 100644 index ce7f47cf72337..0000000000000 --- a/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp +++ /dev/null @@ -1,729 +0,0 @@ -// clang-format off -// Adapted (Heavily) from: https://github.com/soundOfDestiny/cutlass/blob/9d997ce0dea4c5fa1a617db6b7ff29aa9235822c/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp - -/*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/trace.h" -#include "cutlass/numeric_types.h" - -#include "cute/arch/cluster_sm90.hpp" -#include "cute/arch/copy_sm80.hpp" -#include "cute/arch/copy_sm90.hpp" -#include "cute/algorithm/functional.hpp" -#include "cute/atom/mma_atom.hpp" -#include "cute/algorithm/gemm.hpp" -#include "cute/numeric/arithmetic_tuple.hpp" - -#include "cutlass_extensions/gemm/dispatch_policy.hpp" -#include "cutlass_extensions/gemm/collective/fp8_accumulation.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm::collective { -using namespace cute; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// WarpSpecialized Mainloop -template < - int Stages, - class ClusterShape, - class KernelSchedule, - int ScaleGranularityM_, - class TileShape_, - class ElementA_, - class StrideA_, - class ElementB_, - class StrideB_, - class TiledMma_, - class GmemTiledCopyA_, - class SmemLayoutAtomA_, - class SmemCopyAtomA_, - class TransformA_, - class GmemTiledCopyB_, - class SmemLayoutAtomB_, - class SmemCopyAtomB_, - class TransformB_> -struct CollectiveMma< - MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8, - TileShape_, - ElementA_, - StrideA_, - ElementB_, - StrideB_, - TiledMma_, - GmemTiledCopyA_, - SmemLayoutAtomA_, - SmemCopyAtomA_, - TransformA_, - GmemTiledCopyB_, - SmemLayoutAtomB_, - SmemCopyAtomB_, - TransformB_> -{ - // - // Type Aliases - // - using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8; - using TileShape = TileShape_; - using ElementA = ElementA_; - using StrideA = StrideA_; - using ElementB = ElementB_; - using StrideB = StrideB_; - using TiledMma = TiledMma_; - using ElementAccumulator = typename TiledMma::ValTypeC; - using ElementBlockScale = ElementAccumulator; - using GmemTiledCopyA = GmemTiledCopyA_; - using GmemTiledCopyB = GmemTiledCopyB_; - using SmemLayoutAtomA = SmemLayoutAtomA_; - using SmemLayoutAtomB = SmemLayoutAtomB_; - using SmemCopyAtomA = SmemCopyAtomA_; - using SmemCopyAtomB = SmemCopyAtomB_; - using TransformA = TransformA_; - using TransformB = TransformB_; - using ArchTag = typename DispatchPolicy::ArchTag; - - using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); - using MainloopPipeline = cutlass::PipelineTmaAsync; - using PipelineState = cutlass::PipelineState; - using PipelineParams = typename MainloopPipeline::Params; - - // Two threads per CTA are producers (1 for operand tile and 32 for scales) - static constexpr int NumProducerThreadEvents = 33; - - static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0>(TileShape{}) : ScaleGranularityM_; - static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM; - - static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); - static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - - static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); - static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - - static_assert((size<0>(TileShape{}) % ScaleGranularityM) == 0, "FP8 scaling granularity must evenly divide tile shape along M."); - - // Tile along modes in a way that maximizes the TMA box size. - using SmemLayoutA = decltype(tile_to_shape( - SmemLayoutAtomA{}, - make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), - cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); - using SmemLayoutB = decltype(tile_to_shape( - SmemLayoutAtomB{}, - make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), - cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); - - // Block scaling gmem-to-smem copy atom - using SmemBlockScalingCopyAtomA = Copy_Atom, ElementBlockScale>; - using SmemBlockScalingCopyAtomB = Copy_Atom, ElementBlockScale>; - - // Block scaling smem layout - using SmemLayoutScaleA = Layout, Int>>; - using SmemLayoutScaleB = Layout>, Stride<_1>>; // `ScaleNsPerTile` is always 1. - - static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); - static_assert(cute::is_base_of::value && - cute::is_base_of::value, - "MMA atom must source both A and B operand from smem_desc for this mainloop."); - static_assert(cute::is_same_v || cute::is_same_v, - "GmemTiledCopy - invalid SM90 TMA copy atom specified."); - static_assert(cute::is_same_v || cute::is_same_v, - "GmemTiledCopy - invalid SM90 TMA copy atom specified."); - static_assert(cute::is_same_v, - "ElementAccumulator and ElementBlockScale should be same datatype"); - - struct SharedStorage - { - struct TensorStorage : cute::aligned_struct<128> { - cute::array_aligned> smem_A; // mxk - cute::array_aligned> smem_B; // nxk - cute::array_aligned> smem_scale_A; // ScaleMsPerTile x k - cute::array_aligned> smem_scale_B; // 1xk - } tensors; - - using PipelineStorage = typename MainloopPipeline::SharedStorage; - PipelineStorage pipeline; - }; - using TensorStorage = typename SharedStorage::TensorStorage; - using PipelineStorage = typename SharedStorage::PipelineStorage; - - // Host side kernel arguments - struct Arguments { - ElementA const* ptr_A; - StrideA dA; - ElementB const* ptr_B; - StrideB dB; - ElementBlockScale const* ptr_scale_A; - ElementBlockScale const* ptr_scale_B; - }; - - // Device side kernel params - struct Params { - // Assumption: StrideA is congruent with Problem_MK - using TMA_A = decltype(make_tma_copy_A_sm90( - GmemTiledCopyA{}, - make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), - SmemLayoutA{}(_,_,0), - TileShape{}, - ClusterShape{})); - // Assumption: StrideB is congruent with Problem_NK - using TMA_B = decltype(make_tma_copy_B_sm90( - GmemTiledCopyB{}, - make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), - SmemLayoutB{}(_,_,0), - TileShape{}, - ClusterShape{})); - TMA_A tma_load_a; - TMA_B tma_load_b; - uint32_t tma_transaction_bytes = TmaTransactionBytes; - uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; - uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK; - // Block scaling factors for A and B - ElementBlockScale const* ptr_scale_A; - ElementBlockScale const* ptr_scale_B; - }; - - // - // Methods - // - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - (void) workspace; - - // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) - auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M,N,K,L] = problem_shape_MNKL; - - auto ptr_A = reinterpret_cast(args.ptr_A); - auto ptr_B = reinterpret_cast(args.ptr_B); - - Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA)); - Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB)); - typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90( - GmemTiledCopyA{}, - tensor_a, - SmemLayoutA{}(_,_,cute::Int<0>{}), - TileShape{}, - ClusterShape{}); - typename Params::TMA_B tma_load_b = make_tma_copy_B_sm90( - GmemTiledCopyB{}, - tensor_b, - SmemLayoutB{}(_,_,cute::Int<0>{}), - TileShape{}, - ClusterShape{}); - uint32_t transaction_bytes_mk = TmaTransactionBytesMK; - uint32_t transaction_bytes_nk = TmaTransactionBytesNK; - uint32_t transaction_bytes = transaction_bytes_mk + transaction_bytes_nk; - - return { - tma_load_a, - tma_load_b, - transaction_bytes, - transaction_bytes_mk, - transaction_bytes_nk, - args.ptr_scale_A, - args.ptr_scale_B - }; - } - - template - static bool - can_implement( - ProblemShape const& problem_shape, - [[maybe_unused]] Arguments const& args) { - constexpr int tma_alignment_bits = 128; - auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M,N,K,L] = problem_shape_MNKL; - - bool implementable = true; - constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; - implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); - constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; - implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); - - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); - } - return implementable; - } - - static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; - static constexpr int K_PIPE_MMAS = 1; - static constexpr uint32_t TmaTransactionBytesMK = - cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value)); - static constexpr uint32_t TmaTransactionBytesNK = - cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value)); - static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK; - - /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance - CUTLASS_DEVICE - static void prefetch_tma_descriptors(Params const& mainloop_params) - { - cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); - cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); - } - - /// Set up the data needed by this collective for load and mma. - /// Returns a tuple of tensors. The collective and the kernel layer have the contract - /// Returned tuple must contain at least two elements, with the first two elements being: - /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) - /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) - template - CUTLASS_DEVICE auto - load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { - using X = Underscore; - // Separate out problem shape for convenience - auto [M,N,K,L] = problem_shape_MNKL; - - // TMA requires special handling of strides to deal with coord codomain mapping - // Represent the full tensors -- get these from TMA - Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) - Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) - - // Make tiled views, defer the slice - Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) - Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) - - constexpr auto scales_m = Int{}; - auto tM = get<2>(gA_mkl.shape()); - auto tN = get<2>(gB_nkl.shape()); - auto tK = get<3>(gA_mkl.shape()); - - // Make the tiled views of scale tensors - auto scaleA_shape = make_shape(M / ScaleGranularityM, tK, L); // (scale_m,k,l) - auto scaleA_layout = make_ordered_layout(scaleA_shape, Step<_0, _1, _2>{}); - auto scaleB_shape = make_shape(tN, tK, L); // (n,k,l) - auto scaleB_layout = make_ordered_layout(scaleB_shape, Step<_1, _0, _2>{}); - - // Note that mScaleA_mkl and mScaleB_nkl are already blocked tiled in the `m` host and - // gScaleA_mkl and gScaleB_nkl in `g` global memory are same as mScaleA_mkl and mScaleB_nkl. - Tensor mScaleA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_A), scaleA_layout); // (scale_m,k,l) - Tensor mScaleB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_B), scaleB_layout); // (n,k,l) - - return cute::make_tuple(gA_mkl, gB_nkl, mScaleA_mkl, mScaleB_nkl); - } - - /// Perform a collective-scoped matrix multiply-accumulate - /// Producer Perspective - template < - class TensorA, class TensorB, - class TensorScaleA, class TensorScaleB, - class KTileIterator, class BlockCoord - > - CUTLASS_DEVICE void - load( - Params const& mainloop_params, - MainloopPipeline pipeline, - PipelineState smem_pipe_write, - cute::tuple const& load_inputs, - BlockCoord const& blk_coord, - KTileIterator k_tile_iter, int k_tile_count, - int thread_idx, - uint32_t block_rank_in_cluster, - TensorStorage& shared_tensors) { - int lane_predicate = cute::elect_one_sync(); - - // Blockscaling: Tma loads for load_input and CpAsync for load_scale - Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) - Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) - Tensor sScaleA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), SmemLayoutScaleA{}); // (ScaleMsPerTile,k) - Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k) - - // - // Prepare the TMA loads for A and B - // - - constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); - uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; - - Tensor gA_mkl = get<0>(load_inputs); - Tensor gB_nkl = get<1>(load_inputs); - - auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); - auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); - - // Partition the inputs based on the current block coordinates. - auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; - Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) - Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) - - - // Block scaling: load_scale has scaling tensors in global memory which are not tiled - Tensor mScaleA_mkl = get<2>(load_inputs); - Tensor mScaleB_nkl = get<3>(load_inputs); - auto scales_m = get<0>(mScaleA_mkl.shape()); - - Tensor cScaleA_mkl = make_identity_tensor(mScaleA_mkl.shape()); - - Tensor gScaleA = local_tile( - mScaleA_mkl, make_tile(Int{}), - make_coord(m_coord,_,l_coord)); // (ScaleMsPerTile,k,1) - Tensor cScaleA = local_tile( - cScaleA_mkl, make_tile(Int{}), - make_coord(m_coord,_,l_coord)); - Tensor gScaleB = mScaleB_nkl(n_coord,_,l_coord); // (1,k,1) - - // TODO: test `scale_copy_a` with `ScaleMsPerTile` < 128 - TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{}, - Layout>{}, Layout>{}); // (1,1,1) - TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{}, - Layout>{}, Layout>{}); // (1,1,1) - ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x); - ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x); - - Tensor tAgA_ScaleA = thr_scale_copy_a.partition_S(gScaleA); - Tensor tAcA_ScaleA = thr_scale_copy_a.partition_S(cScaleA); - Tensor tAsA_ScaleA = thr_scale_copy_a.partition_D(sScaleA); - - Tensor tBgB_ScaleB = thr_scale_copy_b.partition_S(gScaleB); - Tensor tBsB_ScaleB = thr_scale_copy_b.partition_D(sScaleB); - - // Applies the mapping from block_tma_a - Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) - Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) - - Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) - Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) - - uint16_t mcast_mask_a = 0; - uint16_t mcast_mask_b = 0; - - // Issue TmaLoads for GEMM operands A/B and CpAsync for scale tensors - // Maps the tile -> block, value - if constexpr (cute::is_same_v) { - auto block_layout = Layout{}; // (m,n) -> block_id - for (int n = 0; n < size<1>(block_layout); ++n) { - mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); - } - } - - if constexpr (cute::is_same_v) { - auto block_layout = Layout{}; // (m,n) -> block_id - for (int m = 0; m < size<0>(block_layout); ++m) { - mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); - } - } - - // Allocate predicate tensors for a_scales (since we can't guarantee that - // all scales are valid, since we could have a partial tiles along M) - Tensor tApA_ScaleA = make_tensor(shape(tAsA_ScaleA(_,_,0))); - #pragma unroll - for (int i = 0; i < size(tApA_ScaleA); ++i) { - tApA_ScaleA(i) = get<0>(tAcA_ScaleA(i)) < scales_m; - } - - // Mainloop - CUTLASS_PRAGMA_NO_UNROLL - for ( ; k_tile_count > 0; --k_tile_count) { - // LOCK smem_pipe_write for _writing_ - pipeline.producer_acquire(smem_pipe_write); - - // - // Copy gmem to smem for *k_tile_iter - // - int write_stage = smem_pipe_write.index(); - using BarrierType = typename MainloopPipeline::ProducerBarrierType; - BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); - - // Copy operands A and B from global memory to shared memory - if (lane_predicate) copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); - if (lane_predicate) copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); - - // Copy scale tensors from global memory to shared memory - copy_if(scale_copy_a, tApA_ScaleA, tAgA_ScaleA(_,_,*k_tile_iter), tAsA_ScaleA(_,_,write_stage)); - copy(scale_copy_b, tBgB_ScaleB(_,*k_tile_iter), tBsB_ScaleB(_,write_stage)); - pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc); - - ++k_tile_iter; - - // Advance smem_pipe_write - ++smem_pipe_write; - } - } - - /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster - CUTLASS_DEVICE void - load_tail( - MainloopPipeline pipeline, - PipelineState smem_pipe_write) { - int lane_predicate = cute::elect_one_sync(); - - // Issue the epilogue waits - if (lane_predicate) { - /* This helps avoid early exit of blocks in Cluster - * Waits for all stages to either be released (all - * Consumer UNLOCKs), or if the stage was never used - * then would just be acquired since the phase was - * still inverted from make_producer_start_state - */ - pipeline.producer_tail(smem_pipe_write); - } - } - - /// Perform a collective-scoped matrix multiply-accumulate - /// Consumer Perspective - template < - class FrgTensorC - > - CUTLASS_DEVICE void - mma(MainloopPipeline pipeline, - PipelineState smem_pipe_read, - FrgTensorC& accum, - int k_tile_count, - int thread_idx, - TensorStorage& shared_tensors, - Params const& mainloop_params) { - - - static_assert(is_rmem::value, "C tensor must be rmem resident."); - static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); - static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); - static_assert(cute::is_void_v, - "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); - static_assert(cute::is_void_v, - "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); - - Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) - Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) - - // Block scaling - Tensor sScaleAViewAsC = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), - Layout< - Shape, Int>, cute::tuple_element_t<1, TileShape>, Int>, - Stride, _0, Int> - >{}); // ((ScaleGranularityM,ScaleMsPerTile),n,k) - Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k) - - // - // Define C accumulators and A/B partitioning - // - - // Layout of warp group to thread mapping - - static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and - stride<0>(typename TiledMma::BLayout{}) == 0 and - size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and - size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, - "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); - - constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; - Layout warp_group_thread_layout = make_layout(Int{}, - Int{}); - - int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); - - TiledMma tiled_mma; - auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); - - Tensor tCsScaleAViewAsC = tiled_mma.get_slice(thread_idx).partition_C(sScaleAViewAsC); // (MMA,MMA_M,MMA_N,PIPE), `thread_mma` above is correct when partitioning A and B, but it is not correct when partitioning C. - - Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) - Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) - - // Allocate "fragments/descriptors" - Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) - Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) - - CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M - CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N - CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K - CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE - CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE - CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE - - // - // PIPELINED MAIN LOOP - // - static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), - "ERROR : Incorrect number of MMAs in flight"); - - // We release buffers to producer warps(dma load) with some mmas in flight - PipelineState smem_pipe_release = smem_pipe_read; - - // Per block scale values for operand A and B - - using RegLayoutScaleAViewAsC = decltype(make_layout_like(tCsScaleAViewAsC(_, _, _, 0).layout())); // `make_layout_like` makes a compact layout. - using RegLayoutScaleAEssential = decltype(filter_zeros(RegLayoutScaleAViewAsC{}.stride(), RegLayoutScaleAViewAsC{}.shape())); // an interface to traverse the underlying storage for the compact layout mentioned above - - Tensor tCrScaleAViewAsC = make_tensor(RegLayoutScaleAViewAsC{}); // (MMA,MMA_M,MMA_N) - ElementBlockScale scale_b; - - // Prologue GMMAs - int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); - - tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - - GmmaFP8AccumulationWithScale accumulation(accum, size<2>(TileShape{}) / size<2>(typename TiledMma::AtomShape_MNK{}), size<2>(tCrA)); - warpgroup_fence_operand(accumulation()); - CUTLASS_PRAGMA_UNROLL - for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) - { - // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) - auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); - pipeline.consumer_wait(smem_pipe_read, barrier_token); - - if (accumulation.prepare_if_needed()) { - tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - } - - int read_stage = smem_pipe_read.index(); - - // Load per block scale values from shared memory to registers. - scale_b = sScaleB[read_stage]; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { - tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{})); - } - if constexpr (ScaleMsPerTile == 1) { - static_assert(size(RegLayoutScaleAEssential{}) == 1); - tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`. - } else { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { - tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b; - } - } - - warpgroup_arrive(); - // Unroll the K mode manually to set scale D to 1 - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { - // (V,M,K) x (V,N,K) => (V,M,N) - cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); - tiled_mma.accumulate_ = GMMA::ScaleOut::One; - } - warpgroup_commit_batch(); - - // Block scale the accumulators with reg tensor `tCrScaleAViewAsC` - accumulation.scale_if_needed(tCrScaleAViewAsC); - - ++smem_pipe_read; - } - - warpgroup_fence_operand(accumulation()); - // Mainloop GMMAs - k_tile_count -= prologue_mma_count; - - CUTLASS_PRAGMA_NO_UNROLL - for ( ; k_tile_count > 0; --k_tile_count) - { - // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) - auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); - pipeline.consumer_wait(smem_pipe_read, barrier_token); - - // - // Compute on k_tile - // - - int read_stage = smem_pipe_read.index(); - - // Load per block scale values from shared memory to registers (at most twice per block along M and exactly once per block along N) - scale_b = sScaleB[read_stage]; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { - tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{})); - } - if constexpr (ScaleMsPerTile == 1) { - static_assert(size(RegLayoutScaleAEssential{}) == 1); - tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`. - } else { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { - tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b; - } - } - - if (accumulation.prepare_if_needed()) { - tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - } - - warpgroup_fence_operand(accumulation()); - warpgroup_arrive(); - // Unroll the K mode manually to set scale D to 1 - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { - // (V,M,K) x (V,N,K) => (V,M,N) - cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); - tiled_mma.accumulate_ = GMMA::ScaleOut::One; - } - warpgroup_commit_batch(); - - /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed - warpgroup_wait(); - warpgroup_fence_operand(accumulation()); - - // Block scale the accumulators with reg tensor `tCrScaleAViewAsC` - accumulation.scale_if_needed(tCrScaleAViewAsC); - - pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it - - // Advance smem_pipe_read and smem_pipe_release - ++smem_pipe_read; - ++smem_pipe_release; - } - - accumulation.scale_residue_if_needed(tCrScaleAViewAsC); - - warpgroup_fence_operand(accumulation()); - } - - /// Perform a Consumer Epilogue to release all buffers - CUTLASS_DEVICE void - mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { - // Prologue GMMAs - int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); - k_tile_count -= prologue_mma_count; - - smem_pipe_release.advance(k_tile_count); - - // Wait on all GMMAs to complete - warpgroup_wait<0>(); - - for (int count = 0; count < prologue_mma_count; ++count) { - pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it - ++smem_pipe_release; - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::collective - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/cutlass_extensions/gemm/dispatch_policy.hpp b/csrc/cutlass_extensions/gemm/dispatch_policy.hpp deleted file mode 100644 index df809e27a3efe..0000000000000 --- a/csrc/cutlass_extensions/gemm/dispatch_policy.hpp +++ /dev/null @@ -1,39 +0,0 @@ -#pragma once - -#include "cutlass/gemm/dispatch_policy.hpp" - -namespace cutlass::gemm { - -////////////////////////////////////////////////////////////////////////////// - -// FP8 related policies (including Blocked Scaled Accumulation) -// `ScaleGranularityM` specifies scaling granularity along M, while zero-value -// `ScaleGranularityM` indicates that scaling granularity is -// `size<0>(TileShape_MNK{})` along M. -template -struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum - : KernelTmaWarpSpecializedCooperative {}; - -// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp -// specialized dynamic schedule For FP8 kernels with Block Scaling -template , - class KernelSchedule = KernelTmaWarpSpecialized, - int ScaleGranularityM = - 0 // `ScaleGranularityM` specifies scaling granularity along M, - // while zero-value `ScaleGranularityM` indicates that scaling - // granularity is `size<0>(TileShape_MNK{})` along M. - > -struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8 - : MainloopSm90TmaGmmaWarpSpecialized { - static_assert( - cute::is_same_v< - KernelSchedule, - KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum< - ScaleGranularityM>>, - "KernelSchedule must be one of the warp specialized policies"); -}; - -////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm \ No newline at end of file diff --git a/csrc/cutlass_extensions/vllm_collective_builder.cuh b/csrc/cutlass_extensions/vllm_collective_builder.cuh index e7fbba4cd4b0d..085ee1290031f 100644 --- a/csrc/cutlass_extensions/vllm_collective_builder.cuh +++ b/csrc/cutlass_extensions/vllm_collective_builder.cuh @@ -1,6 +1,6 @@ #pragma once -#include "cutlass_extensions/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" namespace cutlass::gemm::collective { using namespace cute; diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index f051eb0702228..05be023de0f28 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -140,6 +140,211 @@ fused_add_rms_norm_kernel( } } +/* Function specialization in the case of FP16/BF16 tensors. + Additional optimizations we can make in this case are + packed and vectorized operations, which help with the + memory latency bottleneck. + + _f16VecPN struct extends _f16Vec to add operations specifically required for + polynomial normalization (poly norm). + The original _f16Vec does not include the sum-of-powers computation or + in-place polynomial normalization logic. */ +template +struct alignas(16) _f16VecPN : _f16Vec { + using Base = _f16Vec; + using Converter = typename Base::Converter; + using T1 = typename Base::T1; + using T2 = typename Base::T2; + using Base::data; + + __device__ auto sum_pows() const { + float s2 = 0.0f, s4 = 0.0f, s6 = 0.0f; + +#pragma unroll + for (int i = 0; i < width; i += 2) { + float2 z = Converter::convert(T2{data[i], data[i + 1]}); + float x2 = z.x * z.x; + float x4 = x2 * x2; + float x6 = x4 * x2; + + float y2 = z.y * z.y; + float y4 = y2 * y2; + float y6 = y4 * y2; + + s2 += x2 + y2; + s4 += x4 + y4; + s6 += x6 + y6; + } + return std::make_tuple(s2, s4, s6); + } + + __device__ void poly_norm_inplace(const float w2_inv_std, + const float w1_inv_std2, + const float w0_inv_std3, const float bias) { +#pragma unroll + for (int i = 0; i < width; i += 2) { + float2 z = Converter::convert(T2{data[i], data[i + 1]}); + + float x2 = z.x * z.x; + float x3 = x2 * z.x; + z.x = w2_inv_std * z.x + w1_inv_std2 * x2 + w0_inv_std3 * x3 + bias; + + float y2 = z.y * z.y; + float y3 = y2 * z.y; + z.y = w2_inv_std * z.y + w1_inv_std2 * y2 + w0_inv_std3 * y3 + bias; + + auto out = Converter::convert(z); + data[i] = out.x; + data[i + 1] = out.y; + } + } +}; + +template +__global__ std::enable_if_t<(width > 0) && _typeConvert::exists> +poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size] + const scalar_t* __restrict__ input, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [3] + const scalar_t* __restrict__ bias, // [1] + const float epsilon, const int hidden_size) { + // Sanity checks on our vector struct and type-punned pointer arithmetic + static_assert(std::is_pod_v<_f16VecPN>); + static_assert(sizeof(_f16VecPN) == sizeof(scalar_t) * width); + + /* These and the argument pointers are all declared `restrict` as they are + not aliased in practice. Argument pointers should not be dereferenced + in this kernel as that would be undefined behavior */ + auto* __restrict__ input_v = + reinterpret_cast*>(input); + const int vec_hidden_size = hidden_size / width; + float variance = 0.0f; + float variance2 = 0.0f; + float variance3 = 0.0f; + + for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { + int id = blockIdx.x * vec_hidden_size + idx; + _f16VecPN temp = input_v[id]; + auto [x2, x4, x6] = temp.sum_pows(); + + variance += x2; + variance2 += x4; + variance3 += x6; + } + + float3 thread_variances = make_float3(variance, variance2, variance3); + + struct SumOp { + __device__ float3 operator()(const float3& a, const float3& b) const { + return make_float3(a.x + b.x, a.y + b.y, a.z + b.z); + } + }; + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + float3 block_variances = + BlockReduce(reduceStore).Reduce(thread_variances, SumOp{}, blockDim.x); + + variance = block_variances.x; + variance2 = block_variances.y; + variance3 = block_variances.z; + + __shared__ float s_w2_inv_std; + __shared__ float s_w1_inv_std2; + __shared__ float s_w0_inv_std3; + __shared__ float s_bias; + + if (threadIdx.x == 0) { + float w0 = (float)weight[0]; + float w1 = (float)weight[1]; + float w2 = (float)weight[2]; + s_bias = (float)bias[0]; + + s_w2_inv_std = w2 * rsqrtf(variance / hidden_size + epsilon); + s_w1_inv_std2 = w1 * rsqrtf(variance2 / hidden_size + epsilon); + s_w0_inv_std3 = w0 * rsqrtf(variance3 / hidden_size + epsilon); + } + __syncthreads(); + + auto* __restrict__ out_v = reinterpret_cast<_f16VecPN*>(out); + + for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { + int id = blockIdx.x * vec_hidden_size + idx; + _f16VecPN temp = input_v[id]; + temp.poly_norm_inplace(s_w2_inv_std, s_w1_inv_std2, s_w0_inv_std3, s_bias); + out_v[id] = temp; + } +} + +/* Generic poly_norm_kernel + The width field is not used here but necessary for other specializations. + */ +template +__global__ std::enable_if_t<(width == 0) || !_typeConvert::exists> +poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size] + const scalar_t* __restrict__ input, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [3] + const scalar_t* __restrict__ bias, // [1] + const float epsilon, const int hidden_size) { + float variance = 0.0f; + float variance2 = 0.0f; + float variance3 = 0.0f; + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float x = (float)input[blockIdx.x * hidden_size + idx]; + float x2 = x * x; + float x4 = x2 * x2; + float x6 = x4 * x2; + + variance += x2; + variance2 += x4; + variance3 += x6; + } + + float3 thread_variances = make_float3(variance, variance2, variance3); + + struct SumOp { + __device__ float3 operator()(const float3& a, const float3& b) const { + return make_float3(a.x + b.x, a.y + b.y, a.z + b.z); + } + }; + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + float3 block_variances = + BlockReduce(reduceStore).Reduce(thread_variances, SumOp{}, blockDim.x); + + variance = block_variances.x; + variance2 = block_variances.y; + variance3 = block_variances.z; + + __shared__ float s_w2_inv_std; + __shared__ float s_w1_inv_std2; + __shared__ float s_w0_inv_std3; + __shared__ float s_bias; + + if (threadIdx.x == 0) { + float w0 = (float)weight[0]; + float w1 = (float)weight[1]; + float w2 = (float)weight[2]; + s_bias = (float)bias[0]; + + s_w2_inv_std = w2 * rsqrtf(variance / hidden_size + epsilon); + s_w1_inv_std2 = w1 * rsqrtf(variance2 / hidden_size + epsilon); + s_w0_inv_std3 = w0 * rsqrtf(variance3 / hidden_size + epsilon); + } + __syncthreads(); + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float x = (float)input[blockIdx.x * hidden_size + idx]; + float x2 = x * x; + float x3 = x2 * x; + + out[blockIdx.x * hidden_size + idx] = + (scalar_t)(x * s_w2_inv_std + x2 * s_w1_inv_std2 + x3 * s_w0_inv_std3 + + s_bias); + } +} + } // namespace vllm void rms_norm(torch::Tensor& out, // [..., hidden_size] @@ -219,3 +424,49 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] LAUNCH_FUSED_ADD_RMS_NORM(0); } } + +#define LAUNCH_FUSED_POLY_NORM(width) \ + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "poly_norm_kernel", [&] { \ + vllm::poly_norm_kernel<<>>( \ + out.data_ptr(), input.data_ptr(), \ + weight.data_ptr(), bias.data_ptr(), epsilon, \ + hidden_size); \ + }); + +void poly_norm(torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] + torch::Tensor& weight, // [3] + torch::Tensor& bias, // [1] + double epsilon) { + TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(out.data_ptr() != input.data_ptr()); + + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + + dim3 grid(num_tokens); + /* This kernel is memory-latency bound in many scenarios. + When num_tokens is large, a smaller block size allows + for increased block occupancy on CUs and better latency + hiding on global mem ops. */ + const int max_block_size = (num_tokens < 256) ? 1024 : 256; + dim3 block(std::min(hidden_size, max_block_size)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + /*If the tensor types are FP16/BF16, try to use the optimized kernel + with packed + vectorized ops. + Max optimization is achieved with a width-8 vector of FP16/BF16s + since we can load at most 128 bits at once in a global memory op. + However, this requires each tensor's data to be aligned to 16 + bytes. + */ + auto inp_ptr = reinterpret_cast(input.data_ptr()); + auto out_ptr = reinterpret_cast(out.data_ptr()); + bool ptrs_are_aligned = inp_ptr % 16 == 0 && out_ptr % 16 == 0; + if (ptrs_are_aligned && hidden_size % 8 == 0) { + LAUNCH_FUSED_POLY_NORM(8); + } else { + LAUNCH_FUSED_POLY_NORM(0); + } +} diff --git a/csrc/ops.h b/csrc/ops.h index a288112e21000..c65bf431640d5 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -92,6 +92,9 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, double epsilon); +void poly_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, + torch::Tensor& bias, double epsilon); + void apply_repetition_penalties_(torch::Tensor& logits, const torch::Tensor& prompt_mask, const torch::Tensor& output_mask, @@ -119,12 +122,6 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, std::optional key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox); -void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query, - std::optional key, - int64_t head_size, torch::Tensor& cos_sin_cache, - bool is_neox, int64_t rot_dim, - torch::Tensor& cos_sin_cache_offsets); - void silu_and_mul(torch::Tensor& out, torch::Tensor& input); void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input, @@ -136,6 +133,12 @@ void silu_and_mul_nvfp4_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& input_global_scale); #endif +void silu_mul_fp8_quant_deep_gemm_cuda( + const at::Tensor& input, // (E, T, 2*H) + const at::Tensor& counts, // (E) + at::Tensor& y_q, // (E, T, H) [OUT] + at::Tensor& y_s, // (E, T, H//group_size) [OUT] + int64_t group_size, bool use_ue8m0, int64_t num_parallel_tokens); void mul_and_silu(torch::Tensor& out, torch::Tensor& input); @@ -353,4 +356,4 @@ void qr_open_handles(fptr_t _fa, const std::vector& handles); void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, int64_t quant_level, bool cast_bf2half = false); int64_t qr_max_size(); -#endif \ No newline at end of file +#endif diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index 266f2a0667a24..b5645b33b9073 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -99,35 +99,6 @@ __global__ void rotary_embedding_kernel( token_idx, query_stride, key_stride, head_stride); } -template -__global__ void batched_rotary_embedding_kernel( - const int64_t* __restrict__ positions, // [batch_size, seq_len] or - // [num_tokens] - scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, - // head_size] or [num_tokens, num_heads, - // head_size] - scalar_t* __restrict__ key, // nullptr or - // [batch_size, seq_len, num_kv_heads, - // head_size] or [num_tokens, num_kv_heads, - // head_size] - const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // - // 2] - const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] - const int rot_dim, const int64_t query_stride, const int64_t key_stride, - const int64_t head_stride, const int num_heads, const int num_kv_heads, - const int head_size) { - // Each thread block is responsible for one token. - const int token_idx = blockIdx.x; - int64_t pos = positions[token_idx]; - int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx]; - const scalar_t* cache_ptr = - cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim; - - apply_rotary_embedding( - query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, - token_idx, query_stride, key_stride, head_stride); -} - } // namespace vllm void rotary_embedding( @@ -211,96 +182,3 @@ void rotary_embedding( } }); } - -/* -Batched version of rotary embedding, pack multiple LoRAs together -and process in batched manner. -*/ -void batched_rotary_embedding( - torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] - torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or - // [num_tokens, num_heads * head_size] or - // [batch_size, seq_len, num_heads, head_size] or - // [num_tokens, num_heads, head_size] - std::optional - key, // null or - // [batch_size, seq_len, num_kv_heads * head_size] or - // [num_tokens, num_kv_heads * head_size] or - // [batch_size, seq_len, num_heads, head_size] or - // [num_tokens, num_heads, head_size] - int64_t head_size, - torch::Tensor& cos_sin_cache, // [max_position, rot_dim] - bool is_neox, int64_t rot_dim, - torch::Tensor& cos_sin_cache_offsets // [num_tokens] or [batch_size] -) { - // num_tokens = batch_size * seq_len - int64_t num_tokens = cos_sin_cache_offsets.size(0); - TORCH_CHECK( - positions.size(0) == num_tokens || positions.numel() == num_tokens, - "positions must have the same num_tokens or batch_size as " - "cos_sin_cache_offsets"); - - int positions_ndim = positions.dim(); - // Make sure num_tokens dim is consistent across positions, query, and key - TORCH_CHECK( - positions_ndim == 1 || positions_ndim == 2, - "positions must have shape [num_tokens] or [batch_size, seq_len]"); - if (positions_ndim == 1) { - TORCH_CHECK(query.size(0) == positions.size(0) && - (!key.has_value() || key->size(0) == positions.size(0)), - "query, key and positions must have the same number of tokens"); - } - if (positions_ndim == 2) { - TORCH_CHECK( - query.size(0) == positions.size(0) && - (!key.has_value() || key->size(0) == positions.size(0)) && - query.size(1) == positions.size(1) && - (!key.has_value() || key->size(1) == positions.size(1)), - "query, key and positions must have the same batch_size and seq_len"); - } - - // Make sure head_size is valid for query and key - int query_hidden_size = query.numel() / num_tokens; - int key_hidden_size = key.has_value() ? key->numel() / num_tokens : 0; - TORCH_CHECK(query_hidden_size % head_size == 0); - TORCH_CHECK(key_hidden_size % head_size == 0); - - // Make sure query and key have concistent number of heads - int num_heads = query_hidden_size / head_size; - int num_kv_heads = key.has_value() ? key_hidden_size / head_size : num_heads; - TORCH_CHECK(num_heads % num_kv_heads == 0); - - int seq_dim_idx = positions_ndim - 1; - int64_t query_stride = query.stride(seq_dim_idx); - int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0; - // Determine head stride: for [*, heads, head_size] use stride of last dim; - // for flat [*, heads*head_size], heads blocks are contiguous of size - // head_size - int query_ndim = query.dim(); - int64_t head_stride = - (query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size; - - dim3 grid(num_tokens); - dim3 block(std::min(num_heads * rot_dim / 2, 512)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { - if (is_neox) { - vllm::batched_rotary_embedding_kernel - <<>>( - positions.data_ptr(), query.data_ptr(), - key.has_value() ? key->data_ptr() : nullptr, - cos_sin_cache.data_ptr(), - cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride, - key_stride, head_stride, num_heads, num_kv_heads, head_size); - } else { - vllm::batched_rotary_embedding_kernel - <<>>( - positions.data_ptr(), query.data_ptr(), - key.has_value() ? key->data_ptr() : nullptr, - cos_sin_cache.data_ptr(), - cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride, - key_stride, head_stride, num_heads, num_kv_heads, head_size); - } - }); -} diff --git a/csrc/quantization/activation_kernels.cu b/csrc/quantization/activation_kernels.cu index 8bc2b9bff3d5a..9ddb5af3052fa 100644 --- a/csrc/quantization/activation_kernels.cu +++ b/csrc/quantization/activation_kernels.cu @@ -9,6 +9,26 @@ #include "quantization/fp8/common.cuh" +#include + +#ifndef USE_ROCM + #include + #include + #include +#else + #include + #include + #include + +typedef __hip_bfloat162 __nv_bfloat162; +typedef __hip_bfloat16 __nv_bfloat16; +typedef __hip_bfloat16_raw __nv_bfloat16_raw; + +typedef __hip_fp8_e4m3 __nv_fp8_e4m3; +typedef __hip_fp8x4_e4m3 __nv_fp8x4_e4m3; +#endif + +#include "core/registration.h" namespace vllm { template @@ -87,6 +107,337 @@ __global__ void act_and_mul_quant_kernel( } } } + +__device__ __forceinline__ float silu(float x) { + return (__fdividef(x, (1.f + expf(-x)))); +} + +__device__ __forceinline__ float2 silu2(float2 x) { + return make_float2(silu(x.x), silu(x.y)); +} + +#ifndef USE_ROCM +__device__ __forceinline__ float warp_max(float v) { + static constexpr unsigned FULL_MASK = 0xffffffffu; + for (int offset = 1; offset < WARP_SIZE; offset *= 2) { + v = fmaxf(v, __shfl_xor_sync(FULL_MASK, v, offset)); + } + return v; +} + +__device__ __forceinline__ __nv_bfloat16 warp_max(__nv_bfloat16 v) { + static constexpr unsigned FULL_MASK = 0xffffffffu; + for (int offset = 1; offset < WARP_SIZE; offset *= 2) { + v = __hmax(v, __shfl_xor_sync(FULL_MASK, v, offset)); + } + return v; +} +#endif + +template +__device__ __forceinline__ void cp_async4(T* _smem_ptr, const U* _glob_ptr) { +#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800 + auto smem_ptr = reinterpret_cast(_smem_ptr); + auto glob_ptr = reinterpret_cast(_glob_ptr); + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); +#else + _smem_ptr[0] = _glob_ptr[0]; +#endif +} + +__device__ __forceinline__ void cp_async_fence() { +#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800 + asm volatile("cp.async.commit_group;\n" ::); +#else +#endif +} + +template +__device__ __forceinline__ void cp_async_wait() { +#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800 + asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); +#else +#endif +} + +template <> +__device__ __forceinline__ void cp_async_wait<0>() { +#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800 + asm volatile("cp.async.wait_all;\n" ::); +#else +#endif +} + +__device__ __forceinline__ float clip(float v, float mmin, float mmax) { +#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800 + return fminf(mmax, fmaxf(v, mmin)); +#else +#endif +} + +__device__ __forceinline__ __nv_bfloat16 clip(__nv_bfloat16 v, + __nv_bfloat16 mmin, + __nv_bfloat16 mmax) { + return __hmin(mmax, __hmax(v, mmin)); +} + +__device__ __forceinline__ __nv_bfloat162 clip(__nv_bfloat162 v, + __nv_bfloat162 mmin, + __nv_bfloat162 mmax) { + return __hmin2(mmax, __hmax2(v, mmin)); +} + +// We use the following values for fp8 min/max: +// __nv_fp8_e4m3 = (-448, +448) +// __nv_fp8_e4m3uz = (-240.0, +240.0) +// It is currently assumed that only +template +constexpr __nv_bfloat16 get_fp8_max() { + static_assert(std::is_same_v || + std::is_same_v); + if constexpr (std::is_same_v) { + return __nv_bfloat16(__nv_bfloat16_raw{.x = 17376}); + } else { + return __nv_bfloat16(__nv_bfloat16_raw{.x = 17264}); + } +} + +template +constexpr __nv_bfloat16 get_fp8_min() { + static_assert(std::is_same_v || + std::is_same_v); + if constexpr (std::is_same_v) { + return __nv_bfloat16(__nv_bfloat16_raw{.x = 50144}); + } else { + return __nv_bfloat16(__nv_bfloat16_raw{.x = 50032}); + } +} +#ifndef USE_ROCM +template +__global__ void silu_mul_fp8_quant_deep_gemm_kernel( + const __nv_bfloat16* __restrict__ _input, fp8_type* __restrict__ _y_q, + float* __restrict__ _y_s, const int32_t* __restrict__ counts, + + // sizes + int H, int G, + + // strides (in elements) + Idx_t stride_i_e, Idx_t stride_i_t, Idx_t stride_i_h, Idx_t stride_yq_e, + Idx_t stride_yq_t, Idx_t stride_yq_h, Idx_t stride_ys_e, Idx_t stride_ys_t, + Idx_t stride_ys_g, Idx_t stride_counts_e) { + static constexpr __nv_bfloat16 fp8_min = get_fp8_min(); + static constexpr __nv_bfloat16 fp8_max = get_fp8_max(); + // We assign EPS with its 16-bit unsigned counterpart to allow constexpr. + static constexpr __nv_bfloat16 EPS = (__nv_bfloat16_raw{.x = 11996}); + + // We pack 8 16-bit bfloat16 values into a 128-bit __int128_t. + static constexpr int32_t BFLOAT16_PER_GROUP = 8; + + // We split the shared memory in half, corresponding to gate and up matrices: + // [...gate_i, ...up_i] where 0 <= i < stages. + static constexpr int32_t S_NUM_128 = + 2u * (GROUP_SIZE / BFLOAT16_PER_GROUP) * NUM_WARPS * NUM_STAGES; + static constexpr auto THREAD_COUNT = NUM_WARPS * WARP_SIZE; + static constexpr int HALF_THREAD_COUNT = THREAD_COUNT / 2; + static constexpr int32_t S_NUM_64 = S_NUM_128 * 2; + __shared__ __int128_t __align__(16) s_buff_128[S_NUM_128]; + + const int32_t tid = threadIdx.x; + const int32_t warp_id = tid / WARP_SIZE; + const int32_t lane_id = tid % WARP_SIZE; + + auto s_buff_compute_32 = reinterpret_cast<__nv_bfloat162*>(s_buff_128); + + // block handles one (expert e, group g) + int32_t pid = blockIdx.x; + int32_t e = pid / G; + int32_t g = pid % G; + + const int32_t n_tokens = counts[e * stride_counts_e]; + + if (!n_tokens) { + return; // Exit ASAP. + } + + const Idx_t stride_i_t_128 = stride_i_t / 8u; + + int32_t n_tokens_lower, n_tokens_upper; + + // Each block i iterates over tokens of a slice of n_tokens = + // expert_counts[i], with the size of chunk being + // (n_tokens / NUM_PARALLEL_TOKENS) + residual, instead of + // updiv(n_tokens, NUM_PARALLEL_TOKENS) for better scheduling. + if (n_tokens < NUM_PARALLEL_TOKENS && blockIdx.y < n_tokens) { + // Specialize this, but can be likely fused. + if (blockIdx.y >= NUM_PARALLEL_TOKENS) { + return; + } + n_tokens_lower = blockIdx.y; + n_tokens_upper = blockIdx.y + 1; + } else { + auto chunk_size = n_tokens / NUM_PARALLEL_TOKENS; + auto residual = n_tokens - chunk_size * NUM_PARALLEL_TOKENS; + auto calc_id = [&](int32_t id) { + if (id < residual) { + return min(n_tokens, id * (chunk_size + 1)); + } else { + return min(n_tokens, id * chunk_size + residual); + } + }; + n_tokens_lower = calc_id(blockIdx.y); + n_tokens_upper = calc_id(blockIdx.y + 1); + } + + if (n_tokens_lower >= n_tokens_upper) { + return; + } + + // We do calculations here, using constexpr wherever possible. + const Idx_t base_i = e * stride_i_e + NUM_WARPS * g * GROUP_SIZE * stride_i_h; + const Idx_t base_ys = e * stride_ys_e + NUM_WARPS * g * stride_ys_g; + const Idx_t base_yq = + e * stride_yq_e + NUM_WARPS * g * GROUP_SIZE * stride_yq_h; + Idx_t gate_off_128 = (base_i / static_cast(8u)); + auto input_128_ptr = reinterpret_cast(_input); + auto gate_128_ptr = input_128_ptr + gate_off_128 + (tid % HALF_THREAD_COUNT) + + stride_i_t_128 * n_tokens_lower; + auto up_128_ptr = gate_128_ptr + (H * stride_i_h) / 8u; + auto y_s_ptr = + _y_s + base_ys + warp_id * stride_ys_g + n_tokens_lower * stride_ys_t; + auto y_q_ptr = _y_q + base_yq + warp_id * GROUP_SIZE + + stride_yq_t * n_tokens_lower + 4 * lane_id; + int32_t t_load = n_tokens_lower, load_stage_id = 0; + auto s_buff_gate_load_128 = s_buff_128 + (tid % HALF_THREAD_COUNT); + auto s_buff_up_load_128 = s_buff_gate_load_128 + S_NUM_128 / 2u; + int32_t stage_offset{}; + + static constexpr int32_t LOAD_STAGE_SIZE = (NUM_WARPS * WARP_SIZE / 2); + static constexpr int32_t LOAD_STAGE_MOD = + NUM_STAGES * (NUM_WARPS * WARP_SIZE / 2); + + // Two halves of all threads in a block conduct global loads for gate and up, + // repsectively. + auto load_and_advance_y_pred = [&] { + if (t_load < n_tokens_upper) { + auto s_gate_stage_128_staged_ptr = s_buff_gate_load_128 + stage_offset; + auto s_up_stage_128_staged_ptr = s_buff_up_load_128 + stage_offset; + + // It is very important that LOAD_STAGE_SIZE is constexpr to avoid + // unnecessary ALU ops. + stage_offset += LOAD_STAGE_SIZE; + stage_offset %= LOAD_STAGE_MOD; + + if (tid < HALF_THREAD_COUNT) { + cp_async4(s_gate_stage_128_staged_ptr, gate_128_ptr); + gate_128_ptr += stride_i_t_128; + } else { + cp_async4(s_up_stage_128_staged_ptr, up_128_ptr); + up_128_ptr += stride_i_t_128; + } + ++t_load; + ++load_stage_id; + } + // We fence even if there is nothing to load to simplify pipelining. + cp_async_fence(); + }; + + #pragma unroll + for (int i = 0; i < NUM_STAGES - 1; i++) { + load_and_advance_y_pred(); + } + + __int64_t* s_gate_ptr = reinterpret_cast<__int64_t*>( + s_buff_compute_32 + warp_id * (GROUP_SIZE / 2)) + + lane_id; + __int64_t* s_up_ptr = s_gate_ptr + S_NUM_64 / 2; + + static constexpr int32_t STAGE_SIZE = (GROUP_SIZE * NUM_WARPS) / 4u; + static constexpr int32_t STAGE_MOD = STAGE_SIZE * NUM_STAGES; + + int32_t compute_pipeline_offset_64 = 0; + + for (int32_t t = n_tokens_lower; t < n_tokens_upper; ++t) { + __nv_bfloat16 y_max_bf16 = EPS; + __nv_bfloat162 results_bf162[2]; + + cp_async_wait(); + __syncthreads(); + + // We double-buffer pipelined loads so that the next load will + // concurrently run with compute without overwrites. + load_and_advance_y_pred(); + + auto s_gate_compute_64 = s_gate_ptr + compute_pipeline_offset_64; + auto s_up_compute_64 = s_up_ptr + compute_pipeline_offset_64; + + // STAGE_SIZE must also be constexpr! + compute_pipeline_offset_64 += STAGE_SIZE; + compute_pipeline_offset_64 %= STAGE_MOD; + + // Each thread loads (gate/up) 2X 4X bfloat16 values into registers. + __int64_t gate64 = *s_gate_compute_64; + __nv_bfloat162* s_gate_compute_32 = + reinterpret_cast<__nv_bfloat162*>(&gate64); + + __int64_t up64 = *s_up_compute_64; + __nv_bfloat162* s_up_compute_32 = reinterpret_cast<__nv_bfloat162*>(&up64); + + #pragma unroll + for (int i = 0; i < 2; i++) { + // For silu, we make sure that div is emitted. + float2 gate = silu2(__bfloat1622float2(s_gate_compute_32[i])); + results_bf162[i] = __float22bfloat162_rn(gate); + } + + #pragma unroll + for (int i = 0; i < 2; i++) { + results_bf162[i] = __hmul2(results_bf162[i], s_up_compute_32[i]); + } + + auto _y_max2 = + __hmax2(__habs2(results_bf162[0]), __habs2(results_bf162[1])); + + y_max_bf16 = __hmax(_y_max2.x, _y_max2.y); + + // An entire group is assigned to a single warp, so a simple warp reduce + // is used. + __nv_bfloat16 y_s = warp_max(y_max_bf16) / fp8_max; + + if constexpr (USE_UE8M0) { + y_s = hexp2(hceil(hlog2(y_s))); + } + + auto inv_y = __float2bfloat16_rn(1.f) / y_s; + + auto y_s2 = make_bfloat162(inv_y, inv_y); + + #pragma unroll + for (int32_t i = 0; i < 2; ++i) { + results_bf162[i] = + clip(__hmul2(results_bf162[i], y_s2), __bfloat162bfloat162(fp8_min), + __bfloat162bfloat162(fp8_max)); + } + + auto fp8x4 = __nv_fp8x4_e4m3(results_bf162[0], results_bf162[1]); + *reinterpret_cast<__nv_fp8x4_e4m3*>(y_q_ptr) = fp8x4; + y_q_ptr += stride_yq_t; + + if (lane_id == 0) { + *y_s_ptr = y_s; + y_s_ptr += stride_ys_t; + } + } +} +#endif + } // namespace vllm // Launch activation, gating, and quantize kernel. @@ -119,3 +470,117 @@ void silu_and_mul_quant(torch::Tensor& out, // [..., d] TORCH_CHECK(input.size(-1) % 2 == 0); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); } + +void silu_mul_fp8_quant_deep_gemm_cuda( + const at::Tensor& input, // (E, T, 2*H) + const at::Tensor& counts, // (E) + at::Tensor& y_q, // (E, T, H) [OUT] + at::Tensor& y_s, // (E, T, H//group_size) [OUT] + int64_t group_size, bool use_ue8m0, int64_t num_parallel_tokens) { +#ifndef USE_ROCM + // This kernel relies heavily on cp.async and fp8 support. + // This kernel currently only supports H % 128 == 0 and assumes a + // fixed GROUP_SIZE of 128. + TORCH_CHECK(input.dtype() == torch::kBFloat16); + TORCH_CHECK(y_q.dtype() == torch::kFloat8_e4m3fn || + y_q.dtype() == torch::kFloat8_e4m3fnuz); + TORCH_CHECK(y_s.dtype() == torch::kFloat32); + TORCH_CHECK(input.size(-1) % 256 == 0); + + // Check that num_parallel_tokens is of power of 2 and between 1 and 64. + TORCH_CHECK(1 <= num_parallel_tokens && num_parallel_tokens <= 64); + TORCH_CHECK(!(num_parallel_tokens & (num_parallel_tokens - 1))); + + using Idx_t = int64_t; + + Idx_t E = input.size(0); + Idx_t T = input.size(1); + Idx_t H = input.size(2) / 2; + Idx_t stride_i_e = input.stride(0); + Idx_t stride_i_t = input.stride(1); + Idx_t stride_i_h = input.stride(2); + Idx_t stride_yq_e = y_q.stride(0); + Idx_t stride_yq_t = y_q.stride(1); + Idx_t stride_yq_h = y_q.stride(2); + Idx_t stride_ys_e = y_s.stride(0); + Idx_t stride_ys_t = y_s.stride(1); + Idx_t stride_ys_g = y_s.stride(2); + + Idx_t stride_counts_e = counts.stride(0); + + static constexpr int GROUP_SIZE = 128; + + #define KERNEL_FN \ + if (use_ue8m0) { \ + vllm::silu_mul_fp8_quant_deep_gemm_kernel \ + <<>>( \ + reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \ + (fp8_t*)y_q.data_ptr(), y_s.data_ptr(), \ + reinterpret_cast(counts.data_ptr()), H, G, \ + stride_i_e, stride_i_t, stride_i_h, stride_yq_e, stride_yq_t, \ + stride_yq_h, stride_ys_e, stride_ys_t, stride_ys_g, \ + stride_counts_e); \ + } else { \ + vllm::silu_mul_fp8_quant_deep_gemm_kernel \ + <<>>( \ + reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \ + (fp8_t*)y_q.data_ptr(), y_s.data_ptr(), \ + reinterpret_cast(counts.data_ptr()), H, G, \ + stride_i_e, stride_i_t, stride_i_h, stride_yq_e, stride_yq_t, \ + stride_yq_h, stride_ys_e, stride_ys_t, stride_ys_g, \ + stride_counts_e); \ + } + + #define KERNEL_CALL_H \ + if (H % (4 * GROUP_SIZE) == 0) { \ + static constexpr int NUM_WARPS = 4; \ + populate_launch_params(NUM_WARPS, NUM_PARALLEL_TOKENS); \ + KERNEL_FN \ + } else { \ + static constexpr int NUM_WARPS = 1; \ + populate_launch_params(NUM_WARPS, NUM_PARALLEL_TOKENS); \ + KERNEL_FN \ + } + + #define KERNEL_CALL_TOP_LEVEL \ + if (num_parallel_tokens == 1) { \ + static constexpr int NUM_PARALLEL_TOKENS = 1; \ + KERNEL_CALL_H \ + } else if (num_parallel_tokens == 2) { \ + static constexpr int NUM_PARALLEL_TOKENS = 2; \ + KERNEL_CALL_H \ + } else if (num_parallel_tokens == 4) { \ + static constexpr int NUM_PARALLEL_TOKENS = 4; \ + KERNEL_CALL_H \ + } else if (num_parallel_tokens == 8) { \ + static constexpr int NUM_PARALLEL_TOKENS = 8; \ + KERNEL_CALL_H \ + } else if (num_parallel_tokens == 16) { \ + static constexpr int NUM_PARALLEL_TOKENS = 16; \ + KERNEL_CALL_H \ + } else if (num_parallel_tokens == 32) { \ + static constexpr int NUM_PARALLEL_TOKENS = 32; \ + KERNEL_CALL_H \ + } else if (num_parallel_tokens == 64) { \ + static constexpr int NUM_PARALLEL_TOKENS = 64; \ + KERNEL_CALL_H \ + } + + Idx_t G; + dim3 block, grid; + auto populate_launch_params = [&](int num_warps, int _num_parallel_tokens) { + G = H / Idx_t(group_size * num_warps); + grid = dim3(E * G, _num_parallel_tokens); + block = dim3(num_warps * WARP_SIZE); + }; + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + VLLM_DISPATCH_FP8_TYPES(y_q.scalar_type(), + "silu_mul_fp8_quant_deep_gemm_kernel", + [&] { KERNEL_CALL_TOP_LEVEL }); + +#endif +} diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh index c841125dbb734..939879b2c59fa 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh @@ -14,9 +14,6 @@ #include "cutlass/epilogue/dispatch_policy.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp" -#include "cutlass_extensions/gemm/dispatch_policy.hpp" -#include "cutlass_extensions/gemm/collective/collective_builder.hpp" - #include "cutlass_gemm_caller.cuh" namespace vllm { diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh index d50a83ae1cd48..78d5cf37fa6d0 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh @@ -14,9 +14,6 @@ #include "cutlass/epilogue/dispatch_policy.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp" -#include "cutlass_extensions/gemm/dispatch_policy.hpp" -#include "cutlass_extensions/gemm/collective/collective_builder.hpp" - #include "cutlass_gemm_caller.cuh" namespace vllm { diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh index e089c3d4be2cc..86220264151e7 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh @@ -13,27 +13,18 @@ #include "cutlass/epilogue/dispatch_policy.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp" -#include "cutlass_extensions/gemm/dispatch_policy.hpp" -#include "cutlass_extensions/gemm/collective/collective_builder.hpp" - #include "cutlass_gemm_caller.cuh" namespace vllm { using namespace cute; -template > +// clang-format off +template struct cutlass_3x_gemm_fp8_blockwise { - using GroupSizeM = Int; - using GroupSizeN = Int; - using GroupSizeK = Int; - using TileSizeM = Int; - - static_assert(TileSizeM_ % GroupSizeM_ == 0, - "TileSizeM must be a multiple of GroupSizeM"); - using ElementAB = cutlass::float_e4m3_t; using ElementA = ElementAB; @@ -45,52 +36,67 @@ struct cutlass_3x_gemm_fp8_blockwise { static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; using ElementD = OutType; - using StrideD = Stride, Int<0>>; + using LayoutD = cutlass::layout::RowMajor; static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; - using ElementC = void; - using StrideC = StrideD; + using ElementC = void; // TODO: support bias + using LayoutC = LayoutD; static constexpr int AlignmentC = AlignmentD; using ElementAccumulator = float; - using ElementBlockScale = float; using ElementCompute = float; + using ElementBlockScale = float; + + using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig< + ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>; + + using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); + using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); + using ArchTag = cutlass::arch::Sm90; using OperatorClass = cutlass::arch::OpClassTensorOp; - using TileShape = Shape; - using KernelSchedule = cutlass::gemm:: - KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum< - GroupSizeM_>; - using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; - using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + using ElementScalar = float; + using DefaultOperation = cutlass::epilogue::fusion::LinearCombination; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + MmaTileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementCompute, + ElementC, + LayoutC, + AlignmentC, + ElementD, + LayoutD, + AlignmentD, + EpilogueScheduler, + DefaultOperation + >::CollectiveOp; - using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT< - cutlass::epilogue::fusion::Sm90AccFetch>; - - using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, - ElementAccumulator, ElementCompute, ElementC, StrideC, AlignmentC, - ElementD, StrideD, AlignmentD, EpilogueSchedule, - StoreEpilogueCompute>::CollectiveOp; - - using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, - LayoutB, AlignmentB, ElementAccumulator, TileShape, ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout( - sizeof(typename CollectiveEpilogue::SharedStorage))>, - KernelSchedule>::CollectiveOp; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementA, + cute::tuple, + AlignmentA, + ElementB, + cute::tuple, + AlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopScheduler + >::CollectiveOp; using KernelType = enable_sm90_or_later, CollectiveMainloop, CollectiveEpilogue, - SchedulerType>>; + Shape, CollectiveMainloop, CollectiveEpilogue>>; struct GemmKernel : public KernelType {}; - - using StrideA = typename GemmKernel::StrideA; - using StrideB = typename GemmKernel::StrideB; }; template @@ -99,76 +105,54 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& a_scales, torch::Tensor const& b_scales) { using GemmKernel = typename Gemm::GemmKernel; + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideD = typename Gemm::GemmKernel::StrideD; + using StrideC = typename Gemm::GemmKernel::StrideC; + using LayoutSFA = typename Gemm::LayoutSFA; + using LayoutSFB = typename Gemm::LayoutSFB; + using ScaleConfig = typename Gemm::ScaleConfig; using ElementAB = typename Gemm::ElementAB; using ElementD = typename Gemm::ElementD; - auto prob_shape = c3x::get_problem_shape(a, b); - int32_t m = get<0>(prob_shape), n = get<1>(prob_shape), - k = get<2>(prob_shape); + int32_t m = a.size(0), n = b.size(1), k = a.size(1); - int64_t lda = a.stride(0); - int64_t ldb = b.stride(1); - int64_t ldc = out.stride(0); + TORCH_CHECK(m % 4 == 0, "m must be divisible by 4"); - using StrideA = Stride, int64_t>; - using StrideB = Stride, int64_t>; - using StrideC = typename Gemm::StrideC; + StrideA a_stride; + StrideB b_stride; + StrideC c_stride; + a_stride = + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); + b_stride = + cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1)); + c_stride = + cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1)); - StrideA a_stride{lda, Int<1>{}, 0}; - StrideB b_stride{ldb, Int<1>{}, 0}; - StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; + LayoutSFA layout_SFA = + ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1)); + LayoutSFB layout_SFB = + ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1)); auto a_ptr = static_cast(a.data_ptr()); auto b_ptr = static_cast(b.data_ptr()); auto a_scales_ptr = static_cast(a_scales.data_ptr()); auto b_scales_ptr = static_cast(b_scales.data_ptr()); - // Check is the t is contiguous and is 1D or 2D with one of the dimensions - // being 1 (i.e. a row or column vector) - auto is_contiguous_vector = [](const torch::Tensor& t) { - auto t_sizes = t.sizes(); - return t.is_contiguous() && - (t.dim() == 1 || - (t.dim() == 2 && - *std::min_element(t_sizes.begin(), t_sizes.end()) == 1)); - }; - - // TODO(lucas): lets clean-up the kernel so that we pass in Strides so - // we don't have to deal with enforcing implicit layouts - TORCH_CHECK(a_scales.size(0) == m / Gemm::GroupSizeM::value); - TORCH_CHECK(a_scales.size(1) == k / Gemm::GroupSizeK::value); - TORCH_CHECK(a_scales.stride(0) == 1 || is_contiguous_vector(a_scales), - "a_scales must be M major"); - TORCH_CHECK(b_scales.size(0) == k / Gemm::GroupSizeK::value); - TORCH_CHECK(b_scales.size(1) == n / Gemm::GroupSizeN::value); - TORCH_CHECK(b_scales.stride(0) == 1 || is_contiguous_vector(b_scales), - "b_scales must be K major"); - typename GemmKernel::MainloopArguments mainloop_args{ - a_ptr, a_stride, b_ptr, b_stride, a_scales_ptr, b_scales_ptr}; + auto mainloop_args = [&](){ + return typename GemmKernel::MainloopArguments{ + a_ptr, a_stride, b_ptr, b_stride, + a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB + }; + }(); + auto prob_shape = cute::make_shape(m, n, k, 1); auto c_ptr = static_cast(out.data_ptr()); typename GemmKernel::EpilogueArguments epilogue_args{ {}, c_ptr, c_stride, c_ptr, c_stride}; - - typename GemmKernel::TileSchedulerArguments scheduler; - - static constexpr bool UsesStreamKScheduler = - cute::is_same_v; - - if constexpr (UsesStreamKScheduler) { - using DecompositionMode = typename cutlass::gemm::kernel::detail:: - PersistentTileSchedulerSm90StreamKParams::DecompositionMode; - using ReductionMode = typename cutlass::gemm::kernel::detail:: - PersistentTileSchedulerSm90StreamKParams::ReductionMode; - - scheduler.decomposition_mode = DecompositionMode::StreamK; - scheduler.reduction_mode = ReductionMode::Nondeterministic; - } - c3x::cutlass_gemm_caller(a.device(), prob_shape, mainloop_args, - epilogue_args, scheduler); + epilogue_args); } template @@ -177,18 +161,12 @@ void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales) { - auto k = a.size(1); - auto n = b.size(1); - - if (k > 3 * n) { - cutlass_gemm_caller_blockwise>( - out, a, b, a_scales, b_scales); - } else { - cutlass_gemm_caller_blockwise>( - out, a, b, a_scales, b_scales); - } + // TODO: better heuristics + cutlass_gemm_caller_blockwise, + Shape<_1, _2, _1>, cutlass::epilogue::TmaWarpSpecializedCooperative, + cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum>>( + out, a, b, a_scales, b_scales); } } // namespace vllm \ No newline at end of file diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp index 2ee6a19407f92..3af59267bd60c 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp @@ -32,7 +32,7 @@ void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a, TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor."); TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor."); int32_t version_num = get_sm_version_num(); - if (version_num >= 100) { + if (version_num >= 90) { TORCH_CHECK( a.size(0) == a_scales.size(0) && cuda_utils::ceil_div(a.size(1), int64_t(128)) == a_scales.size(1), @@ -41,32 +41,6 @@ void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a, cuda_utils::ceil_div(b.size(0), int64_t(128)) == b_scales.size(0) && cuda_utils::ceil_div(b.size(1), int64_t(128)) == b_scales.size(1), "b_scale_group_shape must be [128, 128]."); - } else { - // TODO: Remove this after using cutlass sm90 blockwise scaling gemm - // kernel, or introducing ceil_div to the load_init() of mainloop. - using GroupShape = std::array; - auto make_group_shape = [](torch::Tensor const& x, - torch::Tensor const& s) -> GroupShape { - TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D"); - return {cuda_utils::ceil_div(x.size(0), s.size(0)), - cuda_utils::ceil_div(x.size(1), s.size(1))}; - }; - - GroupShape a_scale_group_shape = make_group_shape(a, a_scales); - GroupShape b_scale_group_shape = make_group_shape(b, b_scales); - - // 1x128 per-token group scales for activations - // 128x128 blockwise scales for weights - TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} && - b_scale_group_shape == GroupShape{128, 128} && - a.dtype() == torch::kFloat8_e4m3fn && - b.dtype() == torch::kFloat8_e4m3fn), - "cutlass_scaled_mm only supports datatype float8_e4m3fn.\n" - "a_scale_group_shape must be [1, 128]. Got: [", - a_scale_group_shape[0], ", ", a_scale_group_shape[1], - "]\n" - "b_scale_group_shape must be [128, 128]. Got: [", - b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]"); } TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm"); diff --git a/csrc/quantization/fp8/common.cuh b/csrc/quantization/fp8/common.cuh index 1aad6330c44b8..7838f211c59db 100644 --- a/csrc/quantization/fp8/common.cuh +++ b/csrc/quantization/fp8/common.cuh @@ -5,7 +5,9 @@ #include -#ifdef USE_ROCM +#ifndef USE_ROCM + #include "nvidia/quant_utils.cuh" +#else #include "amd/quant_utils.cuh" #endif @@ -48,7 +50,9 @@ __device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val, float r = fmaxf(-quant_type_max_v, fminf(x, quant_type_max_v)); #ifndef USE_ROCM - return static_cast(r); + // Use hardware cvt instruction for fp8 on nvidia + // Currently only support fp8_type = c10::Float8_e4m3fn + return fp8::vec_conversion(r); #else // Use hardware cvt instruction for fp8 on rocm return fp8::cvt_c10(r); diff --git a/csrc/quantization/fp8/nvidia/quant_utils.cuh b/csrc/quantization/fp8/nvidia/quant_utils.cuh index f8cd1dcba4ab3..5b9c2df8468cb 100644 --- a/csrc/quantization/fp8/nvidia/quant_utils.cuh +++ b/csrc/quantization/fp8/nvidia/quant_utils.cuh @@ -12,13 +12,26 @@ namespace vllm { namespace fp8 { #ifdef ENABLE_FP8 - #if 0 // Disable the following code to reduce the binary size. template -__inline__ __device__ Tout -vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) { +__inline__ __device__ Tout vec_conversion( + const Tin& x, const __nv_fp8_interpretation_t fp8_type = __NV_E4M3) { return x; } +// float -> c10::Float8_e4m3fn +template <> +__inline__ __device__ c10::Float8_e4m3fn +vec_conversion( + const float& a, const __nv_fp8_interpretation_t fp8_type) { + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return static_cast(a); + #else + return c10::Float8_e4m3fn(__nv_cvt_float_to_fp8(a, __NV_SATFINITE, fp8_type), + c10::Float8_e4m3fn::from_bits()); + #endif +} + + #if 0 // Disable the following code to reduce the binary size. // fp8 -> half template <> __inline__ __device__ uint16_t vec_conversion( diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 95fb5b197f534..81aca7b8860d5 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -32,6 +32,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { #define stride_tag #endif + ops.def( + "silu_mul_fp8_quant_deep_gemm_cuda(Tensor input, Tensor counts, Tensor! " + "y_q, Tensor! y_s, int group_size, " + "bool use_ue8m0, int num_parallel_tokens) -> ()"); + ops.impl("silu_mul_fp8_quant_deep_gemm_cuda", torch::kCUDA, + &silu_mul_fp8_quant_deep_gemm_cuda); + ops.def("weak_ref_tensor(Tensor input) -> Tensor"); ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor); @@ -168,6 +175,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "float epsilon) -> ()"); ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm); + // Polynomial Normalization. + ops.def( + "poly_norm(Tensor! out, Tensor input, Tensor weight, Tensor bias, float " + "epsilon) -> ()"); + ops.impl("poly_norm", torch::kCUDA, &poly_norm); + // Apply repetition penalties to logits in-place ops.def( "apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, " @@ -208,16 +221,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor cos_sin_cache, bool is_neox) -> ()"); ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding); - // Apply GPT-NeoX or GPT-J style rotary embedding to query and key - // (supports multiple loras). - ops.def( - "batched_rotary_embedding(Tensor positions, Tensor! query," - " Tensor!? key, int head_size," - " Tensor cos_sin_cache, bool is_neox," - " int rot_dim," - " Tensor cos_sin_cache_offsets) -> ()"); - ops.impl("batched_rotary_embedding", torch::kCUDA, &batched_rotary_embedding); - // Quantization ops #ifndef USE_ROCM // Quantized GEMM for AWQ. @@ -516,10 +519,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // SM100 CUTLASS MLA decode ops.def( - "sm100_cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe," - " Tensor kv_c_and_k_pe_cache, Tensor seq_lens," - " Tensor page_table, Tensor workspace, float " - "scale," + "sm100_cutlass_mla_decode(Tensor! out, Tensor! lse, Tensor q_nope," + " Tensor q_pe, Tensor kv_c_and_k_pe_cache," + " Tensor seq_lens, Tensor page_table," + " Tensor workspace, float scale," " int num_kv_splits) -> ()"); // conditionally compiled so impl in source file diff --git a/docker/Dockerfile b/docker/Dockerfile index b78d7d88f1f83..307e9658f7175 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -519,7 +519,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ else \ BITSANDBYTES_VERSION="0.46.1"; \ fi; \ - uv pip install --system accelerate hf_transfer modelscope "bitsandbytes>=${BITSANDBYTES_VERSION}" 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3] + uv pip install --system accelerate hf_transfer modelscope "bitsandbytes>=${BITSANDBYTES_VERSION}" 'timm>=1.0.17' boto3 runai-model-streamer runai-model-streamer[s3] ENV VLLM_USAGE_SOURCE production-docker-image diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index f164857325043..063fc49693288 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -47,6 +47,7 @@ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/requirements /requirements COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/benchmarks /benchmarks COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/tests /tests COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/examples /examples +COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/docker/Dockerfile.rocm /docker/ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/.buildkite /.buildkite # ----------------------- @@ -71,7 +72,7 @@ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm /vllm-workspace RUN cd /vllm-workspace \ && rm -rf vllm \ && python3 -m pip install -e tests/vllm_test_utils \ - && python3 -m pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] \ + && python3 -m pip install lm-eval[api]==0.4.4 \ && python3 -m pip install pytest-shard # ----------------------- @@ -100,8 +101,10 @@ ARG COMMON_WORKDIR # Copy over the benchmark scripts as well COPY --from=export_vllm /benchmarks ${COMMON_WORKDIR}/vllm/benchmarks COPY --from=export_vllm /examples ${COMMON_WORKDIR}/vllm/examples +COPY --from=export_vllm /docker ${COMMON_WORKDIR}/vllm/docker ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1 +ENV RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES=1 ENV TOKENIZERS_PARALLELISM=false # ENV that can improve safe tensor loading, and end-to-end time diff --git a/docker/Dockerfile.rocm_base b/docker/Dockerfile.rocm_base index 3414c0aa845cb..2ba5461dfe551 100644 --- a/docker/Dockerfile.rocm_base +++ b/docker/Dockerfile.rocm_base @@ -1,18 +1,16 @@ -ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:6.3.1-complete -ARG HIPBLASLT_BRANCH="db8e93b4" -ARG HIPBLAS_COMMON_BRANCH="7c1566b" +ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:6.4.1-complete +ARG HIPBLASLT_BRANCH="aa0bda7b" +ARG HIPBLAS_COMMON_BRANCH="9b80ba8e" ARG LEGACY_HIPBLASLT_OPTION= -ARG RCCL_BRANCH="648a58d" -ARG RCCL_REPO="https://github.com/ROCm/rccl" ARG TRITON_BRANCH="e5be006" ARG TRITON_REPO="https://github.com/triton-lang/triton.git" -ARG PYTORCH_BRANCH="295f2ed4" +ARG PYTORCH_BRANCH="f717b2af" ARG PYTORCH_VISION_BRANCH="v0.21.0" -ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git" +ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git" ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" ARG FA_BRANCH="1a7f4dfa" ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git" -ARG AITER_BRANCH="916bf3c" +ARG AITER_BRANCH="4822e675" ARG AITER_REPO="https://github.com/ROCm/aiter.git" FROM ${BASE_IMAGE} AS base @@ -45,7 +43,7 @@ RUN apt-get update -y \ && curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \ && python3 --version && python3 -m pip --version -RUN pip install -U packaging 'cmake<4' ninja wheel setuptools pybind11 Cython +RUN pip install -U packaging 'cmake<4' ninja wheel 'setuptools<80' pybind11 Cython FROM base AS build_hipblaslt ARG HIPBLASLT_BRANCH @@ -53,6 +51,7 @@ ARG HIPBLAS_COMMON_BRANCH # Set to "--legacy_hipblas_direct" for ROCm<=6.2 ARG LEGACY_HIPBLASLT_OPTION RUN git clone https://github.com/ROCm/hipBLAS-common.git +RUN apt-get remove -y hipblaslt && apt-get autoremove -y && apt-get autoclean -y RUN cd hipBLAS-common \ && git checkout ${HIPBLAS_COMMON_BRANCH} \ && mkdir build \ @@ -69,24 +68,17 @@ RUN cd hipBLASLt \ && make package RUN mkdir -p /app/install && cp /app/hipBLASLt/build/release/*.deb /app/hipBLAS-common/build/*.deb /app/install -FROM base AS build_rccl -ARG RCCL_BRANCH -ARG RCCL_REPO -RUN git clone ${RCCL_REPO} -RUN cd rccl \ - && git checkout ${RCCL_BRANCH} \ - && ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH} -RUN mkdir -p /app/install && cp /app/rccl/build/release/*.deb /app/install - FROM base AS build_triton ARG TRITON_BRANCH ARG TRITON_REPO RUN git clone ${TRITON_REPO} RUN cd triton \ && git checkout ${TRITON_BRANCH} \ - && cd python \ - && python3 setup.py bdist_wheel --dist-dir=dist -RUN mkdir -p /app/install && cp /app/triton/python/dist/*.whl /app/install + && if [ ! -f setup.py ]; then cd python; fi \ + && python3 setup.py bdist_wheel --dist-dir=dist \ + && mkdir -p /app/install && cp dist/*.whl /app/install +RUN if [ -d triton/python/triton_kernels ]; then pip install build && cd triton/python/triton_kernels \ + && python3 -m build --wheel && cp dist/*.whl /app/install; fi FROM base AS build_amdsmi RUN cd /opt/rocm/share/amd_smi \ @@ -132,15 +124,25 @@ RUN cd aiter \ RUN pip install pyyaml && cd aiter && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py bdist_wheel --dist-dir=dist && ls /app/aiter/dist/*.whl RUN mkdir -p /app/install && cp /app/aiter/dist/*.whl /app/install +FROM base AS debs +RUN mkdir /app/debs +RUN --mount=type=bind,from=build_hipblaslt,src=/app/install/,target=/install \ + cp /install/*.deb /app/debs +RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \ + cp /install/*.whl /app/debs +RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \ + cp /install/*.whl /app/debs +RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \ + cp /install/*.whl /app/debs +RUN --mount=type=bind,from=build_aiter,src=/app/install/,target=/install \ + cp /install/*.whl /app/debs + FROM base AS final RUN --mount=type=bind,from=build_hipblaslt,src=/app/install/,target=/install \ dpkg -i /install/*deb \ - && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \ - && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status -RUN --mount=type=bind,from=build_rccl,src=/app/install/,target=/install \ - dpkg -i /install/*deb \ - && sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \ - && sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status + && perl -p -i -e 's/, hipblas-common-dev \([^)]*?\), /, /g' /var/lib/dpkg/status \ + && perl -p -i -e 's/, hipblaslt-dev \([^)]*?\), /, /g' /var/lib/dpkg/status \ + && perl -p -i -e 's/, hipblaslt \([^)]*?\), /, /g' /var/lib/dpkg/status RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \ pip install /install/*.whl RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \ @@ -154,8 +156,6 @@ ARG BASE_IMAGE ARG HIPBLAS_COMMON_BRANCH ARG HIPBLASLT_BRANCH ARG LEGACY_HIPBLASLT_OPTION -ARG RCCL_BRANCH -ARG RCCL_REPO ARG TRITON_BRANCH ARG TRITON_REPO ARG PYTORCH_BRANCH @@ -170,8 +170,6 @@ RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \ && echo "HIPBLAS_COMMON_BRANCH: ${HIPBLAS_COMMON_BRANCH}" >> /app/versions.txt \ && echo "HIPBLASLT_BRANCH: ${HIPBLASLT_BRANCH}" >> /app/versions.txt \ && echo "LEGACY_HIPBLASLT_OPTION: ${LEGACY_HIPBLASLT_OPTION}" >> /app/versions.txt \ - && echo "RCCL_BRANCH: ${RCCL_BRANCH}" >> /app/versions.txt \ - && echo "RCCL_REPO: ${RCCL_REPO}" >> /app/versions.txt \ && echo "TRITON_BRANCH: ${TRITON_BRANCH}" >> /app/versions.txt \ && echo "TRITON_REPO: ${TRITON_REPO}" >> /app/versions.txt \ && echo "PYTORCH_BRANCH: ${PYTORCH_BRANCH}" >> /app/versions.txt \ @@ -180,4 +178,4 @@ RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \ && echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \ && echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \ && echo "AITER_BRANCH: ${AITER_BRANCH}" >> /app/versions.txt \ - && echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt + && echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt \ No newline at end of file diff --git a/docker/Dockerfile.s390x b/docker/Dockerfile.s390x index 9270b48c54d4b..9942b7626f81e 100644 --- a/docker/Dockerfile.s390x +++ b/docker/Dockerfile.s390x @@ -16,7 +16,8 @@ ENV LANG=C.UTF-8 \ RUN microdnf install -y \ which procps findutils tar vim git gcc gcc-gfortran g++ make patch zlib-devel \ libjpeg-turbo-devel libtiff-devel libpng-devel libwebp-devel freetype-devel harfbuzz-devel \ - openssl-devel openblas openblas-devel autoconf automake libtool cmake numpy libsndfile && \ + openssl-devel openblas openblas-devel autoconf automake libtool cmake numpy libsndfile \ + clang llvm-devel llvm-static clang-devel && \ microdnf clean all # Python Installation @@ -191,7 +192,6 @@ RUN --mount=type=cache,target=/root/.cache/uv \ -DCOMPILER_RT_BUILD_ORC=OFF \ -DCOMPILER_RT_INCLUDE_TESTS=OFF \ ${CMAKE_ARGS} -GNinja ../llvm \ - && ninja install . && \ # build llvmlite cd ../../llvmlite && python setup.py bdist_wheel && \ @@ -200,6 +200,45 @@ RUN --mount=type=cache,target=/root/.cache/uv \ sed -i '/#include "internal\/pycore_atomic.h"/i\#include "dynamic_annotations.h"' numba/_dispatcher.cpp; \ fi && python setup.py bdist_wheel +# Edit aws-lc-sys to support s390x +FROM python-install AS aws-lc-sys-editor +WORKDIR /tmp +ENV CARGO_HOME=/root/.cargo +ENV RUSTUP_HOME=/root/.rustup +ENV PATH="$CARGO_HOME/bin:$RUSTUP_HOME/bin:$PATH" +ARG AWS_LC_VERSION=v0.30.0 +RUN --mount=type=cache,target=/root/.cache/uv \ + --mount=type=bind,from=rust,source=/root/.cargo,target=/root/.cargo,rw \ + --mount=type=bind,from=rust,source=/root/.rustup,target=/root/.rustup,rw \ + git clone --recursive https://github.com/aws/aws-lc-rs.git && \ + cd aws-lc-rs && \ + git checkout tags/aws-lc-sys/${AWS_LC_VERSION} && \ + git submodule sync && \ + git submodule update --init --recursive && \ + cd aws-lc-sys && \ + sed -i '682 s/strncmp(buf, "-----END ", 9)/memcmp(buf, "-----END ", 9)/' aws-lc/crypto/pem/pem_lib.c && \ + sed -i '712 s/strncmp(buf, "-----END ", 9)/memcmp(buf, "-----END ", 9)/' aws-lc/crypto/pem/pem_lib.c && \ + sed -i '747 s/strncmp(buf, "-----END ", 9)/memcmp(buf, "-----END ", 9)/' aws-lc/crypto/pem/pem_lib.c + +# Build Outlines Core +FROM python-install AS outlines-core-builder +WORKDIR /tmp +ENV CARGO_HOME=/root/.cargo +ENV RUSTUP_HOME=/root/.rustup +ENV PATH="$CARGO_HOME/bin:$RUSTUP_HOME/bin:$PATH" +ARG OUTLINES_CORE_VERSION=0.2.10 +RUN --mount=type=cache,target=/root/.cache/uv \ + --mount=type=bind,from=rust,source=/root/.cargo,target=/root/.cargo,rw \ + --mount=type=bind,from=rust,source=/root/.rustup,target=/root/.rustup,rw \ + --mount=type=bind,from=aws-lc-sys-editor,source=/tmp/aws-lc-rs/aws-lc-sys,target=/tmp/aws-lc-sys,rw \ + git clone https://github.com/dottxt-ai/outlines-core.git && \ + cd outlines-core && \ + git checkout tags/${OUTLINES_CORE_VERSION} && \ + sed -i "s/version = \"0.0.0\"/version = \"${OUTLINES_CORE_VERSION}\"/" Cargo.toml && \ + echo '[patch.crates-io]' >> Cargo.toml && \ + echo 'aws-lc-sys = { path = "/tmp/aws-lc-sys" }' >> Cargo.toml && \ + uv pip install maturin && \ + python -m maturin build --release --out dist # Final build stage FROM python-install AS vllm-cpu @@ -230,6 +269,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ --mount=type=bind,from=torch,source=/tmp/pytorch/dist,target=/tmp/torch-wheels/ \ --mount=type=bind,from=numba-builder,source=/tmp/llvmlite/dist,target=/tmp/llvmlite-wheels/ \ --mount=type=bind,from=numba-builder,source=/tmp/numba/dist,target=/tmp/numba-wheels/ \ + --mount=type=bind,from=outlines-core-builder,source=/tmp/outlines-core/dist,target=/tmp/outlines-core/dist/ \ sed -i '/^torch/d' requirements/build.txt && \ ARROW_WHL_FILE=$(ls /tmp/arrow-wheels/pyarrow-*.whl) && \ VISION_WHL_FILE=$(ls /tmp/vision-wheels/*.whl) && \ @@ -237,6 +277,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ TORCH_WHL_FILE=$(ls /tmp/torch-wheels/*.whl) && \ LLVM_WHL_FILE=$(ls /tmp/llvmlite-wheels/*.whl) && \ NUMBA_WHL_FILE=$(ls /tmp/numba-wheels/*.whl) && \ + OUTLINES_CORE_WHL_FILE=$(ls /tmp/outlines-core/dist/*.whl) && \ uv pip install -v \ $ARROW_WHL_FILE \ $VISION_WHL_FILE \ @@ -244,6 +285,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ $TORCH_WHL_FILE \ $LLVM_WHL_FILE \ $NUMBA_WHL_FILE \ + $OUTLINES_CORE_WHL_FILE \ --index-strategy unsafe-best-match \ -r requirements/build.txt \ -r requirements/cpu.txt diff --git a/docker/Dockerfile.xpu b/docker/Dockerfile.xpu index 65d2e5036b783..ef422352509a9 100644 --- a/docker/Dockerfile.xpu +++ b/docker/Dockerfile.xpu @@ -1,12 +1,10 @@ FROM intel/deep-learning-essentials:2025.1.3-0-devel-ubuntu24.04 AS vllm-base -RUN rm /etc/apt/sources.list.d/intel-graphics.list +RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && \ + echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list && \ + add-apt-repository -y ppa:kobuk-team/intel-graphics RUN apt clean && apt-get update -y && \ - apt-get install -y software-properties-common && \ - add-apt-repository ppa:deadsnakes/ppa && \ - apt-get install -y python3.10 python3.10-distutils && \ - curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10 && \ apt-get install -y --no-install-recommends --fix-missing \ curl \ ffmpeg \ @@ -17,17 +15,29 @@ RUN apt clean && apt-get update -y && \ libgl1 \ lsb-release \ numactl \ - python3.10-dev \ - wget + wget \ + vim \ + python3.12 \ + python3.12-dev \ + python3-pip +RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.12 1 +RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.12 1 -RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1 -RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.10 1 +RUN apt install -y libze1 libze-dev libze-intel-gpu1 intel-opencl-icd libze-intel-gpu-raytracing + +RUN wget https://github.com/uxlfoundation/oneCCL/releases/download/2021.15.4/intel-oneccl-2021.15.4.11_offline.sh +RUN bash intel-oneccl-2021.15.4.11_offline.sh -a --silent --eula accept && echo "source /opt/intel/oneapi/setvars.sh --force" >> /root/.bashrc +SHELL ["bash", "-c"] +CMD ["bash", "-c", "source /root/.bashrc && exec bash"] WORKDIR /workspace/vllm COPY requirements/xpu.txt /workspace/vllm/requirements/xpu.txt COPY requirements/common.txt /workspace/vllm/requirements/common.txt +# suppress the python externally managed environment error +RUN python3 -m pip config set global.break-system-packages true + RUN --mount=type=cache,target=/root/.cache/pip \ pip install --no-cache-dir \ -r requirements/xpu.txt @@ -54,8 +64,9 @@ FROM vllm-base AS vllm-openai RUN --mount=type=cache,target=/root/.cache/pip \ pip install accelerate hf_transfer pytest pytest_asyncio lm_eval[api] modelscope -ENV VLLM_USAGE_SOURCE production-docker-image \ - TRITON_XPU_PROFILE 1 +RUN --mount=type=cache,target=/root/.cache/pip \ + pip uninstall oneccl oneccl-devel -y + # install development dependencies (for testing) RUN python3 -m pip install -e tests/vllm_test_utils ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] diff --git a/docs/.nav.yml b/docs/.nav.yml index dbac0e12f1bf2..c103ed476d76d 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -32,10 +32,7 @@ nav: - models/pooling_models.md - models/extensions - Hardware Supported Models: models/hardware_supported_models - - Features: - - features/compatibility_matrix.md - - features/* - - features/quantization + - Features: features - Developer Guide: - contributing/README.md - General: @@ -47,11 +44,12 @@ nav: - contributing/model/registration.md - contributing/model/tests.md - contributing/model/multimodal.md + - contributing/model/transcription.md - CI: contributing/ci - Design Documents: design - API Reference: - api/README.md - - api/vllm/* + - api/vllm - CLI Reference: cli - Community: - community/* diff --git a/docs/README.md b/docs/README.md index 683e1d37563f5..ae95717def4cd 100644 --- a/docs/README.md +++ b/docs/README.md @@ -56,7 +56,7 @@ vLLM is flexible and easy to use with: - Tensor, pipeline, data and expert parallelism support for distributed inference - Streaming outputs - OpenAI-compatible API server -- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs, Gaudi® accelerators and GPUs, IBM Power CPUs, TPU, and AWS Trainium and Inferentia Accelerators. +- Support for NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, and TPU. Additionally, support for diverse hardware plugins such as Intel Gaudi, IBM Spyre and Huawei Ascend. - Prefix caching support - Multi-LoRA support diff --git a/docs/configuration/optimization.md b/docs/configuration/optimization.md index c853fcf92941e..5807d787cf531 100644 --- a/docs/configuration/optimization.md +++ b/docs/configuration/optimization.md @@ -230,6 +230,20 @@ Multi-modal IPC caching is automatically enabled when there is a one-to-one correspondence between API (`P0`) and engine core (`P1`) processes, to avoid repeatedly transferring the same multi-modal inputs between them. +#### Key-Replicated Cache + +By default, IPC caching uses a **key-replicated cache**, where cache keys exist +in both the API (`P0`) and engine core (`P1`) processes, but the actual cache +data resides only in `P1`. + +#### Shared Memory Cache + +When multiple worker processes are involved (e.g., when TP > 1), a +**shared-memory cache** is more efficient. This can be enabled by setting +`mm_processor_cache_type="shm"`. In this mode, cache keys are stored +on `P0`, while the cache data itself lives in shared memory accessible by all +processes. + ### Configuration You can adjust the size of the cache by setting the value of `mm_processor_cache_gb` (default 4 GiB). @@ -244,6 +258,12 @@ Examples: llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", mm_processor_cache_gb=8) +# Use a shared-memory based IPC cache +llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", + tensor_parallel_size=2, + mm_processor_cache_type="shm", + mm_processor_cache_gb=8) + # Disable the cache llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", mm_processor_cache_gb=0) @@ -253,11 +273,12 @@ llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", Based on the configuration, the content of the multi-modal caches on `P0` and `P1` are as follows: -| Processor Caching | IPC Caching | `P0` Cache | `P1` Cache | Max. Memory | -|-------------------|-------------|------------|------------|-------------| -| ✅ | ✅ | K | K + V | `mm_processor_cache_gb * data_parallel_size` | -| ✅ | ❌ | K + V | N/A | `mm_processor_cache_gb * api_server_count` | -| ❌ | ❌ | N/A | N/A | `0` | +| mm_processor_cache_type | Cache Type | `P0` Cache | `P1` Engine Cache | `P1` Worker Cache | Max. Memory | +|-------------------|-------------|------------|------------|-------------|-------------| +| lru | Processor Caching | K + V | N/A | N/A | `mm_processor_cache_gb * data_parallel_size` | +| lru | Key-Replicated Caching | K | K + V | N/A | `mm_processor_cache_gb * api_server_count` | +| shm | Shared Memory Caching | K | N/A | V | `mm_processor_cache_gb * api_server_count` | +| N/A | Disabled | N/A | N/A | N/A | `0` | K: Stores the hashes of multi-modal items V: Stores the processed tensor data of multi-modal items diff --git a/docs/contributing/model/README.md b/docs/contributing/model/README.md index 0ca77fa499db7..6c013738ac1ec 100644 --- a/docs/contributing/model/README.md +++ b/docs/contributing/model/README.md @@ -15,6 +15,7 @@ Read through these pages for a step-by-step guide: - [Registering a Model](registration.md) - [Unit Testing](tests.md) - [Multi-Modal Support](multimodal.md) +- [Speech-to-Text Support](transcription.md) !!! tip If you are encountering issues while integrating your model into vLLM, feel free to open a [GitHub issue](https://github.com/vllm-project/vllm/issues) diff --git a/docs/contributing/model/transcription.md b/docs/contributing/model/transcription.md new file mode 100644 index 0000000000000..62e58e5c6ac58 --- /dev/null +++ b/docs/contributing/model/transcription.md @@ -0,0 +1,276 @@ +# Speech-to-Text (Transcription/Translation) Support + +This document walks you through the steps to add support for speech-to-text (ASR) models to vLLM’s transcription and translation APIs by implementing [SupportsTranscription][vllm.model_executor.models.interfaces.SupportsTranscription]. +Please refer to the [supported models](../../models/supported_models.md#transcription) for further guidance. + +## Update the base vLLM model + +It is assumed you have already implemented your model in vLLM according to the basic model guide. Extend your model with the [SupportsTranscription][vllm.model_executor.models.interfaces.SupportsTranscription] interface and implement the following class attributes and methods. + +### `supported_languages` and `supports_transcription_only` + +Declare supported languages and capabilities: + +- The `supported_languages` mapping is validated at init time. +- Set `supports_transcription_only=True` if the model should not serve text generation (eg Whisper). + +??? code "supported_languages and supports_transcription_only" + ```python + from typing import ClassVar, Mapping, Optional, Literal + import numpy as np + import torch + from torch import nn + + from vllm.config import ModelConfig, SpeechToTextConfig + from vllm.inputs.data import PromptType + from vllm.model_executor.models.interfaces import SupportsTranscription + + class YourASRModel(nn.Module, SupportsTranscription): + # Map of ISO 639-1 language codes to language names + supported_languages: ClassVar[Mapping[str, str]] = { + "en": "English", + "it": "Italian", + # ... add more as needed + } + + # If your model only supports audio-conditioned generation + # (no text-only generation), enable this flag. + supports_transcription_only: ClassVar[bool] = True + ``` + +Provide an ASR configuration via [get_speech_to_text_config][vllm.model_executor.models.interfaces.SupportsTranscription.get_speech_to_text_config]. + +This is for controlling general behavior of the API when serving your model: + +??? code "get_speech_to_text_config()" + ```python + class YourASRModel(nn.Module, SupportsTranscription): + ... + + @classmethod + def get_speech_to_text_config( + cls, + model_config: ModelConfig, + task_type: Literal["transcribe", "translate"], + ) -> SpeechToTextConfig: + return SpeechToTextConfig( + sample_rate=16_000, + max_audio_clip_s=30, + # Set to None to disable server-side chunking if your + # model/processor handles it already + min_energy_split_window_size=None, + ) + ``` + +See [Audio preprocessing and chunking](#audio-preprocessing-and-chunking) for what each field controls. + +Implement the prompt construction via [get_generation_prompt][vllm.model_executor.models.interfaces.SupportsTranscription.get_generation_prompt]. The server passes you the resampled waveform and task parameters; you return a valid [PromptType][vllm.inputs.data.PromptType]. There are two common patterns: + +#### Multimodal LLM with audio embeddings (e.g., Voxtral, Gemma3n) + +Return a dict containing `multi_modal_data` with the audio, and either a `prompt` string or `prompt_token_ids`: + +??? code "get_generation_prompt()" + ```python + class YourASRModel(nn.Module, SupportsTranscription): + ... + + @classmethod + def get_generation_prompt( + cls, + audio: np.ndarray, + stt_config: SpeechToTextConfig, + model_config: ModelConfig, + language: Optional[str], + task_type: Literal["transcribe", "translate"], + request_prompt: str, + to_language: Optional[str], + ) -> PromptType: + # Example with a free-form instruction prompt + task_word = "Transcribe" if task_type == "transcribe" else "Translate" + prompt = ( + "user\n" + f"{task_word} this audio: " + "\nmodel\n" + ) + + return { + "multi_modal_data": {"audio": (audio, stt_config.sample_rate)}, + "prompt": prompt, + } + ``` + + For further clarification on multi modal inputs, please refer to [Multi-Modal Inputs](../../features/multimodal_inputs.md). + +#### Encoder–decoder audio-only (e.g., Whisper) + +Return a dict with separate `encoder_prompt` and `decoder_prompt` entries: + +??? code "get_generation_prompt()" + ```python + class YourASRModel(nn.Module, SupportsTranscription): + ... + + @classmethod + def get_generation_prompt( + cls, + audio: np.ndarray, + stt_config: SpeechToTextConfig, + model_config: ModelConfig, + language: Optional[str], + task_type: Literal["transcribe", "translate"], + request_prompt: str, + to_language: Optional[str], + ) -> PromptType: + if language is None: + raise ValueError("Language must be specified") + + prompt = { + "encoder_prompt": { + "prompt": "", + "multi_modal_data": { + "audio": (audio, stt_config.sample_rate), + }, + }, + "decoder_prompt": ( + (f"<|prev|>{request_prompt}" if request_prompt else "") + + f"<|startoftranscript|><|{language}|>" + + f"<|{task_type}|><|notimestamps|>" + ), + } + return cast(PromptType, prompt) + ``` + +### `validate_language` (optional) + +Language validation via [validate_language][vllm.model_executor.models.interfaces.SupportsTranscription.validate_language] + +If your model requires a language and you want a default, override this method (see Whisper): + +??? code "validate_language()" + ```python + @classmethod + def validate_language(cls, language: Optional[str]) -> Optional[str]: + if language is None: + logger.warning( + "Defaulting to language='en'. If you wish to transcribe audio in a different language, pass the `language` field.") + language = "en" + return super().validate_language(language) + ``` + +### `get_num_audio_tokens` (optional) + +Token accounting for streaming via [get_num_audio_tokens][vllm.model_executor.models.interfaces.SupportsTranscription.get_num_audio_tokens] + +Provide a fast duration→token estimate to improve streaming usage statistics: + +??? code "get_num_audio_tokens()" + ```python + class YourASRModel(nn.Module, SupportsTranscription): + ... + + @classmethod + def get_num_audio_tokens( + cls, + audio_duration_s: float, + stt_config: SpeechToTextConfig, + model_config: ModelConfig, + ) -> Optional[int]: + # Return None if unknown; otherwise return an estimate. + return int(audio_duration_s * stt_config.sample_rate // 320) # example + ``` + +## Audio preprocessing and chunking + +The API server takes care of basic audio I/O and optional chunking before building prompts: + +- Resampling: Input audio is resampled to `SpeechToTextConfig.sample_rate` using `librosa`. +- Chunking: If `SpeechToTextConfig.allow_audio_chunking` is True and the duration exceeds `max_audio_clip_s`, the server splits the audio into overlapping chunks and generates a prompt per chunk. Overlap is controlled by `overlap_chunk_second`. +- Energy-aware splitting: When `min_energy_split_window_size` is set, the server finds low-energy regions to minimize cutting within words. + +Relevant server logic: + +??? code "_preprocess_speech_to_text()" + ```python + # vllm/entrypoints/openai/speech_to_text.py + async def _preprocess_speech_to_text(...): + language = self.model_cls.validate_language(request.language) + ... + y, sr = librosa.load(bytes_, sr=self.asr_config.sample_rate) + duration = librosa.get_duration(y=y, sr=sr) + do_split_audio = (self.asr_config.allow_audio_chunking + and duration > self.asr_config.max_audio_clip_s) + chunks = [y] if not do_split_audio else self._split_audio(y, int(sr)) + prompts = [] + for chunk in chunks: + prompt = self.model_cls.get_generation_prompt( + audio=chunk, + stt_config=self.asr_config, + model_config=self.model_config, + language=language, + task_type=self.task_type, + request_prompt=request.prompt, + to_language=to_language, + ) + prompts.append(prompt) + return prompts, duration + ``` + +## Exposing tasks automatically + +vLLM automatically advertises transcription support if your model implements the interface: + +```python +if supports_transcription(model): + if model.supports_transcription_only: + return ["transcription"] + supported_tasks.append("transcription") +``` + +When enabled, the server initializes the transcription and translation handlers: + +```python +state.openai_serving_transcription = OpenAIServingTranscription(...) if "transcription" in supported_tasks else None +state.openai_serving_translation = OpenAIServingTranslation(...) if "transcription" in supported_tasks else None +``` + +No extra registration is required beyond having your model class available via the model registry and implementing `SupportsTranscription`. + +## Examples in-tree + +- Whisper encoder–decoder (audio-only): +- Voxtral decoder-only (audio embeddings + LLM): +- Gemma3n decoder-only with fixed instruction prompt: + +## Test with the API + +Once your model implements `SupportsTranscription`, you can test the endpoints (API mimics OpenAI): + +- Transcription (ASR): + + ```bash + curl -s -X POST \ + -H "Authorization: Bearer $VLLM_API_KEY" \ + -H "Content-Type: multipart/form-data" \ + -F "file=@/path/to/audio.wav" \ + -F "model=$MODEL_ID" \ + http://localhost:8000/v1/audio/transcriptions + ``` + +- Translation (source → English unless otherwise supported): + + ```bash + curl -s -X POST \ + -H "Authorization: Bearer $VLLM_API_KEY" \ + -H "Content-Type: multipart/form-data" \ + -F "file=@/path/to/audio.wav" \ + -F "model=$MODEL_ID" \ + http://localhost:8000/v1/audio/translations + ``` + +Or check out more examples in . + +!!! note + - If your model handles chunking internally (e.g., via its processor or encoder), set `min_energy_split_window_size=None` in the returned `SpeechToTextConfig` to disable server-side chunking. + - Implementing `get_num_audio_tokens` improves accuracy of streaming usage metrics (`prompt_tokens`) without an extra forward pass. + - For multilingual behavior, keep `supported_languages` aligned with actual model capabilities. diff --git a/docs/contributing/profiling.md b/docs/contributing/profiling.md index dffd62385e017..5b83d93274f0d 100644 --- a/docs/contributing/profiling.md +++ b/docs/contributing/profiling.md @@ -19,7 +19,7 @@ When using `vllm bench serve`, you can enable profiling by passing the `--profil Traces can be visualized using . !!! tip -You can directly call bench module without installing vllm using `python -m vllm.entrypoints.cli.main bench`. + You can directly call bench module without installing vLLM using `python -m vllm.entrypoints.cli.main bench`. !!! tip Only send a few requests through vLLM when profiling, as the traces can get quite large. Also, no need to untar the traces, they can be viewed directly. diff --git a/docs/deployment/frameworks/anything-llm.md b/docs/deployment/frameworks/anything-llm.md index 0b41e73b030cc..40a463a8a596c 100644 --- a/docs/deployment/frameworks/anything-llm.md +++ b/docs/deployment/frameworks/anything-llm.md @@ -1,41 +1,53 @@ -# Anything LLM +# AnythingLLM -[Anything LLM](https://github.com/Mintplex-Labs/anything-llm) is a full-stack application that enables you to turn any document, resource, or piece of content into context that any LLM can use as references during chatting. +[AnythingLLM](https://github.com/Mintplex-Labs/anything-llm) is a full-stack application that enables you to turn any document, resource, or piece of content into context that any LLM can use as references during chatting. It allows you to deploy a large language model (LLM) server with vLLM as the backend, which exposes OpenAI-compatible endpoints. ## Prerequisites -- Setup vLLM environment +Set up the vLLM environment: + +```bash +pip install vllm +``` ## Deploy -- Start the vLLM server with the supported chat completion model, e.g. +1. Start the vLLM server with a supported chat-completion model, for example: -```bash -vllm serve Qwen/Qwen1.5-32B-Chat-AWQ --max-model-len 4096 -``` + ```bash + vllm serve Qwen/Qwen1.5-32B-Chat-AWQ --max-model-len 4096 + ``` -- Download and install [Anything LLM desktop](https://anythingllm.com/desktop). +1. Download and install [AnythingLLM Desktop](https://anythingllm.com/desktop). -- On the bottom left of open settings, AI Providers --> LLM: - - LLM Provider: Generic OpenAI - - Base URL: http://{vllm server host}:{vllm server port}/v1 - - Chat Model Name: `Qwen/Qwen1.5-32B-Chat-AWQ` +1. Configure the AI provider: -![](../../assets/deployment/anything-llm-provider.png) + - At the bottom, click the 🔧 wrench icon -> **Open settings** -> **AI Providers** -> **LLM**. + - Enter the following values: + - LLM Provider: Generic OpenAI + - Base URL: `http://{vllm server host}:{vllm server port}/v1` + - Chat Model Name: `Qwen/Qwen1.5-32B-Chat-AWQ` -- Back to home page, New Workspace --> create `vllm` workspace, and start to chat: + ![set AI providers](../../assets/deployment/anything-llm-provider.png) -![](../../assets/deployment/anything-llm-chat-without-doc.png) +1. Create a workspace: -- Click the upload button: - - upload the doc - - select the doc and move to the workspace - - save and embed + 1. At the bottom, click the ↺ back icon and back to workspaces. + 1. Create a workspace (e.g., `vllm`) and start chatting. -![](../../assets/deployment/anything-llm-upload-doc.png) + ![create a workspace](../../assets/deployment/anything-llm-chat-without-doc.png) -- Chat again: +1. Add a document. -![](../../assets/deployment/anything-llm-chat-with-doc.png) + 1. Click the 📎 attachment icon. + 1. Upload a document. + 1. Select and move the document into your workspace. + 1. Save and embed it. + + ![add a document](../../assets/deployment/anything-llm-upload-doc.png) + +1. Chat using your document as context. + + ![chat with your context](../../assets/deployment/anything-llm-chat-with-doc.png) diff --git a/docs/deployment/frameworks/autogen.md b/docs/deployment/frameworks/autogen.md index c255a85d38401..7517ee771c097 100644 --- a/docs/deployment/frameworks/autogen.md +++ b/docs/deployment/frameworks/autogen.md @@ -4,9 +4,7 @@ ## Prerequisites -- Setup vLLM environment - -- Setup [AutoGen](https://microsoft.github.io/autogen/0.2/docs/installation/) environment +Set up the vLLM and [AutoGen](https://microsoft.github.io/autogen/0.2/docs/installation/) environment: ```bash pip install vllm @@ -18,14 +16,14 @@ pip install -U "autogen-agentchat" "autogen-ext[openai]" ## Deploy -- Start the vLLM server with the supported chat completion model, e.g. +1. Start the vLLM server with the supported chat completion model, e.g. -```bash -python -m vllm.entrypoints.openai.api_server \ - --model mistralai/Mistral-7B-Instruct-v0.2 -``` + ```bash + python -m vllm.entrypoints.openai.api_server \ + --model mistralai/Mistral-7B-Instruct-v0.2 + ``` -- Call it with AutoGen: +1. Call it with AutoGen: ??? code diff --git a/docs/deployment/frameworks/chatbox.md b/docs/deployment/frameworks/chatbox.md index cbca6e6282fc6..002935da56009 100644 --- a/docs/deployment/frameworks/chatbox.md +++ b/docs/deployment/frameworks/chatbox.md @@ -6,27 +6,31 @@ It allows you to deploy a large language model (LLM) server with vLLM as the bac ## Prerequisites -- Setup vLLM environment +Set up the vLLM environment: + +```bash +pip install vllm +``` ## Deploy -- Start the vLLM server with the supported chat completion model, e.g. +1. Start the vLLM server with the supported chat completion model, e.g. -```bash -vllm serve qwen/Qwen1.5-0.5B-Chat -``` + ```bash + vllm serve qwen/Qwen1.5-0.5B-Chat + ``` -- Download and install [Chatbox desktop](https://chatboxai.app/en#download). +1. Download and install [Chatbox desktop](https://chatboxai.app/en#download). -- On the bottom left of settings, Add Custom Provider +1. On the bottom left of settings, Add Custom Provider - API Mode: `OpenAI API Compatible` - Name: vllm - API Host: `http://{vllm server host}:{vllm server port}/v1` - API Path: `/chat/completions` - Model: `qwen/Qwen1.5-0.5B-Chat` -![](../../assets/deployment/chatbox-settings.png) + ![](../../assets/deployment/chatbox-settings.png) -- Go to `Just chat`, and start to chat: +1. Go to `Just chat`, and start to chat: -![](../../assets/deployment/chatbox-chat.png) + ![](../../assets/deployment/chatbox-chat.png) diff --git a/docs/deployment/frameworks/dify.md b/docs/deployment/frameworks/dify.md index 35f02c33cb02b..820ef0cbed9fa 100644 --- a/docs/deployment/frameworks/dify.md +++ b/docs/deployment/frameworks/dify.md @@ -8,44 +8,50 @@ This guide walks you through deploying Dify using a vLLM backend. ## Prerequisites -- Setup vLLM environment -- Install [Docker](https://docs.docker.com/engine/install/) and [Docker Compose](https://docs.docker.com/compose/install/) +Set up the vLLM environment: + +```bash +pip install vllm +``` + +And install [Docker](https://docs.docker.com/engine/install/) and [Docker Compose](https://docs.docker.com/compose/install/). ## Deploy -- Start the vLLM server with the supported chat completion model, e.g. +1. Start the vLLM server with the supported chat completion model, e.g. -```bash -vllm serve Qwen/Qwen1.5-7B-Chat -``` + ```bash + vllm serve Qwen/Qwen1.5-7B-Chat + ``` -- Start the Dify server with docker compose ([details](https://github.com/langgenius/dify?tab=readme-ov-file#quick-start)): +1. Start the Dify server with docker compose ([details](https://github.com/langgenius/dify?tab=readme-ov-file#quick-start)): -```bash -git clone https://github.com/langgenius/dify.git -cd dify -cd docker -cp .env.example .env -docker compose up -d -``` + ```bash + git clone https://github.com/langgenius/dify.git + cd dify + cd docker + cp .env.example .env + docker compose up -d + ``` -- Open the browser to access `http://localhost/install`, config the basic login information and login. +1. Open the browser to access `http://localhost/install`, config the basic login information and login. -- In the top-right user menu (under the profile icon), go to Settings, then click `Model Provider`, and locate the `vLLM` provider to install it. +1. In the top-right user menu (under the profile icon), go to Settings, then click `Model Provider`, and locate the `vLLM` provider to install it. + +1. Fill in the model provider details as follows: -- Fill in the model provider details as follows: - **Model Type**: `LLM` - **Model Name**: `Qwen/Qwen1.5-7B-Chat` - **API Endpoint URL**: `http://{vllm_server_host}:{vllm_server_port}/v1` - **Model Name for API Endpoint**: `Qwen/Qwen1.5-7B-Chat` - **Completion Mode**: `Completion` -![](../../assets/deployment/dify-settings.png) + ![](../../assets/deployment/dify-settings.png) -- To create a test chatbot, go to `Studio → Chatbot → Create from Blank`, then select Chatbot as the type: +1. To create a test chatbot, go to `Studio → Chatbot → Create from Blank`, then select Chatbot as the type: -![](../../assets/deployment/dify-create-chatbot.png) + ![](../../assets/deployment/dify-create-chatbot.png) -- Click the chatbot you just created to open the chat interface and start interacting with the model: +1. Click the chatbot you just created to open the chat interface and start interacting with the model: -![](../../assets/deployment/dify-chat.png) + ![](../../assets/deployment/dify-chat.png) diff --git a/docs/deployment/frameworks/haystack.md b/docs/deployment/frameworks/haystack.md index 70b4b48d4543e..836305cf15c42 100644 --- a/docs/deployment/frameworks/haystack.md +++ b/docs/deployment/frameworks/haystack.md @@ -6,7 +6,7 @@ It allows you to deploy a large language model (LLM) server with vLLM as the bac ## Prerequisites -- Setup vLLM and Haystack environment +Set up the vLLM and Haystack environment: ```bash pip install vllm haystack-ai @@ -14,13 +14,13 @@ pip install vllm haystack-ai ## Deploy -- Start the vLLM server with the supported chat completion model, e.g. +1. Start the vLLM server with the supported chat completion model, e.g. -```bash -vllm serve mistralai/Mistral-7B-Instruct-v0.1 -``` + ```bash + vllm serve mistralai/Mistral-7B-Instruct-v0.1 + ``` -- Use the `OpenAIGenerator` and `OpenAIChatGenerator` components in Haystack to query the vLLM server. +1. Use the `OpenAIGenerator` and `OpenAIChatGenerator` components in Haystack to query the vLLM server. ??? code diff --git a/docs/deployment/frameworks/litellm.md b/docs/deployment/frameworks/litellm.md index c7e514f2276e0..0d6c3729911ad 100644 --- a/docs/deployment/frameworks/litellm.md +++ b/docs/deployment/frameworks/litellm.md @@ -13,7 +13,7 @@ And LiteLLM supports all models on VLLM. ## Prerequisites -- Setup vLLM and litellm environment +Set up the vLLM and litellm environment: ```bash pip install vllm litellm @@ -23,13 +23,13 @@ pip install vllm litellm ### Chat completion -- Start the vLLM server with the supported chat completion model, e.g. +1. Start the vLLM server with the supported chat completion model, e.g. -```bash -vllm serve qwen/Qwen1.5-0.5B-Chat -``` + ```bash + vllm serve qwen/Qwen1.5-0.5B-Chat + ``` -- Call it with litellm: +1. Call it with litellm: ??? code @@ -51,13 +51,13 @@ vllm serve qwen/Qwen1.5-0.5B-Chat ### Embeddings -- Start the vLLM server with the supported embedding model, e.g. +1. Start the vLLM server with the supported embedding model, e.g. -```bash -vllm serve BAAI/bge-base-en-v1.5 -``` + ```bash + vllm serve BAAI/bge-base-en-v1.5 + ``` -- Call it with litellm: +1. Call it with litellm: ```python from litellm import embedding diff --git a/docs/deployment/frameworks/retrieval_augmented_generation.md b/docs/deployment/frameworks/retrieval_augmented_generation.md index d5f2ec302b6cd..d86ab1600f126 100644 --- a/docs/deployment/frameworks/retrieval_augmented_generation.md +++ b/docs/deployment/frameworks/retrieval_augmented_generation.md @@ -11,7 +11,7 @@ Here are the integrations: ### Prerequisites -- Setup vLLM and langchain environment +Set up the vLLM and langchain environment: ```bash pip install -U vllm \ @@ -22,33 +22,33 @@ pip install -U vllm \ ### Deploy -- Start the vLLM server with the supported embedding model, e.g. +1. Start the vLLM server with the supported embedding model, e.g. -```bash -# Start embedding service (port 8000) -vllm serve ssmits/Qwen2-7B-Instruct-embed-base -``` + ```bash + # Start embedding service (port 8000) + vllm serve ssmits/Qwen2-7B-Instruct-embed-base + ``` -- Start the vLLM server with the supported chat completion model, e.g. +1. Start the vLLM server with the supported chat completion model, e.g. -```bash -# Start chat service (port 8001) -vllm serve qwen/Qwen1.5-0.5B-Chat --port 8001 -``` + ```bash + # Start chat service (port 8001) + vllm serve qwen/Qwen1.5-0.5B-Chat --port 8001 + ``` -- Use the script: +1. Use the script: -- Run the script +1. Run the script -```python -python retrieval_augmented_generation_with_langchain.py -``` + ```python + python retrieval_augmented_generation_with_langchain.py + ``` ## vLLM + llamaindex ### Prerequisites -- Setup vLLM and llamaindex environment +Set up the vLLM and llamaindex environment: ```bash pip install vllm \ @@ -60,24 +60,24 @@ pip install vllm \ ### Deploy -- Start the vLLM server with the supported embedding model, e.g. +1. Start the vLLM server with the supported embedding model, e.g. -```bash -# Start embedding service (port 8000) -vllm serve ssmits/Qwen2-7B-Instruct-embed-base -``` + ```bash + # Start embedding service (port 8000) + vllm serve ssmits/Qwen2-7B-Instruct-embed-base + ``` -- Start the vLLM server with the supported chat completion model, e.g. +1. Start the vLLM server with the supported chat completion model, e.g. -```bash -# Start chat service (port 8001) -vllm serve qwen/Qwen1.5-0.5B-Chat --port 8001 -``` + ```bash + # Start chat service (port 8001) + vllm serve qwen/Qwen1.5-0.5B-Chat --port 8001 + ``` -- Use the script: +1. Use the script: -- Run the script +1. Run the script: -```python -python retrieval_augmented_generation_with_llamaindex.py -``` + ```python + python retrieval_augmented_generation_with_llamaindex.py + ``` diff --git a/docs/deployment/integrations/llamastack.md b/docs/deployment/integrations/llamastack.md index 28031f01f85e8..8eb7f8d81275d 100644 --- a/docs/deployment/integrations/llamastack.md +++ b/docs/deployment/integrations/llamastack.md @@ -1,6 +1,6 @@ # Llama Stack -vLLM is also available via [Llama Stack](https://github.com/meta-llama/llama-stack) . +vLLM is also available via [Llama Stack](https://github.com/llamastack/llama-stack). To install Llama Stack, run @@ -8,9 +8,9 @@ To install Llama Stack, run pip install llama-stack -q ``` -## Inference using OpenAI Compatible API +## Inference using OpenAI-Compatible API -Then start Llama Stack server pointing to your vLLM server with the following configuration: +Then start the Llama Stack server and configure it to point to your vLLM server with the following settings: ```yaml inference: @@ -20,15 +20,15 @@ inference: url: http://127.0.0.1:8000 ``` -Please refer to [this guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/remote-vllm.html) for more details on this remote vLLM provider. +Please refer to [this guide](https://llama-stack.readthedocs.io/en/latest/providers/inference/remote_vllm.html) for more details on this remote vLLM provider. -## Inference via Embedded vLLM +## Inference using Embedded vLLM -An [inline vLLM provider](https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/inline/inference/vllm) +An [inline provider](https://github.com/llamastack/llama-stack/tree/main/llama_stack/providers/inline/inference) is also available. This is a sample of configuration using that method: ```yaml -inference +inference: - provider_type: vllm config: model: Llama3.1-8B-Instruct diff --git a/docs/design/multiprocessing.md b/docs/design/multiprocessing.md index 247072d1cb275..6e92b20d267b4 100644 --- a/docs/design/multiprocessing.md +++ b/docs/design/multiprocessing.md @@ -8,7 +8,7 @@ page for information on known issues and how to solve them. ## Introduction !!! important - The source code references are to the state of the code at the time of writing in December, 2024. + The source code references are to the state of the code at the time of writing in December 2024. The use of Python multiprocessing in vLLM is complicated by: diff --git a/docs/examples/README.md b/docs/examples/README.md index 3cf93027f4209..94f5efc92f386 100644 --- a/docs/examples/README.md +++ b/docs/examples/README.md @@ -2,6 +2,6 @@ vLLM's examples are split into three categories: -- If you are using vLLM from within Python code, see [Offline Inference](./offline_inference) -- If you are using vLLM from an HTTP application or client, see [Online Serving](./online_serving) -- For examples of using some of vLLM's advanced features (e.g. LMCache or Tensorizer) which are not specific to either of the above use cases, see [Others](./others) +- If you are using vLLM from within Python code, see the *Offline Inference* section. +- If you are using vLLM from an HTTP application or client, see the *Online Serving* section. +- For examples of using some of vLLM's advanced features (e.g. LMCache or Tensorizer) which are not specific to either of the above use cases, see the *Others* section. diff --git a/docs/features/compatibility_matrix.md b/docs/features/README.md similarity index 95% rename from docs/features/compatibility_matrix.md rename to docs/features/README.md index 5b08b3810776c..d8e26ec02aecc 100644 --- a/docs/features/compatibility_matrix.md +++ b/docs/features/README.md @@ -1,4 +1,6 @@ -# Compatibility Matrix +# Features + +## Compatibility Matrix The tables below show mutually exclusive features and the support on some hardware. @@ -12,7 +14,7 @@ The symbols used have the following meanings: !!! note Check the ❌ or 🟠 with links to see tracking issue for unsupported feature/hardware combination. -## Feature x Feature +### Feature x Feature -| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | Intel Gaudi | x86 CPU | AWS Neuron | Google TPU | -|-----------------------|---------|----------|----------|-------|----------|-----------|-------------|-------------|-----------|--------------|--------------| -| AWQ | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | ❌ | ❌ | -| GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | ❌ | ❌ | -| Marlin (GPTQ/AWQ/FP8) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| INT8 (W8A8) | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | -| FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ | ❌ | -| BitBLAS | ✅︎ | ✅ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| BitBLAS (GPTQ) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | -| INC (W8A8) | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅︎ | ❌ | ❌ | ❌ | +| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | Intel Gaudi | x86 CPU | Google TPU | +|-----------------------|---------|----------|----------|-------|----------|-----------|-------------|-------------|-----------|--------------| +| AWQ | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | ❌ | +| GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | ❌ | +| Marlin (GPTQ/AWQ/FP8) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | +| INT8 (W8A8) | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | +| FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | +| BitBLAS | ✅︎ | ✅ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | +| BitBLAS (GPTQ) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | +| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | +| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | +| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | +| INC (W8A8) | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅︎ | ❌ | ❌ | - Volta refers to SM 7.0, Turing to SM 7.5, Ampere to SM 8.0/8.6, Ada to SM 8.9, and Hopper to SM 9.0. - ✅︎ indicates that the quantization method is supported on the specified hardware. diff --git a/docs/features/reasoning_outputs.md b/docs/features/reasoning_outputs.md index d9a785eb73fbe..d518e7f0cff43 100644 --- a/docs/features/reasoning_outputs.md +++ b/docs/features/reasoning_outputs.md @@ -15,6 +15,7 @@ vLLM currently supports the following reasoning models: | [IBM Granite 3.2 language models](https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a) | `granite` | ❌ | ❌ | | [Qwen3 series](https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f) | `qwen3` | `guided_json`, `guided_regex` | ✅ | | [Hunyuan A13B series](https://huggingface.co/collections/tencent/hunyuan-a13b-685ec38e5b46321e3ea7c4be) | `hunyuan_a13b` | `guided_json`, `guided_regex` | ✅ | +| [GLM-4.5 series](https://huggingface.co/collections/zai-org/glm-45-687c621d34bda8c9e4bf503b) | `glm45` | `guided_json`, `guided_regex` | ✅ | !!! note IBM Granite 3.2 reasoning is disabled by default; to enable it, you must also pass `thinking=True` in your `chat_template_kwargs`. diff --git a/docs/features/tool_calling.md b/docs/features/tool_calling.md index afc605a504b3d..a8c0db0a7ac13 100644 --- a/docs/features/tool_calling.md +++ b/docs/features/tool_calling.md @@ -169,7 +169,7 @@ All Llama 3.1, 3.2 and 4 models should be supported. The tool calling that is supported is the [JSON-based tool calling](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling). For [pythonic tool calling](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#zero-shot-function-calling) introduced by the Llama-3.2 models, see the `pythonic` tool parser below. As for Llama 4 models, it is recommended to use the `llama4_pythonic` tool parser. -Other tool calling formats like the built in python tool calling or custom tool calling are not supported. +Other tool calling formats like the built-in python tool calling or custom tool calling are not supported. Known issues: @@ -311,6 +311,15 @@ Flags: * For non-reasoning: `--tool-call-parser hunyuan_a13b` * For reasoning: `--tool-call-parser hunyuan_a13b --reasoning-parser hunyuan_a13b --enable_reasoning` +### GLM-4.5 Models (`glm45`) + +Supported models: + +* `ZhipuAI/GLM-4.5` +* `ZhipuAI/GLM-4.5-Air` + +Flags: `--tool-call-parser glm45` + ### Models with Pythonic Tool Calls (`pythonic`) A growing number of models output a python list to represent tool calls instead of using JSON. This has the advantage of inherently supporting parallel tool calls and removing ambiguity around the JSON schema required for tool calls. The `pythonic` tool parser can support such models. diff --git a/docs/getting_started/installation/.nav.yml b/docs/getting_started/installation/.nav.yml index d4a727c926406..ba1f8099a6456 100644 --- a/docs/getting_started/installation/.nav.yml +++ b/docs/getting_started/installation/.nav.yml @@ -3,5 +3,3 @@ nav: - gpu.md - cpu.md - google_tpu.md - - intel_gaudi.md - - aws_neuron.md diff --git a/docs/getting_started/installation/README.md b/docs/getting_started/installation/README.md index 8a658b7a9103f..5e57d23f4a1df 100644 --- a/docs/getting_started/installation/README.md +++ b/docs/getting_started/installation/README.md @@ -12,7 +12,6 @@ vLLM supports the following hardware platforms: - [Apple silicon](cpu.md#apple-silicon) - [IBM Z (S390X)](cpu.md#ibm-z-s390x) - [Google TPU](google_tpu.md) -- [AWS Neuron](aws_neuron.md) ## Hardware Plugins diff --git a/docs/getting_started/installation/aws_neuron.md b/docs/getting_started/installation/aws_neuron.md deleted file mode 100644 index ff2500f035270..0000000000000 --- a/docs/getting_started/installation/aws_neuron.md +++ /dev/null @@ -1,147 +0,0 @@ -# AWS Neuron - -[AWS Neuron](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/) is the software development kit (SDK) used to run deep learning and -generative AI workloads on AWS Inferentia and AWS Trainium powered Amazon EC2 instances and UltraServers (Inf1, Inf2, Trn1, Trn2, -and Trn2 UltraServer). Both Trainium and Inferentia are powered by fully-independent heterogeneous compute-units called NeuronCores. -This describes how to set up your environment to run vLLM on Neuron. - -!!! warning - There are no pre-built wheels or images for this device, so you must build vLLM from source. - -## Requirements - -- OS: Linux -- Python: 3.9 or newer -- Pytorch 2.5/2.6 -- Accelerator: NeuronCore-v2 (in trn1/inf2 chips) or NeuronCore-v3 (in trn2 chips) -- AWS Neuron SDK 2.23 - -## Configure a new environment - -### Launch a Trn1/Trn2/Inf2 instance and verify Neuron dependencies - -The easiest way to launch a Trainium or Inferentia instance with pre-installed Neuron dependencies is to follow this -[quick start guide](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/setup/neuron-setup/multiframework/multi-framework-ubuntu22-neuron-dlami.html#setup-ubuntu22-multi-framework-dlami) using the Neuron Deep Learning AMI (Amazon machine image). - -- After launching the instance, follow the instructions in [Connect to your instance](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/AccessingInstancesLinux.html) to connect to the instance -- Once inside your instance, activate the pre-installed virtual environment for inference by running - -```bash -source /opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/bin/activate -``` - -Refer to the [NxD Inference Setup Guide](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/nxd-inference/nxdi-setup.html) -for alternative setup instructions including using Docker and manually installing dependencies. - -!!! note - NxD Inference is the default recommended backend to run inference on Neuron. If you are looking to use the legacy [transformers-neuronx](https://github.com/aws-neuron/transformers-neuronx) - library, refer to [Transformers NeuronX Setup](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/transformers-neuronx/setup/index.html). - -## Set up using Python - -### Pre-built wheels - -Currently, there are no pre-built Neuron wheels. - -### Build wheel from source - -To build and install vLLM from source, run: - -```bash -git clone https://github.com/vllm-project/vllm.git -cd vllm -pip install -U -r requirements/neuron.txt -VLLM_TARGET_DEVICE="neuron" pip install -e . -``` - -AWS Neuron maintains a [Github fork of vLLM](https://github.com/aws-neuron/upstreaming-to-vllm/tree/neuron-2.23-vllm-v0.7.2) at -, which contains several features in addition to what's -available on vLLM V0. Please utilize the AWS Fork for the following features: - -- Llama-3.2 multi-modal support -- Multi-node distributed inference - -Refer to [vLLM User Guide for NxD Inference](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/nxd-inference/developer_guides/vllm-user-guide.html) - for more details and usage examples. - -To install the AWS Neuron fork, run the following: - -```bash -git clone -b neuron-2.23-vllm-v0.7.2 https://github.com/aws-neuron/upstreaming-to-vllm.git -cd upstreaming-to-vllm -pip install -r requirements/neuron.txt -VLLM_TARGET_DEVICE="neuron" pip install -e . -``` - -Note that the AWS Neuron fork is only intended to support Neuron hardware; compatibility with other hardwares is not tested. - -## Set up using Docker - -### Pre-built images - -Currently, there are no pre-built Neuron images. - -### Build image from source - -See [deployment-docker-build-image-from-source][deployment-docker-build-image-from-source] for instructions on building the Docker image. - -Make sure to use in place of the default Dockerfile. - -## Extra information - -[](){ #feature-support-through-nxd-inference-backend } - -### Feature support through NxD Inference backend - -The current vLLM and Neuron integration relies on either the `neuronx-distributed-inference` (preferred) or `transformers-neuronx` backend -to perform most of the heavy lifting which includes PyTorch model initialization, compilation, and runtime execution. Therefore, most -[features supported on Neuron](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/nxd-inference/developer_guides/feature-guide.html) are also available via the vLLM integration. - -To configure NxD Inference features through the vLLM entrypoint, use the `override_neuron_config` setting. Provide the configs you want to override -as a dictionary (or JSON object when starting vLLM from the CLI). For example, to disable auto bucketing, include - -```python -override_neuron_config={ - "enable_bucketing":False, -} -``` - -or when launching vLLM from the CLI, pass - -```bash ---override-neuron-config "{\"enable_bucketing\":false}" -``` - -Alternatively, users can directly call the NxDI library to trace and compile your model, then load the pre-compiled artifacts -(via `NEURON_COMPILED_ARTIFACTS` environment variable) in vLLM to run inference workloads. - -### Known limitations - -- EAGLE speculative decoding: NxD Inference requires the EAGLE draft checkpoint to include the LM head weights from the target model. Refer to this - [guide](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/nxd-inference/developer_guides/feature-guide.html#eagle-checkpoint-compatibility) - for how to convert pretrained EAGLE model checkpoints to be compatible for NxDI. -- Quantization: the native quantization flow in vLLM is not well supported on NxD Inference. It is recommended to follow this - [Neuron quantization guide](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/nxd-inference/developer_guides/custom-quantization.html) - to quantize and compile your model using NxD Inference, and then load the compiled artifacts into vLLM. -- Multi-LoRA serving: NxD Inference only supports loading of LoRA adapters at server startup. Dynamic loading of LoRA adapters at - runtime is not currently supported. Refer to [multi-lora example](https://github.com/aws-neuron/upstreaming-to-vllm/blob/neuron-2.23-vllm-v0.7.2/examples/offline_inference/neuron_multi_lora.py) -- Multi-modal support: multi-modal support is only available through the AWS Neuron fork. This feature has not been upstreamed - to vLLM main because NxD Inference currently relies on certain adaptations to the core vLLM logic to support this feature. -- Multi-node support: distributed inference across multiple Trainium/Inferentia instances is only supported on the AWS Neuron fork. Refer - to this [multi-node example](https://github.com/aws-neuron/upstreaming-to-vllm/tree/neuron-2.23-vllm-v0.7.2/examples/neuron/multi_node) - to run. Note that tensor parallelism (distributed inference across NeuronCores) is available in vLLM main. -- Known edge case bug in speculative decoding: An edge case failure may occur in speculative decoding when sequence length approaches - max model length (e.g. when requesting max tokens up to the max model length and ignoring eos). In this scenario, vLLM may attempt - to allocate an additional block to ensure there is enough memory for number of lookahead slots, but since we do not have good support - for paged attention, there isn't another Neuron block for vLLM to allocate. A workaround fix (to terminate 1 iteration early) is - implemented in the AWS Neuron fork but is not upstreamed to vLLM main as it modifies core vLLM logic. - -### Environment variables - -- `NEURON_COMPILED_ARTIFACTS`: set this environment variable to point to your pre-compiled model artifacts directory to avoid - compilation time upon server initialization. If this variable is not set, the Neuron module will perform compilation and save the - artifacts under `neuron-compiled-artifacts/{unique_hash}/` subdirectory in the model path. If this environment variable is set, - but the directory does not exist, or the contents are invalid, Neuron will also fall back to a new compilation and store the artifacts - under this specified path. -- `NEURON_CONTEXT_LENGTH_BUCKETS`: Bucket sizes for context encoding. (Only applicable to `transformers-neuronx` backend). -- `NEURON_TOKEN_GEN_BUCKETS`: Bucket sizes for token generation. (Only applicable to `transformers-neuronx` backend). diff --git a/docs/getting_started/installation/gpu/cuda.inc.md b/docs/getting_started/installation/gpu/cuda.inc.md index 275232e12e08c..01c5f5fc02f3e 100644 --- a/docs/getting_started/installation/gpu/cuda.inc.md +++ b/docs/getting_started/installation/gpu/cuda.inc.md @@ -165,14 +165,14 @@ There are scenarios where the PyTorch dependency cannot be easily installed with - Building vLLM with PyTorch nightly or a custom PyTorch build. - Building vLLM with aarch64 and CUDA (GH200), where the PyTorch wheels are not available on PyPI. Currently, only the PyTorch nightly has wheels for aarch64 with CUDA. You can run `uv pip install --index-url https://download.pytorch.org/whl/nightly/cu128 torch torchvision torchaudio` to [install PyTorch nightly](https://pytorch.org/get-started/locally/) and then build vLLM on top of it. -To build vLLM using an existing PyTorch installation: +To build vLLM using an existing PyTorch installation, it is recommended to use `uv`, because it has [a unique mechanism](https://docs.astral.sh/uv/concepts/projects/config/#disabling-build-isolation) for disabling build isolation for specific packages and vLLM leverages this mechanism to specify `torch` as the package to disable build isolation. ```bash +# install PyTorch first, either from PyPI or from source git clone https://github.com/vllm-project/vllm.git cd vllm -python use_existing_torch.py -uv pip install -r requirements/build.txt -uv pip install --no-build-isolation -e . +# pip install -e . does not work directly, only uv can do this +uv pip install -e . ``` ##### Use the local cutlass for compilation diff --git a/docs/getting_started/installation/gpu/rocm.inc.md b/docs/getting_started/installation/gpu/rocm.inc.md index 80e99d3034d39..37c6647929b51 100644 --- a/docs/getting_started/installation/gpu/rocm.inc.md +++ b/docs/getting_started/installation/gpu/rocm.inc.md @@ -1,6 +1,6 @@ # --8<-- [start:installation] -vLLM supports AMD GPUs with ROCm 6.3. +vLLM supports AMD GPUs with ROCm 6.3 or above. !!! tip [Docker](#set-up-using-docker) is the recommended way to use vLLM on ROCm. @@ -11,8 +11,9 @@ vLLM supports AMD GPUs with ROCm 6.3. # --8<-- [end:installation] # --8<-- [start:requirements] -- GPU: MI200s (gfx90a), MI300 (gfx942), Radeon RX 7900 series (gfx1100/1101), Radeon RX 9000 series (gfx1200/1201) -- ROCm 6.3 +- GPU: MI200s (gfx90a), MI300 (gfx942), MI350 (gfx950), Radeon RX 7900 series (gfx1100/1101), Radeon RX 9000 series (gfx1200/1201) +- ROCm 6.3 or above + - MI350 requires ROCm 7.0 or above # --8<-- [end:requirements] # --8<-- [start:set-up-using-python] @@ -32,35 +33,35 @@ Currently, there are no pre-built ROCm wheels. - [ROCm](https://rocm.docs.amd.com/en/latest/deploy/linux/index.html) - [PyTorch](https://pytorch.org/) - For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.3_ubuntu24.04_py3.12_pytorch_release_2.4.0`, `rocm/pytorch-nightly`. If you are using docker image, you can skip to Step 3. + For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.4.3_ubuntu24.04_py3.12_pytorch_release_2.6.0`, `rocm/pytorch-nightly`. If you are using docker image, you can skip to Step 3. Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTorch installation guide in PyTorch [Getting Started](https://pytorch.org/get-started/locally/). Example: ```bash # Install PyTorch pip uninstall torch -y - pip install --no-cache-dir --pre torch --index-url https://download.pytorch.org/whl/nightly/rocm6.3 + pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/rocm6.4 ``` -1. Install [Triton flash attention for ROCm](https://github.com/ROCm/triton) +1. Install [Triton for ROCm](https://github.com/triton-lang/triton) - Install ROCm's Triton flash attention (the default triton-mlir branch) following the instructions from [ROCm/triton](https://github.com/ROCm/triton/blob/triton-mlir/README.md) + Install ROCm's Triton (the default triton-mlir branch) following the instructions from [ROCm/triton](https://github.com/ROCm/triton/blob/triton-mlir/README.md) ```bash python3 -m pip install ninja cmake wheel pybind11 pip uninstall -y triton - git clone https://github.com/OpenAI/triton.git + git clone https://github.com/triton-lang/triton.git cd triton git checkout e5be006 - cd python - pip3 install . + if [ ! -f setup.py ]; then cd python; fi + python3 setup.py install cd ../.. ``` !!! note If you see HTTP issue related to downloading packages during building triton, please try again as the HTTP error is intermittent. -2. Optionally, if you choose to use CK flash attention, you can install [flash attention for ROCm](https://github.com/ROCm/flash-attention) +2. Optionally, if you choose to use CK flash attention, you can install [flash attention for ROCm](https://github.com/Dao-AILab/flash-attention) Install ROCm's flash attention (v2.7.2) following the instructions from [ROCm/flash-attention](https://github.com/ROCm/flash-attention#amd-rocm-support) Alternatively, wheels intended for vLLM use can be accessed under the releases. @@ -68,9 +69,9 @@ Currently, there are no pre-built ROCm wheels. For example, for ROCm 6.3, suppose your gfx arch is `gfx90a`. To get your gfx architecture, run `rocminfo |grep gfx`. ```bash - git clone https://github.com/ROCm/flash-attention.git + git clone https://github.com/Dao-AILab/flash-attention.git cd flash-attention - git checkout b7d29fb + git checkout 1a7f4dfa git submodule update --init GPU_ARCHS="gfx90a" python3 setup.py install cd .. @@ -119,7 +120,7 @@ Currently, there are no pre-built ROCm wheels. This may take 5-10 minutes. Currently, `pip install .` does not work for ROCm installation. !!! tip - - Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm up step before collecting perf numbers. + - Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm-up step before collecting perf numbers. - Triton flash attention does not currently support sliding window attention. If using half precision, please use CK flash-attention for sliding window support. - To use CK flash-attention or PyTorch naive attention, please use this flag `export VLLM_USE_TRITON_FLASH_ATTN=0` to turn off triton flash attention. - The ROCm version of PyTorch, ideally, should match the ROCm driver version. @@ -194,16 +195,6 @@ To build vllm on ROCm 6.3 for MI200 and MI300 series, you can use the default: DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile.rocm -t vllm-rocm . ``` -To build vllm on ROCm 6.3 for Radeon RX7900 series (gfx1100), you should pick the alternative base image: - -```bash -DOCKER_BUILDKIT=1 docker build \ - --build-arg BASE_IMAGE="rocm/vllm-dev:navi_base" \ - -f docker/Dockerfile.rocm \ - -t vllm-rocm \ - . -``` - To run the above docker image `vllm-rocm`, use the below command: ??? console "Command" @@ -218,8 +209,7 @@ To run the above docker image `vllm-rocm`, use the below command: --device /dev/kfd \ --device /dev/dri \ -v :/app/model \ - vllm-rocm \ - bash + vllm-rocm ``` Where the `` is the location where the model is stored, for example, the weights for llama2 or llama3 models. diff --git a/docs/getting_started/installation/gpu/xpu.inc.md b/docs/getting_started/installation/gpu/xpu.inc.md index b77c4e00cf0c4..ed1dc0418cf7e 100644 --- a/docs/getting_started/installation/gpu/xpu.inc.md +++ b/docs/getting_started/installation/gpu/xpu.inc.md @@ -3,13 +3,16 @@ vLLM initially supports basic model inference and serving on Intel GPU platform. !!! warning - There are no pre-built wheels or images for this device, so you must build vLLM from source. + There are no pre-built wheels for this device, so you need build vLLM from source. Or you can use pre-built images which are based on vLLM released versions. # --8<-- [end:installation] # --8<-- [start:requirements] - Supported Hardware: Intel Data Center GPU, Intel ARC GPU -- OneAPI requirements: oneAPI 2025.0 +- OneAPI requirements: oneAPI 2025.1 +- Python: 3.12 +!!! warning + The provided IPEX whl is Python3.12 specific so this version is a MUST. # --8<-- [end:requirements] # --8<-- [start:set-up-using-python] @@ -24,7 +27,7 @@ Currently, there are no pre-built XPU wheels. # --8<-- [end:pre-built-wheels] # --8<-- [start:build-wheel-from-source] -- First, install required [driver](https://dgpu-docs.intel.com/driver/installation.html#installing-gpu-drivers) and [Intel OneAPI](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html) 2025.0 or later. +- First, install required [driver](https://dgpu-docs.intel.com/driver/installation.html#installing-gpu-drivers) and [Intel OneAPI](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html) 2025.1 or later. - Second, install Python packages for vLLM XPU backend building: ```bash @@ -40,14 +43,10 @@ pip install -v -r requirements/xpu.txt VLLM_TARGET_DEVICE=xpu python setup.py install ``` -!!! note - - FP16 is the default data type in the current XPU backend. The BF16 data - type is supported on Intel Data Center GPU, not supported on Intel Arc GPU yet. - # --8<-- [end:build-wheel-from-source] # --8<-- [start:pre-built-images] -Currently, there are no pre-built XPU images. +Currently, we release prebuilt XPU images at docker [hub](https://hub.docker.com/r/intel/vllm/tags) based on vLLM released version. For more information, please refer release [note](https://github.com/intel/ai-containers/blob/main/vllm). # --8<-- [end:pre-built-images] # --8<-- [start:build-image-from-source] @@ -65,14 +64,14 @@ docker run -it \ # --8<-- [end:build-image-from-source] # --8<-- [start:supported-features] -XPU platform supports **tensor parallel** inference/serving and also supports **pipeline parallel** as a beta feature for online serving. We require Ray as the distributed runtime backend. For example, a reference execution like following: +XPU platform supports **tensor parallel** inference/serving and also supports **pipeline parallel** as a beta feature for online serving. For **pipeline parallel**, we support it on single node with mp as the backend. For example, a reference execution like following: ```bash python -m vllm.entrypoints.openai.api_server \ --model=facebook/opt-13b \ --dtype=bfloat16 \ --max_model_len=1024 \ - --distributed-executor-backend=ray \ + --distributed-executor-backend=mp \ --pipeline-parallel-size=2 \ -tp=8 ``` diff --git a/docs/mkdocs/hooks/generate_argparse.py b/docs/mkdocs/hooks/generate_argparse.py index 051a2d904406d..91454ec272b81 100644 --- a/docs/mkdocs/hooks/generate_argparse.py +++ b/docs/mkdocs/hooks/generate_argparse.py @@ -165,6 +165,7 @@ def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool): # Generate documentation for each parser for stem, parser in parsers.items(): doc_path = ARGPARSE_DOC_DIR / f"{stem}.md" - with open(doc_path, "w") as f: + # Specify encoding for building on Windows + with open(doc_path, "w", encoding="utf-8") as f: f.write(parser.format_help()) logger.info("Argparse generated: %s", doc_path.relative_to(ROOT_DIR)) diff --git a/docs/mkdocs/hooks/generate_examples.py b/docs/mkdocs/hooks/generate_examples.py index 881df791698e2..0cbaebb598a34 100644 --- a/docs/mkdocs/hooks/generate_examples.py +++ b/docs/mkdocs/hooks/generate_examples.py @@ -106,13 +106,41 @@ class Example: def determine_title(self) -> str: if not self.is_code: - with open(self.main_file) as f: + # Specify encoding for building on Windows + with open(self.main_file, encoding="utf-8") as f: first_line = f.readline().strip() match = re.match(r'^#\s+(?P.+)$', first_line) if match: return match.group('title') return fix_case(self.path.stem.replace("_", " ").title()) + def fix_relative_links(self, content: str) -> str: + """ + Fix relative links in markdown content by converting them to gh-file + format. + + Args: + content (str): The markdown content to process + + Returns: + str: Content with relative links converted to gh-file format + """ + # Regex to match markdown links [text](relative_path) + # This matches links that don't start with http, https, ftp, or # + link_pattern = r'\[([^\]]*)\]\((?!(?:https?|ftp)://|#)([^)]+)\)' + + def replace_link(match): + link_text = match.group(1) + relative_path = match.group(2) + + # Make relative to repo root + gh_file = (self.main_file.parent / relative_path).resolve() + gh_file = gh_file.relative_to(ROOT_DIR) + + return f'[{link_text}](gh-file:{gh_file})' + + return re.sub(link_pattern, replace_link, content) + def generate(self) -> str: content = f"# {self.title}\n\n" content += f"Source <gh-file:{self.path.relative_to(ROOT_DIR)}>.\n\n" @@ -120,14 +148,16 @@ class Example: # Use long code fence to avoid issues with # included files containing code fences too code_fence = "``````" - # Skip the title from md snippets as it's been included above - start_line = 2 + if self.is_code: - content += f"{code_fence}{self.main_file.suffix[1:]}\n" - start_line = 1 - content += f'--8<-- "{self.main_file}:{start_line}"\n' - if self.is_code: - content += f"{code_fence}\n" + content += (f"{code_fence}{self.main_file.suffix[1:]}\n" + f'--8<-- "{self.main_file}"\n' + f"{code_fence}\n") + else: + with open(self.main_file) as f: + # Skip the title from md snippets as it's been included above + main_content = f.readlines()[1:] + content += self.fix_relative_links("".join(main_content)) content += "\n" if not self.other_files: @@ -174,6 +204,7 @@ def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool): doc_path = EXAMPLE_DOC_DIR / example.category / example_name if not doc_path.parent.exists(): doc_path.parent.mkdir(parents=True) - with open(doc_path, "w+") as f: + # Specify encoding for building on Windows + with open(doc_path, "w+", encoding="utf-8") as f: f.write(example.generate()) logger.debug("Example generated: %s", doc_path.relative_to(ROOT_DIR)) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index bdb29aac333c1..6295a2aa8dc2f 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -322,6 +322,7 @@ th { | Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | |--------------|--------|-------------------|----------------------|---------------------------|---------------------| +| `ApertusForCausalLM` | Apertus | `swiss-ai/Apertus-8B-2509`, `swiss-ai/Apertus-70B-Instruct-2509`, etc. | ✅︎ | ✅︎ | ✅︎ | | `AquilaForCausalLM` | Aquila, Aquila2 | `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `ArceeForCausalLM` | Arcee (AFM) | `arcee-ai/AFM-4.5B-Base`, etc. | ✅︎ | ✅︎ | ✅︎ | | `ArcticForCausalLM` | Arctic | `Snowflake/snowflake-arctic-base`, `Snowflake/snowflake-arctic-instruct`, etc. | | ✅︎ | ✅︎ | @@ -382,11 +383,13 @@ th { | `MiniCPM3ForCausalLM` | MiniCPM3 | `openbmb/MiniCPM3-4B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MistralForCausalLM` | Mistral, Mistral-Instruct | `mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MixtralForCausalLM` | Mixtral-8x7B, Mixtral-8x7B-Instruct | `mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `MotifForCausalLM` | Motif-1-Tiny | `Motif-Technologies/Motif-2.6B`, `Motif-Technologies/Motif-2.6b-v1.1-LC`, etc. | ✅︎ | ✅︎ | | | `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | | ✅︎ | ✅︎ | | `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. | ✅︎ | ✅︎ | ✅︎ | | `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | ✅︎ | | `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | ✅︎ | ✅︎ | ✅︎ | | `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `OLMo3ForCausalLM` | OLMo3 | TBA | ✅︎ | ✅︎ | ✅︎ | | `OLMoEForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | | ✅︎ | ✅︎ | | `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | | ✅︎ | ✅︎ | | `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ | ✅︎ | @@ -401,6 +404,7 @@ th { | `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen3ForCausalLM` | Qwen3 | `Qwen/Qwen3-8B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen3MoeForCausalLM` | Qwen3MoE | `Qwen/Qwen3-30B-A3B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Qwen3NextForCausalLM` | Qwen3NextMoE | `Qwen/Qwen3-Next-80B-A3B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `SeedOssForCausalLM` | SeedOss | `ByteDance-Seed/Seed-OSS-36B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | | | ✅︎ | | `Starcoder2ForCausalLM` | Starcoder2 | `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. | | ✅︎ | ✅︎ | @@ -763,8 +767,9 @@ Speech2Text models trained specifically for Automatic Speech Recognition. | Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | |--------------|--------|-------------------|----------------------|---------------------------|---------------------| -| `WhisperForConditionalGeneration` | Whisper | `openai/whisper-small`, `openai/whisper-large-v3-turbo`, etc. | | | | -| `VoxtralForConditionalGeneration` | Voxtral (Mistral format) | `mistralai/Voxtral-Mini-3B-2507`, `mistralai/Voxtral-Small-24B-2507`, etc. | | ✅︎ | ✅︎ | +| `WhisperForConditionalGeneration` | Whisper | `openai/whisper-small`, `openai/whisper-large-v3-turbo`, etc. | | | ✅︎ | +| `VoxtralForConditionalGeneration` | Voxtral (Mistral format) | `mistralai/Voxtral-Mini-3B-2507`, `mistralai/Voxtral-Small-24B-2507`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Gemma3nForConditionalGeneration` | Gemma3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ | ### Pooling Models diff --git a/docs/serving/expert_parallel_deployment.md b/docs/serving/expert_parallel_deployment.md index 280b3322b11c3..494d2ad021e71 100644 --- a/docs/serving/expert_parallel_deployment.md +++ b/docs/serving/expert_parallel_deployment.md @@ -123,18 +123,46 @@ When enabled, vLLM collects load statistics with every forward pass and periodic ### EPLB Parameters +Configure EPLB with the `--eplb-config` argument, which accepts a JSON string. The available keys and their descriptions are: + | Parameter | Description | Default | |-----------|-------------|---------| -| `--eplb-window-size` | Number of engine steps to track for rebalancing decisions | - | -| `--eplb-step-interval` | Frequency of rebalancing (every N engine steps) | - | -| `--eplb-log-balancedness` | Log balancedness metrics (avg tokens per expert ÷ max tokens per expert) | `false` | -| `--num-redundant-experts` | Additional global experts per EP rank beyond equal distribution | `0` | +| `window_size`| Number of engine steps to track for rebalancing decisions | 1000 | +| `step_interval`| Frequency of rebalancing (every N engine steps) | 3000 | +| `log_balancedness` | Log balancedness metrics (avg tokens per expert ÷ max tokens per expert) | `false` | +| `num_redundant_experts` | Additional global experts per EP rank beyond equal distribution | `0` | + +For example: + +```bash +vllm serve Qwen/Qwen3-30B-A3B \ + --enable-eplb \ + --eplb-config '{"window_size":1000,"step_interval":3000,"num_redundant_experts":2,"log_balancedness":true}' +``` + +??? tip "Prefer individual arguments instead of JSON?" + + ```bash + vllm serve Qwen/Qwen3-30B-A3B \ + --enable-eplb \ + --eplb-config.window_size 1000 \ + --eplb-config.step_interval 3000 \ + --eplb-config.num_redundant_experts 2 \ + --eplb-config.log_balancedness true + ``` ### Expert Distribution Formula - **Default**: Each EP rank has `NUM_TOTAL_EXPERTS ÷ NUM_EP_RANKS` experts - **With redundancy**: Each EP rank has `(NUM_TOTAL_EXPERTS + NUM_REDUNDANT_EXPERTS) ÷ NUM_EP_RANKS` experts +### Memory Footprint Overhead + +EPLB uses redundant experts that need to fit in GPU memory. This means that EPLB may not be a good fit for memory constrained environments or when KV cache space is at a premium. + +This overhead equals `NUM_MOE_LAYERS * BYTES_PER_EXPERT * (NUM_TOTAL_EXPERTS + NUM_REDUNDANT_EXPERTS) ÷ NUM_EP_RANKS`. +For DeepSeekV3, this is approximately `2.4 GB` for one redundant expert per EP rank. + ### Example Command Single node deployment with EPLB enabled: @@ -146,12 +174,10 @@ VLLM_ALL2ALL_BACKEND=pplx VLLM_USE_DEEP_GEMM=1 vllm serve deepseek-ai/DeepSeek-V --data-parallel-size 8 \ # Data parallelism --enable-expert-parallel \ # Enable EP --enable-eplb \ # Enable load balancer - --eplb-log-balancedness \ # Log balancing metrics - --eplb-window-size 1000 \ # Track last 1000 engine steps - --eplb-step-interval 3000 # Rebalance every 3000 steps + --eplb-config '{"window_size":1000,"step_interval":3000,"num_redundant_experts":2,"log_balancedness":true}' ``` -For multi-node deployment, add these EPLB flags to each node's command. We recommend setting `--num-redundant-experts` to 32 in large scale use cases so the most popular experts are always available. +For multi-node deployment, add these EPLB flags to each node's command. We recommend setting `--eplb-config '{"num_redundant_experts":32}'` to 32 in large scale use cases so the most popular experts are always available. ## Disaggregated Serving (Prefill/Decode Split) diff --git a/docs/serving/parallelism_scaling.md b/docs/serving/parallelism_scaling.md index fa7fc1b290d50..cef1127fc5c15 100644 --- a/docs/serving/parallelism_scaling.md +++ b/docs/serving/parallelism_scaling.md @@ -66,7 +66,7 @@ Ray is a distributed computing framework for scaling Python programs. Multi-node vLLM uses Ray to manage the distributed execution of tasks across multiple nodes and control where execution happens. -Ray also offers high-level APIs for large-scale [offline batch inference](https://docs.ray.io/en/latest/data/working-with-llms.html) and [online serving](https://docs.ray.io/en/latest/serve/llm/serving-llms.html) that can leverage vLLM as the engine. These APIs add production-grade fault tolerance, scaling, and distributed observability to vLLM workloads. +Ray also offers high-level APIs for large-scale [offline batch inference](https://docs.ray.io/en/latest/data/working-with-llms.html) and [online serving](https://docs.ray.io/en/latest/serve/llm) that can leverage vLLM as the engine. These APIs add production-grade fault tolerance, scaling, and distributed observability to vLLM workloads. For details, see the [Ray documentation](https://docs.ray.io/en/latest/index.html). @@ -104,7 +104,7 @@ Note that `VLLM_HOST_IP` is unique for each worker. Keep the shells running thes From any node, enter a container and run `ray status` and `ray list nodes` to verify that Ray finds the expected number of nodes and GPUs. !!! tip - Alternatively, set up the Ray cluster using KubeRay. For more information, see [KubeRay vLLM documentation](https://docs.ray.io/en/latest/cluster/kubernetes/examples/vllm-rayservice.html). + Alternatively, set up the Ray cluster using KubeRay. For more information, see [KubeRay vLLM documentation](https://docs.ray.io/en/latest/cluster/kubernetes/examples/rayserve-llm-example.html). ### Running vLLM on a Ray cluster diff --git a/docs/usage/troubleshooting.md b/docs/usage/troubleshooting.md index 4945927e3d787..6e700d1faaa9c 100644 --- a/docs/usage/troubleshooting.md +++ b/docs/usage/troubleshooting.md @@ -40,6 +40,34 @@ If other strategies don't solve the problem, it's likely that the vLLM instance - `export NCCL_DEBUG=TRACE` to turn on more logging for NCCL. - `export VLLM_TRACE_FUNCTION=1` to record all function calls for inspection in the log files to tell which function crashes or hangs. Do not use this flag unless absolutely needed for debugging, it will cause significant delays in startup time. +## Breakpoints + +Setting normal `pdb` breakpoints may not work in vLLM's codebase if they are executed in a subprocess. You will experience something like: + +``` text + File "/usr/local/uv/cpython-3.12.11-linux-x86_64-gnu/lib/python3.12/bdb.py", line 100, in trace_dispatch + return self.dispatch_line(frame) + ^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/uv/cpython-3.12.11-linux-x86_64-gnu/lib/python3.12/bdb.py", line 125, in dispatch_line + if self.quitting: raise BdbQuit + ^^^^^^^^^^^^^ +bdb.BdbQuit +``` + +One solution is using [forked-pdb](https://github.com/Lightning-AI/forked-pdb). Install with `pip install fpdb` and set a breakpoint with something like: + +``` python +__import__('fpdb').ForkedPdb().set_trace() +``` + +Another option is to disable multiprocessing entirely, with the `VLLM_ENABLE_V1_MULTIPROCESSING` environment variable. +This keeps the scheduler in the same process, so you can use stock `pdb` breakpoints: + +``` python +import os +os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" +``` + ## Incorrect network setup The vLLM instance cannot get the correct IP address if you have a complicated network config. You can find a log such as `DEBUG 06-10 21:32:17 parallel_state.py:88] world_size=8 rank=0 local_rank=0 distributed_init_method=tcp://xxx.xxx.xxx.xxx:54641 backend=nccl` and the IP address should be the correct one. @@ -296,3 +324,4 @@ This indicates vLLM failed to initialize the NCCL communicator, possibly due to - In `v0.5.2`, `v0.5.3`, and `v0.5.3.post1`, there is a bug caused by [zmq](https://github.com/zeromq/pyzmq/issues/2000) , which can occasionally cause vLLM to hang depending on the machine configuration. The solution is to upgrade to the latest version of `vllm` to include the [fix](gh-pr:6759). - To address a memory overhead issue in older NCCL versions (see [bug](https://github.com/NVIDIA/nccl/issues/1234)), vLLM versions `>= 0.4.3, <= 0.10.1.1` would set the environment variable `NCCL_CUMEM_ENABLE=0`. External processes connecting to vLLM also needed to set this variable to prevent hangs or crashes. Since the underlying NCCL bug was fixed in NCCL 2.22.3, this override was removed in newer vLLM versions to allow for NCCL performance optimizations. +- In some PCIe machines (e.g. machines without NVLink), if you see an error like `transport/shm.cc:590 NCCL WARN Cuda failure 217 'peer access is not supported between these two devices'`, it's likely caused by a driver bug. See [this issue](https://github.com/NVIDIA/nccl/issues/1838) for more details. In that case, you can try to set `NCCL_CUMEM_HOST_ENABLE=0` to disable the feature, or upgrade your driver to the latest version. diff --git a/docs/usage/v1_guide.md b/docs/usage/v1_guide.md index 525f740d12a7f..d404c87e8f5a7 100644 --- a/docs/usage/v1_guide.md +++ b/docs/usage/v1_guide.md @@ -83,7 +83,7 @@ based on assigned priority, with FCFS as a tie-breaker), configurable via the | Model Type | Status | |-----------------------------|------------------------------------------------------------------------------------| | **Decoder-only Models** | <nobr>🚀 Optimized</nobr> | -| **Encoder-Decoder Models** | <nobr>🟠 Delayed</nobr> | +| **Encoder-Decoder Models** | <nobr>🟢 Whisper only</nobr> | | **Embedding Models** | <nobr>🟢 Functional</nobr> | | **Mamba Models** | <nobr>🟢 (Mamba-2), 🟢 (Mamba-1)</nobr> | | **Multimodal Models** | <nobr>🟢 Functional</nobr> | @@ -118,8 +118,9 @@ Please note that prefix caching is not yet supported for any of the above models #### Encoder-Decoder Models -Models requiring cross-attention between separate encoder and decoder (e.g., `BartForConditionalGeneration`, `MllamaForConditionalGeneration`) -are not yet supported. +Whisper is supported. Other models requiring cross-attention between separate +encoder and decoder (e.g., `BartForConditionalGeneration`, +`MllamaForConditionalGeneration`) are not yet supported. ### Features diff --git a/examples/offline_inference/chat_with_tools.py b/examples/offline_inference/chat_with_tools.py index 6e56e24f2092c..3a95b1fdfbabc 100644 --- a/examples/offline_inference/chat_with_tools.py +++ b/examples/offline_inference/chat_with_tools.py @@ -143,5 +143,5 @@ outputs = llm.chat(messages, sampling_params, tools=tools) print(outputs[0].outputs[0].text.strip()) # yields -# 'The weather in Dallas, TX is 85 degrees fahrenheit. ' +# 'The weather in Dallas, TX is 85 degrees Fahrenheit. ' # 'It is partly cloudly, with highs in the 90's.' diff --git a/examples/offline_inference/encoder_decoder.py b/examples/offline_inference/encoder_decoder.py index df6c1eaf4a21e..957db3c23b863 100644 --- a/examples/offline_inference/encoder_decoder.py +++ b/examples/offline_inference/encoder_decoder.py @@ -5,6 +5,8 @@ Demonstrate prompting of text-to-text encoder/decoder models, specifically BART and mBART. This script is refactored to allow model selection via command-line arguments. + +NOTE: This example is not yet supported in V1. """ import argparse diff --git a/examples/offline_inference/encoder_decoder_multimodal.py b/examples/offline_inference/encoder_decoder_multimodal.py index 655f9f3fce7ae..35e9203d1caf0 100644 --- a/examples/offline_inference/encoder_decoder_multimodal.py +++ b/examples/offline_inference/encoder_decoder_multimodal.py @@ -5,6 +5,7 @@ This example shows how to use vLLM for running offline inference with the explicit/implicit prompt format on enc-dec LMMs for text generation. """ +import os import time from collections.abc import Sequence from dataclasses import asdict @@ -130,6 +131,8 @@ def run_mllama(): def run_whisper(): + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + engine_args = EngineArgs( model="openai/whisper-large-v3-turbo", max_model_len=448, diff --git a/examples/offline_inference/prithvi_geospatial_mae_io_processor.py b/examples/offline_inference/prithvi_geospatial_mae_io_processor.py index 5d629fabf0a27..418c40645f9f2 100644 --- a/examples/offline_inference/prithvi_geospatial_mae_io_processor.py +++ b/examples/offline_inference/prithvi_geospatial_mae_io_processor.py @@ -18,7 +18,7 @@ from vllm.pooling_params import PoolingParams def main(): torch.set_default_dtype(torch.float16) - image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/India_900498_S2Hand.tif" # noqa: E501 + image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501 img_prompt = dict( data=image_url, @@ -36,7 +36,7 @@ def main(): # to avoid the model going OOM. # The maximum number depends on the available GPU memory max_num_seqs=32, - io_processor_plugin="prithvi_to_tiff_india", + io_processor_plugin="prithvi_to_tiff", model_impl="terratorch", ) diff --git a/examples/offline_inference/rlhf_colocate.py b/examples/offline_inference/rlhf_colocate.py index 65621023ab6ce..360fd79b55aad 100644 --- a/examples/offline_inference/rlhf_colocate.py +++ b/examples/offline_inference/rlhf_colocate.py @@ -28,12 +28,15 @@ Learn more about Ray placement groups: https://docs.ray.io/en/latest/placement-groups.html """ +import gc import os import ray import torch +import zmq from ray.util.placement_group import placement_group from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +from torch.multiprocessing.reductions import reduce_tensor from vllm import LLM @@ -86,20 +89,72 @@ class RayTrainingActor: from vllm.platforms import current_platform self.device_uuid = current_platform.get_device_uuid(0) + self.zmq_context = zmq.Context() + self.zmq_address_counter = 0 + self.zmq_handle = None def report_device_id(self) -> str: return self.device_uuid - def get_weight_ipc_handles(self): - from torch.multiprocessing.reductions import reduce_tensor + def get_zmq_handles(self) -> dict[str, str]: + suffix = f"{self.device_uuid}-{self.zmq_address_counter}" + self.zmq_handle = f"ipc:///tmp/rl-colocate-zmq-{suffix}.sock" + self.zmq_address_counter += 1 + return {self.device_uuid: self.zmq_handle} - data = {} - for name, p in self.model.named_parameters(): - # A training actor might hold only a subset of the weights and may - # need to gather weights from other actors. For demonstration - # purposes, each training actor owns the full weight set. - data[name] = reduce_tensor(p.detach()) - return {self.device_uuid: data} + def update_weights(self): + # align size to avoid misaligned address + align_size = 256 + + def get_size(p: torch.Tensor) -> int: + return (p.nbytes + align_size - 1) // align_size * align_size + + named_parameters: dict[str, torch.nn.Parameter] = dict( + self.model.named_parameters() + ) + max_tensor_size = max(get_size(p) for p in named_parameters.values()) + # use max_tensor_size * 2 as buffer size + buffer = torch.empty(max_tensor_size * 2, dtype=torch.uint8, device="cuda:0") + s = self.zmq_context.socket(zmq.REQ) + s.bind(self.zmq_handle) + handle = reduce_tensor(buffer) + + offset = 0 + buckets: list[tuple[list[dict], list[torch.Tensor]]] = [] + named_tensors: list[dict] = [] + real_tensors: list[torch.Tensor] = [] + for name, p in named_parameters.items(): + size = get_size(p) + if offset + size > buffer.numel(): + buckets.append((named_tensors, real_tensors)) + named_tensors, real_tensors = [], [] + offset = 0 + # assume tensors are contiguous + named_tensors.append( + {"name": name, "dtype": p.dtype, "shape": p.shape, "offset": offset} + ) + real_tensors.append(p) + offset += size + if named_tensors: + buckets.append((named_tensors, real_tensors)) + s.send_pyobj(handle) + s.recv() + for named_tensors, real_tensors in buckets: + offset = 0 + for p in real_tensors: + buffer[offset : offset + p.nbytes].data.copy_( + p.data.view(-1).view(dtype=torch.uint8), non_blocking=True + ) + offset += get_size(p) + torch.cuda.synchronize() + s.send_pyobj(named_tensors) + s.recv() + s.send_pyobj(None) + s.recv() + s.close() + del buffer + gc.collect() + torch.cuda.empty_cache() # Ray manages four GPUs. @@ -175,18 +230,22 @@ assert training_actor_device_ids[:2] == inference_engine_device_ids[0] # the second inference engine. assert training_actor_device_ids[2:] == inference_engine_device_ids[1] -print("Gather all the IPC handles from the training actors.") -ipc_handles = {} +print("Gather all the ZMQ handles from the training actors.") +zmq_handles = {} for actor in training_actors: - ipc_handles.update(ray.get(actor.get_weight_ipc_handles.remote())) + zmq_handles.update(ray.get(actor.get_zmq_handles.remote())) + +print(f"ZMQ handles: {zmq_handles}") print("Update the weights of the inference engines.") -for llm in inference_engines: - ray.get( - llm.collective_rpc.remote( - "update_weights_from_ipc_handles", args=(ipc_handles,) - ) - ) +ray.get( + [actor.update_weights.remote() for actor in training_actors] + + [ + llm.collective_rpc.remote("update_weights_from_ipc", args=(zmq_handles,)) + for llm in inference_engines + ] +) + print("Check if the weights are updated.") for llm in inference_engines: assert ray.get(llm.collective_rpc.remote("check_weights_changed", args=tuple())) diff --git a/examples/offline_inference/rlhf_utils.py b/examples/offline_inference/rlhf_utils.py index d2a8419ffabcd..c0e60b9793407 100644 --- a/examples/offline_inference/rlhf_utils.py +++ b/examples/offline_inference/rlhf_utils.py @@ -1,6 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import gc +from typing import Callable, Optional, TypedDict + import torch +import zmq def stateless_init_process_group(master_address, master_port, rank, world_size, device): @@ -66,6 +70,27 @@ class WorkerExtension: return weights_updated +def rebuild_ipc( + handle: tuple[Callable, tuple], device_id: Optional[int] = None +) -> torch.Tensor: + func, args = handle + list_args = list(args) + if device_id is not None: + # the key is to change device id to the current device id + # in case two processes have different CUDA_VISIBLE_DEVICES + list_args[6] = device_id + buffer = func(*list_args) + return buffer + + +class FlattenedTensorMetadata(TypedDict): + name: str + shape: torch.Size + dtype: torch.dtype + # specify the start offset of this tensor in shared ipc_buffer tensor + offset: int + + class ColocateWorkerExtension: """ The class for vLLM's worker to inherit from, in the colocate setting. @@ -76,27 +101,62 @@ class ColocateWorkerExtension: should pass the full qualified name as `worker_extension_cls` argument. """ + def update_weights_from_ipc(self, zmq_handles: dict[str, str]): + from vllm.model_executor.model_loader.utils import process_weights_after_loading + + assert self.device is not None + if not hasattr(self, "_zmq_ctx") or self._zmq_ctx is None: + self._zmq_ctx = zmq.Context() + socket = self._zmq_ctx.socket(zmq.REP) + socket.connect(zmq_handles[self.report_device_id()]) + buffer: Optional[torch.Tensor] = None + while True: + payload: tuple[Callable, tuple] | list[FlattenedTensorMetadata] | None = ( + socket.recv_pyobj() + ) + if payload is None: + # means the update is done + process_weights_after_loading( + self.model_runner.model, self.model_config, self.device + ) + torch.cuda.synchronize() + socket.send(b"") + break + if isinstance(payload, tuple): + # an ipc handle that vLLM can use `func, args = handle` + # and `func(*args)` to rebuild GPU tensor. + buffer = rebuild_ipc(payload, self.device.index) + assert buffer.dtype == torch.uint8 + socket.send(b"") + continue + assert isinstance(payload, list) + assert buffer is not None + weights = [] + for item in payload: + shape = item["shape"] + if isinstance(shape, (list, tuple)): + shape = torch.Size(shape) + assert isinstance(shape, torch.Size) + dtype, offset = item["dtype"], item["offset"] + size = dtype.itemsize * shape.numel() + tensor = buffer[offset : offset + size].view(dtype=dtype).view(shape) + weights.append((item["name"], tensor)) + self.model_runner.model.load_weights(weights=weights) + del weights + torch.cuda.synchronize() + socket.send(b"") + + socket.close() + del buffer + gc.collect() + torch.cuda.empty_cache() + def report_device_id(self) -> str: from vllm.platforms import current_platform self.device_uuid = current_platform.get_device_uuid(self.device.index) return self.device_uuid - def update_weights_from_ipc_handles(self, ipc_handles): - handles = ipc_handles[self.device_uuid] - device_id = self.device.index - weights = [] - for name, handle in handles.items(): - func, args = handle - list_args = list(args) - # the key is to change device id to the current device id - # in case two processes have different CUDA_VISIBLE_DEVICES - list_args[6] = device_id - tensor = func(*list_args) - weights.append((name, tensor)) - self.model_runner.model.load_weights(weights=weights) - torch.cuda.synchronize() - def check_weights_changed(self): """ Check if the weights are updated to 0. diff --git a/examples/offline_inference/tpu.py b/examples/offline_inference/tpu.py index 9776f4fe322b9..0093b63b0b1f3 100644 --- a/examples/offline_inference/tpu.py +++ b/examples/offline_inference/tpu.py @@ -42,7 +42,7 @@ def main(): llm_args["model"] = "meta-llama/Llama-3.1-8B-Instruct" # Set `enforce_eager=True` to avoid ahead-of-time compilation. - # In real workloads, `enforace_eager` should be `False`. + # In real workloads, `enforce_eager` should be `False`. llm = LLM(**llm_args) outputs = llm.generate(prompts, sampling_params) print("-" * 50) diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index b104113b88213..4b75eb19fcf94 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -1764,6 +1764,7 @@ def apply_image_repeat( probs = [1.0 - image_repeat_prob, image_repeat_prob] inputs = [] + inputs_with_empty_media = [] cur_image = data for i in range(num_prompts): if image_repeat_prob is not None: @@ -1774,14 +1775,25 @@ def apply_image_repeat( new_val = (i // 256 // 256, i // 256, i % 256) cur_image.putpixel((0, 0), new_val) + uuid = "uuid_{}".format(i) + inputs.append( { "prompt": prompts[i % len(prompts)], "multi_modal_data": {modality: cur_image}, + "multi_modal_uuids": {modality: uuid}, } ) - return inputs + inputs_with_empty_media.append( + { + "prompt": prompts[i % len(prompts)], + "multi_modal_data": {modality: None}, + "multi_modal_uuids": {modality: uuid}, + } + ) + + return inputs, inputs_with_empty_media @contextmanager @@ -1860,6 +1872,13 @@ def parse_args(): help="If True, then use different prompt (with the same multi-modal " "data) for each request.", ) + + parser.add_argument( + "--verify-mm-cache-hit-with-uuids", + action="store_true", + help="If True, will send all requests in a second batch with empty mm " + "data to verify cache hits with UUIDs.", + ) return parser.parse_args() @@ -1903,26 +1922,48 @@ def main(args): assert args.num_prompts > 0 if args.num_prompts == 1: # Single inference + uuid = "uuid_0" inputs = { "prompt": prompts[0], "multi_modal_data": {modality: data}, + "multi_modal_uuids": {modality: uuid}, + } + inputs_with_empty_media = { + "prompt": prompts[0], + "multi_modal_data": {modality: None}, + "multi_modal_uuids": {modality: uuid}, } else: # Batch inference if args.image_repeat_prob is not None: # Repeat images with specified probability of "image_repeat_prob" - inputs = apply_image_repeat( - args.image_repeat_prob, args.num_prompts, data, prompts, modality + inputs, inputs_with_empty_media = apply_image_repeat( + args.image_repeat_prob, + args.num_prompts, + data, + prompts, + modality, ) else: # Use the same image for all prompts - inputs = [ - { - "prompt": prompts[i % len(prompts)], - "multi_modal_data": {modality: data}, - } - for i in range(args.num_prompts) - ] + inputs = [] + inputs_with_empty_media = [] + for i in range(args.num_prompts): + uuid = "uuid_{}".format(i) + inputs.append( + { + "prompt": prompts[i % len(prompts)], + "multi_modal_data": {modality: data}, + "multi_modal_uuids": {modality: uuid}, + } + ) + inputs_with_empty_media.append( + { + "prompt": prompts[i % len(prompts)], + "multi_modal_data": {modality: None}, + "multi_modal_uuids": {modality: uuid}, + } + ) # Add LoRA request if applicable lora_request = ( @@ -1942,6 +1983,26 @@ def main(args): print(generated_text) print("-" * 50) + if args.verify_mm_cache_hit_with_uuids: + try: + # Verify cache hits with UUIDs + print( + "Sending a second batch of requests with empty media" + " and matching UUIDs." + ) + outputs = llm.generate( + inputs_with_empty_media, + sampling_params=sampling_params, + lora_request=lora_request, + ) + print("-" * 50) + for o in outputs: + generated_text = o.outputs[0].text + print(generated_text) + print("-" * 50) + except Exception as e: + print(f"Failed to verify cache hits with UUIDs. Error: {e}") + if __name__ == "__main__": args = parse_args() diff --git a/examples/online_serving/kv_events_subscriber.py b/examples/online_serving/kv_events_subscriber.py index f238c66234dcc..9fd55fc9ddc94 100644 --- a/examples/online_serving/kv_events_subscriber.py +++ b/examples/online_serving/kv_events_subscriber.py @@ -6,6 +6,8 @@ import msgspec import zmq from msgspec.msgpack import Decoder +from vllm.v1.core.kv_cache_utils import BlockHash + # # Types copied from vllm.distributed.kv_events @@ -22,8 +24,8 @@ class KVCacheEvent( class BlockStored(KVCacheEvent): - block_hashes: list[int] - parent_block_hash: Optional[int] + block_hashes: list[BlockHash] + parent_block_hash: Optional[BlockHash] token_ids: list[int] block_size: int lora_id: Optional[int] @@ -31,7 +33,7 @@ class BlockStored(KVCacheEvent): class BlockRemoved(KVCacheEvent): - block_hashes: list[int] + block_hashes: list[BlockHash] medium: Optional[str] diff --git a/examples/online_serving/prithvi_geospatial_mae.py b/examples/online_serving/prithvi_geospatial_mae.py index c6eed64838ea4..611a7cbc89fa2 100644 --- a/examples/online_serving/prithvi_geospatial_mae.py +++ b/examples/online_serving/prithvi_geospatial_mae.py @@ -18,11 +18,11 @@ import requests # --model-impl terratorch # --task embed --trust-remote-code # --skip-tokenizer-init --enforce-eager -# --io-processor-plugin prithvi_to_tiff_india +# --io-processor-plugin prithvi_to_tiff def main(): - image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/India_900498_S2Hand.tif" # noqa: E501 + image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501 server_endpoint = "http://localhost:8000/pooling" request_payload_url = { diff --git a/examples/tool_chat_template_phi4_mini.jinja b/examples/tool_chat_template_phi4_mini.jinja index 83886762c2893..6f40c38c20644 100644 --- a/examples/tool_chat_template_phi4_mini.jinja +++ b/examples/tool_chat_template_phi4_mini.jinja @@ -9,7 +9,7 @@ <|system|> {{ system_message }} {%- if tools %} -In addition to plain text responses, you can chose to call one or more of the provided functions. +In addition to plain text responses, you can choose to call one or more of the provided functions. Use the following rule to decide when to call a function: * if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so @@ -19,7 +19,7 @@ If you decide to call functions: * prefix function calls with functools marker (no closing marker required) * all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...] * follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples - * respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0 + * respect the argument type formatting. E.g., if the type is number and format is float, write value 7 as 7.0 * make sure you pick the right functions that match the user intent diff --git a/pyproject.toml b/pyproject.toml index e63f8aeae2787..f5a44f56f416e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -145,6 +145,7 @@ skip_gitignore = true [tool.pytest.ini_options] markers = [ + "slow_test", "skip_global_cleanup", "core_model: enable this model test in each PR instead of only nightly", "hybrid_model: models that contain mamba layers (including pure SSM and hybrid architectures)", @@ -228,6 +229,7 @@ fo = "fo" ba = "ba" [tool.typos.type.py.extend-words] +ba = "ba" [tool.typos.type.cpp] extend-glob = ["*.cu"] @@ -344,3 +346,6 @@ extend-ignore-re = [] windo = "windo" [tool.typos.type.vimscript.extend-words] + +[tool.uv] +no-build-isolation-package = ["torch"] diff --git a/requirements/common.txt b/requirements/common.txt index ce0795488cc1e..b8665104bd09a 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -20,8 +20,7 @@ prometheus-fastapi-instrumentator >= 7.0.0 tiktoken >= 0.6.0 # Required for DBRX tokenizer lm-format-enforcer == 0.11.3 llguidance >= 0.7.11, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64" -outlines_core == 0.2.10 ; platform_machine != "s390x" -outlines == 0.1.11 ; platform_machine == "s390x" +outlines_core == 0.2.11 # required for outlines backend disk cache diskcache == 5.6.3 lark == 1.2.2 diff --git a/requirements/rocm.txt b/requirements/rocm.txt index c3bb65b70a0b8..8e39951210714 100644 --- a/requirements/rocm.txt +++ b/requirements/rocm.txt @@ -8,7 +8,7 @@ numba == 0.61.2; python_version > '3.9' boto3 botocore datasets -ray>=2.10.0,<2.45.0 +ray[cgraph]>=2.48.0 # Ray Compiled Graph, required for pipeline parallelism in V1. peft pytest-asyncio tensorizer==2.10.1 diff --git a/requirements/test.in b/requirements/test.in index 5db9cd797904f..744cfbe885278 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -21,6 +21,7 @@ ray[cgraph,default]>=2.48.0 # Ray Compiled Graph, required by pipeline paralleli sentence-transformers # required for embedding tests soundfile # required for audio tests jiwer # required for audio tests +tblib # for pickling test exceptions timm >=1.0.17 # required for internvl and gemma3n-mm test torch==2.8.0 torchaudio==2.8.0 @@ -54,4 +55,4 @@ runai-model-streamer-s3==0.11.0 fastsafetensors>=0.1.10 pydantic>=2.10 # 2.9 leads to error on python 3.10 decord==0.6.0 -terratorch==1.1rc3 # required for PrithviMAE test +terratorch @ git+https://github.com/IBM/terratorch.git@1.1.rc3 # required for PrithviMAE test diff --git a/requirements/test.txt b/requirements/test.txt index 332a9b9cfbf59..5eebdc788aa3d 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -137,7 +137,7 @@ contourpy==1.3.0 # via matplotlib cramjam==2.9.0 # via fastparquet -cupy-cuda12x==13.3.0 +cupy-cuda12x==13.6.0 # via ray cycler==0.12.1 # via matplotlib @@ -1032,6 +1032,8 @@ tabledata==1.3.3 # via pytablewriter tabulate==0.9.0 # via sacrebleu +tblib==3.1.0 + # via -r requirements/test.in tcolorpy==0.1.6 # via pytablewriter tenacity==9.0.0 @@ -1042,7 +1044,7 @@ tensorboardx==2.6.4 # via lightning tensorizer==2.10.1 # via -r requirements/test.in -terratorch==1.1rc3 +terratorch @ git+https://github.com/IBM/terratorch.git@07184fcf91a1324f831ff521dd238d97fe350e3e # via -r requirements/test.in threadpoolctl==3.5.0 # via scikit-learn diff --git a/requirements/xpu.txt b/requirements/xpu.txt index c44a2a9c74e50..74f5b05b2382a 100644 --- a/requirements/xpu.txt +++ b/requirements/xpu.txt @@ -11,10 +11,9 @@ jinja2>=3.1.6 datasets # for benchmark scripts numba == 0.60.0 # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding nixl==0.3.0 # for PD disaggregation ---extra-index-url=https://download.pytorch.org/whl/xpu torch==2.8.0+xpu torchaudio torchvision -pytorch-triton-xpu ---extra-index-url=https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ -intel-extension-for-pytorch==2.8.10+xpu +--extra-index-url=https://download.pytorch.org/whl/xpu + +intel-extension-for-pytorch @ https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.8.10.post0%2Bxpu-cp312-cp312-linux_x86_64.whl diff --git a/setup.py b/setup.py index 4ea0baa0b2204..eb313b7d219c7 100644 --- a/setup.py +++ b/setup.py @@ -656,8 +656,10 @@ setup( "bench": ["pandas", "datasets"], "tensorizer": ["tensorizer==2.10.1"], "fastsafetensors": ["fastsafetensors >= 0.1.10"], - "runai": - ["runai-model-streamer >= 0.13.3", "runai-model-streamer-s3", "boto3"], + "runai": [ + "runai-model-streamer >= 0.14.0", "runai-model-streamer-gcs", + "google-cloud-storage", "runai-model-streamer-s3", "boto3" + ], "audio": ["librosa", "soundfile", "mistral_common[audio]"], # Required for audio processing "video": [], # Kept for backwards compatibility diff --git a/tests/async_engine/test_api_server.py b/tests/async_engine/test_api_server.py index 90f63e7ea17db..07370a8803291 100644 --- a/tests/async_engine/test_api_server.py +++ b/tests/async_engine/test_api_server.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copyreg import os import subprocess import sys @@ -10,6 +11,30 @@ from pathlib import Path import pytest import requests +import urllib3.exceptions + + +def _pickle_new_connection_error(obj): + """Custom pickler for NewConnectionError to fix tblib compatibility.""" + # Extract the original message by removing the "conn: " prefix + full_message = obj.args[0] if obj.args else "" + if ': ' in full_message: + # Split off the connection part and keep the actual message + _, actual_message = full_message.split(': ', 1) + else: + actual_message = full_message + return _unpickle_new_connection_error, (actual_message, ) + + +def _unpickle_new_connection_error(message): + """Custom unpickler for NewConnectionError.""" + # Create with None as conn and the actual message + return urllib3.exceptions.NewConnectionError(None, message) + + +# Register the custom pickle/unpickle functions for tblib compatibility +copyreg.pickle(urllib3.exceptions.NewConnectionError, + _pickle_new_connection_error) def _query_server(prompt: str, max_tokens: int = 5) -> dict: @@ -52,6 +77,7 @@ def api_server(distributed_executor_backend: str): uvicorn_process.terminate() +@pytest.mark.timeout(300) @pytest.mark.parametrize("distributed_executor_backend", ["mp", "ray"]) def test_api_server(api_server, distributed_executor_backend: str): """ diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index a3b09cc817917..fba18f197074b 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -62,6 +62,8 @@ def _fix_prompt_embed_outputs( @pytest.mark.parametrize("backend", ["FLASH_ATTN"]) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("enforce_eager", [False]) +@pytest.mark.parametrize("async_scheduling", [True, False]) +@pytest.mark.parametrize("model_executor", ["uni", "mp"]) @pytest.mark.parametrize("enable_prompt_embeds", [True, False]) def test_models( monkeypatch: pytest.MonkeyPatch, @@ -70,6 +72,8 @@ def test_models( backend: str, max_tokens: int, enforce_eager: bool, + async_scheduling: bool, + model_executor: str, enable_prompt_embeds: bool, ) -> None: @@ -77,6 +81,12 @@ def test_models( "VLLM_USE_V1") and envs.VLLM_USE_V1: pytest.skip("enable_prompt_embeds is not supported in v1.") + if not envs.VLLM_USE_V1: + if async_scheduling: + pytest.skip("async_scheduling only supported in v1.") + if model_executor != "uni": + pytest.skip("only test uniproc executor for v0.") + if backend == "XFORMERS" and model == "google/gemma-2-2b-it": pytest.skip( f"{backend} does not support gemma2 with full context length.") @@ -98,11 +108,15 @@ def test_models( prompt_embeds = hf_model.get_prompt_embeddings( example_prompts) - with VllmRunner(model, - max_model_len=8192, - enforce_eager=enforce_eager, - enable_prompt_embeds=enable_prompt_embeds, - gpu_memory_utilization=0.7) as vllm_model: + with VllmRunner( + model, + max_model_len=8192, + enforce_eager=enforce_eager, + enable_prompt_embeds=enable_prompt_embeds, + gpu_memory_utilization=0.7, + async_scheduling=async_scheduling, + distributed_executor_backend=model_executor, + ) as vllm_model: if enable_prompt_embeds: vllm_outputs = vllm_model.generate_greedy( prompt_embeds, max_tokens) diff --git a/tests/benchmarks/test_serve_cli.py b/tests/benchmarks/test_serve_cli.py index bfcf274727e27..5471d6b8e4a5f 100644 --- a/tests/benchmarks/test_serve_cli.py +++ b/tests/benchmarks/test_serve_cli.py @@ -45,3 +45,34 @@ def test_bench_serve(server): print(result.stderr) assert result.returncode == 0, f"Benchmark failed: {result.stderr}" + +@pytest.mark.benchmark +def test_bench_serve_chat(server): + command = [ + "vllm", + "bench", + "serve", + "--model", + MODEL_NAME, + "--host", + server.host, + "--port", + str(server.port), + "--dataset-name", + "random", + "--random-input-len", + "32", + "--random-output-len", + "4", + "--num-prompts", + "5", + "--endpoint", + "/v1/chat/completions", + "--endpoint-type", + "openai-chat", + ] + result = subprocess.run(command, capture_output=True, text=True) + print(result.stdout) + print(result.stderr) + + assert result.returncode == 0, f"Benchmark failed: {result.stderr}" diff --git a/tests/ci_envs.py b/tests/ci_envs.py new file mode 100644 index 0000000000000..d16ecce1ef8dd --- /dev/null +++ b/tests/ci_envs.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +These envs only work for a small part of the tests, fix what you need! +""" + +import os +from typing import TYPE_CHECKING, Any, Callable, Optional + +if TYPE_CHECKING: + VLLM_CI_NO_SKIP: bool = False + VLLM_CI_DTYPE: Optional[str] = None + VLLM_CI_HEAD_DTYPE: Optional[str] = None + VLLM_CI_HF_DTYPE: Optional[str] = None + +environment_variables: dict[str, Callable[[], Any]] = { + # A model family has many models with the same architecture. + # By default, a model family tests only one model. + # Through this flag, all models can be tested. + "VLLM_CI_NO_SKIP": lambda: bool(int(os.getenv("VLLM_CI_NO_SKIP", "0"))), + # Allow changing the dtype used by vllm in tests + "VLLM_CI_DTYPE": lambda: os.getenv("VLLM_CI_DTYPE", None), + # Allow changing the head dtype used by vllm in tests + "VLLM_CI_HEAD_DTYPE": lambda: os.getenv("VLLM_CI_HEAD_DTYPE", None), + # Allow changing the head dtype used by transformers in tests + "VLLM_CI_HF_DTYPE": lambda: os.getenv("VLLM_CI_HF_DTYPE", None), +} + + +def __getattr__(name: str): + # lazy evaluation of environment variables + if name in environment_variables: + return environment_variables[name]() + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def __dir__(): + return list(environment_variables.keys()) + + +def is_set(name: str): + """Check if an environment variable is explicitly set.""" + if name in environment_variables: + return name in os.environ + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/tests/compile/piecewise/test_full_cudagraph.py b/tests/compile/piecewise/test_full_cudagraph.py index 97140a9db7af6..2454f85342eba 100644 --- a/tests/compile/piecewise/test_full_cudagraph.py +++ b/tests/compile/piecewise/test_full_cudagraph.py @@ -61,6 +61,16 @@ backend_configs = { "cudagraph_mode": "FULL_AND_PIECEWISE", }, specific_gpu_arch=(9, 0)), + # FlashAttention MLA on Hopper + "FlashAttentionMLA": + BackendConfig(name="FlashAttentionMLA", + env_vars={ + "VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA", + }, + comp_config={ + "cudagraph_mode": "FULL_DECODE_ONLY", + }, + specific_gpu_arch=(9, 0)), # Cutlass MLA on Blackwell "CutlassMLA": BackendConfig( @@ -102,7 +112,7 @@ backend_configs = { test_params_full_cudagraph = [] # deepseek-ai/DeepSeek-V2-Lite with MLA -MLA_backends = ["FlashMLA", "CutlassMLA"] +MLA_backends = ["FlashMLA", "FlashAttentionMLA", "CutlassMLA"] for mla_backend in MLA_backends: test_params_full_cudagraph.append( pytest.param( diff --git a/tests/compile/piecewise/test_multiple_graphs.py b/tests/compile/piecewise/test_multiple_graphs.py index aee2acbd490ee..5cfebfce9ea2a 100644 --- a/tests/compile/piecewise/test_multiple_graphs.py +++ b/tests/compile/piecewise/test_multiple_graphs.py @@ -4,9 +4,9 @@ Test (piecewise) compilation with a simple model where multiple submodules are compiled and graph captured separately. """ + import torch from torch import nn -from torch.library import Library from vllm.compilation.backends import set_model_tag from vllm.compilation.counter import compilation_counter @@ -15,10 +15,9 @@ from vllm.compilation.decorators import (ignore_torch_compile, from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, VllmConfig, set_current_vllm_config) from vllm.forward_context import BatchDescriptor, set_forward_context -from vllm.utils import direct_register_custom_op -# create a library to hold the custom op -silly_lib = Library("silly", "FRAGMENT") # noqa +# This import automatically registers `torch.ops.silly.attention` +from .. import silly_attention # noqa: F401 BATCH_SIZE = 32 MLP_SIZE = 128 @@ -26,27 +25,6 @@ HIDDEN_SIZE = 1024 RANDOM_SEED = 0 -def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - out.copy_(q) - out += k - out += v - - -def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - return - - -direct_register_custom_op( - op_name="attention", - op_func=silly_attention, - mutates_args=["out"], - fake_impl=silly_attention_fake, - target_lib=silly_lib, -) - - @support_torch_compile class ParentModel(nn.Module): diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index 2d1a72d44ec70..84f4945c82725 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -4,10 +4,10 @@ Test the piecewise compilation with a simple model so that we can exactly calculate the expected output and side effects. """ + import pytest import torch from torch import nn -from torch.library import Library from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile @@ -15,35 +15,9 @@ from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, VllmConfig, set_current_vllm_config) from vllm.envs import VLLM_USE_V1 from vllm.forward_context import BatchDescriptor, set_forward_context -from vllm.utils import direct_register_custom_op -global_counter = 0 - -# create a library to hold the custom op -silly_lib = Library("silly", "FRAGMENT") # noqa - - -def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - global global_counter - global_counter += 1 - print(f"{global_counter=}") - out.copy_(q) - out[0] += 1 - - -def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - return - - -direct_register_custom_op( - op_name="attention", - op_func=silly_attention, - mutates_args=["out"], - fake_impl=silly_attention_fake, - target_lib=silly_lib, -) +# This import automatically registers `torch.ops.silly.attention` +from ..silly_attention import get_global_counter, reset_global_counter @support_torch_compile @@ -59,8 +33,7 @@ class SillyModel(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: """ Overall effect: - x += 1 - x[0] += 2 + x = 3 * x + 19 global_counter += 2 """ x = x + 1 @@ -78,6 +51,7 @@ class SillyModel(nn.Module): @pytest.mark.parametrize("use_inductor", [True, False]) +@torch.inference_mode() def test_simple_piecewise_compile(use_inductor): assert VLLM_USE_V1 @@ -121,13 +95,12 @@ def test_simple_piecewise_compile(use_inductor): model(torch.randn(1).cuda()) input = torch.zeros(2).cuda() - global global_counter - global_counter = 0 + reset_global_counter() with set_forward_context( None, vllm_config=vllm_config, cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE, batch_descriptor=BatchDescriptor(num_tokens=2, )): output = model(input) - assert global_counter == 2 - assert torch.allclose(output.cpu(), torch.tensor([3., 1.])) + assert get_global_counter() == 2 + assert torch.allclose(output.cpu(), torch.tensor([19.0, 19.0])) diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index bcfd0d834c5db..cba7517647e51 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -14,38 +14,15 @@ from typing import Any, Optional import pytest import torch from torch import nn -from torch.library import Library from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, VllmConfig, set_current_vllm_config) from vllm.forward_context import BatchDescriptor, set_forward_context -from vllm.utils import direct_register_custom_op -# create a library to hold the custom op -silly_lib = Library("silly", "FRAGMENT") # noqa - - -def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - out.copy_(q) - out += k - out += v - - -def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - return - - -direct_register_custom_op( - op_name="attention", - op_func=silly_attention, - mutates_args=["out"], - fake_impl=silly_attention_fake, - target_lib=silly_lib, -) +# This import automatically registers `torch.ops.silly.attention` +from .. import silly_attention # noqa: F401 @dataclass diff --git a/tests/compile/silly_attention.py b/tests/compile/silly_attention.py new file mode 100644 index 0000000000000..13eb0bf4b1fa1 --- /dev/null +++ b/tests/compile/silly_attention.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Shared PyTorch custom silly attention for compilation tests. +Centralizes custom operation definitions to avoid duplicate registrations. +""" + +import torch +from torch.library import Library + +from vllm.utils import direct_register_custom_op + +# Shared library for all compilation test operations +# Using "silly" namespace to match existing test expectations +# import this file will automatically register +# torch ops for testing (like silly.attention) +silly_lib = Library("silly", "FRAGMENT") + +# Global counter that counts the number of times attention is invoked +_global_counter = 0 + + +def get_global_counter(): + """Get the current global counter value""" + return _global_counter + + +def reset_global_counter(): + """Reset the global counter to 0""" + global _global_counter + _global_counter = 0 + + +def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + out: torch.Tensor) -> None: + """ + Unified attention implementation that depends on + all inputs and affects the output. + Always increments a global counter that tests can use or ignore. + """ + global _global_counter + + # Always increment the global counter + _global_counter += 1 + + # Unified implementation that depends on all inputs + out.copy_(q + k + v) + + +def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + out: torch.Tensor) -> None: + """Fake implementation for testing""" + return + + +# Register the unified attention operation +direct_register_custom_op( + op_name="attention", + op_func=silly_attention, + mutates_args=["out"], + fake_impl=silly_attention_fake, + target_lib=silly_lib, +) diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index 422cb94b036ca..fd2b1866e62e1 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -23,7 +23,7 @@ class TestSetting: fullgraph: bool -# we cannot afford testing the full Catesian product +# we cannot afford testing the full Cartesian product # of all models and all levels @pytest.mark.parametrize( "test_setting", @@ -62,8 +62,12 @@ class TestSetting: TestSetting( model="BAAI/bge-multilingual-gemma2", model_args=[ - "--runner", "pooling", "--dtype", "bfloat16", - "--max-model-len", "2048" + "--runner", + "pooling", + "--dtype", + "bfloat16", + "--max-model-len", + "2048", ], pp_size=1, tp_size=1, @@ -71,17 +75,15 @@ class TestSetting: method="encode", fullgraph=True, ), - # TODO: bert models are not supported in V1 yet - # # encoder-based embedding model (BERT) - # TestSetting( - # model="BAAI/bge-base-en-v1.5", - # model_args=["--runner", "pooling"], - # pp_size=1, - # tp_size=1, - # attn_backend="XFORMERS", - # method="encode", - # fullgraph=True, - # ), + TestSetting( + model="BAAI/bge-base-en-v1.5", + model_args=["--runner", "pooling"], + pp_size=1, + tp_size=1, + attn_backend="FLASH_ATTN", + method="encode", + fullgraph=True, + ), # vision language model TestSetting( model="microsoft/Phi-3.5-vision-instruct", @@ -92,7 +94,8 @@ class TestSetting: method="generate_with_image", fullgraph=False, ), - ]) + ], +) def test_compile_correctness( monkeypatch: pytest.MonkeyPatch, test_setting: TestSetting, diff --git a/tests/compile/test_decorator.py b/tests/compile/test_decorator.py index 51f8ddd566d56..d73586d53ff3e 100644 --- a/tests/compile/test_decorator.py +++ b/tests/compile/test_decorator.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch from torch import nn -from torch.library import Library from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import (ignore_torch_compile, @@ -10,36 +9,14 @@ from vllm.compilation.decorators import (ignore_torch_compile, from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel, CUDAGraphMode, VllmConfig, set_current_vllm_config) from vllm.forward_context import BatchDescriptor, set_forward_context -from vllm.utils import direct_register_custom_op -# create a library to hold the custom op -silly_lib = Library("silly", "FRAGMENT") # noqa +# This import automatically registers `torch.ops.silly.attention` +from . import silly_attention # noqa: F401 BATCH_SIZE = 32 MLP_SIZE = 128 -def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - out.copy_(q) - out += k - out += v - - -def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - return - - -direct_register_custom_op( - op_name="attention", - op_func=silly_attention, - mutates_args=["out"], - fake_impl=silly_attention_fake, - target_lib=silly_lib, -) - - @torch.inference_mode def run_model(vllm_config: VllmConfig, model: nn.Module, cudagraph_runtime_mode: CUDAGraphMode): @@ -151,7 +128,7 @@ def test_ignore_torch_compile_decorator(): run_model(vllm_config, mod_C, cudagraph_runtime_mode) -# Only enable torch.compile if +# Only enable torch.compile if # vllm_config.cache_config.kv_sharing_fast_prefill=True @support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config. kv_sharing_fast_prefill) @@ -173,7 +150,7 @@ class B(nn.Module): return x -# Only enable torch.compile if +# Only enable torch.compile if # vllm_config.cache_config.kv_sharing_fast_prefill=False @support_torch_compile(enable_if=lambda vllm_config: not vllm_config. cache_config.kv_sharing_fast_prefill) diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index dba668cfa16a6..6baf4bf83f499 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -40,13 +40,12 @@ backend_unfused: Optional[TestBackend] = None @pytest.mark.parametrize( "model, quant_key", [("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)]) -@pytest.mark.parametrize( - "use_triton_fa", [True, False] if current_platform.is_rocm() else [False]) +@pytest.mark.parametrize("use_triton_fa", [True, False]) @pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8") -@pytest.mark.skipif(not current_platform.is_cuda_alike(), - reason="Only test CUDA and ROCm") -def test_attention_fusion(example_prompts, monkeypatch, model: str, - quant_key: QuantKey, use_triton_fa: bool): +@pytest.mark.skipif(not current_platform.is_rocm(), + reason="V0 attn quant fusion only on ROCm") +def test_attention_fusion_v0(example_prompts, monkeypatch, model: str, + quant_key: QuantKey, use_triton_fa: bool): # Clean Dynamo cache to avoid reusing other test cases # (for some reason the reset at the end is not enough) torch._dynamo.reset() @@ -69,13 +68,17 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str, backend="tests.compile.test_fusion_attn.backend_unfused", custom_ops=["+quant_fp8"], ) - vllm_config = VllmConfig(compilation_config=compile_config) + vllm_config = VllmConfig(compilation_config=compile_config, + model_config=ModelConfig( + model=model, + dtype=torch.bfloat16, + )) backend_unfused = TestBackend(NoOpEliminationPass(vllm_config)) llm = LLM(model, enforce_eager=True, compilation_config=compile_config, - gpu_memory_utilization=0.9, + gpu_memory_utilization=0.5, max_model_len=2048) sampling_params = SamplingParams(temperature=0.0, @@ -93,7 +96,11 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str, backend="tests.compile.test_fusion_attn.backend", custom_ops=["+quant_fp8"], ) - vllm_config = VllmConfig(compilation_config=compile_config) + vllm_config = VllmConfig(compilation_config=compile_config, + model_config=ModelConfig( + model=model, + dtype=torch.bfloat16, + )) # AttnFusionPass needs attention layers to be registered in config upon init # so we initialize it during compilation. @@ -102,7 +109,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str, llm2 = LLM(model, enforce_eager=True, compilation_config=compile_config, - gpu_memory_utilization=0.9, + gpu_memory_utilization=0.5, max_model_len=2048) # check support @@ -171,6 +178,8 @@ class AttentionQuantPatternModel(torch.nn.Module): cache_config=vllm_config.cache_config, prefix="model.layers.0.self_attn.attn", ) + self.attn._k_scale = self.attn._k_scale.to(device) + self.attn._v_scale = self.attn._v_scale.to(device) self.block_size = 16 @@ -188,7 +197,7 @@ class AttentionQuantPatternModel(torch.nn.Module): device=self.device, ) - def build_attn_metadata(self, batch_size: int): + def build_attn_metadata(self, batch_size: int, use_hnd: bool): """Initialize attention metadata.""" # Create common attn metadata @@ -205,10 +214,8 @@ class AttentionQuantPatternModel(torch.nn.Module): num_blocks = batch_size * max_blocks # Create dummy KV cache for FlashInfer TRTLLM - # - NHD: [num_blocks, 2, block_size, num_kv_heads, head_size] - # - HND: [num_blocks, 2, num_kv_heads, block_size, head_size] - # Create kv_cache in HND layout and permute to NHD layout - # (later will be permuted back to HND layout in forward pass) + # - NHD: [num_blocks, block_size, num_kv_heads, head_size] + # - HND: [num_blocks, num_kv_heads, block_size, head_size] kv_cache = torch.zeros(num_blocks, 2, self.num_kv_heads, @@ -216,7 +223,17 @@ class AttentionQuantPatternModel(torch.nn.Module): self.head_size, dtype=self.kv_cache_dtype, device=self.device) - kv_cache = kv_cache.permute(0, 1, 3, 2, 4) + if current_platform.is_rocm(): + # k/v as 1st dimention + if use_hnd: + kv_cache = kv_cache.permute(1, 0, 2, 3, 4) + else: + kv_cache = kv_cache.permute(1, 0, 3, 2, 4) + else: + # k/v as 2nd dimention + # Create kv_cache in HND layout and permute to NHD layout + # (later will be permuted back to HND layout in forward pass) + kv_cache = kv_cache.permute(0, 1, 3, 2, 4) self.attn.kv_cache = [kv_cache] # Build attn metadata @@ -296,28 +313,51 @@ class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel): out_dtype=attn_output.dtype) -@pytest.mark.parametrize("num_qo_heads, num_kv_heads", [(64, 8), (40, 8)]) +if current_platform.is_cuda(): + MODELS = [("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", + TestAttentionFp8StaticQuantPatternModel), + ("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", + TestAttentionNvfp4QuantPatternModel)] + HEADS = [(64, 8), (40, 8)] +elif current_platform.is_rocm(): + MODELS = [("amd/Llama-3.1-8B-Instruct-FP8-KV", + TestAttentionFp8StaticQuantPatternModel)] + HEADS = [(32, 8), (40, 8)] +else: + MODELS = [] + HEADS = [] + + +@pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS) @pytest.mark.parametrize("head_size", [128]) -@pytest.mark.parametrize("batch_size", [7, 256, 533]) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) -@pytest.mark.parametrize("model_name, model_class", - [("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", - TestAttentionFp8StaticQuantPatternModel), - ("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", - TestAttentionNvfp4QuantPatternModel)]) -@pytest.mark.parametrize("backend", [_Backend.FLASHINFER]) -@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA") +@pytest.mark.parametrize("batch_size", + [7, 256, 533] if current_platform.is_cuda() else [8]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("model_name, model_class", MODELS) +@pytest.mark.parametrize("backend", [_Backend.FLASHINFER] if + current_platform.is_cuda() else [_Backend.ROCM_FLASH]) +@pytest.mark.parametrize( + "split_attention", + [False, True] if current_platform.is_rocm() else [False]) +@pytest.mark.skipif(not current_platform.is_cuda_alike(), + reason="Only test ROCm or CUDA") @pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8") -@pytest.mark.skipif(not current_platform.is_device_capability((10, 0)), - reason="Only test on SM100(Blackwell)") +@pytest.mark.skipif(current_platform.is_cuda() + and not current_platform.is_device_capability((10, 0)), + reason="On CUDA only test on SM100(Blackwell)") +@pytest.mark.skipif(not current_platform.is_cuda_alike(), + reason="Only test ROCm or CUDA") def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, head_size: int, batch_size: int, dtype: torch.dtype, model_name: str, model_class: type[AttentionQuantPatternModel], - backend: _Backend, monkeypatch, dist_init): + backend: _Backend, split_attention: bool, + monkeypatch, dist_init): """Test AttentionStaticQuantPattern fusion pass""" monkeypatch.setenv("VLLM_USE_V1", "1") + if split_attention: + monkeypatch.setenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "1") device = torch.device("cuda:0") torch.manual_seed(42) @@ -326,6 +366,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, model_config=ModelConfig( model=model_name, max_model_len=2048, + dtype=dtype, ), scheduler_config=SchedulerConfig(max_num_seqs=1024), compilation_config=CompilationConfig( @@ -368,7 +409,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, forward_ctx = get_forward_context() forward_ctx.attn_metadata = model_unfused.build_attn_metadata( - batch_size) + batch_size, use_hnd=split_attention) # Run model directly without compilation and fusion result_unfused = model_unfused(q, k, v) @@ -389,7 +430,8 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, model_fused = model_fused.to(device) forward_ctx = get_forward_context() - forward_ctx.attn_metadata = model_fused.build_attn_metadata(batch_size) + forward_ctx.attn_metadata = model_fused.build_attn_metadata( + batch_size, use_hnd=split_attention) # Create test backend with fusion passes enabled noop_pass = NoOpEliminationPass(vllm_config) @@ -404,12 +446,19 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, assert model_compiled.attn._o_scale_float is None result_fused_1 = model_compiled(q, k, v) - # After the 1st round of the forward pass, output quant scale should be - # loaded into the attn layer's _o_scale_float, the 2nd round should - # reuse the loaded _o_scale_float - assert model_compiled.attn._o_scale_float is not None - result_fused_2 = model_compiled(q, k, v) - assert model_compiled.attn._o_scale_float is not None + if backend == _Backend.FLASHINFER: + # With the Flashinfer backend after the 1st round of the forward + # pass, output quant scale should be loaded into the attn layer's + # _o_scale_float, the 2nd round should reuse the loaded + # _o_scale_float + assert model_compiled.attn._o_scale_float is not None + result_fused_2 = model_compiled(q, k, v) + assert model_compiled.attn._o_scale_float is not None + + torch.testing.assert_close(result_unfused, + result_fused_2, + atol=1e-2, + rtol=1e-2) # Check attn fusion support quant_key = model_class.quant_key @@ -444,12 +493,8 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, \ "Attention should have output_block_scale after FP4 fusion" # noqa: E501 - # Check that results are closed + # Check that results are close torch.testing.assert_close(result_unfused, result_fused_1, atol=1e-2, rtol=1e-2) - torch.testing.assert_close(result_unfused, - result_fused_2, - atol=1e-2, - rtol=1e-2) diff --git a/tests/conftest.py b/tests/conftest.py index 1052aeb35bac7..0440e859fe02d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# ruff: noqa + +from tblib import pickling_support + +# Install support for pickling exceptions so that we can nicely propagate +# failures from tests running in a subprocess. +# This should be run before any custom exception subclasses are defined. +pickling_support.install() + import http.server import json import math diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index e1a840bb15039..86e08328c43b0 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -10,7 +10,8 @@ import pytest # noqa import torch from torch import Use # noqa -from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig +from vllm.config import CacheConfig, SchedulerConfig +from vllm.config.lora import LoRAConfig from vllm.core.interfaces import AllocStatus from vllm.core.scheduler import Scheduler, SchedulingBudget from vllm.lora.request import LoRARequest diff --git a/tests/detokenizer/test_stop_string_while_stop_model_terminates.py b/tests/detokenizer/test_stop_string_while_stop_model_terminates.py new file mode 100644 index 0000000000000..9b32a2927f2de --- /dev/null +++ b/tests/detokenizer/test_stop_string_while_stop_model_terminates.py @@ -0,0 +1,103 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm.sampling_params import SamplingParams +from vllm.v1.engine import EngineCoreRequest +from vllm.v1.engine.detokenizer import BaseIncrementalDetokenizer + + +@pytest.fixture(params=[True, False]) +def include_stop_str_in_output(request): + return request.param + + +class _DummyDetokenizer(BaseIncrementalDetokenizer): + + def __init__(self, request: EngineCoreRequest): + super().__init__(request) + + def decode_next(self, next_token_id: int) -> str: + # Map token id to single ASCII character for deterministic testing. + return chr(next_token_id) + + +def _make_request(stop, include_stop_str_in_output: bool, min_tokens: int = 0): + params = SamplingParams( + stop=stop, + include_stop_str_in_output=include_stop_str_in_output, + min_tokens=min_tokens) + # Keep other fields minimal for unit test purposes. + req = EngineCoreRequest( + request_id="test", + prompt_token_ids=[], + mm_features=None, + sampling_params=params, + pooling_params=None, + eos_token_id=None, + arrival_time=0.0, + lora_request=None, + cache_salt=None, + data_parallel_rank=None, + ) + return req + + +def test_stop_string_while_stop_token_terminates( + include_stop_str_in_output: bool): + """ + This test verifies that the detokenizer correctly handles the case where + the generated token sequence contains both: + - a stop token + - an <eos> token + + The detokenizer should respect the stop string and truncate the output + accordingly. + + Imagine the following sequence: + - "abcdeZ" is generated, where "Z" is the <eos> token. + - "cd" is the stop string. + + If include_stop_str_in_output=False, the detokenizer should truncate the + output to "ab" because the stop string "cd" is excluded. + If include_stop_str_in_output=True, the detokenizer should include the stop + string "cd" in the output, resulting in "abcd". + + + This verifies the behavioral change introduced in BaseIncrementalDetokenizer + where stop-string evaluation occurs before the early-return on + stop_terminated. + """ + + # Generate text "abcdeZ" and tokenize it. + generated_text = "abcde" + eos_token = "Z" + stop_string = "cd" + generated_text = generated_text + eos_token + token_ids = [ord(c) for c in generated_text] + + # Create a request with the stop string and initialize the detokenizer. + req = _make_request(stop=[stop_string], + include_stop_str_in_output=include_stop_str_in_output) + detok = _DummyDetokenizer(req) + + # Simulate that the last token ('Z') is a stop token (stop_terminated=True). + result = detok.update(new_token_ids=token_ids, stop_terminated=True) + + # The update should not report a stop string + assert result == stop_string + + # Output text should reflect stop-string handling: + # - include_stop_str_in_output=False => exclude "cd" => "ab" + # - include_stop_str_in_output=True => include "cd" => "abcd" + expected_text = "abcd" if include_stop_str_in_output else "ab" + assert detok.output_text == expected_text + + # The skipped final token should still be recorded in token_ids. + assert detok.output_token_ids == token_ids + + # get_next_output_text should return the full text when finished=True. + # (Buffering only applies during streaming when finished=False.) + assert detok.get_next_output_text(finished=True, + delta=False) == expected_text diff --git a/tests/distributed/conftest.py b/tests/distributed/conftest.py index 666a715cc0da1..7dc4a0cc3d582 100644 --- a/tests/distributed/conftest.py +++ b/tests/distributed/conftest.py @@ -8,7 +8,7 @@ import msgspec.msgpack import pytest import zmq -from vllm.config import KVEventsConfig +from vllm.config.kv_events import KVEventsConfig from vllm.distributed.kv_events import EventPublisherFactory from .test_events import SampleBatch diff --git a/tests/distributed/test_shm_buffer.py b/tests/distributed/test_shm_buffer.py new file mode 100644 index 0000000000000..f70028b879609 --- /dev/null +++ b/tests/distributed/test_shm_buffer.py @@ -0,0 +1,172 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import traceback +import unittest + +from vllm.distributed.device_communicators.shm_object_storage import ( + SingleWriterShmRingBuffer) + + +class TestSingleWriterShmRingBuffer(unittest.TestCase): + """Test suite for the ring buffer implementation""" + + def setUp(self): + """Set up test fixtures""" + self.buffer_size = 4096 + self.ring_buffer = None + + def tearDown(self): + """Clean up after tests""" + if self.ring_buffer: + del self.ring_buffer + + def test_buffer_opening(self): + """Test opening an existing buffer""" + # First create a buffer + self.ring_buffer = SingleWriterShmRingBuffer( + data_buffer_size=self.buffer_size, create=True) + + # Then open it with another instance + reader_buffer = SingleWriterShmRingBuffer(*self.ring_buffer.handle()) + self.assertFalse(reader_buffer.is_writer) + self.assertEqual(reader_buffer.shared_memory.name, + self.ring_buffer.shared_memory.name) + + def test_buffer_access(self): + """Test accessing allocated buffers""" + self.ring_buffer = SingleWriterShmRingBuffer( + data_buffer_size=self.buffer_size, create=True) + + size = 100 + address, monotonic_id = self.ring_buffer.allocate_buf(size) + + # Write some test data + test_data = b"Hello, World!" * 7 # 91 bytes + with self.ring_buffer.access_buf(address) as (data_buf, metadata): + data_buf[0:len(test_data)] = test_data + + # Read it back + with self.ring_buffer.access_buf(address) as (data_buf2, metadata2): + read_data = bytes(data_buf2[0:len(test_data)]) + read_id = metadata2[0] + + self.assertEqual(read_data, test_data) + self.assertEqual(read_id, monotonic_id) + + def test_memory_error_on_full_buffer(self): + """Test that MemoryError is raised when buffer is full""" + small_buffer_size = 200 + self.ring_buffer = SingleWriterShmRingBuffer( + data_buffer_size=small_buffer_size, create=True) + + # Fill up the buffer + self.ring_buffer.allocate_buf(100) + self.ring_buffer.allocate_buf(80) # Total: 196 bytes used + + # This should fail + with self.assertRaises(MemoryError): + self.ring_buffer.allocate_buf(1) # Would exceed buffer capacity + + def test_allocation_and_free(self): + """Test allocation and freeing of buffers""" + small_buffer_size = 200 + self.ring_buffer = SingleWriterShmRingBuffer( + data_buffer_size=small_buffer_size, create=True) + + size = 80 + # Write some data + test_data = b"Repeated test data" + for i in range(5): + address, monotonic_id = self.ring_buffer.allocate_buf(size) + with self.ring_buffer.access_buf(address) as (data_buf, metadata): + data_buf[0:4] = (0).to_bytes(4, "little") # 0 for not in-use + data_buf[4:len(test_data) + 4] = test_data + print(self.ring_buffer.metadata) + freed_ids = self.ring_buffer.free_buf(lambda *args: True) + print(f" Freed IDs: {freed_ids}") + self.assertEqual(freed_ids[0], i) + + def test_clear_buffer(self): + """Test clearing the buffer""" + self.ring_buffer = SingleWriterShmRingBuffer( + data_buffer_size=self.buffer_size, create=True) + + # Allocate some buffers + for _ in range(3): + self.ring_buffer.allocate_buf(100) + + # Clear the buffer + self.ring_buffer.clear() + + # Check that metadata is empty and IDs reset + self.assertEqual(len(self.ring_buffer.metadata), 0) + self.assertEqual(self.ring_buffer.monotonic_id_start, 0) + self.assertEqual(self.ring_buffer.monotonic_id_end, 0) + self.assertEqual(self.ring_buffer.data_buffer_start, 0) + self.assertEqual(self.ring_buffer.data_buffer_end, 0) + + +def main(): + """Main function demonstrating usage and running tests""" + print("=== SingleWriterShmRingBuffer Test Suite ===\n") + + # Run unit tests + print("Running unit tests...") + unittest.main(argv=[""], exit=False, verbosity=2) + + print("\n" + "=" * 50) + print("=== Manual Demo ===\n") + + # Manual demonstration + try: + print("Creating ring buffer...") + writer_buffer = SingleWriterShmRingBuffer(data_buffer_size=2048, + create=True) + reader_buffer = SingleWriterShmRingBuffer(*writer_buffer.handle()) + + print(f"Buffer created with name: {writer_buffer.shared_memory.name}") + + # Allocate some buffers + print("\nAllocating buffers...") + address_array = [] + for i in range(3): + size = 100 + i * 50 + try: + writer_buffer.free_buf(lambda *args: True) + address, monotonic_id = writer_buffer.allocate_buf(size) + address_array.append((address, size, monotonic_id)) + + # Write some test data + with writer_buffer.access_buf(address) as (data_buf, metadata): + test_message = f"Test message {i}".encode() + data_buf[0:len(test_message)] = test_message + + except MemoryError as e: + print(f" Failed to allocate {size} bytes: {e}") + + print("\nBuffer state:") + print(f" Data buffer start: {writer_buffer.data_buffer_start}") + print(f" Data buffer end: {writer_buffer.data_buffer_end}") + print(f" Monotonic ID start: {writer_buffer.monotonic_id_start}") + print(f" Monotonic ID end: {writer_buffer.monotonic_id_end}") + print(f" Metadata entries: {len(writer_buffer.metadata)}") + + # Try to read back the data + print("\nReading back data...") + for address, size, monotonic_id in address_array: + with reader_buffer.access_buf(address) as (data_buf, metadata): + # Find null terminator or read first 50 chars + data_bytes = bytes(data_buf[0:size]) + message = data_bytes.decode() + print(f" ID {monotonic_id}: '{message}'") + + except Exception as e: + print(f"Demo error: {e}") + traceback.print_exc() + + print("\n=== Demo Complete ===") + + +if __name__ == "__main__": + main() diff --git a/tests/distributed/test_shm_storage.py b/tests/distributed/test_shm_storage.py new file mode 100644 index 0000000000000..03495222bc1b8 --- /dev/null +++ b/tests/distributed/test_shm_storage.py @@ -0,0 +1,327 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import multiprocessing +import random +import time +import traceback +import unittest +from multiprocessing import Lock + +import torch + +# Assuming these are imported from your module +from vllm.distributed.device_communicators.shm_object_storage import ( + MsgpackSerde, SingleWriterShmObjectStorage, SingleWriterShmRingBuffer) +from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargsItem, + MultiModalSharedField) + + +def _dummy_elem(modality: str, key: str, size: int): + return MultiModalFieldElem( + modality=modality, + key=key, + data=torch.empty((size, ), dtype=torch.int8), + field=MultiModalSharedField(1), + ) + + +def _dummy_item(modality: str, size_by_key: dict[str, int]): + return MultiModalKwargsItem.from_elems([ + _dummy_elem(modality, key, size) for key, size in size_by_key.items() + ]) + + +class TestSingleWriterShmObjectStorage(unittest.TestCase): + + def setUp(self): + """Set up test fixtures before each test method.""" + ring_buffer = SingleWriterShmRingBuffer( + data_buffer_size=1024 * 100, + create=True, # 10 MB buffer + ) + self.storage = SingleWriterShmObjectStorage( + max_object_size=1024 * 10, # 10KB max object + n_readers=2, + ring_buffer=ring_buffer, + serde_class=MsgpackSerde, + reader_lock=Lock(), + ) + + def tearDown(self): + """Clean up after each test.""" + if self.storage: + del self.storage + + def test_minimal_put_get_cycle(self): + """Test basic put and get operations.""" + key = "test_key" + value = _dummy_item("text", {"field1": 10, "field2": 20}) + + # Put operation + address, monotonic_id = self.storage.put(key, value) + + # Verify key is in index + self.assertIn(key, self.storage.key_index) + self.assertEqual(self.storage.key_index[key], (address, monotonic_id)) + self.assertEqual(self.storage.id_index[monotonic_id], key) + + # Get operation + result = self.storage.get(address, monotonic_id) + + # Verify result + self.assertEqual(result, value) + + def test_put_same_key_twice(self): + """Test behavior when putting the same key multiple times.""" + key = "duplicate_key" + value1 = "first value" + value2 = "second value" + + # First put + address1, id1 = self.storage.put(key, value1) + retrieved1 = self.storage.get(address1, id1) + self.assertEqual(retrieved1, value1) + + # should raise an error on second put + with self.assertRaises(ValueError) as context: + self.storage.put(key, value2) + + self.assertIn("already exists in the storage", str(context.exception)) + + def test_large_object_rejection(self): + """Test that objects exceeding max_object_size are rejected.""" + # Create an object larger than max_object_size + large_data = "x" * (self.storage.max_object_size + 100) + + with self.assertRaises(ValueError) as context: + self.storage.put("large_key", large_data) + + self.assertIn("exceeds max object size", str(context.exception)) + + def test_buffer_overflow_and_cleanup(self): + """Test behavior when buffer fills up and needs cleanup.""" + # Fill up the buffer with many small objects + stored_items = [] + + try: + for i in range(1000): # Try to store many items + key = f"item_{i}" + value = f"data_{i}" * 100 # Make it reasonably sized + address, monotonic_id = self.storage.put(key, value) + stored_items.append((key, value, address, monotonic_id)) + except MemoryError: + print(f"Buffer filled after {len(stored_items)} items") + + # Verify that some items are still accessible + accessible_count = 0 + for key, original_value, address, monotonic_id in stored_items: + for i in range(self.storage.n_readers): + retrieved = self.storage.get(address, monotonic_id) + if retrieved == original_value: + accessible_count += 1 + + self.assertEqual(accessible_count, len(stored_items)) + + try: + for i in range(len(stored_items), 1000): # Try to store many items + key = f"item_{i}" + value = f"data_{i}" * 100 # Make it reasonably sized + address, monotonic_id = self.storage.put(key, value) + stored_items.append((key, value, address, monotonic_id)) + except MemoryError: + print(f"Buffer filled after {len(stored_items)} items") + + # Verify that some items are still accessibles + for key, original_value, address, monotonic_id in stored_items: + try: + for i in range(self.storage.n_readers): + retrieved = self.storage.get(address, monotonic_id) + if retrieved == original_value: + accessible_count += 1 + except ValueError as e: + print(f"Error retrieving {key}: {e}") + + # some items from the first batch may still be accessible + self.assertGreaterEqual(accessible_count, len(stored_items)) + + def test_blocking_unread_object(self): + """Test behavior when buffer fills up and needs cleanup.""" + # Fill up the buffer with many small objects + stored_items = [] + + try: + for i in range(1000): # Try to store many items + key = f"item_{i}" + value = f"data_{i}" * 100 # Make it reasonably sized + address, monotonic_id = self.storage.put(key, value) + stored_items.append((key, value, address, monotonic_id)) + except MemoryError: + print(f"Buffer filled after {len(stored_items)} items") + + # read all items except the first one + # to simulate a blocking situation + accessible_count = 0 + for key, original_value, address, monotonic_id in stored_items[1:]: + for i in range(self.storage.n_readers): + retrieved = self.storage.get(address, monotonic_id) + if retrieved == original_value: + accessible_count += 1 + + self.assertEqual(accessible_count, len(stored_items) - 1) + + try: + key = f"item_{len(stored_items)}" + value = f"data_{len(stored_items)}" * 100 + address, monotonic_id = self.storage.put(key, value) + except MemoryError: + print(f"Buffer filled after {len(stored_items)} items") + + # read the first item + for i in range(self.storage.n_readers): + key, original_value, address, monotonic_id = stored_items[0] + retrieved = self.storage.get(address, monotonic_id) + self.assertEqual(retrieved, original_value) + + try: + for i in range(len(stored_items), 1000): # Try to store many items + key = f"item_{i}" + value = f"data_{i}" * 100 # Make it reasonably sized + address, monotonic_id = self.storage.put(key, value) + stored_items.append((key, value, address, monotonic_id)) + except MemoryError: + print(f"Buffer filled after {len(stored_items)} items") + + # some items from the first batch may still be accessible + self.assertGreaterEqual(len(stored_items), accessible_count + 10) + + def test_invalid_get_operations(self): + """Test various invalid get operations.""" + # Test with non-existent address + with self.assertRaises(ValueError): # Could be various exceptions + self.storage.get(99999, 1) + + # Store something first + address, monotonic_id = self.storage.put("test", "value") + + # Test with wrong monotonic_id + with self.assertRaises(ValueError) as context: + self.storage.get(address, monotonic_id + 100) + + self.assertIn("has been modified or is invalid", \ + str(context.exception)) + + def test_clear_storage(self): + """Test clearing the storage.""" + # Store some items + for i in range(5): + self.storage.put(f"item_{i}", f"value_{i}") + + # Clear the storage + self.storage.clear() + + # Verify that all indices are empty + self.assertEqual(len(self.storage.key_index), 0) + self.assertEqual(len(self.storage.id_index), 0) + self.assertEqual(len(self.storage.ring_buffer.metadata), 0) + + # Verify that new items can be added after clearing + address, monotonic_id = self.storage.put("new_item", "new_value") + self.assertIn("new_item", self.storage.key_index) + self.assertEqual((address, monotonic_id), (0, 0)) + + +# Reader process function +def reader_process(process_id, storage_handle, items_to_read): + """Reader process that connects to existing shared memory and reads data.""" + reader_storage = SingleWriterShmObjectStorage.create_from_handle( + storage_handle) + + print(f"Reader {process_id} started") + + errors = [] + + for key, original_value, address, monotonic_id in items_to_read: + time.sleep(random.random() / 100) + try: + # Read data from shared memory + retrieved_value = reader_storage.get(address, monotonic_id) + + # Verify data integrity + assert retrieved_value == original_value + print(f"Reader {process_id} retrieved {key}: {retrieved_value}") + except Exception as e: + errors.append((key, str(e), type(e).__name__)) + + +def run_multiprocess_example(): + """Run a minimal working example with real shared memory.""" + print("=== Minimal Object Storage Example ===") + + try: + # Create storage instance + ring_buffer = SingleWriterShmRingBuffer( + data_buffer_size=1024 * 100, + create=True, # 10 MB buffer + ) + storage = SingleWriterShmObjectStorage( + max_object_size=1024, + n_readers=3, + ring_buffer=ring_buffer, + serde_class=MsgpackSerde, + reader_lock=Lock(), + ) + + print(f"Created storage (writer: {storage.is_writer})") + + # Test basic data types + test_data = [ + ("user_data", { + "name": "Alice", + "age": 30, + "scores": [95, 87, 92] + }), + ("simple_string", "Hello, World!"), + ("number", 42), + ("list_data", [1, 2, 3, "four", 5.0]), + ] + + stored_items = [] + + # Store all data + for key, value in test_data: + print(f"Storing {key}: {value}") + address, monotonic_id = storage.put(key, value) + stored_items.append((key, value, address, monotonic_id)) + print(f" -> Stored at address {address}, ID {monotonic_id}") + + print("\n--- Retrieving Data ---") + processes = [] + handle = storage.handle() + # initialize lock for reader processes + handle.reader_lock = Lock() + for i in range(storage.n_readers): + p = multiprocessing.Process(target=reader_process, + args=(i, handle, stored_items)) + processes.append(p) + p.start() + + for p in processes: + p.join(timeout=10) + if p.is_alive(): + p.terminate() + p.join() + + except Exception as e: + print(f"Error in minimal example: {e}") + traceback.print_exc() + + +if __name__ == "__main__": + # Run the minimal example first + run_multiprocess_example() + print("\n" + "=" * 50 + "\n") + + # Run the test suite + print("Running comprehensive test suite...") + unittest.main(verbosity=2, exit=False) diff --git a/tests/encoder_decoder/test_e2e_correctness.py b/tests/encoder_decoder/test_e2e_correctness.py index 8b99d9d6e21fb..3cf4c377fb581 100644 --- a/tests/encoder_decoder/test_e2e_correctness.py +++ b/tests/encoder_decoder/test_e2e_correctness.py @@ -63,6 +63,7 @@ def clear_cache(): current_platform.is_cpu(), reason="CPU backend is not currently supported with encoder/decoder models" ) +@pytest.mark.skip(reason="bart not supported in V1") def test_encoder_decoder_e2e( hf_runner, vllm_runner, diff --git a/tests/engine/test_executor.py b/tests/engine/test_executor.py index 15c7a97b50e1f..67064aff3ae92 100644 --- a/tests/engine/test_executor.py +++ b/tests/engine/test_executor.py @@ -25,7 +25,7 @@ class CustomUniExecutor(UniProcExecutor): timeout: Optional[float] = None, args: tuple = (), kwargs: Optional[dict] = None) -> list[Any]: - # Drop marker to show that this was ran + # Drop marker to show that this was run with open(".marker", "w"): ... return super().collective_rpc(method, timeout, args, kwargs) diff --git a/tests/entrypoints/offline_mode/test_offline_mode.py b/tests/entrypoints/offline_mode/test_offline_mode.py index a154bb1059aae..f8ed5dda260ff 100644 --- a/tests/entrypoints/offline_mode/test_offline_mode.py +++ b/tests/entrypoints/offline_mode/test_offline_mode.py @@ -79,7 +79,7 @@ def test_offline_mode(monkeypatch: pytest.MonkeyPatch): ) # Need to re-import huggingface_hub - # and friends to setup offline mode + # and friends to set up offline mode _re_import_modules() # Cached model files should be used in offline mode for model_config in MODEL_CONFIGS: @@ -136,7 +136,7 @@ def test_model_from_huggingface_offline(monkeypatch: pytest.MonkeyPatch): disable_connect, ) # Need to re-import huggingface_hub - # and friends to setup offline mode + # and friends to set up offline mode _re_import_modules() engine_args = EngineArgs(model="facebook/opt-125m") LLM(**dataclasses.asdict(engine_args)) diff --git a/tests/entrypoints/openai/test_chat.py b/tests/entrypoints/openai/test_chat.py index c9947c54a9181..4608850c7dae2 100644 --- a/tests/entrypoints/openai/test_chat.py +++ b/tests/entrypoints/openai/test_chat.py @@ -12,7 +12,7 @@ import pytest_asyncio import regex as re import requests import torch -from openai import BadRequestError, OpenAI +from openai import BadRequestError from ...utils import RemoteOpenAIServer @@ -968,59 +968,6 @@ async def test_long_seed(client: openai.AsyncOpenAI): or "less_than_equal" in exc_info.value.message) -@pytest.mark.asyncio -async def test_http_chat_no_model_name_with_curl(server: RemoteOpenAIServer): - url = f"http://localhost:{server.port}/v1/chat/completions" - headers = { - "Content-Type": "application/json", - } - data = { - # model_name is avoided here. - "messages": [{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "what is 1+1?" - }], - "max_tokens": - 5 - } - - response = requests.post(url, headers=headers, json=data) - response_data = response.json() - print(response_data) - assert response_data.get("model") == MODEL_NAME - choice = response_data.get("choices")[0] - message = choice.get("message") - assert message is not None - content = message.get("content") - assert content is not None - assert len(content) > 0 - - -@pytest.mark.asyncio -async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer): - openai_api_key = "EMPTY" - openai_api_base = f"http://localhost:{server.port}/v1" - - client = OpenAI( - api_key=openai_api_key, - base_url=openai_api_base, - ) - messages = [ - { - "role": "user", - "content": "Hello, vLLM!" - }, - ] - response = client.chat.completions.create( - model="", # empty string - messages=messages, - ) - assert response.model == MODEL_NAME - - @pytest.mark.asyncio async def test_invocations(server: RemoteOpenAIServer, client: openai.AsyncOpenAI): diff --git a/tests/entrypoints/openai/test_encoder_decoder.py b/tests/entrypoints/openai/test_encoder_decoder.py index 9c2aef23e8772..75612962c95f7 100644 --- a/tests/entrypoints/openai/test_encoder_decoder.py +++ b/tests/entrypoints/openai/test_encoder_decoder.py @@ -30,6 +30,7 @@ async def client(server): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.skip(reason="bart is not yet supported in V1") async def test_single_completion(client: openai.AsyncOpenAI, model_name: str): completion = await client.completions.create(model=model_name, prompt="Hello, my name is", diff --git a/tests/entrypoints/openai/test_prompt_validation.py b/tests/entrypoints/openai/test_prompt_validation.py index 4197583074dfe..bfa3f983cd87e 100644 --- a/tests/entrypoints/openai/test_prompt_validation.py +++ b/tests/entrypoints/openai/test_prompt_validation.py @@ -10,7 +10,7 @@ import pytest import regex as re import torch -from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.renderer import BaseRenderer from ...utils import RemoteOpenAIServer @@ -27,12 +27,16 @@ async def test_empty_prompt(): with RemoteOpenAIServer(model_name, server_args) as remote_server: client = remote_server.get_async_client() - with pytest.raises(openai.BadRequestError, - match="decoder prompt cannot be empty"): + with pytest.raises( + openai.BadRequestError, + match= + "Either prompt or prompt_embeds must be provided and non-empty." + ): await client.completions.create(model=model_name, prompt="", max_tokens=5, - temperature=0.0) + temperature=0.0, + extra_body={"prompt_embeds": []}) @pytest.mark.asyncio @@ -83,7 +87,7 @@ def test_load_prompt_embeds(dtype: torch.dtype, layout: torch.layout, buffer.seek(0) encoded_tensor = pybase64.b64encode(buffer.getvalue()) - loaded_prompt_embeds = OpenAIServing._load_prompt_embeds(encoded_tensor) + loaded_prompt_embeds = BaseRenderer.load_prompt_embeds(encoded_tensor) assert len(loaded_prompt_embeds) == 1 loaded_tensor = loaded_prompt_embeds[0]["prompt_embeds"] assert loaded_tensor.device.type == "cpu" diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 04805dbca74fa..d219a1f311f15 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -178,7 +178,7 @@ async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI, }, { "role": "user", - "content": "What is the weather in Dallas, TX?" + "content": "What is the weather in Dallas, TX with celsius?" }, ] @@ -213,8 +213,12 @@ async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI, MODEL_NAME = "openai-community/gpt2" +MODEL_NAME_SHORT = "gpt2" CHAT_TEMPLATE = "Dummy chat template for testing {}" -BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] +BASE_MODEL_PATHS = [ + BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME), + BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT) +] @dataclass @@ -270,6 +274,42 @@ def test_async_serving_chat_init(): assert serving_completion.chat_template == CHAT_TEMPLATE +@pytest.mark.asyncio +async def test_serving_chat_returns_correct_model_name(): + mock_engine = MagicMock(spec=MQLLMEngineClient) + mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) + mock_engine.errored = False + + models = OpenAIServingModels(engine_client=mock_engine, + base_model_paths=BASE_MODEL_PATHS, + model_config=MockModelConfig()) + serving_chat = OpenAIServingChat(mock_engine, + MockModelConfig(), + models, + response_role="assistant", + chat_template=CHAT_TEMPLATE, + chat_template_content_format="auto", + request_logger=None) + messages = [{"role": "user", "content": "what is 1+1?"}] + + async def return_model_name(*args): + return args[3] + + serving_chat.chat_completion_full_generator = return_model_name + + # Test that full name is returned when short name is requested + req = ChatCompletionRequest(model=MODEL_NAME_SHORT, messages=messages) + assert await serving_chat.create_chat_completion(req) == MODEL_NAME + + # Test that full name is returned when empty string is specified + req = ChatCompletionRequest(model="", messages=messages) + assert await serving_chat.create_chat_completion(req) == MODEL_NAME + + # Test that full name is returned when no model is specified + req = ChatCompletionRequest(messages=messages) + assert await serving_chat.create_chat_completion(req) == MODEL_NAME + + @pytest.mark.asyncio async def test_serving_chat_should_set_correct_max_tokens(): mock_engine = MagicMock(spec=MQLLMEngineClient) diff --git a/tests/entrypoints/openai/test_skip_tokenizer.py b/tests/entrypoints/openai/test_skip_tokenizer.py index af520ac61d8df..840e0dac81c97 100644 --- a/tests/entrypoints/openai/test_skip_tokenizer.py +++ b/tests/entrypoints/openai/test_skip_tokenizer.py @@ -11,7 +11,7 @@ import torch from ...utils import RemoteOpenAIServer -MODEL_NAME = "mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11" +MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11" DTYPE = "float16" diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index 9d61754059e2f..a324e86666055 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -34,11 +34,11 @@ EXPECTED_MM_BEAM_SEARCH_RES = [ ], [ "The image shows a Venn diagram with three over", - "The image shows a Venn diagram with three intersect", + "This image shows a Venn diagram with three over", ], [ "This image displays a gradient of colors ranging from", - "The image displays a gradient of colors ranging from", + "This image displays a gradient of colors forming a spectrum", ], ] @@ -436,3 +436,197 @@ async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str, ) message = chat_completion.choices[0].message assert message.content is not None and len(message.content) >= 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize( + "image_urls", + [TEST_IMAGE_ASSETS[:i] for i in range(2, len(TEST_IMAGE_ASSETS))], + indirect=True) +async def test_completions_with_image( + client: openai.AsyncOpenAI, + model_name: str, + image_urls: list[str], +): + for image_url in image_urls: + chat_completion = await client.chat.completions.create( + messages=[ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "Describe this image.", + }, + { + "type": "image_url", + "image_url": { + "url": image_url, + } + }, + ], + }, + ], + model=model_name, + ) + assert chat_completion.choices[0].message.content is not None + assert isinstance(chat_completion.choices[0].message.content, str) + assert len(chat_completion.choices[0].message.content) > 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize( + "image_urls", + [TEST_IMAGE_ASSETS[:i] for i in range(2, len(TEST_IMAGE_ASSETS))], + indirect=True) +async def test_completions_with_image_with_uuid( + client: openai.AsyncOpenAI, + model_name: str, + image_urls: list[str], +): + for image_url in image_urls: + chat_completion = await client.chat.completions.create( + messages=[ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "Describe this image.", + }, + { + "type": "image_url", + "image_url": { + "url": image_url, + }, + "uuid": image_url + }, + ], + }, + ], + model=model_name, + ) + assert chat_completion.choices[0].message.content is not None + assert isinstance(chat_completion.choices[0].message.content, str) + assert len(chat_completion.choices[0].message.content) > 0 + + # Second request, with empty image but the same uuid. + chat_completion_with_empty_image = await client.chat.completions.create( + messages=[ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "Describe this image.", + }, + { + "type": "image_url", + "image_url": {}, + "uuid": image_url + }, + ], + }, + ], + model=model_name, + ) + assert chat_completion_with_empty_image.choices[ + 0].message.content is not None + assert isinstance( + chat_completion_with_empty_image.choices[0].message.content, str) + assert len( + chat_completion_with_empty_image.choices[0].message.content) > 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_completions_with_empty_image_with_uuid_without_cache_hit( + client: openai.AsyncOpenAI, + model_name: str, +): + with pytest.raises(openai.BadRequestError): + _ = await client.chat.completions.create( + messages=[ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "Describe this image.", + }, + { + "type": "image_url", + "image_url": {}, + "uuid": "uuid_not_previously_seen" + }, + ], + }, + ], + model=model_name, + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize( + "image_urls", + [TEST_IMAGE_ASSETS[:i] for i in range(2, len(TEST_IMAGE_ASSETS))], + indirect=True) +async def test_completions_with_image_with_incorrect_uuid_format( + client: openai.AsyncOpenAI, + model_name: str, + image_urls: list[str], +): + for image_url in image_urls: + chat_completion = await client.chat.completions.create( + messages=[ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "Describe this image.", + }, + { + "type": "image_url", + "image_url": { + "url": image_url, + "incorrect_uuid_key": image_url, + }, + "also_incorrect_uuid_key": image_url, + }, + ], + }, + ], + model=model_name, + ) + assert chat_completion.choices[0].message.content is not None + assert isinstance(chat_completion.choices[0].message.content, str) + assert len(chat_completion.choices[0].message.content) > 0 diff --git a/tests/entrypoints/pooling/__init__.py b/tests/entrypoints/pooling/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/entrypoints/pooling/correctness/__init__.py b/tests/entrypoints/pooling/correctness/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/entrypoints/openai/correctness/test_mteb_embed.py b/tests/entrypoints/pooling/correctness/test_mteb_embed.py similarity index 79% rename from tests/entrypoints/openai/correctness/test_mteb_embed.py rename to tests/entrypoints/pooling/correctness/test_mteb_embed.py index 1601c18d9b787..12a4875bdacfd 100644 --- a/tests/entrypoints/openai/correctness/test_mteb_embed.py +++ b/tests/entrypoints/pooling/correctness/test_mteb_embed.py @@ -4,10 +4,9 @@ import os import pytest -from tests.models.language.pooling.mteb_utils import (MTEB_EMBED_TASKS, - MTEB_EMBED_TOL, - OpenAIClientMtebEncoder, - run_mteb_embed_task) +from tests.models.language.pooling_mteb_test.mteb_utils import ( + MTEB_EMBED_TASKS, MTEB_EMBED_TOL, OpenAIClientMtebEncoder, + run_mteb_embed_task) from tests.utils import RemoteOpenAIServer os.environ["VLLM_LOGGING_LEVEL"] = "WARNING" diff --git a/tests/entrypoints/openai/correctness/test_mteb_score.py b/tests/entrypoints/pooling/correctness/test_mteb_score.py similarity index 77% rename from tests/entrypoints/openai/correctness/test_mteb_score.py rename to tests/entrypoints/pooling/correctness/test_mteb_score.py index 417f85adc6e06..7c059d16b3863 100644 --- a/tests/entrypoints/openai/correctness/test_mteb_score.py +++ b/tests/entrypoints/pooling/correctness/test_mteb_score.py @@ -4,15 +4,9 @@ import os import pytest -# yapf conflicts with isort for this block -# yapf: disable -from tests.models.language.pooling.mteb_utils import (MTEB_RERANK_LANGS, - MTEB_RERANK_TASKS, - MTEB_RERANK_TOL, - RerankClientMtebEncoder, - ScoreClientMtebEncoder, - run_mteb_rerank) -# yapf: enable +from tests.models.language.pooling_mteb_test.mteb_utils import ( + MTEB_RERANK_LANGS, MTEB_RERANK_TASKS, MTEB_RERANK_TOL, + RerankClientMtebEncoder, ScoreClientMtebEncoder, run_mteb_rerank) from tests.utils import RemoteOpenAIServer os.environ["VLLM_LOGGING_LEVEL"] = "WARNING" diff --git a/tests/entrypoints/pooling/llm/__init__.py b/tests/entrypoints/pooling/llm/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/entrypoints/llm/test_classify.py b/tests/entrypoints/pooling/llm/test_classify.py similarity index 98% rename from tests/entrypoints/llm/test_classify.py rename to tests/entrypoints/pooling/llm/test_classify.py index 6c0c9cd015801..ff5cea11a9182 100644 --- a/tests/entrypoints/llm/test_classify.py +++ b/tests/entrypoints/pooling/llm/test_classify.py @@ -6,11 +6,10 @@ import weakref import pytest import torch +from tests.models.utils import softmax from vllm import LLM, PoolingParams from vllm.distributed import cleanup_dist_env_and_memory -from ...models.utils import softmax - MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach" prompts = ["The chef prepared a delicious meal."] diff --git a/tests/entrypoints/llm/test_embedding.py b/tests/entrypoints/pooling/llm/test_embedding.py similarity index 100% rename from tests/entrypoints/llm/test_embedding.py rename to tests/entrypoints/pooling/llm/test_embedding.py diff --git a/tests/entrypoints/llm/test_encode.py b/tests/entrypoints/pooling/llm/test_encode.py similarity index 100% rename from tests/entrypoints/llm/test_encode.py rename to tests/entrypoints/pooling/llm/test_encode.py diff --git a/tests/entrypoints/llm/test_reward.py b/tests/entrypoints/pooling/llm/test_reward.py similarity index 97% rename from tests/entrypoints/llm/test_reward.py rename to tests/entrypoints/pooling/llm/test_reward.py index 2cee3c8d94e36..11d164c978a92 100644 --- a/tests/entrypoints/llm/test_reward.py +++ b/tests/entrypoints/pooling/llm/test_reward.py @@ -6,11 +6,10 @@ import weakref import pytest import torch +from tests.models.utils import softmax from vllm import LLM, PoolingParams from vllm.distributed import cleanup_dist_env_and_memory -from ...models.utils import softmax - MODEL_NAME = "internlm/internlm2-1_8b-reward" prompts = ["The chef prepared a delicious meal."] diff --git a/tests/entrypoints/llm/test_score.py b/tests/entrypoints/pooling/llm/test_score.py similarity index 97% rename from tests/entrypoints/llm/test_score.py rename to tests/entrypoints/pooling/llm/test_score.py index f715dacacb8ff..447378f989d09 100644 --- a/tests/entrypoints/llm/test_score.py +++ b/tests/entrypoints/pooling/llm/test_score.py @@ -6,11 +6,10 @@ import weakref import pytest import torch +from tests.models.utils import softmax from vllm import LLM, PoolingParams from vllm.distributed import cleanup_dist_env_and_memory -from ...models.utils import softmax - MODEL_NAME = "tomaarsen/Qwen3-Reranker-0.6B-seq-cls" diff --git a/tests/entrypoints/pooling/openai/__init__.py b/tests/entrypoints/pooling/openai/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/entrypoints/openai/test_classification.py b/tests/entrypoints/pooling/openai/test_classification.py similarity index 99% rename from tests/entrypoints/openai/test_classification.py rename to tests/entrypoints/pooling/openai/test_classification.py index 36c96d76c2e5f..26c2c8e6af17d 100644 --- a/tests/entrypoints/openai/test_classification.py +++ b/tests/entrypoints/pooling/openai/test_classification.py @@ -6,10 +6,9 @@ import requests import torch import torch.nn.functional as F +from tests.utils import RemoteOpenAIServer from vllm.entrypoints.openai.protocol import ClassificationResponse -from ...utils import RemoteOpenAIServer - MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach" DTYPE = "float32" # Use float32 to avoid NaN issue diff --git a/tests/entrypoints/openai/test_embedding.py b/tests/entrypoints/pooling/openai/test_embedding.py similarity index 98% rename from tests/entrypoints/openai/test_embedding.py rename to tests/entrypoints/pooling/openai/test_embedding.py index d46ab304ba6d5..37a10e79d4fc7 100644 --- a/tests/entrypoints/openai/test_embedding.py +++ b/tests/entrypoints/pooling/openai/test_embedding.py @@ -11,14 +11,13 @@ import requests import torch import torch.nn.functional as F +from tests.models.language.pooling.embed_utils import ( + run_embedding_correctness_test) +from tests.models.utils import check_embeddings_close +from tests.utils import RemoteOpenAIServer from vllm.entrypoints.openai.protocol import EmbeddingResponse from vllm.transformers_utils.tokenizer import get_tokenizer -from ...models.language.pooling.embed_utils import ( - run_embedding_correctness_test) -from ...models.utils import check_embeddings_close -from ...utils import RemoteOpenAIServer - MODEL_NAME = "intfloat/multilingual-e5-small" DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501 DTYPE = "bfloat16" diff --git a/tests/entrypoints/openai/test_embedding_dimensions.py b/tests/entrypoints/pooling/openai/test_embedding_dimensions.py similarity index 95% rename from tests/entrypoints/openai/test_embedding_dimensions.py rename to tests/entrypoints/pooling/openai/test_embedding_dimensions.py index 91e91699b92ca..3c7e88daa8ff3 100644 --- a/tests/entrypoints/openai/test_embedding_dimensions.py +++ b/tests/entrypoints/pooling/openai/test_embedding_dimensions.py @@ -9,13 +9,12 @@ from typing import Optional import openai import pytest -from vllm.entrypoints.openai.protocol import EmbeddingResponse - -from ...conftest import HfRunner -from ...models.language.pooling.embed_utils import ( +from tests.conftest import HfRunner +from tests.models.language.pooling.embed_utils import ( run_embedding_correctness_test) -from ...models.utils import EmbedModelInfo -from ...utils import RemoteOpenAIServer +from tests.models.utils import EmbedModelInfo +from tests.utils import RemoteOpenAIServer +from vllm.entrypoints.openai.protocol import EmbeddingResponse MODELS = [ EmbedModelInfo("intfloat/multilingual-e5-small", is_matryoshka=False), diff --git a/tests/entrypoints/openai/test_embedding_long_text.py b/tests/entrypoints/pooling/openai/test_embedding_long_text.py similarity index 99% rename from tests/entrypoints/openai/test_embedding_long_text.py rename to tests/entrypoints/pooling/openai/test_embedding_long_text.py index 86bd34abb97e0..2d3da238d245e 100644 --- a/tests/entrypoints/openai/test_embedding_long_text.py +++ b/tests/entrypoints/pooling/openai/test_embedding_long_text.py @@ -14,10 +14,9 @@ import openai import pytest import pytest_asyncio +from tests.utils import RemoteOpenAIServer from vllm.entrypoints.openai.protocol import EmbeddingResponse -from ...utils import RemoteOpenAIServer - def _generate_random_text(word_count: int) -> str: """Generate random text with approximately the specified word count.""" diff --git a/tests/entrypoints/openai/test_pooling.py b/tests/entrypoints/pooling/openai/test_pooling.py similarity index 99% rename from tests/entrypoints/openai/test_pooling.py rename to tests/entrypoints/pooling/openai/test_pooling.py index 63f4205e0a42b..9f58955cfb40b 100644 --- a/tests/entrypoints/openai/test_pooling.py +++ b/tests/entrypoints/pooling/openai/test_pooling.py @@ -8,11 +8,10 @@ import pytest import requests from tests.models.utils import check_embeddings_close +from tests.utils import RemoteOpenAIServer from vllm.entrypoints.openai.protocol import PoolingResponse from vllm.transformers_utils.tokenizer import get_tokenizer -from ...utils import RemoteOpenAIServer - MODEL_NAME = "internlm/internlm2-1_8b-reward" DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501 diff --git a/tests/entrypoints/openai/test_rerank.py b/tests/entrypoints/pooling/openai/test_rerank.py similarity index 99% rename from tests/entrypoints/openai/test_rerank.py rename to tests/entrypoints/pooling/openai/test_rerank.py index ce4d6c5f5d337..992cb5147ef0d 100644 --- a/tests/entrypoints/openai/test_rerank.py +++ b/tests/entrypoints/pooling/openai/test_rerank.py @@ -6,10 +6,9 @@ import requests import torch import torch.nn.functional as F +from tests.utils import RemoteOpenAIServer from vllm.entrypoints.openai.protocol import RerankResponse -from ...utils import RemoteOpenAIServer - MODEL_NAME = "BAAI/bge-reranker-base" DTYPE = "bfloat16" diff --git a/tests/entrypoints/openai/test_score.py b/tests/entrypoints/pooling/openai/test_score.py similarity index 99% rename from tests/entrypoints/openai/test_score.py rename to tests/entrypoints/pooling/openai/test_score.py index 4fafcfb45fa22..d676ecccbc87c 100644 --- a/tests/entrypoints/openai/test_score.py +++ b/tests/entrypoints/pooling/openai/test_score.py @@ -8,10 +8,9 @@ import torch import torch.nn.functional as F from torch import tensor +from tests.utils import RemoteOpenAIServer from vllm.entrypoints.openai.protocol import ScoreResponse -from ...utils import RemoteOpenAIServer - MODELS = [ { "name": "BAAI/bge-reranker-v2-m3", diff --git a/tests/entrypoints/openai/test_truncation.py b/tests/entrypoints/pooling/openai/test_truncation.py similarity index 100% rename from tests/entrypoints/openai/test_truncation.py rename to tests/entrypoints/pooling/openai/test_truncation.py diff --git a/tests/entrypoints/openai/test_vision_embedding.py b/tests/entrypoints/pooling/openai/test_vision_embedding.py similarity index 98% rename from tests/entrypoints/openai/test_vision_embedding.py rename to tests/entrypoints/pooling/openai/test_vision_embedding.py index dbd403fb7a7b5..48434e36eb265 100644 --- a/tests/entrypoints/openai/test_vision_embedding.py +++ b/tests/entrypoints/pooling/openai/test_vision_embedding.py @@ -7,11 +7,10 @@ import pytest import requests from transformers import AutoProcessor +from tests.utils import VLLM_PATH, RemoteOpenAIServer from vllm.entrypoints.openai.protocol import EmbeddingResponse from vllm.multimodal.utils import encode_image_base64, fetch_image -from ...utils import VLLM_PATH, RemoteOpenAIServer - MODEL_NAME = "TIGER-Lab/VLM2Vec-Full" MAXIMUM_IMAGES = 2 diff --git a/tests/entrypoints/test_api_server_process_manager.py b/tests/entrypoints/test_api_server_process_manager.py index e4af60a782651..a993e24ff838a 100644 --- a/tests/entrypoints/test_api_server_process_manager.py +++ b/tests/entrypoints/test_api_server_process_manager.py @@ -95,7 +95,7 @@ def test_api_server_process_manager_init(api_server_args, with_stats_update): assert not proc.is_alive() -@patch("vllm.entrypoints.cli.serve.run_api_server_worker", +@patch("vllm.entrypoints.cli.serve.run_api_server_worker_proc", mock_run_api_server_worker) def test_wait_for_completion_or_failure(api_server_args): """Test that wait_for_completion_or_failure works with failures.""" diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index 18db1027c004d..dd33f5c8c1d8e 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -21,7 +21,7 @@ from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template, resolve_chat_template_content_format, resolve_hf_chat_template) from vllm.entrypoints.llm import apply_hf_chat_template -from vllm.multimodal import MultiModalDataDict +from vllm.multimodal import MultiModalDataDict, MultiModalUUIDDict from vllm.multimodal.utils import (encode_audio_base64, encode_image_base64, encode_video_base64) from vllm.transformers_utils.tokenizer_group import TokenizerGroup @@ -79,6 +79,28 @@ def phi3v_tokenizer(): ) +@pytest.fixture(scope="function") +def qwen2_audio_model_config(): + return ModelConfig( + QWEN2AUDIO_MODEL_ID, + runner="generate", + trust_remote_code=True, + limit_mm_per_prompt={ + "audio": 1, + }, + ) + + +@pytest.fixture(scope="module") +def qwen2_audio_tokenizer(): + return TokenizerGroup( + tokenizer_id=QWEN2AUDIO_MODEL_ID, + enable_lora=False, + max_num_seqs=5, + max_input_length=None, + ) + + @pytest.fixture(scope="function") def qwen25omni_model_config_mm_interleaved(): return ModelConfig( @@ -169,6 +191,7 @@ def audio_url(): def _assert_mm_data_is_image_input( mm_data: Optional[MultiModalDataDict], image_count: int, + skipped_image_indices: Optional[list] = None, ) -> None: assert mm_data is not None assert set(mm_data.keys()) == {"image"} @@ -177,6 +200,30 @@ def _assert_mm_data_is_image_input( assert image_data is not None assert isinstance(image_data, list) and len(image_data) == image_count + if skipped_image_indices is not None: + for i in skipped_image_indices: + assert image_data[i] is None + + +def _assert_mm_uuids( + mm_uuids: Optional[MultiModalUUIDDict], + media_count: int, + expected_uuids: list[Optional[str]], + modality: str = "image", +) -> None: + if len(expected_uuids) > 0: + assert mm_uuids is not None + assert modality in mm_uuids + + image_uuids = mm_uuids.get(modality) + assert image_uuids is not None + + assert isinstance(image_uuids, + list) and len(image_uuids) == media_count + + assert image_uuids == expected_uuids + else: + assert mm_uuids is None ModalityType = Literal["image", "video", "audio"] @@ -184,8 +231,10 @@ MultiModalDataCounts = Mapping[ModalityType, int] def _assert_mm_data_inputs( - mm_data: Optional[MultiModalDataDict], - data_count: MultiModalDataCounts, + mm_data: Optional[MultiModalDataDict], + data_count: MultiModalDataCounts, + skipped_media_indices: Optional[dict[ + str, list]] = None, # modality -> list[int] ) -> None: assert mm_data is not None assert set(data_count.keys()) == (set(mm_data.keys())) @@ -195,13 +244,20 @@ def _assert_mm_data_inputs( assert modality_data is not None assert isinstance(modality_data, list) and len(modality_data) == n + if skipped_media_indices is not None: + skipped_media_indices_for_modality = skipped_media_indices.get( + modality) + assert skipped_media_indices_for_modality is not None + for i in skipped_media_indices_for_modality: + assert modality_data[i] is None + def test_parse_chat_messages_single_image( phi3v_model_config, phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages( + conversation, mm_data, mm_uuids = parse_chat_messages( [{ "role": "user", @@ -228,6 +284,470 @@ def test_parse_chat_messages_single_image( "content": "<|image_1|>\nWhat's in the image?" }] _assert_mm_data_is_image_input(mm_data, 1) + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[None]) + + +def test_parse_chat_messages_single_image_with_uuid( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + image_uuid = str(hash(image_url)) + conversation, mm_data, mm_uuids = parse_chat_messages( + [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url, + }, + "uuid": image_uuid, + }, + { + "type": "text", + "text": "What's in the image?" + }, + ], + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": "user", + "content": "<|image_1|>\nWhat's in the image?" + }] + _assert_mm_data_is_image_input(mm_data, 1) + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[image_uuid]) + + +def test_parse_chat_messages_single_empty_image_with_uuid( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + image_uuid = str(hash(image_url)) + conversation, mm_data, mm_uuids = parse_chat_messages( + [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": None, + "uuid": image_uuid, + }, + { + "type": "text", + "text": "What's in the image?" + }, + ], + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": "user", + "content": "<|image_1|>\nWhat's in the image?" + }] + _assert_mm_data_is_image_input(mm_data, 1, skipped_image_indices=[0]) + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[image_uuid]) + + +def test_parse_chat_messages_single_image_with_bad_uuid_format( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + image_uuid = str(hash(image_url)) + conversation, mm_data, mm_uuids = parse_chat_messages( + [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url, + "uuid": image_uuid, + }, + "bad_uuid_key": image_uuid, + }, + { + "type": "text", + "text": "What's in the image?" + }, + ], + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": "user", + "content": "<|image_1|>\nWhat's in the image?" + }] + _assert_mm_data_is_image_input(mm_data, 1) + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[None]) + + +def test_parse_chat_messages_multiple_images_with_uuids( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + image_uuid1 = "my_uuid_1" + image_uuid2 = "my_uuid_2" + + conversation, mm_data, mm_uuids = parse_chat_messages( + [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url, + }, + "uuid": image_uuid1, + }, + { + "type": "image_url", + "image_url": { + "url": image_url, + }, + "uuid": image_uuid2, + }, + { + "type": "text", + "text": "What's in the image?" + }, + ], + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": + "user", + "content": + "<|image_1|>\n<|image_2|>\nWhat's in the image?", + }] + _assert_mm_data_is_image_input(mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid1, image_uuid2]) + + +def test_parse_chat_messages_multiple_empty_images_with_uuids( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + image_uuid1 = "my_uuid_1" + image_uuid2 = "my_uuid_2" + + conversation, mm_data, mm_uuids = parse_chat_messages( + [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": None, + "uuid": image_uuid1, + }, + { + "type": "image_url", + "image_url": None, + "uuid": image_uuid2, + }, + { + "type": "text", + "text": "What's in the image?" + }, + ], + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": + "user", + "content": + "<|image_1|>\n<|image_2|>\nWhat's in the image?", + }] + _assert_mm_data_is_image_input(mm_data, 2, skipped_image_indices=[0, 1]) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid1, image_uuid2]) + + +def test_parse_chat_messages_mixed_empty_images_with_uuids( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + image_uuid1 = "my_uuid_1" + image_uuid2 = "my_uuid_2" + + conversation, mm_data, mm_uuids = parse_chat_messages( + [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url, + }, + "uuid": image_uuid1, + }, + { + "type": "image_url", + "image_url": None, + "uuid": image_uuid2, + }, + { + "type": "text", + "text": "What's in the image?" + }, + ], + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": + "user", + "content": + "<|image_1|>\n<|image_2|>\nWhat's in the image?", + }] + _assert_mm_data_is_image_input(mm_data, 2, skipped_image_indices=[1]) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid1, image_uuid2]) + + +@pytest.mark.asyncio +async def test_parse_chat_messages_single_image_with_uuid_async( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + image_uuid = str(hash(image_url)) + conversation, mm_future, mm_uuids = parse_chat_messages_futures( + [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + }, + "uuid": image_uuid, + }, + { + "type": "text", + "text": "What's in the image?" + }, + ], + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": "user", + "content": "<|image_1|>\nWhat's in the image?" + }] + _assert_mm_data_is_image_input(await mm_future, 1) + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[image_uuid]) + + +@pytest.mark.asyncio +async def test_parse_chat_messages_empty_image_with_uuid_async( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + image_uuid = str(hash(image_url)) + conversation, mm_future, mm_uuids = parse_chat_messages_futures( + [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": None, + "uuid": image_uuid, + }, + { + "type": "text", + "text": "What's in the image?" + }, + ], + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": "user", + "content": "<|image_1|>\nWhat's in the image?" + }] + _assert_mm_data_is_image_input(await mm_future, + 1, + skipped_image_indices=[0]) + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[image_uuid]) + + +@pytest.mark.asyncio +async def test_parse_chat_messages_multiple_images_with_uuids_async( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + image_uuid1 = "my_uuid_1" + image_uuid2 = "my_uuid_2" + + conversation, mm_future, mm_uuids = parse_chat_messages_futures( + [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + }, + "uuid": image_uuid1, + }, + { + "type": "image_pil", + "image_pil": ImageAsset("cherry_blossom").pil_image, + "uuid": image_uuid2, + }, + { + "type": "text", + "text": "What's in these images?" + }, + ], + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": + "user", + "content": + "<|image_1|>\n<|image_2|>\nWhat's in these images?", + }] + _assert_mm_data_is_image_input(await mm_future, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid1, image_uuid2]) + + +@pytest.mark.asyncio +async def test_parse_chat_messages_multiple_empty_images_with_uuids_async( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + image_uuid1 = "my_uuid_1" + image_uuid2 = "my_uuid_2" + + conversation, mm_future, mm_uuids = parse_chat_messages_futures( + [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": None, + "uuid": image_uuid1, + }, + { + "type": "image_pil", + "image_pil": None, + "uuid": image_uuid2, + }, + { + "type": "text", + "text": "What's in these images?" + }, + ], + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": + "user", + "content": + "<|image_1|>\n<|image_2|>\nWhat's in these images?", + }] + _assert_mm_data_is_image_input(await mm_future, + 2, + skipped_image_indices=[0, 1]) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid1, image_uuid2]) + + +@pytest.mark.asyncio +async def test_parse_chat_messages_multiple_images_with_partial_uuids_async( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + image_uuid2 = "my_uuid_2" + + conversation, mm_future, mm_uuids = parse_chat_messages_futures( + [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + }, + }, + { + "type": "image_pil", + "image_pil": ImageAsset("cherry_blossom").pil_image, + "uuid": image_uuid2, + }, + { + "type": "text", + "text": "What's in these images?" + }, + ], + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": + "user", + "content": + "<|image_1|>\n<|image_2|>\nWhat's in these images?", + }] + _assert_mm_data_is_image_input(await mm_future, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, image_uuid2]) def test_parse_chat_messages_empty_system( @@ -235,7 +755,7 @@ def test_parse_chat_messages_empty_system( mistral_tokenizer, ): # Test string format - conversation, _ = parse_chat_messages( + conversation, _, _ = parse_chat_messages( [ { "role": "system", @@ -265,7 +785,7 @@ def test_parse_chat_messages_empty_system( ] # Test openai format - conversation, _ = parse_chat_messages( + conversation, _, _ = parse_chat_messages( [ { "role": "system", @@ -307,7 +827,7 @@ async def test_parse_chat_messages_single_image_async( phi3v_tokenizer, image_url, ): - conversation, mm_future = parse_chat_messages_futures( + conversation, mm_future, mm_uuids = parse_chat_messages_futures( [{ "role": "user", @@ -334,6 +854,7 @@ async def test_parse_chat_messages_single_image_async( "content": "<|image_1|>\nWhat's in the image?" }] _assert_mm_data_is_image_input(await mm_future, 1) + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[None]) def test_parse_chat_messages_multiple_images( @@ -341,7 +862,7 @@ def test_parse_chat_messages_multiple_images( phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages( + conversation, mm_data, mm_uuids = parse_chat_messages( [{ "role": "user", @@ -374,6 +895,115 @@ def test_parse_chat_messages_multiple_images( "<|image_1|>\n<|image_2|>\nWhat's in these images?", }] _assert_mm_data_is_image_input(mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) + + +def test_parse_chat_messages_empty_pil_image_with_uuid( + phi3v_model_config, + phi3v_tokenizer, +): + uuid = "abcd" + conversation, mm_data, mm_uuids = parse_chat_messages( + [{ + "role": + "user", + "content": [ + { + "type": "image_pil", + "image_pil": None, + "uuid": uuid + }, + { + "type": "text", + "text": "What's in this image?" + }, + ], + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": "user", + "content": "<|image_1|>\nWhat's in this image?", + }] + _assert_mm_data_is_image_input(mm_data, 1, skipped_image_indices=[0]) + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[uuid]) + + +def test_parse_chat_messages_empty_image_embeds_with_uuid( + phi3v_model_config, + phi3v_tokenizer, +): + uuid = "abcd" + conversation, mm_data, mm_uuids = parse_chat_messages( + [{ + "role": + "user", + "content": [ + { + "type": "image_embeds", + "image_embeds": None, + "uuid": uuid + }, + { + "type": "text", + "text": "What's in this image?" + }, + ], + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": "user", + "content": "<|image_1|>\nWhat's in this image?", + }] + assert mm_data is not None + assert "image" in mm_data + assert mm_data["image"] is None + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[uuid]) + + +@pytest.mark.asyncio +async def test_parse_chat_messages_empty_image_embeds_with_uuid_async( + phi3v_model_config, + phi3v_tokenizer, +): + uuid = "abcd" + conversation, mm_future, mm_uuids = parse_chat_messages_futures( + [{ + "role": + "user", + "content": [ + { + "type": "image_embeds", + "image_embeds": None, + "uuid": uuid + }, + { + "type": "text", + "text": "What's in this image?" + }, + ], + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": "user", + "content": "<|image_1|>\nWhat's in this image?", + }] + mm_data = await mm_future + assert mm_data is not None + assert "image" in mm_data + assert mm_data["image"] is None + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[uuid]) @pytest.mark.asyncio @@ -382,7 +1012,7 @@ async def test_parse_chat_messages_multiple_images_async( phi3v_tokenizer, image_url, ): - conversation, mm_future = parse_chat_messages_futures( + conversation, mm_future, mm_uuids = parse_chat_messages_futures( [{ "role": "user", @@ -415,6 +1045,7 @@ async def test_parse_chat_messages_multiple_images_async( "<|image_1|>\n<|image_2|>\nWhat's in these images?", }] _assert_mm_data_is_image_input(await mm_future, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) def test_parse_chat_messages_placeholder_already_in_prompt( @@ -422,7 +1053,7 @@ def test_parse_chat_messages_placeholder_already_in_prompt( phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages( + conversation, mm_data, mm_uuids = parse_chat_messages( [{ "role": "user", @@ -458,6 +1089,7 @@ def test_parse_chat_messages_placeholder_already_in_prompt( "What's in <|image_1|> and how does it compare to <|image_2|>?", }] _assert_mm_data_is_image_input(mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) def test_parse_chat_messages_placeholder_one_already_in_prompt( @@ -465,7 +1097,7 @@ def test_parse_chat_messages_placeholder_one_already_in_prompt( phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages( + conversation, mm_data, mm_uuids = parse_chat_messages( [{ "role": "user", @@ -503,6 +1135,7 @@ def test_parse_chat_messages_placeholder_one_already_in_prompt( "other one?", }] _assert_mm_data_is_image_input(mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) def test_parse_chat_messages_multiple_images_across_messages( @@ -510,7 +1143,7 @@ def test_parse_chat_messages_multiple_images_across_messages( phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages( + conversation, mm_data, mm_uuids = parse_chat_messages( [ { "role": @@ -569,13 +1202,84 @@ def test_parse_chat_messages_multiple_images_across_messages( }, ] _assert_mm_data_is_image_input(mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) + + +def test_parse_chat_messages_multiple_images_with_uuids_across_messages( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + image_uuid = str(hash(image_url)) + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + }, + "uuid": image_uuid, + }, + { + "type": "text", + "text": "What's in this image?" + }, + ], + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + }, + "uuid": image_uuid, + }, + { + "type": "text", + "text": "What about this one?" + }, + ], + }, + ], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [ + { + "role": "user", + "content": "<|image_1|>\nWhat's in this image?" + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": "user", + "content": "<|image_2|>\nWhat about this one?" + }, + ] + _assert_mm_data_is_image_input(mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid, image_uuid]) def test_parse_chat_messages_context_text_format( phi3v_model_config, phi3v_tokenizer, ): - conversation, mm_data = parse_chat_messages( + conversation, mm_data, mm_uuids = parse_chat_messages( [ { "role": "user", @@ -621,6 +1325,8 @@ def test_parse_chat_messages_context_text_format( }], }, ] + assert mm_data is None + assert mm_uuids is None def test_parse_chat_messages_rejects_too_many_images_in_one_message( @@ -736,7 +1442,7 @@ def test_parse_chat_messages_multiple_images_uncommon_input( phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages( + conversation, mm_data, mm_uuids = parse_chat_messages( [{ "role": "user", @@ -762,6 +1468,7 @@ def test_parse_chat_messages_multiple_images_uncommon_input( "<|image_1|>\n<|image_2|>\nWhat's in these images?", }] _assert_mm_data_is_image_input(mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) def test_parse_chat_messages_multiple_images_interleave( @@ -769,7 +1476,7 @@ def test_parse_chat_messages_multiple_images_interleave( phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages( + conversation, mm_data, mm_uuids = parse_chat_messages( [{ "role": "user", @@ -813,6 +1520,7 @@ def test_parse_chat_messages_multiple_images_interleave( "Do they have differences?", }] _assert_mm_data_is_image_input(mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) @pytest.mark.asyncio @@ -821,7 +1529,7 @@ async def test_parse_chat_messages_multiple_images_interleave_async( phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages_futures( + conversation, mm_data, mm_uuids = parse_chat_messages_futures( [{ "role": "user", @@ -865,6 +1573,63 @@ async def test_parse_chat_messages_multiple_images_interleave_async( "Do they have differences?", }] _assert_mm_data_is_image_input(await mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) + + +@pytest.mark.asyncio +async def test_parse_chat_messages_multiple_images_with_uuids_interleave_async( + phi3v_model_config_mm_interleaved, + phi3v_tokenizer, + image_url, +): + image_uuid = str(hash(image_url)) + conversation, mm_data, mm_uuids = parse_chat_messages_futures( + [{ + "role": + "user", + "content": [ + { + "type": "text", + "text": "I need you to compare this image", + }, + { + "type": "image_url", + "image_url": { + "url": image_url + }, + "uuid": image_uuid, + }, + { + "type": "text", + "text": "and this one" + }, + { + "type": "image_url", + "image_url": { + "url": image_url + }, + "uuid": image_uuid, + }, + { + "type": "text", + "text": "Do they have differences?" + }, + ], + }], + phi3v_model_config_mm_interleaved, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": + "user", + "content": + "I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501 + "Do they have differences?", + }] + _assert_mm_data_is_image_input(await mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid, image_uuid]) def test_parse_chat_messages_multiple_images_multiple_messages_interleave( @@ -872,7 +1637,7 @@ def test_parse_chat_messages_multiple_images_multiple_messages_interleave( phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages( + conversation, mm_data, mm_uuids = parse_chat_messages( [ { "role": @@ -935,6 +1700,81 @@ def test_parse_chat_messages_multiple_images_multiple_messages_interleave( }, ] _assert_mm_data_is_image_input(mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) + + +def test_parse_chat_messages_multiple_images_with_uuids_multiple_messages_interleave( # noqa: E501 + phi3v_model_config_mm_interleaved, + phi3v_tokenizer, + image_url, +): + image_uuid = str(hash(image_url)) + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's on this image?" + }, + { + "type": "image_url", + "image_url": { + "url": image_url + }, + "uuid": image_uuid, + }, + { + "type": "text", + "text": "Be accurate." + }, + ], + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's on this image?" + }, + { + "type": "image_url", + "image_url": { + "url": image_url + }, + "uuid": image_uuid, + }, + ], + }, + ], + phi3v_model_config_mm_interleaved, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [ + { + "role": "user", + "content": "What's on this image?\n<|image_1|>\nBe accurate.", + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": "user", + "content": "What's on this image?\n<|image_2|>" + }, + ] + _assert_mm_data_is_image_input(mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid, image_uuid]) def test_parse_chat_messages_multiple_modals_multiple_messages_interleave( @@ -944,7 +1784,7 @@ def test_parse_chat_messages_multiple_modals_multiple_messages_interleave( video_url, audio_url, ): - conversation, mm_data = parse_chat_messages( + conversation, mm_data, mm_uuids = parse_chat_messages( [ { "role": @@ -1030,6 +1870,341 @@ def test_parse_chat_messages_multiple_modals_multiple_messages_interleave( ] _assert_mm_data_inputs(mm_data, {"image": 2, "video": 1, "audio": 1}) + _assert_mm_uuids(mm_uuids, + 2, + modality="image", + expected_uuids=[None, None]) + _assert_mm_uuids(mm_uuids, 1, modality="video", expected_uuids=[None]) + _assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[None]) + + +def test_parse_chat_messages_multiple_modals_with_uuids_multiple_messages_interleave( # noqa: E501 + qwen25omni_model_config_mm_interleaved, + qwen25omni_tokenizer, + image_url, + video_url, + audio_url, +): + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's on this image?" + }, + { + "type": "image_url", + "image_url": { + "url": image_url + }, + "uuid": "image_123", + }, + { + "type": "text", + "text": "Now listen to this audio" + }, + { + "type": "audio_url", + "audio_url": { + "url": audio_url + }, + "uuid": "audio_123", + }, + ], + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's on this image?" + }, + { + "type": "image_url", + "image_url": { + "url": image_url + }, + "uuid": "image_123", + }, + { + "type": "text", + "text": "And what's in the video?" + }, + { + "type": "video_url", + "video_url": { + "url": video_url + }, + "uuid": "video_123", + }, + ], + }, + ], + qwen25omni_model_config_mm_interleaved, + qwen25omni_tokenizer, + content_format="string", + ) + + assert conversation == [ + { + "role": + "user", + "content": + "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" + "Now listen to this audio\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>", # noqa: E501 + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": + "user", + "content": + "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" + "And what's in the video?\n<|vision_start|><|VIDEO|><|vision_end|>", + }, + ] + + _assert_mm_data_inputs(mm_data, {"image": 2, "video": 1, "audio": 1}) + _assert_mm_uuids(mm_uuids, + 2, + modality="image", + expected_uuids=["image_123", "image_123"]) + _assert_mm_uuids(mm_uuids, + 1, + modality="video", + expected_uuids=["video_123"]) + _assert_mm_uuids(mm_uuids, + 1, + modality="audio", + expected_uuids=["audio_123"]) + + +def test_parse_chat_messages_multiple_modals_with_uuids_multiple_empty_media_messages_interleave( # noqa: E501 + qwen25omni_model_config_mm_interleaved, + qwen25omni_tokenizer, + image_url, + video_url, + audio_url, +): + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's on this image?" + }, + { + "type": "image_url", + "image_url": None, + "uuid": "image_123", + }, + { + "type": "text", + "text": "Now listen to this audio" + }, + { + "type": "audio_url", + "audio_url": None, + "uuid": "audio_123", + }, + ], + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's on this image?" + }, + { + "type": "image_url", + "image_url": None, + "uuid": "image_123", + }, + { + "type": "text", + "text": "And what's in the video?" + }, + { + "type": "video_url", + "video_url": None, + "uuid": "video_123", + }, + ], + }, + ], + qwen25omni_model_config_mm_interleaved, + qwen25omni_tokenizer, + content_format="string", + ) + + assert conversation == [ + { + "role": + "user", + "content": + "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" + "Now listen to this audio\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>", # noqa: E501 + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": + "user", + "content": + "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" + "And what's in the video?\n<|vision_start|><|VIDEO|><|vision_end|>", + }, + ] + + _assert_mm_data_inputs(mm_data, { + "image": 2, + "video": 1, + "audio": 1 + }, + skipped_media_indices={ + "image": [0, 1], + "video": [0], + "audio": [0] + }) + _assert_mm_uuids(mm_uuids, + 2, + modality="image", + expected_uuids=["image_123", "image_123"]) + _assert_mm_uuids(mm_uuids, + 1, + modality="video", + expected_uuids=["video_123"]) + _assert_mm_uuids(mm_uuids, + 1, + modality="audio", + expected_uuids=["audio_123"]) + + +def test_parse_chat_messages_multiple_modals_with_partial_uuids_multiple_messages_interleave( # noqa: E501 + qwen25omni_model_config_mm_interleaved, + qwen25omni_tokenizer, + image_url, + video_url, + audio_url, +): + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's on this image?" + }, + { + "type": "image_url", + "image_url": { + "url": image_url + }, + "uuid": "image_123", + }, + { + "type": "text", + "text": "Now listen to this audio" + }, + { + "type": "audio_url", + "audio_url": { + "url": audio_url + } + }, + ], + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's on this image?" + }, + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "And what's in the video?" + }, + { + "type": "video_url", + "video_url": { + "url": video_url + }, + "uuid": "video_123", + }, + ], + }, + ], + qwen25omni_model_config_mm_interleaved, + qwen25omni_tokenizer, + content_format="string", + ) + + assert conversation == [ + { + "role": + "user", + "content": + "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" + "Now listen to this audio\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>", # noqa: E501 + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": + "user", + "content": + "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" + "And what's in the video?\n<|vision_start|><|VIDEO|><|vision_end|>", + }, + ] + + _assert_mm_data_inputs(mm_data, {"image": 2, "video": 1, "audio": 1}) + _assert_mm_uuids(mm_uuids, + 2, + modality="image", + expected_uuids=["image_123", None]) + _assert_mm_uuids(mm_uuids, + 1, + modality="video", + expected_uuids=["video_123"]) + _assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[None]) def test_parse_chat_messages_multiple_images_interleave_with_placeholders( @@ -1081,7 +2256,7 @@ def test_mllama_single_image( image_url, ): """Ensures that a single image is parsed correctly mllama.""" - conversation, mm_data = parse_chat_messages( + conversation, mm_data, mm_uuids = parse_chat_messages( [{ "role": "user", @@ -1100,6 +2275,7 @@ def test_mllama_single_image( content_format="openai", ) _assert_mm_data_is_image_input(mm_data, 1) + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[None]) assert conversation == [{ "role": "user", @@ -1121,7 +2297,7 @@ def test_mllama_interleaved_images( image_url, ): """Ensures that multiple image are parsed as interleaved dicts.""" - conversation, mm_data = parse_chat_messages( + conversation, mm_data, mm_uuids = parse_chat_messages( [{ "role": "user", @@ -1147,6 +2323,7 @@ def test_mllama_interleaved_images( content_format="openai", ) _assert_mm_data_is_image_input(mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) assert conversation == [{ "role": "user", @@ -1227,7 +2404,7 @@ def test_multimodal_image_parsing_matches_hf(model, image_url): # Now parse with vLLMs chat utils & apply the template vllm_conversation = get_conversation(is_hf=False) - conversation, _ = parse_chat_messages( + conversation, _, _ = parse_chat_messages( vllm_conversation, model_config, tokenizer_group, @@ -1518,7 +2695,7 @@ def test_parse_chat_messages_include_thinking_chunk(mistral_model_config, }], }] - conversation_with_thinking, _ = parse_chat_messages( + conversation_with_thinking, _, _ = parse_chat_messages( messages, mistral_model_config, mistral_tokenizer, @@ -1643,3 +2820,82 @@ def test_apply_mistral_chat_template_thinking_chunk(): r"[INST]Thanks, what is 3+3?[/INST]") assert string_tokens == expected_tokens + + +def test_parse_chat_messages_single_empty_audio_with_uuid( + qwen2_audio_model_config, + qwen2_audio_tokenizer, +): + audio_uuid = "abcd" + conversation, mm_data, mm_uuids = parse_chat_messages( + [{ + "role": + "user", + "content": [ + { + "type": "input_audio", + "input_audio": {}, + "uuid": audio_uuid, + }, + { + "type": "text", + "text": "What does the audio say?" + }, + ], + }], + qwen2_audio_model_config, + qwen2_audio_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": + "user", + "content": + "Audio 1: <|audio_bos|><|AUDIO|><|audio_eos|>\nWhat does the audio say?" + }] + _assert_mm_data_inputs(mm_data, {"audio": 1}) + _assert_mm_uuids(mm_uuids, + 1, + modality="audio", + expected_uuids=[audio_uuid]) + + +@pytest.mark.asyncio +async def test_parse_chat_messages_single_empty_audio_with_uuid_async( + qwen2_audio_model_config, + qwen2_audio_tokenizer, +): + audio_uuid = "abcd" + conversation, mm_future, mm_uuids = parse_chat_messages_futures( + [{ + "role": + "user", + "content": [ + { + "type": "input_audio", + "input_audio": {}, + "uuid": audio_uuid, + }, + { + "type": "text", + "text": "What does the audio say?" + }, + ], + }], + qwen2_audio_model_config, + qwen2_audio_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": + "user", + "content": + "Audio 1: <|audio_bos|><|AUDIO|><|audio_eos|>\nWhat does the audio say?" + }] + _assert_mm_data_inputs(await mm_future, {"audio": 1}) + _assert_mm_uuids(mm_uuids, + 1, + modality="audio", + expected_uuids=[audio_uuid]) diff --git a/tests/entrypoints/test_renderer.py b/tests/entrypoints/test_renderer.py index 1d80ea6cb4917..1f55b1fba613b 100644 --- a/tests/entrypoints/test_renderer.py +++ b/tests/entrypoints/test_renderer.py @@ -1,13 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import io from dataclasses import dataclass from typing import Optional from unittest.mock import AsyncMock, MagicMock +import pybase64 import pytest +import torch -from vllm.entrypoints.renderer import CompletionRenderer +from vllm.entrypoints.renderer import CompletionRenderer, RenderConfig +from vllm.inputs.data import is_embeds_prompt @dataclass @@ -52,8 +56,8 @@ class TestRenderPrompt: @pytest.mark.asyncio async def test_token_input(self, renderer): tokens = [101, 7592, 2088] - results = await renderer.render_prompt(prompt_or_prompts=tokens, - max_length=100) + results = await renderer.render_prompt( + prompt_or_prompts=tokens, config=RenderConfig(max_length=100)) assert len(results) == 1 assert results[0]["prompt_token_ids"] == tokens @@ -61,8 +65,8 @@ class TestRenderPrompt: @pytest.mark.asyncio async def test_token_list_input(self, renderer): token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]] - results = await renderer.render_prompt(prompt_or_prompts=token_lists, - max_length=100) + results = await renderer.render_prompt( + prompt_or_prompts=token_lists, config=RenderConfig(max_length=100)) assert len(results) == 3 assert results[0]["prompt_token_ids"] == [101, 7592, 2088] @@ -76,8 +80,9 @@ class TestRenderPrompt: renderer.async_tokenizer_pool[ renderer.tokenizer] = mock_async_tokenizer - results = await renderer.render_prompt(prompt_or_prompts="Hello world", - max_length=100) + results = await renderer.render_prompt( + prompt_or_prompts="Hello world", + config=RenderConfig(max_length=100)) assert len(results) == 1 assert results[0]["prompt_token_ids"] == [101, 7592, 2088] @@ -92,7 +97,8 @@ class TestRenderPrompt: text_list_input = ["Hello world", "How are you?", "Good morning"] results = await renderer.render_prompt( - prompt_or_prompts=text_list_input, max_length=100) + prompt_or_prompts=text_list_input, + config=RenderConfig(max_length=100)) assert len(results) == 3 for result in results: @@ -106,8 +112,9 @@ class TestRenderPrompt: renderer.async_tokenizer_pool[ renderer.tokenizer] = mock_async_tokenizer - results = await renderer.render_prompt(prompt_or_prompts="Hello world", - max_length=100) + results = await renderer.render_prompt( + prompt_or_prompts="Hello world", + config=RenderConfig(max_length=100)) assert len(results) == 1 call_args = mock_async_tokenizer.call_args @@ -122,8 +129,9 @@ class TestRenderPrompt: renderer.tokenizer] = mock_async_tokenizer results = await renderer.render_prompt(prompt_or_prompts="Hello world", - max_length=100, - truncate_prompt_tokens=50) + config=RenderConfig( + max_length=100, + truncate_prompt_tokens=50)) assert len(results) == 1 call_args = mock_async_tokenizer.call_args @@ -139,8 +147,9 @@ class TestRenderPrompt: renderer.tokenizer] = mock_async_tokenizer results = await renderer.render_prompt(prompt_or_prompts="Hello world", - max_length=200, - truncate_prompt_tokens=-1) + config=RenderConfig( + max_length=200, + truncate_prompt_tokens=-1)) assert len(results) == 1 call_args = mock_async_tokenizer.call_args @@ -153,8 +162,9 @@ class TestRenderPrompt: long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens results = await renderer.render_prompt(prompt_or_prompts=long_tokens, - max_length=100, - truncate_prompt_tokens=5) + config=RenderConfig( + max_length=100, + truncate_prompt_tokens=5)) assert len(results) == 1 # Should keep the last 5 tokens: [105, 106, 107, 108, 109] @@ -166,7 +176,7 @@ class TestRenderPrompt: with pytest.raises(ValueError, match="maximum context length"): await renderer.render_prompt(prompt_or_prompts=long_tokens, - max_length=100) + config=RenderConfig(max_length=100)) @pytest.mark.asyncio async def test_no_tokenizer_for_text(self, mock_model_config): @@ -177,4 +187,147 @@ class TestRenderPrompt: with pytest.raises(ValueError, match="No tokenizer available"): await renderer_no_tokenizer.render_prompt( - prompt_or_prompts="Hello world", max_length=100) + prompt_or_prompts="Hello world", + config=RenderConfig(max_length=100)) + + @pytest.mark.asyncio + async def test_token_input_with_needs_detokenization( + self, renderer, mock_async_tokenizer): + # When needs_detokenization=True for token inputs, renderer should + # use the async tokenizer to decode and include the original text + # in the returned prompt object. + mock_async_tokenizer.decode = AsyncMock(return_value="decoded text") + renderer.async_tokenizer_pool[ + renderer.tokenizer] = mock_async_tokenizer + + tokens = [1, 2, 3, 4] + results = await renderer.render_prompt( + prompt_or_prompts=tokens, + config=RenderConfig(needs_detokenization=True), + ) + + assert len(results) == 1 + assert results[0]["prompt_token_ids"] == tokens + assert results[0]["prompt"] == "decoded text" + mock_async_tokenizer.decode.assert_awaited_once() + + +class TestRenderEmbedPrompt: + + def _create_test_embed_bytes(self, tensor: torch.Tensor) -> bytes: + """Helper to create base64-encoded tensor bytes""" + buffer = io.BytesIO() + torch.save(tensor, buffer) + buffer.seek(0) + return pybase64.b64encode(buffer.read()) + + @pytest.mark.asyncio + async def test_single_prompt_embed(self, renderer): + # Create a test tensor + test_tensor = torch.randn(10, 768, dtype=torch.float32) + embed_bytes = self._create_test_embed_bytes(test_tensor) + + results = await renderer.render_prompt_and_embeds( + prompt_embeds=embed_bytes, + config=RenderConfig(cache_salt="test_salt"), + ) + + assert len(results) == 1 + assert is_embeds_prompt(results[0]) + assert torch.allclose(results[0]["prompt_embeds"], test_tensor) + assert results[0]["cache_salt"] == "test_salt" + + @pytest.mark.asyncio + async def test_multiple_prompt_embeds(self, renderer): + # Create multiple test tensors + test_tensors = [ + torch.randn(8, 512, dtype=torch.float32), + torch.randn(12, 512, dtype=torch.float32), + ] + embed_bytes_list = [ + self._create_test_embed_bytes(t) for t in test_tensors + ] + + results = await renderer.render_prompt_and_embeds( + prompt_embeds=embed_bytes_list, + config=RenderConfig(), + ) + + assert len(results) == 2 + for i, result in enumerate(results): + assert is_embeds_prompt(result) + assert torch.allclose(result["prompt_embeds"], test_tensors[i]) + + @pytest.mark.asyncio + async def test_prompt_embed_truncation(self, renderer): + # Create tensor with more tokens than truncation limit + test_tensor = torch.randn(20, 768, dtype=torch.float32) + embed_bytes = self._create_test_embed_bytes(test_tensor) + + results = await renderer.render_prompt_and_embeds( + prompt_embeds=embed_bytes, + config=RenderConfig(truncate_prompt_tokens=10), + ) + + assert len(results) == 1 + # Should keep last 10 tokens + expected = test_tensor[-10:] + assert torch.allclose(results[0]["prompt_embeds"], expected) + + @pytest.mark.asyncio + async def test_prompt_embed_different_dtypes(self, renderer): + # Test different supported dtypes + dtypes = [torch.float32, torch.float16, torch.bfloat16] + + for dtype in dtypes: + test_tensor = torch.randn(5, 256, dtype=dtype) + embed_bytes = self._create_test_embed_bytes(test_tensor) + + results = await renderer.render_prompt_and_embeds( + prompt_embeds=embed_bytes, + config=RenderConfig(), + ) + + assert len(results) == 1 + assert results[0]["prompt_embeds"].dtype == dtype + + @pytest.mark.asyncio + async def test_prompt_embed_squeeze_batch_dim(self, renderer): + # Test tensor with batch dimension gets squeezed + test_tensor = torch.randn(1, 10, 768, dtype=torch.float32) + embed_bytes = self._create_test_embed_bytes(test_tensor) + + results = await renderer.render_prompt_and_embeds( + prompt_embeds=embed_bytes, + config=RenderConfig(), + ) + + assert len(results) == 1 + # Should be squeezed to 2D + assert results[0]["prompt_embeds"].shape == (10, 768) + + @pytest.mark.asyncio + async def test_both_prompts_and_embeds(self, renderer, + mock_async_tokenizer): + # Set up text tokenization + mock_async_tokenizer.return_value = MockTokenizerResult( + [101, 102, 103]) + renderer.async_tokenizer_pool[ + renderer.tokenizer] = mock_async_tokenizer + + # Create embed + test_tensor = torch.randn(5, 256, dtype=torch.float32) + embed_bytes = self._create_test_embed_bytes(test_tensor) + + results = await renderer.render_prompt_and_embeds( + prompt_or_prompts="Hello world", + prompt_embeds=embed_bytes, + config=RenderConfig(), + ) + + assert len(results) == 2 + # First should be embed prompt + assert is_embeds_prompt(results[0]) + # Second should be tokens prompt + assert "prompt_token_ids" in results[1] + assert results[1]["prompt_token_ids"] == [101, 102, 103] diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index 3c2aaabacae8c..4d969cf992d23 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -22,7 +22,10 @@ def clear_cache(): # Define MLA and non-MLA backends separately DEVICE_MLA_BACKENDS = { - "cuda": ["TRITON_MLA", "FLASHMLA", "FLASH_ATTN_MLA", "CUTLASS_MLA"], + "cuda": [ + "TRITON_MLA", "FLASHMLA", "FLASHINFER_MLA", "FLASH_ATTN_MLA", + "CUTLASS_MLA" + ], "hip": ["TRITON_MLA", "ROCM_AITER_MLA"], "cpu": [], } @@ -90,8 +93,8 @@ def test_env( with patch("vllm.attention.selector.current_platform", CpuPlatform()): - backend = get_attn_backend(16, torch.float16, torch.float16, - block_size, False) + backend = get_attn_backend(16, torch.float16, None, block_size, + False) assert backend.get_name() == "TORCH_SDPA_VLLM_V1" elif device == "hip": @@ -109,7 +112,7 @@ def test_env( with pytest.raises(ValueError) as exc_info: get_attn_backend(16, torch.float16, - torch.float16, + None, block_size, False, use_mla=use_mla) @@ -120,7 +123,7 @@ def test_env( with pytest.raises(ValueError) as exc_info: get_attn_backend(16, torch.float16, - torch.float16, + None, block_size, False, use_mla=use_mla) @@ -130,7 +133,7 @@ def test_env( # Valid backend-block_size combination backend = get_attn_backend(16, torch.float16, - torch.float16, + None, block_size, False, use_mla=use_mla) @@ -139,7 +142,7 @@ def test_env( else: backend = get_attn_backend(16, torch.float16, - torch.float16, + None, block_size, False, use_mla=use_mla) @@ -153,6 +156,8 @@ def test_env( # CUDA MLA backend logic: # - CUTLASS_MLA: only supported with block_size == 128 # and Blackwell GPUs (SM 10.0), V1 only + # - FLASHINFER_MLA: only supported on Blackwell GPUs + # (SM 10.0+), V1 only # - FLASHMLA: only supported with block_size == 64 # - FLASH_ATTN_MLA: V1 only # - TRITON_MLA: fallback for other cases @@ -169,12 +174,31 @@ def test_env( else: backend = get_attn_backend(16, torch.float16, - torch.float16, + None, block_size, False, use_mla=use_mla) expected = "CUTLASS_MLA_VLLM_V1" assert backend.get_name() == expected + elif name == "FLASHINFER_MLA": + if not use_v1: + # FlashInfer MLA only supported on V1 engine + pytest.skip( + "FlashInfer MLA only supported on V1 engine") + elif block_size not in [32, 64]: + # FlashInfer MLA only supports block_size 32 or 64 + pytest.skip( + "FlashInfer MLA only supports block_size 32 " + "or 64") + else: + backend = get_attn_backend(16, + torch.float16, + None, + block_size, + False, + use_mla=use_mla) + expected = "FLASHINFER_MLA" + assert backend.get_name() == expected elif name == "FLASHMLA": if block_size != 64: # FlashMLA only supports block_size == 64 @@ -189,7 +213,7 @@ def test_env( else: backend = get_attn_backend(16, torch.float16, - torch.float16, + None, block_size, False, use_mla=use_mla) @@ -204,7 +228,7 @@ def test_env( else: backend = get_attn_backend(16, torch.float16, - torch.float16, + None, block_size, False, use_mla=use_mla) @@ -214,7 +238,7 @@ def test_env( # TRITON_MLA or other fallback backend = get_attn_backend(16, torch.float16, - torch.float16, + None, block_size, False, use_mla=use_mla) @@ -224,7 +248,7 @@ def test_env( elif name == "FLASHINFER": backend = get_attn_backend(16, torch.float16, - torch.float16, + None, block_size, False, use_mla=use_mla) @@ -233,7 +257,7 @@ def test_env( else: backend = get_attn_backend(32, torch.float16, - torch.float16, + None, block_size, False, use_mla=use_mla) @@ -243,7 +267,7 @@ def test_env( if use_v1: backend = get_attn_backend(16, torch.float16, - torch.float16, + None, block_size, False, use_mla=use_mla) @@ -269,15 +293,13 @@ def test_fp32_fallback( with patch("vllm.attention.selector.current_platform", CpuPlatform()): - backend = get_attn_backend(16, torch.float32, torch.float32, - 16, False) + backend = get_attn_backend(16, torch.float32, None, 16, False) assert backend.get_name() == "TORCH_SDPA_VLLM_V1" elif device == "cuda": with patch("vllm.attention.selector.current_platform", CudaPlatform()): - backend = get_attn_backend(16, torch.float32, torch.float32, - 16, False) + backend = get_attn_backend(16, torch.float32, None, 16, False) assert (backend.get_name() == "FLEX_ATTENTION" if use_v1 else "XFORMERS") @@ -331,7 +353,7 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch): assert backend.get_name() != STR_FLASH_ATTN_VAL # Attention-free models should bypass env and use PlaceholderAttention - backend = get_attn_backend(16, torch.float16, torch.float16, 16, True) + backend = get_attn_backend(16, torch.float16, None, 16, True) assert backend.get_name() != STR_FLASH_ATTN_VAL diff --git a/tests/kernels/test_cutlass_mla_decode.py b/tests/kernels/attention/test_cutlass_mla_decode.py similarity index 83% rename from tests/kernels/test_cutlass_mla_decode.py rename to tests/kernels/attention/test_cutlass_mla_decode.py index 85984324b1967..5078bd730a1a3 100644 --- a/tests/kernels/test_cutlass_mla_decode.py +++ b/tests/kernels/attention/test_cutlass_mla_decode.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math import random +from typing import Optional import pytest import torch @@ -14,14 +15,20 @@ from vllm.triton_utils import triton def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, - use_fp8: bool = False) -> None: + use_fp8: bool = False, + diff_threshold: Optional[float] = None) -> None: x, y = x.double(), y.double() cos_diff = 1 - 2 * (x * y).sum().item() / max( (x * x + y * y).sum().item(), 1e-12) - if (use_fp8): - assert cos_diff < 1e-4 + if diff_threshold is not None: + # directly compare the cos_diff with the threshold + assert cos_diff < diff_threshold else: - assert cos_diff < 1e-5 + # use the default threshold + if (use_fp8): + assert cos_diff < 1e-4 + else: + assert cos_diff < 1e-5 CUTLASS_MLA_UNSUPPORTED_REASON = \ @@ -42,7 +49,13 @@ CUTLASS_MLA_UNSUPPORTED_REASON = \ @pytest.mark.parametrize("block_size", [64]) @pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("varlen", [False, True]) -@pytest.mark.parametrize("torch_dtype", [torch.bfloat16, torch.float8_e4m3fn]) +@pytest.mark.parametrize( + "torch_dtype", + [ + torch.bfloat16, + # fp8 can have occasional precision-related failures. + pytest.param(torch.float8_e4m3fn, marks=pytest.mark.flaky(reruns=2)) + ]) @torch.inference_mode() def test_cutlass_mla_decode(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, varlen, torch_dtype): @@ -118,11 +131,13 @@ def test_cutlass_mla_decode(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, dtype=torch.uint8) out_ans = torch.empty(b, MAX_HEADS, dv, dtype=init_dtype) - - ops.sm100_cutlass_mla_decode(out_ans, q_nope, q_pe, kv_cache_flat, - cache_seqlens, block_table, workspace, - scale, 1) - return out_ans[:, :h_q].contiguous() + output_lse = torch.empty((b, MAX_HEADS), + dtype=torch.float32, + device=q_nope.device) + ops.sm100_cutlass_mla_decode(out_ans, output_lse, q_nope, q_pe, + kv_cache_flat, cache_seqlens, block_table, + workspace, scale, 1) + return out_ans[:, :h_q].contiguous(), output_lse[:, :h_q].contiguous() def scaled_dot_product_attention(query, key, value, is_causal=False): query = query.float() @@ -165,11 +180,14 @@ def test_cutlass_mla_decode(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, lse[i] = lse_i return out, lse - out_cutlass = cutlass_mla() + out_cutlass, lse_cutlass = cutlass_mla() out_torch, lse_torch = ref_mla() # Extract the single token (s_q=1) slice to match cutlass output shape out_torch_slice = out_torch[:, 0, :, :] # [b, h_q, dv] + lse_torch_slice = lse_torch[:, 0, :] # [b, h_q] cal_diff(out_cutlass, out_torch_slice, "out", use_fp8) + # lse has larger numerical error, so use a larger threshold + cal_diff(lse_cutlass, lse_torch_slice, "lse", use_fp8, diff_threshold=1e-3) t = triton.testing.do_bench(cutlass_mla) FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 diff --git a/tests/kernels/attention/test_flashinfer_mla_decode.py b/tests/kernels/attention/test_flashinfer_mla_decode.py new file mode 100644 index 0000000000000..02225432f77fc --- /dev/null +++ b/tests/kernels/attention/test_flashinfer_mla_decode.py @@ -0,0 +1,123 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch +import torch.nn.functional as F +from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla +from torch import Tensor + +from vllm.platforms import current_platform + +FLASHINFER_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024 + +if not current_platform.has_device_capability(100): + pytest.skip( + reason="FlashInfer MLA Requires compute capability of 10 or above.", + allow_module_level=True) + + +def ref_mla( + out: Tensor, # (bs, num_heads, v_head_dim) + query: Tensor, # (bs, num_heads, head_dim) + kv_cache: Tensor, # (num_blocks, block_size, head_dim) + scale: float, + block_tables: Tensor, # (bs, max_num_blocks) + seq_lens: Tensor, # (bs,) +): + bs, num_heads, v_head_dim = out.shape + head_dim = query.shape[2] + + for i in range(bs): + # gather and flatten KV-cache + kv = kv_cache[ + block_tables[i]] # (max_num_blocks, block_size, head_dim) + kv = kv.view(1, -1, + head_dim)[:, :seq_lens[i]] # (1, seq_len, head_dim) + v = kv[:, :, :v_head_dim] + + q = query[i].view(num_heads, 1, head_dim) + o = F.scaled_dot_product_attention(q, + kv, + v, + scale=scale, + enable_gqa=True) + out[i] = o.view(num_heads, v_head_dim) + + return out + + +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("bs", [1, 2, 4, 16]) +@pytest.mark.parametrize("block_size", [32, 64]) +def test_flashinfer_mla_decode(dtype: torch.dtype, bs: int, block_size: int): + torch.set_default_device('cuda') + torch.manual_seed(42) + + # Deepseek R1 config + num_heads = 128 + kv_lora_rank = 512 + qk_nope_head_dim = 128 + qk_rope_head_dim = 64 + qk_head_dim = kv_lora_rank + qk_rope_head_dim + scale = (qk_nope_head_dim + qk_rope_head_dim)**-0.5 + + MAX_SEQ_LEN = 1024 + + seq_lens = [torch.randint(2, MAX_SEQ_LEN, (1, )).item() for _ in range(bs)] + seq_lens[-1] = MAX_SEQ_LEN + max_seq_len = max(seq_lens) + seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int32) + + # Generate block tables with random but unique block IDs + # From https://github.com/flashinfer-ai/flashinfer/pull/1222 + blocks_per_seq = (seq_lens_tensor + block_size - 1) // block_size + max_num_blocks_per_seq = max(blocks_per_seq.max().item(), 4) + total_blocks_needed = sum(blocks_per_seq) + # Get random unique IDs for all blocks + all_block_ids = torch.randperm(total_blocks_needed) + + block_id = 0 + block_tables = torch.zeros( + (bs, max_num_blocks_per_seq), + dtype=torch.int32, + ) + + # Populate block tables and track block assignments + block_id = 0 + for i in range(bs): + num_blocks_needed = blocks_per_seq[i] + block_tables[i, :num_blocks_needed] = all_block_ids[block_id:block_id + + num_blocks_needed] + block_id += num_blocks_needed + + kv_cache = torch.randn(block_tables.numel(), block_size, + qk_head_dim).to(dtype) + q = torch.randn(bs, num_heads, qk_head_dim).to(dtype) + + out_ref = q.new_zeros(bs, num_heads, kv_lora_rank) + ref_mla(out_ref, q, kv_cache, scale, block_tables, seq_lens_tensor) + + workspace_buffer = torch.zeros( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=q.device, + ) + # Flashinfer MLA expects the query to be of shape + # (bs, q_len_per_request, num_heads, qk_head_dim), + # where q_len_per_request is the MTP query length (=1 without MTP) + q = q.unsqueeze(1) + + out_ans = trtllm_batch_decode_with_kv_cache_mla( + query=q, + kv_cache=kv_cache.unsqueeze(1), + workspace_buffer=workspace_buffer, + qk_nope_head_dim=qk_nope_head_dim, + kv_lora_rank=kv_lora_rank, + qk_rope_head_dim=qk_rope_head_dim, + block_tables=block_tables, + seq_lens=seq_lens_tensor, + max_seq_len=max_seq_len, + bmm1_scale=scale, + ) + out_ans = out_ans.squeeze(1) + torch.testing.assert_close(out_ans, out_ref, atol=1e-2, rtol=1e-2) diff --git a/tests/kernels/attention/test_flashinfer_trtllm_attention.py b/tests/kernels/attention/test_flashinfer_trtllm_attention.py index 8d0a11d8eb8ab..bd3ba554b32e2 100644 --- a/tests/kernels/attention/test_flashinfer_trtllm_attention.py +++ b/tests/kernels/attention/test_flashinfer_trtllm_attention.py @@ -35,6 +35,7 @@ QUANT_DTYPES = [ # (q_quant_dtype, kv_quant_dtype, o_quant_dtype) (None, None, None), (None, FP8_DTYPE, None), + (FP8_DTYPE, FP8_DTYPE, None), (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), (FP8_DTYPE, FP8_DTYPE, FP4_DTYPE), ] @@ -44,6 +45,7 @@ NUM_HEADS = [(64, 8), (40, 8)] HEAD_SIZE = [128] KV_LAYOUT = ["HND"] # currently only HND is supported BLOCK_SIZE = [16] +WINDOW_LEFT = [-1, 127] SOFT_CAP = [None, 50.0] NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation. @@ -57,6 +59,7 @@ NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation. @pytest.mark.parametrize("head_size", HEAD_SIZE) @pytest.mark.parametrize("kv_layout", KV_LAYOUT) @pytest.mark.parametrize("block_size", BLOCK_SIZE) +@pytest.mark.parametrize("window_left", WINDOW_LEFT) @pytest.mark.parametrize("soft_cap", SOFT_CAP) @torch.inference_mode def test_flashinfer_trtllm_decode_with_baseline( @@ -69,6 +72,7 @@ def test_flashinfer_trtllm_decode_with_baseline( head_size: int, kv_layout: str, block_size: int, + window_left: int, soft_cap: Optional[float], ) -> None: torch.set_default_device("cuda") @@ -155,6 +159,7 @@ def test_flashinfer_trtllm_decode_with_baseline( sm_scale=sm_scale, q_data_type=dtype, kv_data_type=dtype, + window_left=window_left, logits_soft_cap=soft_cap) output = torch.empty(ref_query.shape, dtype=dtype) @@ -188,6 +193,7 @@ def test_flashinfer_trtllm_decode_with_baseline( max_seq_len=max_seq_len, bmm1_scale=q_scale * k_scale * sm_scale, bmm2_scale=v_scale / o_scale, + window_left=window_left, o_sf_scale=o_sf_scale, out=output_trtllm, ) @@ -222,6 +228,7 @@ def test_flashinfer_trtllm_decode_with_baseline( @pytest.mark.parametrize("head_size", HEAD_SIZE) @pytest.mark.parametrize("kv_layout", KV_LAYOUT) @pytest.mark.parametrize("block_size", BLOCK_SIZE) +@pytest.mark.parametrize("window_left", WINDOW_LEFT) @pytest.mark.parametrize("soft_cap", [None]) @torch.inference_mode def test_flashinfer_trtllm_prefill_with_baseline( @@ -234,6 +241,7 @@ def test_flashinfer_trtllm_prefill_with_baseline( head_size: int, kv_layout: str, block_size: int, + window_left: int, soft_cap: Optional[float], ) -> None: torch.set_default_device("cuda") @@ -334,6 +342,7 @@ def test_flashinfer_trtllm_prefill_with_baseline( sm_scale=sm_scale, q_data_type=dtype, kv_data_type=dtype, + window_left=window_left, logits_soft_cap=soft_cap) output = torch.empty(ref_query.shape, dtype=dtype) @@ -371,6 +380,7 @@ def test_flashinfer_trtllm_prefill_with_baseline( batch_size=batch_size, cum_seq_lens_q=q_indptr, cum_seq_lens_kv=kv_indptr, + window_left=window_left, o_sf_scale=o_sf_scale, out=output_trtllm, ) @@ -390,6 +400,8 @@ def test_flashinfer_trtllm_prefill_with_baseline( rtol, atol = 4e-1, 1e0 elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE: rtol, atol = 5e-2, 7e-2 + elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype: + rtol, atol = 4e-2, 6e-2 else: rtol, atol = 1e-2, 1e-2 diff --git a/tests/kernels/attention/test_mha_attn.py b/tests/kernels/attention/test_mha_attn.py index 53c37554b15a3..d37b968ed9792 100644 --- a/tests/kernels/attention/test_mha_attn.py +++ b/tests/kernels/attention/test_mha_attn.py @@ -23,6 +23,9 @@ def clear_cache(): """Clear lru cache to ensure each test case runs without caching. """ _cached_get_attn_backend.cache_clear() + # Clear xformers availability cache + import vllm.attention.layer as layer_module + layer_module.USE_XFORMERS_OPS = None @pytest.mark.parametrize("device", ["cpu", "hip", "cuda"]) @@ -33,22 +36,52 @@ def test_mha_attn_platform(device: str): torch.set_default_dtype(torch.float16) if device == "cpu": - with patch("vllm.attention.selector.current_platform", CpuPlatform()): + with patch("vllm.attention.layer.current_platform", CpuPlatform()), \ + patch("vllm.model_executor.models.vision.current_platform", + CpuPlatform()): attn = MultiHeadAttention(16, 64, scale=1) assert attn.attn_backend == _Backend.TORCH_SDPA elif device == "hip": - with patch("vllm.attention.selector.current_platform", RocmPlatform()): + with patch("vllm.attention.layer.current_platform", RocmPlatform()), \ + patch("vllm.model_executor.models.vision.current_platform", + RocmPlatform()): attn = MultiHeadAttention(16, 64, scale=1) assert attn.attn_backend == _Backend.TORCH_SDPA else: - with patch("vllm.attention.selector.current_platform", CudaPlatform()): + # Test CUDA with head_size=64 (divisible by 32) + # - should use vLLM's FlashAttention + with patch("vllm.attention.layer.current_platform", CudaPlatform()), \ + patch("vllm.model_executor.models.vision.current_platform", + CudaPlatform()): attn = MultiHeadAttention(16, 64, scale=1) - assert attn.attn_backend == _Backend.XFORMERS + assert attn.attn_backend == _Backend.FLASH_ATTN - with patch("vllm.attention.selector.current_platform", CudaPlatform()): + # Test CUDA with head_size=72 (not divisible by 32) + # - with upstream FA not available + # - should use xformers + with patch("vllm.attention.layer.current_platform", CudaPlatform()), \ + patch("vllm.model_executor.models.vision.current_platform", + CudaPlatform()), \ + patch("vllm.attention.layer.check_upstream_fa_availability", + return_value=False): attn = MultiHeadAttention(16, 72, scale=1) assert attn.attn_backend == _Backend.XFORMERS + # Test CUDA with head_size=72 (not divisible by 32) + # - with upstream FA available + # - should use upstream FA + with patch("vllm.attention.layer.current_platform", CudaPlatform()), \ + patch("vllm.model_executor.models.vision.current_platform", + CudaPlatform()), \ + patch("vllm.attention.layer.check_upstream_fa_availability", + return_value=True), \ + patch.dict('sys.modules', {'flash_attn': type('MockFlashAttn', (), + { + 'flash_attn_varlen_func': lambda *args, **kwargs: None + })()}): + attn = MultiHeadAttention(16, 72, scale=1) + assert attn.attn_backend == _Backend.FLASH_ATTN + def ref_attention( query: torch.Tensor, diff --git a/tests/kernels/core/test_layernorm.py b/tests/kernels/core/test_layernorm.py index 02316ceaac735..53e6d793cf2f9 100644 --- a/tests/kernels/core/test_layernorm.py +++ b/tests/kernels/core/test_layernorm.py @@ -6,7 +6,7 @@ import torch from tests.kernels.quant_utils import FP8_DTYPE from tests.kernels.utils import opcheck -from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.layernorm import PolyNorm, RMSNorm from vllm.platforms import current_platform DTYPES = [torch.half, torch.bfloat16, torch.float] @@ -70,6 +70,37 @@ def test_rms_norm( (out, x, layer.weight.data, layer.variance_epsilon)) +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_poly_norm( + num_tokens: int, + hidden_size: int, + dtype: torch.dtype, + seed: int, + device: str, +) -> None: + current_platform.seed_everything(seed) + torch.set_default_device(device) + layer = PolyNorm().to(dtype=dtype) + layer.weight.data.normal_(mean=1.0, std=0.1) + layer.bias.data.normal_(mean=1.0, std=0.1) + scale = 1 / (2 * hidden_size) + x = torch.randn(num_tokens, hidden_size, dtype=dtype) + x *= scale + + ref_out = layer.forward_native(x) + out = layer(x) + torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2) + + opcheck( + torch.ops._C.poly_norm, + (out, x, layer.weight.data, layer.bias.data, layer.variance_epsilon)) + + @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("add_residual", ADD_RESIDUAL) diff --git a/tests/kernels/core/test_pos_encoding.py b/tests/kernels/core/test_pos_encoding.py index ab6f1ccf881fd..bf9b1d9b4401a 100644 --- a/tests/kernels/core/test_pos_encoding.py +++ b/tests/kernels/core/test_pos_encoding.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from itertools import accumulate, product +from itertools import product from typing import Callable, Optional import pytest @@ -111,151 +111,6 @@ def test_rotary_embedding( "expected returned key to be None" -@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) -@pytest.mark.parametrize("tensor_shape_fn", TENSORS_SHAPES_FN) -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("seq_len", SEQ_LENS) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.parametrize("use_key", USE_KEY) -@torch.inference_mode() -def test_batched_rotary_embedding( - is_neox_style: bool, - tensor_shape_fn: Callable[[int, int, int, int], tuple[int]], - batch_size: int, - seq_len: int, - num_heads: int, - head_size: int, - rotary_dim: Optional[int], - dtype: torch.dtype, - seed: int, - device: str, - use_key: bool, - max_position: int = 8192, - base: float = 10000, -) -> None: - current_platform.seed_everything(seed) - torch.set_default_device(device) - if rotary_dim is None: - rotary_dim = head_size - rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, { - "rope_type": "linear", - "factor": (1, ) - }) - rope = rope.to(dtype=dtype, device=torch.get_default_device()) - - positions = torch.randint(0, max_position, (batch_size, seq_len)) - query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size) - query = torch.randn(query_shape, dtype=dtype) - key = torch.randn_like(query) if use_key else None - - # slice tensor if required, noop otherwise - query = query[..., :head_size] - key = key[..., :head_size] if use_key else None - - # NOTE(woosuk): The reference implementation should be executed first - # because the custom kernel is in-place. - ref_query, ref_key = rope.forward_native(positions, query, key) - out_query, out_key = rope.forward(positions, - query, - key, - offsets=torch.zeros(batch_size * seq_len, - dtype=torch.long, - device=device)) - # Compare the results. - torch.testing.assert_close(out_query, - ref_query, - atol=get_default_atol(out_query), - rtol=get_default_rtol(out_query)) - if use_key: - torch.testing.assert_close(out_key, - ref_key, - atol=get_default_atol(out_key), - rtol=get_default_rtol(out_key)) - else: - assert ref_key is None and out_key is None, \ - "expected returned key to be None" - - -@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("seq_len", SEQ_LENS) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.parametrize("use_key", USE_KEY) -@torch.inference_mode() -def test_batched_rotary_embedding_multi_lora( - is_neox_style: bool, - batch_size: int, - seq_len: int, - num_heads: int, - head_size: int, - rotary_dim: Optional[int], - dtype: torch.dtype, - seed: int, - device: str, - use_key: bool, - max_position: int = 8192, - base: float = 10000, -) -> None: - current_platform.seed_everything(seed) - torch.set_default_device(device) - if rotary_dim is None: - rotary_dim = head_size - scaling_factors: list[int] = [1, 2, 4] - rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, { - "rope_type": "linear", - "factor": tuple(scaling_factors) - }) - rope = rope.to(dtype=dtype, device=torch.get_default_device()) - - positions = torch.randint(0, max_position, (batch_size, seq_len)) - query = torch.randn(batch_size, - seq_len, - num_heads * head_size, - dtype=dtype) - key = torch.randn_like(query) if use_key else None - - offset_map = torch.tensor( - list( - accumulate([0] + [ - max_position * scaling_factor * 2 - for scaling_factor in scaling_factors[:-1] - ]))) - query_types = torch.randint(0, - len(scaling_factors), (batch_size, seq_len), - device=device) - query_offsets = offset_map[query_types] - - # NOTE(woosuk): The reference implementation should be executed first - # because the custom kernel is in-place. - ref_query, ref_key = rope.forward_native(positions, query, key, - query_offsets) - out_query, out_key = rope.forward(positions, query, key, - query_offsets.flatten()) - # Compare the results. - torch.testing.assert_close(out_query, - ref_query, - atol=get_default_atol(out_query), - rtol=get_default_rtol(out_query)) - if use_key: - torch.testing.assert_close(out_key, - ref_key, - atol=get_default_atol(out_key), - rtol=get_default_rtol(out_key)) - else: - assert ref_key is None and out_key is None, \ - "expected returned key to be None" - - @torch.inference_mode() def test_rope_module_cache(): MAX_POSITIONS = [123, 1234] diff --git a/tests/kernels/core/test_rotary_embedding.py b/tests/kernels/core/test_rotary_embedding.py index d1fd960bf115c..5857dd5ba3fad 100644 --- a/tests/kernels/core/test_rotary_embedding.py +++ b/tests/kernels/core/test_rotary_embedding.py @@ -16,20 +16,14 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding def rotary_embedding_opcheck(rot, positions: torch.Tensor, query: torch.Tensor, - key: Optional[torch.Tensor] = None, - offsets: Optional[torch.Tensor] = None): + key: Optional[torch.Tensor] = None): cos_sin_cache = rot.cos_sin_cache.to(query.device, dtype=query.dtype) - # ops.rotary_embedding()/batched_rotary_embedding() - # are in-place operations that update the query and key tensors. - if offsets is not None: - opcheck(torch.ops._C.batched_rotary_embedding, - (positions, query, key, rot.head_size, cos_sin_cache, - rot.is_neox_style, rot.rotary_dim, offsets)) - else: - opcheck(torch.ops._C.rotary_embedding, - (positions, query, key, rot.head_size, cos_sin_cache, - rot.is_neox_style)) + # ops.rotary_embedding() is a in-place operation + # that updates the query and key tensors. + opcheck(torch.ops._C.rotary_embedding, + (positions, query, key, rot.head_size, cos_sin_cache, + rot.is_neox_style)) @pytest.mark.parametrize("device", ["cuda"]) @@ -65,10 +59,6 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position, key = key[..., :head_size] if use_key else None rotary_embedding_opcheck(rot, positions, query, key) - offsets = torch.zeros(batch_size * seq_len, - device=device, - dtype=torch.long) - rotary_embedding_opcheck(rot, positions, query, key, offsets) # if we have a contiguous head stride, test the alternate # [..., num_heads * head_dim] shape/layout diff --git a/tests/kernels/mamba/test_mamba_ssm_ssd.py b/tests/kernels/mamba/test_mamba_ssm_ssd.py index 2c554baaff76c..fc60d5ac82b27 100644 --- a/tests/kernels/mamba/test_mamba_ssm_ssd.py +++ b/tests/kernels/mamba/test_mamba_ssm_ssd.py @@ -115,21 +115,27 @@ def generate_continuous_batched_examples(example_lens_by_batch, n_heads, d_head, itype, - device='cuda'): + device='cuda', + return_naive_ref=True): # this function generates a random examples of certain length # and then cut according to "example_lens_by_batch" and feed - # them in continuous batches to the kernels + # them in continuous batches to the kernels. + # If if return_naive_ref=True, the naive torch implementation + # ssd_minimal_discrete will be used to compute and return + # reference output. # generate the full-length example A, dt, X, B, C = generate_random_inputs(num_examples, full_length, n_heads, d_head, itype) - Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), - A * dt, - B, - C, - block_len=full_length // 4) + if return_naive_ref: + Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), + A * dt, + B, + C, + block_len=full_length // + 4) # internal function that outputs a cont batch of examples # given a tuple of lengths for each example in the batch @@ -179,7 +185,8 @@ def generate_continuous_batched_examples(example_lens_by_batch, IND_S = [x % full_length for x in IND_E] IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)] - yield ([Y_min[s, IND_S[s]:IND_E[s]] for s in range(num_examples)], + yield ([Y_min[s, IND_S[s]:IND_E[s]] + for s in range(num_examples)] if return_naive_ref else None, cu_seqlens, seq_idx.unsqueeze(0), (A, dt2, X2, B2, C2)) @@ -324,3 +331,213 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, if clear: states[i].fill_(0.) exhausted[i] = False + + +@pytest.mark.parametrize("chunk_size", [8, 256]) +@pytest.mark.parametrize("seqlens", [ + (16, 2, 8, 13), + (270, 88, 212, 203), + (16, 20), +]) +def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens): + + # This test verifies the correctness of the chunked prefill implementation + # in the mamba2 ssd kernels, by comparing concatenation (in the sequence + # dimension) of chunked results with the full sequence result. + # It is different from test_mamba_chunk_scan_cont_batch by: + # 1. Not using the naive torch implementation (ssd_minimal_discrete) to get + # reference outputs. Instead, it compares chunked kernel outputs to full + # sequence kernel outputs. This is the most straightforward way to + # assert chunked prefill correctness. + # 2. It focuses on cases where sequences change in the middle of mamba + # chunks, and not necessarily on chunk boundaries. + + max_seqlen = max(seqlens) + # This test can have larger error for longer sequences + if max_seqlen > 256: + atol, rtol = 1e-2, 5e-3 + else: + atol, rtol = 5e-3, 5e-3 + + num_sequences = len(seqlens) + n_heads = 16 + d_head = 64 + itype = torch.float32 + + # hold state during the cutting process so we know if an + # example has been exhausted and needs to cycle + last_taken: dict = {} # map: eg -> pointer to last taken sample + exhausted: dict = {} # map: eg -> boolean indicating example is exhausted + _, cu_seqlens, seq_idx, (A, dt, X, B, C) = next( + generate_continuous_batched_examples([seqlens], + num_sequences, + max_seqlen, + last_taken, + exhausted, + n_heads, + d_head, + itype, + return_naive_ref=False)) + seqlens = torch.tensor(seqlens, dtype=torch.int32, device=X.device) + device = X.device + + ## full seqlen computation + chunk_indices, chunk_offsets = \ + _query_start_loc_to_chunk_indices_offsets( + cu_seqlens, chunk_size, cu_seqlens[-1]) + Y_ref = torch.empty_like(X) + state_ref = mamba_chunk_scan_combined( + X, + dt, + A, + B, + C, + chunk_size, + D=None, + cu_seqlens=cu_seqlens, + seq_idx=seq_idx, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, + return_varlen_states=True, + initial_states=None, + out=Y_ref, + ) + + ## chunked seqlen computation + # first chunk + chunked_seqlens = seqlens // 2 + chunked_cu_seqlens = torch.cat([ + torch.tensor([0], device=device), + torch.cumsum(chunked_seqlens, dim=0) + ], + dim=0) + chunked_seq_idx = torch.repeat_interleave( + torch.arange(len(chunked_seqlens), device=device), + chunked_seqlens, + output_size=chunked_cu_seqlens[-1]).unsqueeze(0).to(torch.int32) + chunked_input_seq_len = chunked_cu_seqlens[-1] + X_chunked = torch.zeros_like(X)[:, :chunked_input_seq_len, ...] + dt_chunked = torch.zeros_like(dt)[:, :chunked_input_seq_len, ...] + B_chunked = torch.zeros_like(B)[:, :chunked_input_seq_len, ...] + C_chunked = torch.zeros_like(C)[:, :chunked_input_seq_len, ...] + for i in range(num_sequences): + # fmt: off + chunk_f = lambda x, i: x[:, cu_seqlens[i]:cu_seqlens[i] + chunked_seqlens[i], ...] # noqa: E501 + + X_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(X, i) # noqa: E501 + dt_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(dt, i) # noqa: E501 + B_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(B, i) # noqa: E501 + C_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(C, i) # noqa: E501 + # fmt: on + + chunk_indices, chunk_offsets = \ + _query_start_loc_to_chunk_indices_offsets( + chunked_cu_seqlens, chunk_size, chunked_cu_seqlens[-1]) + Y_partial = torch.empty_like(X_chunked) + partial_state = mamba_chunk_scan_combined( + X_chunked, + dt_chunked, + A, + B_chunked, + C_chunked, + chunk_size, + D=None, + cu_seqlens=chunked_cu_seqlens, + seq_idx=chunked_seq_idx, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, + return_varlen_states=True, + initial_states=None, + out=Y_partial, + ) + + # remaining chunk + remaining_chunked_seqlens = seqlens - chunked_seqlens + remaining_chunked_cu_seqlens = torch.cat([ + torch.tensor([0], device=device), + torch.cumsum(remaining_chunked_seqlens, dim=0) + ], + dim=0) + remaining_chunked_seq_idx = torch.repeat_interleave( + torch.arange(len(remaining_chunked_seqlens), device=device), + remaining_chunked_seqlens, + output_size=remaining_chunked_cu_seqlens[-1]).unsqueeze(0).to( + torch.int32) + remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens[-1] + # fmt: off + remaining_X_chunked = torch.zeros_like(X)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501 + remaining_dt_chunked = torch.zeros_like(dt)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501 + remaining_B_chunked = torch.zeros_like(B)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501 + remaining_C_chunked = torch.zeros_like(C)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501 + for i in range(num_sequences): + remaining_chunk_f = lambda x, i: x[:, cu_seqlens[i] + chunked_seqlens[i]:cu_seqlens[i+1], ...] # noqa: E501 + + remaining_X_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(X, i) # noqa: E501 + remaining_dt_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(dt, i) # noqa: E501 + remaining_B_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(B, i) # noqa: E501 + remaining_C_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(C, i) # noqa: E501 + + # assert input chunking is correct + concat_chunk_f = lambda pt1, pt2, i: torch.cat([ + pt1[:,chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1],...], + pt2[:,remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1],...], + ], + dim=1) + concat_batch_f = lambda pt1, pt2: torch.cat([concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=1) # noqa: E501 + # fmt: on + + assert concat_batch_f(X_chunked, remaining_X_chunked).equal(X) + assert concat_batch_f(dt_chunked, remaining_dt_chunked).equal(dt) + assert concat_batch_f(B_chunked, remaining_B_chunked).equal(B) + assert concat_batch_f(C_chunked, remaining_C_chunked).equal(C) + + chunk_indices, chunk_offsets = \ + _query_start_loc_to_chunk_indices_offsets( + remaining_chunked_cu_seqlens, + chunk_size, + remaining_chunked_cu_seqlens[-1]) + + Y_chunked = torch.empty_like(remaining_X_chunked) + state_chunked = mamba_chunk_scan_combined( + remaining_X_chunked, + remaining_dt_chunked, + A, + remaining_B_chunked, + remaining_C_chunked, + chunk_size, + D=None, + cu_seqlens=remaining_chunked_cu_seqlens, + seq_idx=remaining_chunked_seq_idx, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, + return_varlen_states=True, + initial_states=partial_state, + out=Y_chunked, + ) + Y = concat_batch_f(Y_partial, Y_chunked) + + # kernel chunked is same as kernel overall + for i in range(num_sequences): + Y_seq = Y[:, cu_seqlens[i]:cu_seqlens[i + 1], ...] + Y_ref_seq = Y_ref[:, cu_seqlens[i]:cu_seqlens[i + 1], ...] + torch.testing.assert_close( + Y_seq[:, :chunked_seqlens[i], ...], + Y_ref_seq[:, :chunked_seqlens[i], ...], + atol=atol, + rtol=rtol, + msg=lambda x: f"seq{i} output part1 " + x) # noqa: B023 + torch.testing.assert_close( + Y_seq[:, chunked_seqlens[i]:, ...], + Y_ref_seq[:, chunked_seqlens[i]:, ...], + atol=atol, + rtol=rtol, + msg=lambda x: f"seq{i} output part2 " + x) # noqa: B023 + + state_seq = state_chunked[i] + state_seq_ref = state_ref[i] + torch.testing.assert_close( + state_seq, + state_seq_ref, + atol=atol, + rtol=rtol, + msg=lambda x: f"seq{i} state " + x) # noqa: B023 diff --git a/tests/kernels/moe/test_mxfp4_moe.py b/tests/kernels/moe/test_mxfp4_moe.py index c29bed3dd6b32..a3b8f07638d9a 100644 --- a/tests/kernels/moe/test_mxfp4_moe.py +++ b/tests/kernels/moe/test_mxfp4_moe.py @@ -11,6 +11,7 @@ import torch from packaging import version from vllm.platforms import current_platform +from vllm.utils.flashinfer import has_flashinfer QUARK_MXFP4_AVAILABLE = importlib.util.find_spec( "quark") is not None and version.parse( @@ -19,11 +20,17 @@ QUARK_MXFP4_AVAILABLE = importlib.util.find_spec( TRTLLM_GEN_MXFP4_AVAILABLE = current_platform.is_cuda( ) and current_platform.is_device_capability(100) +HOPPER_MXFP4_BF16_AVAILABLE = (current_platform.is_cuda() + and current_platform.is_device_capability(90) + and has_flashinfer()) + if TRTLLM_GEN_MXFP4_AVAILABLE: from flashinfer import (fp4_quantize, mxfp8_quantize, next_positive_power_of_2, reorder_rows_for_gated_act_gemm, shuffle_matrix_a, shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe) + from flashinfer.fp4_quantization import nvfp4_block_scale_interleave + from flashinfer.fused_moe.core import _maybe_get_cached_w2_permute_indices @dataclass @@ -204,6 +211,7 @@ def tg_mxfp4_moe( alpha, beta, limit, + transpose_optimized: bool = False, ) -> torch.Tensor: sf_block_size = 32 assert (w13_weight.dim() == 3 and w13_weight.shape[0] == num_experts @@ -267,22 +275,85 @@ def tg_mxfp4_moe( gemm1_bias_shuffled = [] gemm2_bias_shuffled = [] epilogue_tile_m = 128 # FIXME: this depends on the kernel internals - for i in range(num_experts): - gemm1_weights_shuffled.append( - shuffle_matrix_a(w13_weight[i].view(torch.uint8), epilogue_tile_m)) - gemm1_scales_shuffled.append( - shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8), - epilogue_tile_m)) + _cache_permute_indices: dict[torch.Size, torch.Tensor] = {} + if transpose_optimized: + for i in range(num_experts): + # w13 weight shuffling + permute_indices = _maybe_get_cached_w2_permute_indices( + _cache_permute_indices, + w13_weight[i].view(torch.uint8), + epilogue_tile_m, + ) + gemm1_weights_shuffled.append(w13_weight[i].view( + torch.uint8)[permute_indices.to( + w13_weight.device)].contiguous()) + # w13 scale shuffling + permute_sf_indices = _maybe_get_cached_w2_permute_indices( + _cache_permute_indices, + w13_weight_scale[i].view(torch.uint8), + epilogue_tile_m, + num_elts_per_sf=16, + ) + gemm1_scales_shuffled.append( + nvfp4_block_scale_interleave(w13_weight_scale[i].view( + torch.uint8)[permute_sf_indices.to( + w13_weight_scale.device)].contiguous())) + # w13 bias shuffling + permute_bias_indices = _maybe_get_cached_w2_permute_indices( + _cache_permute_indices, + w13_bias[i].clone().reshape(-1, 1), + epilogue_tile_m, + ) + gemm1_bias_shuffled.append(w13_bias[i].clone().reshape( + -1, 1)[permute_bias_indices.to(w13_bias.device)].contiguous()) + # w2 weight shuffling + permute_indices = _maybe_get_cached_w2_permute_indices( + _cache_permute_indices, + w2_weight[i].view(torch.uint8), + epilogue_tile_m, + ) + gemm2_weights_shuffled.append(w2_weight[i].view( + torch.uint8)[permute_indices.to( + w2_weight.device)].contiguous()) + # w2 scale shuffling + permute_sf_indices = _maybe_get_cached_w2_permute_indices( + _cache_permute_indices, + w2_weight_scale[i].view(torch.uint8), + epilogue_tile_m, + num_elts_per_sf=16, + ) + gemm2_scales_shuffled.append( + nvfp4_block_scale_interleave(w2_weight_scale[i].view( + torch.uint8)[permute_sf_indices.to( + w2_weight_scale.device)].contiguous())) + # w2 bias shuffling + permute_indices = _maybe_get_cached_w2_permute_indices( + _cache_permute_indices, + w2_bias[i].clone().reshape(-1, 1), + epilogue_tile_m, + ) + gemm2_bias_shuffled.append(w2_bias[i].clone().reshape( + -1, 1)[permute_indices.to(w2_bias.device)].contiguous()) - gemm2_weights_shuffled.append( - shuffle_matrix_a(w2_weight[i].view(torch.uint8), epilogue_tile_m)) - gemm2_scales_shuffled.append( - shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8), - epilogue_tile_m)) - gemm1_bias_shuffled.append( - shuffle_matrix_a(w13_bias[i].reshape(-1, 1), epilogue_tile_m)) - gemm2_bias_shuffled.append( - shuffle_matrix_a(w2_bias[i].reshape(-1, 1), epilogue_tile_m)) + else: + for i in range(num_experts): + gemm1_weights_shuffled.append( + shuffle_matrix_a(w13_weight[i].view(torch.uint8), + epilogue_tile_m)) + gemm1_scales_shuffled.append( + shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8), + epilogue_tile_m)) + + gemm2_weights_shuffled.append( + shuffle_matrix_a(w2_weight[i].view(torch.uint8), + epilogue_tile_m)) + gemm2_scales_shuffled.append( + shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8), + epilogue_tile_m)) + gemm1_bias_shuffled.append( + shuffle_matrix_a(w13_bias[i].reshape(-1, 1), epilogue_tile_m)) + gemm2_bias_shuffled.append( + shuffle_matrix_a(w2_bias[i].reshape(-1, 1), epilogue_tile_m)) w13_weight = torch.stack(gemm1_weights_shuffled) w13_weight_scale = torch.stack(gemm1_scales_shuffled).reshape( @@ -356,6 +427,7 @@ def check_accuracy(a, b, atol, rtol, percent): @pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), (1.702, 1.0, 7.0)]) @pytest.mark.parametrize("act_type", ['mxfp8', 'bf16']) +@pytest.mark.parametrize("transpose_optimized", [False, True]) @pytest.mark.skipif( not TRTLLM_GEN_MXFP4_AVAILABLE, reason="nvidia gpu and compute capability sm100 is required for this test") @@ -369,6 +441,7 @@ def test_trtllm_gen_mxfp4_fused_moe( beta: float, limit: Optional[float], act_type: str, + transpose_optimized: bool, ): seed = 42 torch.manual_seed(seed) @@ -470,6 +543,321 @@ def test_trtllm_gen_mxfp4_fused_moe( act_type, alpha=alpha, beta=beta, - limit=limit) + limit=limit, + transpose_optimized=transpose_optimized) # relatively loose check since the mxfp4 quantization is less accurate check_accuracy(ref_result, tg_result, atol=0, rtol=0.3, percent=0.8) + + +def _interleave_scales_lastdim_by4(scales: torch.Tensor) -> torch.Tensor: + """Interleave scales on the last dimension by groups of 4, matching + the transformation in mxfp4.py's BF16 (Hopper) path.""" + s = scales.to(torch.uint8) + s_shape = s.shape + assert s_shape[-1] % 4 == 0 + s = s.reshape(*s_shape[:-1], s_shape[-1] // 4, 4) + # Move the 4-group dimension before the row dimension + permuted = s.permute(0, 2, 1, 3) + # Merge the row dim with the 4-group dim + return permuted.reshape(s_shape[0], s_shape[-1] // 4, s_shape[1] * 4) + + +@pytest.mark.parametrize("topk", [1, 4]) +@pytest.mark.parametrize("num_experts", [32]) +@pytest.mark.parametrize("num_tokens", [1, 128]) +@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)]) +@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), + (1.702, 1.0, 7.0)]) +@pytest.mark.skipif( + not HOPPER_MXFP4_BF16_AVAILABLE, + reason="nvidia gpu sm90 and flashinfer are required for this test", +) +def test_flashinfer_cutlass_mxfp4_fused_moe( + topk: int, + num_experts: int, + num_tokens: int, + intermediate_size: int, + hidden_size: int, + alpha: float, + beta: float, + limit: Optional[float], +): + torch.manual_seed(42) + device = "cuda:0" + + # Inputs + hidden_states = torch.randn(num_tokens, + hidden_size, + device=device, + dtype=torch.bfloat16) + # Random MXFP4 weights and scales (uint8), contiguous [w1; w3] + w13_q = torch.randint( + 0, + 256, (num_experts, 2 * intermediate_size, hidden_size // 2), + device=device, + dtype=torch.uint8) + w13_scale = torch.randint( + 118, + 123, (num_experts, 2 * intermediate_size, hidden_size // 32), + device=device, + dtype=torch.uint8) + + w2_q = torch.randint(0, + 256, + (num_experts, hidden_size, intermediate_size // 2), + device=device, + dtype=torch.uint8) + w2_scale = torch.randint( + 118, + 123, (num_experts, hidden_size, intermediate_size // 32), + device=device, + dtype=torch.uint8) + # Bias contiguous [b1; b3] + bias13 = (torch.randn(num_experts, + 2 * intermediate_size, + device=device, + dtype=torch.bfloat16) * 10) + bias2 = (torch.randn( + num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10) + router_logits = torch.rand(num_tokens, + num_experts, + dtype=torch.float32, + device=device) + + w13_ref = mxfp4_dequantize(w13_q.clone(), w13_scale.clone()).reshape( + num_experts, 2 * intermediate_size, hidden_size) + w2_ref = mxfp4_dequantize(w2_q.clone(), w2_scale.clone()).reshape( + num_experts, hidden_size, intermediate_size) + ref = reference_moe(router_logits.to(torch.float32), topk, num_experts, + hidden_states.to(torch.float32), w13_ref, + bias13.to(torch.float32), w2_ref, + bias2.to(torch.float32), alpha, beta, limit, 'bf16') + + from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe + + # Swap halves to arrange as [w3; w1] (kernel expectation) + w1_w, w3_w = torch.chunk(w13_q, 2, dim=1) + w13_q_swapped = torch.cat([w3_w, w1_w], dim=1) + + b1, b3 = torch.chunk(bias13.to(torch.float32), 2, dim=-1) + w13_b = torch.cat([b3, b1], dim=-1).to(torch.bfloat16) + + w1_s, w3_s = torch.chunk(w13_scale, 2, dim=1) + w13_s = torch.cat([w3_s, w1_s], dim=1) + w13_s_inter = _interleave_scales_lastdim_by4(w13_s) + w2_s_inter = _interleave_scales_lastdim_by4(w2_scale) + + routing_weights = torch.nn.functional.softmax(router_logits, + dim=1, + dtype=torch.float32) + token_final_scales, token_selected_experts = torch.topk(routing_weights, + topk, + dim=-1) + token_final_scales = (token_final_scales / + token_final_scales.sum(dim=-1, keepdim=True)) + token_selected_experts = token_selected_experts.to(torch.int).contiguous() + + out = torch.empty_like(hidden_states, dtype=torch.bfloat16) + if alpha is not None: + alpha = torch.full((num_experts, ), alpha, device=hidden_states.device) + if beta is not None: + beta = torch.full((num_experts, ), beta, device=hidden_states.device) + if limit is not None: + limit = torch.full((num_experts, ), limit, device=hidden_states.device) + + _ = flashinfer_cutlass_fused_moe( + input=hidden_states, + token_selected_experts=token_selected_experts, + token_final_scales=token_final_scales, + fc1_expert_weights=w13_q_swapped, + fc2_expert_weights=w2_q, + output_dtype=torch.bfloat16, + output=out, + quant_scales=[w13_s_inter.to(torch.uint8), + w2_s_inter.to(torch.uint8)], + fc1_expert_biases=w13_b, + fc2_expert_biases=bias2.to(torch.bfloat16), + swiglu_alpha=alpha, + swiglu_beta=beta, + swiglu_limit=limit, + tp_size=1, + tp_rank=0, + ep_size=1, + ep_rank=0, + use_w4_group_scaling=True, + ) + + # Allow some mismatch due to MXFP4 quantization + check_accuracy(ref, out, atol=0, rtol=0.3, percent=0.8) + + +@pytest.mark.parametrize("topk", [1, 4]) +@pytest.mark.parametrize("num_experts", [32]) +@pytest.mark.parametrize("num_tokens", [1, 128]) +@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)]) +@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), + (1.702, 1.0, 7.0)]) +@pytest.mark.skipif( + not (current_platform.is_cuda() + and current_platform.is_device_capability(100) and has_flashinfer()), + reason="NVIDIA GPU sm100 and flashinfer are required for this test", +) +def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe( + topk: int, + num_experts: int, + num_tokens: int, + intermediate_size: int, + hidden_size: int, + alpha: Optional[float], + beta: Optional[float], + limit: Optional[float], +): + torch.manual_seed(42) + device = "cuda:0" + + # Inputs + hidden_states = torch.randn(num_tokens, + hidden_size, + device=device, + dtype=torch.bfloat16) + # Float weights in w13 format [w1; w3] + w13 = (torch.randn(num_experts, + 2 * intermediate_size, + hidden_size, + device=device, + dtype=torch.bfloat16) / 10) + w2 = (torch.randn(num_experts, + hidden_size, + intermediate_size, + device=device, + dtype=torch.bfloat16) / 10) + # Bias contiguous [b1; b3] + bias13 = (torch.randn(num_experts, + 2 * intermediate_size, + device=device, + dtype=torch.bfloat16) * 10) + bias2 = (torch.randn( + num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10) + router_logits = torch.rand(num_tokens, + num_experts, + dtype=torch.float32, + device=device) + + # Quantize weights to MXFP4 per expert (SM100 path) + from flashinfer import mxfp4_quantize + + def quant_mxfp4_batches(a: torch.Tensor, e: int): + qs, sfs = [], [] + for i in range(e): + q, sf = mxfp4_quantize(a[i].cuda()) + qs.append(q) + sfs.append(sf) + return torch.stack(qs), torch.stack(sfs) + + def dequant_mxfp4_batches(mat_fp4: torch.Tensor, + scale_tensor: torch.Tensor): + num_batches = mat_fp4.size(0) + scale_tensor = scale_tensor.view(num_batches, -1) + from flashinfer import mxfp4_dequantize + return torch.stack([ + mxfp4_dequantize(mat_fp4[b, :, :], scale_tensor[b, :]) + for b in range(num_batches) + ]) + + w13_q, w13_scale = quant_mxfp4_batches(w13, num_experts) + w2_q, w2_scale = quant_mxfp4_batches(w2, num_experts) + + # Reference result using dequantized tensors and reference_moe + w13_ref = dequant_mxfp4_batches( + w13_q.view(torch.uint8), + w13_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape( + num_experts, 2 * intermediate_size, hidden_size).to(device) + w2_ref = dequant_mxfp4_batches( + w2_q.view(torch.uint8), + w2_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape( + num_experts, hidden_size, intermediate_size).to(device) + + # Quantize activations for SM100 path and dequantize for reference + hidden_states_q, hidden_states_sf = mxfp8_quantize(hidden_states, True, 32) + # Reference uses BF16 input but quantizes intermediate activation to MXFP8 + ref = reference_moe(router_logits.to(torch.float32), topk, num_experts, + hidden_states.to(torch.float32), w13_ref, + bias13.to(torch.float32), w2_ref, + bias2.to(torch.float32), alpha, beta, limit, 'mxfp8') + + # Prepare inputs for FlashInfer CUTLASS fused MoE + from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe + + # Swap halves to arrange as [w3; w1] (kernel expectation) + w1_w, w3_w = torch.chunk(w13_q, 2, dim=1) + w13_q_swapped = torch.cat([w3_w, w1_w], dim=1) + + # Swap scales halves to match swapped weights + s1, s3 = torch.chunk(w13_scale, 2, dim=1) + w13_scale_swapped = torch.cat([s3, s1], dim=1) + + b1, b3 = torch.chunk(bias13.to(torch.float32), 2, dim=-1) + w13_b = torch.cat([b3, b1], dim=-1).to(torch.bfloat16) + + # Build routing for kernel + routing_weights = torch.nn.functional.softmax(router_logits, + dim=1, + dtype=torch.float32) + token_final_scales, token_selected_experts = torch.topk(routing_weights, + topk, + dim=-1) + token_final_scales = (token_final_scales / + token_final_scales.sum(dim=-1, keepdim=True)) + token_selected_experts = token_selected_experts.to(torch.int).contiguous() + + out = torch.empty_like(hidden_states, dtype=torch.bfloat16) + if alpha is not None: + alpha_t = torch.full((num_experts, ), + alpha, + device=hidden_states.device) + else: + alpha_t = None + if beta is not None: + beta_t = torch.full((num_experts, ), beta, device=hidden_states.device) + else: + beta_t = None + if limit is not None: + limit_t = torch.full((num_experts, ), + limit, + device=hidden_states.device) + else: + limit_t = None + + # Quant scales for SM100 MXFP8+MXFP4 path + fake_input_scale = torch.ones(num_experts, device=device) + quant_scales = [ + w13_scale_swapped.view(torch.int32), + fake_input_scale, + w2_scale.view(torch.int32), + fake_input_scale, + ] + + _ = flashinfer_cutlass_fused_moe( + input=hidden_states_q, + token_selected_experts=token_selected_experts, + token_final_scales=token_final_scales, + fc1_expert_weights=w13_q_swapped.contiguous().view(torch.long), + fc2_expert_weights=w2_q.contiguous().view(torch.long), + output_dtype=torch.bfloat16, + output=out, + quant_scales=quant_scales, + fc1_expert_biases=w13_b, + fc2_expert_biases=bias2.to(torch.bfloat16), + swiglu_alpha=alpha_t, + swiglu_beta=beta_t, + swiglu_limit=limit_t, + tp_size=1, + tp_rank=0, + ep_size=1, + ep_rank=0, + use_mxfp8_act_scaling=True, + input_sf=hidden_states_sf, + ) + + # Allow some mismatch due to MXFP4 quantization + check_accuracy(ref, out, atol=0, rtol=0.3, percent=0.8) diff --git a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py index 5a0379dfb4475..383b5ebfba9b7 100644 --- a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py +++ b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py @@ -5,28 +5,52 @@ import pytest import torch from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( - silu_mul_fp8_quant_deep_gemm) + silu_mul_fp8_quant_deep_gemm_cuda) from vllm.platforms import current_platform +from vllm.utils import cdiv + +fp8_dtype = torch.float8_e4m3fn -# (E, T, H, group_size, seed) CASES = [ - (1, 1, 128, 64, 0), - (1, 4, 128, 128, 0), - (2, 4, 256, 128, 0), - (32, 64, 256, 128, 0), - (17, 31, 768, 128, 0), + (1, 1, 128, fp8_dtype), + (1, 4, 128, fp8_dtype), + (2, 4, 256, fp8_dtype), + (32, 64, 256, fp8_dtype), + (17, 31, 768, fp8_dtype), + (1, 1, 128 * 1, fp8_dtype), + (1, 1, 128 * 2, fp8_dtype), + (1, 1, 128 * 3, fp8_dtype), + (1, 1, 128 * 4, fp8_dtype), + (8, 16, 128 * 1, fp8_dtype), + (8, 16, 128 * 2, fp8_dtype), + (8, 16, 128 * 3, fp8_dtype), + (8, 16, 128 * 4, fp8_dtype), + (8, 64, 7168, fp8_dtype), + (8, 128, 7168, fp8_dtype), + (8, 256, 7168, fp8_dtype), + (8, 512, 7168, fp8_dtype), + (8, 1024, 7168, fp8_dtype), + (256, 8, 7168, fp8_dtype), + (256, 16, 7168, fp8_dtype), + (256, 32, 7168, fp8_dtype), + (256, 64, 7168, fp8_dtype), + + # Only add a few fnuz tests to help with long CI times. + (8, 512, 7168, torch.float8_e4m3fnuz), + (8, 1024, 7168, torch.float8_e4m3fnuz), ] -@pytest.mark.parametrize("E,T,H,group_size,seed", CASES) +@pytest.mark.parametrize("E,T,H,fp8_type", CASES) @torch.inference_mode() -def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed): - current_platform.seed_everything(seed) +def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type): + group_size = 128 + current_platform.seed_everything(42) # Input tensor of shape (E, T, 2*H) y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda") tokens_per_expert = torch.randint( - low=0, + low=T // 2, high=T, size=(E, ), dtype=torch.int32, @@ -34,45 +58,59 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed): ) # Run the Triton kernel - y_q, y_s = silu_mul_fp8_quant_deep_gemm(y, - tokens_per_expert, - group_size=group_size, - eps=1e-10) + y_q, y_s = silu_mul_fp8_quant_deep_gemm_cuda(y, + tokens_per_expert, + group_size=group_size) - # Reference implementation - fp8_info = torch.finfo(torch.float8_e4m3fn) + torch.cuda.synchronize() + fp8_info = torch.finfo(fp8_dtype) fp8_max = fp8_info.max fp8_min = fp8_info.min eps = 1e-10 - # Compute silu activation and elementwise multiplication - y1 = y[..., :H] + y1 = y[..., :H].float() y2 = y[..., H:] silu_x = y1 * torch.sigmoid(y1) merged = silu_x * y2 - # Compute reference scales and quantized output, skipping padded tokens for e in range(E): nt = tokens_per_expert[e].item() - ref_s = torch.empty((T, H // group_size), + ref_s = torch.empty((T, cdiv(H, group_size)), dtype=torch.float32, device="cuda") - ref_q = torch.empty((T, H), dtype=torch.float8_e4m3fn, device="cuda") + ref_q = torch.empty((T, H), dtype=fp8_dtype, device="cuda") + for t in range(nt): - data = merged[e, t] - data_grp = data.view(H // group_size, group_size) - amax = data_grp.abs().amax(dim=1).clamp(min=eps) - scale = amax / fp8_max + data = merged[e, t].float() + ref_q_row = torch.empty_like(data) - scaled = data / scale.repeat_interleave(group_size) - clamped = scaled.clamp(fp8_min, fp8_max) - q = clamped.to(torch.float8_e4m3fn) + # process full groups + n_full_groups = H // group_size + if n_full_groups > 0: + data_grp = data[:n_full_groups * group_size].view( + n_full_groups, group_size) + amax = data_grp.abs().amax(dim=1).clamp(min=eps) + scale = amax / fp8_max + scaled = data[:n_full_groups * + group_size] / scale.repeat_interleave(group_size) + ref_q_row[:n_full_groups * group_size] = scaled.clamp( + fp8_min, fp8_max).to(fp8_dtype) + ref_s[t, :n_full_groups] = scale - ref_s[t] = scale - ref_q[t] = q + # process remainder group + rem = H % group_size + if rem > 0: + data_rem = data[-rem:] + amax = data_rem.abs().amax().clamp(min=eps) + scale = amax / fp8_max + scaled = data_rem / scale + ref_q_row[-rem:] = scaled.clamp(fp8_min, fp8_max).to(fp8_dtype) + ref_s[t, -1] = scale - y_se = y_s[e] - y_qe = y_q[e] + ref_q[t] = ref_q_row + + y_se = y_s[e].float() + y_qe = y_q[e].float() torch.testing.assert_close(y_se[:nt], ref_s[:nt], atol=1e-4, rtol=1e-2) torch.testing.assert_close( diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index d9154d3fd7f33..c440747316b80 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -11,8 +11,8 @@ from tests.kernels.quant_utils import (native_per_token_group_quant_fp8, native_w8a8_block_matmul) from vllm.config import VllmConfig from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - get_col_major_tma_aligned_tensor, per_token_group_quant_fp8, - w8a8_block_fp8_matmul) + cutlass_scaled_mm, get_col_major_tma_aligned_tensor, + per_token_group_quant_fp8, w8a8_block_fp8_matmul) from vllm.platforms import current_platform from vllm.utils import has_deep_gemm from vllm.utils.deep_gemm import fp8_gemm_nt, per_block_cast_to_fp8 @@ -98,6 +98,54 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): assert rel_diff < 0.001 +@torch.inference_mode() +def test_w8a8_block_fp8_cutlass_matmul(): + # Test simple case where weight.shape % 128 != 0, + # like in DSV3 kv_a_proj_with_mqa + M = 32 + N = 576 + K = 7168 + block_size = [128, 128] + out_dtype = torch.bfloat16 + seed = 0 + + torch.manual_seed(seed) + factor_for_scale = 1e-2 + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max + + B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max + B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + block_n, block_k = block_size[0], block_size[1] + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + + Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale + # Hopper requires row-major format for scales + Bs_cutlass = Bs.T.contiguous() if current_platform.is_device_capability( + 90) else Bs + + A_fp8, As = per_token_group_quant_fp8(A_fp32, + block_size[1], + column_major_scales=False) + # CUTLASS uses column-major format for scales + A_fp8_cutlass, As_cutlass = per_token_group_quant_fp8( + A_fp32, block_size[1], column_major_scales=True) + + ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, + out_dtype) + out = cutlass_scaled_mm(A_fp8_cutlass, B_fp8, As_cutlass, Bs_cutlass, + block_size, out_dtype) + + rel_diff = (torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / + torch.mean(torch.abs(ref_out.to(torch.float32)))) + assert rel_diff < 0.001 + + @pytest.mark.parametrize( "M,N,K,block_size,out_dtype,seed", itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS)) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index c46db8e307936..c9bf85f6e2a5c 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -1247,7 +1247,7 @@ def baseline_scaled_mm(a: torch.Tensor, # then we would expand a to: # a = [[1, 1, 2, 2], # [3, 3, 4, 4]] - # NOTE this function this function does not explicitly broadcast dimensions + # NOTE this function does not explicitly broadcast dimensions # with an extent of 1, since this can be done implicitly by pytorch def group_broadcast(t, shape): for i, s in enumerate(shape): diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 891bc75fcdee0..6735b7cd9e436 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -11,21 +11,21 @@ import pytest import torch import torch.nn.functional as F -from vllm.config import LoRAConfig -from vllm.lora.fully_sharded_layers import ( - ColumnParallelLinearWithShardedLoRA, - MergedColumnParallelLinearWithShardedLoRA, - MergedQKVParallelLinearWithShardedLoRA, QKVParallelLinearWithShardedLoRA, - RowParallelLinearWithShardedLoRA) +from vllm.config.lora import LoRAConfig # yapf conflicts with isort for this block # yapf: disable from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, + ColumnParallelLinearWithShardedLoRA, LogitsProcessorWithLoRA, LoRAMapping, MergedColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithShardedLoRA, MergedQKVParallelLinearWithLoRA, + MergedQKVParallelLinearWithShardedLoRA, QKVParallelLinearWithLoRA, + QKVParallelLinearWithShardedLoRA, ReplicatedLinearWithLoRA, RowParallelLinearWithLoRA, + RowParallelLinearWithShardedLoRA, VocabParallelEmbeddingWithLoRA) # yapf: enable from vllm.lora.models import LoRALayerWeights, PackedLoRALayerWeights diff --git a/tests/lora/test_lora_allowed_token_ids.py b/tests/lora/test_lora_allowed_token_ids.py index e77eae70445db..be6409000ae77 100644 --- a/tests/lora/test_lora_allowed_token_ids.py +++ b/tests/lora/test_lora_allowed_token_ids.py @@ -3,8 +3,8 @@ import pytest -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - VllmConfig) +from vllm.config import CacheConfig, DeviceConfig, ModelConfig, VllmConfig +from vllm.config.lora import LoRAConfig from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index c9ab32edc7f32..a5802c108c6be 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -8,7 +8,7 @@ import torch from safetensors.torch import load_file from torch import nn -from vllm.config import LoRAConfig +from vllm.config.lora import LoRAConfig from vllm.lora.layers import (ColumnParallelLinearWithLoRA, MergedColumnParallelLinearWithLoRA, RowParallelLinearWithLoRA) diff --git a/tests/lora/test_peft_helper.py b/tests/lora/test_peft_helper.py index df8696cf58e0f..ffffb5d8eab90 100644 --- a/tests/lora/test_peft_helper.py +++ b/tests/lora/test_peft_helper.py @@ -7,7 +7,7 @@ import shutil import pytest -from vllm.config import LoRAConfig +from vllm.config.lora import LoRAConfig from vllm.lora.peft_helper import PEFTHelper ERROR_CASES = [ diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index a836ff94ba3ed..9c47abf8f4dce 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -6,9 +6,10 @@ import random import tempfile from unittest.mock import patch -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig, - VllmConfig) +from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, + ParallelConfig, SchedulerConfig, VllmConfig) +from vllm.config.load import LoadConfig +from vllm.config.lora import LoRAConfig from vllm.lora.models import LoRAMapping from vllm.lora.request import LoRARequest from vllm.v1.worker.gpu_worker import Worker diff --git a/tests/model_executor/model_loader/test_registry.py b/tests/model_executor/model_loader/test_registry.py index 93a3e34835b5a..639ee6db9270f 100644 --- a/tests/model_executor/model_loader/test_registry.py +++ b/tests/model_executor/model_loader/test_registry.py @@ -4,7 +4,8 @@ import pytest from torch import nn -from vllm.config import LoadConfig, ModelConfig +from vllm.config import ModelConfig +from vllm.config.load import LoadConfig from vllm.model_executor.model_loader import (get_model_loader, register_model_loader) from vllm.model_executor.model_loader.base_loader import BaseModelLoader diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index 140f00294765d..86139d598582d 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -13,13 +13,15 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (dispatch_topk_func, vllm_topk_softmax) from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( is_rocm_aiter_moe_enabled) -from vllm.model_executor.layers.layernorm import ( - RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm, - rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm) +from vllm.model_executor.layers.layernorm import (RMSNorm, + dispatch_rocm_rmsnorm_func, + fused_add_rms_norm, rms_norm) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( cutlass_scaled_mm, dispatch_w8a8_blockscale_func, w8a8_block_fp8_matmul) from vllm.platforms import current_platform +RMS_NORM_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16] + # Registered subclass for test @CustomOp.register("relu3") @@ -149,24 +151,27 @@ def test_topk_dispatch(use_rocm_aiter: str, monkeypatch): @pytest.mark.parametrize("add_residual", [True, False]) +@pytest.mark.parametrize("dtype", + [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) @pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"]) @pytest.mark.skipif(not current_platform.is_rocm(), reason="AITER is a feature exclusive for ROCm") -def test_rms_norm_dispatch(add_residual: bool, use_rocm_aiter: str, - use_rocm_aiter_norm: str, monkeypatch): +def test_rms_norm_dispatch(add_residual: bool, dtype: torch.dtype, + use_rocm_aiter: str, use_rocm_aiter_norm: str, + monkeypatch): monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) monkeypatch.setenv("VLLM_ROCM_USE_AITER_RMSNORM", use_rocm_aiter_norm) - rms_norm_func = dispatch_cuda_rmsnorm_func(add_residual) + rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype) - if not add_residual: - if current_platform.is_rocm() and int(use_rocm_aiter) and int( - use_rocm_aiter_norm): - assert rms_norm_func == rocm_aiter_rms_norm - else: - assert rms_norm_func == rms_norm - elif current_platform.is_rocm() and int(use_rocm_aiter) and int( - use_rocm_aiter_norm): - assert rms_norm_func == rocm_aiter_fused_add_rms_norm - else: + should_use_rocm_aiter = current_platform.is_rocm() and int(use_rocm_aiter) \ + and int(use_rocm_aiter_norm) and dtype in RMS_NORM_SUPPORTED_DTYPES + + if add_residual and should_use_rocm_aiter: + assert rms_norm_func == torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add + elif should_use_rocm_aiter: + assert rms_norm_func == torch.ops.vllm.rocm_aiter_rms_norm + elif add_residual: assert rms_norm_func == fused_add_rms_norm + else: + assert rms_norm_func == rms_norm diff --git a/tests/models/language/generation/test_bart.py b/tests/models/language/generation/test_bart.py index b4c771840196c..22ceb27869ac4 100644 --- a/tests/models/language/generation/test_bart.py +++ b/tests/models/language/generation/test_bart.py @@ -178,6 +178,7 @@ def run_test( @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType)) +@pytest.mark.skip(reason="bart not supported in V1") def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model, dtype, max_tokens, num_logprobs, decoder_prompt_type) -> None: @@ -201,6 +202,7 @@ def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model, @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("decoder_prompt_type", [DecoderPromptType.CUSTOM]) +@pytest.mark.skip(reason="bart not supported in V1") def test_models_distributed(hf_runner, vllm_runner, example_encoder_decoder_prompts, distributed_executor_backend, model, dtype, diff --git a/tests/models/language/generation/test_common.py b/tests/models/language/generation/test_common.py index 8a04946b2ffb3..a5aa1e3f49743 100644 --- a/tests/models/language/generation/test_common.py +++ b/tests/models/language/generation/test_common.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import os from typing import Optional import pytest @@ -39,7 +38,7 @@ AITER_MODEL_LIST = [ [ pytest.param( "bigscience/bloom-560m", # bloom - testing alibi slopes - marks=[pytest.mark.core_model], + marks=[pytest.mark.core_model, pytest.mark.slow_test], ), pytest.param( "openai-community/gpt2", # gpt2 @@ -50,7 +49,10 @@ AITER_MODEL_LIST = [ pytest.param("EleutherAI/pythia-70m"), # gpt_neox pytest.param( "google/gemma-1.1-2b-it", # gemma - marks=[pytest.mark.core_model, pytest.mark.cpu_model], + marks=[ + pytest.mark.core_model, pytest.mark.cpu_model, + pytest.mark.slow_test + ], ), pytest.param( "zai-org/chatglm3-6b", # chatglm (text-only) @@ -71,14 +73,17 @@ AITER_MODEL_LIST = [ ), pytest.param( "microsoft/phi-2", # phi - marks=[pytest.mark.core_model], + marks=[pytest.mark.core_model, pytest.mark.slow_test], ), pytest.param( "Qwen/Qwen-7B-Chat", # qwen (text-only) ), pytest.param( "Qwen/Qwen2.5-0.5B-Instruct", # qwen2 - marks=[pytest.mark.core_model, pytest.mark.cpu_model], + marks=[ + pytest.mark.core_model, pytest.mark.cpu_model, + pytest.mark.slow_test + ], ), pytest.param( "Qwen/Qwen3-8B", # qwen (text-only) @@ -93,15 +98,16 @@ AITER_MODEL_LIST = [ "allenai/OLMoE-1B-7B-0924-Instruct", marks=[pytest.mark.cpu_model], ), - pytest.param("swiss-ai/Apertus-8B"), # apertus + pytest.param("swiss-ai/Apertus-8B-2509"), # apertus ]) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize( "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) +@pytest.mark.parametrize("use_prompt_embeds", [True, False]) def test_models(hf_runner, vllm_runner, example_prompts, model: str, max_tokens: int, num_logprobs: int, use_rocm_aiter: bool, - monkeypatch) -> None: + use_prompt_embeds: bool, monkeypatch) -> None: model_info = HF_EXAMPLE_MODELS.find_hf_info(model) model_info.check_available_online(on_fail="skip") @@ -119,7 +125,11 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, # in parts of the operators pytest.skip(f"Skipping '{model}' model test with AITER kernel.") - use_prompt_embeds = os.getenv("VLLM_USE_V1") == "0" + # Note: can be removed when + # https://github.com/vllm-project/vllm/pull/24278 finished + if current_platform.is_cpu() and use_prompt_embeds: + pytest.skip("Skipping use_prompt_embeds=True with " + "V1-only CPU backend.") with hf_runner(model) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index b44ddc61b6c8c..d0e42062099ec 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -301,7 +301,7 @@ def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks( finished_requests_ids is larger than the maximum mamba block capacity. This could generally happen due to the fact that hybrid does support - statelessness mechanism where it can cleanup new incoming requests in + statelessness mechanism where it can clean up new incoming requests in a single step. """ try: @@ -322,7 +322,7 @@ def test_state_cleanup( This test is for verifying that the Hybrid state is cleaned up between steps. - If its not cleaned, an error would be expected. + If it's not cleaned, an error would be expected. """ try: with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: diff --git a/tests/models/language/generation_ppl_test/__init__.py b/tests/models/language/generation_ppl_test/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/models/language/generation_ppl_test/ppl_utils.py b/tests/models/language/generation_ppl_test/ppl_utils.py new file mode 100644 index 0000000000000..6225bbe3377bd --- /dev/null +++ b/tests/models/language/generation_ppl_test/ppl_utils.py @@ -0,0 +1,140 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from https://huggingface.co/docs/transformers/perplexity +from typing import Optional, cast + +import pytest +import torch +from datasets import load_dataset + +import tests.ci_envs as ci_envs +from tests.models.utils import (GenerateModelInfo, + TokensTextLogprobsPromptLogprobs) +from vllm.logprobs import Logprob + +# See #24485 +PPL_TOL = 0.01 +MAX_LENGTH = 1024 + + +@torch.inference_mode +def wikitext_ppl_test(hf_runner, + vllm_runner, + model_info: GenerateModelInfo, + max_length=MAX_LENGTH, + vllm_extra_kwargs=None, + atol=PPL_TOL): + + # A model family has many models with the same architecture, + # and we don't need to test each one. + if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test: + pytest.skip("Skipping test.") + + dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + + # Allow vllm to test using the given dtype, such as float32 + vllm_extra_kwargs = vllm_extra_kwargs or {} + vllm_extra_kwargs["dtype"] = ci_envs.VLLM_CI_DTYPE or model_info.dtype + + # Allow vllm to test using hf_overrides + if model_info.hf_overrides is not None: + vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides + + # Allow changing the head dtype used by vllm in tests + if ci_envs.VLLM_CI_HEAD_DTYPE is not None: + if "hf_overrides" not in vllm_extra_kwargs: + vllm_extra_kwargs["hf_overrides"] = {} + vllm_extra_kwargs["hf_overrides"][ + "head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE + + with vllm_runner(model_info.name, + gpu_memory_utilization=0.7, + max_model_len=max_length, + max_num_seqs=1, + enforce_eager=True, + **vllm_extra_kwargs) as vllm_model: + # Use max_num_seqs=1 to avoid OOM, + # and avoid batch different requests together. + + model_config = vllm_model.llm.llm_engine.model_config + + # Confirm whether vllm is using the correct architecture + if model_info.architecture: + assert (model_info.architecture in model_config.architectures) + + max_length = min(model_config.max_model_len - 1, max_length) + stride = max_length + + tokenizer = vllm_model.llm.get_tokenizer() + tokens = tokenizer.encode("\n\n".join(dataset["text"])) + n_tokens = len(tokens) + + chunks = [] + for begin_loc in range(0, n_tokens, stride): + end_loc = min(begin_loc + max_length, n_tokens) + chunks.append(tokens[begin_loc:end_loc]) + + outputs = vllm_model.generate_greedy_logprobs(prompts=chunks, + max_tokens=1, + num_logprobs=None, + num_prompt_logprobs=0, + use_tqdm=False) + nll_sum = torch.tensor(0., dtype=torch.float32, device="cpu") + n_tokens = 0 + for output in outputs: + output = cast(TokensTextLogprobsPromptLogprobs, output) + token_datas = cast(list[Optional[dict[int, Logprob]]], output[3]) + + assert token_datas[0] is None + token_log_probs = [] + for token_data in token_datas[1:]: + assert token_data is not None + assert len(token_data) == 1 + token_log_prob = list(token_data.values())[0].logprob + token_log_probs.append(token_log_prob) + + neg_log_likelihood = -torch.tensor( + token_log_probs, dtype=torch.float32, device="cpu").sum() + nll_sum += neg_log_likelihood + n_tokens += len(token_log_probs) + vllm_ppl = float(torch.exp(nll_sum / n_tokens)) + vllm_dtype = model_config.dtype + head_dtype = model_config.head_dtype + + # Accelerate ppl test by setting Transformers ppl score to a constant + if model_info.hf_ppl is None: + with hf_runner( + model_info.name, + dtype=ci_envs.VLLM_CI_HF_DTYPE or model_info.hf_dtype, + ) as hf_model: + nll_sum = torch.tensor(0., dtype=torch.float32, device="cpu") + n_tokens = 0 + for chunk in chunks: + inputs = hf_model.wrap_device( + {"input_ids": torch.tensor([chunk])}) + input_ids = inputs["input_ids"] + outputs = hf_model.model(input_ids, labels=input_ids) + neg_log_likelihood = outputs.loss + + neg_log_likelihood = neg_log_likelihood.to(torch.float32).cpu() + + num_loss_tokens = len(chunk) - 1 + nll_sum += neg_log_likelihood * num_loss_tokens + n_tokens += num_loss_tokens + + hf_ppl = float(torch.exp(nll_sum / n_tokens)) + hf_dtype = next(hf_model.model.parameters()).dtype + else: + hf_ppl = model_info.hf_ppl + hf_dtype = "Constant" + + differ = (vllm_ppl - hf_ppl) / hf_ppl + print("Model:", model_info.name) + print("VLLM:", f"dtype:{vllm_dtype}", f"head_dtype:{head_dtype}", vllm_ppl) + print("Transformers:", hf_dtype, hf_ppl) + print("Difference (%):", differ * 100) + + # PPL the smaller, the better + # We are not concerned that the vllm PPL is less than Transformers, + # so we only perform one-sided testing. + assert differ < atol diff --git a/tests/models/language/generation_ppl_test/test_gemma.py b/tests/models/language/generation_ppl_test/test_gemma.py new file mode 100644 index 0000000000000..5324de143d674 --- /dev/null +++ b/tests/models/language/generation_ppl_test/test_gemma.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest + +from tests.models.utils import GenerateModelInfo + +from .ppl_utils import wikitext_ppl_test + +MODELS = [ + GenerateModelInfo("google/gemma-2b"), + GenerateModelInfo("google/gemma-2-2b"), + GenerateModelInfo("google/gemma-3-4b-it"), +] + + +@pytest.mark.parametrize("model_info", MODELS) +def test_ppl(hf_runner, vllm_runner, model_info: GenerateModelInfo): + wikitext_ppl_test(hf_runner, vllm_runner, model_info) diff --git a/tests/models/language/generation_ppl_test/test_gpt.py b/tests/models/language/generation_ppl_test/test_gpt.py new file mode 100644 index 0000000000000..f3f9e55a24234 --- /dev/null +++ b/tests/models/language/generation_ppl_test/test_gpt.py @@ -0,0 +1,14 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest + +from tests.models.utils import GenerateModelInfo + +from .ppl_utils import wikitext_ppl_test + +MODELS = [GenerateModelInfo("openai-community/gpt2-large")] + + +@pytest.mark.parametrize("model_info", MODELS) +def test_ppl(hf_runner, vllm_runner, model_info: GenerateModelInfo): + wikitext_ppl_test(hf_runner, vllm_runner, model_info) diff --git a/tests/models/language/generation_ppl_test/test_qwen.py b/tests/models/language/generation_ppl_test/test_qwen.py new file mode 100644 index 0000000000000..0d3127cbaac47 --- /dev/null +++ b/tests/models/language/generation_ppl_test/test_qwen.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from tests.models.utils import GenerateModelInfo + +from .ppl_utils import wikitext_ppl_test + +MODELS = [ + GenerateModelInfo("Qwen/Qwen3-0.6B"), + GenerateModelInfo("Qwen/Qwen3-0.6B-FP8"), + # transformers: + # Loading a GPTQ quantized model requires optimum, gptqmodel + # GenerateModelInfo("Qwen/Qwen3-0.6B-GPTQ-Int8"), +] + + +@pytest.mark.parametrize("model_info", MODELS) +def test_ppl(hf_runner, vllm_runner, model_info: GenerateModelInfo): + wikitext_ppl_test(hf_runner, vllm_runner, model_info) diff --git a/tests/models/language/pooling/embed_utils.py b/tests/models/language/pooling/embed_utils.py index 8f8393c4e16fc..86751e0a4d5f4 100644 --- a/tests/models/language/pooling/embed_utils.py +++ b/tests/models/language/pooling/embed_utils.py @@ -59,7 +59,7 @@ def correctness_test_embed_models(hf_runner, with hf_runner( model_info.name, - dtype="float32", + dtype=model_info.hf_dtype, is_sentence_transformer=True, ) as hf_model: diff --git a/tests/models/language/pooling/test_classification.py b/tests/models/language/pooling/test_classification.py index c71fa96275335..8e398830d39df 100644 --- a/tests/models/language/pooling/test_classification.py +++ b/tests/models/language/pooling/test_classification.py @@ -11,7 +11,10 @@ from vllm.platforms import current_platform "model", [ pytest.param("jason9693/Qwen2.5-1.5B-apeach", - marks=[pytest.mark.core_model, pytest.mark.cpu_model]), + marks=[ + pytest.mark.core_model, pytest.mark.cpu_model, + pytest.mark.slow_test + ]), ], ) @pytest.mark.parametrize("dtype", diff --git a/tests/models/language/pooling/test_embedding.py b/tests/models/language/pooling/test_embedding.py index 0733ac85c11fc..d61ac08475e3c 100644 --- a/tests/models/language/pooling/test_embedding.py +++ b/tests/models/language/pooling/test_embedding.py @@ -19,7 +19,7 @@ from ...utils import check_embeddings_close # model code with bidirectional attention. # [Decoder-only] pytest.param("BAAI/bge-multilingual-gemma2", - marks=[pytest.mark.core_model]), + marks=[pytest.mark.core_model, pytest.mark.slow_test]), pytest.param( "intfloat/e5-mistral-7b-instruct", # CPU v1 doesn't support sliding window @@ -27,11 +27,20 @@ from ...utils import check_embeddings_close pytest.param("ssmits/Qwen2-7B-Instruct-embed-base", marks=[pytest.mark.cpu_model]), # [Encoder-only] - pytest.param("BAAI/bge-base-en-v1.5", marks=[pytest.mark.core_model]), + pytest.param( + "BAAI/bge-base-en-v1.5", + marks=[ + pytest.mark.core_model, pytest.mark.cpu_model, + pytest.mark.slow_test + ], + ), pytest.param("sentence-transformers/all-MiniLM-L12-v2"), pytest.param("intfloat/multilingual-e5-small"), # [Cross-Encoder] - pytest.param("sentence-transformers/stsb-roberta-base-v2"), + pytest.param( + "sentence-transformers/stsb-roberta-base-v2", + marks=[pytest.mark.core_model, pytest.mark.cpu_model], + ), ], ) def test_models( diff --git a/tests/models/language/pooling/test_mm_classifier_conversion.py b/tests/models/language/pooling/test_mm_classifier_conversion.py new file mode 100644 index 0000000000000..166b953de43e7 --- /dev/null +++ b/tests/models/language/pooling/test_mm_classifier_conversion.py @@ -0,0 +1,114 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.platforms import current_platform + + +def test_idefics_multimodal( + vllm_runner, + monkeypatch, +) -> None: + if current_platform.is_rocm(): + # ROCm Triton FA does not currently support sliding window attention + # switch to use ROCm CK FA backend + monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False") + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + with vllm_runner(model_name="HuggingFaceM4/Idefics3-8B-Llama3", + runner="pooling", + task="classify", + convert="classify", + load_format="dummy", + max_model_len=512, + enforce_eager=True, + tensor_parallel_size=1, + disable_log_stats=True, + dtype="bfloat16") as vllm_model: + llm = vllm_model.get_llm() + outputs = llm.classify(prompts) + for output in outputs: + assert len(output.outputs.probs) == 2 + + +def update_config(config): + config.text_config.update({ + "architectures": ["Gemma3ForSequenceClassification"], + "classifier_from_token": ["A", "B", "C", "D", "E"], + "method": + "no_post_processing", + "id2label": { + "A": "Chair", + "B": "Couch", + "C": "Table", + "D": "Bed", + "E": "Cupboard" + }, + }) + return config + + +def test_gemma_multimodal( + vllm_runner, + monkeypatch, +) -> None: + if current_platform.is_rocm(): + # ROCm Triton FA does not currently support sliding window attention + # switch to use ROCm CK FA backend + monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False") + + messages = [{ + "role": + "system", + "content": + """ + You are a helpful assistant. You will be given a product description + which may also include an image. Classify the following product into + one of the categories: + + A = chair + B = couch + C = table + D = bed + E = cupboard + + You'll answer with exactly one letter (A, B, C, D, or E).""" + }, { + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": + "https://upload.wikimedia.org/wikipedia/commons/c/c6/Set_of_fourteen_side_chairs_MET_DP110780.jpg" + } + }, { + "type": "text", + "text": "A fine 19th century piece of furniture." + }] + }] + + with vllm_runner(model_name="google/gemma-3-4b-it", + runner="pooling", + task="classify", + convert="classify", + load_format="auto", + hf_overrides=update_config, + override_pooler_config={"pooling_type": "LAST"}, + max_model_len=512, + enforce_eager=True, + tensor_parallel_size=1, + disable_log_stats=True, + dtype="bfloat16") as vllm_model: + + llm = vllm_model.get_llm() + prompts = llm.preprocess_chat(messages) + + result = llm.classify(prompts) + assert result[0].outputs.probs[0] > 0.95 + assert all(c < 0.05 for c in result[0].outputs.probs[1:]) \ No newline at end of file diff --git a/tests/models/language/pooling_mteb_test/__init__.py b/tests/models/language/pooling_mteb_test/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/models/language/pooling/mteb_utils.py b/tests/models/language/pooling_mteb_test/mteb_utils.py similarity index 77% rename from tests/models/language/pooling/mteb_utils.py rename to tests/models/language/pooling_mteb_test/mteb_utils.py index 68b1cc80303ad..7b3c02fbbd9f8 100644 --- a/tests/models/language/pooling/mteb_utils.py +++ b/tests/models/language/pooling_mteb_test/mteb_utils.py @@ -9,7 +9,9 @@ import mteb import numpy as np import pytest import requests +import torch +import tests.ci_envs as ci_envs from tests.models.utils import (EmbedModelInfo, RerankModelInfo, check_embeddings_close) @@ -165,19 +167,29 @@ def mteb_test_embed_models(hf_runner, vllm_extra_kwargs=None, hf_model_callback=None, atol=MTEB_EMBED_TOL): - if not model_info.enable_test: - # A model family has many models with the same architecture, - # and we don't need to test each one. + # A model family has many models with the same architecture, + # and we don't need to test each one. + if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test: pytest.skip("Skipping test.") - example_prompts = ["The chef prepared a delicious meal."] + # Test embed_dims, isnan and whether to use normalize + example_prompts = ["The chef prepared a delicious meal." * 1000] + # Allow vllm to test using the given dtype, such as float32 vllm_extra_kwargs = vllm_extra_kwargs or {} - vllm_extra_kwargs["dtype"] = model_info.dtype + vllm_extra_kwargs["dtype"] = ci_envs.VLLM_CI_DTYPE or model_info.dtype + # Allow vllm to test using hf_overrides if model_info.hf_overrides is not None: vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides + # Allow changing the head dtype used by vllm in tests + if ci_envs.VLLM_CI_HEAD_DTYPE is not None: + if "hf_overrides" not in vllm_extra_kwargs: + vllm_extra_kwargs["hf_overrides"] = {} + vllm_extra_kwargs["hf_overrides"][ + "head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE + with vllm_runner(model_info.name, runner="pooling", max_model_len=None, @@ -186,21 +198,35 @@ def mteb_test_embed_models(hf_runner, model_config = vllm_model.llm.llm_engine.model_config + # Confirm whether vllm is using the correct architecture if model_info.architecture: assert model_info.architecture in model_config.architectures + + # Confirm whether vllm uses the correct default_pooling_type, which + # relates to whether chunked prefill and prefix caching are enabled assert (model_config._model_info.default_pooling_type == model_info.default_pooling_type) vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model), MTEB_EMBED_TASKS) vllm_dtype = vllm_model.llm.llm_engine.model_config.dtype - vllm_outputs = vllm_model.embed(example_prompts) + head_dtype = model_config.head_dtype + # Test embed_dims, isnan and whether to use normalize + vllm_outputs = vllm_model.embed(example_prompts, + truncate_prompt_tokens=-1) + assert not torch.any(torch.isnan(torch.tensor(vllm_outputs))) + + # Accelerate mteb test by setting + # SentenceTransformers mteb score to a constant if model_info.mteb_score is None: - with hf_runner(model_info.name, - is_sentence_transformer=True, - dtype="float32") as hf_model: + with hf_runner( + model_info.name, + is_sentence_transformer=True, + dtype=ci_envs.VLLM_CI_HF_DTYPE or model_info.hf_dtype, + ) as hf_model: + # e.g. setting default parameters for the encode method of hf_runner if hf_model_callback is not None: hf_model_callback(hf_model) @@ -221,7 +247,8 @@ def mteb_test_embed_models(hf_runner, st_dtype = "Constant" print("Model:", model_info.name) - print("VLLM:", vllm_dtype, vllm_main_score) + print("VLLM:", f"dtype:{vllm_dtype}", f"head_dtype:{head_dtype}", + vllm_main_score) print("SentenceTransformers:", st_dtype, st_main_score) print("Difference:", st_main_score - vllm_main_score) @@ -263,9 +290,12 @@ def run_mteb_rerank(cross_encoder, tasks, languages): return main_score -def mteb_test_rerank_models_hf(hf_runner, model_name, hf_model_callback=None): +def mteb_test_rerank_models_hf(hf_runner, + model_name, + hf_dtype="float32", + hf_model_callback=None): with hf_runner(model_name, is_cross_encoder=True, - dtype="float32") as hf_model: + dtype=hf_dtype) as hf_model: original_predict = hf_model.predict @@ -299,17 +329,26 @@ def mteb_test_rerank_models(hf_runner, hf_model_callback=None, vllm_mteb_encoder=VllmMtebEncoder, atol=MTEB_RERANK_TOL): - if not model_info.enable_test: - # A model family has many models with the same architecture, - # and we don't need to test each one. + # A model family has many models with the same architecture, + # and we don't need to test each one. + if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test: pytest.skip("Skipping test.") + # Allow vllm to test using the given dtype, such as float32 vllm_extra_kwargs = vllm_extra_kwargs or {} - vllm_extra_kwargs["dtype"] = model_info.dtype + vllm_extra_kwargs["dtype"] = ci_envs.VLLM_CI_DTYPE or model_info.dtype + # Allow vllm to test using hf_overrides if model_info.hf_overrides is not None: vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides + # Allow changing the head dtype used by vllm in tests + if ci_envs.VLLM_CI_HEAD_DTYPE is not None: + if "hf_overrides" not in vllm_extra_kwargs: + vllm_extra_kwargs["hf_overrides"] = {} + vllm_extra_kwargs["hf_overrides"][ + "head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE + with vllm_runner(model_info.name, runner="pooling", max_model_len=None, @@ -319,9 +358,15 @@ def mteb_test_rerank_models(hf_runner, model_config = vllm_model.llm.llm_engine.model_config + # Confirm whether vllm is using the correct architecture if model_info.architecture: assert (model_info.architecture in model_config.architectures) + + # Score API is only enabled for num_labels == 1 assert model_config.hf_config.num_labels == 1 + + # Confirm whether vllm uses the correct default_pooling_type, which + # relates to whether chunked prefill and prefix caching are enabled assert (model_config._model_info.default_pooling_type == model_info.default_pooling_type) @@ -329,16 +374,20 @@ def mteb_test_rerank_models(hf_runner, tasks=MTEB_RERANK_TASKS, languages=MTEB_RERANK_LANGS) vllm_dtype = model_config.dtype + head_dtype = model_config.head_dtype + # Accelerate mteb test by setting + # SentenceTransformers mteb score to a constant if model_info.mteb_score is None: st_main_score, st_dtype = mteb_test_rerank_models_hf( - hf_runner, model_info.name, hf_model_callback) + hf_runner, model_info.name, model_info.hf_dtype, hf_model_callback) else: st_main_score = model_info.mteb_score st_dtype = "Constant" print("Model:", model_info.name) - print("VLLM:", vllm_dtype, vllm_main_score) + print("VLLM:", f"dtype:{vllm_dtype}", f"head_dtype:{head_dtype}", + vllm_main_score) print("SentenceTransformers:", st_dtype, st_main_score) print("Difference:", st_main_score - vllm_main_score) diff --git a/tests/models/language/pooling/test_baai.py b/tests/models/language/pooling_mteb_test/test_baai.py similarity index 93% rename from tests/models/language/pooling/test_baai.py rename to tests/models/language/pooling_mteb_test/test_baai.py index be8cb6fa76994..e131c9b1038de 100644 --- a/tests/models/language/pooling/test_baai.py +++ b/tests/models/language/pooling_mteb_test/test_baai.py @@ -2,10 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest -from ...utils import (CLSPoolingEmbedModelInfo, CLSPoolingRerankModelInfo, - EmbedModelInfo, LASTPoolingEmbedModelInfo, - RerankModelInfo) -from .embed_utils import correctness_test_embed_models +from tests.models.language.pooling.embed_utils import ( + correctness_test_embed_models) +from tests.models.utils import (CLSPoolingEmbedModelInfo, + CLSPoolingRerankModelInfo, EmbedModelInfo, + LASTPoolingEmbedModelInfo, RerankModelInfo) + from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models MODELS = [ diff --git a/tests/models/language/pooling/test_bge_reranker_v2_gemma.py b/tests/models/language/pooling_mteb_test/test_bge_reranker_v2_gemma.py similarity index 95% rename from tests/models/language/pooling/test_bge_reranker_v2_gemma.py rename to tests/models/language/pooling_mteb_test/test_bge_reranker_v2_gemma.py index eaa8bfb84ffdd..1eca2a2c0abd9 100644 --- a/tests/models/language/pooling/test_bge_reranker_v2_gemma.py +++ b/tests/models/language/pooling_mteb_test/test_bge_reranker_v2_gemma.py @@ -7,13 +7,14 @@ import pytest import torch from tests.conftest import HfRunner - -from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo -from .mteb_utils import VllmMtebEncoder, mteb_test_rerank_models +from tests.models.language.pooling_mteb_test.mteb_utils import ( + VllmMtebEncoder, mteb_test_rerank_models) +from tests.models.utils import LASTPoolingRerankModelInfo, RerankModelInfo RERANK_MODELS = [ LASTPoolingRerankModelInfo("BAAI/bge-reranker-v2-gemma", architecture="GemmaForSequenceClassification", + mteb_score=0.33757, hf_overrides={ "architectures": ["GemmaForSequenceClassification"], diff --git a/tests/models/language/pooling/test_cross_encoder.py b/tests/models/language/pooling_mteb_test/test_cross_encoder.py similarity index 85% rename from tests/models/language/pooling/test_cross_encoder.py rename to tests/models/language/pooling_mteb_test/test_cross_encoder.py index b49908c9ce6a6..ad320fae0c85a 100644 --- a/tests/models/language/pooling/test_cross_encoder.py +++ b/tests/models/language/pooling_mteb_test/test_cross_encoder.py @@ -2,8 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest -from ...utils import (CLSPoolingRerankModelInfo, LASTPoolingRerankModelInfo, - RerankModelInfo) +from tests.models.utils import (CLSPoolingRerankModelInfo, + LASTPoolingRerankModelInfo, RerankModelInfo) + from .mteb_utils import mteb_test_rerank_models RERANK_MODELS = [ diff --git a/tests/models/language/pooling/test_gte.py b/tests/models/language/pooling_mteb_test/test_gte.py similarity index 94% rename from tests/models/language/pooling/test_gte.py rename to tests/models/language/pooling_mteb_test/test_gte.py index 98d215b0ad25e..9ae43fd05bf78 100644 --- a/tests/models/language/pooling/test_gte.py +++ b/tests/models/language/pooling_mteb_test/test_gte.py @@ -3,10 +3,12 @@ import pytest -from ...utils import (CLSPoolingEmbedModelInfo, CLSPoolingRerankModelInfo, - EmbedModelInfo, LASTPoolingEmbedModelInfo, - RerankModelInfo) -from .embed_utils import correctness_test_embed_models +from tests.models.language.pooling.embed_utils import ( + correctness_test_embed_models) +from tests.models.utils import (CLSPoolingEmbedModelInfo, + CLSPoolingRerankModelInfo, EmbedModelInfo, + LASTPoolingEmbedModelInfo, RerankModelInfo) + from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models MODELS = [ diff --git a/tests/models/language/pooling/test_intfloat.py b/tests/models/language/pooling_mteb_test/test_intfloat.py similarity index 92% rename from tests/models/language/pooling/test_intfloat.py rename to tests/models/language/pooling_mteb_test/test_intfloat.py index bc95475836e87..0d6026898ad4a 100644 --- a/tests/models/language/pooling/test_intfloat.py +++ b/tests/models/language/pooling_mteb_test/test_intfloat.py @@ -2,8 +2,10 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest -from ...utils import CLSPoolingEmbedModelInfo, EmbedModelInfo -from .embed_utils import correctness_test_embed_models +from tests.models.language.pooling.embed_utils import ( + correctness_test_embed_models) +from tests.models.utils import CLSPoolingEmbedModelInfo, EmbedModelInfo + from .mteb_utils import mteb_test_embed_models MODELS = [ diff --git a/tests/models/language/pooling/test_jina.py b/tests/models/language/pooling_mteb_test/test_jina.py similarity index 92% rename from tests/models/language/pooling/test_jina.py rename to tests/models/language/pooling_mteb_test/test_jina.py index c4e4835556a54..0a77a78bb31b6 100644 --- a/tests/models/language/pooling/test_jina.py +++ b/tests/models/language/pooling_mteb_test/test_jina.py @@ -4,12 +4,13 @@ from functools import partial import pytest +from tests.models.language.pooling.embed_utils import ( + check_embeddings_close, correctness_test_embed_models, matryoshka_fy) +from tests.models.utils import (CLSPoolingEmbedModelInfo, + CLSPoolingRerankModelInfo, EmbedModelInfo, + RerankModelInfo) from vllm import PoolingParams -from ...utils import (CLSPoolingEmbedModelInfo, CLSPoolingRerankModelInfo, - EmbedModelInfo, RerankModelInfo) -from .embed_utils import (check_embeddings_close, - correctness_test_embed_models, matryoshka_fy) from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models EMBEDDING_MODELS = [ diff --git a/tests/models/language/pooling/test_mxbai_rerank.py b/tests/models/language/pooling_mteb_test/test_mxbai_rerank.py similarity index 97% rename from tests/models/language/pooling/test_mxbai_rerank.py rename to tests/models/language/pooling_mteb_test/test_mxbai_rerank.py index 1731c6ae6fff7..05ebb4ec4d3f5 100644 --- a/tests/models/language/pooling/test_mxbai_rerank.py +++ b/tests/models/language/pooling_mteb_test/test_mxbai_rerank.py @@ -6,8 +6,8 @@ import pytest import torch from tests.conftest import HfRunner +from tests.models.utils import LASTPoolingRerankModelInfo, RerankModelInfo -from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo from .mteb_utils import mteb_test_rerank_models mxbai_rerank_hf_overrides = { diff --git a/tests/models/language/pooling/test_nomic.py b/tests/models/language/pooling_mteb_test/test_nomic.py similarity index 90% rename from tests/models/language/pooling/test_nomic.py rename to tests/models/language/pooling_mteb_test/test_nomic.py index 52a8ce6e6671f..61512fd0dff18 100644 --- a/tests/models/language/pooling/test_nomic.py +++ b/tests/models/language/pooling_mteb_test/test_nomic.py @@ -3,8 +3,10 @@ import pytest -from ...utils import CLSPoolingEmbedModelInfo, EmbedModelInfo -from .embed_utils import correctness_test_embed_models +from tests.models.language.pooling.embed_utils import ( + correctness_test_embed_models) +from tests.models.utils import CLSPoolingEmbedModelInfo, EmbedModelInfo + from .mteb_utils import mteb_test_embed_models MODELS = [ diff --git a/tests/models/language/pooling/test_qwen3_reranker.py b/tests/models/language/pooling_mteb_test/test_qwen3_reranker.py similarity index 98% rename from tests/models/language/pooling/test_qwen3_reranker.py rename to tests/models/language/pooling_mteb_test/test_qwen3_reranker.py index ebdacf9d0c673..65403081dc0f8 100644 --- a/tests/models/language/pooling/test_qwen3_reranker.py +++ b/tests/models/language/pooling_mteb_test/test_qwen3_reranker.py @@ -6,9 +6,9 @@ import pytest import torch from tests.conftest import HfRunner +from tests.models.utils import LASTPoolingRerankModelInfo, RerankModelInfo from tests.utils import multi_gpu_test -from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo from .mteb_utils import mteb_test_rerank_models qwen3_reranker_hf_overrides = { diff --git a/tests/models/language/pooling/test_snowflake_arctic_embed.py b/tests/models/language/pooling_mteb_test/test_snowflake_arctic_embed.py similarity index 94% rename from tests/models/language/pooling/test_snowflake_arctic_embed.py rename to tests/models/language/pooling_mteb_test/test_snowflake_arctic_embed.py index 864f3d75ef5aa..91bad2c4e42fc 100644 --- a/tests/models/language/pooling/test_snowflake_arctic_embed.py +++ b/tests/models/language/pooling_mteb_test/test_snowflake_arctic_embed.py @@ -3,8 +3,10 @@ import pytest -from ...utils import CLSPoolingEmbedModelInfo, EmbedModelInfo -from .embed_utils import correctness_test_embed_models +from tests.models.language.pooling.embed_utils import ( + correctness_test_embed_models) +from tests.models.utils import CLSPoolingEmbedModelInfo, EmbedModelInfo + from .mteb_utils import mteb_test_embed_models MODELS = [ diff --git a/tests/models/language/pooling/test_st_projector.py b/tests/models/language/pooling_mteb_test/test_st_projector.py similarity index 86% rename from tests/models/language/pooling/test_st_projector.py rename to tests/models/language/pooling_mteb_test/test_st_projector.py index 9301e705c4335..bd493e7e2ba09 100644 --- a/tests/models/language/pooling/test_st_projector.py +++ b/tests/models/language/pooling_mteb_test/test_st_projector.py @@ -2,8 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest -from ...utils import (CLSPoolingEmbedModelInfo, EmbedModelInfo, - LASTPoolingEmbedModelInfo) +from tests.models.utils import (CLSPoolingEmbedModelInfo, EmbedModelInfo, + LASTPoolingEmbedModelInfo) + from .mteb_utils import mteb_test_embed_models # ST models with projector (Dense) layers diff --git a/tests/models/multimodal/generation/test_pixtral.py b/tests/models/multimodal/generation/test_pixtral.py index f95dbc7547ecc..a4e21aface41f 100644 --- a/tests/models/multimodal/generation/test_pixtral.py +++ b/tests/models/multimodal/generation/test_pixtral.py @@ -187,27 +187,19 @@ def test_chat(vllm_runner, max_model_len: int, model: str, dtype: str, name_1="output") -@pytest.fixture -def prompt(request, local_asset_server) -> TextPrompt: - names = request.param - urls = [local_asset_server.url_for(n) for n in names] - return _create_engine_inputs_hf(urls) - - @pytest.mark.parametrize( - "prompt,expected_ranges", - [ - pytest.param(IMG_URLS[:1], [PlaceholderRange(offset=11, length=494)]), - pytest.param(IMG_URLS[1:4], [ - PlaceholderRange(offset=11, length=266), - PlaceholderRange(offset=277, length=1056), - PlaceholderRange(offset=1333, length=418) - ]) - ], -) -def test_multi_modal_placeholders(vllm_runner, prompt: TextPrompt, + "image_urls,expected_ranges", + [(IMG_URLS[:1], [PlaceholderRange(offset=11, length=494)]), + (IMG_URLS[1:4], [ + PlaceholderRange(offset=11, length=266), + PlaceholderRange(offset=277, length=1056), + PlaceholderRange(offset=1333, length=418) + ])]) +def test_multi_modal_placeholders(vllm_runner, image_urls: list[str], expected_ranges: list[PlaceholderRange], - monkeypatch) -> None: + local_asset_server, monkeypatch) -> None: + local_image_urls = [local_asset_server.url_for(u) for u in image_urls] + prompt = _create_engine_inputs_hf(local_image_urls) # This placeholder checking test only works with V0 engine # where `multi_modal_placeholders` is returned with `RequestOutput` diff --git a/tests/models/multimodal/generation/test_whisper.py b/tests/models/multimodal/generation/test_whisper.py index 4a65e8c95204e..e0e9980b88339 100644 --- a/tests/models/multimodal/generation/test_whisper.py +++ b/tests/models/multimodal/generation/test_whisper.py @@ -122,8 +122,7 @@ def run_test( @pytest.mark.core_model -@pytest.mark.parametrize( - "model", ["openai/whisper-small", "openai/whisper-large-v3-turbo"]) +@pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"]) @create_new_process_for_each_test() def test_models(vllm_runner, model) -> None: run_test( diff --git a/tests/models/multimodal/processing/test_glm4_1v.py b/tests/models/multimodal/processing/test_glm4_1v.py index a49842e1099c2..dfb8d9b2a038d 100644 --- a/tests/models/multimodal/processing/test_glm4_1v.py +++ b/tests/models/multimodal/processing/test_glm4_1v.py @@ -5,6 +5,7 @@ import pytest from vllm.assets.video import VideoAsset from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.video import OpenCVDynamicVideoBackend, OpenCVVideoBackend from ...utils import build_model_context @@ -50,3 +51,49 @@ def test_processor_override( assert grid_t == expected_grid_t assert video_tok_count == expected_toks_per_frame * grid_t + + +@pytest.mark.parametrize("model_id", ["zai-org/GLM-4.1V-9B-Thinking"]) +@pytest.mark.parametrize("fps", [2]) +def test_video_loader_consistency( + model_id: str, + fps: int, +): + """ + Ensure dynamic video loader (pre-sampled by loader) and normal video + loader (post-sampled by processor) produce same video processing outputs. + """ + ctx = build_model_context( + model_id, + mm_processor_kwargs=None, + limit_mm_per_prompt={"video": 1}, + ) + processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config) + hf_processor_mm_kwargs = {"fps": fps} + + # Build the image str / prompt based on the number of images we pass + prompt = "<|begin_of_video|><|video|><|end_of_video|>" + + video_path = VideoAsset(name="baby_reading", num_frames=-1).video_path + with open(video_path, "rb") as f: + video_bytes = f.read() + + static_video, static_metadata = OpenCVVideoBackend.load_bytes(video_bytes) + dynamic_video, dynamic_metadata = OpenCVDynamicVideoBackend.load_bytes( + video_bytes, requested_fps=fps) + + # pre-sampled loader shouldn't read all frames + assert len(dynamic_video) < len(static_video) + + static_mm_data = {"video": [(static_video, static_metadata)]} + dynamic_mm_data = {"video": [(dynamic_video, dynamic_metadata)]} + + static_outputs = processor.apply(prompt, static_mm_data, + hf_processor_mm_kwargs) + dynamic_outputs = processor.apply(prompt, dynamic_mm_data, + hf_processor_mm_kwargs) + + assert static_outputs["prompt_token_ids"] == dynamic_outputs[ + "prompt_token_ids"] + assert static_outputs["mm_kwargs"].get_data( + ) == dynamic_outputs["mm_kwargs"].get_data() diff --git a/tests/models/multimodal/processing/test_tensor_schema.py b/tests/models/multimodal/processing/test_tensor_schema.py index b678313752d65..3b87b669dbbe3 100644 --- a/tests/models/multimodal/processing/test_tensor_schema.py +++ b/tests/models/multimodal/processing/test_tensor_schema.py @@ -31,6 +31,7 @@ from ...utils import dummy_hf_overrides ARCH_TO_SKIP = { "MolmoForCausalLM": "incompatible requirements", + "Florence2ForConditionalGeneration": "not supported in V1", } ARCH_NEEDS_EXTRAS = [ "InternVLChatModel", diff --git a/tests/models/registry.py b/tests/models/registry.py index c6ff50b5426e1..b268bf12a3f30 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -97,6 +97,12 @@ class _HfExamplesInfo: max_num_seqs: Optional[int] = None """Maximum number of sequences to be processed in a single iteration.""" + use_original_num_layers: bool = False + """ + If True, use the original number of layers from the model config + instead of minimal layers for testing. + """ + def check_transformers_version( self, *, @@ -158,7 +164,7 @@ class _HfExamplesInfo: # yapf: disable _TEXT_GENERATION_EXAMPLE_MODELS = { # [Decoder-only] - "ApertusForCausalLM": _HfExamplesInfo("swiss-ai/Apertus-8B", + "ApertusForCausalLM": _HfExamplesInfo("swiss-ai/Apertus-8B-2509", min_transformers_version="4.56.0", trust_remote_code=True), "AquilaModel": _HfExamplesInfo("BAAI/AquilaChat-7B", @@ -285,7 +291,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"), "MixtralForCausalLM": _HfExamplesInfo("mistralai/Mixtral-8x7B-Instruct-v0.1", # noqa: E501 {"tiny": "TitanML/tiny-mixtral"}), # noqa: E501 - "QuantMixtralForCausalLM": _HfExamplesInfo("mistral-community/Mixtral-8x22B-v0.1-AWQ"), # noqa: E501 + "MotifForCausalLM": _HfExamplesInfo("Motif-Technologies/Motif-2.6B", + trust_remote_code=True, + v0_only=True), "MptForCausalLM": _HfExamplesInfo("mpt", is_available_online=False), "MPTForCausalLM": _HfExamplesInfo("mosaicml/mpt-7b"), "NemotronForCausalLM": _HfExamplesInfo("nvidia/Minitron-8B-Base"), @@ -293,6 +301,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { trust_remote_code=True), "OlmoForCausalLM": _HfExamplesInfo("allenai/OLMo-1B-hf"), "Olmo2ForCausalLM": _HfExamplesInfo("allenai/OLMo-2-0425-1B"), + "Olmo3ForCausalLM": _HfExamplesInfo("shanearora/2025-sep-a-base-model"), "OlmoeForCausalLM": _HfExamplesInfo("allenai/OLMoE-1B-7B-0924-Instruct"), "OPTForCausalLM": _HfExamplesInfo("facebook/opt-125m", {"1b": "facebook/opt-iml-max-1.3b"}), @@ -318,6 +327,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"), "Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"), "Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"), + "Qwen3NextForCausalLM": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct", + min_transformers_version="4.56.2"), "RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"), "SeedOssForCausalLM": _HfExamplesInfo("ByteDance-Seed/Seed-OSS-36B-Instruct", # noqa: E501 trust_remote_code=True, @@ -383,7 +394,7 @@ _EMBEDDING_EXAMPLE_MODELS = { "Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full", trust_remote_code=True), "Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), # noqa: E501 - "PrithviGeoSpatialMAE": _HfExamplesInfo("mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501 + "PrithviGeoSpatialMAE": _HfExamplesInfo("ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501 dtype=torch.float16, enforce_eager=True, skip_tokenizer_init=True, @@ -391,7 +402,7 @@ _EMBEDDING_EXAMPLE_MODELS = { # going OOM in CI max_num_seqs=32, ), - "Terratorch": _HfExamplesInfo("mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11", + "Terratorch": _HfExamplesInfo("ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501 dtype=torch.float16, enforce_eager=True, skip_tokenizer_init=True, @@ -516,6 +527,9 @@ _MULTIMODAL_EXAMPLE_MODELS = { trust_remote_code=True), "Llama_Nemotron_Nano_VL" : _HfExamplesInfo("nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1", # noqa: E501 trust_remote_code=True), + "NemotronH_Nano_VL": _HfExamplesInfo("nano_vl_dummy", + is_available_online=False, + trust_remote_code=True), "Ovis": _HfExamplesInfo("AIDC-AI/Ovis2-1B", trust_remote_code=True, max_transformers_version="4.53", transformers_version_reason="HF model is not compatible", # noqa: E501 @@ -595,19 +609,21 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { "EagleDeepSeekMTPModel": _HfExamplesInfo("eagle618/deepseek-v3-random", speculative_model="eagle618/eagle-deepseek-v3-random", # noqa: E501 trust_remote_code=True), - "EagleLlamaForCausalLM": _HfExamplesInfo("yuhuili/EAGLE-LLaMA3-Instruct-8B", + "EagleLlamaForCausalLM": _HfExamplesInfo("meta-llama/Meta-Llama-3-8B-Instruct", # noqa: E501 trust_remote_code=True, speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B", - tokenizer="meta-llama/Meta-Llama-3-8B-Instruct"), # noqa: E501 - "Eagle3LlamaForCausalLM": _HfExamplesInfo("yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", # noqa: E501 + tokenizer="meta-llama/Meta-Llama-3-8B-Instruct"), # noqa: E501 + "Eagle3LlamaForCausalLM": _HfExamplesInfo("meta-llama/Llama-3.1-8B-Instruct", # noqa: E501 trust_remote_code=True, - speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", - tokenizer="meta-llama/Llama-3.1-8B-Instruct"), - # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501 - # "LlamaForCausalLMEagle3": _HfExamplesInfo("AngelSlim/Qwen3-8B_eagle3", # noqa: E501 - # trust_remote_code=True, - # speculative_model="AngelSlim/Qwen3-8B_eagle3", # noqa: E501 - # tokenizer="Qwen/Qwen3-8B"), + speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", # noqa: E501 + tokenizer="meta-llama/Llama-3.1-8B-Instruct", + use_original_num_layers=True, + max_model_len=10240), + "LlamaForCausalLMEagle3": _HfExamplesInfo("Qwen/Qwen3-8B", # noqa: E501 + trust_remote_code=True, + speculative_model="AngelSlim/Qwen3-8B_eagle3", # noqa: E501 + tokenizer="Qwen/Qwen3-8B", + use_original_num_layers=True), "EagleLlama4ForCausalLM": _HfExamplesInfo( "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", trust_remote_code=True, @@ -627,7 +643,9 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { is_available_online=False), "MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL", trust_remote_code=True, - speculative_model="XiaomiMiMo/MiMo-7B-RL") + speculative_model="XiaomiMiMo/MiMo-7B-RL"), + "Qwen3NextMTP": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct", + min_transformers_version="4.56.2"), } _TRANSFORMERS_BACKEND_MODELS = { diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index aaa04f52f7794..c22d94948d249 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -18,6 +18,26 @@ from .registry import (_TRANSFORMERS_BACKEND_MODELS, AUTO_EXAMPLE_MODELS, HF_EXAMPLE_MODELS, HfExampleModels) from .utils import dummy_hf_overrides +# This minimal list of model architectures is smaller than the total list of +# supported models. The intention is that in the "typical" regression testing +# scenario, we only test initializing these models. This subset was chosen +# to include representative examples of model varieties/workloads (conditional +# generation, sequence classification, causal LM, ranking, chat, reward model, +# multimodal, geospatial, voice, embedding, MTP) +MINIMAL_MODEL_ARCH_LIST = [ + "LlavaForConditionalGeneration", "Llama4ForConditionalGeneration", + "BertForSequenceClassification", "Gemma3nForCausalLM", "JinaVLForRanking", + "InternVLChatModel", "InternLM2ForRewardModel", + "TransformersForMultimodalLM", "PrithviGeoSpatialMAE", "UltravoxModel", + "DeepSeekMTPModel", "XLMRobertaModel" +] + +# This list is the complement of the minimal list above. The intention is that +# this list of models is only tested in a "special case" i.e. most PRs should +# not test these models +OTHER_MODEL_ARCH_LIST = (set(HF_EXAMPLE_MODELS.get_supported_archs()) - + set(MINIMAL_MODEL_ARCH_LIST)) + @create_new_process_for_each_test() def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch, @@ -36,7 +56,10 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch, hf_overrides_fn = partial(dummy_hf_overrides, model_arch=model_arch, - exist_overrides=model_info.hf_overrides) + exist_overrides=model_info.hf_overrides, + use_original_num_layers=getattr( + model_info, 'use_original_num_layers', + False)) # Avoid calling model.forward() def _initialize_kv_caches_v0(self) -> None: @@ -60,14 +83,21 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch, _initialize_kv_caches_v1), monkeypatch.context() as m): if model_info.v0_only: m.setenv("VLLM_USE_V1", "0") - if model_arch == "Phi4FlashForCausalLM": - # Phi4FlashForCausalLM only supports DIFFERENTIAL_FLASH_ATTN backend + if model_arch in ("Phi4FlashForCausalLM", "MotifForCausalLM"): + # Phi4FlashForCausalLM and MotifForCausalLM + # only supports DIFFERENTIAL_FLASH_ATTN backend m.setenv("VLLM_ATTENTION_BACKEND", "DIFFERENTIAL_FLASH_ATTN") if model_arch == "GptOssForCausalLM": # FIXME: A hack to bypass FA3 assertion because our CI's L4 GPU # has cc==8.9 which hasn't supported FA3 yet. Remove this hack when # L4 supports FA3. m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1") + if model_arch == "Florence2ForConditionalGeneration": + # An encoder-decoder model that's V0-only. Just skip it + # since V0 is about to be removed. + pytest.skip("Skipping Florence2ForConditionalGeneration") + if model_arch == "WhisperForConditionalGeneration": + m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") LLM( model_info.default, tokenizer=model_info.tokenizer, @@ -91,8 +121,23 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch, max_num_seqs=model_info.max_num_seqs) -@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs()) -def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch): +@pytest.mark.parametrize("model_arch", MINIMAL_MODEL_ARCH_LIST) +def test_can_initialize_small_subset(model_arch: str, + monkeypatch: pytest.MonkeyPatch): + """Test initializing small subset of supported models""" + if model_arch == "Lfm2ForCausalLM": + pytest.skip("Skipping until test supports V1-only models") + can_initialize(model_arch, monkeypatch, HF_EXAMPLE_MODELS) + + +@pytest.mark.parametrize("model_arch", OTHER_MODEL_ARCH_LIST) +def test_can_initialize_large_subset(model_arch: str, + monkeypatch: pytest.MonkeyPatch): + """Test initializing large subset of supported models + + This test covers the complement of the tests covered in the "small subset" + test. + """ if model_arch == "Lfm2ForCausalLM": pytest.skip("Skipping until test supports V1-only models") can_initialize(model_arch, monkeypatch, HF_EXAMPLE_MODELS) diff --git a/tests/models/test_terratorch.py b/tests/models/test_terratorch.py index bfa54280dc02d..d6d43ca2f7e15 100644 --- a/tests/models/test_terratorch.py +++ b/tests/models/test_terratorch.py @@ -11,7 +11,7 @@ from vllm.utils import set_default_torch_num_threads @pytest.mark.parametrize( "model", [ - "mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11", + "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", "mgazz/Prithvi_v2_eo_300_tl_unet_agb" ], ) diff --git a/tests/models/utils.py b/tests/models/utils.py index ab0b27af4d697..76c6e4823a12c 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -347,14 +347,15 @@ class ModelInfo: name: str architecture: str = "" dtype: str = "auto" + hf_dtype: str = "float32" hf_overrides: Optional[dict[str, Any]] = None default_pooling_type: str = "" - mteb_score: Optional[float] = None enable_test: bool = True @dataclass class EmbedModelInfo(ModelInfo): + mteb_score: Optional[float] = None is_matryoshka: bool = False matryoshka_dimensions: Optional[list[int]] = None @@ -371,7 +372,7 @@ class LASTPoolingEmbedModelInfo(EmbedModelInfo): @dataclass class RerankModelInfo(ModelInfo): - pass + mteb_score: Optional[float] = None @dataclass @@ -384,11 +385,18 @@ class LASTPoolingRerankModelInfo(RerankModelInfo): default_pooling_type: str = "LAST" +@dataclass +class GenerateModelInfo(ModelInfo): + hf_dtype: str = "auto" + hf_ppl: Optional[float] = None + + def dummy_hf_overrides( hf_config: PretrainedConfig, *, model_arch: str = "", exist_overrides: Optional[dict[str, Any]] = None, + use_original_num_layers: bool = False, ) -> PretrainedConfig: """ Dummy HF overrides function used to create dummy model @@ -405,10 +413,18 @@ def dummy_hf_overrides( # we use three layers for Gemma-3n to check # both normal layer and kv_shared_layer - num_hidden_layers = (3 if model_arch == "Gemma3nForConditionalGeneration" - else 1) + if use_original_num_layers: + # Use the original number of layers from the config + num_layers = getattr(text_config, 'num_layers', 1) + num_hidden_layers = getattr(text_config, 'num_hidden_layers', 1) + else: + # Use minimal layers for testing + num_layers = 1 + num_hidden_layers = (3 if model_arch + == "Gemma3nForConditionalGeneration" else 1) + text_config.update({ - "num_layers": 1, + "num_layers": num_layers, "num_hidden_layers": num_hidden_layers, "num_experts": num_experts, "num_experts_per_tok": 2, diff --git a/tests/multimodal/test_cache.py b/tests/multimodal/test_cache.py index 44c05db2278f7..3c61ee26e092e 100644 --- a/tests/multimodal/test_cache.py +++ b/tests/multimodal/test_cache.py @@ -10,8 +10,8 @@ from vllm.config import ModelConfig, ParallelConfig, VllmConfig from vllm.multimodal.cache import (MultiModalCache, MultiModalProcessorCacheItem, MultiModalProcessorCacheItemMetadata, - processor_cache_from_config, - receiver_cache_from_config) + engine_receiver_cache_from_config, + processor_cache_from_config) from vllm.multimodal.hasher import MultiModalHasher from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargsItem, MultiModalKwargsItems, @@ -115,9 +115,9 @@ def _compare_caches( ): mm_registry = MultiModalRegistry() cache_0_p0 = processor_cache_from_config(config_0, mm_registry) - cache_0_p1 = receiver_cache_from_config(config_0, mm_registry) + cache_0_p1 = engine_receiver_cache_from_config(config_0, mm_registry) cache_1_p0 = processor_cache_from_config(config_1, mm_registry) - cache_1_p1 = receiver_cache_from_config(config_1, mm_registry) + cache_1_p1 = engine_receiver_cache_from_config(config_1, mm_registry) cache_size_gb = max( config_0.model_config.mm_processor_cache_gb, diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index 886582a516409..e1e8282dd66d4 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -204,6 +204,32 @@ async def test_fetch_video_http(video_url: str, num_frames: int): assert metadata_sync == metadata_async +@pytest.mark.asyncio +@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS) +@pytest.mark.parametrize("max_duration", [1, 60, 1800]) +@pytest.mark.parametrize("requested_fps", [2, 24]) +async def test_fetch_video_http_with_dynamic_loader( + video_url: str, max_duration: int, requested_fps: int, + monkeypatch: pytest.MonkeyPatch): + with monkeypatch.context() as m: + m.setenv("VLLM_VIDEO_LOADER_BACKEND", "opencv_dynamic") + connector = MediaConnector( + media_io_kwargs={ + "video": { + "max_duration": max_duration, + "requested_fps": requested_fps, + } + }) + + video_sync, metadata_sync = connector.fetch_video(video_url) + video_async, metadata_async = await connector.fetch_video_async( + video_url) + + assert np.array_equal(video_sync, video_async) + assert metadata_sync == metadata_async + assert metadata_sync["video_backend"] == "opencv_dynamic" + + # Used for `test_argsort_mm_positions`. class TestCase(NamedTuple): mm_positions: "MultiModalPlaceholderDict" diff --git a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/__init__.py b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/__init__.py index a750c756c11a2..4bbb79c98a82a 100644 --- a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/__init__.py +++ b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/__init__.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -def register_prithvi_india(): - return "prithvi_io_processor.prithvi_processor.PrithviMultimodalDataProcessorIndia" # noqa: E501 -def register_prithvi_valencia(): - return "prithvi_io_processor.prithvi_processor.PrithviMultimodalDataProcessorValencia" # noqa: E501 +def register_prithvi(): + return "prithvi_io_processor.prithvi_processor.PrithviMultimodalDataProcessor" # noqa: E501 diff --git a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py index 0ebaafda94dc5..42874f0398f0a 100644 --- a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py +++ b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py @@ -234,6 +234,8 @@ def load_image( class PrithviMultimodalDataProcessor(IOProcessor): + indices = [0, 1, 2, 3, 4, 5] + def __init__(self, vllm_config: VllmConfig): super().__init__(vllm_config) @@ -412,21 +414,3 @@ class PrithviMultimodalDataProcessor(IOProcessor): format="tiff", data=out_data, request_id=request_id) - - -class PrithviMultimodalDataProcessorIndia(PrithviMultimodalDataProcessor): - - def __init__(self, vllm_config: VllmConfig): - - super().__init__(vllm_config) - - self.indices = [1, 2, 3, 8, 11, 12] - - -class PrithviMultimodalDataProcessorValencia(PrithviMultimodalDataProcessor): - - def __init__(self, vllm_config: VllmConfig): - - super().__init__(vllm_config) - - self.indices = [0, 1, 2, 3, 4, 5] diff --git a/tests/plugins/prithvi_io_processor_plugin/setup.py b/tests/plugins/prithvi_io_processor_plugin/setup.py index a03b1fbbd4a80..3ddda1a47bbe4 100644 --- a/tests/plugins/prithvi_io_processor_plugin/setup.py +++ b/tests/plugins/prithvi_io_processor_plugin/setup.py @@ -9,8 +9,7 @@ setup( packages=["prithvi_io_processor"], entry_points={ "vllm.io_processor_plugins": [ - "prithvi_to_tiff_india = prithvi_io_processor:register_prithvi_india", # noqa: E501 - "prithvi_to_tiff_valencia = prithvi_io_processor:register_prithvi_valencia", # noqa: E501 + "prithvi_to_tiff = prithvi_io_processor:register_prithvi", # noqa: E501 ] }, ) diff --git a/tests/plugins_tests/test_io_processor_plugins.py b/tests/plugins_tests/test_io_processor_plugins.py index 825165e89b33c..3567a701a3afa 100644 --- a/tests/plugins_tests/test_io_processor_plugins.py +++ b/tests/plugins_tests/test_io_processor_plugins.py @@ -11,7 +11,7 @@ from vllm.entrypoints.openai.protocol import IOProcessorResponse from vllm.plugins.io_processors import get_io_processor from vllm.pooling_params import PoolingParams -MODEL_NAME = "mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11" +MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11" image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501 @@ -35,7 +35,7 @@ def server(): "--max-num-seqs", "32", "--io-processor-plugin", - "prithvi_to_tiff_valencia", + "prithvi_to_tiff", "--model-impl", "terratorch", ] @@ -107,7 +107,7 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str): # to avoid the model going OOM in CI. max_num_seqs=1, model_impl="terratorch", - io_processor_plugin="prithvi_to_tiff_valencia", + io_processor_plugin="prithvi_to_tiff", ) as llm_runner: pooler_output = llm_runner.get_llm().encode( img_prompt, diff --git a/tests/quantization/test_torchao.py b/tests/quantization/test_torchao.py index eef3568efea12..8e68f6a2e019f 100644 --- a/tests/quantization/test_torchao.py +++ b/tests/quantization/test_torchao.py @@ -75,5 +75,25 @@ def test_qwenvl_int8wo_model_loading_with_params(vllm_runner): print(output) +@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available") +@pytest.mark.skip( + reason="since torchao nightly is only compatible with torch nightly" + "currently https://github.com/pytorch/ao/issues/2919, we'll have to skip " + "torchao tests that requires newer versions (0.14.0.dev+) for now") +def test_opt_125m_awq_int4wo_model_loading_with_params(vllm_runner): + torch._dynamo.reset() + model_name = ("torchao-testing/opt-125m-AWQConfig-Int4WeightOnlyConfig-v2" + "-0.14.0.dev") + with vllm_runner(model_name=model_name, + quantization="torchao", + dtype="bfloat16", + pt_load_map_location="cuda:0") as llm: + output = llm.generate_greedy(["The capital of France is"], + max_tokens=32) + + assert output + print(output) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/runai_model_streamer_test/test_runai_model_streamer_loader.py b/tests/runai_model_streamer_test/test_runai_model_streamer_loader.py index 84c615b6b8dbc..22bdb3b44eb03 100644 --- a/tests/runai_model_streamer_test/test_runai_model_streamer_loader.py +++ b/tests/runai_model_streamer_test/test_runai_model_streamer_loader.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm import SamplingParams -from vllm.config import LoadConfig +from vllm.config.load import LoadConfig from vllm.model_executor.model_loader import get_model_loader load_format = "runai_streamer" diff --git a/tests/runai_model_streamer_test/test_runai_utils.py b/tests/runai_model_streamer_test/test_runai_utils.py new file mode 100644 index 0000000000000..bde77ff665063 --- /dev/null +++ b/tests/runai_model_streamer_test/test_runai_utils.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import glob +import os +import tempfile + +import huggingface_hub.constants + +from vllm.model_executor.model_loader.weight_utils import ( + download_weights_from_hf) +from vllm.transformers_utils.runai_utils import (is_runai_obj_uri, + list_safetensors) + + +def test_is_runai_obj_uri(): + assert is_runai_obj_uri("gs://some-gcs-bucket/path") + assert is_runai_obj_uri("s3://some-s3-bucket/path") + assert not is_runai_obj_uri("nfs://some-nfs-path") + + +def test_runai_list_safetensors_local(): + with tempfile.TemporaryDirectory() as tmpdir: + huggingface_hub.constants.HF_HUB_OFFLINE = False + download_weights_from_hf("openai-community/gpt2", + allow_patterns=["*.safetensors", "*.json"], + cache_dir=tmpdir) + safetensors = glob.glob(f"{tmpdir}/**/*.safetensors", recursive=True) + assert len(safetensors) > 0 + parentdir = [ + os.path.dirname(safetensor) for safetensor in safetensors + ][0] + files = list_safetensors(parentdir) + assert len(safetensors) == len(files) + + +if __name__ == "__main__": + test_is_runai_obj_uri() + test_runai_list_safetensors_local() diff --git a/tests/tensorizer_loader/conftest.py b/tests/tensorizer_loader/conftest.py index 18aa4c88c0338..571dc2e0eb50f 100644 --- a/tests/tensorizer_loader/conftest.py +++ b/tests/tensorizer_loader/conftest.py @@ -90,6 +90,7 @@ class DummyExecutor(UniProcExecutor): distributed_init_method=distributed_init_method, is_driver_worker=is_driver_worker, ) + self.mm_receiver_cache = None self.collective_rpc("init_worker", args=([kwargs], )) self.collective_rpc("init_device") diff --git a/tests/tensorizer_loader/test_tensorizer.py b/tests/tensorizer_loader/test_tensorizer.py index 0fb142a1b6e56..e00d7c2f80c67 100644 --- a/tests/tensorizer_loader/test_tensorizer.py +++ b/tests/tensorizer_loader/test_tensorizer.py @@ -161,11 +161,11 @@ def test_load_without_tensorizer_load_format(vllm_runner, capfd, model_ref): model = vllm_runner( model_ref, model_loader_extra_config=TensorizerConfig(tensorizer_uri="test")) + pytest.fail("Expected RuntimeError for extra config keys") except RuntimeError: out, err = capfd.readouterr() combined_output = out + err - assert ("ValueError: Model loader extra config " - "is not supported for load " + assert ("ValueError: Unexpected extra config keys for load " "format auto") in combined_output finally: del model @@ -181,11 +181,12 @@ def test_raise_value_error_on_invalid_load_format(vllm_runner, capfd, model_ref, load_format="safetensors", model_loader_extra_config=TensorizerConfig(tensorizer_uri="test")) + pytest.fail("Expected RuntimeError for extra config keys") except RuntimeError: out, err = capfd.readouterr() combined_output = out + err - assert ("ValueError: Model loader extra config is not supported " + assert ("ValueError: Unexpected extra config keys " "for load format safetensors") in combined_output finally: del model diff --git a/tests/test_config.py b/tests/test_config.py index 957771a4226bc..373fbd267539a 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -6,8 +6,9 @@ from dataclasses import MISSING, Field, asdict, dataclass, field import pytest from vllm.compilation.backends import VllmBackend -from vllm.config import (LoadConfig, ModelConfig, PoolerConfig, VllmConfig, - get_field, update_config) +from vllm.config import (ModelConfig, PoolerConfig, VllmConfig, get_field, + update_config) +from vllm.config.load import LoadConfig from vllm.model_executor.layers.pooler import PoolingType from vllm.platforms import current_platform diff --git a/tests/tpu/test_quantization_accuracy.py b/tests/tpu/test_quantization_accuracy.py index 6cefbae4bdd18..8d9fbd280317c 100644 --- a/tests/tpu/test_quantization_accuracy.py +++ b/tests/tpu/test_quantization_accuracy.py @@ -28,7 +28,7 @@ ACCURACY_CONFIGS = [ expected_value=0.76), # no bias # NOTE(rob): We cannot re-initialize vLLM in the same process for TPU, # so only one of these tests can run in a single call to pytest. As - # a follow up, move this into the LM-EVAL section of the CI. + # a follow-up, move this into the LM-EVAL section of the CI. # GSM8KAccuracyTestConfig( # model_name="neuralmagic/Qwen2-7B-Instruct-quantized.w8a8", # expected_value=0.66), # bias in QKV layers diff --git a/tests/transformers_utils/__init__.py b/tests/transformers_utils/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/transformers_utils/test_config_parser_registry.py b/tests/transformers_utils/test_config_parser_registry.py new file mode 100644 index 0000000000000..13c654e05d2ac --- /dev/null +++ b/tests/transformers_utils/test_config_parser_registry.py @@ -0,0 +1,37 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from pathlib import Path +from typing import Optional, Union + +import pytest +from transformers import PretrainedConfig + +from vllm.transformers_utils.config import (get_config_parser, + register_config_parser) +from vllm.transformers_utils.config_parser_base import ConfigParserBase + + +@register_config_parser("custom_config_parser") +class CustomConfigParser(ConfigParserBase): + + def parse(self, + model: Union[str, Path], + trust_remote_code: bool, + revision: Optional[str] = None, + code_revision: Optional[str] = None, + **kwargs) -> tuple[dict, PretrainedConfig]: + raise NotImplementedError + + +def test_register_config_parser(): + assert isinstance(get_config_parser("custom_config_parser"), + CustomConfigParser) + + +def test_invalid_config_parser(): + with pytest.raises(ValueError): + + @register_config_parser("invalid_config_parser") + class InvalidConfigParser: + pass diff --git a/tests/utils.py b/tests/utils.py index e47235002657d..16e1e60393290 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio +import contextlib import copy import functools import importlib @@ -13,7 +14,7 @@ import sys import tempfile import time import warnings -from contextlib import contextmanager, suppress +from contextlib import ExitStack, contextmanager, suppress from multiprocessing import Process from pathlib import Path from typing import Any, Callable, Literal, Optional, Union @@ -800,43 +801,106 @@ _P = ParamSpec("_P") def fork_new_process_for_each_test( - f: Callable[_P, None]) -> Callable[_P, None]: + func: Callable[_P, None]) -> Callable[_P, None]: """Decorator to fork a new process for each test function. See https://github.com/vllm-project/vllm/issues/7053 for more details. """ - @functools.wraps(f) + @functools.wraps(func) def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: # Make the process the leader of its own process group # to avoid sending SIGTERM to the parent process os.setpgrp() from _pytest.outcomes import Skipped - pid = os.fork() - print(f"Fork a new process to run a test {pid}") - if pid == 0: - try: - f(*args, **kwargs) - except Skipped as e: - # convert Skipped to exit code 0 - print(str(e)) - os._exit(0) - except Exception: - import traceback - traceback.print_exc() - os._exit(1) + + # Create a unique temporary file to store exception info from child + # process. Use test function name and process ID to avoid collisions. + with tempfile.NamedTemporaryFile( + delete=False, + mode='w+b', + prefix=f"vllm_test_{func.__name__}_{os.getpid()}_", + suffix=".exc") as exc_file, ExitStack() as delete_after: + exc_file_path = exc_file.name + delete_after.callback(os.remove, exc_file_path) + + pid = os.fork() + print(f"Fork a new process to run a test {pid}") + if pid == 0: + # Parent process responsible for deleting, don't delete + # in child. + delete_after.pop_all() + try: + func(*args, **kwargs) + except Skipped as e: + # convert Skipped to exit code 0 + print(str(e)) + os._exit(0) + except Exception as e: + import traceback + tb_string = traceback.format_exc() + + # Try to serialize the exception object first + exc_to_serialize: dict[str, Any] + try: + # First, try to pickle the actual exception with + # its traceback. + exc_to_serialize = {'pickled_exception': e} + # Test if it can be pickled + cloudpickle.dumps(exc_to_serialize) + except (Exception, KeyboardInterrupt): + # Fall back to string-based approach. + exc_to_serialize = { + 'exception_type': type(e).__name__, + 'exception_msg': str(e), + 'traceback': tb_string, + } + try: + with open(exc_file_path, 'wb') as f: + cloudpickle.dump(exc_to_serialize, f) + except Exception: + # Fallback: just print the traceback. + print(tb_string) + os._exit(1) + else: + os._exit(0) else: - os._exit(0) - else: - pgid = os.getpgid(pid) - _pid, _exitcode = os.waitpid(pid, 0) - # ignore SIGTERM signal itself - old_signal_handler = signal.signal(signal.SIGTERM, signal.SIG_IGN) - # kill all child processes - os.killpg(pgid, signal.SIGTERM) - # restore the signal handler - signal.signal(signal.SIGTERM, old_signal_handler) - assert _exitcode == 0, (f"function {f} failed when called with" - f" args {args} and kwargs {kwargs}") + pgid = os.getpgid(pid) + _pid, _exitcode = os.waitpid(pid, 0) + # ignore SIGTERM signal itself + old_signal_handler = signal.signal(signal.SIGTERM, + signal.SIG_IGN) + # kill all child processes + os.killpg(pgid, signal.SIGTERM) + # restore the signal handler + signal.signal(signal.SIGTERM, old_signal_handler) + if _exitcode != 0: + # Try to read the exception from the child process + exc_info = {} + if os.path.exists(exc_file_path): + with contextlib.suppress(Exception), \ + open(exc_file_path, 'rb') as f: + exc_info = cloudpickle.load(f) + + if (original_exception := + exc_info.get('pickled_exception')) is not None: + # Re-raise the actual exception object if it was + # successfully pickled. + assert isinstance(original_exception, Exception) + raise original_exception + + if (original_tb := exc_info.get("traceback")) is not None: + # Use string-based traceback for fallback case + raise AssertionError( + f"Test {func.__name__} failed when called with" + f" args {args} and kwargs {kwargs}" + f" (exit code: {_exitcode}):\n{original_tb}" + ) from None + + # Fallback to the original generic error + raise AssertionError( + f"function {func.__name__} failed when called with" + f" args {args} and kwargs {kwargs}" + f" (exit code: {_exitcode})") from None return wrapper diff --git a/tests/utils_/test_utils.py b/tests/utils_/test_utils.py index 66124dd854ee0..6dbba18b4dcfa 100644 --- a/tests/utils_/test_utils.py +++ b/tests/utils_/test_utils.py @@ -835,22 +835,20 @@ def test_model_specification(parser_with_config, cli_config_file, @pytest.mark.parametrize("input", [(), ("abc", ), (None, ), (None, bool, [1, 2, 3])]) -@pytest.mark.parametrize("output", [0, 1, 2]) -def test_sha256(input: tuple, output: int): - hash = sha256(input) - assert hash is not None - assert isinstance(hash, int) - assert hash != 0 +def test_sha256(input: tuple): + digest = sha256(input) + assert digest is not None + assert isinstance(digest, bytes) + assert digest != b"" - bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) - assert hash == int.from_bytes(hashlib.sha256(bytes).digest(), - byteorder="big") + input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) + assert digest == hashlib.sha256(input_bytes).digest() # hashing again, returns the same value - assert hash == sha256(input) + assert digest == sha256(input) # hashing different input, returns different value - assert hash != sha256(input + (1, )) + assert digest != sha256(input + (1, )) @pytest.mark.parametrize( diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index 1ae8b91c347a2..0b7e103beca63 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -178,6 +178,7 @@ class MockAttentionLayer: self._k_scale = torch.tensor(1.0, device=device) self._v_scale = torch.tensor(1.0, device=device) # Add float versions for flashinfer + self._q_scale_float = 1.0 self._k_scale_float = 1.0 self._v_scale_float = 1.0 diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index e7cd116fdc834..a62993950affe 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -73,7 +73,6 @@ def create_and_prepopulate_kv_cache( kv_c_contexts: list[torch.Tensor], k_pe_contexts: list[torch.Tensor], block_size: int, - num_kv_heads: int, head_size: int, dtype: torch.dtype, device: torch.device, @@ -87,7 +86,6 @@ def create_and_prepopulate_kv_cache( k_pe_contexts: List of key positional embedding context tensors for each sequence block_size: Size of each block - num_kv_heads: Number of KV heads (should be 1 for MLA) head_size: Size of each head (latent dimension) dtype: Data type for the cache device: Device to create the cache on @@ -285,8 +283,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): query_lens = batch_spec.query_lens num_q_heads = vllm_config.model_config.get_num_attention_heads( vllm_config.parallel_config) - num_kv_heads = vllm_config.model_config.get_num_kv_heads( - vllm_config.parallel_config) head_size = vllm_config.model_config.get_head_size() dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype) block_size = vllm_config.cache_config.block_size @@ -476,7 +472,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): kv_c_contexts=kv_c_contexts, k_pe_contexts=k_pe_contexts, block_size=block_size, - num_kv_heads=num_kv_heads, head_size=head_size, dtype=dtype, device=device, diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index 5c49566240df4..f07c6eb0ea4da 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -141,6 +141,8 @@ def get_attention_backend(backend_name: _Backend): "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend", _Backend.FLASH_ATTN_MLA: "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend", + _Backend.FLASHINFER_MLA: + "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend", _Backend.TRITON_MLA_VLLM_V1: "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend", } diff --git a/tests/v1/core/test_encoder_cache_manager.py b/tests/v1/core/test_encoder_cache_manager.py index ae5b751f45a4b..4e3cace86be6a 100644 --- a/tests/v1/core/test_encoder_cache_manager.py +++ b/tests/v1/core/test_encoder_cache_manager.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.multimodal.inputs import MultiModalFeatureSpec, PlaceholderRange from vllm.v1.core.encoder_cache_manager import EncoderCacheManager @@ -9,8 +10,17 @@ class MockRequest: def __init__(self, request_id, mm_hashes, token_counts): self.request_id = request_id - self.mm_hashes = mm_hashes self._token_counts = token_counts + self.mm_features = [] + for i, mm_hash in enumerate(mm_hashes): + feature = MultiModalFeatureSpec( + data=None, + modality="image", + identifier=mm_hash, + mm_position=PlaceholderRange(offset=0, + length=self._token_counts[i]), + ) + self.mm_features.append(feature) def get_num_encoder_tokens(self, input_id: int) -> int: return self._token_counts[input_id] diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 4d0a26f76e98e..44e479098ad5d 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -6,20 +6,22 @@ from typing import Callable, Optional import pytest import torch +import vllm.v1.core.kv_cache_utils as kv_cache_utils from vllm.config import ModelConfig, SchedulerConfig, VllmConfig from vllm.multimodal.inputs import (MultiModalFeatureSpec, MultiModalKwargsItem, PlaceholderRange) from vllm.sampling_params import SamplingParams -from vllm.utils import GiB_bytes, sha256, sha256_cbor_64bit +from vllm.utils import GiB_bytes, sha256, sha256_cbor from vllm.v1.core.kv_cache_manager import KVCacheManager # disable yapf here as it formats differently than isort such that both fail # yapf: disable from vllm.v1.core.kv_cache_utils import ( - FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics, + BlockHash, FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics, estimate_max_model_len, generate_block_hash_extra_keys, get_kv_cache_config, get_max_concurrency_for_kv_cache_config, get_request_block_hasher, hash_block_tokens, init_none_hash, - is_kv_cache_type_uniform, unify_kv_cache_configs) + is_kv_cache_type_uniform, make_block_hash_with_group_id, + unify_kv_cache_configs) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheTensor, SlidingWindowSpec) @@ -88,7 +90,7 @@ def new_sliding_window_spec(block_size=16, sliding_window=sliding_window) -@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) def test_none_hash(monkeypatch, hash_fn): import vllm.v1.core.kv_cache_utils @@ -98,8 +100,8 @@ def test_none_hash(monkeypatch, hash_fn): reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils) reloaded_kv_cache_utils.init_none_hash(hash_fn) assert reloaded_kv_cache_utils.NONE_HASH is not None - assert isinstance(reloaded_kv_cache_utils.NONE_HASH, int) - assert reloaded_kv_cache_utils.NONE_HASH != 0 + assert isinstance(reloaded_kv_cache_utils.NONE_HASH, bytes) + assert reloaded_kv_cache_utils.NONE_HASH != b"" # case 2: PYTHONHASHSEED is set, use the seed and hash_fn with monkeypatch.context() as m: @@ -107,12 +109,11 @@ def test_none_hash(monkeypatch, hash_fn): reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils) reloaded_kv_cache_utils.init_none_hash(hash_fn) assert reloaded_kv_cache_utils.NONE_HASH is not None - assert isinstance(reloaded_kv_cache_utils.NONE_HASH, int) + assert isinstance(reloaded_kv_cache_utils.NONE_HASH, bytes) assert hash_fn('python hash seed') == reloaded_kv_cache_utils.NONE_HASH def test_kv_cache_block(): - import vllm.v1.core.kv_cache_utils # Test KVCacheBlock initialization block = KVCacheBlock(block_id=0) @@ -127,8 +128,7 @@ def test_kv_cache_block(): assert block.ref_cnt == 0 # Test block hash setting and resetting - block_hash = vllm.v1.core.kv_cache_utils.BlockHash(hash_value=123, - token_ids=(1, 2, 3)) + block_hash = make_block_hash_with_group_id(BlockHash(b"abc"), 0) block.block_hash = block_hash assert block.block_hash == block_hash @@ -407,27 +407,23 @@ def test_generate_block_hash_extra_keys_cache_salt(): assert next_mm_idx == 1 -@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) def test_hash_block_tokens(hash_fn): - import vllm.v1.core.kv_cache_utils init_none_hash(hash_fn) - parent_block_hash = 123 + parent_block_hash = BlockHash(b"123") curr_block_token_ids = (1, 2, 3) extra_keys = ("key1", "key2") block_hash = hash_block_tokens(hash_fn, parent_block_hash, curr_block_token_ids, extra_keys) - assert isinstance(block_hash, vllm.v1.core.kv_cache_utils.BlockHash) - assert block_hash.hash_value == hash_fn( - (parent_block_hash, curr_block_token_ids, extra_keys)) - assert block_hash.token_ids == curr_block_token_ids - assert block_hash.extra_keys == extra_keys + expected = hash_fn((parent_block_hash, curr_block_token_ids, extra_keys)) + assert block_hash == expected -@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) def test_request_block_hasher(hash_fn): - import vllm.v1.core.kv_cache_utils - init_none_hash(hash_fn) + kv_cache_utils.init_none_hash(hash_fn) + request = make_request( request_id="0", prompt_token_ids=[_ for _ in range(6)], @@ -442,19 +438,13 @@ def test_request_block_hasher(hash_fn): block_hashes = request.block_hashes assert len(block_hashes) == 2 - assert isinstance(block_hashes[0], vllm.v1.core.kv_cache_utils.BlockHash) - assert isinstance(block_hashes[1], vllm.v1.core.kv_cache_utils.BlockHash) - - # Check the first block - assert block_hashes[0].token_ids == (0, 1, 2) - assert block_hashes[0].extra_keys == ("hash1", ) - - # Check the second block - assert block_hashes[1].token_ids == (3, 4, 5) - assert block_hashes[1].extra_keys == ("hash2", ) + assert block_hashes[0] == hash_fn( + (kv_cache_utils.NONE_HASH, (0, 1, 2), ("hash1", ))) + assert block_hashes[1] == hash_fn( + (block_hashes[0], (3, 4, 5), ("hash2", ))) -@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) def test_hash_tokens_different_mm_input(hash_fn): init_none_hash(hash_fn) @@ -484,9 +474,9 @@ def test_hash_tokens_different_mm_input(hash_fn): assert block_hashes1[1] != block_hashes2[1] -@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) def test_hash_request_tokens_no_mm_inputs(hash_fn): - init_none_hash(hash_fn) + kv_cache_utils.init_none_hash(hash_fn) request = make_request( request_id="0", @@ -500,10 +490,9 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn): block_hashes = request.block_hashes assert len(block_hashes) == 2 - assert block_hashes[0].token_ids == (0, 1, 2) - assert block_hashes[0].extra_keys is None - assert block_hashes[1].token_ids == (3, 4, 5) - assert block_hashes[1].extra_keys is None + assert block_hashes[0] == hash_fn( + (kv_cache_utils.NONE_HASH, (0, 1, 2), None)) + assert block_hashes[1] == hash_fn((block_hashes[0], (3, 4, 5), None)) def test_metrics(): diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index e7a8f63702b30..659d768bcf2e9 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -8,17 +8,19 @@ from typing import Callable, Optional import pytest import torch +import vllm.v1.core.kv_cache_utils as kv_cache_utils from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved from vllm.multimodal.inputs import (MultiModalFeatureSpec, MultiModalKwargsItem, PlaceholderRange) from vllm.sampling_params import SamplingParams -from vllm.utils import sha256, sha256_cbor_64bit +from vllm.utils import sha256, sha256_cbor from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_manager import KVCacheManager, Request -from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, - KVCacheBlock, +from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock, + get_block_hash, get_group_id, get_request_block_hasher, - hash_block_tokens, init_none_hash) + hash_block_tokens, init_none_hash, + make_block_hash_with_group_id) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, SlidingWindowSpec) @@ -101,8 +103,10 @@ def make_kv_cache_config_hybrid_model(block_size: int, ) -@pytest.mark.parametrize("hash_algo", ["sha256", "sha256_cbor_64bit", "hash"]) -def test_prefill(hash_algo): +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) +def test_prefill(hash_fn): + init_none_hash(hash_fn) + block_size = 16 manager = KVCacheManager( make_kv_cache_config(block_size, 11), @@ -110,10 +114,6 @@ def test_prefill(hash_algo): enable_caching=True, ) - # choose the hash function according to the parameter - hash_fn = (sha256_cbor_64bit if hash_algo == "sha256_cbor_64bit" else - sha256 if hash_algo == "sha256" else hash) - # Complete 3 blocks (48 tokens) common_token_ids = [i for i in range(3) for _ in range(16)] @@ -137,10 +137,12 @@ def test_prefill(hash_algo): block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16]) block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_tokens) - assert manager.block_pool.blocks[ - block_id].block_hash.block_hash == block_hash + blk_hash = manager.block_pool.blocks[block_id].block_hash + assert blk_hash is not None + assert get_block_hash(blk_hash) == block_hash + assert get_group_id(blk_hash) == 0 assert manager.block_pool.blocks[block_id].ref_cnt == 1 - parent_block_hash = block_hash.hash_value + parent_block_hash = block_hash # Check partial block metadata for block_id in (4, ): @@ -233,7 +235,7 @@ def test_prefill_hybrid_model(): enable_caching=True, ) - hash_fn = hash + hash_fn = sha256 # Complete 3 blocks (48 tokens) common_token_ids = [i for i in range(3) for _ in range(block_size)] @@ -260,11 +262,13 @@ def test_prefill_hybrid_model(): block_tokens = tuple(all_token_ids[(length - 1) * 16:length * 16]) block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_tokens) - for block_id in block_ids: - assert manager.block_pool.blocks[ - block_id].block_hash.block_hash == block_hash + for group_id, block_id in enumerate(block_ids): + blk_hash = manager.block_pool.blocks[block_id].block_hash + assert blk_hash is not None + assert get_block_hash(blk_hash) == block_hash + assert get_group_id(blk_hash) == group_id assert manager.block_pool.blocks[block_id].ref_cnt == 1 - parent_block_hash = block_hash.hash_value + parent_block_hash = block_hash # Check partial block metadata for block_id in (4, 8, 12): @@ -298,11 +302,10 @@ def test_prefill_hybrid_model(): cached_block_hash_to_block_bak = copy.copy( manager.block_pool.cached_block_hash_to_block) - def test_partial_request_hit(request_id: str, - hash_to_evict: list[BlockHashWithGroupId], + def test_partial_request_hit(request_id: str, hash_to_evict: list[bytes], expect_hit_length: int): req = make_request(request_id, common_token_ids + unique_token_ids, - block_size, hash) + block_size, sha256) for hash_with_group_id in hash_to_evict: manager.block_pool.cached_block_hash_to_block.pop( hash_with_group_id) @@ -319,33 +322,32 @@ def test_prefill_hybrid_model(): # Evict the blocks outside sliding window, does not affect the hit length. test_partial_request_hit("2", [ - BlockHashWithGroupId(block_hashes[0], 1), - BlockHashWithGroupId(block_hashes[0], 2) + make_block_hash_with_group_id(block_hashes[0], 1), + make_block_hash_with_group_id(block_hashes[0], 2) ], 3) # Evict the first block of full attention, makes total cache miss. - test_partial_request_hit("3", [ - BlockHashWithGroupId(block_hashes[0], 0), - ], 0) + test_partial_request_hit( + "3", [make_block_hash_with_group_id(block_hashes[0], 0)], 0) # Evict the last block of all layers, reduces the hit length to 2. test_partial_request_hit("4", [ - BlockHashWithGroupId(block_hashes[2], 0), - BlockHashWithGroupId(block_hashes[2], 1), - BlockHashWithGroupId(block_hashes[2], 2), + make_block_hash_with_group_id(block_hashes[2], 0), + make_block_hash_with_group_id(block_hashes[2], 1), + make_block_hash_with_group_id(block_hashes[2], 2), ], 2) # Evict the last block of full attention, reduces the hit length to 2. - test_partial_request_hit("5", [BlockHashWithGroupId(block_hashes[2], 0)], - 2) + test_partial_request_hit( + "5", [make_block_hash_with_group_id(block_hashes[2], 0)], 2) # Evict the last block of sliding window, reduces the hit length to 2. - test_partial_request_hit("6", [BlockHashWithGroupId(block_hashes[2], 1)], - 2) + test_partial_request_hit( + "6", [make_block_hash_with_group_id(block_hashes[2], 1)], 2) # Evict the last block of sliding window, reduces the hit length to 2. - test_partial_request_hit("7", [BlockHashWithGroupId(block_hashes[2], 2)], - 2) + test_partial_request_hit( + "7", [make_block_hash_with_group_id(block_hashes[2], 2)], 2) # Evict different set of blocks for full attention and sliding window makes # total cache miss. @@ -353,9 +355,9 @@ def test_prefill_hybrid_model(): # The cache hit length of sliding window is 2 * block_size. # Then it is cache miss as the two type of layers have different hit length. test_partial_request_hit("8", [ - BlockHashWithGroupId(block_hashes[2], 0), - BlockHashWithGroupId(block_hashes[0], 1), - BlockHashWithGroupId(block_hashes[0], 2), + make_block_hash_with_group_id(block_hashes[2], 0), + make_block_hash_with_group_id(block_hashes[0], 1), + make_block_hash_with_group_id(block_hashes[0], 2), ], 0) @@ -372,8 +374,8 @@ def test_prefill_plp(): max_model_len=8192, enable_caching=True, ) - # the default hash function is hash - hash_fn = hash + # the default hash function is sha256 + hash_fn = sha256 # Complete 3 blocks (48 tokens) common_token_ids = [i for i in range(3) for _ in range(16)] @@ -404,10 +406,12 @@ def test_prefill_plp(): block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16]) block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_tokens) - assert manager.block_pool.blocks[ - block_id].block_hash.block_hash == block_hash + blk_hash = (manager.block_pool.blocks[block_id].block_hash) + assert blk_hash is not None + assert get_block_hash(blk_hash) == block_hash + assert get_group_id(blk_hash) == 0 assert manager.block_pool.blocks[block_id].ref_cnt == 1 - parent_block_hash = block_hash.hash_value + parent_block_hash = block_hash # Check partial block metadata for block_id in (4, ): @@ -493,7 +497,7 @@ def test_decode(): # Incomplete 1 block (7 tokens) unique_token_ids = [3] * 7 req0 = make_request("0", common_token_ids + unique_token_ids, block_size, - hash) + sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -538,7 +542,7 @@ def test_evict(): ) last_token_id = 5 * 16 + 7 - req0 = make_request("0", list(range(last_token_id)), block_size, hash) + req0 = make_request("0", list(range(last_token_id)), block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -550,7 +554,7 @@ def test_evict(): # 3 blocks. req1 = make_request("1", list(range(last_token_id, last_token_id + 3 * 16)), block_size, - hash) + sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -572,7 +576,7 @@ def test_evict(): ] == [10, 6, 5, 4, 3, 2, 1, 9, 8, 7] # Touch the first 2 blocks. - req2 = make_request("2", list(range(2 * 16 + 3)), block_size, hash) + req2 = make_request("2", list(range(2 * 16 + 3)), block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert computed_blocks.get_block_ids() == ([1, 2], ) assert num_computed_tokens == 2 * 16 @@ -597,7 +601,7 @@ def test_hash_block_correct_reuse(): # Allocate 1 block and cache it. num_tokens = block_size * 1 - req = make_request("0", list(range(num_tokens)), block_size, hash) + req = make_request("0", list(range(num_tokens)), block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -611,7 +615,7 @@ def test_hash_block_correct_reuse(): # Allocate a new block that's not full, make sure hash info on the # block is cleared. - req = make_request("1", list(range(num_tokens - 1)), block_size, hash) + req = make_request("1", list(range(num_tokens - 1)), block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -638,7 +642,7 @@ def test_computed_blocks_not_evicted(): # Allocate a block and cache it. num_tokens = block_size * 1 - req0 = make_request("0", list(range(num_tokens)), block_size, hash) + req0 = make_request("0", list(range(num_tokens)), block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -650,7 +654,7 @@ def test_computed_blocks_not_evicted(): # Allocate another block. req1 = make_request("1", list(range(num_tokens, num_tokens * 2)), - block_size, hash) + block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -666,7 +670,7 @@ def test_computed_blocks_not_evicted(): # Now if we have a cache hit on the first block, we should evict the second # cached block rather than the first one. - req2 = make_request("2", list(range(num_tokens * 2)), block_size, hash) + req2 = make_request("2", list(range(num_tokens * 2)), block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(computed_blocks.blocks[0]) == 1 assert computed_blocks.blocks[0][0].block_id == 1 @@ -691,7 +695,7 @@ def test_basic_prefix_caching_disabled(): ) req1 = make_request("1", list(range(10)), block_size, - hash) # 2 blocks and some more + sha256) # 2 blocks and some more computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks.blocks[0] @@ -706,7 +710,7 @@ def test_basic_prefix_caching_disabled(): # No caching. req2 = make_request("2", list(range(16)), block_size, - hash) # shared prefix + sha256) # shared prefix computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -716,7 +720,7 @@ def test_basic_prefix_caching_disabled(): assert len(blocks.blocks[0]) == 4 # New requests should not have any blocks. - req3 = make_request("3", list(range(4)), block_size, hash) + req3 = make_request("3", list(range(4)), block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -726,7 +730,7 @@ def test_basic_prefix_caching_disabled(): assert not blocks -@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) def test_cache_blocks(hash_fn): """ This is a unit test that tests the correctness of the _cache_full_blocks @@ -787,7 +791,7 @@ def test_cache_blocks_multi_group(): # Block 1/5: [4, 5, 6, 7] # Block 2/6: [8, 9, 10, 11] # Block 3/7: [12, 13] - req = make_request("0", list(range(14)), block_size, hash) + req = make_request("0", list(range(14)), block_size, sha256) # Cache the blocks for group 0. blocks = [KVCacheBlock(block_id=i) for i in range(2)] @@ -845,6 +849,8 @@ def test_mm_prefix_caching(): """ This tests that the multi-modal prefix caching is correct. """ + kv_cache_utils.init_none_hash(sha256) + block_size = 16 manager = KVCacheManager( make_kv_cache_config(block_size, 11), @@ -874,23 +880,30 @@ def test_mm_prefix_caching(): req0 = make_request("0", all_token_ids, block_size, - hash, + sha256, mm_positions=mm_positions, mm_hashes=mm_hashes) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - # Completed block should have hashes with extra keys. + # Completed block should have hashes assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 block_hashes = req0.block_hashes assert len(block_hashes) == 3 - assert block_hashes[0].extra_keys == ("aaa", ) - assert block_hashes[1].extra_keys == ("aaa", "bbb") - assert block_hashes[2].extra_keys == ("bbb", ) + assert block_hashes[0] == sha256( + (kv_cache_utils.NONE_HASH, tuple(all_token_ids[:block_size]), + ("aaa", ))) + assert block_hashes[1] == sha256( + (block_hashes[0], tuple(all_token_ids[block_size:block_size * 2]), + ("aaa", "bbb"))) + assert block_hashes[2] == sha256( + (block_hashes[1], tuple(all_token_ids[block_size * 2:block_size * 3]), + ("bbb", ))) blocks = manager.allocate_slots(req0, 59, len(computed_blocks.blocks[0]) * 16, computed_blocks) + assert blocks is not None assert blocks.get_block_ids() == ([1, 2, 3, 4], ) req0.num_computed_tokens = 59 @@ -901,10 +914,10 @@ def test_mm_prefix_caching(): len(computed_blocks.blocks[0]) * 16, computed_blocks) assert new_blocks is not None and len(new_blocks.blocks[0]) == 0 - - # The just completed block should have hashes with extra keys. assert len(block_hashes) == 4 - assert block_hashes[3].extra_keys == ("ccc", ) + assert block_hashes[3] == sha256( + (block_hashes[2], tuple(all_token_ids[3 * block_size:] + [8] * 5), + ("ccc", ))) # Cache hit. unique_token_ids = [-1] * 7 + [200] * 5 @@ -916,7 +929,7 @@ def test_mm_prefix_caching(): req1 = make_request("1", all_token_ids, block_size, - hash, + sha256, mm_positions=mm_positions, mm_hashes=mm_hashes) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) @@ -929,6 +942,8 @@ def test_cache_key_salting(): This tests that cache salts are applied during hashing and the cache is separated cache as expected. """ + kv_cache_utils.init_none_hash(sha256) + block_size = 16 manager = KVCacheManager( make_kv_cache_config(block_size, 11), @@ -939,21 +954,26 @@ def test_cache_key_salting(): # 3 complete blocks and an incomplete block with 11 tokens. common_token_ids = [i for i in range(3) for _ in range(block_size)] token_ids = common_token_ids + [3] * 11 - req0 = make_request("0", token_ids, block_size, hash, cache_salt="salt1") + req0 = make_request("0", token_ids, block_size, sha256, cache_salt="salt1") computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - # Completed block should have hashes with extra keys. + # Completed block should have hashes assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 block_hashes = req0.block_hashes assert len(block_hashes) == 3 - assert block_hashes[0].extra_keys == ("salt1", ) - assert block_hashes[1].extra_keys is None - assert block_hashes[2].extra_keys is None + assert block_hashes[0] == sha256( + (kv_cache_utils.NONE_HASH, tuple(token_ids[:block_size]), ("salt1", ))) + assert block_hashes[1] == sha256( + (block_hashes[0], tuple(token_ids[block_size:block_size * 2]), None)) + assert block_hashes[2] == sha256( + (block_hashes[1], tuple(token_ids[block_size * 2:block_size * 3]), + None)) blocks = manager.allocate_slots(req0, 59, len(computed_blocks.blocks[0]) * 16, computed_blocks) + assert blocks is not None assert blocks.get_block_ids() == ([1, 2, 3, 4], ) req0.num_computed_tokens = 59 @@ -964,14 +984,13 @@ def test_cache_key_salting(): len(computed_blocks.blocks[0]) * 16, computed_blocks) assert new_blocks is not None and len(new_blocks.blocks[0]) == 0 - - # Now one more block that should not have extra keys. assert len(block_hashes) == 4 - assert block_hashes[3].extra_keys is None + assert block_hashes[3] == sha256( + (block_hashes[2], tuple(token_ids[3 * block_size:] + [8] * 5), None)) # Test cache hit with a new request that has the same salt. token_ids = common_token_ids + [4] * 11 - req1 = make_request("1", token_ids, block_size, hash, cache_salt="salt1") + req1 = make_request("1", token_ids, block_size, sha256, cache_salt="salt1") computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) # Should match only a prefix of 3 blocks. assert len(computed_blocks.blocks[0]) == 3 @@ -979,13 +998,19 @@ def test_cache_key_salting(): # Test cache miss with same content but different salt. token_ids = common_token_ids + [4] * 11 - req2 = make_request("2", token_ids, block_size, hash, cache_salt="salt2") + req2 = make_request("2", token_ids, block_size, sha256, cache_salt="salt2") computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(computed_blocks.blocks[0]) == 0 assert num_computed_tokens == 0 block_hashes = req2.block_hashes assert len(block_hashes) == 3 - assert block_hashes[0].extra_keys == ("salt2", ) + assert block_hashes[0] == sha256( + (kv_cache_utils.NONE_HASH, tuple(token_ids[:block_size]), ("salt2", ))) + assert block_hashes[1] == sha256( + (block_hashes[0], tuple(token_ids[block_size:block_size * 2]), None)) + assert block_hashes[2] == sha256( + (block_hashes[1], tuple(token_ids[block_size * 2:block_size * 3]), + None)) def test_prefill_not_enough_free_blocks_with_computed_blocks(): @@ -1004,7 +1029,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # Complete 3 blocks (48 tokens) # | Common-0 | Common-1 | Common-2 | ... | common_token_ids = [i for i in range(3) for _ in range(16)] - req0 = make_request("0", common_token_ids, block_size, hash) + req0 = make_request("0", common_token_ids, block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -1015,7 +1040,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): req0.request_id] # | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... | - req1 = make_request("1", common_token_ids * 2, block_size, hash) + req1 = make_request("1", common_token_ids * 2, block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert computed_blocks.blocks[0] == block_part0 assert num_computed_tokens == 3 * 16 @@ -1032,7 +1057,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) | # | Req1-5(F)| Req2-0 | Req2-1 | ... | - req2 = make_request("2", [7] * block_size * 2, block_size, hash) + req2 = make_request("2", [7] * block_size * 2, block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -1044,7 +1069,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # but it cannot be allocated due to insufficient free blocks (2). # In this case, the ref_cnt of the computed blocks should not be changed. assert manager.block_pool.free_block_queue.num_free_blocks == 5 - req3 = make_request("3", common_token_ids * 3, block_size, hash) + req3 = make_request("3", common_token_ids * 3, block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert computed_blocks.blocks[0] == block_part1 assert num_computed_tokens == 6 * 16 @@ -1069,13 +1094,13 @@ def test_reset_prefix_cache(): full_block_token_ids = [i for i in range(3) for _ in range(16)] unique_token_ids = [3] * 7 all_token_ids = full_block_token_ids + unique_token_ids - req0 = make_request("0", all_token_ids, block_size, hash) + req0 = make_request("0", all_token_ids, block_size, sha256) blocks = manager.allocate_slots(req0, 55) assert blocks.get_block_ids() == ([1, 2, 3, 4], ) unique_token_ids = [4] * 7 all_token_ids = full_block_token_ids + unique_token_ids - req1 = make_request("1", all_token_ids, block_size, hash) + req1 = make_request("1", all_token_ids, block_size, sha256) computed_blocks, _ = manager.get_computed_blocks(req1) assert len(req1.block_hashes) == 3 assert len(computed_blocks.blocks[0]) == 3 @@ -1109,7 +1134,7 @@ def test_prefix_cache_stats_disabled(): assert manager.prefix_cache_stats is None # Call all functions that check whether log_stats is disabled. - req = make_request("0", list(range(16)), block_size, hash) + req = make_request("0", list(range(16)), block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -1124,15 +1149,9 @@ def test_prefix_cache_stats_disabled(): def test_maybe_evict_cached_block(): pool = BlockPool(num_gpu_blocks=4, enable_caching=True) - block_hash0 = BlockHashWithGroupId(block_hash=BlockHash(hash_value=10, - token_ids=(100, )), - group_id=1000) - block_hash1 = BlockHashWithGroupId(block_hash=BlockHash(hash_value=20, - token_ids=(200, )), - group_id=2000) - block_hash2 = BlockHashWithGroupId(block_hash=BlockHash(hash_value=30, - token_ids=(300, )), - group_id=3000) + block_hash0 = make_block_hash_with_group_id(BlockHash(b"10"), 1000) + block_hash1 = make_block_hash_with_group_id(BlockHash(b"20"), 2000) + block_hash2 = make_block_hash_with_group_id(BlockHash(b"30"), 3000) block_hashes = [ block_hash0, block_hash1, @@ -1206,7 +1225,7 @@ def test_kv_cache_events(blocks_to_cache: int): ) num_tokens = block_size * blocks_to_cache - req0 = make_request("0", list(range(num_tokens)), block_size, hash) + req0 = make_request("0", list(range(num_tokens)), block_size, sha256) _ = manager.allocate_slots(req0, num_tokens) events = manager.take_events() @@ -1222,7 +1241,7 @@ def test_kv_cache_events(blocks_to_cache: int): # Should see block_to_cache number of removed block events and a new block # stored event manager.free(req0) - req1 = make_request("1", list(range(num_tokens)), block_size, hash) + req1 = make_request("1", list(range(num_tokens)), block_size, sha256) _ = manager.allocate_slots(req1, num_tokens) events = manager.take_events() @@ -1256,7 +1275,7 @@ def test_eagle_enabled_removes_last_block(): # Request with 3 full blocks (48 tokens) token_ids = [0] * (3 * block_size) - req = make_request("divisible_request", token_ids, block_size, hash) + req = make_request("divisible_request", token_ids, block_size, sha256) # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) @@ -1266,7 +1285,7 @@ def test_eagle_enabled_removes_last_block(): manager.free(req) # New request with same tokens + Eagle enabled - req_eagle = make_request("eagle_divisible", token_ids, block_size, hash) + req_eagle = make_request("eagle_divisible", token_ids, block_size, sha256) computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) # Should retain 1 block: @@ -1287,7 +1306,7 @@ def test_eagle_with_partial_blocks(): ) # 2 full blocks + 5 tokens (non-divisible length) token_ids = [0] * (2 * block_size + 5) - req = make_request("partial_block_test", token_ids, block_size, hash) + req = make_request("partial_block_test", token_ids, block_size, sha256) # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) @@ -1297,7 +1316,7 @@ def test_eagle_with_partial_blocks(): manager.free(req) # New request with Eagle enabled - req_eagle = make_request("partial_eagle", token_ids, block_size, hash) + req_eagle = make_request("partial_eagle", token_ids, block_size, sha256) computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) # Original match: 2 full blocks → Eagle removes 1 → 1 remaining assert len(computed_blocks.blocks[0]) == 1 @@ -1328,7 +1347,7 @@ def test_eagle_with_sliding_window(): # 2 full blocks + 5 tokens (non-divisible length) token_ids = [0] * (2 * block_size + 5) - req = make_request("partial_block_test", token_ids, block_size, hash) + req = make_request("partial_block_test", token_ids, block_size, sha256) # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) @@ -1341,7 +1360,7 @@ def test_eagle_with_sliding_window(): manager.free(req) # New request with Eagle enabled - req_eagle = make_request("partial_eagle", token_ids, block_size, hash) + req_eagle = make_request("partial_eagle", token_ids, block_size, sha256) computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) # Original match: 2 full blocks → Eagle removes 1 → 1 remaining assert len(computed_blocks.blocks[0]) == 1 @@ -1351,11 +1370,11 @@ def test_eagle_with_sliding_window(): assert manager.block_pool.get_cached_block( block_hash_first_block, kv_cache_group_ids=[0]) is not None manager.block_pool.cached_block_hash_to_block.pop( - BlockHashWithGroupId(block_hash_first_block, 0)) + make_block_hash_with_group_id(block_hash_first_block, 0)) # New request req_after_evict = make_request("partial_eagle_after_evict", token_ids, - block_size, hash) + block_size, sha256) computed_blocks, num_tokens = manager.get_computed_blocks(req_after_evict) # Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is # not considered. But after dropping the last matched block due to eagle, diff --git a/tests/v1/core/test_single_type_kv_cache_manager.py b/tests/v1/core/test_single_type_kv_cache_manager.py index 7dcebba491fab..b70850a9bcff9 100644 --- a/tests/v1/core/test_single_type_kv_cache_manager.py +++ b/tests/v1/core/test_single_type_kv_cache_manager.py @@ -6,8 +6,8 @@ import random import torch from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, - KVCacheBlock) +from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock, + make_block_hash_with_group_id) from vllm.v1.core.single_type_kv_cache_manager import ( ChunkedLocalAttentionManager, SlidingWindowManager) from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, @@ -44,7 +44,7 @@ def test_chunked_local_attention_possible_cached_prefix(): def run_one_case(block_is_cached, tail_token, expect_length): block_hash_list = [ - BlockHash(i, ()) for i in range(len(block_is_cached)) + BlockHash(str(i).encode()) for i in range(len(block_is_cached)) ] block_pool.cached_block_hash_to_block.clear() @@ -53,8 +53,8 @@ def test_chunked_local_attention_possible_cached_prefix(): for i, (block_hash, is_cached) in enumerate(zip(block_hash_list, block_is_cached)): if is_cached: - block_pool.cached_block_hash_to_block[BlockHashWithGroupId( - block_hash, 0)] = { + block_pool.cached_block_hash_to_block[ + make_block_hash_with_group_id(block_hash, 0)] = { i: block_pool.blocks[i + 10], } @@ -109,7 +109,7 @@ def test_sliding_window_possible_cached_prefix(): def run_one_case(block_is_cached, expect_length): block_hash_list = [ - BlockHash(i, ()) for i in range(len(block_is_cached)) + BlockHash(str(i).encode()) for i in range(len(block_is_cached)) ] block_pool.cached_block_hash_to_block.clear() @@ -118,8 +118,8 @@ def test_sliding_window_possible_cached_prefix(): for i, (block_hash, is_cached) in enumerate(zip(block_hash_list, block_is_cached)): if is_cached: - block_pool.cached_block_hash_to_block[BlockHashWithGroupId( - block_hash, 0)] = { + block_pool.cached_block_hash_to_block[ + make_block_hash_with_group_id(block_hash, 0)] = { i: block_pool.blocks[i + 10], } diff --git a/tests/v1/core/utils.py b/tests/v1/core/utils.py index e392c2c336e9b..d343141cdf4cb 100644 --- a/tests/v1/core/utils.py +++ b/tests/v1/core/utils.py @@ -9,6 +9,7 @@ from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, from vllm.multimodal.inputs import (MultiModalFeatureSpec, MultiModalKwargsItem, PlaceholderRange) from vllm.sampling_params import SamplingParams +from vllm.utils import sha256 from vllm.v1.core.kv_cache_utils import (get_request_block_hasher, init_none_hash) from vllm.v1.core.sched.async_scheduler import AsyncScheduler @@ -130,10 +131,10 @@ def create_requests( ) -> list[Request]: global _none_hash_initialized if not _none_hash_initialized: - init_none_hash(hash) + init_none_hash(sha256) _none_hash_initialized = True - block_hasher = get_request_block_hasher(block_size, hash) + block_hasher = get_request_block_hasher(block_size, sha256) sampling_params = SamplingParams(ignore_eos=False, max_tokens=max_tokens, stop_token_ids=stop_token_ids, diff --git a/tests/v1/cudagraph/test_cudagraph_mode.py b/tests/v1/cudagraph/test_cudagraph_mode.py index 81655e4175006..25e01806f4956 100644 --- a/tests/v1/cudagraph/test_cudagraph_mode.py +++ b/tests/v1/cudagraph/test_cudagraph_mode.py @@ -62,6 +62,16 @@ backend_configs = { "cudagraph_mode": "FULL_AND_PIECEWISE", }, specific_gpu_arch=(9, 0)), + # FlashAttention MLA on Hopper + "FlashAttentionMLA": + BackendConfig(name="FlashAttentionMLA", + env_vars={ + "VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA", + }, + comp_config={ + "cudagraph_mode": "FULL_DECODE_ONLY", + }, + specific_gpu_arch=(9, 0)), # FA2 "FA2": BackendConfig(name="FA2", diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index cd1d34fc6c3ec..bf90f50b10828 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -117,45 +117,38 @@ def test_ngram_correctness( print(f"ref_output: {ref_output.outputs[0].text}") print(f"spec_output: {spec_output.outputs[0].text}") - # Heuristic: expect at least 70% of the prompts to match exactly + # Heuristic: expect at least 66% of the prompts to match exactly # Upon failure, inspect the outputs to check for inaccuracy. - assert matches > int(0.7 * len(ref_outputs)) + assert matches >= int(0.66 * len(ref_outputs)) del spec_llm torch.cuda.empty_cache() cleanup_dist_env_and_memory() -@pytest.mark.parametrize( - ["model_setup", "mm_enabled"], - [ - # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501 - # (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False), - (("eagle", "meta-llama/Llama-3.1-8B-Instruct", - "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False), - (("eagle3", "meta-llama/Llama-3.1-8B-Instruct", - "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False), - pytest.param( - ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", - "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), - False, - marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), - pytest.param( - ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", - "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), - True, - marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), - (("eagle", "eagle618/deepseek-v3-random", - "eagle618/eagle-deepseek-v3-random", 1), False), - ], - ids=[ - # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501 - # "qwen3_eagle3", - "llama3_eagle", - "llama3_eagle3", - "llama4_eagle", - "llama4_eagle_mm", - "deepseek_eagle" - ]) +@pytest.mark.parametrize(["model_setup", "mm_enabled"], [ + (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False), + (("eagle", "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False), + (("eagle3", "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False), + pytest.param( + ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), + False, + marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), + pytest.param( + ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), + True, + marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), + (("eagle", "eagle618/deepseek-v3-random", + "eagle618/eagle-deepseek-v3-random", 1), False), +], + ids=[ + "qwen3_eagle3", "llama3_eagle", "llama3_eagle3", + "llama4_eagle", "llama4_eagle_mm", + "deepseek_eagle" + ]) @pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform()) def test_eagle_correctness( @@ -169,7 +162,7 @@ def test_eagle_correctness( # TODO: Fix this flaky test pytest.skip( "TREE_ATTN is flaky in the test disable for now until it can be " - "reolved (see https://github.com/vllm-project/vllm/issues/22922)") + "resolved (see https://github.com/vllm-project/vllm/issues/22922)") # Generate test prompts inside the function instead of using fixture test_prompts = get_test_prompts(mm_enabled) diff --git a/tests/v1/engine/test_engine_args.py b/tests/v1/engine/test_engine_args.py index f70a3ce147ff2..23ec3673b10b4 100644 --- a/tests/v1/engine/test_engine_args.py +++ b/tests/v1/engine/test_engine_args.py @@ -36,18 +36,19 @@ def test_prefix_caching_from_cli(): assert vllm_config.cache_config.enable_prefix_caching # default hash algorithm is "builtin" - assert vllm_config.cache_config.prefix_caching_hash_algo == "builtin" + assert vllm_config.cache_config.prefix_caching_hash_algo == "sha256" + + # set hash algorithm to sha256_cbor + args = parser.parse_args(["--prefix-caching-hash-algo", "sha256_cbor"]) + vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config() + assert vllm_config.cache_config.prefix_caching_hash_algo == \ + "sha256_cbor" # set hash algorithm to sha256 args = parser.parse_args(["--prefix-caching-hash-algo", "sha256"]) vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config() assert vllm_config.cache_config.prefix_caching_hash_algo == "sha256" - # set hash algorithm to builtin - args = parser.parse_args(["--prefix-caching-hash-algo", "builtin"]) - vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config() - assert vllm_config.cache_config.prefix_caching_hash_algo == "builtin" - # an invalid hash algorithm raises an error parser.exit_on_error = False with pytest.raises(ArgumentError): diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index 98265c6349578..17b136aa42731 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -257,9 +257,13 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): def execute_model( self, scheduler_output, + non_block=False, ) -> Future[ModelRunnerOutput]: """Make execute_model non-blocking.""" + # DummyExecutor used only for testing async case. + assert non_block + def _execute(): output = self.collective_rpc("execute_model", args=(scheduler_output, )) diff --git a/tests/v1/engine/test_processor_multi_modal_uuids.py b/tests/v1/engine/test_processor_multi_modal_uuids.py index 970a59eca8ece..955c74d262a09 100644 --- a/tests/v1/engine/test_processor_multi_modal_uuids.py +++ b/tests/v1/engine/test_processor_multi_modal_uuids.py @@ -152,8 +152,8 @@ def test_multi_modal_uuids_accepts_none_and_passes_through( *, tokenization_kwargs=None, lora_request=None, - mm_hash_overrides=None): - captured["mm_hash_overrides"] = mm_hash_overrides + mm_uuids=None): + captured["mm_uuids"] = mm_uuids # Minimal processed inputs for decoder-only flow return {"type": "token", "prompt_token_ids": [1]} @@ -180,7 +180,7 @@ def test_multi_modal_uuids_accepts_none_and_passes_through( params=SamplingParams(), ) - assert captured["mm_hash_overrides"] == mm_uuids + assert captured["mm_uuids"] == mm_uuids def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch): @@ -196,8 +196,8 @@ def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch): *, tokenization_kwargs=None, lora_request=None, - mm_hash_overrides=None): - captured["mm_hash_overrides"] = mm_hash_overrides + mm_uuids=None): + captured["mm_uuids"] = mm_uuids return {"type": "token", "prompt_token_ids": [1]} monkeypatch.setattr(processor.input_preprocessor, @@ -223,7 +223,7 @@ def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch): ) # Expect request-id-based overrides are passed through - assert captured["mm_hash_overrides"] == { + assert captured["mm_uuids"] == { "image": [f"{request_id}-image-0", f"{request_id}-image-1"], "video": [f"{request_id}-video-0"], } diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index c10b1abb2b3b7..126d8ce8c8e00 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -46,12 +46,12 @@ PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [ ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None), ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None), ("Qwen/Qwen2.5-1.5B-Instruct", "lm-format-enforcer", "auto", None), - ("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", None), - ("mistralai/Ministral-8B-Instruct-2410", "outlines", "mistral", None), + #FIXME: This tests are flaky on CI thus disabled. Tracking in Issue #24402 + # ("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", None), + # ("mistralai/Ministral-8B-Instruct-2410", "outlines", "mistral", None), + #("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"), ("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", NGRAM_SPEC_CONFIG), - #FIXME: This test is flaky on CI thus disabled - #("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"), ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", NGRAM_SPEC_CONFIG), ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", NGRAM_SPEC_CONFIG), diff --git a/tests/v1/entrypoints/openai/test_completion.py b/tests/v1/entrypoints/openai/test_completion.py index 3a65583fab8d3..3114d7639f045 100644 --- a/tests/v1/entrypoints/openai/test_completion.py +++ b/tests/v1/entrypoints/openai/test_completion.py @@ -686,7 +686,7 @@ async def test_invalid_grammar(client: openai.AsyncOpenAI, model_name: str): async def test_completion_with_empty_prompt_embeds( client: openai.AsyncOpenAI) -> None: """Test completion with empty prompt embeds.""" - payload: dict[str, list] = {"prompt_embeds": []} + payload: dict[str, object] = {"prompt": "Hello", "prompt_embeds": []} headers: dict[str, str] = {"Content-Type": "application/json"} # base_url = http://localhost:8000/v1/completions response = requests.post(f"{client.base_url}completions", diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 3f068d5e8c7eb..0cae1c7bc0518 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -13,6 +13,7 @@ from vllm.distributed.kv_transfer.kv_connector.factory import ( KVConnectorFactory) from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa SharedStorageConnector) +from vllm.utils import sha256 from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_utils import (get_request_block_hasher, init_none_hash) @@ -127,11 +128,11 @@ def create_request(request_id: int, use_all_1s_for_prompt_tokens: bool = False, num_remote_blocks: int = 3, block_size: int = 16, - hash_fn: Callable = hash) -> Request: + hash_fn: Callable = sha256) -> Request: """Make dummy request for testing.""" global _none_hash_initialized if not _none_hash_initialized: - init_none_hash(hash) + init_none_hash(hash_fn) _none_hash_initialized = True kv_transfer_params: Optional[dict[str, Any]] = None diff --git a/tests/v1/sample/test_logprobs.py b/tests/v1/sample/test_logprobs.py index e835c029634ce..570e330208a39 100644 --- a/tests/v1/sample/test_logprobs.py +++ b/tests/v1/sample/test_logprobs.py @@ -430,7 +430,7 @@ def test_zero_logprobs(vllm_model, example_prompts, def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch): - """Engine should return all vocabulary logprobs + """Engine should return all vocabulary logprobs and prompt logprobs Args: example_prompts: list of example prompts (test fixture) @@ -444,16 +444,24 @@ def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch): # 2 other llms alive during whole session gpu_memory_utilization=0.15, max_model_len=256) + sampling_params_logprobs_all = SamplingParams(max_tokens=5, - logprobs=-1) + logprobs=-1, + prompt_logprobs=-1) results_logprobs_all = runner.llm.generate( example_prompts, sampling_params=sampling_params_logprobs_all) vocab_size = runner.llm.llm_engine.get_model_config().get_vocab_size() + for i in range(len(results_logprobs_all)): logprobs = results_logprobs_all[i].outputs[0].logprobs + prompt_logprobs = results_logprobs_all[i].prompt_logprobs assert logprobs is not None for logprob in logprobs: assert len(logprob) == vocab_size + assert prompt_logprobs is not None + assert prompt_logprobs[0] is None + for prompt_logprob in prompt_logprobs[1:]: + assert len(prompt_logprob) == vocab_size @pytest.mark.parametrize("logprobs_mode", list(LogprobsMode)) diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 46e3a611c6d26..ddedc61aae296 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -12,9 +12,10 @@ from tests.v1.attention.utils import (BatchSpec, _Backend, create_common_attn_metadata, create_standard_kv_cache_spec, get_attention_backend) -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, +from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig) +from vllm.config.load import LoadConfig from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.platforms import current_platform from vllm.v1.spec_decode.eagle import EagleProposer diff --git a/tests/v1/test_oracle.py b/tests/v1/test_oracle.py index 1f16e92f657e0..efa604dd6b5a8 100644 --- a/tests/v1/test_oracle.py +++ b/tests/v1/test_oracle.py @@ -10,7 +10,6 @@ from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine UNSUPPORTED_MODELS_V1 = [ - "openai/whisper-large-v3", # transcription "facebook/bart-large-cnn", # encoder decoder ] diff --git a/tests/v1/tpu/test_pallas.py b/tests/v1/tpu/test_pallas.py index bfba3af57f715..1bc8dff317a74 100644 --- a/tests/v1/tpu/test_pallas.py +++ b/tests/v1/tpu/test_pallas.py @@ -33,10 +33,12 @@ def test_ragged_paged_attention(): ) class FakeAttentionLayer: + _q_scale_float: float _k_scale_float: float _v_scale_float: float layer = FakeAttentionLayer() + layer._q_scale_float = 1.0 layer._k_scale_float = 1.0 layer._v_scale_float = 1.0 diff --git a/tests/v1/tpu/worker/test_tpu_model_runner.py b/tests/v1/tpu/worker/test_tpu_model_runner.py index 941aa0a77692c..c719e44acc9c2 100644 --- a/tests/v1/tpu/worker/test_tpu_model_runner.py +++ b/tests/v1/tpu/worker/test_tpu_model_runner.py @@ -64,9 +64,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: NewRequestData( req_id=req_id, prompt_token_ids=[1, 2, 3], - mm_kwargs=[], - mm_hashes=[], - mm_positions=[], + mm_features=[], sampling_params=SamplingParams(), pooling_params=PoolingParams(), block_ids=([0], ), # block_ids should be tuple[list[int]] diff --git a/tests/v1/tracing/test_tracing.py b/tests/v1/tracing/test_tracing.py new file mode 100644 index 0000000000000..da8655f95e195 --- /dev/null +++ b/tests/v1/tracing/test_tracing.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa +# type: ignore +from __future__ import annotations + +import threading +from collections.abc import Iterable +from concurrent import futures +from typing import Callable, Generator, Literal + +import grpc +import pytest +from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import ( + ExportTraceServiceResponse) +from opentelemetry.proto.collector.trace.v1.trace_service_pb2_grpc import ( + TraceServiceServicer, add_TraceServiceServicer_to_server) +from opentelemetry.proto.common.v1.common_pb2 import AnyValue, KeyValue +from opentelemetry.sdk.environment_variables import ( + OTEL_EXPORTER_OTLP_TRACES_INSECURE) + +from vllm import LLM, SamplingParams +from vllm.tracing import SpanAttributes + +FAKE_TRACE_SERVER_ADDRESS = "localhost:4317" + +FieldName = Literal['bool_value', 'string_value', 'int_value', 'double_value', + 'array_value'] + + +def decode_value(value: AnyValue): + field_decoders: dict[FieldName, Callable] = { + "bool_value": (lambda v: v.bool_value), + "string_value": (lambda v: v.string_value), + "int_value": (lambda v: v.int_value), + "double_value": (lambda v: v.double_value), + "array_value": + (lambda v: [decode_value(item) for item in v.array_value.values]), + } + for field, decoder in field_decoders.items(): + if value.HasField(field): + return decoder(value) + raise ValueError(f"Couldn't decode value: {value}") + + +def decode_attributes(attributes: Iterable[KeyValue]): + return {kv.key: decode_value(kv.value) for kv in attributes} + + +class FakeTraceService(TraceServiceServicer): + + def __init__(self): + self.request = None + self.evt = threading.Event() + + def Export(self, request, context): + self.request = request + self.evt.set() + return ExportTraceServiceResponse() + + +@pytest.fixture +def trace_service() -> Generator[FakeTraceService, None, None]: + """Fixture to set up a fake gRPC trace service""" + server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) + service = FakeTraceService() + add_TraceServiceServicer_to_server(service, server) + server.add_insecure_port(FAKE_TRACE_SERVER_ADDRESS) + server.start() + + yield service + + server.stop(None) + + +def test_traces( + monkeypatch: pytest.MonkeyPatch, + trace_service: FakeTraceService, +): + with monkeypatch.context() as m: + m.setenv(OTEL_EXPORTER_OTLP_TRACES_INSECURE, "true") + m.setenv("VLLM_USE_V1", "1") + sampling_params = SamplingParams( + temperature=0.01, + top_p=0.1, + max_tokens=256, + ) + model = "facebook/opt-125m" + llm = LLM(model=model, + otlp_traces_endpoint=FAKE_TRACE_SERVER_ADDRESS, + gpu_memory_utilization=0.3, + disable_log_stats=False) + prompts = ["This is a short prompt"] + outputs = llm.generate(prompts, sampling_params=sampling_params) + print(f"test_traces outputs is : {outputs}") + + timeout = 10 + if not trace_service.evt.wait(timeout): + raise TimeoutError( + f"The fake trace service didn't receive a trace within " + f"the {timeout} seconds timeout") + + request = trace_service.request + assert len(request.resource_spans) == 1, ( + f"Expected 1 resource span, " + f"but got {len(request.resource_spans)}") + assert len(request.resource_spans[0].scope_spans) == 1, ( + f"Expected 1 scope span, " + f"but got {len(request.resource_spans[0].scope_spans)}") + assert len(request.resource_spans[0].scope_spans[0].spans) == 1, ( + f"Expected 1 span, " + f"but got {len(request.resource_spans[0].scope_spans[0].spans)}") + + attributes = decode_attributes( + request.resource_spans[0].scope_spans[0].spans[0].attributes) + # assert attributes.get(SpanAttributes.GEN_AI_RESPONSE_MODEL) == model + assert attributes.get( + SpanAttributes.GEN_AI_REQUEST_ID) == outputs[0].request_id + assert attributes.get(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE + ) == sampling_params.temperature + assert attributes.get( + SpanAttributes.GEN_AI_REQUEST_TOP_P) == sampling_params.top_p + assert attributes.get(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS + ) == sampling_params.max_tokens + assert attributes.get( + SpanAttributes.GEN_AI_REQUEST_N) == sampling_params.n + assert attributes.get( + SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS) == len( + outputs[0].prompt_token_ids) + completion_tokens = sum(len(o.token_ids) for o in outputs[0].outputs) + assert attributes.get( + SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS) == completion_tokens + + assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE) > 0 + assert attributes.get( + SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN) > 0 + assert attributes.get(SpanAttributes.GEN_AI_LATENCY_E2E) > 0 diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 7031859078264..38f543c784866 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -203,9 +203,7 @@ def _construct_cached_request_state(req_id_suffix: int): prompt_token_ids=prompt_token_ids, sampling_params=_create_sampling_params(), pooling_params=None, - mm_kwargs=[], - mm_positions=[], - mm_hashes=[], + mm_features=[], block_ids=([], ), generator=None, num_computed_tokens=len(output_token_ids), diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 6d99029e404ef..5ebc00d573030 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import random - import numpy as np import pytest import torch @@ -120,9 +118,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: NewRequestData( req_id=req_id, prompt_token_ids=[1, 2, 3], - mm_kwargs=[], - mm_hashes=[], - mm_positions=[], + mm_features=[], sampling_params=SamplingParams(), pooling_params=None, block_ids=([0], ), @@ -409,29 +405,30 @@ def test_kv_cache_stride_order(monkeypatch, model_runner): model_runner.model_config.get_head_size() ] # TODO mla test - default_stride = list(range(5)) + default_stride = tuple(range(5)) # Permutation that gets you back to expected kv shape - rnd_stride = tuple(random.sample(default_stride, len(default_stride))) + for test_stride in ((1, 4, 0, 2, 3), (0, 1, 2, 3, 4)): - def rnd_stride_order(): - return rnd_stride + def rnd_stride_order(test_stride=test_stride): + return test_stride - # Patch the attention backend class and re-trigger the KV cache creation. - for attn_group in model_runner._attn_group_iterator(): - attn_backend = attn_group.backend - monkeypatch.setattr(attn_backend, "get_kv_cache_stride_order", - rnd_stride_order) + # Patch the attention backend class and re-trigger the KV cache creation + for attn_group in model_runner._attn_group_iterator(): + attn_backend = attn_group.backend + monkeypatch.setattr(attn_backend, "get_kv_cache_stride_order", + rnd_stride_order) - model_runner.attn_groups = [] - model_runner.initialize_kv_cache(model_runner.kv_cache_config) + model_runner.attn_groups = [] + model_runner.kv_caches = [] + model_runner.initialize_kv_cache(model_runner.kv_cache_config) - # Shape is unchanged, but layout may differ - kv_cache_shape = model_runner.kv_caches[0].shape - assert list(kv_cache_shape) == expected_kv_cache_shape - if default_stride == rnd_stride: - assert all(kv.is_contiguous() for kv in model_runner.kv_caches) - else: - assert all(not kv.is_contiguous() for kv in model_runner.kv_caches) + # Shape is unchanged, but layout may differ + kv_cache_shape = model_runner.kv_caches[0].shape + assert list(kv_cache_shape) == expected_kv_cache_shape + if default_stride == test_stride: + assert all(kv.is_contiguous() for kv in model_runner.kv_caches) + else: + assert all(not kv.is_contiguous() for kv in model_runner.kv_caches) def test_update_config(model_runner): diff --git a/tools/check_pickle_imports.py b/tools/check_pickle_imports.py index ad0ae45d1d465..fe717121db40d 100644 --- a/tools/check_pickle_imports.py +++ b/tools/check_pickle_imports.py @@ -39,6 +39,7 @@ ALLOWED_FILES = set([ 'vllm/engine/multiprocessing/client.py', 'vllm/distributed/device_communicators/all_reduce_utils.py', 'vllm/distributed/device_communicators/shm_broadcast.py', + 'vllm/distributed/device_communicators/shm_object_storage.py', 'vllm/engine/multiprocessing/engine.py', 'benchmarks/kernels/graph_machete_bench.py', 'benchmarks/kernels/benchmark_lora.py', diff --git a/tools/mypy.sh b/tools/mypy.sh index 781d8fc02884b..63e3b9a916634 100755 --- a/tools/mypy.sh +++ b/tools/mypy.sh @@ -29,7 +29,7 @@ run_mypy vllm/engine run_mypy vllm/executor run_mypy vllm/inputs run_mypy vllm/lora -run_mypy vllm/model_executor +run_mypy --exclude 'vllm/model_executor/layers/fla/ops' vllm/model_executor run_mypy vllm/plugins run_mypy vllm/worker run_mypy vllm/v1 diff --git a/use_existing_torch.py b/use_existing_torch.py index a9f79e16981c4..b5aafdde16c28 100644 --- a/use_existing_torch.py +++ b/use_existing_torch.py @@ -1,21 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import glob - -requires_files = glob.glob('requirements/*.txt') -requires_files += ["pyproject.toml"] -for file in requires_files: - print(f">>> cleaning {file}") - with open(file) as f: - lines = f.readlines() - if "torch" in "".join(lines).lower(): - print("removed:") - with open(file, 'w') as f: - for line in lines: - if 'torch' not in line.lower(): - f.write(line) - else: - print(line.strip()) - print(f"<<< done cleaning {file}") - print() +print("vLLM is now using 'uv' to disable build isolation for 'torch'.") +print("Please instead install vLLM with 'uv pip install -e .' (must use 'uv')") diff --git a/vllm/__init__.py b/vllm/__init__.py index 7b90fd3a241bd..3a5c1b1ce0daf 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -14,6 +14,8 @@ import typing import vllm.env_override # noqa: F401 MODULE_ATTRS = { + "bc_linter_skip": "._bc_linter:bc_linter_skip", + "bc_linter_include": "._bc_linter:bc_linter_include", "AsyncEngineArgs": ".engine.arg_utils:AsyncEngineArgs", "EngineArgs": ".engine.arg_utils:EngineArgs", "AsyncLLMEngine": ".engine.async_llm_engine:AsyncLLMEngine", @@ -54,6 +56,8 @@ if typing.TYPE_CHECKING: ScoringRequestOutput) from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams + + from ._bc_linter import bc_linter_include, bc_linter_skip else: def __getattr__(name: str) -> typing.Any: @@ -70,6 +74,8 @@ else: __all__ = [ "__version__", + "bc_linter_skip", + "bc_linter_include", "__version_tuple__", "LLM", "ModelRegistry", diff --git a/vllm/_bc_linter.py b/vllm/_bc_linter.py new file mode 100644 index 0000000000000..52a95dbee1866 --- /dev/null +++ b/vllm/_bc_linter.py @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# vllm/_bc_linter.py +from __future__ import annotations + +from typing import Any, Callable, TypeVar, overload + +T = TypeVar("T") + + +@overload +def bc_linter_skip(obj: T) -> T: + ... + + +@overload +def bc_linter_skip(*, reason: str | None = ...) -> Callable[[T], T]: + ... + + +def bc_linter_skip(obj: Any = None, *, reason: str | None = None): + """ + No-op decorator to mark symbols/files for BC-linter suppression. + + Usage: + @bc_linter_skip + def legacy_api(...): ... + """ + + def _wrap(x: T) -> T: + return x + + return _wrap if obj is None else obj + + +@overload +def bc_linter_include(obj: T) -> T: + ... + + +@overload +def bc_linter_include(*, reason: str | None = ...) -> Callable[[T], T]: + ... + + +def bc_linter_include(obj: Any = None, *, reason: str | None = None): + """ + Usage: + @bc_linter_include + def public_api(...): ... + """ + + def _wrap(x: T) -> T: + return x + + return _wrap if obj is None else obj + + +__all__ = ["bc_linter_skip", "bc_linter_include"] diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 545f4cb48bf47..93b4f87ed260c 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -257,16 +257,6 @@ def rotary_embedding( cos_sin_cache, is_neox) -def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, - key: Optional[torch.Tensor], head_size: int, - cos_sin_cache: torch.Tensor, is_neox: bool, - rot_dim: int, - cos_sin_cache_offsets: torch.Tensor) -> None: - torch.ops._C.batched_rotary_embedding(positions, query, key, head_size, - cos_sin_cache, is_neox, rot_dim, - cos_sin_cache_offsets) - - # layer norm ops def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, epsilon: float) -> None: @@ -280,6 +270,13 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) +def poly_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, + bias: torch.Tensor, epsilon: float) -> None: + # TODO: Remove this contiguous call when the kernel is updated to support non-contiguous input + input_contiguous = input.contiguous() + torch.ops._C.poly_norm(out, input_contiguous, weight, bias, epsilon) + + def apply_repetition_penalties_torch( logits: torch.Tensor, prompt_mask: torch.Tensor, output_mask: torch.Tensor, repetition_penalties: torch.Tensor) -> None: @@ -710,6 +707,7 @@ def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool: def cutlass_group_gemm_supported(cuda_device_capability: int) -> bool: return torch.ops._C.cutlass_group_gemm_supported(cuda_device_capability) + def cutlass_sparse_compress(a: torch.Tensor) \ -> tuple[torch.Tensor, torch.Tensor]: """ @@ -1833,13 +1831,13 @@ def cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor, return out -def sm100_cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor, - q_pe: torch.Tensor, +def sm100_cutlass_mla_decode(out: torch.Tensor, lse: torch.Tensor, + q_nope: torch.Tensor, q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, seq_lens: torch.Tensor, page_table: torch.Tensor, workspace: torch.Tensor, scale: float, num_kv_splits: int) -> torch.Tensor: - torch.ops._C.sm100_cutlass_mla_decode(out, q_nope, q_pe, + torch.ops._C.sm100_cutlass_mla_decode(out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, scale, num_kv_splits) diff --git a/vllm/_ipex_ops.py b/vllm/_ipex_ops.py index c2868c040aa16..59b0aed321502 100644 --- a/vllm/_ipex_ops.py +++ b/vllm/_ipex_ops.py @@ -148,17 +148,6 @@ class ipex_ops: head_size, cos_sin_cache, is_neox, rot_dim) - @staticmethod - def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, - key: torch.Tensor, head_size: int, - cos_sin_cache: torch.Tensor, is_neox: bool, - rot_dim: int, - cos_sin_cache_offsets: torch.Tensor) -> None: - ipex.llm.functional.rotary_embedding_batched(positions, query, key, - head_size, cos_sin_cache, - is_neox, rot_dim, - cos_sin_cache_offsets) - @staticmethod def rms_norm(input: torch.Tensor, weight: torch.Tensor, epsilon: float) -> torch.Tensor: diff --git a/vllm/assets/video.py b/vllm/assets/video.py index 8ab0e9760be87..983e9114cccfb 100644 --- a/vllm/assets/video.py +++ b/vllm/assets/video.py @@ -110,22 +110,23 @@ class VideoAsset: def filename(self) -> str: return self._NAME_TO_FILE[self.name] + @property + def video_path(self) -> str: + return download_video_asset(self.filename) + @property def pil_images(self) -> list[Image.Image]: - video_path = download_video_asset(self.filename) - ret = video_to_pil_images_list(video_path, self.num_frames) + ret = video_to_pil_images_list(self.video_path, self.num_frames) return ret @property def np_ndarrays(self) -> npt.NDArray: - video_path = download_video_asset(self.filename) - ret = video_to_ndarrays(video_path, self.num_frames) + ret = video_to_ndarrays(self.video_path, self.num_frames) return ret @property def metadata(self) -> dict[str, Any]: - video_path = download_video_asset(self.filename) - ret = video_get_metadata(video_path) + ret = video_get_metadata(self.video_path) return ret def get_audio(self, sampling_rate: Optional[float] = None) -> npt.NDArray: @@ -134,5 +135,4 @@ class VideoAsset: See also: examples/offline_inference/qwen2_5_omni/only_thinker.py """ - video_path = download_video_asset(self.filename) - return librosa.load(video_path, sr=sampling_rate)[0] + return librosa.load(self.video_path, sr=sampling_rate)[0] diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 0217bff6adafa..75bcdc4bbcf0d 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -240,6 +240,7 @@ class AttentionLayer(Protocol): _q_scale: torch.Tensor _k_scale: torch.Tensor _v_scale: torch.Tensor + _q_scale_float: float _k_scale_float: float _v_scale_float: float _prob_scale: torch.Tensor diff --git a/vllm/attention/backends/differential_flash_attn.py b/vllm/attention/backends/differential_flash_attn.py index caa02530d2fd6..a7d0e3afb517f 100644 --- a/vllm/attention/backends/differential_flash_attn.py +++ b/vllm/attention/backends/differential_flash_attn.py @@ -734,6 +734,7 @@ class DifferentialFlashAttentionImpl(AttentionImpl): window_size=self.sliding_window, alibi_slopes=self.alibi_slopes, softcap=self.logits_soft_cap, + fa_version=self.vllm_flash_attn_version, ) assert prefill_output.shape == output[: num_prefill_tokens].shape @@ -755,6 +756,7 @@ class DifferentialFlashAttentionImpl(AttentionImpl): window_size=self.sliding_window, alibi_slopes=self.alibi_slopes, softcap=self.logits_soft_cap, + fa_version=self.vllm_flash_attn_version, ).squeeze(1) except Exception as e: logger.error("Error in PagedAttention.forward_decode: %s", @@ -787,6 +789,7 @@ class DifferentialFlashAttentionImpl(AttentionImpl): window_size=self.sliding_window, alibi_slopes=self.alibi_slopes, softcap=self.logits_soft_cap, + fa_version=self.vllm_flash_attn_version, ).squeeze(1) return output diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index d8cb208c4f2ea..78c768f92d3c2 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -901,7 +901,7 @@ def _get_query_key_seq_metadata( attn_metadata.encoder_seq_start_loc, attn_metadata.max_encoder_seq_len) elif attn_type == AttentionType.ENCODER: - # For encoder attention both the query and the key are same i.e the + # For encoder attention both the query and the key are same i.e. the # encoder sequence. return (attn_metadata.encoder_seq_start_loc, attn_metadata.max_encoder_seq_len, diff --git a/vllm/attention/backends/flashmla.py b/vllm/attention/backends/flashmla.py index f23c096952ce0..411eb5413f53c 100644 --- a/vllm/attention/backends/flashmla.py +++ b/vllm/attention/backends/flashmla.py @@ -17,6 +17,7 @@ from vllm.attention.backends.mla.common import (MLACommonBackend, from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, get_mla_metadata, is_flashmla_supported) +from vllm.platforms.cuda import CudaPlatform class FlashMLABackend(MLACommonBackend): @@ -181,6 +182,16 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): assert is_flashmla_supported(), \ "FlashMLA is not supported on this device" + # disallow FlashMLA on NVIDIA Blackwell (SM 10.0+) GPUs + # context: + # https://github.com/deepseek-ai/FlashMLA/issues/83 + # https://github.com/vllm-project/vllm/issues/24513 + if CudaPlatform.has_device_capability(100): + raise NotImplementedError( + "FlashMLA is temporarily disabled on Blackwell (SM 10.0). " + "Please use CUTLASS_MLA or TRITON_MLA instead. " + "Example: `export VLLM_ATTENTION_BACKEND=CUTLASS_MLA`") + unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): raise NotImplementedError( diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 3b9037521168e..789393eb39a73 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -1052,7 +1052,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): return layer.weight # we currently do not have quantized bmm's which are needed for - # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform + # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform # the bmm's in 16-bit, the extra memory overhead of this is fairly low kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T assert kv_b_proj_weight.shape == ( diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 237802afccde9..44cb2c7c6b642 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -23,6 +23,7 @@ from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.platforms import _Backend, current_platform from vllm.utils import direct_register_custom_op @@ -55,6 +56,14 @@ def check_xformers_availability(): return USE_XFORMERS_OPS +def check_upstream_fa_availability(dtype: torch.dtype): + if dtype in (torch.float16, torch.bfloat16) and current_platform.is_cuda( + ) and current_platform.has_device_capability(80): + from transformers.utils import is_flash_attn_2_available + return is_flash_attn_2_available() + return False + + class Attention(nn.Module, AttentionLayerBase): """Attention layer. @@ -349,29 +358,55 @@ class MultiHeadAttention(nn.Module): f"divisible by num_kv_heads ({self.num_kv_heads})" self.num_queries_per_kv = self.num_heads // self.num_kv_heads + # During model initialization, the default dtype is set as the model + # weight and activation dtype. dtype = torch.get_default_dtype() - attn_backend = get_attn_backend(head_size, - dtype, - kv_cache_dtype=None, - block_size=16, - is_attention_free=False) - backend = backend_name_to_enum(attn_backend.get_name()) + + # Determine the attention backend + backend = get_vit_attn_backend(head_size=head_size, dtype=dtype) + + # Some auto-selected backends can be upgraded + # to upstream flash attention if available. + # If vllm native fa is selected, we use it directly. + use_upstream_fa = False + if backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( + dtype): + backend = _Backend.FLASH_ATTN + use_upstream_fa = True + if current_platform.is_rocm(): # currently, only torch_sdpa is supported on rocm self.attn_backend = _Backend.TORCH_SDPA else: - if backend in (_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1, - _Backend.FLEX_ATTENTION): - backend = _Backend.XFORMERS self.attn_backend = backend if backend in { - _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1 + _Backend.TORCH_SDPA, + _Backend.TORCH_SDPA_VLLM_V1, + _Backend.XFORMERS, + _Backend.PALLAS_VLLM_V1, + _Backend.ROCM_AITER_FA, + _Backend.FLASH_ATTN, + _Backend.FLASH_ATTN_VLLM_V1, } else _Backend.TORCH_SDPA if (self.attn_backend == _Backend.XFORMERS and not check_xformers_availability()): self.attn_backend = _Backend.TORCH_SDPA + if self.attn_backend in { + _Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1 + }: + if use_upstream_fa: + from flash_attn import flash_attn_varlen_func + self._flash_attn_varlen_func = flash_attn_varlen_func + else: + from vllm.vllm_flash_attn import flash_attn_varlen_func + self._flash_attn_varlen_func = flash_attn_varlen_func + + logger.info_once( + f"MultiHeadAttention attn_backend: {self.attn_backend}, " + f"use_upstream_fa: {use_upstream_fa}") + def forward( self, query: torch.Tensor, @@ -392,14 +427,39 @@ class MultiHeadAttention(nn.Module): key = torch.repeat_interleave(key, num_repeat, dim=2) value = torch.repeat_interleave(value, num_repeat, dim=2) - if self.attn_backend == _Backend.XFORMERS: + if self.attn_backend in { + _Backend.FLASH_ATTN, + _Backend.FLASH_ATTN_VLLM_V1, + }: + + cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len, + step=q_len, + dtype=torch.int32, + device=query.device) + cu_seqlens_k = torch.arange(0, (bsz + 1) * kv_len, + step=kv_len, + dtype=torch.int32, + device=key.device) + + out = self._flash_attn_varlen_func( + query.flatten(0, 1), + key.flatten(0, 1), + value.flatten(0, 1), + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=q_len, + max_seqlen_k=kv_len, + softmax_scale=self.scale, + ) + elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops out = xops.memory_efficient_attention_forward(query, key, value, scale=self.scale) - elif self.attn_backend == _Backend.TORCH_SDPA: + elif (self.attn_backend == _Backend.TORCH_SDPA + or self.attn_backend == _Backend.TORCH_SDPA_VLLM_V1): query, key, value = (x.transpose(1, 2) for x in (query, key, value)) out = F.scaled_dot_product_attention(query, @@ -413,6 +473,19 @@ class MultiHeadAttention(nn.Module): from torch_xla.experimental.custom_kernel import flash_attention out = flash_attention(query, key, value, sm_scale=self.scale) out = out.transpose(1, 2) + elif self.attn_backend == _Backend.ROCM_AITER_FA: + from aiter import flash_attn_varlen_func + + # ROCm Flash Attention expects (batch, seq, heads, head_dim) + out = flash_attn_varlen_func(query, + key, + value, + softmax_scale=self.scale) + else: + # ViT attention hasn't supported this backend yet + raise NotImplementedError( + f"ViT attention hasn't supported {self.attn_backend} " + f"backend yet.") return out.reshape(bsz, q_len, -1) diff --git a/vllm/attention/layers/cross_attention.py b/vllm/attention/layers/cross_attention.py new file mode 100644 index 0000000000000..c24fa4e15f679 --- /dev/null +++ b/vllm/attention/layers/cross_attention.py @@ -0,0 +1,159 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools +from copy import copy +from typing import Optional + +import numpy as np +import torch + +from vllm import envs +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata, AttentionType) +from vllm.attention.layer import Attention +from vllm.attention.selector import get_attn_backend +from vllm.config import CacheConfig, VllmConfig +from vllm.logger import init_logger +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.utils import cdiv +from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, + subclass_attention_backend) +from vllm.v1.kv_cache_interface import CrossAttentionSpec + +logger = init_logger(__name__) + + +def _get_max_encoder_len(vllm_config: VllmConfig) -> int: + return MULTIMODAL_REGISTRY.get_encdec_max_encoder_len( + vllm_config.model_config) + + +def _get_cross_slot_mapping(encoder_seq_lens: np.ndarray, + block_table_tensor: torch.Tensor, + kv_cache_spec: CrossAttentionSpec, + device: torch.device) -> torch.Tensor: + """Get cross-attention slot mappings.""" + + block_size = kv_cache_spec.block_size + slot_mappings = [] + + # Find indices with non-zero encoder sequence lengths + # The majority of parallel requests will be running the + # decoder, so this list should be relatively small. + active_indices = np.nonzero(encoder_seq_lens)[0] + + for req_index in active_indices: + encoder_seq_len = encoder_seq_lens[req_index].item() + + # Calculate the number of blocks needed for this request + num_blocks_needed = cdiv(encoder_seq_len, block_size) + + # Get the block IDs for this request from the tensor + req_block_ids = block_table_tensor[req_index] + + # Get only the blocks we need (first num_blocks_needed blocks) + needed_block_ids = req_block_ids[:num_blocks_needed] + + # All needed blocks are allocated + i_values = torch.arange(encoder_seq_len, + dtype=torch.int64, + device=device) + block_indices = i_values // block_size + block_offsets = i_values % block_size + block_numbers = needed_block_ids[block_indices] + slot_mapping = block_numbers * block_size + block_offsets + + slot_mappings.append(slot_mapping) + + if slot_mappings: + return torch.cat(slot_mappings) + else: + return torch.empty(0, dtype=torch.int64, device=device) + + +@functools.lru_cache +def create_cross_attention_backend( + underlying_attn_backend: AttentionBackend, ) -> type[AttentionBackend]: + prefix = "CrossAttention_" + underlying_builder = underlying_attn_backend.get_builder_cls() + + class CrossAttentionBuilder(underlying_builder): # type: ignore + + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> AttentionMetadata: + new_metadata = copy(common_attn_metadata) + new_metadata.causal = False + max_encoder_len = _get_max_encoder_len(self.vllm_config) + new_metadata.max_seq_len = max_encoder_len + + new_metadata.seq_lens = torch.full( + (new_metadata.num_reqs, ), + max_encoder_len, + dtype=torch.int32, + device=self.device, + ) + new_metadata.seq_lens_cpu = torch.full( + (new_metadata.num_reqs, ), + max_encoder_len, + dtype=torch.int32, + device="cpu", + ) + new_metadata.slot_mapping = _get_cross_slot_mapping( + new_metadata.encoder_seq_lens, new_metadata.block_table_tensor, + self.kv_cache_spec, self.device) + return super().build(common_prefix_len, new_metadata, fast_build) + + attn_backend = subclass_attention_backend( + name_prefix=prefix, + attention_backend_cls=underlying_attn_backend, + builder_cls=CrossAttentionBuilder) + + return attn_backend + + +class CrossAttention(Attention): + """ + Cross-attention for encoder-decoder models. + Handles attention between decoder queries and encoder keys/values. + """ + + def __init__(self, + num_heads: int, + head_size: int, + scale: float, + cache_config: Optional[CacheConfig] = None, + attn_type: Optional[str] = None, + **kwargs): + dtype = torch.get_default_dtype() + + if cache_config is not None: + kv_cache_dtype = cache_config.cache_dtype + block_size = cache_config.block_size + else: + kv_cache_dtype = "auto" + block_size = 16 + + if envs.VLLM_USE_V1: + underlying_attn_backend = get_attn_backend(head_size, dtype, + kv_cache_dtype, + block_size) + + attn_backend = create_cross_attention_backend( + underlying_attn_backend) + else: + # in v0 cross attention is handled inside the backends + attn_backend = None + + if attn_type is not None: + assert attn_type == AttentionType.ENCODER_DECODER, ( + "CrossAttention only supports AttentionType.ENCODER_DECODER") + + super().__init__(num_heads=num_heads, + head_size=head_size, + scale=scale, + cache_config=cache_config, + attn_backend=attn_backend, + attn_type=AttentionType.ENCODER_DECODER, + **kwargs) diff --git a/vllm/attention/ops/chunked_prefill_paged_decode.py b/vllm/attention/ops/chunked_prefill_paged_decode.py index e5b90a8b27558..bf4b06512a3c1 100644 --- a/vllm/attention/ops/chunked_prefill_paged_decode.py +++ b/vllm/attention/ops/chunked_prefill_paged_decode.py @@ -15,6 +15,8 @@ from vllm.triton_utils import tl, triton from .prefix_prefill import context_attention_fwd +float8_info = torch.finfo(current_platform.fp8_dtype()) + @triton.jit def cdiv_fn(x, y): @@ -34,6 +36,7 @@ def kernel_paged_attention_2d( scale, # float32 k_scale, # float32 v_scale, # float32 + out_scale_inv, num_query_heads: tl.constexpr, # int num_queries_per_kv: tl.constexpr, # int num_queries_per_kv_padded: tl.constexpr, # int @@ -60,7 +63,9 @@ def kernel_paged_attention_2d( filter_by_query_len: tl.constexpr, # bool query_start_len_ptr, # [num_seqs+1] USE_SINKS: tl.constexpr, # bool -): + USE_FP8: tl.constexpr, + FP8_MIN: tl.constexpr = float8_info.min, + FP8_MAX: tl.constexpr = float8_info.max): seq_idx = tl.program_id(0) kv_head_idx = tl.program_id(1) @@ -204,6 +209,9 @@ def kernel_paged_attention_2d( # epilogue acc = acc / L[:, None] + if USE_FP8: + acc = acc * tl.load(out_scale_inv) + acc = tl.clamp(acc, FP8_MIN, FP8_MAX) output_offset = (cur_batch_in_all_start_index * output_stride_0 + query_head_idx * output_stride_1) @@ -234,6 +242,7 @@ def chunked_prefill_paged_decode( alibi_slopes=None, sliding_window=None, sm_scale=None, + output_scale=None, # Optional tensor for sinks sinks=None, ): @@ -266,6 +275,7 @@ def chunked_prefill_paged_decode( sliding_window=sliding_window, sm_scale=sm_scale, skip_decode=True, + fp8_out_scale=output_scale, sinks=sinks, ) @@ -316,7 +326,7 @@ def chunked_prefill_paged_decode( tmp_output = torch.empty( size=(total_num_seq, num_query_heads, max_num_partitions, head_size), - dtype=output.dtype, + dtype=query.dtype, device=output.device, ) exp_sums = torch.empty( @@ -345,6 +355,7 @@ def chunked_prefill_paged_decode( kv_cache_dtype=kv_cache_dtype, k_scale=k_scale, v_scale=v_scale, + fp8_out_scale=output_scale, ) else: kernel_paged_attention_2d[( @@ -362,6 +373,8 @@ def chunked_prefill_paged_decode( scale=sm_scale, k_scale=k_scale, v_scale=v_scale, + out_scale_inv=1.0 / + output_scale if output_scale is not None else 1.0, num_query_heads=num_query_heads, num_queries_per_kv=num_queries_per_kv, num_queries_per_kv_padded=num_queries_per_kv_padded, @@ -388,4 +401,5 @@ def chunked_prefill_paged_decode( filter_by_query_len=True, query_start_len_ptr=query_start_loc, USE_SINKS=sinks is not None, + USE_FP8=output_scale is not None, ) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index a70db89cdb76e..7e5c2b6c62e9b 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -15,6 +15,7 @@ NUM_WARPS = 4 if current_platform.is_rocm() else 8 # To check compatibility IS_TURING = current_platform.get_device_capability() == (7, 5) +float8_info = torch.finfo(current_platform.fp8_dtype()) # Here's an example autotuner config for this kernel. This config does provide @@ -43,6 +44,7 @@ def _fwd_kernel(Q, sm_scale, k_scale, v_scale, + out_scale_inv, B_Start_Loc, B_Seqlen, x: tl.constexpr, @@ -82,8 +84,11 @@ def _fwd_kernel(Q, num_unroll_request: tl.constexpr, SKIP_DECODE: tl.constexpr, USE_SINKS: tl.constexpr, + USE_FP8: tl.constexpr, MAX_Q_LEN: tl.constexpr = 0, - MAX_CTX_LEN: tl.constexpr = 0): + MAX_CTX_LEN: tl.constexpr = 0, + FP8_MIN: tl.constexpr = float8_info.min, + FP8_MAX: tl.constexpr = float8_info.max): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -284,6 +289,9 @@ def _fwd_kernel(Q, off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od) out_ptrs = Out + off_o + if USE_FP8: + acc = acc * tl.load(out_scale_inv) + acc = tl.clamp(acc, FP8_MIN, FP8_MAX) tl.store(out_ptrs, acc, mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len)) @@ -743,6 +751,7 @@ def context_attention_fwd(q, sliding_window=None, sm_scale=None, skip_decode=False, + fp8_out_scale=None, sinks=None): q_dtype_is_f32 = q.dtype is torch.float32 @@ -793,6 +802,7 @@ def context_attention_fwd(q, if alibi_slopes is not None: assert sinks is None, "Sinks arg is not supported with alibi" + assert fp8_out_scale is None, "FP8 output not supported with alibi" # need to reduce num. blocks when using fp32 # due to increased use of GPU shared memory # if q.dtype is torch.float32: @@ -870,6 +880,7 @@ def context_attention_fwd(q, sm_scale, k_scale, v_scale, + 1.0 / fp8_out_scale if fp8_out_scale is not None else 1.0, b_start_loc, b_seq_len, k_cache.shape[4], @@ -905,6 +916,7 @@ def context_attention_fwd(q, BLOCK_DMODEL_PADDED=Lk_padded, SLIDING_WINDOW=sliding_window, SKIP_DECODE=skip_decode, + USE_FP8=fp8_out_scale is not None, BLOCK_M=128, BLOCK_N=64, num_unroll_cache=4, diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index 250e9b3890444..d2ad2f7e8d2aa 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -10,9 +10,11 @@ import torch from vllm.logger import init_logger +from vllm.platforms import current_platform from vllm.triton_utils import tl, triton logger = init_logger(__name__) +float8_info = torch.finfo(current_platform.fp8_dtype()) @triton.jit @@ -48,47 +50,51 @@ def find_seq_idx(query_start_len_ptr, target_idx, num_seqs, @triton.jit def kernel_unified_attention_2d( - output_ptr, # [num_tokens, num_query_heads, head_size] - query_ptr, # [num_tokens, num_query_heads, head_size] - key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] - value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] - sink_ptr, # [num_query_heads] - block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] - seq_lens_ptr, # [num_seqs] - alibi_slopes_ptr, # [num_query_heads] - qq_bias_ptr, # [num_query_tokens, num_query_tokens] - scale, # float32 - k_scale, # float32 - v_scale, # float32 - softcap, # float32 - num_query_heads: tl.constexpr, # int - num_queries_per_kv: tl.constexpr, # int - block_table_stride: tl.int64, # int - query_stride_0: tl.int64, # int - query_stride_1: tl.int64, # int, should be equal to head_size - output_stride_0: tl.int64, # int - output_stride_1: tl.int64, # int, should be equal to head_size - qq_bias_stride_0: tl.int64, # int - BLOCK_SIZE: tl.constexpr, # int - HEAD_SIZE: tl.constexpr, # int - HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 - USE_ALIBI_SLOPES: tl.constexpr, # bool - USE_QQ_BIAS: tl.constexpr, # bool - USE_SOFTCAP: tl.constexpr, # bool - USE_SINKS: tl.constexpr, # bool - SLIDING_WINDOW: tl.constexpr, # int - stride_k_cache_0: tl.int64, # int - stride_k_cache_1: tl.int64, # int - stride_k_cache_2: tl.int64, # int - stride_k_cache_3: tl.constexpr, # int - stride_v_cache_0: tl.int64, # int - stride_v_cache_1: tl.int64, # int - stride_v_cache_2: tl.int64, # int - stride_v_cache_3: tl.constexpr, # int - query_start_len_ptr, # [num_seqs+1] - BLOCK_Q: tl.constexpr, # int - num_seqs: tl.int32, - BLOCK_M: tl.constexpr, # int + output_ptr, # [num_tokens, num_query_heads, head_size] + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + sink_ptr, # [num_query_heads] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + seq_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + qq_bias_ptr, # [num_query_tokens, num_query_tokens] + scale, # float32 + k_scale, # float32 + v_scale, # float32 + out_scale, # float32 + softcap, # float32 + num_query_heads: tl.constexpr, # int + num_queries_per_kv: tl.constexpr, # int + block_table_stride: tl.int64, # int + query_stride_0: tl.int64, # int + query_stride_1: tl.int64, # int, should be equal to head_size + output_stride_0: tl.int64, # int + output_stride_1: tl.int64, # int, should be equal to head_size + qq_bias_stride_0: tl.int64, # int + BLOCK_SIZE: tl.constexpr, # int + HEAD_SIZE: tl.constexpr, # int + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + USE_ALIBI_SLOPES: tl.constexpr, # bool + USE_QQ_BIAS: tl.constexpr, # bool + USE_SOFTCAP: tl.constexpr, # bool + USE_SINKS: tl.constexpr, # bool + SLIDING_WINDOW: tl.constexpr, # int + stride_k_cache_0: tl.int64, # int + stride_k_cache_1: tl.int64, # int + stride_k_cache_2: tl.int64, # int + stride_k_cache_3: tl.constexpr, # int + stride_v_cache_0: tl.int64, # int + stride_v_cache_1: tl.int64, # int + stride_v_cache_2: tl.int64, # int + stride_v_cache_3: tl.constexpr, # int + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + num_seqs: tl.int32, + BLOCK_M: tl.constexpr, # int + USE_FP8: tl.constexpr, # bool + FP8_MIN: tl.constexpr = float8_info.min, + FP8_MAX: tl.constexpr = float8_info.max, ): q_block_global_idx = tl.program_id(0) kv_head_idx = tl.program_id(1) @@ -281,6 +287,9 @@ def kernel_unified_attention_2d( # epilogue acc = acc / L[:, None] + if USE_FP8: + acc = acc * tl.load(out_scale) + acc = tl.clamp(acc, FP8_MIN, FP8_MAX) output_offset = (query_offset_0[:, None] * output_stride_0 + query_offset_1[:, None] * output_stride_1 + @@ -552,23 +561,27 @@ def kernel_unified_attention_3d( @triton.jit def reduce_segments( - output_ptr, # [num_tokens, num_query_heads, head_size] - segm_output_ptr, - #[num_tokens, num_query_heads, max_num_segments, head_size] - segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments] - segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments] - seq_lens_ptr, # [num_seqs] - num_seqs, # int - num_query_heads: tl.constexpr, # int - output_stride_0: tl.int64, # int - output_stride_1: tl.int64, # int, should be equal to head_size - block_table_stride: tl.int64, # int - BLOCK_SIZE: tl.constexpr, # int - HEAD_SIZE: tl.constexpr, # int, must be power of 2 - HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 - query_start_len_ptr, # [num_seqs+1] - BLOCK_Q: tl.constexpr, # int - NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int + output_ptr, # [num_tokens, num_query_heads, head_size] + segm_output_ptr, + #[num_tokens, num_query_heads, max_num_segments, head_size] + segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments] + segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments] + seq_lens_ptr, # [num_seqs] + num_seqs, # int + num_query_heads: tl.constexpr, # int + out_scale_inv, # float32 + output_stride_0: tl.int64, # int + output_stride_1: tl.int64, # int, should be equal to head_size + block_table_stride: tl.int64, # int + BLOCK_SIZE: tl.constexpr, # int + HEAD_SIZE: tl.constexpr, # int, must be power of 2 + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int + USE_FP8: tl.constexpr, # bool + FP8_MIN: tl.constexpr = float8_info.min, + FP8_MAX: tl.constexpr = float8_info.max, ): query_token_idx = tl.program_id(0) query_head_idx = tl.program_id(1) @@ -624,6 +637,10 @@ def reduce_segments( # safely divide by overall_expsum, returning 0.0 if overall_expsum is 0 acc = tl.where(overall_expsum == 0.0, 0.0, acc_sum / overall_expsum) + if USE_FP8: + acc = acc * tl.load(out_scale_inv) + acc = tl.clamp(acc, FP8_MIN, FP8_MAX) + # write result output_offset = (query_token_idx * output_stride_0 + query_head_idx * output_stride_1 + @@ -649,6 +666,7 @@ def unified_attention( k_descale, v_descale, alibi_slopes=None, + output_scale=None, qq_bias=None, # Optional tensor for sinks sinks=None, @@ -707,6 +725,7 @@ def unified_attention( scale=softmax_scale, k_scale=k_descale, v_scale=v_descale, + out_scale=1 / output_scale if output_scale is not None else 1.0, softcap=softcap, num_query_heads=num_query_heads, num_queries_per_kv=num_queries_per_kv, @@ -736,6 +755,7 @@ def unified_attention( BLOCK_Q=BLOCK_Q, num_seqs=num_seqs, BLOCK_M=BLOCK_M, + USE_FP8=output_scale is not None, ) else: # for initial version, NUM_SEGMENTS = 16 is chosen as a default @@ -819,6 +839,8 @@ def unified_attention( seq_lens_ptr=seqused_k, num_seqs=num_seqs, num_query_heads=num_query_heads, + out_scale_inv=1 / + output_scale if output_scale is not None else 1.0, output_stride_0=out.stride(0), output_stride_1=out.stride(1), block_table_stride=block_table.stride(0), @@ -828,4 +850,5 @@ def unified_attention( query_start_len_ptr=cu_seqlens_q, BLOCK_Q=BLOCK_Q, NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, + USE_FP8=output_scale is not None, ) diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index 882b68ac9e2fd..bf9e87198bcf1 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -198,8 +198,9 @@ class BenchmarkDataset(ABC): @abstractmethod def sample(self, tokenizer: PreTrainedTokenizerBase, - num_requests: int, - request_id_prefix: str = "") -> list[SampleRequest]: + num_requests: int, + request_id_prefix: str = "", + no_oversample: bool = False) -> list[SampleRequest]: """ Abstract method to generate sample requests from the dataset. @@ -224,6 +225,7 @@ class BenchmarkDataset(ABC): requests: list[SampleRequest], num_requests: int, request_id_prefix: str = "", + no_oversample: bool = False, ) -> None: """ Oversamples the list of requests if its size is less than the desired @@ -236,6 +238,11 @@ class BenchmarkDataset(ABC): request_id_prefix (str) The prefix of the request ids. """ + if no_oversample: + logger.info("Skipping oversampling. " \ + "Total samples: %d.", len(requests)) + return + if len(requests) < num_requests: random.seed(self.random_seed) additional = deepcopy( @@ -405,6 +412,7 @@ class RandomDataset(BenchmarkDataset): tokenizer: PreTrainedTokenizerBase, num_requests: int, request_id_prefix: str = "", + no_oversample: bool = False, prefix_len: int = DEFAULT_PREFIX_LEN, range_ratio: float = DEFAULT_RANGE_RATIO, input_len: int = DEFAULT_INPUT_LEN, @@ -543,7 +551,7 @@ class RandomDataset(BenchmarkDataset): [6880, 6881] -> ['Ġcalls', 'here'] -> [1650, 939, 486] -> ['Ġcall', 'sh', 'ere'] To avoid uncontrolled change of the prompt length, - the encoded sequence is truncated before being decode again. + the encoded sequence is truncated before being decoded again. """ # Build the inner sequence by sampling sequentially from the vocab inner_seq = ((offset + index + np.arange(input_len)) @@ -832,6 +840,7 @@ class RandomMultiModalDataset(RandomDataset): tokenizer: PreTrainedTokenizerBase, num_requests: int, request_id_prefix: str = "", + no_oversample: bool = False, prefix_len: int = RandomDataset.DEFAULT_PREFIX_LEN, range_ratio: float = RandomDataset.DEFAULT_RANGE_RATIO, input_len: int = RandomDataset.DEFAULT_INPUT_LEN, @@ -959,6 +968,7 @@ class ShareGPTDataset(BenchmarkDataset): output_len: Optional[int] = None, enable_multimodal_chat: bool = False, request_id_prefix: str = "", + no_oversample: bool = False, **kwargs, ) -> list: samples: list = [] @@ -1002,7 +1012,10 @@ class ShareGPTDataset(BenchmarkDataset): request_id=request_id_prefix + str(ind), )) ind += 1 - self.maybe_oversample_requests(samples, num_requests, request_id_prefix) + self.maybe_oversample_requests(samples, + num_requests, + request_id_prefix, + no_oversample) return samples @@ -1020,7 +1033,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser): default="random", choices=[ "sharegpt", "burstgpt", "sonnet", "random", "random-mm", "hf", - "custom", "prefix_repetition" + "custom", "prefix_repetition", "spec_bench" ], help="Name of the dataset to benchmark on.", ) @@ -1036,6 +1049,12 @@ def add_dataset_parser(parser: FlexibleArgumentParser): help="Path to the sharegpt/sonnet dataset. " "Or the huggingface dataset ID if using HF dataset.", ) + parser.add_argument( + "--no-oversample", + action="store_true", + help="Do not oversample if the dataset has " \ + "fewer samples than num-prompts.", + ) # group for dataset specific arguments custom_group = parser.add_argument_group("custom dataset options") @@ -1053,6 +1072,22 @@ def add_dataset_parser(parser: FlexibleArgumentParser): "Skip applying chat template to prompt, used only for custom dataset.", ) + spec_bench_group = parser.add_argument_group("spec bench dataset options") + spec_bench_group.add_argument( + "--spec-bench-output-len", + type=int, + default=256, + help= + "Num of output tokens per request, used only for spec bench dataset.", + ) + spec_bench_group.add_argument( + "--spec-bench-category", + type=str, + default=None, + help= + "Category for spec bench dataset. If None, use all categories.", + ) + sonnet_group = parser.add_argument_group("sonnet dataset options") sonnet_group.add_argument( "--sonnet-input-len", @@ -1085,6 +1120,22 @@ def add_dataset_parser(parser: FlexibleArgumentParser): "from the ShareGPT dataset.", ) + blazedit_group = parser.add_argument_group("blazedit dataset options") + blazedit_group.add_argument( + "--blazedit-min-distance", + type=float, + default=0.0, + help= + "Minimum distance for blazedit dataset. Min: 0, Max: 1.0", + ) + blazedit_group.add_argument( + "--blazedit-max-distance", + type=float, + default=1.0, + help= + "Maximum distance for blazedit dataset. Min: 0, Max: 1.0", + ) + random_group = parser.add_argument_group("random dataset options") random_group.add_argument( "--random-input-len", @@ -1278,6 +1329,10 @@ def add_dataset_parser(parser: FlexibleArgumentParser): def get_samples(args, tokenizer) -> list[SampleRequest]: + + if not hasattr(args, "request_id_prefix"): + args.request_id_prefix = "" + if args.dataset_name == "custom": dataset = CustomDataset(dataset_path=args.dataset_path) input_requests = dataset.sample( @@ -1286,6 +1341,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: output_len=args.custom_output_len, skip_chat_template=args.custom_skip_chat_template, request_id_prefix=args.request_id_prefix, + no_oversample=args.no_oversample, ) elif args.dataset_name == "sonnet": @@ -1300,6 +1356,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: tokenizer=tokenizer, return_prompt_formatted=False, request_id_prefix=args.request_id_prefix, + no_oversample=args.no_oversample, ) else: assert tokenizer.chat_template or tokenizer.default_chat_template, ( @@ -1312,11 +1369,13 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: tokenizer=tokenizer, return_prompt_formatted=True, request_id_prefix=args.request_id_prefix, + no_oversample=args.no_oversample, ) elif args.dataset_name == "hf": # all following datasets are implemented from the # HuggingFaceDataset base class + hf_kwargs = {} if ( args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS or args.hf_name in VisionArenaDataset.SUPPORTED_DATASET_PATHS @@ -1360,6 +1419,13 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: ): dataset_class = ASRDataset args.hf_split = "train" + elif args.dataset_path in BlazeditDataset.SUPPORTED_DATASET_PATHS: + dataset_class = BlazeditDataset + args.hf_split = "train" + hf_kwargs = { + "min_distance": args.blazedit_min_distance, + "max_distance": args.blazedit_max_distance, + } elif ( args.dataset_path in MLPerfDataset.SUPPORTED_DATASET_PATHS or args.hf_name in MLPerfDataset.SUPPORTED_DATASET_PATHS @@ -1399,11 +1465,22 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: tokenizer=tokenizer, output_len=args.hf_output_len, request_id_prefix=args.request_id_prefix, + no_oversample=args.no_oversample, + **hf_kwargs ) else: # For datasets that follow a similar structure, use a mapping. dataset_mapping = { + "spec_bench": + lambda: SpecBench(dataset_path=args.dataset_path, + category=args.spec_bench_category).sample( + num_requests=args.num_prompts, + tokenizer=tokenizer, + output_len=args.spec_bench_output_len, + request_id_prefix=args.request_id_prefix, + no_oversample=args.no_oversample, + ), "sharegpt": lambda: ShareGPTDataset( random_seed=args.seed, dataset_path=args.dataset_path ).sample( @@ -1411,6 +1488,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: num_requests=args.num_prompts, output_len=args.sharegpt_output_len, request_id_prefix=args.request_id_prefix, + no_oversample=args.no_oversample, ), "burstgpt": lambda: BurstGPTDataset( random_seed=args.seed, dataset_path=args.dataset_path @@ -1418,6 +1496,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: tokenizer=tokenizer, num_requests=args.num_prompts, request_id_prefix=args.request_id_prefix, + no_oversample=args.no_oversample, ), "random": lambda: RandomDataset( random_seed=args.seed, dataset_path=args.dataset_path @@ -1430,6 +1509,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: range_ratio=args.random_range_ratio, request_id_prefix=args.request_id_prefix, batchsize=args.random_batch_size, + no_oversample=args.no_oversample, ), "random-mm": lambda: RandomMultiModalDataset( @@ -1446,6 +1526,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: num_mm_items_range_ratio=args.random_mm_num_mm_items_range_ratio, bucket_config=args.random_mm_bucket_config, request_id_prefix=args.request_id_prefix, + no_oversample=args.no_oversample, ), "prefix_repetition": lambda: PrefixRepetitionRandomDataset( @@ -1458,6 +1539,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: num_prefixes=args.prefix_repetition_num_prefixes, output_len=args.prefix_repetition_output_len, request_id_prefix=args.request_id_prefix, + no_oversample=args.no_oversample, ), } @@ -1539,8 +1621,17 @@ class CustomDataset(BenchmarkDataset): enable_multimodal_chat: bool = False, skip_chat_template: bool = False, request_id_prefix: str = "", + no_oversample: bool = False, **kwargs, ) -> list: + # load all data if needed + self.num_available_samples = len(self.data) + if num_requests <= 0: + num_requests = self.num_available_samples + logger.info("num_requests is set to 0 or negative, " + "so using all available samples: %d", + num_requests) + sampled_requests = [] for i, item in enumerate(self.data): if len(sampled_requests) >= num_requests: @@ -1567,11 +1658,57 @@ class CustomDataset(BenchmarkDataset): request_id=request_id_prefix + str(i), )) self.maybe_oversample_requests(sampled_requests, num_requests, - request_id_prefix) + request_id_prefix, no_oversample) return sampled_requests +# ----------------------------------------------------------------------------- +# Spec Bench Dataset Implementation +# ----------------------------------------------------------------------------- + + +class SpecBench(CustomDataset): + """ + Implements the SpecBench dataset: https://github.com/hemingkx/Spec-Bench + Download the dataset using: + wget https://raw.githubusercontent.com/hemingkx/Spec-Bench/refs/heads/main/data/spec_bench/question.jsonl + """ # noqa: E501 + + def __init__(self, **kwargs) -> None: + self.category = kwargs.pop("category", None) + super().__init__(**kwargs) + self.load_data() + + def load_data(self) -> None: + if self.dataset_path is None: + raise ValueError("dataset_path must be provided for loading data.") + + self.data = [] + + # Load the JSONL file + jsonl_data = pd.read_json(path_or_buf=self.dataset_path, + lines=True) + + # check if the JSONL file has a 'turns' column + if "turns" not in jsonl_data.columns: + raise ValueError("JSONL file must contain a 'turns' column.") + + for _, row in jsonl_data.iterrows(): + # sample only from a specific category if specified + if (not self.category) or (self.category == row['category']): + prompt = row["turns"][0] + self.data.append({"prompt": prompt}) + + random.seed(self.random_seed) + random.shuffle(self.data) + + def sample(self, **kwargs) -> list: + # leverage CustomDataset sample + kwargs["skip_chat_template"] = False + return super().sample(**kwargs) + + # ----------------------------------------------------------------------------- # Sonnet Dataset Implementation # ----------------------------------------------------------------------------- @@ -1612,6 +1749,7 @@ class SonnetDataset(BenchmarkDataset): output_len: int = DEFAULT_OUTPUT_LEN, return_prompt_formatted: bool = False, request_id_prefix: str = "", + no_oversample: bool = False, **kwargs, ) -> list: # Calculate average token length for a poem line. @@ -1707,6 +1845,7 @@ class BurstGPTDataset(BenchmarkDataset): max_loras: Optional[int] = None, lora_path: Optional[str] = None, request_id_prefix: str = "", + no_oversample: bool = False, **kwargs, ) -> list[SampleRequest]: samples = [] @@ -1786,6 +1925,7 @@ class ConversationDataset(HuggingFaceDataset): output_len: Optional[int] = None, enable_multimodal_chat: bool = False, request_id_prefix: str = "", + no_oversample: bool = False, **kwargs) -> list: # Filter examples with at least 2 conversations filtered_data = self.data.filter( @@ -1827,7 +1967,7 @@ class ConversationDataset(HuggingFaceDataset): )) ind += 1 self.maybe_oversample_requests(sampled_requests, num_requests, - request_id_prefix) + request_id_prefix, no_oversample) return sampled_requests @@ -1857,6 +1997,7 @@ class VisionArenaDataset(HuggingFaceDataset): output_len: Optional[int] = None, enable_multimodal_chat: bool = False, request_id_prefix: str = "", + no_oversample: bool = False, **kwargs, ) -> list: output_len = (output_len @@ -1886,7 +2027,7 @@ class VisionArenaDataset(HuggingFaceDataset): request_id=request_id_prefix + str(i), )) self.maybe_oversample_requests(sampled_requests, num_requests, - request_id_prefix) + request_id_prefix, no_oversample) return sampled_requests @@ -1916,6 +2057,7 @@ class InstructCoderDataset(HuggingFaceDataset): output_len: Optional[int] = None, enable_multimodal_chat: bool = False, request_id_prefix: str = "", + no_oversample: bool = False, **kwargs) -> list: output_len = (output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN) @@ -1947,7 +2089,7 @@ class InstructCoderDataset(HuggingFaceDataset): request_id=request_id_prefix + str(i), )) self.maybe_oversample_requests(sampled_requests, num_requests, - request_id_prefix) + request_id_prefix, no_oversample) return sampled_requests @@ -1978,6 +2120,7 @@ class MTBenchDataset(HuggingFaceDataset): output_len: Optional[int] = None, enable_multimodal_chat: bool = False, request_id_prefix: str = "", + no_oversample: bool = False, **kwargs, ) -> list: output_len = (output_len @@ -2008,7 +2151,96 @@ class MTBenchDataset(HuggingFaceDataset): request_id=request_id_prefix + str(i), )) self.maybe_oversample_requests(sampled_requests, num_requests, - request_id_prefix) + request_id_prefix, no_oversample) + return sampled_requests + + +# ----------------------------------------------------------------------------- +# Blazedit Dataset Implementation +# ----------------------------------------------------------------------------- + + +class BlazeditDataset(HuggingFaceDataset): + """ + Blazedit Dataset. + https://github.com/ise-uiuc/blazedit + + 5k char version: vdaita/edit_5k_char + 10k char version: vdaita/edit_10k_char + """ # noqa: E501 + + # 5k char version will have output as ~5k chars + # 10k char version will have output as ~10k chars + # Assuming 3 char per token, 10k chars will be 3333 tokens + # We set default to 4000 to be safe + DEFAULT_OUTPUT_LEN = 4000 + SUPPORTED_DATASET_PATHS = { + "vdaita/edit_5k_char", + "vdaita/edit_10k_char", + } + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + request_id_prefix: str = "", + no_oversample: bool = False, + min_distance: float = 0.0, + max_distance: float = 1.0, + **kwargs, + ) -> list: + output_len = (output_len + if output_len is not None else self.DEFAULT_OUTPUT_LEN) + sampled_requests = [] + + for i, item in enumerate(self.data): + if len(sampled_requests) >= num_requests: + break + code = item["code"] + change_request = item["change_request"] + norm_distance = item["norm_distance"] + + # compare the levenshtein distance normalized by code length + if norm_distance < min_distance or norm_distance > max_distance: + continue + + # template copied from + # https://github.com/ise-uiuc/blazedit/blob/7765137e656fd62de877422d2e4cf8de51228054/dataset/create_refined_dataset.py#L94-L105 # noqa: E501 + instruction = f"""Given a code file, please apply the change requests and generate the new file. + +Original file: +```python +{code} +``` + +Change request: +{change_request} + +Please generate the new code file in the "New file" section below.""" # noqa: E501 + + # apply template + prompt = tokenizer.apply_chat_template( + [{ + "role": "user", + "content": instruction + }], + add_generation_prompt=True, + tokenize=False, + ) + + prompt_len = len(tokenizer(prompt).input_ids) + + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + request_id=request_id_prefix + str(i), + )) + self.maybe_oversample_requests(sampled_requests, num_requests, + request_id_prefix, no_oversample) + return sampled_requests @@ -2031,6 +2263,7 @@ class AIMODataset(HuggingFaceDataset): num_requests: int, output_len: Optional[int] = None, request_id_prefix: str = "", + no_oversample: bool = False, **kwargs) -> list: sampled_requests = [] ind = 0 @@ -2063,7 +2296,7 @@ class AIMODataset(HuggingFaceDataset): )) ind += 1 self.maybe_oversample_requests(sampled_requests, num_requests, - request_id_prefix) + request_id_prefix, no_oversample) return sampled_requests @@ -2135,6 +2368,7 @@ class NextEditPredictionDataset(HuggingFaceDataset): def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, request_id_prefix: str = "", + no_oversample: bool = False, **kwargs): formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(self.hf_name) if formatting_prompt_func is None: @@ -2152,7 +2386,10 @@ class NextEditPredictionDataset(HuggingFaceDataset): )) if len(samples) >= num_requests: break - self.maybe_oversample_requests(samples, num_requests, request_id_prefix) + self.maybe_oversample_requests(samples, + num_requests, + request_id_prefix, + no_oversample) return samples @@ -2203,6 +2440,7 @@ class ASRDataset(HuggingFaceDataset): num_requests: int, output_len: Optional[int] = None, request_id_prefix: str = "", + no_oversample: bool = False, **kwargs, ) -> list: output_len = (output_len @@ -2241,7 +2479,7 @@ class ASRDataset(HuggingFaceDataset): skipped, ) self.maybe_oversample_requests(sampled_requests, num_requests, - request_id_prefix) + request_id_prefix, no_oversample) return sampled_requests @@ -2279,6 +2517,7 @@ class MLPerfDataset(HuggingFaceDataset): num_requests: int, output_len: Optional[int] = None, request_id_prefix: str = "", + no_oversample: bool = False, **kwargs, ) -> list[SampleRequest]: # Force dynamic output length based on reference completion. @@ -2325,7 +2564,7 @@ class MLPerfDataset(HuggingFaceDataset): ind += 1 self.maybe_oversample_requests(sampled_requests, num_requests, - request_id_prefix) + request_id_prefix, no_oversample) return sampled_requests @@ -2359,6 +2598,7 @@ class PrefixRepetitionRandomDataset(BenchmarkDataset): num_prefixes: int = DEFAULT_NUM_PREFIXES, output_len: int = DEFAULT_OUTPUT_LEN, request_id_prefix: str = "", + no_oversample: bool = False, **kwargs, ) -> list[SampleRequest]: vocab_size = tokenizer.vocab_size diff --git a/vllm/benchmarks/lib/endpoint_request_func.py b/vllm/benchmarks/lib/endpoint_request_func.py index 6bb2a497119e9..e640630476630 100644 --- a/vllm/benchmarks/lib/endpoint_request_func.py +++ b/vllm/benchmarks/lib/endpoint_request_func.py @@ -17,6 +17,47 @@ from tqdm.asyncio import tqdm AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) +class StreamedResponseHandler: + """Handles streaming HTTP responses by accumulating chunks until complete + messages are available.""" + + def __init__(self): + self.buffer = "" + + def add_chunk(self, chunk_bytes: bytes) -> list[str]: + """Add a chunk of bytes to the buffer and return any complete + messages.""" + chunk_str = chunk_bytes.decode("utf-8") + self.buffer += chunk_str + + messages = [] + + # Split by double newlines (SSE message separator) + while "\n\n" in self.buffer: + message, self.buffer = self.buffer.split("\n\n", 1) + message = message.strip() + if message: + messages.append(message) + + # if self.buffer is not empty, check if it is a complete message + # by removing data: prefix and check if it is a valid JSON + if self.buffer.startswith("data: "): + message_content = self.buffer.removeprefix("data: ").strip() + if message_content == "[DONE]": + messages.append(self.buffer.strip()) + self.buffer = "" + elif message_content: + try: + json.loads(message_content) + messages.append(self.buffer.strip()) + self.buffer = "" + except json.JSONDecodeError: + # Incomplete JSON, wait for more chunks. + pass + + return messages + + @dataclass class RequestFuncInput: """The input for the request function.""" @@ -27,6 +68,7 @@ class RequestFuncInput: model: str model_name: Optional[str] = None logprobs: Optional[int] = None + extra_headers: Optional[dict] = None extra_body: Optional[dict] = None multi_modal_content: Optional[Union[dict, list[dict]]] = None ignore_eos: bool = False @@ -88,6 +130,8 @@ async def async_request_openai_completions( headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" } + if request_func_input.extra_headers: + headers |= request_func_input.extra_headers if request_func_input.request_id: headers["x-request-id"] = request_func_input.request_id @@ -102,46 +146,50 @@ async def async_request_openai_completions( headers=headers) as response: if response.status == 200: first_chunk_received = False - async for chunk_bytes in response.content: + handler = StreamedResponseHandler() + + async for chunk_bytes in response.content.iter_any(): chunk_bytes = chunk_bytes.strip() if not chunk_bytes: continue - chunk_bytes = chunk_bytes.decode("utf-8") - # NOTE: SSE comments (often used as pings) start with - # a colon. These are not JSON data payload and should - # be skipped. - if chunk_bytes.startswith(":"): - continue - chunk = chunk_bytes.removeprefix("data: ") + messages = handler.add_chunk(chunk_bytes) + for message in messages: + # NOTE: SSE comments (often used as pings) start with + # a colon. These are not JSON data payload and should + # be skipped. + if message.startswith(":"): + continue - if chunk != "[DONE]": - data = json.loads(chunk) + chunk = message.removeprefix("data: ") - # NOTE: Some completion API might have a last - # usage summary response without a token so we - # want to check a token was generated - if choices := data.get("choices"): - # Note that text could be empty here - # e.g. for special tokens - text = choices[0].get("text") - timestamp = time.perf_counter() - # First token - if not first_chunk_received: - first_chunk_received = True - ttft = time.perf_counter() - st - output.ttft = ttft + if chunk != "[DONE]": + data = json.loads(chunk) - # Decoding phase - else: - output.itl.append(timestamp - - most_recent_timestamp) + # NOTE: Some completion API might have a last + # usage summary response without a token so we + # want to check a token was generated + if choices := data.get("choices"): + # Note that text could be empty here + # e.g. for special tokens + text = choices[0].get("text") + timestamp = time.perf_counter() + # First token + if not first_chunk_received: + first_chunk_received = True + ttft = time.perf_counter() - st + output.ttft = ttft - most_recent_timestamp = timestamp - generated_text += text or "" - elif usage := data.get("usage"): - output.output_tokens = usage.get( - "completion_tokens") + # Decoding phase + else: + output.itl.append(timestamp - + most_recent_timestamp) + + most_recent_timestamp = timestamp + generated_text += text or "" + elif usage := data.get("usage"): + output.output_tokens = usage.get( + "completion_tokens") if first_chunk_received: output.success = True else: @@ -213,6 +261,8 @@ async def async_request_openai_chat_completions( "Content-Type": "application/json", "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", } + if request_func_input.extra_headers: + headers |= request_func_input.extra_headers if request_func_input.request_id: headers["x-request-id"] = request_func_input.request_id @@ -227,41 +277,44 @@ async def async_request_openai_chat_completions( async with session.post(url=api_url, json=payload, headers=headers) as response: if response.status == 200: - async for chunk_bytes in response.content: + handler = StreamedResponseHandler() + async for chunk_bytes in response.content.iter_any(): chunk_bytes = chunk_bytes.strip() if not chunk_bytes: continue - chunk_bytes = chunk_bytes.decode("utf-8") - # NOTE: SSE comments (often used as pings) start with - # a colon. These are not JSON data payload and should - # be skipped. - if chunk_bytes.startswith(":"): - continue - chunk = chunk_bytes.removeprefix("data: ") + messages = handler.add_chunk(chunk_bytes) + for message in messages: + # NOTE: SSE comments (often used as pings) start with + # a colon. These are not JSON data payload and should + # be skipped. + if message.startswith(":"): + continue - if chunk != "[DONE]": - timestamp = time.perf_counter() - data = json.loads(chunk) + chunk = message.removeprefix("data: ") - if choices := data.get("choices"): - content = choices[0]["delta"].get("content") - # First token - if ttft == 0.0: - ttft = timestamp - st - output.ttft = ttft + if chunk != "[DONE]": + timestamp = time.perf_counter() + data = json.loads(chunk) - # Decoding phase - else: - output.itl.append(timestamp - - most_recent_timestamp) + if choices := data.get("choices"): + content = choices[0]["delta"].get("content") + # First token + if ttft == 0.0: + ttft = timestamp - st + output.ttft = ttft - generated_text += content or "" - elif usage := data.get("usage"): - output.output_tokens = usage.get( - "completion_tokens") + # Decoding phase + else: + output.itl.append(timestamp - + most_recent_timestamp) - most_recent_timestamp = timestamp + generated_text += content or "" + elif usage := data.get("usage"): + output.output_tokens = usage.get( + "completion_tokens") + + most_recent_timestamp = timestamp output.generated_text = generated_text output.success = True @@ -316,6 +369,8 @@ async def async_request_openai_audio( headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", } + if request_func_input.extra_headers: + headers |= request_func_input.extra_headers if request_func_input.request_id: headers["x-request-id"] = request_func_input.request_id @@ -347,36 +402,40 @@ async def async_request_openai_audio( data=form, headers=headers) as response: if response.status == 200: - async for chunk_bytes in response.content: + handler = StreamedResponseHandler() + + async for chunk_bytes in response.content.iter_any(): chunk_bytes = chunk_bytes.strip() if not chunk_bytes: continue - chunk = chunk_bytes.decode("utf-8").removeprefix( - "data: ") - if chunk != "[DONE]": - timestamp = time.perf_counter() - data = json.loads(chunk) + messages = handler.add_chunk(chunk_bytes) + for message in messages: + chunk = message.decode("utf-8").removeprefix( + "data: ") + if chunk != "[DONE]": + timestamp = time.perf_counter() + data = json.loads(chunk) - if choices := data.get("choices"): - content = choices[0]["delta"].get( - "content") - # First token - if ttft == 0.0: - ttft = timestamp - st - output.ttft = ttft + if choices := data.get("choices"): + content = choices[0]["delta"].get( + "content") + # First token + if ttft == 0.0: + ttft = timestamp - st + output.ttft = ttft - # Decoding phase - else: - output.itl.append( - timestamp - most_recent_timestamp) + # Decoding phase + else: + output.itl.append( + timestamp - most_recent_timestamp) - generated_text += content or "" - elif usage := data.get("usage"): - output.output_tokens = usage.get( - "completion_tokens") + generated_text += content or "" + elif usage := data.get("usage"): + output.output_tokens = usage.get( + "completion_tokens") - most_recent_timestamp = timestamp + most_recent_timestamp = timestamp output.generated_text = generated_text output.success = True diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py index a98eb2a78f103..33e831e54bbc9 100644 --- a/vllm/benchmarks/serve.py +++ b/vllm/benchmarks/serve.py @@ -389,6 +389,7 @@ async def benchmark( goodput_config_dict: dict[str, float], max_concurrency: Optional[int], lora_modules: Optional[Iterable[str]], + extra_headers: Optional[dict], extra_body: Optional[dict], ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None, ramp_up_start_rps: Optional[int] = None, @@ -452,6 +453,7 @@ async def benchmark( logprobs=logprobs, multi_modal_content=test_mm_content, ignore_eos=ignore_eos, + extra_headers=extra_headers, extra_body=extra_body, ) @@ -484,6 +486,7 @@ async def benchmark( logprobs=logprobs, multi_modal_content=test_mm_content, ignore_eos=ignore_eos, + extra_headers=extra_headers, extra_body=extra_body) profile_output = await request_func( request_func_input=profile_input, session=session) @@ -568,6 +571,7 @@ async def benchmark( logprobs=logprobs, multi_modal_content=mm_content, ignore_eos=ignore_eos, + extra_headers=extra_headers, extra_body=extra_body, request_id=request_id,) tasks.append( @@ -815,6 +819,15 @@ def add_cli_args(parser: argparse.ArgumentParser): default="/v1/completions", help="API endpoint.", ) + parser.add_argument( + "--header", + metavar="KEY=VALUE", + nargs="*", + help="Key-value pairs (e.g, --header x-additional-info=0.3.3) " + "for headers to be passed with each request. These headers override " \ + "per backend constants and values set via environment variable, and " \ + "will be overriden by other arguments (such as request ids)." + ) parser.add_argument( "--max-concurrency", type=int, @@ -1104,6 +1117,19 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: api_url = f"http://{args.host}:{args.port}{args.endpoint}" base_url = f"http://{args.host}:{args.port}" + # Headers + headers = None + if args.header: + headers = {} + for item in args.header: + if "=" in item: + kvstring = item.split("=", 1) + headers[kvstring[0].strip()] = kvstring[1].strip() + else: + raise ValueError( + "Invalid header format. Please use KEY=VALUE format." + ) + tokenizer = get_tokenizer(tokenizer_id, tokenizer_mode=tokenizer_mode, trust_remote_code=args.trust_remote_code) @@ -1161,6 +1187,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: goodput_config_dict=goodput_config_dict, max_concurrency=args.max_concurrency, lora_modules=args.lora_modules, + extra_headers=headers, extra_body=sampling_params, ramp_up_strategy=args.ramp_up_strategy, ramp_up_start_rps=args.ramp_up_start_rps, @@ -1184,7 +1211,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: if args.metadata: for item in args.metadata: if "=" in item: - kvstring = item.split("=") + kvstring = item.split("=", 1) result_json[kvstring[0].strip()] = kvstring[1].strip() else: raise ValueError( diff --git a/vllm/benchmarks/throughput.py b/vllm/benchmarks/throughput.py index f022a55e625f5..96e39fd92eba0 100644 --- a/vllm/benchmarks/throughput.py +++ b/vllm/benchmarks/throughput.py @@ -37,6 +37,7 @@ def run_vllm( requests: list[SampleRequest], n: int, engine_args: EngineArgs, + do_profile: bool, disable_detokenize: bool = False, ) -> tuple[float, Optional[list[RequestOutput]]]: from vllm import LLM, SamplingParams @@ -75,10 +76,14 @@ def run_vllm( outputs = None if not use_beam_search: start = time.perf_counter() + if do_profile: + llm.start_profile() outputs = llm.generate(prompts, sampling_params, lora_request=lora_requests, use_tqdm=True) + if do_profile: + llm.stop_profile() end = time.perf_counter() else: assert lora_requests is None, "BeamSearch API does not support LoRA" @@ -88,6 +93,8 @@ def run_vllm( for request in requests: assert request.expected_output_len == output_len start = time.perf_counter() + if do_profile: + llm.start_profile() llm.beam_search( prompts, BeamSearchParams( @@ -95,6 +102,8 @@ def run_vllm( max_tokens=output_len, ignore_eos=True, )) + if do_profile: + llm.stop_profile() end = time.perf_counter() return end - start, outputs @@ -103,6 +112,7 @@ def run_vllm_chat( requests: list[SampleRequest], n: int, engine_args: EngineArgs, + do_profile: bool, disable_detokenize: bool = False) -> tuple[float, list[RequestOutput]]: """ Run vLLM chat benchmark. This function is recommended ONLY for benchmarking @@ -133,7 +143,11 @@ def run_vllm_chat( detokenize=not disable_detokenize, )) start = time.perf_counter() + if do_profile: + llm.start_profile() outputs = llm.chat(prompts, sampling_params, use_tqdm=True) + if do_profile: + llm.stop_profile() end = time.perf_counter() return end - start, outputs @@ -142,6 +156,7 @@ async def run_vllm_async( requests: list[SampleRequest], n: int, engine_args: AsyncEngineArgs, + do_profile: bool, disable_frontend_multiprocessing: bool = False, disable_detokenize: bool = False, ) -> float: @@ -185,6 +200,8 @@ async def run_vllm_async( generators = [] start = time.perf_counter() + if do_profile: + await llm.start_profile() for i, (prompt, sp, lr) in enumerate(zip(prompts, sampling_params, lora_requests)): generator = llm.generate(prompt, @@ -195,6 +212,8 @@ async def run_vllm_async( all_gens = merge_async_iterators(*generators) async for i, res in all_gens: pass + if do_profile: + await llm.stop_profile() end = time.perf_counter() return end - start @@ -543,6 +562,12 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, default=None, help="Split of the HF dataset.") + parser.add_argument( + "--profile", + action="store_true", + default=False, + help="Use Torch Profiler. The env variable " + "VLLM_TORCH_PROFILER_DIR must be set to enable profiler.") # prefix repetition dataset prefix_repetition_group = parser.add_argument_group( @@ -600,22 +625,27 @@ def main(args: argparse.Namespace): requests, args.n, AsyncEngineArgs.from_cli_args(args), - args.disable_frontend_multiprocessing, - args.disable_detokenize, + disable_frontend_multiprocessing=args.disable_frontend_multiprocessing, + disable_detokenize=args.disable_detokenize, + do_profile=args.profile, )) else: elapsed_time, request_outputs = run_vllm( requests, args.n, EngineArgs.from_cli_args(args), - args.disable_detokenize) + disable_detokenize=args.disable_detokenize, + do_profile=args.profile) elif args.backend == "hf": assert args.tensor_parallel_size == 1 + if args.profile: + raise NotImplementedError( + "Profiling not implemented yet for backend='hf'.") elapsed_time = run_hf(requests, args.model, tokenizer, args.n, args.hf_max_batch_size, args.trust_remote_code, args.disable_detokenize) elif args.backend == "vllm-chat": elapsed_time, request_outputs = run_vllm_chat( requests, args.n, EngineArgs.from_cli_args(args), - args.disable_detokenize) + disable_detokenize=args.disable_detokenize, do_profile=args.profile) else: raise ValueError(f"Unknown backend: {args.backend}") diff --git a/vllm/collect_env.py b/vllm/collect_env.py index 0291f64e84f0a..fb9d3657790cf 100644 --- a/vllm/collect_env.py +++ b/vllm/collect_env.py @@ -489,6 +489,16 @@ def get_libc_version(): return '-'.join(platform.libc_ver()) +def is_uv_venv(): + if os.environ.get("UV"): + return True + pyvenv_cfg_path = os.path.join(sys.prefix, 'pyvenv.cfg') + if os.path.exists(pyvenv_cfg_path): + with open(pyvenv_cfg_path, 'r') as f: + return any(line.startswith('uv = ') for line in f) + return False + + def get_pip_packages(run_lambda, patterns=None): """Return `pip list` output. Note: will also find conda-installed pytorch and numpy packages.""" if patterns is None: @@ -504,7 +514,7 @@ def get_pip_packages(run_lambda, patterns=None): if pip_available: cmd = [sys.executable, '-mpip', 'list', '--format=freeze'] - elif os.environ.get("UV") is not None: + elif is_uv_venv(): print("uv is set") cmd = ["uv", "pip", "list", "--format=freeze"] else: diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 3361b65a9b885..3cc0fc3106f5a 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -454,11 +454,12 @@ class VllmBackend: inductor_config = config.inductor_compile_config PASS_KEY = "post_grad_custom_post_pass" if PASS_KEY in inductor_config: - # Config should automatically wrap all inductor passes if isinstance(inductor_config[PASS_KEY], PostGradPassManager): + # PassManager already added to config, make sure it's correct assert (inductor_config[PASS_KEY].uuid() == self.post_grad_pass_manager.uuid()) else: + # Config should automatically wrap all inductor passes assert isinstance(inductor_config[PASS_KEY], InductorPass) self.post_grad_pass_manager.add(inductor_config[PASS_KEY]) inductor_config[PASS_KEY] = self.post_grad_pass_manager diff --git a/vllm/compilation/fusion_attn.py b/vllm/compilation/fusion_attn.py index 3095f17110fde..e3677b3dd62d8 100644 --- a/vllm/compilation/fusion_attn.py +++ b/vllm/compilation/fusion_attn.py @@ -39,6 +39,7 @@ class AttentionQuantPattern(ABC): self, layer: Attention, quant_key: QuantKey, + dtype: torch.dtype, ): self.layer = layer self.layer_name = layer.layer_name @@ -46,11 +47,16 @@ class AttentionQuantPattern(ABC): self.head_size = layer.head_size self.quant_key = quant_key self.quant_dtype = quant_key.dtype + self.dtype = dtype assert self.quant_key in QUANT_OPS, \ f"unsupported quantization scheme {self.quant_key}" self.QUANT_OP = QUANT_OPS[self.quant_key] + def empty(self, *args, **kwargs): + kwargs = {'dtype': self.dtype, 'device': "cuda", **kwargs} + return torch.empty(*args, **kwargs) + def empty_quant(self, *args, **kwargs): kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs} return torch.empty(*args, **kwargs) @@ -91,12 +97,13 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern): def __init__( self, layer: Attention, + dtype: torch.dtype, symmetric: bool = True, ): quant_key = QuantKey(dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric) - super().__init__(layer, quant_key) + super().__init__(layer, quant_key, dtype) def _register(self, pm_pass: PatternMatcherPass): @@ -139,10 +146,14 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern): return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size]) inputs = [ - empty_bf16(5, self.num_heads, self.head_size), # q - empty_bf16(5, self.num_heads, self.head_size), # k - empty_bf16(5, self.num_heads, self.head_size), # v - empty_bf16(5, self.num_heads, self.head_size), # attn_output + self.empty(5, self.num_heads, self.head_size, + dtype=self.dtype), # q + self.empty(5, self.num_heads, self.head_size, + dtype=self.dtype), # k + self.empty(5, self.num_heads, self.head_size, + dtype=self.dtype), # v + self.empty(5, self.num_heads, self.head_size, + dtype=self.dtype), # attn_output self.empty_quant(5, self.num_heads * self.head_size), # quant_output empty_fp32(1, 1) # scale @@ -165,8 +176,8 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern): will be passed into Attention op as the `output_scale` argument. """ - def __init__(self, layer: Attention): - super().__init__(layer, kNvfp4Quant) + def __init__(self, layer: Attention, dtype: torch.dtype): + super().__init__(layer, kNvfp4Quant, dtype) def _register(self, pm_pass: PatternMatcherPass): @@ -255,11 +266,15 @@ class AttnFusionPass(VllmInductorPass): attn_layers = get_layers_from_vllm_config(config, Attention) for layer_name, layer in attn_layers.items(): - pattern_fp8 = AttentionFp8StaticQuantPattern(layer) + pattern_fp8 = AttentionFp8StaticQuantPattern( + layer, config.model_config.dtype) pattern_fp8.register_if_supported(self.patterns) - pattern_nvfp4 = AttentionNvfp4QuantPattern(layer) - pattern_nvfp4.register_if_supported(self.patterns) + if current_platform.is_cuda() and hasattr(torch.ops._C, + "scaled_fp4_quant"): + pattern_nvfp4 = AttentionNvfp4QuantPattern( + layer, config.model_config.dtype) + pattern_nvfp4.register_if_supported(self.patterns) if len(attn_layers) == 0: logger.warning( diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 063af69f41dad..85e58c290b792 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -8,8 +8,8 @@ import enum import hashlib import inspect import json +import os import textwrap -import uuid import warnings from collections.abc import Mapping from contextlib import contextmanager @@ -33,12 +33,17 @@ from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType, PrefixCachingHashAlgo) from vllm.config.compilation import (CompilationConfig, CompilationLevel, CUDAGraphMode, PassConfig) +from vllm.config.kv_events import KVEventsConfig +from vllm.config.kv_transfer import KVTransferConfig +from vllm.config.load import LoadConfig +from vllm.config.lora import LoRAConfig from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig, ParallelConfig) from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy from vllm.config.utils import ConfigType, config from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.platforms import current_platform from vllm.transformers_utils.config import ( ConfigFormat, get_config, get_hf_image_processor_config, @@ -47,8 +52,9 @@ from vllm.transformers_utils.config import ( is_interleaved, maybe_override_with_speculators_target_model, try_get_generation_config, try_get_safetensors_metadata, try_get_tokenizer_config, uses_mrope) -from vllm.transformers_utils.s3_utils import S3Model -from vllm.transformers_utils.utils import is_s3, maybe_model_redirect +from vllm.transformers_utils.runai_utils import (ObjectStorageModel, + is_runai_obj_uri) +from vllm.transformers_utils.utils import maybe_model_redirect from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS, STR_DUAL_CHUNK_FLASH_ATTN_VAL, LayerBlockType, LazyLoader, common_broadcastable_dtype, random_uuid) @@ -62,8 +68,6 @@ if TYPE_CHECKING: from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) - from vllm.model_executor.model_loader import LoadFormats - from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.v1.sample.logits_processor import LogitsProcessor HfOverrides = Union[dict, Callable[[type], type]] @@ -73,8 +77,6 @@ else: QuantizationConfig = Any QuantizationMethods = Any BaseModelLoader = Any - LoadFormats = Any - TensorizerConfig = Any LogitsProcessor = Any HfOverrides = Union[dict[str, Any], Callable[[type], type]] @@ -260,6 +262,7 @@ def is_init_field(cls: ConfigType, name: str) -> bool: TokenizerMode = Literal["auto", "slow", "mistral", "custom"] ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"] MMEncoderTPMode = Literal["weights", "data"] +MMCacheType = Literal["shm", "lru"] class LogprobsMode(enum.Enum): @@ -420,7 +423,7 @@ class ModelConfig: `--media-io-kwargs '{"video": {"num_frames": 40} }'` """ use_async_output_proc: bool = True """Whether to use async output processor.""" - config_format: Union[str, ConfigFormat] = ConfigFormat.AUTO.value + config_format: Union[str, ConfigFormat] = "auto" """The format of the model config to load:\n - "auto" will try to load the config in hf format if available else it will try to load in mistral format.\n @@ -448,6 +451,13 @@ class ModelConfig: `mm_processor_cache_gb * (api_server_count + data_parallel_size)`. Set to `0` to disable this cache completely (not recommended).""" + mm_processor_cache_type: MMCacheType = "lru" + """Type of cache to use for the multi-modal preprocessor/mapper. If `shm`, + use shared memory FIFO cache. If `lru`, use mirrored LRU cache.""" + mm_shm_cache_max_object_size_mb: int = 128 + """Size limit (in MiB) for each object stored in the multi-modal processor + shared memory cache. Only effective when `mm_processor_cache_type` is + `"shm"`.""" mm_encoder_tp_mode: MMEncoderTPMode = "weights" """Indicates how to optimize multi-modal encoder inference using tensor parallelism (TP). @@ -555,15 +565,6 @@ class ModelConfig: "affect the random state of the Python process that " "launched vLLM.", self.seed) - if self.runner != "draft": - # If we're not running the draft model, check for speculators config - # If speculators config, set model / tokenizer to be target model - self.model, self.tokenizer = maybe_override_with_speculators_target_model( # noqa: E501 - model=self.model, - tokenizer=self.tokenizer, - revision=self.revision, - trust_remote_code=self.trust_remote_code) - # Keep set served_model_name before maybe_model_redirect(self.model) self.served_model_name = get_served_model_name(self.model, self.served_model_name) @@ -602,7 +603,16 @@ class ModelConfig: f"'Please instead use `--hf-overrides '{hf_overrides_str}'`") warnings.warn(DeprecationWarning(msg), stacklevel=2) - self.maybe_pull_model_tokenizer_for_s3(self.model, self.tokenizer) + self.maybe_pull_model_tokenizer_for_runai(self.model, self.tokenizer) + + if self.runner != "draft": + # If we're not running the draft model, check for speculators config + # If speculators config, set model / tokenizer to be target model + self.model, self.tokenizer = maybe_override_with_speculators_target_model( # noqa: E501 + model=self.model, + tokenizer=self.tokenizer, + revision=self.revision, + trust_remote_code=self.trust_remote_code) if (backend := envs.VLLM_ATTENTION_BACKEND ) and backend == "FLASHINFER" and find_spec("flashinfer") is None: @@ -625,9 +635,6 @@ class ModelConfig: raise ValueError( "Sleep mode is not supported on current platform.") - if isinstance(self.config_format, str): - self.config_format = ConfigFormat(self.config_format) - hf_config = get_config(self.hf_config_path or self.model, self.trust_remote_code, self.revision, @@ -744,7 +751,7 @@ class ModelConfig: self.pooler_config = self._init_pooler_config() - self.dtype = _get_and_verify_dtype( + self.dtype: torch.dtype = _get_and_verify_dtype( self.model, self.hf_config, self.dtype, @@ -831,41 +838,42 @@ class ModelConfig: """The architecture vllm actually used.""" return self._architecture - def maybe_pull_model_tokenizer_for_s3(self, model: str, - tokenizer: str) -> None: - """Pull model/tokenizer from S3 to temporary directory when needed. + def maybe_pull_model_tokenizer_for_runai(self, model: str, + tokenizer: str) -> None: + """Pull model/tokenizer from Object Storage to temporary + directory when needed. Args: model: Model name or path tokenizer: Tokenizer name or path """ - if not (is_s3(model) or is_s3(tokenizer)): + if not (is_runai_obj_uri(model) or is_runai_obj_uri(tokenizer)): return - if is_s3(model): - s3_model = S3Model() - s3_model.pull_files(model, - allow_pattern=["*.model", "*.py", "*.json"]) + if is_runai_obj_uri(model): + object_storage_model = ObjectStorageModel() + object_storage_model.pull_files( + model, allow_pattern=["*.model", "*.py", "*.json"]) self.model_weights = model - self.model = s3_model.dir + self.model = object_storage_model.dir # If tokenizer is same as model, download to same directory if model == tokenizer: - s3_model.pull_files(model, - ignore_pattern=[ - "*.pt", "*.safetensors", "*.bin", - "*.tensors" - ]) - self.tokenizer = s3_model.dir + object_storage_model.pull_files(model, + ignore_pattern=[ + "*.pt", "*.safetensors", + "*.bin", "*.tensors" + ]) + self.tokenizer = object_storage_model.dir return # Only download tokenizer if needed and not already handled - if is_s3(tokenizer): - s3_tokenizer = S3Model() - s3_tokenizer.pull_files( + if is_runai_obj_uri(tokenizer): + object_storage_tokenizer = ObjectStorageModel() + object_storage_tokenizer.pull_files( model, ignore_pattern=["*.pt", "*.safetensors", "*.bin", "*.tensors"]) - self.tokenizer = s3_tokenizer.dir + self.tokenizer = object_storage_tokenizer.dir def _init_multimodal_config(self) -> Optional["MultiModalConfig"]: if self._model_info.supports_multimodal: @@ -881,6 +889,9 @@ class ModelConfig: media_io_kwargs=self.media_io_kwargs, mm_processor_kwargs=self.mm_processor_kwargs, mm_processor_cache_gb=self.mm_processor_cache_gb, + mm_processor_cache_type=self.mm_processor_cache_type, + mm_shm_cache_max_object_size_mb=self. + mm_shm_cache_max_object_size_mb, mm_encoder_tp_mode=self.mm_encoder_tp_mode, interleave_mm_strings=self.interleave_mm_strings, skip_mm_profiling=self.skip_mm_profiling, @@ -1092,11 +1103,11 @@ class ModelConfig: assert_never(runner_type) - def _parse_quant_hf_config(self): - quant_cfg = getattr(self.hf_config, "quantization_config", None) + def _parse_quant_hf_config(self, hf_config: PretrainedConfig): + quant_cfg = getattr(hf_config, "quantization_config", None) if quant_cfg is None: # compressed-tensors uses a "compression_config" key - quant_cfg = getattr(self.hf_config, "compression_config", None) + quant_cfg = getattr(hf_config, "compression_config", None) else: # Set quant_method for ModelOpt models. @@ -1137,7 +1148,11 @@ class ModelConfig: self.quantization) # Parse quantization method from the HF model config, if available. - quant_cfg = self._parse_quant_hf_config() + quant_cfg = self._parse_quant_hf_config(self.hf_config) + if quant_cfg is None and (text_config := getattr( + self.hf_config, "text_config", None)): + # Check the text config as well for multi-modal models. + quant_cfg = self._parse_quant_hf_config(text_config) if quant_cfg is not None: # Use the community standard 'quant_method' @@ -1169,7 +1184,7 @@ class ModelConfig: ] # Any custom overrides will be in quantization_methods so we place # them at the start of the list so custom overrides have preference - # over the built in ones. + # over the built-in ones. quantization_methods = quantization_methods + overrides # Detect which checkpoint is it @@ -1505,7 +1520,8 @@ class ModelConfig: if (self.hf_text_config.model_type == "deepseek_mtp" or self.hf_config.model_type == "mimo_mtp" or self.hf_config.model_type == "glm4_moe_mtp" - or self.hf_config.model_type == "ernie_mtp"): + or self.hf_config.model_type == "ernie_mtp" + or self.hf_config.model_type == "qwen3_next_mtp"): total_num_hidden_layers = getattr(self.hf_text_config, "num_nextn_predict_layers", 0) else: @@ -1549,7 +1565,7 @@ class ModelConfig: for bc in block_configs[start:end]) else: # Hybrid model Jamba - layers_block_type_value = getattr(self.hf_config, + layers_block_type_value = getattr(self.hf_text_config, "layers_block_type", None) if layers_block_type_value is not None: if hasattr(self.hf_text_config, @@ -1568,15 +1584,28 @@ class ModelConfig: if attn_type_list: return sum(t == 1 for t in attn_type_list[start:end]) - if layers_block_type_value is None and attn_type_list is None: + # Hybrid model Qwen3Next + layer_types_value = getattr(self.hf_config, "layer_types", None) + if layer_types_value is not None: + if getattr(block_type, "value", block_type) == "attention": + return sum(t == "full_attention" + for t in layer_types_value[start:end]) + elif getattr(block_type, "value", + block_type) == "linear_attention": + return sum(t == "linear_attention" + for t in layer_types_value[start:end]) + else: + return sum(t == getattr(block_type, "value", block_type) + for t in layer_types_value[start:end]) + + if (layers_block_type_value is None and attn_type_list is None + and layer_types_value is None): raise ValueError( "The model is an hybrid without a" - "layers_block_type or an attn_type_list in the hf_config," - "cannot determine the num of " + "layers_block_type or an attn_type_list, or a layer_types " + "in the hf_config, cannot determine the num of " f"{block_type.value} layers") - return sum(t == 1 for t in attn_type_list[start:end]) - def get_mamba_chunk_size(self) -> Optional[int]: """ Returns the mamba chunk size if it exists @@ -1750,6 +1779,37 @@ class ModelConfig: # `llm as reranker` models defaults to not using pad_token. return getattr(self.hf_config, "use_pad_token", True) + @property + def head_dtype(self) -> torch.dtype: + """ + "head" refers to the last Linear layer(s) of an LLM, + such as the lm_head in a generation model, + or the score or classifier in a classification model. + + `head_dtype` currently only supports pooling models.\n + - The pooling model defaults to using fp32 head, + you can use --hf-overrides '{"head_dtype": "model"}' to disable it. + """ + + head_dtype = _get_head_dtype(config=self.hf_config, + dtype=self.dtype, + runner_type=self.runner_type) + + if self.runner_type != "pooling" and head_dtype != self.dtype: + logger.warning_once( + "`head_dtype` currently only supports pooling models." + "fallback to model dtype [%s].", self.dtype) + return self.dtype + + if head_dtype not in current_platform.supported_dtypes: + logger.warning_once( + "The current platform does not support [%s] head dtype, " + "fallback to model dtype [%s].", head_dtype, self.dtype) + return self.dtype + + logger.debug_once("head dtype: %s", head_dtype) + return head_dtype + def get_and_verify_max_len(self, max_model_len: int): # Consider max_model_len in tokenizer_config only when # pooling models use absolute position_embedding. @@ -1772,90 +1832,6 @@ class ModelConfig: return max_model_len -@config -@dataclass -class LoadConfig: - """Configuration for loading the model weights.""" - - load_format: Union[str, LoadFormats] = "auto" - """The format of the model weights to load:\n - - "auto" will try to load the weights in the safetensors format and fall - back to the pytorch bin format if safetensors format is not available.\n - - "pt" will load the weights in the pytorch bin format.\n - - "safetensors" will load the weights in the safetensors format.\n - - "npcache" will load the weights in pytorch format and store a numpy cache - to speed up the loading.\n - - "dummy" will initialize the weights with random values, which is mainly - for profiling.\n - - "tensorizer" will use CoreWeave's tensorizer library for fast weight - loading. See the Tensorize vLLM Model script in the Examples section for - more information.\n - - "runai_streamer" will load the Safetensors weights using Run:ai Model - Streamer.\n - - "bitsandbytes" will load the weights using bitsandbytes quantization.\n - - "sharded_state" will load weights from pre-sharded checkpoint files, - supporting efficient loading of tensor-parallel models.\n - - "gguf" will load weights from GGUF format files (details specified in - https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n - - "mistral" will load weights from consolidated safetensors files used by - Mistral models. - - Other custom values can be supported via plugins.""" - download_dir: Optional[str] = None - """Directory to download and load the weights, default to the default - cache directory of Hugging Face.""" - model_loader_extra_config: Union[dict, TensorizerConfig] = field( - default_factory=dict) - """Extra config for model loader. This will be passed to the model loader - corresponding to the chosen load_format.""" - device: Optional[str] = None - """Device to which model weights will be loaded, default to - device_config.device""" - ignore_patterns: Optional[Union[list[str], str]] = None - """The list of patterns to ignore when loading the model. Default to - "original/**/*" to avoid repeated loading of llama's checkpoints.""" - use_tqdm_on_load: bool = True - """Whether to enable tqdm for showing progress bar when loading model - weights.""" - pt_load_map_location: Union[str, dict[str, str]] = "cpu" - """ - pt_load_map_location: the map location for loading pytorch checkpoint, to - support loading checkpoints can only be loaded on certain devices like - "cuda", this is equivalent to {"": "cuda"}. Another supported format is - mapping from different devices like from GPU 1 to GPU 0: - {"cuda:1": "cuda:0"}. Note that when passed from command line, the strings - in dictionary needs to be double quoted for json parsing. For more details, - see original doc for `map_location` in https://pytorch.org/docs/stable/generated/torch.load.html - """ - - def compute_hash(self) -> str: - """ - WARNING: Whenever a new field is added to this config, - ensure that it is included in the factors list if - it affects the computation graph. - - Provide a hash that uniquely identifies all the configs - that affect the structure of the computation - graph from input ids/embeddings to the final hidden states, - excluding anything before input ids/embeddings and after - the final hidden states. - """ - # no factors to consider. - # this config will not affect the computation graph. - factors: list[Any] = [] - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() - return hash_str - - def __post_init__(self): - self.load_format = self.load_format.lower() - if self.ignore_patterns is not None and len(self.ignore_patterns) > 0: - logger.info( - "Ignoring the following patterns when downloading weights: %s", - self.ignore_patterns) - else: - self.ignore_patterns = ["original/**/*"] - - Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"] @@ -1921,7 +1897,7 @@ class DeviceConfig: SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa", "mlp_speculator", "draft_model", "deepseek_mtp", - "ernie_mtp"] + "ernie_mtp", "qwen3_next_mtp"] @config @@ -2062,7 +2038,15 @@ class SpeculativeConfig: "n_predict": n_predict, "architectures": ["ErnieMTPModel"] }) - return hf_config + + if hf_config.model_type == "qwen3_next": + hf_config.model_type = "qwen3_next_mtp" + if hf_config.model_type == "qwen3_next_mtp": + n_predict = getattr(hf_config, "num_nextn_predict_layers", None) + hf_config.update({ + "n_predict": n_predict, + "architectures": ["Qwen3NextMTP"] + }) return hf_config @@ -2083,9 +2067,13 @@ class SpeculativeConfig: (self.target_model_config.hf_text_config.model_type \ == "deepseek_v3" or self.target_model_config.hf_text_config.model_type in - ("mimo","ernie4_5_moe")): + ("mimo","ernie4_5_moe", "qwen3_next")): # use the draft model from the same model: self.model = self.target_model_config.model + # Align the quantization of draft model for cases such as + # --quantization fp8 with a bf16 checkpoint. + if not self.quantization: + self.quantization = self.target_model_config.quantization elif self.method in ("ngram", "[ngram]"): self.model = "ngram" else: @@ -2164,9 +2152,14 @@ class SpeculativeConfig: # Automatically detect the method if self.method in ('eagle', 'eagle3'): pass - elif "eagle-" in self.draft_model_config.model.lower() or \ - "eagle3-" in self.draft_model_config.model.lower(): + # examples: + # yuhuili/EAGLE-LLaMA3-Instruct-8B + # yuhuili/EAGLE3-LLaMA3.1-Instruct-8B + # AngelSlim/Qwen3-8B_eagle3 + elif "eagle-" in self.draft_model_config.model.lower(): self.method = "eagle" + elif "eagle3" in self.draft_model_config.model.lower(): + self.method = "eagle3" elif self.draft_model_config.hf_config.model_type == "medusa": self.method = "medusa" elif (self.draft_model_config.hf_config.model_type == @@ -2190,6 +2183,15 @@ class SpeculativeConfig: "one layer. Might need some code changes " \ "to support multiple layers." ) + elif (self.draft_model_config.hf_config.model_type == + "qwen3_next_mtp"): + self.method = "qwen3_next_mtp" + if self.num_speculative_tokens > 1: + logger.warning( + "All Qwen3Next MTP models only have " \ + "one layer. Might need some code changes " \ + "to support multiple layers." + ) else: self.method = "draft_model" raise NotImplementedError( @@ -2405,7 +2407,8 @@ class SpeculativeConfig: return self.num_speculative_tokens def use_eagle(self) -> bool: - return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp") + return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp", + "qwen3_next_mtp") def __repr__(self) -> str: method = self.method @@ -2414,116 +2417,6 @@ class SpeculativeConfig: return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})" -LoRADType = Literal["auto", "float16", "bfloat16"] - - -@config -@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) -class LoRAConfig: - """Configuration for LoRA.""" - - max_lora_rank: int = 16 - """Max LoRA rank.""" - max_loras: int = 1 - """Max number of LoRAs in a single batch.""" - fully_sharded_loras: bool = False - """By default, only half of the LoRA computation is sharded with tensor - parallelism. Enabling this will use the fully sharded layers. At high - sequence length, max rank or tensor parallel size, this is likely faster. - """ - max_cpu_loras: Optional[int] = None - """Maximum number of LoRAs to store in CPU memory. Must be >= than - `max_loras`.""" - lora_dtype: Union[torch.dtype, LoRADType] = "auto" - """Data type for LoRA. If auto, will default to base model dtype.""" - lora_extra_vocab_size: int = 256 - """(Deprecated) Maximum size of extra vocabulary that can be present in a - LoRA adapter. Will be removed in v0.12.0.""" - lora_vocab_padding_size: ClassVar[int] = current_platform\ - .get_lora_vocab_padding_size() - default_mm_loras: Optional[dict[str, str]] = None - """Dictionary mapping specific modalities to LoRA model paths; this field - is only applicable to multimodal models and should be leveraged when a - model always expects a LoRA to be active when a given modality is present. - Note that currently, if a request provides multiple additional - modalities, each of which have their own LoRA, we do NOT apply - default_mm_loras because we currently only support one lora adapter - per prompt. When run in offline mode, the lora IDs for n modalities - will be automatically assigned to 1-n with the names of the modalities - in alphabetic order.""" - bias_enabled: bool = False - """[DEPRECATED] Enable bias for LoRA adapters. This option will be - removed in v0.12.0.""" - - def compute_hash(self) -> str: - """ - WARNING: Whenever a new field is added to this config, - ensure that it is included in the factors list if - it affects the computation graph. - - Provide a hash that uniquely identifies all the configs - that affect the structure of the computation - graph from input ids/embeddings to the final hidden states, - excluding anything before input ids/embeddings and after - the final hidden states. - """ - factors: list[Any] = [] - factors.append(self.max_lora_rank) - factors.append(self.max_loras) - factors.append(self.fully_sharded_loras) - factors.append(self.lora_dtype) - factors.append(self.lora_extra_vocab_size) - factors.append(self.lora_vocab_padding_size) - factors.append(self.bias_enabled) - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() - return hash_str - - def __post_init__(self): - # Deprecation warning for lora_extra_vocab_size - logger.warning( - "`lora_extra_vocab_size` is deprecated and will be removed " - "in v0.12.0. Additional vocabulary support for " - "LoRA adapters is being phased out.") - - # Deprecation warning for enable_lora_bias - if self.bias_enabled: - logger.warning("`enable_lora_bias` is deprecated " - "and will be removed in v0.12.0.") - - # Setting the maximum rank to 512 should be able to satisfy the vast - # majority of applications. - possible_max_ranks = (8, 16, 32, 64, 128, 256, 320, 512) - possible_lora_extra_vocab_size = (256, 512) - if self.max_lora_rank not in possible_max_ranks: - raise ValueError( - f"max_lora_rank ({self.max_lora_rank}) must be one of " - f"{possible_max_ranks}.") - if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size: - raise ValueError( - f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) " - f"must be one of {possible_lora_extra_vocab_size}.") - if self.max_loras < 1: - raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.") - if self.max_cpu_loras is None: - self.max_cpu_loras = self.max_loras - elif self.max_cpu_loras < self.max_loras: - raise ValueError( - f"max_cpu_loras ({self.max_cpu_loras}) must be >= " - f"max_loras ({self.max_loras})") - - def verify_with_cache_config(self, cache_config: CacheConfig): - if cache_config.cpu_offload_gb > 0 and not envs.VLLM_USE_V1: - raise ValueError( - "V0 LoRA does not support CPU offload, please use V1.") - - def verify_with_model_config(self, model_config: ModelConfig): - if self.lora_dtype in (None, "auto"): - self.lora_dtype = model_config.dtype - elif isinstance(self.lora_dtype, str): - self.lora_dtype = getattr(torch, self.lora_dtype) - - @config @dataclass class MultiModalConfig: @@ -2566,6 +2459,15 @@ class MultiModalConfig: Set to `0` to disable this cache completely (not recommended). """ + mm_processor_cache_type: MMCacheType = "lru" + """Type of cache to use for the multi-modal preprocessor/mapper. If `shm`, + use shared memory FIFO cache. If `lru`, use mirrored LRU cache.""" + + mm_shm_cache_max_object_size_mb: int = 128 + """Size limit (in MiB) for each object stored in the multi-modal processor + shared memory cache. Only effective when `mm_processor_cache_type` is + `"shm"`.""" + mm_encoder_tp_mode: MMEncoderTPMode = "weights" """ Indicates how to optimize multi-modal encoder inference using @@ -2892,6 +2794,31 @@ def _get_and_verify_dtype( return torch_dtype +def _get_head_dtype(config: PretrainedConfig, dtype: torch.dtype, + runner_type: str) -> torch.dtype: + head_dtype: Optional[Union[str, + torch.dtype]] = getattr(config, "head_dtype", + None) + + if head_dtype == "model": + return dtype + elif isinstance(head_dtype, str): + head_dtype = head_dtype.lower() + if head_dtype not in _STR_DTYPE_TO_TORCH_DTYPE: + raise ValueError(f"Unknown dtype: {head_dtype!r}") + return _STR_DTYPE_TO_TORCH_DTYPE[head_dtype] + elif isinstance(head_dtype, torch.dtype): + return head_dtype + elif head_dtype is None: + if torch.float32 not in current_platform.supported_dtypes: + return dtype + if runner_type == "pooling": + return torch.float32 + return dtype + else: + raise ValueError(f"Unknown dtype: {head_dtype}") + + def _get_and_verify_max_len( hf_config: PretrainedConfig, tokenizer_config: Optional[dict], @@ -3209,149 +3136,6 @@ class ObservabilityConfig: self.collect_detailed_traces[0].split(",")) -KVProducer = Literal["kv_producer", "kv_both"] -KVConsumer = Literal["kv_consumer", "kv_both"] -KVRole = Literal[KVProducer, KVConsumer] - - -@config -@dataclass -class KVTransferConfig: - """Configuration for distributed KV cache transfer.""" - - kv_connector: Optional[str] = None - """The KV connector for vLLM to transmit KV caches between vLLM instances. - """ - - engine_id: Optional[str] = None - """The engine id for KV transfers.""" - - kv_buffer_device: Optional[str] = "cuda" - """The device used by kv connector to buffer the KV cache. - Currently only support 'cuda'.""" - - kv_buffer_size: float = 1e9 - """The buffer size for TorchDistributedConnector. Measured in number of - bytes. Recommended value: 1e9 (about 1GB).""" - - kv_role: Optional[KVRole] = None - """Whether this vLLM instance produces, consumes KV cache, or both. Choices - are 'kv_producer', 'kv_consumer', and 'kv_both'.""" - - kv_rank: Optional[int] = None - """The rank of this vLLM instance in the KV cache transfer. Typical value: - 0 for prefill instance, 1 for decode instance. - Currently only 1P1D is supported.""" - - kv_parallel_size: int = 1 - """The number of parallel instances for KV cache transfer. For - P2pNcclConnector, this should be 2.""" - - kv_ip: str = "127.0.0.1" - """The KV connector ip, used to build distributed connection.""" - - kv_port: int = 14579 - """The KV connector port, used to build distributed connection.""" - - kv_connector_extra_config: dict[str, Any] = field(default_factory=dict) - """any extra config that the connector may need.""" - - kv_connector_module_path: Optional[str] = None - """The Python module path to dynamically load the KV connector from. - Only supported in V1.""" - - def compute_hash(self) -> str: - """ - WARNING: Whenever a new field is added to this config, - ensure that it is included in the factors list if - it affects the computation graph. - - Provide a hash that uniquely identifies all the configs - that affect the structure of the computation - graph from input ids/embeddings to the final hidden states, - excluding anything before input ids/embeddings and after - the final hidden states. - """ - # no factors to consider. - # this config will not affect the computation graph. - factors: list[Any] = [] - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() - return hash_str - - def __post_init__(self) -> None: - if self.engine_id is None: - self.engine_id = str(uuid.uuid4()) - - if self.kv_role is not None and self.kv_role not in get_args(KVRole): - raise ValueError(f"Unsupported kv_role: {self.kv_role}. " - f"Supported roles are {get_args(KVRole)}") - - if self.kv_connector is not None and self.kv_role is None: - raise ValueError("Please specify kv_disagg_role when kv_connector " - f"is set, supported roles are {get_args(KVRole)}") - - @property - def is_kv_transfer_instance(self) -> bool: - return self.kv_connector is not None and \ - self.kv_role in get_args(KVRole) - - @property - def is_kv_producer(self) -> bool: - return self.kv_connector is not None and \ - self.kv_role in get_args(KVProducer) - - @property - def is_kv_consumer(self) -> bool: - return self.kv_connector is not None and \ - self.kv_role in get_args(KVConsumer) - - def get_from_extra_config(self, key, default) -> Any: - return self.kv_connector_extra_config.get(key, default) - - -@config -@dataclass -class KVEventsConfig: - """Configuration for KV event publishing.""" - - enable_kv_cache_events: bool = False - """If True, enable KV cache events for tracking block storage and removal. - Events can be published externally by zmq using the event publisher config. - """ - - publisher: str = "null" - """The publisher to use for publishing kv events. Can be "null", "zmq". - """ - - endpoint: str = "tcp://*:5557" - """The zmq endpoint to use for publishing kv events. - """ - - replay_endpoint: Optional[str] = None - """The zmq endpoint to use for replaying kv events. - """ - - buffer_steps: int = 10_000 - """The number of steps to cache for replay endpoint. Will only save - events from the last N steps for the replay endpoint. - """ - - hwm: int = 100_000 - """The zmq high water mark for the event publisher. After queueing N events, - events will start dropping if the consumer is not keeping up. - """ - - max_queue_size: int = 100_000 - """The maximum number of events to queue while waiting for publishing. - """ - - topic: str = "" - """The topic to use for the event publisher. Consumers can subscribe to - this topic to receive events. - """ - - @config @dataclass(config=ConfigDict(arbitrary_types_allowed=True)) class VllmConfig: @@ -3683,16 +3467,37 @@ class VllmConfig: disable_chunked_prefill_reasons: list[str] = [] - if self.model_config and self.model_config.pooler_config: - pooling_type = self.model_config.pooler_config.pooling_type - if pooling_type is None or pooling_type.lower() != "last": + if self.model_config: + if self.model_config.pooler_config: + pooling_type = self.model_config.pooler_config.pooling_type + if pooling_type is None or pooling_type.lower() != "last": + disable_chunked_prefill_reasons.append( + "Only \"last\" pooling supports chunked " + "prefill and prefix caching; disabling both.") + if not getattr(self.model_config.hf_config, "is_causal", True): + disable_chunked_prefill_reasons.append( + "Only models using causal attention supports chunked " + "prefill and prefix caching; disabling both.") + elif self.model_config.is_encoder_decoder: + self.scheduler_config.max_num_encoder_input_tokens = \ + MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config) + logger.debug( + "Encoder-decoder model detected: setting " + "`max_num_encoder_input_tokens` to encoder length (%s)", + self.scheduler_config.max_num_encoder_input_tokens) + self.scheduler_config.disable_chunked_mm_input = True disable_chunked_prefill_reasons.append( - "Only \"last\" pooling supports chunked " - "prefill and prefix caching; disabling both.") - elif not getattr(self.model_config.hf_config, "is_causal", True): - disable_chunked_prefill_reasons.append( - "Only models using causal attention supports chunked " - "prefill and prefix caching; disabling both.") + "Encoder-decoder models do not support chunked prefill nor" + " prefix caching; disabling both.") + if (self.model_config.architecture + == "WhisperForConditionalGeneration" + and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") + != "spawn"): + logger.warning( + "Whisper is known to have issues with " + "forked workers. If startup is hanging, " + "try setting 'VLLM_WORKER_MULTIPROC_METHOD' " + "to 'spawn'.") if disable_chunked_prefill_reasons: for reason in disable_chunked_prefill_reasons: @@ -3749,7 +3554,7 @@ class VllmConfig: # logger should only print warning message for hybrid models. As we # can't know whether the model is hybrid or not now, so we don't log # warning message here and will log it later. - if not (current_platform.is_cuda() or current_platform.is_rocm()): + if not current_platform.support_hybrid_kv_cache(): # Hybrid KV cache manager is not supported on non-GPU platforms. self.scheduler_config.disable_hybrid_kv_cache_manager = True if self.kv_transfer_config is not None: @@ -3799,30 +3604,40 @@ class VllmConfig: def _set_cudagraph_sizes(self): """ - cudagraph batchsize padding logic: + vLLM defines the default candidate list of batch sizes for CUDA graph + capture as: - `[1, 2, 4] + [8 * i for i in range(1, 1025)]` is a list of all possible - batch sizes that cudagraph will capture. - - Depending on the engine's configuration of `max_num_seqs`, the - candidate batch sizes to capture cudagraph will shrink to the subset - which just cover the range of `[1, max_num_seqs]`. In the common case, - `max_num_seqs` is 256, and the cudagraph batch sizes will be - `[1, 2, 4, 8, 16, 24, 32, 40, ..., 256]`. - - However, if users specify the cudagraph capture sizes through - compilation config, we will use the specified sizes instead. + ```python + max_graph_size = min(max_num_seqs * 2, 512) + # 1, 2, 4, then multiples of 8 up to max_graph_size + cuda_graph_sizes = [1, 2, 4, 8, 16, 24, 32, 40, ..., max_graph_size] In the end, `vllm_config.compilation_config.cudagraph_capture_sizes` will be the final sizes to capture cudagraph (in descending order). - During runtime, if batchsize is larger than - `vllm_config.compilation_config.cudagraph_capture_sizes`, - no cudagraph will be used. - If the batch size is no larger than - `vllm_config.compilation_config.cudagraph_capture_sizes`, - we can quickly find the padded graph size for a given batch size by - looking up `vllm_config.compilation_config.bs_to_padded_graph_size`. + These sizes are used to capture and reuse CUDA graphs for + performance-critical paths (e.g., decoding). Capturing enables + significantly faster kernel dispatch by avoiding Python overhead. The + list is then filtered based on `max_num_batched_tokens` (e.g., 8192 on + most GPUs), which controls the total allowed number of tokens in a + batch. Since each sequence may have a variable number of tokens, the + maximum usable batch size will depend on actual sequence lengths. + + Example: + With `max_num_batched_tokens = 8192`, and typical sequences + averaging ~32 tokens, most practical batch sizes fall below 256. + However, the system will still allow capture sizes up to 512 if + shape and memory permit. + + Note: + If users explicitly specify cudagraph capture sizes in the + compilation config, those will override this default logic. + At runtime: + + - If batch size <= one of the `cudagraph_capture_sizes`, the closest + padded CUDA graph will be used. + - If batch size > largest `cudagraph_capture_sizes`, cudagraph will + not be used. """ # calculate the default `batch_size_capture_list` @@ -3932,6 +3747,7 @@ class VllmConfig: f"load_format={self.load_config.load_format}, " f"tensor_parallel_size={self.parallel_config.tensor_parallel_size}, " # noqa f"pipeline_parallel_size={self.parallel_config.pipeline_parallel_size}, " # noqa + f"data_parallel_size={self.parallel_config.data_parallel_size}, " # noqa f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, " # noqa f"quantization={self.model_config.quantization}, " f"enforce_eager={self.model_config.enforce_eager}, " @@ -4030,7 +3846,7 @@ def contains_object_print(text): Check if the text looks like a printed Python object, e.g. contains any substring matching the pattern: "at 0xFFFFFFF>" We match against 0x followed by 2-16 hex chars (there's - a max of 16 on a 64 bit system). + a max of 16 on a 64-bit system). Args: text (str): The text to check diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 5cc630b72846d..4c4e39c37ee50 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -24,7 +24,7 @@ logger = init_logger(__name__) BlockSize = Literal[1, 8, 16, 32, 64, 128] CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"] MambaDType = Literal["auto", "float32"] -PrefixCachingHashAlgo = Literal["builtin", "sha256", "sha256_cbor_64bit"] +PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"] @config @@ -63,17 +63,12 @@ class CacheConfig: """Sliding window size for the KV cache. This is primarily set in `ModelConfig` and that value should be manually duplicated here.""" enable_prefix_caching: Optional[bool] = None - """Whether to enable prefix caching. Disabled by default for V0. Enabled by - default for V1.""" - prefix_caching_hash_algo: PrefixCachingHashAlgo = "builtin" + """Whether to enable prefix caching. Enabled by default for V1.""" + prefix_caching_hash_algo: PrefixCachingHashAlgo = "sha256" """Set the hash algorithm for prefix caching:\n - - "builtin" is Python's built-in hash.\n - - "sha256" is collision resistant but with certain overheads. - This option uses Pickle for object serialization before hashing.\n - - "sha256_cbor_64bit" provides a reproducible, cross-language compatible - hash. It serializes objects using canonical CBOR and hashes them with - SHA-256. The resulting hash consists of the lower 64 bits of the SHA-256 - digest.""" + - "sha256" uses Pickle for object serialization before hashing.\n + - "sha256_cbor" provides a reproducible, cross-language compatible hash. It + serializes objects using canonical CBOR and hashes them with SHA-256.""" cpu_offload_gb: float = 0 """The space in GiB to offload to CPU, per GPU. Default is 0, which means no offloading. Intuitively, this argument can be seen as a virtual way to @@ -118,6 +113,15 @@ class CacheConfig: necessary for implementing this optimization in some models (e.g. Gemma3n) """ + kv_cache_memory_bytes: Optional[int] = None + """Size of KV Cache per GPU in bytes. By default, this is set to None + and vllm can automatically infer the kv cache size based on + gpu_memory_utilization. However, users may want to manually specify + the kv cache memory size. kv_cache_memory_bytes allows more fine-grain + control of how much memory gets used when compared with using + gpu_memory_memory_utilization. Note that kv_cache_memory_bytes + (when not-None) ignores gpu_memory_utilization""" + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 677fb069bc07a..f8ccc20222615 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -341,6 +341,7 @@ class CompilationConfig: "vllm.short_conv", "vllm.linear_attention", "vllm.plamo2_mamba_mixer", + "vllm.gdn_attention", ] def compute_hash(self) -> str: @@ -546,7 +547,8 @@ class CompilationConfig: # full cudagraph outside the fx graph. This reduces some cpu # overhead when the runtime batch_size is not cudagraph captured. # see https://github.com/vllm-project/vllm/pull/20059 for details. - self.splitting_ops = self._attention_ops + # make a copy to avoid mutating the class-level list via reference. + self.splitting_ops = list(self._attention_ops) elif len(self.splitting_ops) == 0: logger.warning_once("Using piecewise compilation with empty " "splitting_ops.") @@ -561,6 +563,18 @@ class CompilationConfig: self.cudagraph_mode = CUDAGraphMode.FULL self.splitting_ops = [] + if envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput": + # exclude MoE dispatch/combine from capture by ensuring + # piecewise splitting includes them, so communication remains + # outside CUDA graphs while compute can still be graphed. + moe_ops = [ + "vllm.moe_forward", + "vllm.moe_forward_shared", + ] + for op in moe_ops: + if op not in self.splitting_ops: + self.splitting_ops.append(op) + def splitting_ops_contain_attention(self) -> bool: return self.splitting_ops is not None and all( op in self.splitting_ops for op in self._attention_ops) diff --git a/vllm/config/kv_events.py b/vllm/config/kv_events.py new file mode 100644 index 0000000000000..1c6bdffa1281d --- /dev/null +++ b/vllm/config/kv_events.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +from pydantic.dataclasses import dataclass + +from vllm.config.utils import config + + +@config +@dataclass +class KVEventsConfig: + """Configuration for KV event publishing.""" + + enable_kv_cache_events: bool = False + """If True, enable KV cache events for tracking block storage and removal. + Events can be published externally by zmq using the event publisher config. + """ + + publisher: str = "null" + """The publisher to use for publishing kv events. Can be "null", "zmq". + """ + + endpoint: str = "tcp://*:5557" + """The zmq endpoint to use for publishing kv events. + """ + + replay_endpoint: Optional[str] = None + """The zmq endpoint to use for replaying kv events. + """ + + buffer_steps: int = 10_000 + """The number of steps to cache for replay endpoint. Will only save + events from the last N steps for the replay endpoint. + """ + + hwm: int = 100_000 + """The zmq high water mark for the event publisher. After queueing N events, + events will start dropping if the consumer is not keeping up. + """ + + max_queue_size: int = 100_000 + """The maximum number of events to queue while waiting for publishing. + """ + + topic: str = "" + """The topic to use for the event publisher. Consumers can subscribe to + this topic to receive events. + """ diff --git a/vllm/config/kv_transfer.py b/vllm/config/kv_transfer.py new file mode 100644 index 0000000000000..9abf4acacfe81 --- /dev/null +++ b/vllm/config/kv_transfer.py @@ -0,0 +1,111 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import hashlib +import uuid +from dataclasses import field +from typing import Any, Literal, Optional, get_args + +from pydantic.dataclasses import dataclass + +from vllm.config.utils import config + +KVProducer = Literal["kv_producer", "kv_both"] +KVConsumer = Literal["kv_consumer", "kv_both"] +KVRole = Literal[KVProducer, KVConsumer] + + +@config +@dataclass +class KVTransferConfig: + """Configuration for distributed KV cache transfer.""" + + kv_connector: Optional[str] = None + """The KV connector for vLLM to transmit KV caches between vLLM instances. + """ + + engine_id: Optional[str] = None + """The engine id for KV transfers.""" + + kv_buffer_device: Optional[str] = "cuda" + """The device used by kv connector to buffer the KV cache. + Currently only support 'cuda'.""" + + kv_buffer_size: float = 1e9 + """The buffer size for TorchDistributedConnector. Measured in number of + bytes. Recommended value: 1e9 (about 1GB).""" + + kv_role: Optional[KVRole] = None + """Whether this vLLM instance produces, consumes KV cache, or both. Choices + are 'kv_producer', 'kv_consumer', and 'kv_both'.""" + + kv_rank: Optional[int] = None + """The rank of this vLLM instance in the KV cache transfer. Typical value: + 0 for prefill instance, 1 for decode instance. + Currently only 1P1D is supported.""" + + kv_parallel_size: int = 1 + """The number of parallel instances for KV cache transfer. For + P2pNcclConnector, this should be 2.""" + + kv_ip: str = "127.0.0.1" + """The KV connector ip, used to build distributed connection.""" + + kv_port: int = 14579 + """The KV connector port, used to build distributed connection.""" + + kv_connector_extra_config: dict[str, Any] = field(default_factory=dict) + """any extra config that the connector may need.""" + + kv_connector_module_path: Optional[str] = None + """The Python module path to dynamically load the KV connector from. + Only supported in V1.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self) -> None: + if self.engine_id is None: + self.engine_id = str(uuid.uuid4()) + + if self.kv_role is not None and self.kv_role not in get_args(KVRole): + raise ValueError(f"Unsupported kv_role: {self.kv_role}. " + f"Supported roles are {get_args(KVRole)}") + + if self.kv_connector is not None and self.kv_role is None: + raise ValueError("Please specify kv_disagg_role when kv_connector " + f"is set, supported roles are {get_args(KVRole)}") + + @property + def is_kv_transfer_instance(self) -> bool: + return self.kv_connector is not None and \ + self.kv_role in get_args(KVRole) + + @property + def is_kv_producer(self) -> bool: + return self.kv_connector is not None and \ + self.kv_role in get_args(KVProducer) + + @property + def is_kv_consumer(self) -> bool: + return self.kv_connector is not None and \ + self.kv_role in get_args(KVConsumer) + + def get_from_extra_config(self, key, default) -> Any: + return self.kv_connector_extra_config.get(key, default) diff --git a/vllm/config/load.py b/vllm/config/load.py new file mode 100644 index 0000000000000..26ffec23ad5c6 --- /dev/null +++ b/vllm/config/load.py @@ -0,0 +1,113 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import hashlib +from dataclasses import field +from typing import TYPE_CHECKING, Any, Optional, Union + +from pydantic.dataclasses import dataclass + +from vllm.config.utils import config +from vllm.logger import init_logger + +if TYPE_CHECKING: + from vllm.model_executor.model_loader import LoadFormats + from vllm.model_executor.model_loader.tensorizer import TensorizerConfig +else: + LoadFormats = Any + TensorizerConfig = Any + +logger = init_logger(__name__) + + +@config +@dataclass +class LoadConfig: + """Configuration for loading the model weights.""" + + load_format: Union[str, LoadFormats] = "auto" + """The format of the model weights to load:\n + - "auto" will try to load the weights in the safetensors format and fall + back to the pytorch bin format if safetensors format is not available.\n + - "pt" will load the weights in the pytorch bin format.\n + - "safetensors" will load the weights in the safetensors format.\n + - "npcache" will load the weights in pytorch format and store a numpy cache + to speed up the loading.\n + - "dummy" will initialize the weights with random values, which is mainly + for profiling.\n + - "tensorizer" will use CoreWeave's tensorizer library for fast weight + loading. See the Tensorize vLLM Model script in the Examples section for + more information.\n + - "runai_streamer" will load the Safetensors weights using Run:ai Model + Streamer.\n + - "bitsandbytes" will load the weights using bitsandbytes quantization.\n + - "sharded_state" will load weights from pre-sharded checkpoint files, + supporting efficient loading of tensor-parallel models.\n + - "gguf" will load weights from GGUF format files (details specified in + https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n + - "mistral" will load weights from consolidated safetensors files used by + Mistral models. + - Other custom values can be supported via plugins.""" + download_dir: Optional[str] = None + """Directory to download and load the weights, default to the default + cache directory of Hugging Face.""" + safetensors_load_strategy: str = "lazy" + """Specifies the loading strategy for safetensors weights. + - "lazy" (default): Weights are memory-mapped from the file. This enables + on-demand loading and is highly efficient for models on local storage. + - "eager": The entire file is read into CPU memory upfront before loading. + This is recommended for models on network filesystems (e.g., Lustre, NFS) + as it avoids inefficient random reads, significantly speeding up model + initialization. However, it uses more CPU RAM. + """ + model_loader_extra_config: Union[dict, TensorizerConfig] = field( + default_factory=dict) + """Extra config for model loader. This will be passed to the model loader + corresponding to the chosen load_format.""" + device: Optional[str] = None + """Device to which model weights will be loaded, default to + device_config.device""" + ignore_patterns: Optional[Union[list[str], str]] = None + """The list of patterns to ignore when loading the model. Default to + "original/**/*" to avoid repeated loading of llama's checkpoints.""" + use_tqdm_on_load: bool = True + """Whether to enable tqdm for showing progress bar when loading model + weights.""" + pt_load_map_location: Union[str, dict[str, str]] = "cpu" + """ + pt_load_map_location: the map location for loading pytorch checkpoint, to + support loading checkpoints can only be loaded on certain devices like + "cuda", this is equivalent to {"": "cuda"}. Another supported format is + mapping from different devices like from GPU 1 to GPU 0: + {"cuda:1": "cuda:0"}. Note that when passed from command line, the strings + in dictionary needs to be double quoted for json parsing. For more details, + see original doc for `map_location` in https://pytorch.org/docs/stable/generated/torch.load.html + """ + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self): + self.load_format = self.load_format.lower() + if self.ignore_patterns is not None and len(self.ignore_patterns) > 0: + logger.info( + "Ignoring the following patterns when downloading weights: %s", + self.ignore_patterns) + else: + self.ignore_patterns = ["original/**/*"] diff --git a/vllm/config/lora.py b/vllm/config/lora.py new file mode 100644 index 0000000000000..3fe28f5dad4fa --- /dev/null +++ b/vllm/config/lora.py @@ -0,0 +1,132 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import hashlib +from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional, Union + +import torch +from pydantic import ConfigDict +from pydantic.dataclasses import dataclass + +import vllm.envs as envs +from vllm.config.utils import config +from vllm.logger import init_logger +from vllm.platforms import current_platform + +if TYPE_CHECKING: + from vllm.config import ModelConfig + from vllm.config.cache import CacheConfig +else: + ModelConfig = Any + CacheConfig = Any + +logger = init_logger(__name__) + +LoRADType = Literal["auto", "float16", "bfloat16"] + + +@config +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) +class LoRAConfig: + """Configuration for LoRA.""" + + max_lora_rank: int = 16 + """Max LoRA rank.""" + max_loras: int = 1 + """Max number of LoRAs in a single batch.""" + fully_sharded_loras: bool = False + """By default, only half of the LoRA computation is sharded with tensor + parallelism. Enabling this will use the fully sharded layers. At high + sequence length, max rank or tensor parallel size, this is likely faster. + """ + max_cpu_loras: Optional[int] = None + """Maximum number of LoRAs to store in CPU memory. Must be >= than + `max_loras`.""" + lora_dtype: Union[torch.dtype, LoRADType] = "auto" + """Data type for LoRA. If auto, will default to base model dtype.""" + lora_extra_vocab_size: int = 256 + """(Deprecated) Maximum size of extra vocabulary that can be present in a + LoRA adapter. Will be removed in v0.12.0.""" + lora_vocab_padding_size: ClassVar[int] = current_platform\ + .get_lora_vocab_padding_size() + default_mm_loras: Optional[dict[str, str]] = None + """Dictionary mapping specific modalities to LoRA model paths; this field + is only applicable to multimodal models and should be leveraged when a + model always expects a LoRA to be active when a given modality is present. + Note that currently, if a request provides multiple additional + modalities, each of which have their own LoRA, we do NOT apply + default_mm_loras because we currently only support one lora adapter + per prompt. When run in offline mode, the lora IDs for n modalities + will be automatically assigned to 1-n with the names of the modalities + in alphabetic order.""" + bias_enabled: bool = False + """[DEPRECATED] Enable bias for LoRA adapters. This option will be + removed in v0.12.0.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: list[Any] = [] + factors.append(self.max_lora_rank) + factors.append(self.max_loras) + factors.append(self.fully_sharded_loras) + factors.append(self.lora_dtype) + factors.append(self.lora_extra_vocab_size) + factors.append(self.lora_vocab_padding_size) + factors.append(self.bias_enabled) + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self): + # Deprecation warning for lora_extra_vocab_size + logger.warning( + "`lora_extra_vocab_size` is deprecated and will be removed " + "in v0.12.0. Additional vocabulary support for " + "LoRA adapters is being phased out.") + + # Deprecation warning for enable_lora_bias + if self.bias_enabled: + logger.warning("`enable_lora_bias` is deprecated " + "and will be removed in v0.12.0.") + + # Setting the maximum rank to 512 should be able to satisfy the vast + # majority of applications. + possible_max_ranks = (8, 16, 32, 64, 128, 256, 320, 512) + possible_lora_extra_vocab_size = (256, 512) + if self.max_lora_rank not in possible_max_ranks: + raise ValueError( + f"max_lora_rank ({self.max_lora_rank}) must be one of " + f"{possible_max_ranks}.") + if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size: + raise ValueError( + f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) " + f"must be one of {possible_lora_extra_vocab_size}.") + if self.max_loras < 1: + raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.") + if self.max_cpu_loras is None: + self.max_cpu_loras = self.max_loras + elif self.max_cpu_loras < self.max_loras: + raise ValueError( + f"max_cpu_loras ({self.max_cpu_loras}) must be >= " + f"max_loras ({self.max_loras})") + + def verify_with_cache_config(self, cache_config: CacheConfig): + if cache_config.cpu_offload_gb > 0 and not envs.VLLM_USE_V1: + raise ValueError( + "V0 LoRA does not support CPU offload, please use V1.") + + def verify_with_model_config(self, model_config: ModelConfig): + if self.lora_dtype in (None, "auto"): + self.lora_dtype = model_config.dtype + elif isinstance(self.lora_dtype, str): + self.lora_dtype = getattr(torch, self.lora_dtype) diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 3a74b5fb7e64f..2f8ad5c6b6b04 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -368,8 +368,10 @@ class ParallelConfig: else: if self.eplb_config.num_redundant_experts != 0: raise ValueError( - "num_redundant_experts should be used with EPLB." - f"{self.eplb_config.num_redundant_experts}.") + "num_redundant_experts is set to " + f"{self.eplb_config.num_redundant_experts} but EPLB is not " + "enabled. Either enable EPLB or unset " + "num_redundant_experts.") if self.distributed_executor_backend is None and self.world_size > 1: # We use multiprocessing by default if world_size fits on the # current node and we aren't in a ray placement group. diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index 7d9b32cd4b674..ae876d131eb66 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -182,7 +182,7 @@ class NaiveBlockAllocator(BlockAllocator): # Increment refcount for each block. assert block.block_id is not None refcount = self._refcounter.incr(block.block_id) - assert refcount != 1, "can't fork free'd block" + assert refcount != 1, "can't fork freed block" forked_block = self._block_pool.init_block( prev_block=prev_block, diff --git a/vllm/core/evictor.py b/vllm/core/evictor.py index 7a4a836ee348e..85ff6bc9ca610 100644 --- a/vllm/core/evictor.py +++ b/vllm/core/evictor.py @@ -58,7 +58,7 @@ class Evictor(ABC): class BlockMetaData: """Data structure for storing key data describe cached block, so that - evitor could use to make its decision which one to choose for eviction + evictor could use to make its decision which one to choose for eviction Here we use physical block id as the dict key, as there maybe several blocks with the same content hash, but their physical id is unique. diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index d7864293e9647..92ebad778ea4b 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -11,7 +11,8 @@ from typing import Callable, Deque, Dict, Iterable, List, Optional from typing import Sequence as GenericSequence from typing import Set, Tuple, Union -from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig +from vllm.config import CacheConfig, SchedulerConfig +from vllm.config.lora import LoRAConfig from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.logger import init_logger from vllm.lora.request import LoRARequest diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index 7963fb15c4191..af7ca6be1fca8 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -16,8 +16,11 @@ from typing import Any, Callable, Optional, Union import torch +from vllm.logger import init_logger from vllm.utils import is_pin_memory_available +logger = init_logger(__name__) + def find_loaded_library(lib_name) -> Optional[str]: """ @@ -165,6 +168,9 @@ class CuMemAllocator: py_d_mem = allocation_handle[2] self.pointer_to_data[py_d_mem] = AllocationData( allocation_handle, self.current_tag) + logger.debug( + "Allocated %s bytes for %s with address %s from cumem allocator", + allocation_handle[1], self.current_tag, py_d_mem) return def _python_free_callback(self, ptr: int) -> HandleType: @@ -174,6 +180,9 @@ class CuMemAllocator: data = self.pointer_to_data.pop(ptr) if data.cpu_backup_tensor is not None: data.cpu_backup_tensor = None + logger.debug( + "Freed %s bytes for %s with address %s from cumem allocator", + data.handle[1], data.tag, ptr) return data.handle def sleep( @@ -197,9 +206,14 @@ class CuMemAllocator: assert isinstance(offload_tags, tuple) + total_bytes = 0 + backup_bytes = 0 + for ptr, data in self.pointer_to_data.items(): handle = data.handle + total_bytes += handle[1] if data.tag in offload_tags: + backup_bytes += handle[1] size_in_bytes = handle[1] cpu_backup_tensor = torch.empty( size_in_bytes, @@ -211,6 +225,12 @@ class CuMemAllocator: data.cpu_backup_tensor = cpu_backup_tensor unmap_and_release(handle) + logger.info( + "CuMemAllocator: sleep freed %.2f GiB memory in total, of which " + "%.2f GiB is backed up in CPU and the rest %.2f GiB is discarded " + "directly.", total_bytes / 1024**3, backup_bytes / 1024**3, + (total_bytes - backup_bytes) / 1024**3) + gc.collect() torch.cuda.empty_cache() @@ -267,12 +287,17 @@ class CuMemAllocator: # when using pluggable allocator, see # https://github.com/pytorch/pytorch/issues/145168 . # if we have some memory allocated and then freed, - # the memory will not be released. - # right now it is fine, because we only use this allocator - # during weight loading and kv cache creation, where we only - # allocate memory. - # TODO: we need to find a way to release the memory, - # i.e. calling torch.cuda.empty_cache() + # the memory will not be released, e.g. in online quantization, + # where the model is created in higher precision, and then + # quantized in lower precision. + # Find all unused allocations and manually release them. + # TODO: we should expose `empty_cache` method in the memory pool. + # TODO: ask for help from PyTorch team to expose this method. + allocations = data[0].snapshot() + for allocation in allocations: + if allocation["allocated_size"] == 0: + handle = self._python_free_callback(allocation["address"]) + unmap_and_release(handle) self.current_tag = old_tag def get_current_usage(self) -> int: diff --git a/vllm/distributed/device_communicators/all_reduce_utils.py b/vllm/distributed/device_communicators/all_reduce_utils.py index 5c64e7d5c4ba3..805a88854b77c 100644 --- a/vllm/distributed/device_communicators/all_reduce_utils.py +++ b/vllm/distributed/device_communicators/all_reduce_utils.py @@ -36,8 +36,8 @@ CUSTOM_ALL_REDUCE_MAX_SIZES = { "10.0": { 2: 2 * MiB, # 2 MB 4: 2 * MiB, # 2 MB - 6: 2 * MiB, # 2 MB - 8: 2 * MiB, # 2 MB + 6: 1 * MiB, # 1 MB + 8: 1 * MiB, # 1 MB } } diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index eef3f9f75f9f1..78c90b006ffc8 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -57,11 +57,19 @@ class CudaCommunicator(DeviceCommunicatorBase): self.ca_comm: Optional[CustomAllreduce] = None self.qr_comm: Optional[QuickAllReduce] = None self.symm_mem_comm: Optional[SymmMemCommunicator] = None + if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda(): + self.symm_mem_comm = SymmMemCommunicator( + group=self.cpu_group, + device=self.device, + ) + if use_custom_allreduce and self.world_size > 1: # Initialize a custom fast all-reduce implementation. self.ca_comm = CustomAllreduce( group=self.cpu_group, device=self.device, + symm_mem_enabled=(self.symm_mem_comm is not None + and not self.symm_mem_comm.disabled), ) if current_platform.is_rocm(): @@ -72,11 +80,6 @@ class CudaCommunicator(DeviceCommunicatorBase): # currently be an MI300 series. self.qr_comm = QuickAllReduce(group=self.cpu_group, device=self.device) - if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda(): - self.symm_mem_comm = SymmMemCommunicator( - group=self.cpu_group, - device=self.device, - ) if self.use_all2all: all2all_backend = envs.VLLM_ALL2ALL_BACKEND diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index c8cc35f99785c..3cc4bbb258244 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -54,7 +54,8 @@ class CustomAllreduce: def __init__(self, group: ProcessGroup, device: Union[int, str, torch.device], - max_size=8192 * 1024) -> None: + max_size=8192 * 1024, + symm_mem_enabled=False) -> None: """ Args: group: the process group to work on. If None, it will use the @@ -111,7 +112,7 @@ class CustomAllreduce: self.device = device device_capability = current_platform.get_device_capability( ).as_version_str() - if (current_platform.is_cuda() and envs.VLLM_ALLREDUCE_USE_SYMM_MEM + if (current_platform.is_cuda() and symm_mem_enabled and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES): max_size = min( CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size], diff --git a/vllm/distributed/device_communicators/shm_object_storage.py b/vllm/distributed/device_communicators/shm_object_storage.py new file mode 100644 index 0000000000000..3fac104bda1e8 --- /dev/null +++ b/vllm/distributed/device_communicators/shm_object_storage.py @@ -0,0 +1,635 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pickle +from abc import ABC, abstractmethod +from collections.abc import Iterable +from contextlib import contextmanager +from dataclasses import dataclass +from itertools import chain +from multiprocessing import shared_memory +from multiprocessing.synchronize import Lock as LockType +from typing import Any, Callable, Optional, Union +from unittest.mock import patch + +import torch + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class SingleWriterShmRingBuffer: + """ + A single-writer, multiple-reader ring buffer implementation using shared + memory. This class provides a thread-safe ring buffer where one process + can write data while multiple processes/threads can read from it. + + Architecture: + - Uses shared memory for cross-process communication + - Maintains metadata for each allocated buffer chunk in the writer process + - Supports custom "is_free_fn" functions to determine when buffers can be + reused + - Each buffer chunk contains: [4-byte id][4-byte size][actual_data] + + Key Concepts: + - monotonic_id_start/end: Track the range of active buffer IDs + - data_buffer_start/end: Track the physical memory range in use + - Automatic wraparound when reaching buffer end + - Lazy garbage collection based on is_free_fn checks + + Example Usage Scenarios: + + Scenario 1: Simple Linear Allocation + ``` + Buffer size: 100 bytes + Initial state: [................................................. ] + ^start=end(0) + + After allocating 20 bytes (id=0): + [id:0|size:20|data........][...................................] + ^start(0) ^end(28) + + After allocating 30 bytes (id=1): + [id:0|size:20|data........][id:1|size:30|data..............][..] + ^start(0) ^end(66) + ``` + + Scenario 2: Memory Reclamation + ``` + Before freeing (both buffers still in use): + [id:0|size:20|data........][id:1|size:30|data..............][..] + ^start(0) ^end(66) + + After id:0 is marked free by readers: + [FREED.................... ][id:1|size:30|data..............][..] + ^start(28) ^end(66) + + After both are freed: + [FREED..............................................][..] + ^start=end(66) + ``` + + Scenario 3: Wraparound Allocation (continuing from Scenario 2) + ``` + Starting from after memory reclamation in Scenario 2: + [FREED..............................................][..] + ^start=end(66) + + Allocate 40 bytes (id=2) - only 34 bytes available at end, so wraparound: + [id:2|size:40|data........................][FREED.............][..] + ^end(148) ^start(66) + ``` + + Scenario 4: Error Handling - Out of Space + ``` + Starting from after wraparound allocation in Scenario 3: + [id:2|size:40|data........................][FREED.............][..] + ^end(148) ^start(66) + + Trying to allocate 20 more bytes: + occupied_size_new = end + size - start = 148 + 28 - 66 > buffer_size(100) + -> Raises MemoryError: "Not enough space in the data buffer" + ``` + + Thread Safety: + - Single writer: Only one process/thread should write (allocate_buf) + - Multiple readers: Multiple processes/threads can read (access_buf) + - Reader synchronization handled by is_free_fn callback + - Writer handles garbage collection (free_buf) based on reader feedback + + Memory Layout per Buffer Chunk: + [4-byte monotonic_id][4-byte chunk_size][actual_data...] + ^metadata_start ^data_start + + The monotonic_id ensures data integrity - readers can verify they're + accessing the correct data even after buffer wraparound or reuse. + """ + + def __init__( + self, + data_buffer_size: int, + name: Optional[str] = None, + create: bool = False, + ): + self.data_buffer_size = data_buffer_size + self.is_writer = create + + self.ID_NBYTES = 4 + self.ID_MAX = 2**31 # exclusive, so 2**31 - 1 is the max value + self.SIZE_NBYTES = 4 + # 4 bytes for id, 4 bytes for buffer size + self.MD_SIZE = self.ID_NBYTES + self.SIZE_NBYTES + self.monotonic_id_end = 0 + self.monotonic_id_start = 0 + self.data_buffer_start = 0 + self.data_buffer_end = 0 + + if create: + # we are creating a buffer + self.metadata = { + self.monotonic_id_end: self.data_buffer_end + } # monotonic_id -> start address + self.shared_memory = shared_memory.SharedMemory( + create=True, size=self.data_buffer_size, name=name) + else: + # we are opening an existing buffer + # fix to https://stackoverflow.com/q/62748654/9191338 + # Python incorrectly tracks shared memory even if it is not + # created by the process. The following patch is a workaround. + with patch( + "multiprocessing.resource_tracker.register", + lambda *args, **kwargs: None, + ): + self.shared_memory = shared_memory.SharedMemory(name=name) + # See https://docs.python.org/3/library/multiprocessing.shared_memory.html # noqa + # Some platforms allocate memory based on page size, + # so the shared memory block size may be larger or equal + # to the requested size. The size parameter is ignored + # when attaching to an existing block. + assert self.shared_memory.size >= self.data_buffer_size + + logger.debug("Shared memory created/opened with name: %s, size: %d", + self.shared_memory.name, self.data_buffer_size) + + def handle(self): + return ( + self.data_buffer_size, + self.shared_memory.name, + ) + + def clear(self) -> None: + """Clear the ring buffer.""" + assert self.is_writer, "Only the writer can clear the buffer." + self.metadata.clear() + self.monotonic_id_end = 0 + self.monotonic_id_start = 0 + self.data_buffer_start = 0 + self.data_buffer_end = 0 + + def __del__(self): + if hasattr(self, "shared_memory"): + self.shared_memory.close() + if self.is_writer: + self.shared_memory.unlink() + + def int2byte(self, integer: int) -> bytes: + """Convert an integer to bytes.""" + return integer.to_bytes(self.ID_NBYTES, "little", signed=True) + + def byte2int(self, byte_data: bytes) -> int: + """Convert bytes back to an integer.""" + return int.from_bytes(byte_data, "little", signed=True) + + def allocate_buf(self, size: int) -> tuple[int, int]: + ''' + Allocate a buffer `MD_SIZE` + `size` bytes in the shared memory. + Memory layout: + [4-byte monotonic_id][4-byte size][buffer data...] + ''' + assert self.is_writer, "Only the writer can allocate buffers." + assert size > 0, "Size must be greater than 0" + size += self.MD_SIZE # add metadata size to the buffer size + # reset to beginning if the buffer does have enough contiguous space + buffer_end_reset = self.data_buffer_end % self.data_buffer_size + if buffer_end_reset + size > self.data_buffer_size: + buffer_end_reset = (self.data_buffer_end // self.data_buffer_size + + 1) * self.data_buffer_size + else: # no reset needed + buffer_end_reset = self.data_buffer_end + + # check if we have enough space in the data buffer + # i.e. if the new end (self.data_buffer_end + size) + # exceeds the start of the data buffer + occupied_size_new = buffer_end_reset + size - self.data_buffer_start + if occupied_size_new > self.data_buffer_size: + raise MemoryError("Not enough space in the data buffer, " + "try calling free_buf() to free up space") + self.data_buffer_end = buffer_end_reset + + # first 4 bytes as the monotonic id + buf_idx = self.data_buffer_end % self.data_buffer_size + self.shared_memory.buf[buf_idx:buf_idx + self.ID_NBYTES] = \ + self.int2byte(self.monotonic_id_end) + # next 4 bytes as the size of the data buffer + self.shared_memory.buf[buf_idx + self.ID_NBYTES: \ + buf_idx + self.MD_SIZE] = self.int2byte(size) + + # record metadata + self.metadata[self.monotonic_id_end % + self.ID_MAX] = self.data_buffer_end + # update buffer and monotonic id indices + current_buffer_end = self.data_buffer_end + current_id_end = self.monotonic_id_end + self.data_buffer_end += size + self.monotonic_id_end = (self.monotonic_id_end + 1) % self.ID_MAX + return current_buffer_end, current_id_end + + @contextmanager + def access_buf(self, address: int): + buf_idx = address % self.data_buffer_size + + # read metadata + metadata_buff = self.shared_memory.buf[buf_idx:buf_idx + self.MD_SIZE] + id = self.byte2int(metadata_buff[:self.ID_NBYTES]) + size = self.byte2int(metadata_buff[self.ID_NBYTES:self.MD_SIZE]) + + # yield the data buffer and metadata + data_buff = self.shared_memory.buf[buf_idx + self.MD_SIZE:buf_idx + + size] + with (memoryview(data_buff) as data_view, ): + yield data_view, (id, size) + + def free_buf(self, + is_free_fn: Callable[[int, memoryview], bool], + nbytes: Optional[int] = None) -> Iterable[int]: + ''' + Free a buffer of the given size. This is a no-op in shared memory, + but we need to keep track of the metadata. + + If freed memory spreads across the end and start of the ring buffer, + the actual freed memory will be in two segments. In this case there + still might not be a contiguous space of `nbytes` available. + + Args: + nbytes (int, optional): The size of the buffer to free. If None, + frees the maximum size of the ring buffer. + ''' + + assert self.is_writer, "Only the writer can free buffers." + logger.debug( + "Freeing up space in the ring buffer, " + "monotonic_id_start: %d, monotonic_id_end: %d", + self.monotonic_id_start, self.monotonic_id_end) + monotonic_id_before = self.monotonic_id_start + # if nbytes is None, free up the maximum size of the ring buffer + if nbytes is None: + nbytes = self.data_buffer_size + freed_bytes = 0 + while self.monotonic_id_start in self.metadata and freed_bytes < nbytes: + address = self.metadata[self.monotonic_id_start] + with self.access_buf(address) as (data_buff, metadata): + if is_free_fn(self.monotonic_id_start, data_buff): + # check passed, we can free the buffer + del self.metadata[self.monotonic_id_start] + self.monotonic_id_start = ((self.monotonic_id_start + 1) % + self.ID_MAX) + self.data_buffer_start = address + freed_bytes += metadata[1] + else: + # there are still readers, we cannot free the buffer + break + + logger.debug( + "Freed %d bytes from the ring buffer, " + "monotonic_id_start: %d, monotonic_id_end: %d", freed_bytes, + self.monotonic_id_start, self.monotonic_id_end) + + # buffer wrap around + if self.data_buffer_start >= self.data_buffer_size: + self.data_buffer_start -= self.data_buffer_size + self.data_buffer_end -= self.data_buffer_size + + monotonic_id_after = self.monotonic_id_start + # id wrap around + if monotonic_id_after >= monotonic_id_before: + return range(monotonic_id_before, monotonic_id_after) + else: + return chain(range(monotonic_id_before, self.ID_MAX), + range(0, monotonic_id_after)) + + +class ObjectSerde(ABC): + + @abstractmethod + def serialize(self, value: Any) -> tuple[Any, int, bytes, int]: + """Serialize an object to bytes.""" + raise NotImplementedError + + @abstractmethod + def deserialize(self, data: memoryview) -> Any: + """Deserialize bytes back to an object.""" + raise NotImplementedError + + +class MsgpackSerde(ObjectSerde): + + def __init__(self): + # Delayed import to avoid circular dependency + from vllm.multimodal.inputs import MultiModalKwargsItem + from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder + + self.encoder = MsgpackEncoder() + self.tensor_decoder = MsgpackDecoder(torch.Tensor) + self.mm_decoder = MsgpackDecoder(MultiModalKwargsItem) + self._mm_kwargs_item_cls = MultiModalKwargsItem + + def serialize( + self, + value: Any) -> tuple[Union[bytes, list[bytes]], int, bytes, int]: + len_arr = None + if isinstance(value, (torch.Tensor, self._mm_kwargs_item_cls)): + type_name = type(value).__name__ + value = self.encoder.encode(value) + len_arr = [len(s) for s in value] + nbytes = sum(len_arr) + else: + value = pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL) + type_name = type(value).__name__ + nbytes = len(value) + + object_metadata = (type_name, nbytes, len_arr) + serialized_metadata = pickle.dumps(object_metadata, + protocol=pickle.HIGHEST_PROTOCOL) + return value, nbytes, serialized_metadata, len(serialized_metadata) + + def deserialize(self, data_view: memoryview) -> Any: + # pickle.loads do not read past the end of a pickled object + # within a large buffer, so we can skip storing the metadata size + type_name, nbytes, len_arr = pickle.loads(data_view) + serialized_data = bytearray(data_view[-nbytes:]) + + if type_name == torch.Tensor.__name__: + obj = [] + start_idx = 0 + for length in len_arr: + item_bytes = serialized_data[start_idx:start_idx + length] + obj.append(item_bytes) + start_idx += length + obj = self.tensor_decoder.decode(obj) + elif type_name == self._mm_kwargs_item_cls.__name__: + obj = [] + start_idx = 0 + for length in len_arr: + item_bytes = serialized_data[start_idx:start_idx + length] + obj.append(item_bytes) + start_idx += length + obj = self.mm_decoder.decode(obj) + elif type_name == bytes.__name__: + obj = pickle.loads(serialized_data) + else: + raise ValueError( + f"Unsupported object type '{type_name}' in metadata") + + return obj + + +@dataclass +class ShmObjectStorageHandle: + max_object_size: int + n_readers: int + ring_buffer_handle: tuple[int, str] + serde_class: type[ObjectSerde] + reader_lock: Optional[LockType] + + +class SingleWriterShmObjectStorage: + """ + A single-writer, multiple-reader object storage system built on top of a + shared memory ring buffer. Provides key-value storage with automatic memory + management and cross-process serialization support. + + This storage system follows a FIFO (First-In-First-Out) eviction policy + where the oldest objects are automatically freed when memory runs low. + Memory is reclaimed based on reader reference counting - objects are only + freed when all readers have finished accessing them. + + Architecture: + - Single writer process can put(key, value) objects + - Multiple reader processes can get(address, monotonic_id) objects + - Built on SingleWriterShmRingBuffer for efficient shared memory management + - Thread-safe operations with reader synchronization via locks + + Key Features: + - FIFO Eviction: Oldest objects are evicted first when memory is full + - Reference Counting: Objects are only freed when no readers are + accessing them + - Duplicate Key Handling: Existing keys are not overwritten, just + re-referenced + - Customized Serialization: By default uses Msgpack for efficient + serialization of Python objects, but can be extended for custom types + - Cross-Process Safety: Uses shared memory with proper synchronization + - Automatic Cleanup: Garbage collection happens transparently during + allocation + + Memory Layout per Object: + [4-byte reference_count][metadata_size][serialized_object_data] + + Thread Safety: + - Writer operations (put, clear) are single-threaded by design + - Reader operations (get) are thread-safe with lock-based reference + counting + - Memory reclamation is handled exclusively by the writer process + """ + + def __init__( + self, + max_object_size: int, + n_readers: int, + ring_buffer: SingleWriterShmRingBuffer, + serde_class: type[ObjectSerde] = MsgpackSerde, + reader_lock: Optional[LockType] = None, + ): + """ + Initialize the object storage. + + Args: + max_object_size: Maximum size for a single object in bytes. + n_readers: Number of reader processes that can access the storage. + ring_buffer: The shared memory ring buffer for storing objects. + serde_class: Serializer/deserializer for objects. + reader_lock: Optional lock for synchronizing reader access. + Raises: + ValueError: If reader_lock is None for readers. + """ + + self.max_object_size = max_object_size + self.n_readers = n_readers + self.serde_class = serde_class + self.ser_de = serde_class() + self.ring_buffer = ring_buffer + self.is_writer = self.ring_buffer.is_writer + + self.flag_bytes = 4 # for in-use flag + + if self.is_writer: + # Key-value mapping: key -> (address, monotonic_id) + self.key_index: dict[str, tuple[int, int]] = {} + # Reverse mapping: monotonic_id -> key + self.id_index: dict[int, str] = {} + # Writer flag to track in-use status: monotonic_id -> count + self.writer_flag: dict[int, int] = {} + else: + if reader_lock is None: + raise ValueError("Lock must be provided for readers.") + + self._reader_lock = reader_lock + + def clear(self) -> None: + """Clear the object storage.""" + if self.is_writer: + self.ring_buffer.clear() + self.key_index.clear() + self.id_index.clear() + self.writer_flag.clear() + logger.debug("Object storage cleared and reinitialized.") + + def copy_to_buffer( + self, + data: Union[bytes, list[bytes]], + data_bytes: int, + metadata: bytes, + md_bytes: int, + data_view: memoryview, + ) -> None: + data_view[self.flag_bytes:self.flag_bytes + md_bytes] = metadata + if isinstance(data, bytes): + data_view[-data_bytes:] = data + elif isinstance(data, list): + start_idx = self.flag_bytes + md_bytes + for item_bytes in data: + item_size = len(item_bytes) + data_view[start_idx:start_idx + item_size] = item_bytes + start_idx += item_size + else: + raise ValueError( + f"Unsupported data type for serialization: {type(data)}") + + def increment_writer_flag(self, id: int) -> None: + """Set the in-use flag for the writer.""" + self.writer_flag[id] = self.writer_flag.get(id, 0) + 1 + + def increment_reader_flag(self, data_view: memoryview) -> None: + """Set the in-use flag for the reader.""" + # >0 for in-use flag + reader_count = self.ring_buffer.byte2int(data_view) + data_view[:] = self.ring_buffer.int2byte(reader_count + 1) + + def free_unused(self) -> None: + """Free unused buffers in the ring buffer.""" + # try to free up 2*max_object_size bytes of space in the ring buffer, + # since the buffer might be fragmented + freed_ids = self.ring_buffer.free_buf(self.default_is_free_check, + 2 * self.max_object_size) + # update the metadata after freeing up space + for freed_id in freed_ids: + key_to_free = self.id_index[freed_id] + del self.key_index[key_to_free] + del self.id_index[freed_id] + del self.writer_flag[freed_id] + + def is_cached(self, key: str) -> bool: + """ + Check if the object with the given key is cached. + """ + return key in self.key_index + + def get_cached(self, key: str) -> tuple[int, int]: + """ + Get the cached object by key if it exists. + """ + address, monotonic_id = self.key_index[key] + self.increment_writer_flag(monotonic_id) + return address, monotonic_id + + def put(self, key: str, value: Any) -> tuple[int, int]: + """ + Store a key-value pair in the object storage. + Attempts to free max_object_size bytes using FIFO order + when the ring buffer runs out of space during a put() operation. + + Args: + key: String key to identify the object + value: Any serializable Python object + + Raises: + MemoryError: If there's not enough space in the buffer + ValueError: If the serialized object is too large + ValueError: If the key already exists in the storage + """ + if key in self.key_index: + raise ValueError(f"Key '{key}' already exists in the storage.") + + object_data, data_bytes, object_metadata, md_bytes = \ + self.ser_de.serialize(value) + buffer_size = self.flag_bytes + data_bytes + md_bytes + + # Sanity checks + if buffer_size > self.max_object_size: + raise ValueError( + f"Serialized object size ({buffer_size} bytes) exceeds " + f"max object size ({self.max_object_size} bytes)") + + # Allocate new buffer + try: + address, monotonic_id = self.ring_buffer.allocate_buf(buffer_size) + except MemoryError: + self.free_unused() + # try again after freeing up space + address, monotonic_id = self.ring_buffer.allocate_buf(buffer_size) + + # Write data to buffer + with self.ring_buffer.access_buf(address) as (data_view, metadata): + data_view[:self.flag_bytes] = self.ring_buffer.int2byte(0) + self.copy_to_buffer(object_data, data_bytes, object_metadata, + md_bytes, data_view) + self.increment_writer_flag(monotonic_id) + + # Update key index + self.key_index[key] = (address, monotonic_id) + self.id_index[monotonic_id] = key + return address, monotonic_id + + def get(self, address: int, monotonic_id: int) -> Any: + # Read data from buffer + with self.ring_buffer.access_buf(address) as (data_view, buf_metadata): + # check id from metadata + if buf_metadata[0] != monotonic_id: + raise ValueError( + f"Data for address:id '{address}:{monotonic_id}'" + " has been modified or is invalid.") + + obj = self.ser_de.deserialize(data_view[self.flag_bytes:]) + + # decrease the in-use flag for reader reads + if self._reader_lock is not None: + with self._reader_lock: + self.increment_reader_flag(data_view[:self.flag_bytes]) + else: + # if self._reader_lock is None, it means we are the writer + # in this case, we do not need to decrease the reader count + assert self.is_writer + + return obj + + def handle(self): + """Get handle for sharing across processes.""" + return ShmObjectStorageHandle( + max_object_size=self.max_object_size, + n_readers=self.n_readers, + ring_buffer_handle=self.ring_buffer.handle(), + serde_class=self.serde_class, + reader_lock=self._reader_lock, + ) + + @staticmethod + def create_from_handle( + handle: ShmObjectStorageHandle) -> "SingleWriterShmObjectStorage": + logger.debug("Creating storage from handle: %s", handle) + ring_buffer = SingleWriterShmRingBuffer(*handle.ring_buffer_handle) + return SingleWriterShmObjectStorage( + max_object_size=handle.max_object_size, + n_readers=handle.n_readers, + ring_buffer=ring_buffer, + serde_class=handle.serde_class, + reader_lock=handle.reader_lock, + ) + + def default_is_free_check(self, id: int, buf: memoryview) -> bool: + """ + Default is_free function that checks if the first 4 bytes are zero. + This indicates that the buffer is free. + """ + reader_count = int.from_bytes(buf[0:4], "little", signed=True) + writer_count = self.writer_flag[id] + return reader_count >= writer_count * self.n_readers diff --git a/vllm/distributed/device_communicators/symm_mem.py b/vllm/distributed/device_communicators/symm_mem.py index d907e1b833d04..09012d16978d9 100644 --- a/vllm/distributed/device_communicators/symm_mem.py +++ b/vllm/distributed/device_communicators/symm_mem.py @@ -27,8 +27,13 @@ class SymmMemCommunicator: "10.0": [6, 8], } - def __init__(self, group: ProcessGroup, device: Union[int, str, - torch.device]): + def __init__( + self, + group: ProcessGroup, + device: Union[int, str, torch.device], + # add options for testing + force_multimem: Optional[bool] = None, + max_size_override: Optional[int] = None): self.disabled = True if not symm_mem_available: @@ -64,8 +69,17 @@ class SymmMemCommunicator: self.world_size, ) return - self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][ - self.world_size] + # Use override max_size if provided, otherwise use default + if max_size_override is not None: + self.max_size = max_size_override + logger.info( + "SymmMemCommunicator: Using override max_size: %s bytes", + self.max_size, + ) + else: + self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[ + self.device_capability][self.world_size] + self.buffer = torch_symm_mem.empty( self.max_size // self.dtype.itemsize, device=self.device, @@ -76,6 +90,7 @@ class SymmMemCommunicator: logger.warning("SymmMemCommunicator: symmetric memory " "multicast operations are not supported.") return + self.force_multimem = force_multimem self.disabled = False def should_use_symm_mem(self, inp: torch.Tensor): @@ -98,8 +113,18 @@ class SymmMemCommunicator: if out is None: out = torch.empty_like(inp) self.buffer[:inp.numel()].copy_(inp.view(-1)) - if self.world_size in self._WORLD_SIZES_MULTIMEM[ - self.device_capability]: + + # Determine which algorithm to use + use_multimem = False + if self.force_multimem is not None: + # Test override: use forced setting + use_multimem = self.force_multimem + else: + # Normal logic: use multimem for supported world sizes + use_multimem = self.world_size in self._WORLD_SIZES_MULTIMEM[ + self.device_capability] + + if use_multimem: torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()], "sum", self.group.group_name) diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index d5ab61473ab01..8f8baa7d59db7 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -337,11 +337,11 @@ class EplbState: Args: model (MixtureOfExperts): The MoE model. is_dummy (bool): If `True`, this is a dummy step and the load - metrics recorded in this forward pass will not count. Defaults - to `False`. + metrics recorded in this forward pass will not count. Defaults + to `False`. is_profile (bool): If `True`, perform a dummy rearrangement - with maximum communication cost. This is used in `profile_run` - to reserve enough memory for the communication buffer. + with maximum communication cost. This is used in `profile_run` + to reserve enough memory for the communication buffer. log_stats (bool): If `True`, log the expert load metrics. # Stats diff --git a/vllm/distributed/eplb/rebalance_algo.py b/vllm/distributed/eplb/rebalance_algo.py index 879b5b9f18240..3564a10dfc684 100644 --- a/vllm/distributed/eplb/rebalance_algo.py +++ b/vllm/distributed/eplb/rebalance_algo.py @@ -102,14 +102,14 @@ def rebalance_experts_hierarchical( num_groups: int, num_nodes: int, num_gpus: int, -): +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Parameters: weight: [num_moe_layers, num_logical_experts] num_physical_experts: number of physical experts after replication num_groups: number of expert groups - num_nodes: number of server nodes, where the intra-node network - (e.g, NVLink) is faster + num_nodes: number of server nodes, where the intra-node network + (e.g, NVLink) is faster num_gpus: number of GPUs, must be a multiple of `num_nodes` Returns: diff --git a/vllm/distributed/kv_events.py b/vllm/distributed/kv_events.py index 37f8f72fa9056..46f0cd9289b23 100644 --- a/vllm/distributed/kv_events.py +++ b/vllm/distributed/kv_events.py @@ -14,8 +14,9 @@ from typing import Any, Callable, Optional, Union import msgspec import zmq -from vllm.config import KVEventsConfig +from vllm.config.kv_events import KVEventsConfig from vllm.logger import init_logger +from vllm.v1.core.kv_cache_utils import ExternalBlockHash logger = init_logger(__name__) @@ -44,8 +45,8 @@ MEDIUM_GPU = "GPU" class BlockStored(KVCacheEvent): - block_hashes: list[int] - parent_block_hash: Optional[int] + block_hashes: list[ExternalBlockHash] + parent_block_hash: Optional[ExternalBlockHash] token_ids: list[int] block_size: int lora_id: Optional[int] @@ -53,7 +54,7 @@ class BlockStored(KVCacheEvent): class BlockRemoved(KVCacheEvent): - block_hashes: list[int] + block_hashes: list[ExternalBlockHash] medium: Optional[str] diff --git a/vllm/distributed/kv_transfer/__init__.py b/vllm/distributed/kv_transfer/__init__.py index fa9b7e4f14c02..cf58e7914972c 100644 --- a/vllm/distributed/kv_transfer/__init__.py +++ b/vllm/distributed/kv_transfer/__init__.py @@ -2,11 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm.distributed.kv_transfer.kv_transfer_state import ( - KVConnectorBaseType, ensure_kv_transfer_initialized, get_kv_transfer_group, - has_kv_transfer_group, is_v1_kv_transfer_group) + KVConnectorBaseType, ensure_kv_transfer_initialized, + ensure_kv_transfer_shutdown, get_kv_transfer_group, has_kv_transfer_group, + is_v1_kv_transfer_group) __all__ = [ "get_kv_transfer_group", "has_kv_transfer_group", "is_v1_kv_transfer_group", "ensure_kv_transfer_initialized", - "KVConnectorBaseType" + "ensure_kv_transfer_shutdown", "KVConnectorBaseType" ] diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 584fc1d655951..670f9c26b2104 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -14,7 +14,8 @@ from vllm.logger import init_logger # yapf: enable if TYPE_CHECKING: - from vllm.config import KVTransferConfig, VllmConfig + from vllm.config import VllmConfig + from vllm.config.kv_transfer import KVTransferConfig logger = init_logger(__name__) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 2804003f5a708..7e0b927c5b78f 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -149,7 +149,7 @@ class KVConnectorBase_V1(ABC): @abstractmethod def start_load_kv(self, forward_context: "ForwardContext", - **kwargs) -> None: + **kwargs: Any) -> None: """ Start loading the KV cache from the connector to vLLM's paged KV buffer. This is called from the forward context before the @@ -182,7 +182,8 @@ class KVConnectorBase_V1(ABC): @abstractmethod def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", **kwargs) -> None: + attn_metadata: "AttentionMetadata", + **kwargs: Any) -> None: """ Start saving a layer of KV cache from vLLM's paged buffer to the connector. This is called from within attention layer to @@ -226,6 +227,14 @@ class KVConnectorBase_V1(ABC): """ return None, None + def shutdown(self): + """ + Shutdown the connector. This is called when the worker process + is shutting down to ensure that all the async operations are + completed and the connector is cleaned up properly. + """ + return None + # ============================== # Scheduler-side methods # ============================== @@ -235,7 +244,7 @@ class KVConnectorBase_V1(ABC): self, request: "Request", num_computed_tokens: int, - ) -> tuple[int, bool]: + ) -> tuple[Optional[int], bool]: """ Get number of new tokens that can be loaded from the external KV cache beyond the num_computed_tokens. @@ -247,8 +256,11 @@ class KVConnectorBase_V1(ABC): Returns: A tuple with the following elements: - - The number of tokens that can be loaded from the - external KV cache beyond what is already computed. + - An optional number of tokens that can be loaded from the + external KV cache beyond what is already computed. + If None, it means that the connector needs more time to + determine the number of matched tokens, and the scheduler + should query for this request again later. - `True` if external KV cache tokens will be loaded asynchronously (between scheduler steps). Must be 'False' if the first element is 0. diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py index e838ac2499c04..2b0abe983fbb3 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py @@ -30,7 +30,7 @@ class LMCacheConnectorV1(KVConnectorBase_V1): # Worker-side methods # ============================== def start_load_kv(self, forward_context: "ForwardContext", - **kwargs) -> None: + **kwargs: Any) -> None: """ Start loading the KV cache from the connector to vLLM's paged KV buffer. This is called from the forward context before the @@ -61,7 +61,8 @@ class LMCacheConnectorV1(KVConnectorBase_V1): self._lmcache_engine.wait_for_layer_load(layer_name) def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", **kwargs) -> None: + attn_metadata: "AttentionMetadata", + **kwargs: Any) -> None: """ Start saving the a layer of KV cache from vLLM's paged buffer to the connector. This is called from within attention layer to @@ -110,7 +111,7 @@ class LMCacheConnectorV1(KVConnectorBase_V1): self, request: "Request", num_computed_tokens: int, - ) -> tuple[int, bool]: + ) -> tuple[Optional[int], bool]: """ Get number of new tokens that can be loaded from the external KV cache beyond the num_computed_tokens. diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 65bcb4d93b1e1..616d158d67670 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -7,7 +7,8 @@ from typing import TYPE_CHECKING, Any, Optional import torch -from vllm.config import KVTransferConfig, VllmConfig +from vllm.config import VllmConfig +from vllm.config.kv_transfer import KVTransferConfig from vllm.distributed.kv_events import KVCacheEvent from vllm.distributed.kv_transfer.kv_connector.factory import ( KVConnectorFactory) @@ -87,6 +88,18 @@ class MultiConnector(KVConnectorBase_V1): for c in self._connectors: c.clear_connector_metadata() + def shutdown(self): + exception: Optional[Exception] = None + for c in self._connectors: + try: + c.shutdown() + except Exception as e: + logger.exception("Exception during connector %s shutdown.", + c.__class__.__name__) + exception = e + if exception: + raise exception + # ============================== # Worker-side methods # ============================== @@ -142,11 +155,15 @@ class MultiConnector(KVConnectorBase_V1): self, request: "Request", num_computed_tokens: int, - ) -> tuple[int, bool]: + ) -> tuple[Optional[int], bool]: to_return = (0, False) for i, c in enumerate(self._connectors): toks, load_async = c.get_num_new_matched_tokens( request, num_computed_tokens) + # If there is a connector still looking up the matches, + # we return None to indicate that we are not done yet. + if toks is None: + return (None, False) # The first connector that has new matched tokens will be assigned # to this request. if to_return[0] == 0 and toks > 0: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index c2f73fa281555..c306eeb5aa7ab 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -162,7 +162,7 @@ class NixlConnector(KVConnectorBase_V1): def get_num_new_matched_tokens( self, request: "Request", - num_computed_tokens: int) -> tuple[int, bool]: + num_computed_tokens: int) -> tuple[Optional[int], bool]: assert self.connector_scheduler is not None return self.connector_scheduler.get_num_new_matched_tokens( request, num_computed_tokens) @@ -708,8 +708,6 @@ class NixlConnectorWorker: caches_data = [] # With hybrid allocator, layers can share a kv cache tensor seen_base_addresses = [] - xfer_buffers = (self.host_xfer_buffers - if self.use_host_buffer else kv_caches) # Note(tms): I modified this from the original region setup code. # K and V are now in different regions. Advantage is that we can @@ -770,7 +768,7 @@ class NixlConnectorWorker: # with joint KV for each block. This minimizes the overhead in # registerMem allowing faster descs queries. In order to be able to # split on kv_heads dim as required by heterogeneous TP, one must - # be able to index K/V separately. Hence the we double the number + # be able to index K/V separately. Hence we double the number # of 'virtual' regions here and halve `block_len` below. self.num_regions *= 2 diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py index 2485c57d86ecc..ec72905a0d3ec 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py @@ -91,7 +91,7 @@ class P2pNcclConnector(KVConnectorBase_V1): # ============================== def start_load_kv(self, forward_context: "ForwardContext", - **kwargs) -> None: + **kwargs: Any) -> None: """Start loading the KV cache from the connector buffer to vLLM's paged KV buffer. @@ -212,7 +212,8 @@ class P2pNcclConnector(KVConnectorBase_V1): return def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", **kwargs) -> None: + attn_metadata: "AttentionMetadata", + **kwargs: Any) -> None: """Start saving the KV cache of the layer from vLLM's paged buffer to the connector. @@ -278,7 +279,7 @@ class P2pNcclConnector(KVConnectorBase_V1): def get_finished( self, finished_req_ids: set[str], - **kwargs) -> tuple[Optional[set[str]], Optional[set[str]]]: + **kwargs: Any) -> tuple[Optional[set[str]], Optional[set[str]]]: """ Notifies worker-side connector ids of requests that have finished generating tokens. diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py index dfd95548c4632..fa7cc66ab654d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py @@ -15,7 +15,7 @@ import msgpack import torch import zmq -from vllm.config import KVTransferConfig +from vllm.config.kv_transfer import KVTransferConfig from vllm.distributed.device_communicators.pynccl_wrapper import ( NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum) from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501 diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/tensor_memory_pool.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/tensor_memory_pool.py index b775276d4a846..26070488bad89 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/tensor_memory_pool.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/tensor_memory_pool.py @@ -218,8 +218,9 @@ class TensorMemoryPool: return addr - def load_tensor(self, addr: int, dtype: torch.dtype, - shape: tuple[int, ...], device) -> torch.Tensor: + def load_tensor(self, addr: int, dtype: torch.dtype, shape: tuple[int, + ...], + device: torch.device) -> torch.Tensor: """Loads a tensor from pinned host memory to the specified device. Args: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py index fd79387269d56..48fa1a82c6775 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -3,7 +3,7 @@ import hashlib import os from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Optional import safetensors import torch @@ -90,7 +90,7 @@ class SharedStorageConnector(KVConnectorBase_V1): logger.info("Shared storage path is %s", self._storage_path) def start_load_kv(self, forward_context: "ForwardContext", - **kwargs) -> None: + **kwargs: Any) -> None: """Start loading the KV cache from the connector buffer to vLLM's paged KV buffer. @@ -191,7 +191,8 @@ class SharedStorageConnector(KVConnectorBase_V1): return def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", **kwargs) -> None: + attn_metadata: "AttentionMetadata", + **kwargs: Any) -> None: """Start saving the KV cache of the layer from vLLM's paged buffer to the connector. @@ -238,7 +239,7 @@ class SharedStorageConnector(KVConnectorBase_V1): self, request: "Request", num_computed_tokens: int, - ) -> tuple[int, bool]: + ) -> tuple[Optional[int], bool]: """ Get number of new tokens that can be loaded from the external KV cache beyond the num_computed_tokens. @@ -300,11 +301,12 @@ class SharedStorageConnector(KVConnectorBase_V1): total_need_load = 0 for new_req in scheduler_output.scheduled_new_reqs: if new_req.req_id in self._requests_need_load: - meta.add_request(token_ids=new_req.prompt_token_ids, - block_ids=new_req.block_ids[0], - block_size=self._block_size, - is_store=False, - mm_hashes=new_req.mm_hashes) + meta.add_request( + token_ids=new_req.prompt_token_ids, + block_ids=new_req.block_ids[0], + block_size=self._block_size, + is_store=False, + mm_hashes=[f.identifier for f in new_req.mm_features]) total_need_load += 1 else: # NOTE: here, we set the store and load being exclusive, @@ -312,11 +314,12 @@ class SharedStorageConnector(KVConnectorBase_V1): # NOTE(rob): for this debug implementation, we only cache # the original prompt tokens. if not self._found_match_for_request(new_req): - meta.add_request(token_ids=new_req.prompt_token_ids, - block_ids=new_req.block_ids[0], - block_size=self._block_size, - is_store=True, - mm_hashes=new_req.mm_hashes) + meta.add_request( + token_ids=new_req.prompt_token_ids, + block_ids=new_req.block_ids[0], + block_size=self._block_size, + is_store=True, + mm_hashes=[f.identifier for f in new_req.mm_features]) cached_reqs = scheduler_output.scheduled_cached_reqs for i, req_id in enumerate(cached_reqs.req_ids): @@ -341,11 +344,12 @@ class SharedStorageConnector(KVConnectorBase_V1): # of the block_ids for the request. block_ids = new_block_ids[0] - meta.add_request(token_ids=token_ids, - block_ids=block_ids, - block_size=self._block_size, - is_store=False, - mm_hashes=request.mm_hashes) + meta.add_request( + token_ids=token_ids, + block_ids=block_ids, + block_size=self._block_size, + is_store=False, + mm_hashes=[f.identifier for f in request.mm_features]) total_need_load += 1 assert total_need_load == len(self._requests_need_load) @@ -364,10 +368,10 @@ class SharedStorageConnector(KVConnectorBase_V1): """ num_tokens_to_check = align_to_block_size( len(request.prompt_token_ids) - 1, self._block_size) - foldername = self._generate_foldername_debug(torch.tensor( - request.prompt_token_ids)[:num_tokens_to_check], - request.mm_hashes, - create_folder=False) + foldername = self._generate_foldername_debug( + torch.tensor(request.prompt_token_ids)[:num_tokens_to_check], + [f.identifier for f in request.mm_features], + create_folder=False) return os.path.exists(foldername) def _generate_foldername_debug( diff --git a/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py index 0b560d1b3b3ce..2a434e280179e 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py @@ -13,7 +13,7 @@ import zmq from safetensors.torch import load as safetensors_load from safetensors.torch import save as safetensors_save -from vllm.config import KVTransferConfig +from vllm.config.kv_transfer import KVTransferConfig from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase from vllm.logger import init_logger from vllm.utils import join_host_port, make_zmq_path, split_host_port diff --git a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py index 09de0b682efca..7a79a8cc0c932 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py @@ -20,7 +20,7 @@ from typing import Callable, Optional import torch -from vllm.config import KVTransferConfig +from vllm.config.kv_transfer import KVTransferConfig from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase from vllm.distributed.utils import StatelessProcessGroup @@ -251,8 +251,8 @@ class PyNcclPipe(KVPipeBase): """ Receives a tensor and its metadata from the source rank. Blocking call. - Args: - tensor: The received tensor, or `None` if no tensor is received. + Returns: + The received tensor, or `None` if no tensor is received. """ if self.transport_thread is None: self.transport_thread = ThreadPoolExecutor(max_workers=1) diff --git a/vllm/distributed/kv_transfer/kv_transfer_state.py b/vllm/distributed/kv_transfer/kv_transfer_state.py index 5e0f64fca220c..d5747bed92771 100644 --- a/vllm/distributed/kv_transfer/kv_transfer_state.py +++ b/vllm/distributed/kv_transfer/kv_transfer_state.py @@ -64,3 +64,10 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: config=vllm_config, role=KVConnectorRole.WORKER) else: raise ValueError("V0 is no longer supported") + + +def ensure_kv_transfer_shutdown() -> None: + global _KV_CONNECTOR_AGENT + if _KV_CONNECTOR_AGENT is not None: + _KV_CONNECTOR_AGENT.shutdown() + _KV_CONNECTOR_AGENT = None diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 522dfc8d8b5a0..ef229299b6848 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -29,6 +29,7 @@ import weakref from collections import namedtuple from contextlib import contextmanager, nullcontext from dataclasses import dataclass +from datetime import timedelta from multiprocessing import shared_memory from typing import Any, Callable, Optional, Union from unittest.mock import patch @@ -978,13 +979,12 @@ def set_custom_all_reduce(enable: bool): _ENABLE_CUSTOM_ALL_REDUCE = enable -def init_distributed_environment( - world_size: int = -1, - rank: int = -1, - distributed_init_method: str = "env://", - local_rank: int = -1, - backend: str = "nccl", -): +def init_distributed_environment(world_size: int = -1, + rank: int = -1, + distributed_init_method: str = "env://", + local_rank: int = -1, + backend: str = "nccl", + timeout: Optional[timedelta] = None): logger.debug( "world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s", world_size, rank, local_rank, @@ -1020,7 +1020,8 @@ def init_distributed_environment( backend=backend, init_method=distributed_init_method, world_size=world_size, - rank=rank) + rank=rank, + timeout=timeout) # set the local rank # local_rank is not available in torch ProcessGroup, # see https://github.com/pytorch/pytorch/issues/122816 @@ -1117,7 +1118,7 @@ def initialize_model_parallel( "decode context model parallel group is already initialized") # Note(hc): In the current implementation of decode context parallel, # dcp_size must not exceed tp_size, because the world size does not - # change by DCP, it simply reuse the GPUs of TP group, and split one + # change by DCP, it simply reuses the GPUs of TP group, and split one # TP group into tp_size//dcp_size DCP groups. group_ranks = all_ranks.reshape( -1, decode_context_model_parallel_size).unbind(0) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index fdd25a2f9ce2f..ab43c0edc98d7 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -22,13 +22,13 @@ from typing_extensions import TypeIs, deprecated import vllm.envs as envs from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig, - ConfigFormat, ConfigType, ConvertOption, - DecodingConfig, DetailedTraceModules, Device, - DeviceConfig, DistributedExecutorBackend, EPLBConfig, + ConfigType, ConvertOption, DecodingConfig, + DetailedTraceModules, Device, DeviceConfig, + DistributedExecutorBackend, EPLBConfig, GuidedDecodingBackend, HfOverrides, KVEventsConfig, KVTransferConfig, LoadConfig, LogprobsMode, - LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig, - ModelDType, ModelImpl, MultiModalConfig, + LoRAConfig, MambaDType, MMCacheType, MMEncoderTPMode, + ModelConfig, ModelDType, ModelImpl, MultiModalConfig, ObservabilityConfig, ParallelConfig, PoolerConfig, PrefixCachingHashAlgo, RunnerOption, SchedulerConfig, SchedulerPolicy, SpeculativeConfig, TaskOption, @@ -227,8 +227,14 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]: elif contains_type(type_hints, int): kwargs[name]["type"] = int # Special case for large integers - if name in {"max_model_len", "max_num_batched_tokens"}: + human_readable_ints = { + "max_model_len", + "max_num_batched_tokens", + "kv_cache_memory_bytes", + } + if name in human_readable_ints: kwargs[name]["type"] = human_readable_int + kwargs[name]["help"] += f"\n\n{human_readable_int.__doc__}" elif contains_type(type_hints, float): kwargs[name]["type"] = float elif (contains_type(type_hints, dict) @@ -289,6 +295,7 @@ class EngineArgs: trust_remote_code: bool = ModelConfig.trust_remote_code allowed_local_media_path: str = ModelConfig.allowed_local_media_path download_dir: Optional[str] = LoadConfig.download_dir + safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy load_format: Union[str, LoadFormats] = LoadConfig.load_format config_format: str = ModelConfig.config_format dtype: ModelDType = ModelConfig.dtype @@ -334,6 +341,7 @@ class EngineArgs: swap_space: float = CacheConfig.swap_space cpu_offload_gb: float = CacheConfig.cpu_offload_gb gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization + kv_cache_memory_bytes: Optional[int] = CacheConfig.kv_cache_memory_bytes max_num_batched_tokens: Optional[ int] = SchedulerConfig.max_num_batched_tokens max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills @@ -365,6 +373,10 @@ class EngineArgs: MultiModalConfig.mm_processor_kwargs disable_mm_preprocessor_cache: bool = False # DEPRECATED mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb + mm_processor_cache_type: Optional[MMCacheType] = \ + MultiModalConfig.mm_processor_cache_type + mm_shm_cache_max_object_size_mb: int = \ + MultiModalConfig.mm_shm_cache_max_object_size_mb mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode io_processor_plugin: Optional[str] = None skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling @@ -547,7 +559,6 @@ class EngineArgs: help="Disable async output processing. This may result in " "lower performance.") model_group.add_argument("--config-format", - choices=[f.value for f in ConfigFormat], **model_kwargs["config_format"]) # This one is a special case because it can bool # or str. TODO: Handle this in get_kwargs @@ -588,6 +599,8 @@ class EngineArgs: load_group.add_argument("--load-format", **load_kwargs["load_format"]) load_group.add_argument("--download-dir", **load_kwargs["download_dir"]) + load_group.add_argument("--safetensors-load-strategy", + **load_kwargs["safetensors_load_strategy"]) load_group.add_argument("--model-loader-extra-config", **load_kwargs["model_loader_extra_config"]) load_group.add_argument("--ignore-patterns", @@ -732,6 +745,8 @@ class EngineArgs: cache_group.add_argument("--block-size", **cache_kwargs["block_size"]) cache_group.add_argument("--gpu-memory-utilization", **cache_kwargs["gpu_memory_utilization"]) + cache_group.add_argument("--kv-cache-memory-bytes", + **cache_kwargs["kv_cache_memory_bytes"]) cache_group.add_argument("--swap-space", **cache_kwargs["swap_space"]) cache_group.add_argument("--kv-cache-dtype", **cache_kwargs["cache_dtype"]) @@ -771,6 +786,12 @@ class EngineArgs: multimodal_group.add_argument("--disable-mm-preprocessor-cache", action="store_true", deprecated=True) + multimodal_group.add_argument( + "--mm-processor-cache-type", + **multimodal_kwargs["mm_processor_cache_type"]) + multimodal_group.add_argument( + "--mm-shm-cache-max-object-size-mb", + **multimodal_kwargs["mm_shm_cache_max_object_size_mb"]) multimodal_group.add_argument( "--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"]) multimodal_group.add_argument( @@ -987,6 +1008,9 @@ class EngineArgs: config_format=self.config_format, mm_processor_kwargs=self.mm_processor_kwargs, mm_processor_cache_gb=self.mm_processor_cache_gb, + mm_processor_cache_type=self.mm_processor_cache_type, + mm_shm_cache_max_object_size_mb=self. + mm_shm_cache_max_object_size_mb, mm_encoder_tp_mode=self.mm_encoder_tp_mode, override_pooler_config=self.override_pooler_config, logits_processor_pattern=self.logits_processor_pattern, @@ -1024,6 +1048,7 @@ class EngineArgs: return LoadConfig( load_format=self.load_format, download_dir=self.download_dir, + safetensors_load_strategy=self.safetensors_load_strategy, device="cpu" if is_online_quantization(self.quantization) else None, model_loader_extra_config=self.model_loader_extra_config, @@ -1053,9 +1078,10 @@ class EngineArgs: SpeculatorsConfig) if self.speculative_config is None: - hf_config = get_config(self.hf_config_path or self.model, - self.trust_remote_code, self.revision, - self.code_revision, self.config_format) + hf_config = get_config( + self.hf_config_path or target_model_config.model, + self.trust_remote_code, self.revision, self.code_revision, + self.config_format) # if loading a SpeculatorsConfig, load the speculative_config # details from the config directly @@ -1065,7 +1091,7 @@ class EngineArgs: self.speculative_config = {} self.speculative_config[ "num_speculative_tokens"] = hf_config.num_lookahead_tokens - self.speculative_config["model"] = self.model + self.speculative_config["model"] = target_model_config.model self.speculative_config["method"] = hf_config.method else: return None @@ -1159,7 +1185,7 @@ class EngineArgs: # Note(hc): In the current implementation of decode context # parallel(DCP), tp_size needs to be divisible by dcp_size, # because the world size does not change by dcp, it simply - # reuse the GPUs of TP group, and split one TP group into + # reuses the GPUs of TP group, and split one TP group into # tp_size//dcp_size DCP groups. assert self.tensor_parallel_size % self.decode_context_parallel_size \ == 0, ( @@ -1170,6 +1196,7 @@ class EngineArgs: cache_config = CacheConfig( block_size=self.block_size, gpu_memory_utilization=self.gpu_memory_utilization, + kv_cache_memory_bytes=self.kv_cache_memory_bytes, swap_space=self.swap_space, cache_dtype=self.kv_cache_dtype, is_attention_free=model_config.is_attention_free, @@ -1269,11 +1296,8 @@ class EngineArgs: # Async scheduling does not work with the uniprocess backend. if self.distributed_executor_backend is None: self.distributed_executor_backend = "mp" - logger.info("Using mp-based distributed executor backend " - "for async scheduling.") - if self.distributed_executor_backend == "uni": - raise ValueError("Async scheduling is not supported with " - "uni-process backend.") + logger.info("Defaulting to mp-based distributed executor " + "backend for async scheduling.") if self.pipeline_parallel_size > 1: raise ValueError("Async scheduling is not supported with " "pipeline-parallel-size > 1.") @@ -1477,12 +1501,6 @@ class EngineArgs: recommend_to_remove=False) return False - # No OTLP observability so far. - if (self.otlp_traces_endpoint or self.collect_detailed_traces): - _raise_or_fallback(feature_name="--otlp-traces-endpoint", - recommend_to_remove=False) - return False - # V1 supports N-gram, Medusa, and Eagle speculative decoding. if (self.speculative_config is not None and self.speculative_config.get("method") == "draft_model"): @@ -1504,6 +1522,7 @@ class EngineArgs: "FLASH_ATTN_MLA", "FLASHINFER", "FLASHINFER_VLLM_V1", + "FLASHINFER_MLA", "ROCM_AITER_MLA", "TORCH_SDPA_VLLM_V1", "FLEX_ATTENTION", @@ -1592,20 +1611,12 @@ class EngineArgs: "in low performance due to small KV cache size. Consider " "setting --max-model-len to a smaller value.", max_model_len) - # if using prefix caching, we must set a hash algo - if self.enable_prefix_caching: - # Disable prefix caching for multimodal models for VLLM_V0. - if model_config.is_multimodal_model: - logger.warning( - "--enable-prefix-caching is not supported for multimodal " - "models in V0 and has been disabled.") - self.enable_prefix_caching = False - - # VLLM_V0 only supports builtin hash algo for prefix caching. - if self.prefix_caching_hash_algo == "sha256": - raise ValueError( - "sha256 is not supported for prefix caching in V0 engine. " - "Please use 'builtin'.") + # Disable prefix caching for multimodal models for VLLM_V0. + if self.enable_prefix_caching and model_config.is_multimodal_model: + logger.warning( + "--enable-prefix-caching is not supported for multimodal " + "models in V0 and has been disabled.") + self.enable_prefix_caching = False # Set max_num_seqs to 256 for VLLM_V0. if self.max_num_seqs is None: diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 6010a4647a0af..c53ece18964cb 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -10,8 +10,9 @@ from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List, from weakref import ReferenceType import vllm.envs as envs -from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, VllmConfig) +from vllm.config import (DecodingConfig, ModelConfig, ParallelConfig, + SchedulerConfig, VllmConfig) +from vllm.config.lora import LoRAConfig from vllm.core.scheduler import SchedulerOutputs from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_timeout import asyncio_timeout diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 47f56e58130fa..c303d093f6324 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -16,9 +16,9 @@ import torch from typing_extensions import TypeVar import vllm.envs as envs -from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, - ObservabilityConfig, ParallelConfig, SchedulerConfig, - VllmConfig) +from vllm.config import (DecodingConfig, ModelConfig, ObservabilityConfig, + ParallelConfig, SchedulerConfig, VllmConfig) +from vllm.config.lora import LoRAConfig from vllm.core.scheduler import ScheduledSequenceGroup, SchedulerOutputs from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics_types import StatLoggerBase, Stats @@ -278,7 +278,8 @@ class LLMEngine: self.cache_config.block_size, "gpu_memory_utilization": self.cache_config.gpu_memory_utilization, - + "kv_cache_memory_bytes": + self.cache_config.kv_cache_memory_bytes, # Quantization "quantization": self.model_config.quantization, diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 0a8709db40880..2762175c430fb 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -379,7 +379,7 @@ class LoggingStatLogger(StatLoggerBase): if local_interval_elapsed(stats.now, self.last_local_log, self.local_interval): # Compute summary metrics for tracked stats (and log them - # to promethus if applicable). + # to prometheus if applicable). prompt_throughput = get_throughput(self.num_prompt_tokens, now=stats.now, last_log=self.last_local_log) @@ -432,7 +432,7 @@ class LoggingStatLogger(StatLoggerBase): class PrometheusStatLogger(StatLoggerBase): - """PrometheusStatLogger is used LLMEngine to log to Promethus.""" + """PrometheusStatLogger is used LLMEngine to log to Prometheus.""" _metrics_cls = Metrics _gauge_cls = prometheus_client.Gauge diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 0beb9c8cc0b97..7d1f29a9824d7 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -235,7 +235,7 @@ class MQLLMEngineClient(EngineClient): # therefore we have to inform that the current # processed requests failed as well. Send back a dead # engine error give this feedback and also give a - # 'hint' to the server to shutdown next. + # 'hint' to the server to shut down next. exception = self.dead_error if request_id is None: diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index b0b11a33a4443..94eacfbdfb301 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -78,6 +78,7 @@ class EngineClient(ABC): preprocessor = await self.get_input_preprocessor() tokenizer_group = preprocessor.get_tokenizer_group() tokenizer = await tokenizer_group.get_lora_tokenizer_async() + eos_token_id = tokenizer.eos_token_id if is_explicit_encoder_decoder_prompt(prompt): raise NotImplementedError @@ -104,7 +105,7 @@ class EngineClient(ABC): tokenized_length = len(prompt_token_ids) sort_beams_key = create_sort_beams_key_function( - tokenizer.eos_token_id, length_penalty) + eos_token_id, length_penalty) beam_search_params = SamplingParams( logprobs=2 * beam_width, @@ -154,7 +155,7 @@ class EngineClient(ABC): if result.outputs[0].logprobs is not None: logprobs = result.outputs[0].logprobs[0] for token_id, logprob_obj in logprobs.items(): - if token_id == tokenizer.eos_token_id and \ + if token_id == eos_token_id and \ not ignore_eos: completed.append( BeamSearchSequence( @@ -166,7 +167,7 @@ class EngineClient(ABC): cum_logprob=current_beam.cum_logprob + logprob_obj.logprob, finish_reason="stop", - stop_reason=tokenizer.eos_token_id)) + stop_reason=eos_token_id)) else: new_beams.append( BeamSearchSequence( @@ -189,14 +190,14 @@ class EngineClient(ABC): best_beams = sorted_completed[:beam_width] for beam in best_beams: - if (beam.tokens[-1] == tokenizer.eos_token_id and not ignore_eos): + if (beam.tokens[-1] == eos_token_id and not ignore_eos): # Skip the eos token in the text. tokens = beam.tokens[tokenized_length:-1] else: tokens = beam.tokens[tokenized_length:] beam.text = tokenizer.decode(tokens) - beam_search_output = RequestOutput( + yield RequestOutput( request_id=request_id, prompt=prompt_text, outputs=[ @@ -214,8 +215,6 @@ class EngineClient(ABC): prompt_token_ids=prompt_token_ids, prompt_logprobs=None) - yield beam_search_output - @abstractmethod def encode( self, diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 80e2c44a02513..aa231de93c0c3 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -41,7 +41,8 @@ from typing_extensions import Required, TypeAlias, TypedDict from vllm.config import ModelConfig from vllm.logger import init_logger from vllm.model_executor.models import SupportsMultiModal -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict +from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict, + MultiModalUUIDDict) from vllm.multimodal.utils import MediaConnector # yapf: disable from vllm.transformers_utils.chat_templates import ( @@ -75,7 +76,7 @@ class ChatCompletionContentPartAudioParam(TypedDict, total=False): class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False): - image_embeds: Required[Union[str, dict[str, str]]] + image_embeds: Optional[Union[str, dict[str, str]]] """ The image embeddings. It can be either: - A single base64 string. @@ -83,6 +84,11 @@ class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False): """ type: Required[Literal["image_embeds"]] """The type of the content part.""" + uuid: Optional[str] + """ + User-provided UUID of a media. User must guarantee that it is properly + generated and unique for different medias. + """ class VideoURL(TypedDict, total=False): @@ -117,7 +123,12 @@ class CustomChatCompletionContentPILImageParam(TypedDict, total=False): } """ - image_pil: Required[PILImage] + image_pil: Optional[PILImage] + uuid: Optional[str] + """ + User-provided UUID of a media. User must guarantee that it is properly + generated and unique for different medias. + """ class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False): @@ -130,7 +141,12 @@ class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False): } """ - image_url: Required[str] + image_url: Optional[str] + uuid: Optional[str] + """ + User-provided UUID of a media. User must guarantee that it is properly + generated and unique for different medias. + """ class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False): @@ -142,7 +158,7 @@ class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False): } """ - audio_url: Required[str] + audio_url: Optional[str] class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False): @@ -154,7 +170,12 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False): } """ - video_url: Required[str] + video_url: Optional[str] + uuid: Optional[str] + """ + User-provided UUID of a media. User must guarantee that it is properly + generated and unique for different medias. + """ class CustomThinkCompletionContentParam(TypedDict, total=False): @@ -566,7 +587,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): self._model_config = model_config self._tokenizer = tokenizer - self._items_by_modality = defaultdict[str, list[_T]](list) + self._items_by_modality = defaultdict[str, list[Optional[_T]]](list) + self._uuids_by_modality = defaultdict[str, list[Optional[str]]](list) @property def model_config(self) -> ModelConfig: @@ -591,10 +613,18 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): def mm_processor(self): return self.mm_registry.create_processor(self.model_config) - def add(self, modality: ModalityStr, item: _T) -> Optional[str]: + def add( + self, + modality: ModalityStr, + item: Optional[_T], + uuid: Optional[str] = None, + ) -> Optional[str]: """ Add a multi-modal item to the current prompt and returns the placeholder string to use, if any. + + An optional uuid can be added which serves as a unique identifier of the + media. """ input_modality = modality.replace("_embeds", "") num_items = len(self._items_by_modality[modality]) + 1 @@ -602,9 +632,35 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): self.mm_processor.validate_num_items(input_modality, num_items) self._items_by_modality[modality].append(item) + self._uuids_by_modality[modality].append(uuid) return self.model_cls.get_placeholder_str(modality, num_items) + def all_mm_uuids(self) -> Optional[MultiModalUUIDDict]: + if not self._items_by_modality: + return None + mm_uuids = {} + uuids_by_modality = dict(self._uuids_by_modality) + if "image" in uuids_by_modality and "image_embeds" in uuids_by_modality: + raise ValueError( + "Mixing raw image and embedding inputs is not allowed" + ) + + if "image_embeds" in uuids_by_modality: + image_embeds_uuids = uuids_by_modality["image_embeds"] + if len(image_embeds_uuids) > 1: + raise ValueError( + "Only one message can have {'type': 'image_embeds'}" + ) + mm_uuids["image"] = uuids_by_modality["image_embeds"] + if "image" in uuids_by_modality: + mm_uuids["image"] = uuids_by_modality["image"] # UUIDs of images + if "audio" in uuids_by_modality: + mm_uuids["audio"] = uuids_by_modality["audio"] # UUIDs of audios + if "video" in uuids_by_modality: + mm_uuids["video"] = uuids_by_modality["video"] # UUIDs of videos + return mm_uuids + @abstractmethod def create_parser(self) -> "BaseMultiModalContentParser": raise NotImplementedError @@ -645,10 +701,15 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]): if not self._items_by_modality: return None mm_inputs = {} - items_by_modality = { - modality: await asyncio.gather(*items) - for modality, items in self._items_by_modality.items() - } + items_by_modality = {} + for modality, items in self._items_by_modality.items(): + coros = [] + for item in items: + if item is not None: + coros.append(item) + else: + coros.append(asyncio.sleep(0)) + items_by_modality[modality] = await asyncio.gather(*coros) if "image" in items_by_modality and "image_embeds" in items_by_modality: raise ValueError( @@ -697,29 +758,40 @@ class BaseMultiModalContentParser(ABC): return dict(self._placeholder_storage) @abstractmethod - def parse_image(self, image_url: str) -> None: + def parse_image( + self, image_url: Optional[str], uuid: Optional[str] = None) -> None: raise NotImplementedError @abstractmethod def parse_image_embeds( - self, image_embeds: Union[str, dict[str, str]] + self, + image_embeds: Union[str, dict[str, str], None], + uuid: Optional[str] = None, ) -> None: raise NotImplementedError @abstractmethod - def parse_image_pil(self, image_pil: Image.Image) -> None: + def parse_image_pil( + self, image_pil: Optional[Image.Image], uuid: Optional[str] = None + ) -> None: raise NotImplementedError @abstractmethod - def parse_audio(self, audio_url: str) -> None: + def parse_audio( + self, audio_url: Optional[str], uuid: Optional[str] = None + ) -> None: raise NotImplementedError @abstractmethod - def parse_input_audio(self, input_audio: InputAudio) -> None: + def parse_input_audio( + self, input_audio: Optional[InputAudio], uuid: Optional[str] = None + ) -> None: raise NotImplementedError @abstractmethod - def parse_video(self, video_url: str) -> None: + def parse_video( + self, video_url: Optional[str], uuid: Optional[str] = None + ) -> None: raise NotImplementedError @@ -734,49 +806,75 @@ class MultiModalContentParser(BaseMultiModalContentParser): allowed_local_media_path=tracker.allowed_local_media_path, ) - def parse_image(self, image_url: str) -> None: - image = self._connector.fetch_image(image_url) + def parse_image( + self, image_url: Optional[str], uuid: Optional[str] = None + ) -> None: + image = self._connector.fetch_image(image_url) if image_url else None - placeholder = self._tracker.add("image", image) + placeholder = self._tracker.add("image", image, uuid) self._add_placeholder("image", placeholder) def parse_image_embeds( - self, image_embeds: Union[str, dict[str, str]] + self, + image_embeds: Union[str, dict[str, str], None], + uuid: Optional[str] = None, ) -> None: if isinstance(image_embeds, dict): embeds = { k: self._connector.fetch_image_embedding(v) for k, v in image_embeds.items() } - placeholder = self._tracker.add("image_embeds", embeds) + placeholder = self._tracker.add("image_embeds", embeds, uuid) if isinstance(image_embeds, str): embedding = self._connector.fetch_image_embedding(image_embeds) - placeholder = self._tracker.add("image_embeds", embedding) + placeholder = self._tracker.add("image_embeds", embedding, uuid) + + if image_embeds is None: + placeholder = self._tracker.add("image_embeds", None, uuid) self._add_placeholder("image", placeholder) - def parse_image_pil(self, image_pil: Image.Image) -> None: - placeholder = self._tracker.add("image", image_pil) + def parse_image_pil( + self, image_pil: Optional[Image.Image], uuid: Optional[str] = None + ) -> None: + placeholder = self._tracker.add("image", image_pil, uuid) self._add_placeholder("image", placeholder) - def parse_audio(self, audio_url: str) -> None: - audio = self._connector.fetch_audio(audio_url) + def parse_audio( + self, audio_url: Optional[str], uuid: Optional[str] = None + ) -> None: + audio = self._connector.fetch_audio(audio_url) if audio_url else None - placeholder = self._tracker.add("audio", audio) + placeholder = self._tracker.add("audio", audio, uuid) self._add_placeholder("audio", placeholder) - def parse_input_audio(self, input_audio: InputAudio) -> None: - audio_data = input_audio.get("data", "") - audio_format = input_audio.get("format", "") - audio_url = f"data:audio/{audio_format};base64,{audio_data}" + def parse_input_audio( + self, input_audio: Optional[InputAudio], uuid: Optional[str] = None + ) -> None: + if input_audio: + audio_data = input_audio.get("data", "") + audio_format = input_audio.get("format", "") + if audio_data: + audio_url = f"data:audio/{audio_format};base64,{audio_data}" + else: + # If a UUID is provided, audio data may be empty. + audio_url = None + else: + audio_url = None - return self.parse_audio(audio_url) + return self.parse_audio(audio_url, uuid) - def parse_video(self, video_url: str) -> None: - video = self._connector.fetch_video(video_url=video_url) + def parse_video( + self, video_url: Optional[str], uuid: Optional[str] = None + ) -> None: + video = ( + self._connector.fetch_video(video_url=video_url) + if video_url + else None + ) - placeholder = self._tracker.add("video", video) + placeholder = self._tracker.add("video", video, uuid) self._add_placeholder("video", placeholder) @@ -790,16 +888,24 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser): allowed_local_media_path=tracker.allowed_local_media_path, ) - def parse_image(self, image_url: str) -> None: - image_coro = self._connector.fetch_image_async(image_url) + def parse_image( + self, image_url: Optional[str], uuid: Optional[str] = None + ) -> None: + image_coro = ( + self._connector.fetch_image_async(image_url) if image_url else None + ) - placeholder = self._tracker.add("image", image_coro) + placeholder = self._tracker.add("image", image_coro, uuid) self._add_placeholder("image", placeholder) def parse_image_embeds( - self, image_embeds: Union[str, dict[str, str]] + self, + image_embeds: Union[str, dict[str, str], None], + uuid: Optional[str] = None, ) -> None: - future: asyncio.Future[Union[str, dict[str, str]]] = asyncio.Future() + future: asyncio.Future[Union[str, dict[str, str], None]] = ( + asyncio.Future() + ) if isinstance(image_embeds, dict): embeds = { @@ -812,33 +918,60 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser): embedding = self._connector.fetch_image_embedding(image_embeds) future.set_result(embedding) - placeholder = self._tracker.add("image_embeds", future) + if image_embeds is None: + future.set_result(None) + + placeholder = self._tracker.add("image_embeds", future, uuid) self._add_placeholder("image", placeholder) - def parse_image_pil(self, image_pil: Image.Image) -> None: - future: asyncio.Future[Image.Image] = asyncio.Future() - future.set_result(image_pil) + def parse_image_pil( + self, image_pil: Optional[Image.Image], uuid: Optional[str] = None + ) -> None: + future: asyncio.Future[Optional[Image.Image]] = asyncio.Future() + if image_pil: + future.set_result(image_pil) + else: + future.set_result(None) - placeholder = self._tracker.add("image", future) + placeholder = self._tracker.add("image", future, uuid) self._add_placeholder("image", placeholder) - def parse_audio(self, audio_url: str) -> None: - audio_coro = self._connector.fetch_audio_async(audio_url) + def parse_audio( + self, audio_url: Optional[str], uuid: Optional[str] = None + ) -> None: + audio_coro = ( + self._connector.fetch_audio_async(audio_url) if audio_url else None + ) - placeholder = self._tracker.add("audio", audio_coro) + placeholder = self._tracker.add("audio", audio_coro, uuid) self._add_placeholder("audio", placeholder) - def parse_input_audio(self, input_audio: InputAudio) -> None: - audio_data = input_audio.get("data", "") - audio_format = input_audio.get("format", "") - audio_url = f"data:audio/{audio_format};base64,{audio_data}" + def parse_input_audio( + self, input_audio: Optional[InputAudio], uuid: Optional[str] = None + ) -> None: + if input_audio: + audio_data = input_audio.get("data", "") + audio_format = input_audio.get("format", "") + if audio_data: + audio_url = f"data:audio/{audio_format};base64,{audio_data}" + else: + # If a UUID is provided, audio data may be empty. + audio_url = None + else: + audio_url = None - return self.parse_audio(audio_url) + return self.parse_audio(audio_url, uuid) - def parse_video(self, video_url: str) -> None: - video = self._connector.fetch_video_async(video_url=video_url) + def parse_video( + self, video_url: Optional[str], uuid: Optional[str] = None + ) -> None: + video = ( + self._connector.fetch_video_async(video_url=video_url) + if video_url + else None + ) - placeholder = self._tracker.add("video", video) + placeholder = self._tracker.add("video", video, uuid) self._add_placeholder("video", placeholder) @@ -1049,8 +1182,9 @@ def _parse_chat_message_content_mm_part( part, dict ) # This is needed to avoid mypy errors: part.get() from str part_type = part.get("type", None) + uuid = part.get("uuid", None) - if isinstance(part_type, str) and part_type in MM_PARSER_MAP: + if isinstance(part_type, str) and part_type in MM_PARSER_MAP and uuid is None: # noqa: E501 content = MM_PARSER_MAP[part_type](part) # Special case for 'image_url.detail' @@ -1065,25 +1199,54 @@ def _parse_chat_message_content_mm_part( # Handle missing 'type' but provided direct URL fields. # 'type' is required field by pydantic - if part_type is None: - if part.get("image_url") is not None: + if part_type is None or uuid is not None: + if "image_url" in part: image_params = cast( CustomChatCompletionContentSimpleImageParam, part ) - return "image_url", image_params.get("image_url", "") - if part.get("audio_url") is not None: + image_url = image_params.get("image_url", None) + if isinstance(image_url, dict): + # Can potentially happen if user provides a uuid + # with url as a dict of {"url": url} + image_url = image_url.get("url", None) + return "image_url", image_url + if "image_pil" in part: + # "image_pil" could be None if UUID is provided. + image_params = cast( # type: ignore + CustomChatCompletionContentPILImageParam, part + ) + image_pil = image_params.get("image_pil", None) + return "image_pil", image_pil + if "image_embeds" in part: + # "image_embeds" could be None if UUID is provided. + image_params = cast( # type: ignore + ChatCompletionContentPartImageEmbedsParam, part + ) + image_embeds = image_params.get("image_embeds", None) + return "image_embeds", image_embeds + if "audio_url" in part: audio_params = cast( CustomChatCompletionContentSimpleAudioParam, part ) - return "audio_url", audio_params.get("audio_url", "") + audio_url = audio_params.get("audio_url", None) + if isinstance(audio_url, dict): + # Can potentially happen if user provides a uuid + # with url as a dict of {"url": url} + audio_url = audio_url.get("url", None) + return "audio_url", audio_url if part.get("input_audio") is not None: input_audio_params = cast(dict[str, str], part) return "input_audio", input_audio_params - if part.get("video_url") is not None: + if "video_url" in part: video_params = cast( CustomChatCompletionContentSimpleVideoParam, part ) - return "video_url", video_params.get("video_url", "") + video_url = video_params.get("video_url", None) + if isinstance(video_url, dict): + # Can potentially happen if user provides a uuid + # with url as a dict of {"url": url} + video_url = video_url.get("url", None) + return "video_url", video_url # Raise an error if no 'type' or direct URL is found. raise ValueError("Missing 'type' field in multimodal part.") @@ -1092,15 +1255,9 @@ def _parse_chat_message_content_mm_part( return part_type, "unknown part_type content" -VALID_MESSAGE_CONTENT_MM_PART_TYPES = ( +PART_TYPES_TO_SKIP_NONE_CONTENT = ( "text", "refusal", - "image_url", - "image_embeds", - "image_pil", - "audio_url", - "input_audio", - "video_url", ) @@ -1161,7 +1318,7 @@ def _parse_chat_message_content_part( part_type, content = _parse_chat_message_content_mm_part(part) # if part_type is text/refusal/image_url/audio_url/video_url/input_audio but # content is None, log a warning and skip - if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and content is None: + if part_type in PART_TYPES_TO_SKIP_NONE_CONTENT and content is None: logger.warning( "Skipping multimodal part '%s' (type: '%s') " "with empty / unparsable content.", @@ -1177,30 +1334,42 @@ def _parse_chat_message_content_part( else: return str_content + # For media items, if a user has provided one, use it. Otherwise, insert + # a placeholder empty uuid. + uuid = part.get("uuid", None) + if uuid is not None: + uuid = str(uuid) + modality = None if part_type == "image_pil": - image_content = cast(Image.Image, content) - mm_parser.parse_image_pil(image_content) + if content is not None: + image_content = cast(Image.Image, content) + else: + image_content = None + mm_parser.parse_image_pil(image_content, uuid) modality = "image" elif part_type in ("image_url", "input_image"): str_content = cast(str, content) - mm_parser.parse_image(str_content) + mm_parser.parse_image(str_content, uuid) modality = "image" elif part_type == "image_embeds": - content = cast(Union[str, dict[str, str]], content) - mm_parser.parse_image_embeds(content) + if content is not None: + content = cast(Union[str, dict[str, str]], content) + else: + content = None + mm_parser.parse_image_embeds(content, uuid) modality = "image" elif part_type == "audio_url": str_content = cast(str, content) - mm_parser.parse_audio(str_content) + mm_parser.parse_audio(str_content, uuid) modality = "audio" elif part_type == "input_audio": dict_content = cast(InputAudio, content) - mm_parser.parse_input_audio(dict_content) + mm_parser.parse_input_audio(dict_content, uuid) modality = "audio" elif part_type == "video_url": str_content = cast(str, content) - mm_parser.parse_video(str_content) + mm_parser.parse_video(str_content, uuid) modality = "video" else: raise NotImplementedError(f"Unknown part type: {part_type}") @@ -1288,7 +1457,11 @@ def parse_chat_messages( model_config: ModelConfig, tokenizer: AnyTokenizer, content_format: _ChatTemplateContentFormat, -) -> tuple[list[ConversationMessage], Optional[MultiModalDataDict]]: +) -> tuple[ + list[ConversationMessage], + Optional[MultiModalDataDict], + Optional[MultiModalUUIDDict], +]: conversation: list[ConversationMessage] = [] mm_tracker = MultiModalItemTracker(model_config, tokenizer) @@ -1308,7 +1481,7 @@ def parse_chat_messages( _postprocess_messages(conversation) - return conversation, mm_tracker.all_mm_data() + return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids() def parse_chat_messages_futures( @@ -1316,7 +1489,11 @@ def parse_chat_messages_futures( model_config: ModelConfig, tokenizer: AnyTokenizer, content_format: _ChatTemplateContentFormat, -) -> tuple[list[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]: +) -> tuple[ + list[ConversationMessage], + Awaitable[Optional[MultiModalDataDict]], + Optional[MultiModalUUIDDict], +]: conversation: list[ConversationMessage] = [] mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer) @@ -1336,7 +1513,7 @@ def parse_chat_messages_futures( _postprocess_messages(conversation) - return conversation, mm_tracker.all_mm_data() + return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids() def apply_hf_chat_template( diff --git a/vllm/entrypoints/context.py b/vllm/entrypoints/context.py index 7723c5d5cbcfc..9012639457cad 100644 --- a/vllm/entrypoints/context.py +++ b/vllm/entrypoints/context.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio +import contextlib import json import logging from abc import ABC, abstractmethod @@ -57,9 +59,14 @@ class ConversationContext(ABC): @abstractmethod async def init_tool_sessions(self, tool_server: Optional[ToolServer], - exit_stack: AsyncExitStack) -> None: + exit_stack: AsyncExitStack, + request_id: str) -> None: pass + @abstractmethod + async def cleanup_session(self) -> None: + raise NotImplementedError("Should not be called.") + class SimpleContext(ConversationContext): @@ -89,9 +96,13 @@ class SimpleContext(ConversationContext): raise NotImplementedError("Should not be called.") async def init_tool_sessions(self, tool_server: Optional[ToolServer], - exit_stack: AsyncExitStack) -> None: + exit_stack: AsyncExitStack, + request_id: str) -> None: pass + async def cleanup_session(self) -> None: + raise NotImplementedError("Should not be called.") + class HarmonyContext(ConversationContext): @@ -103,6 +114,7 @@ class HarmonyContext(ConversationContext): self._messages = messages self.available_tools = available_tools self._tool_sessions: dict[str, Union[ClientSession, Tool]] = {} + self.called_tools: set[str] = set() self.parser = get_streamable_parser_for_assistant() self.num_init_messages = len(messages) @@ -234,7 +246,8 @@ class HarmonyContext(ConversationContext): last_msg = self.messages[-1] recipient = last_msg.recipient return recipient is not None and (recipient.startswith("browser.") - or recipient.startswith("python")) + or recipient.startswith("python") or + recipient.startswith("container.")) async def call_tool(self) -> list[Message]: if not self.messages: @@ -248,6 +261,9 @@ class HarmonyContext(ConversationContext): elif recipient.startswith("python"): return await self.call_python_tool( self._tool_sessions["python"], last_msg) + elif recipient.startswith("container."): + return await self.call_container_tool( + self._tool_sessions["container"], last_msg) raise ValueError("No tool call found") def render_for_completion(self) -> list[int]: @@ -256,6 +272,7 @@ class HarmonyContext(ConversationContext): async def call_search_tool(self, tool_session: Union["ClientSession", Tool], last_msg: Message) -> list[Message]: + self.called_tools.add("browser") if isinstance(tool_session, Tool): return await tool_session.get_result(self) tool_name = last_msg.recipient.split(".")[1] @@ -265,12 +282,16 @@ class HarmonyContext(ConversationContext): content = TextContent(text=result_str) author = Author(role=Role.TOOL, name=last_msg.recipient) return [ - Message(author=author, content=[content], recipient=Role.ASSISTANT) + Message(author=author, + content=[content], + recipient=Role.ASSISTANT, + channel=last_msg.channel) ] async def call_python_tool(self, tool_session: Union["ClientSession", Tool], last_msg: Message) -> list[Message]: + self.called_tools.add("python") if isinstance(tool_session, Tool): return await tool_session.get_result(self) param = { @@ -290,13 +311,63 @@ class HarmonyContext(ConversationContext): ] async def init_tool_sessions(self, tool_server: Optional[ToolServer], - exit_stack: AsyncExitStack) -> None: + exit_stack: AsyncExitStack, + request_id: str) -> None: if tool_server: for tool_name in self.available_tools: if tool_name not in self._tool_sessions: - self._tool_sessions[ - tool_name] = await exit_stack.enter_async_context( - tool_server.new_session(tool_name)) + tool_session = await exit_stack.enter_async_context( + tool_server.new_session(tool_name, request_id)) + self._tool_sessions[tool_name] = tool_session + exit_stack.push_async_exit(self.cleanup_session) + + async def call_container_tool(self, tool_session: Union["ClientSession", + Tool], + last_msg: Message) -> list[Message]: + """ + Call container tool. Expect this to be run in a stateful docker + with command line terminal. + The official container tool would at least + expect the following format: + - for tool name: exec + - args: + { + "cmd":List[str] "command to execute", + "workdir":optional[str] "current working directory", + "env":optional[object/dict] "environment variables", + "session_name":optional[str] "session name", + "timeout":optional[int] "timeout in seconds", + "user":optional[str] "user name", + } + """ + self.called_tools.add("container") + if isinstance(tool_session, Tool): + return await tool_session.get_result(self) + tool_name = last_msg.recipient.split(".")[1].split(" ")[0] + args = json.loads(last_msg.content[0].text) + result = await tool_session.call_tool(tool_name, args) + result_str = result.content[0].text + content = TextContent(text=result_str) + author = Author(role=Role.TOOL, name=last_msg.recipient) + return [ + Message(author=author, + content=[content], + recipient=Role.ASSISTANT, + channel=last_msg.channel) + ] + + async def cleanup_session(self, *args, **kwargs) -> None: + """Can be used as coro to used in __aexit__""" + + async def cleanup_tool_session(tool_session): + if not isinstance(tool_session, Tool): + logger.info("Cleaning up tool session for %s", + tool_session._client_info) + with contextlib.suppress(Exception): + await tool_session.call_tool("cleanup_session", {}) + + await asyncio.gather(*(cleanup_tool_session(self._tool_sessions[tool]) + for tool in self.called_tools)) class StreamingHarmonyContext(HarmonyContext): diff --git a/vllm/entrypoints/harmony_utils.py b/vllm/entrypoints/harmony_utils.py index d1ff06425fcb3..f7528ba81dce5 100644 --- a/vllm/entrypoints/harmony_utils.py +++ b/vllm/entrypoints/harmony_utils.py @@ -16,11 +16,13 @@ from openai.types.responses.response_function_web_search import ( from openai.types.responses.response_reasoning_item import ( Content as ResponseReasoningTextContent) from openai.types.responses.tool import Tool -from openai_harmony import (Author, Conversation, DeveloperContent, - HarmonyEncodingName, Message, ReasoningEffort, - Role, StreamableParser, SystemContent, TextContent, - ToolDescription, load_harmony_encoding) +from openai_harmony import (Author, ChannelConfig, Conversation, + DeveloperContent, HarmonyEncodingName, Message, + ReasoningEffort, Role, StreamableParser, + SystemContent, TextContent, ToolDescription, + load_harmony_encoding) +from vllm import envs from vllm.entrypoints.openai.protocol import (ChatCompletionToolsParam, ResponseInputOutputItem) from vllm.utils import random_uuid @@ -33,6 +35,20 @@ REASONING_EFFORT = { _harmony_encoding = None +# Builtin tools that should be included in the system message when +# they are available and requested by the user. +# Tool args are provided by MCP tool descriptions. Output +# of the tools are stringified. +BUILTIN_TOOLS = { + "web_search_preview", + "code_interpreter", + "container", +} + + +def has_custom_tools(tool_types: list[str]) -> bool: + return not set(tool_types).issubset(BUILTIN_TOOLS) + def get_encoding(): global _harmony_encoding @@ -48,10 +64,19 @@ def get_system_message( start_date: Optional[str] = None, browser_description: Optional[str] = None, python_description: Optional[str] = None, + container_description: Optional[str] = None, + instructions: Optional[str] = None, + with_custom_tools: bool = False, ) -> Message: sys_msg_content = SystemContent.new() if model_identity is not None: sys_msg_content = sys_msg_content.with_model_identity(model_identity) + if (instructions is not None + and envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS): + current_identity = sys_msg_content.model_identity + new_identity = (f'{current_identity}\n{instructions}' + if current_identity else instructions) + sys_msg_content = sys_msg_content.with_model_identity(new_identity) if reasoning_effort is not None: sys_msg_content = sys_msg_content.with_reasoning_effort( REASONING_EFFORT[reasoning_effort]) @@ -63,6 +88,14 @@ def get_system_message( sys_msg_content = sys_msg_content.with_tools(browser_description) if python_description is not None: sys_msg_content = sys_msg_content.with_tools(python_description) + if container_description is not None: + sys_msg_content = sys_msg_content.with_tools(container_description) + if not with_custom_tools: + channel_config = sys_msg_content.channel_config + invalid_channel = "commentary" + new_config = ChannelConfig.require_channels( + [c for c in channel_config.valid_channels if c != invalid_channel]) + sys_msg_content = sys_msg_content.with_channel_config(new_config) sys_msg = Message.from_role_and_content(Role.SYSTEM, sys_msg_content) return sys_msg @@ -86,14 +119,17 @@ def get_developer_message( tools: Optional[list[Union[Tool, ChatCompletionToolsParam]]] = None, ) -> Message: dev_msg_content = DeveloperContent.new() - if instructions is not None: + if (instructions is not None + and not envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS): dev_msg_content = dev_msg_content.with_instructions(instructions) if tools is not None: function_tools: list[Union[Tool, ChatCompletionToolsParam]] = [] for tool in tools: - if tool.type in ("web_search_preview", "code_interpreter"): + if tool.type in ("web_search_preview", "code_interpreter", + "container"): # These are built-in tools that are added to the system message. pass + elif tool.type == "function": function_tools.append(tool) else: @@ -136,6 +172,8 @@ def parse_response_input( TextContent(text=text_prefix + c["text"]) for c in content ] msg = Message.from_role_and_contents(role, contents) + if role == "assistant": + msg = msg.with_channel("final") elif response_msg["type"] == "function_call_output": call_id = response_msg["call_id"] call_response: Optional[ResponseFunctionToolCall] = None @@ -273,7 +311,7 @@ def parse_output_message(message: Message) -> list[ResponseOutputItem]: call_id=f"call_{random_id}", type="function_call", name=function_name, - id=f"ft_{random_id}", + id=f"fc_{random_id}", ) output_items.append(response_item) elif recipient is not None and (recipient.startswith("python") diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index 4e852ba594930..887e277109240 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -95,7 +95,7 @@ async def serve_http(app: FastAPI, port = uvicorn_kwargs["port"] process = find_process_using_port(port) if process is not None: - logger.debug( + logger.warning( "port %s is used by process %s launched with command:\n%s", port, process, " ".join(process.cmdline())) logger.info("Shutting down FastAPI HTTP server.") diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 9b2ad808eb03e..4b51dbcd8acb9 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -110,6 +110,14 @@ class LLM: values will increase the KV cache size and thus improve the model's throughput. However, if the value is too high, it may cause out-of- memory (OOM) errors. + kv_cache_memory_bytes: Size of KV Cache per GPU in bytes. By default, + this is set to None and vllm can automatically infer the kv cache + size based on gpu_memory_utilization. However, users may want to + manually specify the kv cache memory size. kv_cache_memory_bytes + allows more fine-grain control of how much memory gets used when + compared with using gpu_memory_memory_utilization. Note that + kv_cache_memory_bytes (when not-None) ignores + gpu_memory_utilization swap_space: The size (GiB) of CPU memory per GPU to use as swap space. This can be used for temporarily storing the states of the requests when their `best_of` sampling parameters are larger than 1. If all @@ -184,6 +192,7 @@ class LLM: hf_overrides: Optional[HfOverrides] = None, mm_processor_kwargs: Optional[dict[str, Any]] = None, override_pooler_config: Optional[PoolerConfig] = None, + kv_cache_memory_bytes: Optional[int] = None, compilation_config: Optional[Union[int, dict[str, Any], CompilationConfig]] = None, logits_processors: Optional[list[Union[str, @@ -204,7 +213,7 @@ class LLM: if "kv_transfer_config" in kwargs and isinstance( kwargs["kv_transfer_config"], dict): - from vllm.config import KVTransferConfig + from vllm.config.kv_transfer import KVTransferConfig raw_config_dict = kwargs["kv_transfer_config"] try: kwargs["kv_transfer_config"] = KVTransferConfig( @@ -251,6 +260,7 @@ class LLM: tokenizer_revision=tokenizer_revision, seed=seed, gpu_memory_utilization=gpu_memory_utilization, + kv_cache_memory_bytes=kv_cache_memory_bytes, swap_space=swap_space, cpu_offload_gb=cpu_offload_gb, enforce_eager=enforce_eager, @@ -693,6 +703,106 @@ class LLM: return outputs + def preprocess_chat( + self, + messages: Union[list[ChatCompletionMessageParam], + list[list[ChatCompletionMessageParam]]], + lora_request: Optional[LoRARequest] = None, + chat_template: Optional[str] = None, + chat_template_content_format: ChatTemplateContentFormatOption = "auto", + add_generation_prompt: bool = True, + continue_final_message: bool = False, + tools: Optional[list[dict[str, Any]]] = None, + chat_template_kwargs: Optional[dict[str, Any]] = None, + mm_processor_kwargs: Optional[dict[str, Any]] = None, + ) -> list[TokensPrompt]: + """ + Generate prompt for a chat conversation. The pre-processed + prompt can then be used as input for the other LLM methods. + + Refer to `chat` for a complete description of the arguments. + Returns: + A list of `TokensPrompts` objects containing the tokenized + prompt after chat template interpolation, and the + pre-processed multi-modal inputs. + """ + list_of_messages: list[list[ChatCompletionMessageParam]] + + # Handle multi and single conversations + if is_list_of(messages, list): + # messages is list[list[...]] + list_of_messages = cast(list[list[ChatCompletionMessageParam]], + messages) + else: + # messages is list[...] + list_of_messages = [ + cast(list[ChatCompletionMessageParam], messages) + ] + + tokenizer = self.get_tokenizer(lora_request) + model_config = self.llm_engine.get_model_config() + resolved_content_format = resolve_chat_template_content_format( + chat_template, + tools, + chat_template_content_format, + tokenizer, + model_config=model_config, + ) + + _chat_template_kwargs: dict[str, Any] = dict( + chat_template=chat_template, + add_generation_prompt=add_generation_prompt, + continue_final_message=continue_final_message, + tools=tools, + ) + _chat_template_kwargs.update(chat_template_kwargs or {}) + + prompts: list[TokensPrompt] = [] + + for msgs in list_of_messages: + # NOTE: _parse_chat_message_content_parts() currently doesn't + # handle mm_processor_kwargs, since there is no implementation in + # the chat message parsing for it. + conversation, mm_data, mm_uuids = parse_chat_messages( + msgs, + model_config, + tokenizer, + content_format=resolved_content_format, + ) + + if isinstance(tokenizer, MistralTokenizer): + prompt_token_ids = apply_mistral_chat_template( + tokenizer, + messages=msgs, + **_chat_template_kwargs, + ) + else: + prompt_str = apply_hf_chat_template( + tokenizer=tokenizer, + conversation=conversation, + model_config=model_config, + **_chat_template_kwargs, + ) + # Special tokens are already included in chat templates so + # should not be added by the tokenizer in this case. + prompt_token_ids = tokenizer.encode(prompt_str, + add_special_tokens=False) + + prompt = TokensPrompt(prompt_token_ids=prompt_token_ids) + + if mm_data is not None: + prompt["multi_modal_data"] = mm_data + + if mm_uuids is not None: + prompt["multi_modal_uuids"] = mm_uuids + + if mm_processor_kwargs is not None: + prompt["mm_processor_kwargs"] = mm_processor_kwargs + + prompts.append(prompt) + + return prompts + def chat( self, messages: Union[list[ChatCompletionMessageParam], @@ -759,77 +869,18 @@ class LLM: A list of `RequestOutput` objects containing the generated responses in the same order as the input messages. """ - list_of_messages: list[list[ChatCompletionMessageParam]] - # Handle multi and single conversations - if is_list_of(messages, list): - # messages is list[list[...]] - list_of_messages = cast(list[list[ChatCompletionMessageParam]], - messages) - else: - # messages is list[...] - list_of_messages = [ - cast(list[ChatCompletionMessageParam], messages) - ] - - tokenizer = self.get_tokenizer(lora_request) - model_config = self.llm_engine.get_model_config() - resolved_content_format = resolve_chat_template_content_format( - chat_template, - tools, - chat_template_content_format, - tokenizer, - model_config=model_config, - ) - - _chat_template_kwargs: dict[str, Any] = dict( + prompts = self.preprocess_chat( + messages=messages, + lora_request=lora_request, chat_template=chat_template, + chat_template_content_format=chat_template_content_format, add_generation_prompt=add_generation_prompt, continue_final_message=continue_final_message, tools=tools, + chat_template_kwargs=chat_template_kwargs, + mm_processor_kwargs=mm_processor_kwargs, ) - _chat_template_kwargs.update(chat_template_kwargs or {}) - - prompts: list[Union[TokensPrompt, TextPrompt]] = [] - - for msgs in list_of_messages: - # NOTE: _parse_chat_message_content_parts() currently doesn't - # handle mm_processor_kwargs, since there is no implementation in - # the chat message parsing for it. - conversation, mm_data = parse_chat_messages( - msgs, - model_config, - tokenizer, - content_format=resolved_content_format, - ) - - if isinstance(tokenizer, MistralTokenizer): - prompt_token_ids = apply_mistral_chat_template( - tokenizer, - messages=msgs, - **_chat_template_kwargs, - ) - else: - prompt_str = apply_hf_chat_template( - tokenizer=tokenizer, - conversation=conversation, - model_config=model_config, - **_chat_template_kwargs, - ) - # Special tokens are already included in chat templates so - # should not be added by the tokenizer in this case. - prompt_token_ids = tokenizer.encode(prompt_str, - add_special_tokens=False) - - prompt = TokensPrompt(prompt_token_ids=prompt_token_ids) - - if mm_data is not None: - prompt["multi_modal_data"] = mm_data - - if mm_processor_kwargs is not None: - prompt["mm_processor_kwargs"] = mm_processor_kwargs - - prompts.append(prompt) return self.generate( prompts, @@ -1440,6 +1491,11 @@ class LLM: for i, prompt in enumerate(it): + if isinstance(prompt, dict): + self._validate_mm_data_and_uuids( + prompt.get("multi_modal_data"), + prompt.get("multi_modal_uuids")) + param = params[i] if isinstance(params, Sequence) else params tokenization_kwargs: dict[str, Any] = {} @@ -1456,6 +1512,41 @@ class LLM: priority=priority[i] if priority else 0, ) + def _validate_mm_data_and_uuids( + self, + multi_modal_data: Optional[Any], # MultiModalDataDict + multi_modal_uuids: Optional[Any], # MultiModalUUIDDict + ): + """ + Validate that if any multi-modal data is skipped (i.e. None), + then its corresponding UUID must be set. + """ + if multi_modal_data is None: + return + + for modality, data in multi_modal_data.items(): + if isinstance(data, list): + for i, d in enumerate(data): + if d is None: + if multi_modal_uuids is None or modality not in multi_modal_uuids or multi_modal_uuids[ # noqa: E501 + modality] is None: + raise ValueError( + f"Multi-modal data for {modality} is None " + f"but UUID is not provided") + else: + if len( + multi_modal_uuids[modality] + ) <= i or multi_modal_uuids[modality][i] is None: + raise ValueError( + f"Multi-modal data for {modality} is None " + f"but UUID is not provided") + else: + if data is None and (multi_modal_uuids is None + or modality not in multi_modal_uuids + or multi_modal_uuids[modality] is None): + raise ValueError(f"Multi-modal data for {modality} is None" + f" but UUID is not provided") + def _add_request( self, prompt: PromptType, diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index b6667ebf152e1..c159bcee315f2 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1717,6 +1717,8 @@ async def init_app_state( if args.tool_server == "demo": tool_server: Optional[ToolServer] = DemoToolServer() + assert isinstance(tool_server, DemoToolServer) + await tool_server.init_and_validate() elif args.tool_server: tool_server = MCPToolServer() await tool_server.add_tool_server(args.tool_server) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index d0b5d013eb9e5..1c2a6f58197d8 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -134,14 +134,13 @@ schema. Example: `[{"type": "text", "text": "Hello world!"}]`""" """If specified, will run the OpenAI frontend server in the same process as the model serving engine.""" enable_request_id_headers: bool = False - """If specified, API server will add X-Request-Id header to responses. - Caution: this hurts performance at high QPS.""" + """If specified, API server will add X-Request-Id header to responses.""" enable_auto_tool_choice: bool = False - """If specified, exclude tool definitions in prompts when - tool_choice='none'.""" - exclude_tools_when_tool_choice_none: bool = False """Enable auto tool choice for supported models. Use `--tool-call-parser` to specify which parser to use.""" + exclude_tools_when_tool_choice_none: bool = False + """If specified, exclude tool definitions in prompts when + tool_choice='none'.""" tool_call_parser: Optional[str] = None """Select the tool call parser depending on the model that you're using. This is used to parse the model-generated tool call into OpenAI API format. @@ -172,8 +171,8 @@ schema. Example: `[{"type": "text", "text": "Hello world!"}]`""" """Enable the /get_tokenizer_info endpoint. May expose chat templates and other tokenizer configuration.""" enable_log_outputs: bool = False - """If set to True, enable logging of model outputs (generations) - in addition to the input logging that is enabled by default.""" + """If True, log model outputs (generations). + Requires --enable-log-requests.""" h11_max_incomplete_event_size: int = H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT """Maximum size (bytes) of an incomplete HTTP event (header or body) for h11 parser. Helps mitigate header abuse. Default: 4194304 (4 MB).""" @@ -204,7 +203,7 @@ schema. Example: `[{"type": "text", "text": "Hello world!"}]`""" frontend_kwargs["lora_modules"]["type"] = optional_type(str) frontend_kwargs["lora_modules"]["action"] = LoRAParserAction - # Special case: Middleware needs append action + # Special case: Middleware needs to append action frontend_kwargs["middleware"]["action"] = "append" frontend_kwargs["middleware"]["type"] = str if "nargs" in frontend_kwargs["middleware"]: @@ -274,6 +273,9 @@ def validate_parsed_serve_args(args: argparse.Namespace): if args.enable_auto_tool_choice and not args.tool_call_parser: raise TypeError("Error: --enable-auto-tool-choice requires " "--tool-call-parser") + if args.enable_log_outputs and not args.enable_log_requests: + raise TypeError("Error: --enable-log-outputs requires " + "--enable-log-requests") def create_parser_for_docs() -> FlexibleArgumentParser: diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index c56c68cf76442..4dcb1f3f1c89f 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -242,7 +242,7 @@ def get_logits_processors(processors: Optional[LogitsProcessors], elif processors: raise ValueError( "The `logits_processors` argument is not supported by this " - "server. See --logits-processor-pattern engine argugment " + "server. See --logits-processor-pattern engine argument " "for more information.") return None @@ -1270,9 +1270,20 @@ class CompletionRequest(OpenAIBaseModel): @model_validator(mode="before") @classmethod def validate_prompt_and_prompt_embeds(cls, data): - if data.get("prompt") is None and data.get("prompt_embeds") is None: + prompt = data.get("prompt") + prompt_embeds = data.get("prompt_embeds") + + prompt_is_empty = (prompt is None + or (isinstance(prompt, str) and prompt == "")) + embeds_is_empty = (prompt_embeds is None + or (isinstance(prompt_embeds, list) + and len(prompt_embeds) == 0)) + + if prompt_is_empty and embeds_is_empty: raise ValueError( - "At least one of `prompt` or `prompt_embeds` must be set.") + "Either prompt or prompt_embeds must be provided and non-empty." + ) + return data @model_validator(mode="before") diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 5c7adc53f49b2..579f6f537ee2d 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -186,7 +186,7 @@ class OpenAIServingChat(OpenAIServing): lora_request = self._maybe_get_adapters( request, supports_default_mm_loras=True) - model_name = self._get_model_name(request.model, lora_request) + model_name = self.models.model_name(lora_request) tokenizer = await self.engine_client.get_tokenizer(lora_request) diff --git a/vllm/entrypoints/openai/serving_classification.py b/vllm/entrypoints/openai/serving_classification.py index 98b7a206fa0cb..7e88424c169ce 100644 --- a/vllm/entrypoints/openai/serving_classification.py +++ b/vllm/entrypoints/openai/serving_classification.py @@ -20,6 +20,7 @@ from vllm.entrypoints.openai.serving_engine import (ClassificationServeContext, OpenAIServing, ServeContext) from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.renderer import RenderConfig from vllm.logger import init_logger from vllm.outputs import ClassificationOutput, PoolingRequestOutput from vllm.pooling_params import PoolingParams @@ -57,8 +58,7 @@ class ClassificationMixin(OpenAIServing): renderer = self._get_renderer(ctx.tokenizer) ctx.engine_prompts = await renderer.render_prompt( prompt_or_prompts=ctx.request.input, - max_length=self.max_model_len, - truncate_prompt_tokens=ctx.request.truncate_prompt_tokens) + config=self._build_render_config(ctx.request)) return None @@ -114,6 +114,12 @@ class ClassificationMixin(OpenAIServing): usage=usage, ) + def _build_render_config(self, + request: ClassificationRequest) -> RenderConfig: + return RenderConfig( + max_length=self.max_model_len, + truncate_prompt_tokens=request.truncate_prompt_tokens) + class ServingClassification(ClassificationMixin): request_id_prefix = "classify" @@ -140,7 +146,7 @@ class ServingClassification(ClassificationMixin): request: ClassificationRequest, raw_request: Request, ) -> Union[ClassificationResponse, ErrorResponse]: - model_name = self._get_model_name(request.model) + model_name = self.models.model_name() request_id = (f"{self.request_id_prefix}-" f"{self._base_request_id(raw_request)}") diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index b26140d4b9d7a..c2de449a96994 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -26,14 +26,11 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs, PromptTokenUsageInfo, RequestResponseMetadata, UsageInfo) -from vllm.entrypoints.openai.serving_engine import ( - EmbedsPrompt as ServingEngineEmbedsPrompt) from vllm.entrypoints.openai.serving_engine import (OpenAIServing, - TextTokensPrompt, - clamp_prompt_logprobs, - is_text_tokens_prompt) + clamp_prompt_logprobs) # yapf: enable from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.renderer import RenderConfig from vllm.entrypoints.utils import get_max_tokens from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt, is_tokens_prompt) @@ -132,12 +129,12 @@ class OpenAIServingCompletion(OpenAIServing): else: tokenizer = await self.engine_client.get_tokenizer(lora_request ) + renderer = self._get_renderer(tokenizer) - request_prompts, engine_prompts = await self._preprocess_completion( - request, - tokenizer, - request.prompt, - add_special_tokens=request.add_special_tokens, + engine_prompts = await renderer.render_prompt_and_embeds( + prompt_or_prompts=request.prompt, + prompt_embeds=request.prompt_embeds, + config=self._build_render_config(request), ) except ValueError as e: logger.exception("Error in preprocessing prompt inputs") @@ -198,7 +195,7 @@ class OpenAIServingCompletion(OpenAIServing): self._log_inputs( request_id_item, - request_prompts[i], + engine_prompt, params=sampling_params, lora_request=lora_request, ) @@ -235,7 +232,7 @@ class OpenAIServingCompletion(OpenAIServing): result_generator = merge_async_iterators(*generators) - model_name = self._get_model_name(request.model, lora_request) + model_name = self.models.model_name(lora_request) num_prompts = len(engine_prompts) # Similar to the OpenAI API, when n != best_of, we do not stream the @@ -249,7 +246,7 @@ class OpenAIServingCompletion(OpenAIServing): if stream: return self.completion_stream_generator( request, - request_prompts, + engine_prompts, result_generator, request_id, created_time, @@ -273,11 +270,9 @@ class OpenAIServingCompletion(OpenAIServing): # We did not pass it into vLLM engine to avoid being redundant # with the inputs token IDs if final_res.prompt is None: - request_prompt = request_prompts[i] - if is_text_tokens_prompt(request_prompt): - final_res.prompt = request_prompt["prompt"] - else: - final_res.prompt = None + engine_prompt = engine_prompts[i] + final_res.prompt = None if is_embeds_prompt( + engine_prompt) else engine_prompt.get("prompt") final_res_batch_checked = cast(list[RequestOutput], final_res_batch) @@ -313,8 +308,7 @@ class OpenAIServingCompletion(OpenAIServing): async def completion_stream_generator( self, request: CompletionRequest, - request_prompts: list[Union[TextTokensPrompt, - ServingEngineEmbedsPrompt]], + engine_prompts: list[Union[TokensPrompt, EmbedsPrompt]], result_generator: AsyncIterator[tuple[int, RequestOutput]], request_id: str, created_time: int, @@ -350,14 +344,11 @@ class OpenAIServingCompletion(OpenAIServing): num_cached_tokens = res.num_cached_tokens first_iteration = False - if res.prompt is not None: - prompt_text = res.prompt - else: - request_prompt = request_prompts[prompt_idx] - if is_text_tokens_prompt(request_prompt): - prompt_text = request_prompt["prompt"] - else: - prompt_text = None + prompt_text = res.prompt + if prompt_text is None: + engine_prompt = engine_prompts[prompt_idx] + prompt_text = None if is_embeds_prompt( + engine_prompt) else engine_prompt.get("prompt") # Prompt details are excluded from later streamed outputs if prompt_token_ids is not None: @@ -378,6 +369,8 @@ class OpenAIServingCompletion(OpenAIServing): assert request.max_tokens is not None if request.echo and not has_echoed[i]: assert prompt_token_ids is not None + if request.return_token_ids: + prompt_text = "" assert prompt_text is not None if request.max_tokens == 0: # only return the prompt @@ -525,6 +518,8 @@ class OpenAIServingCompletion(OpenAIServing): for output in final_res.outputs: assert request.max_tokens is not None if request.echo: + if request.return_token_ids: + prompt_text = "" assert prompt_text is not None if request.max_tokens == 0: token_ids = prompt_token_ids @@ -676,3 +671,18 @@ class OpenAIServingCompletion(OpenAIServing): tokens=out_tokens, top_logprobs=out_top_logprobs, ) + + def _build_render_config( + self, + request: CompletionRequest, + max_input_length: Optional[int] = None, + ) -> RenderConfig: + max_input_tokens_len = self.max_model_len - (request.max_tokens or 0) + return RenderConfig( + max_length=max_input_tokens_len, + truncate_prompt_tokens=request.truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, + cache_salt=request.cache_salt, + needs_detokenization=bool(request.echo + and not request.return_token_ids), + ) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index c375f9e7c5064..c0d1fe4b6e168 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -28,7 +28,7 @@ from vllm.entrypoints.openai.serving_engine import (EmbeddingServeContext, TextTokensPrompt) # yapf: enable from vllm.entrypoints.openai.serving_models import OpenAIServingModels -from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt +from vllm.entrypoints.renderer import RenderConfig from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.logger import init_logger from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput, @@ -98,23 +98,28 @@ class EmbeddingMixin(OpenAIServing): add_special_tokens=ctx.request.add_special_tokens, ) else: - # Set max_length based on chunked processing capability - if self._should_use_chunked_processing(ctx.request): - max_length = None - else: - max_length = self.max_embed_len or self.max_model_len - ctx.engine_prompts = await renderer.render_prompt( prompt_or_prompts=ctx.request.input, - max_length=max_length, - truncate_prompt_tokens=ctx.request.truncate_prompt_tokens, - add_special_tokens=ctx.request.add_special_tokens, + config=self._build_render_config(ctx.request), ) return None except (ValueError, TypeError) as e: logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(str(e)) + def _build_render_config( + self, request: EmbeddingCompletionRequest) -> RenderConfig: + # Set max_length based on chunked processing capability + if self._should_use_chunked_processing(request): + max_length = None + else: + max_length = self.max_embed_len or self.max_model_len + + return RenderConfig( + max_length=max_length, + truncate_prompt_tokens=request.truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens) + @override def _build_response( self, @@ -290,7 +295,7 @@ class EmbeddingMixin(OpenAIServing): async def _create_single_prompt_generator( self, ctx: EmbeddingServeContext, - engine_prompt: Union[EngineTokensPrompt, EngineEmbedsPrompt], + engine_prompt: EngineTokensPrompt, pooling_params: PoolingParams, trace_headers: Optional[Mapping[str, str]], prompt_index: int, @@ -303,12 +308,6 @@ class EmbeddingMixin(OpenAIServing): params=pooling_params, lora_request=ctx.lora_request) - # Mypy has an existing bug related to inferring the variance - # of TypedDicts with `builtins.enumerate`: - # https://github.com/python/mypy/issues/8586#issuecomment-2867698435 - engine_prompt = cast(Union[EngineTokensPrompt, EngineEmbedsPrompt], - engine_prompt) - # Return the original generator without wrapping return self.engine_client.encode( engine_prompt, @@ -375,12 +374,8 @@ class EmbeddingMixin(OpenAIServing): continue # Normal processing for short prompts or non-token prompts - # Cast engine_prompt to the expected type for mypy - engine_prompt_typed = cast( - Union[EngineTokensPrompt, EngineEmbedsPrompt], - engine_prompt) generator = await self._create_single_prompt_generator( - ctx, engine_prompt_typed, pooling_params, trace_headers, i) + ctx, engine_prompt, pooling_params, trace_headers, i) generators.append(generator) from vllm.utils import merge_async_iterators @@ -604,7 +599,7 @@ class OpenAIServingEmbedding(EmbeddingMixin): See https://platform.openai.com/docs/api-reference/embeddings/create for the API specification. This API mimics the OpenAI Embedding API. """ - model_name = self._get_model_name(request.model) + model_name = self.models.model_name() request_id = ( f"{self.request_id_prefix}-" f"{self._base_request_id(raw_request, request.request_id)}") diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 1a2236de4fa42..d391cc50ad232 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import asyncio -import io import json import sys import time @@ -9,10 +7,8 @@ import traceback from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence from concurrent.futures import ThreadPoolExecutor from http import HTTPStatus -from typing import (Annotated, Any, Callable, ClassVar, Generic, Optional, - TypeVar, Union, cast, overload) +from typing import Any, Callable, ClassVar, Generic, Optional, TypeVar, Union -import pybase64 import torch from fastapi import Request from pydantic import BaseModel, ConfigDict, Field @@ -62,12 +58,11 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, TranslationRequest) from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.tool_parsers import ToolParser -from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer +from vllm.entrypoints.renderer import (BaseRenderer, CompletionRenderer, + RenderConfig) # yapf: enable -from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt from vllm.inputs.data import PromptType from vllm.inputs.data import TokensPrompt as EngineTokensPrompt -from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger from vllm.logprobs import Logprob, PromptLogprobs from vllm.lora.request import LoRARequest @@ -149,8 +144,7 @@ class RequestProcessingMixin(BaseModel): """ request_prompts: Optional[Sequence[RequestPrompt]] = [] - engine_prompts: Optional[Union[list[EngineTokensPrompt], - list[EngineEmbedsPrompt]]] = [] + engine_prompts: Optional[list[EngineTokensPrompt]] = [] model_config = ConfigDict(arbitrary_types_allowed=True) @@ -255,6 +249,19 @@ class OpenAIServing: tokenizer=tokenizer, async_tokenizer_pool=self._async_tokenizer_pool) + def _build_render_config( + self, + request: Any, + ) -> RenderConfig: + """ + Build and return a `RenderConfig` for an endpoint. + + Used by the renderer to control how prompts are prepared + (e.g., tokenization and length handling). Endpoints should + implement this with logic appropriate to their request type. + """ + raise NotImplementedError + def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer: """ Return (and cache) an `AsyncMicrobatchTokenizer` bound to the @@ -368,13 +375,6 @@ class OpenAIServing: for i, engine_prompt in enumerate(ctx.engine_prompts): request_id_item = f"{ctx.request_id}-{i}" - # Mypy has an existing bug related to inferring the variance of - # TypedDicts with `builtins.enumerate`: - # https://github.com/python/mypy/issues/8586#issuecomment-2867698435 - engine_prompt = cast( - Union[EngineTokensPrompt, EngineEmbedsPrompt], - engine_prompt) - self._log_inputs( request_id_item, engine_prompt, @@ -737,170 +737,6 @@ class OpenAIServing: tokenizer=tokenizer, ) - async def _tokenize_prompt_input_or_inputs_async( - self, - request: AnyRequest, - tokenizer: Optional[AnyTokenizer], - input_or_inputs: Optional[Union[str, list[str], list[int], - list[list[int]]]], - add_special_tokens: bool = True, - ) -> tuple[list[TextTokensPrompt], list[EmbedsPrompt]]: - """ - Tokenize/detokenize depending on the input format. - - According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_ - , each input can be a string or array of tokens. Note that each request - can pass one or more inputs. - """ - inputs_embeds = list[EmbedsPrompt]() - inputs_text = list[TextTokensPrompt]() - - truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", - None) - - if (truncate_prompt_tokens or 0) < 0: - truncate_prompt_tokens = self.max_model_len - - if (isinstance(request, CompletionRequest) - and request.prompt_embeds is not None): - inputs_embeds.extend( - self._load_prompt_embeds(request.prompt_embeds, - truncate_prompt_tokens)) - - # Empty prompts are okay as long as there are prompt embeddings - if input_or_inputs is None or (inputs_embeds - and input_or_inputs == ""): - return [], inputs_embeds - - # Although our type checking is based on mypy, - # VSCode Pyright extension should still work properly - # "is False" is required for Pyright to perform type narrowing - # See: https://github.com/microsoft/pyright/issues/7672 - - # Parse and batch the input prompts - batch_inputs = parse_and_batch_prompt(input_or_inputs) - - # Process each input in the batch concurrently - tasks = [] - for prompt_input in batch_inputs: - if prompt_input["is_tokens"] is False: - assert tokenizer is not None, ( - "Tokenizer is required for text prompts") - task = self._normalize_prompt_text_to_input( - request, - prompt_input["content"], - tokenizer=tokenizer, - add_special_tokens=add_special_tokens, - ) - else: - task = self._normalize_prompt_tokens_to_input( - request, prompt_input["content"], tokenizer=tokenizer) - tasks.append(task) - - # Wait for all tokenization tasks to complete - results = await asyncio.gather(*tasks) - inputs_text.extend(results) - - return inputs_text, inputs_embeds - - @overload - async def _preprocess_completion( - self, - request: Union[ - DetokenizeRequest, - EmbeddingCompletionRequest, - RerankRequest, - ClassificationRequest, - ScoreRequest, - TokenizeCompletionRequest, - ], - tokenizer: Optional[AnyTokenizer], - input_or_inputs: Union[str, list[str], list[int], list[list[int]]], - add_special_tokens: bool = ..., - ) -> tuple[list[TextTokensPrompt], list[EngineTokensPrompt]]: - ... - - @overload - async def _preprocess_completion( - self, - request: CompletionRequest, - tokenizer: Optional[AnyTokenizer], - input_or_inputs: Optional[Union[str, list[str], list[int], - list[list[int]]]], - add_special_tokens: bool = ..., - ) -> tuple[ - list[Union[TextTokensPrompt, EmbedsPrompt]], - list[Union[EngineTokensPrompt, EngineEmbedsPrompt]], - ]: - ... - - async def _preprocess_completion( - self, - request: CompletionLikeRequest, - tokenizer: Optional[AnyTokenizer], - input_or_inputs: Optional[Union[str, list[str], list[int], - list[list[int]]]], - add_special_tokens: bool = True, - ) -> tuple[ - Union[list[TextTokensPrompt], list[Union[TextTokensPrompt, - EmbedsPrompt]]], - Union[ - list[EngineTokensPrompt], - list[Union[EngineTokensPrompt, EngineEmbedsPrompt]], - ], - ]: - if (not isinstance(request, CompletionRequest) - and input_or_inputs is None): - raise ValueError( - "Prompt embeds with non-completion requests is not" - " currently supported.") - - ( - request_prompts_text, - request_prompts_embeds, - ) = await self._tokenize_prompt_input_or_inputs_async( - request, - tokenizer, - input_or_inputs, - add_special_tokens=add_special_tokens, - ) - - engine_prompts_text = [ - EngineTokensPrompt( - prompt_token_ids=request_prompt_text["prompt_token_ids"]) - for request_prompt_text in request_prompts_text - ] - cache_salt = (request.cache_salt if - (hasattr(request, "cache_salt") - and request.cache_salt is not None) else None) - if cache_salt: - for prompt_text in engine_prompts_text: - prompt_text["cache_salt"] = cache_salt - - # This check is equivalent to simply checking if - # `request_prompts_embeds` is empty, but it's difficult to propagate - # overloads to the private helper functions to enable this check. - # This overload is needed because only TextPrompts are allowed for - # non-completion requests and if we don't add the overload here, - # everywhere this function is used outside of serving_completion will - # need logic asserting that only text prompts are in the request. - if (not isinstance(request, CompletionRequest) - and input_or_inputs is not None): - return request_prompts_text, engine_prompts_text - - engine_prompts_embeds = [ - EngineEmbedsPrompt( - prompt_embeds=request_prompt_embeds["prompt_embeds"]) - for request_prompt_embeds in request_prompts_embeds - ] - if cache_salt: - for prompt_embed in engine_prompts_embeds: - prompt_embed["cache_salt"] = cache_salt - - request_prompts = request_prompts_embeds + request_prompts_text - engine_prompts = engine_prompts_embeds + engine_prompts_text - return request_prompts, engine_prompts - async def _preprocess_chat( self, request: Union[ChatLikeRequest, ResponsesRequest], @@ -929,7 +765,7 @@ class OpenAIServing: tokenizer, model_config=model_config, ) - conversation, mm_data_future = parse_chat_messages_futures( + conversation, mm_data_future, mm_uuids = parse_chat_messages_futures( messages, model_config, tokenizer, @@ -1006,6 +842,10 @@ class OpenAIServing: prompt_token_ids=prompt_inputs["prompt_token_ids"]) if mm_data is not None: engine_prompt["multi_modal_data"] = mm_data + + if mm_uuids is not None: + engine_prompt["multi_modal_uuids"] = mm_uuids + if request.mm_processor_kwargs is not None: engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs @@ -1069,41 +909,6 @@ class OpenAIServing: # OPTIMIZATION priority = orig_priority - 1 - @staticmethod - def _load_prompt_embeds( - prompt_embeds: Optional[Union[bytes, list[bytes]]], - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, - ) -> list[EmbedsPrompt]: - - def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt: - tensor = torch.load( - io.BytesIO(pybase64.b64decode(embed, validate=True)), - weights_only=True, - map_location=torch.device("cpu"), - ) - assert isinstance(tensor, torch.Tensor) and tensor.dtype in ( - torch.float32, - torch.bfloat16, - torch.float16, - ) - tensor = tensor.to_dense() - if tensor.dim() > 2: - tensor = tensor.squeeze(0) - assert tensor.dim() == 2 - if truncate_prompt_tokens is not None: - tensor = tensor[-truncate_prompt_tokens:] - return {"prompt_embeds": tensor} - - if prompt_embeds: - if isinstance(prompt_embeds, list): - return [ - _load_and_validate_embed(embed) for embed in prompt_embeds - ] - else: - return [_load_and_validate_embed(prompt_embeds)] - else: - return [] - def _log_inputs( self, request_id: str, @@ -1175,17 +980,6 @@ class OpenAIServing: return True return self.models.is_base_model(model_name) - def _get_model_name( - self, - model_name: Optional[str] = None, - lora_request: Optional[LoRARequest] = None, - ) -> str: - if lora_request: - return lora_request.lora_name - if not model_name: - return self.models.base_model_paths[0].name - return model_name - def clamp_prompt_logprobs( prompt_logprobs: Union[PromptLogprobs, diff --git a/vllm/entrypoints/openai/serving_pooling.py b/vllm/entrypoints/openai/serving_pooling.py index c08c0743ffca6..cac1d1ba56839 100644 --- a/vllm/entrypoints/openai/serving_pooling.py +++ b/vllm/entrypoints/openai/serving_pooling.py @@ -28,6 +28,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse, # yapf: enable from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.renderer import RenderConfig from vllm.entrypoints.utils import _validate_truncation_size from vllm.logger import init_logger from vllm.outputs import PoolingOutput, PoolingRequestOutput @@ -90,7 +91,7 @@ class OpenAIServingPooling(OpenAIServing): if error_check_ret is not None: return error_check_ret - model_name = self._get_model_name(request.model) + model_name = self.models.model_name() request_id = f"pool-{self._base_request_id(raw_request)}" created_time = int(time.time()) @@ -149,10 +150,7 @@ class OpenAIServingPooling(OpenAIServing): elif isinstance(request, PoolingCompletionRequest): engine_prompts = await renderer.render_prompt( prompt_or_prompts=request.input, - max_length=self.max_model_len, - truncate_prompt_tokens=truncate_prompt_tokens, - add_special_tokens=request.add_special_tokens, - cache_salt=getattr(request, 'cache_salt', None), + config=self._build_render_config(request), ) else: raise ValueError( @@ -270,3 +268,10 @@ class OpenAIServingPooling(OpenAIServing): data=items, usage=usage, ) + + def _build_render_config( + self, request: PoolingCompletionRequest) -> RenderConfig: + return RenderConfig( + max_length=self.max_model_len, + truncate_prompt_tokens=request.truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens) diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index a102d4a4a5e68..401ba6c53331c 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -44,8 +44,9 @@ from vllm.entrypoints.context import (ConversationContext, HarmonyContext, SimpleContext, StreamingHarmonyContext) from vllm.entrypoints.harmony_utils import ( get_developer_message, get_stop_tokens_for_assistant_actions, - get_system_message, get_user_message, parse_output_message, - parse_remaining_state, parse_response_input, render_for_completion) + get_system_message, get_user_message, has_custom_tools, + parse_output_message, parse_remaining_state, parse_response_input, + render_for_completion) from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block # yapf: disable @@ -236,7 +237,7 @@ class OpenAIServingResponses(OpenAIServing): try: lora_request = self._maybe_get_adapters(request) - model_name = self._get_model_name(request.model, lora_request) + model_name = self.models.model_name(lora_request) tokenizer = await self.engine_client.get_tokenizer(lora_request) if self.use_harmony: @@ -266,6 +267,8 @@ class OpenAIServingResponses(OpenAIServing): builtin_tool_list.append("browser") if self.tool_server.has_tool("python"): builtin_tool_list.append("python") + if self.tool_server.has_tool("container"): + builtin_tool_list.append("container") if self.tool_server is not None: available_tools = builtin_tool_list @@ -448,7 +451,8 @@ class OpenAIServingResponses(OpenAIServing): async with AsyncExitStack() as exit_stack: try: - await context.init_tool_sessions(self.tool_server, exit_stack) + await context.init_tool_sessions(self.tool_server, exit_stack, + request.request_id) async for _ in result_generator: pass except asyncio.CancelledError: @@ -710,13 +714,21 @@ class OpenAIServingResponses(OpenAIServing): # New conversation. reasoning_effort = (request.reasoning.effort if request.reasoning else None) + # Temporary: OpenAI types doesn't have container tool + # so we used MCP to cover that, up for change tool_types = [tool.type for tool in request.tools] + if envs.VLLM_GPT_OSS_USE_CONTAINER_TOOL: + tool_types.append("container") enable_browser = ("web_search_preview" in tool_types and self.tool_server is not None and self.tool_server.has_tool("browser")) enable_code_interpreter = ("code_interpreter" in tool_types and self.tool_server is not None and self.tool_server.has_tool("python")) + enable_container = ("container" in tool_types + and self.tool_server is not None + and self.tool_server.has_tool("container")) + with_custom_tools = has_custom_tools(tool_types) sys_msg = get_system_message( reasoning_effort=reasoning_effort, browser_description=self.tool_server.get_tool_description( @@ -725,11 +737,17 @@ class OpenAIServingResponses(OpenAIServing): python_description=self.tool_server.get_tool_description( "python") if enable_code_interpreter and self.tool_server is not None else None, + container_description=self.tool_server.get_tool_description( + "container") + if enable_container and self.tool_server is not None else None, + instructions=request.instructions, + with_custom_tools=with_custom_tools, ) messages.append(sys_msg) - dev_msg = get_developer_message(request.instructions, - request.tools) - messages.append(dev_msg) + if with_custom_tools: + dev_msg = get_developer_message( + instructions=request.instructions, tools=request.tools) + messages.append(dev_msg) else: # Continue the previous conversation. # FIXME(woosuk): Currently, request params like reasoning and @@ -1613,7 +1631,8 @@ class OpenAIServingResponses(OpenAIServing): async with AsyncExitStack() as exit_stack: processer = None if self.use_harmony: - await context.init_tool_sessions(self.tool_server, exit_stack) + await context.init_tool_sessions(self.tool_server, exit_stack, + request.request_id) processer = self._process_harmony_streaming_events else: processer = self._process_simple_streaming_events diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index 847c014a11dc3..24767ed66fc6a 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -353,7 +353,7 @@ class ServingScores(OpenAIServing): final_res_batch, request_id, created_time, - self._get_model_name(request.model), + self.models.model_name(), ) except asyncio.CancelledError: return self.create_error_response("Client disconnected") @@ -399,7 +399,7 @@ class ServingScores(OpenAIServing): return self.request_output_to_rerank_response( final_res_batch, request_id, - self._get_model_name(request.model), + self.models.model_name(), documents, top_n, ) diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 70cb6c21b2213..1efd9678571c4 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -22,6 +22,7 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest, # yapf: enable from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.renderer import RenderConfig from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -72,7 +73,7 @@ class OpenAIServingTokenization(OpenAIServing): [tool.model_dump() for tool in request.tools]) ( _, - request_prompts, + _, engine_prompts, ) = await self._preprocess_chat( request, @@ -90,15 +91,14 @@ class OpenAIServingTokenization(OpenAIServing): else: engine_prompts = await renderer.render_prompt( prompt_or_prompts=request.prompt, - add_special_tokens=request.add_special_tokens, - cache_salt=getattr(request, 'cache_salt', None), + config=self._build_render_config(request), ) except (ValueError, TypeError, jinja2.TemplateError) as e: logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(f"{e} {e.__cause__}") input_ids: list[int] = [] - for i, engine_prompt in enumerate(engine_prompts): + for engine_prompt in engine_prompts: self._log_inputs(request_id, engine_prompt, params=None, @@ -157,6 +157,9 @@ class OpenAIServingTokenization(OpenAIServing): return self.create_error_response( f"Failed to get tokenizer info: {str(e)}") + def _build_render_config(self, request: TokenizeRequest) -> RenderConfig: + return RenderConfig(add_special_tokens=request.add_special_tokens) + @dataclass class TokenizerInfo: diff --git a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py index 2055393d7ec71..37c360145b04a 100644 --- a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py @@ -60,7 +60,7 @@ class Internlm2ToolParser(ToolParser): if '<|action_start|>' not in current_text: self.position = len(current_text) return DeltaMessage(content=delta_text) - # if the tool call is sended, return an empty delta message + # if the tool call is sent, return an empty delta message # to make sure the finish_reason will be sent correctly. if self.current_tool_id > 0: return DeltaMessage(content='') @@ -89,7 +89,7 @@ class Internlm2ToolParser(ToolParser): try: parsable_arr = action - # tool calls are generated in an object in inernlm2 + # tool calls are generated in an object in internlm2 # it's not support parallel tool calls try: tool_call_arr: dict = partial_json_parser.loads( diff --git a/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py index 6bf44a4345a9d..9a9a19ce2188e 100644 --- a/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py @@ -176,7 +176,7 @@ class Llama4PythonicToolParser(ToolParser): index] += delta.function.arguments # HACK: serving_chat.py inspects the internal state of tool parsers - # when determining it's final streaming delta, automatically + # when determining its final streaming delta, automatically # adding autocompleted JSON. # These two lines avoid that nonsense while ensuring finish_reason # is set to tool_calls when at least one tool is called. diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index c0691f122904e..e6b300fd84e94 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -143,7 +143,7 @@ class MistralToolParser(ToolParser): except json.JSONDecodeError: # use a regex to find the part corresponding to the tool call. # NOTE: This use case should not happen if the model is trained - # correctly. It's a easy possible fix so it's included, but + # correctly. It's an easy possible fix so it's included, but # can be brittle for very complex / highly nested tool calls raw_tool_call = self.tool_call_regex.findall(tool_content)[0] function_call_arr = json.loads(raw_tool_call) diff --git a/vllm/entrypoints/renderer.py b/vllm/entrypoints/renderer.py index d3f3a8cfa5aa9..f0798afbcf212 100644 --- a/vllm/entrypoints/renderer.py +++ b/vllm/entrypoints/renderer.py @@ -2,18 +2,46 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio +import io from abc import ABC, abstractmethod +from dataclasses import dataclass from typing import Annotated, Optional, Union +import pybase64 +import torch from pydantic import Field from vllm.config import ModelConfig +from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.inputs.parse import parse_and_batch_prompt from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import AsyncMicrobatchTokenizer +@dataclass(frozen=True) +class RenderConfig: + """Configuration to control how prompts are prepared.""" + + max_length: Optional[int] = None + """Maximum allowable total input token length. If provided, + token inputs longer than this raise ``ValueError``.""" + + truncate_prompt_tokens: Optional[int] = None + """Number of tokens to keep. ``None`` means no truncation. + ``0`` yields an empty list (and skips embeds). + ``-1`` maps to ``model_config.max_model_len``.""" + + add_special_tokens: Optional[bool] = True + """Whether to add model-specific special tokens during tokenization.""" + + cache_salt: Optional[str] = None + """String to disambiguate prefix cache entries.""" + + needs_detokenization: Optional[bool] = False + """If True, detokenize IDs back to text for inclusion in outputs.""" + + class BaseRenderer(ABC): """ Base class for unified input processing and rendering. @@ -44,42 +72,105 @@ class BaseRenderer(ABC): @abstractmethod async def render_prompt( self, + *, prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]], - max_length: Optional[int] = None, - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None, - add_special_tokens: Optional[bool] = True, - cache_salt: Optional[str] = None, + config: "RenderConfig", ) -> list[EngineTokensPrompt]: """ - Convert input prompts into tokenized format for engine processing. - - This is the core method that transforms various input formats into - standardized TokensPrompt objects. Implementations should handle - tokenization, special token insertion, truncation, and validation - according to model requirements. - + Convert text or token inputs into engine-ready TokensPrompt objects. + + This method accepts text or token inputs and produces a + list of [`TokensPrompt`][vllm.inputs.data.TokensPrompt] objects + for the engine. + Args: - prompt_or_prompts: Input data in various formats: - - str: Single text prompt - - list[str]: Batch of text prompts - - list[int]: Pre-tokenized sequence - - list[list[int]]: Batch of pre-tokenized sequences - max_length: Maximum sequence length (endpoint-specific behavior) - truncate_prompt_tokens: Truncate to last N tokens - (None=no truncation, 0=empty) - add_special_tokens: Add model-specific tokens (e.g., [CLS], [SEP]) - to text inputs - cache_salt: Optional string to disambiguate cached prompts - + prompt_or_prompts: One of: + - ``str``: Single text prompt. + - ``list[str]``: Batch of text prompts. + - ``list[int]``: Single pre-tokenized sequence. + - ``list[list[int]]``: Batch of pre-tokenized sequences. + config: Render configuration controlling how prompts are prepared + (e.g., tokenization and length handling). + Returns: - list[EngineTokensPrompt]: Tokenized prompts ready for engine - consumption - + list[EngineTokensPrompt]: Engine-ready token prompts. + Raises: - ValueError: If input format is invalid or length limits exceeded + ValueError: If input formats are invalid or length limits exceeded. """ raise NotImplementedError + @abstractmethod + async def render_prompt_and_embeds( + self, + *, + prompt_or_prompts: Optional[Union[str, list[str], list[int], + list[list[int]]]] = None, + prompt_embeds: Optional[Union[bytes, list[bytes]]] = None, + config: "RenderConfig", + ) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]: + """ + Convert text/token and/or base64-encoded embeddings inputs into + engine-ready prompt objects using a unified RenderConfig. + + At least one of ``prompt_or_prompts`` or ``prompt_embeds`` must be + provided and non-empty. If both are omitted or empty (e.g., empty + string and empty list), a ``ValueError`` is raised. + + Args: + prompt_or_prompts: Text or token inputs to include. + prompt_embeds: Base64-encoded bytes (or list thereof) containing a + torch-saved tensor to be used as prompt embeddings. + config: Render configuration controlling how prompts are prepared + (e.g., tokenization and length handling). + + Returns: + list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]: + Engine-ready prompt objects. + + Raises: + ValueError: If both ``prompt_or_prompts`` and ``prompt_embeds`` + are omitted or empty (decoder prompt cannot be empty), or if + length limits are exceeded. + """ + raise NotImplementedError + + @classmethod + def load_prompt_embeds( + cls, + prompt_embeds: Union[bytes, list[bytes]], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=0)]] = None, + cache_salt: Optional[str] = None, + ) -> list[EngineEmbedsPrompt]: + """Load and validate base64-encoded embeddings into prompt objects.""" + + def _load_and_validate_embed(embed: bytes) -> EngineEmbedsPrompt: + tensor = torch.load( + io.BytesIO(pybase64.b64decode(embed, validate=True)), + weights_only=True, + map_location=torch.device("cpu"), + ) + assert isinstance(tensor, torch.Tensor) and tensor.dtype in ( + torch.float32, + torch.bfloat16, + torch.float16, + ) + tensor = tensor.to_dense() + if tensor.dim() > 2: + tensor = tensor.squeeze(0) + assert tensor.dim() == 2 + if truncate_prompt_tokens is not None: + tensor = tensor[-truncate_prompt_tokens:] + embeds_prompt = EngineEmbedsPrompt(prompt_embeds=tensor) + if cache_salt is not None: + embeds_prompt["cache_salt"] = cache_salt + return embeds_prompt + + if isinstance(prompt_embeds, list): + return [_load_and_validate_embed(embed) for embed in prompt_embeds] + + return [_load_and_validate_embed(prompt_embeds)] + class CompletionRenderer(BaseRenderer): @@ -91,60 +182,111 @@ class CompletionRenderer(BaseRenderer): AsyncMicrobatchTokenizer]] = None, ): super().__init__(model_config, tokenizer) - self.async_tokenizer_pool = async_tokenizer_pool or {} + self.async_tokenizer_pool = async_tokenizer_pool self.async_tokenizer: Optional[AsyncMicrobatchTokenizer] = None async def render_prompt( self, + *, prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]], - max_length: Optional[int] = None, - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None, - add_special_tokens: Optional[bool] = True, - cache_salt: Optional[str] = None, + config: "RenderConfig", ) -> list[EngineTokensPrompt]: """Implementation of prompt rendering for completion-style requests. Uses async tokenizer pooling for improved performance. See base class for detailed parameter documentation. """ - if truncate_prompt_tokens is not None: - if truncate_prompt_tokens == 0: - return [] - if truncate_prompt_tokens < 0: - truncate_prompt_tokens = self.model_config.max_model_len - if max_length is not None and truncate_prompt_tokens > max_length: - raise ValueError( - f"truncate_prompt_tokens ({truncate_prompt_tokens}) " - f"cannot be greater than max_length ({max_length}). " - f"Please select a smaller truncation size.") + truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens( + config.truncate_prompt_tokens, config.max_length) + if truncate_prompt_tokens == 0: + return [] # Parse and batch the input prompts batch_inputs = parse_and_batch_prompt(prompt_or_prompts) - rendered_prompts: list[EngineTokensPrompt] = [] - tokenize_tasks = [] + tasks = [] for prompt_input in batch_inputs: if prompt_input["is_tokens"] is True: # Token input - token_ids = self._maybe_apply_truncation( - prompt_input["content"], truncate_prompt_tokens) - rendered_prompts.append( - self._create_tokens_prompt(token_ids, max_length, - cache_salt)) + # Note: detokenization is needed when echo is enabled, + # where the input token IDs are decoded back to text. + task = self._maybe_detokenize(prompt_input["content"], + config.max_length, + truncate_prompt_tokens, + config.cache_salt, + config.needs_detokenization) else: # Text input - tokenize_task = asyncio.create_task( - self._tokenize(prompt_input["content"], max_length, - truncate_prompt_tokens, add_special_tokens, - cache_salt)) - tokenize_tasks.append(tokenize_task) + task = self._tokenize(prompt_input["content"], + config.max_length, + truncate_prompt_tokens, + config.add_special_tokens, + config.cache_salt) + tasks.append(task) # Wait for all text tokenization to finish - if tokenize_tasks: - tokenized_text_prompts = await asyncio.gather(*tokenize_tasks) - rendered_prompts.extend(tokenized_text_prompts) + if tasks: + tokenized_text_prompts = await asyncio.gather(*tasks) + return tokenized_text_prompts - return rendered_prompts + return [] + + async def render_prompt_and_embeds( + self, + *, + prompt_or_prompts: Optional[Union[str, list[str], list[int], + list[list[int]]]] = None, + prompt_embeds: Optional[Union[bytes, list[bytes]]] = None, + config: "RenderConfig", + ) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]: + """ + Render text/token prompts and/or precomputed embedding prompts. At + least one of `prompt_or_prompts` or `prompt_embeds` must be provided. + """ + truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens( + config.truncate_prompt_tokens, config.max_length) + if truncate_prompt_tokens == 0: + return [] + + rendered: list[Union[EngineTokensPrompt, EngineEmbedsPrompt]] = [] + + if prompt_embeds is not None: + rendered.extend( + self.load_prompt_embeds(prompt_embeds, truncate_prompt_tokens, + config.cache_salt)) + if prompt_or_prompts is None or prompt_or_prompts == "": + return rendered + + token_prompts = await self.render_prompt( + prompt_or_prompts=prompt_or_prompts, + config=config, + ) + rendered.extend(token_prompts) + + return rendered + + def _validate_and_normalize_truncate_tokens( + self, + truncate_prompt_tokens: Optional[int], + max_length: Optional[int], + ) -> Optional[int]: + """Validate and normalize truncate_prompt_tokens parameter.""" + if truncate_prompt_tokens is None: + return None + + if truncate_prompt_tokens == 0: + return 0 + + if truncate_prompt_tokens < 0: + truncate_prompt_tokens = self.model_config.max_model_len + + if max_length is not None and truncate_prompt_tokens > max_length: + raise ValueError( + f"truncate_prompt_tokens ({truncate_prompt_tokens}) " + f"cannot be greater than max_length ({max_length}). " + f"Please select a smaller truncation size.") + + return truncate_prompt_tokens def _maybe_apply_truncation( self, token_ids: list[int], @@ -186,30 +328,57 @@ class CompletionRenderer(BaseRenderer): max_length=truncate_prompt_tokens) return self._create_tokens_prompt(encoded.input_ids, max_length, - cache_salt) + cache_salt, text) + + async def _maybe_detokenize( + self, + token_ids: list[int], + max_length: Optional[int], + truncate_prompt_tokens: Optional[int], + cache_salt: Optional[str], + needs_detokenization: Optional[bool] = False, + ) -> EngineTokensPrompt: + """Optionally detokenize token IDs and build a tokens prompt.""" + token_ids = self._maybe_apply_truncation(token_ids, + truncate_prompt_tokens) + + prompt = None + if needs_detokenization is True: + async_tokenizer = self._get_async_tokenizer() + prompt = await async_tokenizer.decode(token_ids) + + return self._create_tokens_prompt(token_ids=token_ids, + max_length=max_length, + cache_salt=cache_salt, + prompt=prompt) def _get_async_tokenizer(self) -> AsyncMicrobatchTokenizer: """Get or create async tokenizer using shared pool.""" - if self.async_tokenizer is not None: - return self.async_tokenizer + async_tokenizer = self.async_tokenizer + if async_tokenizer is not None: + return async_tokenizer + + tokenizer = self.tokenizer if self.tokenizer is None: raise ValueError( "No tokenizer available for text input processing") - # Check shared pool first - if self.tokenizer in self.async_tokenizer_pool: - return self.async_tokenizer_pool[self.tokenizer] - - # Create new async tokenizer and add to pool - self.async_tokenizer = AsyncMicrobatchTokenizer(self.tokenizer) - self.async_tokenizer_pool[self.tokenizer] = self.async_tokenizer - return self.async_tokenizer + if self.async_tokenizer_pool is None: + async_tokenizer = AsyncMicrobatchTokenizer(tokenizer) + else: + async_tokenizer = self.async_tokenizer_pool.get(tokenizer) + if async_tokenizer is None: + async_tokenizer = AsyncMicrobatchTokenizer(tokenizer) + self.async_tokenizer_pool[tokenizer] = async_tokenizer + self.async_tokenizer = async_tokenizer + return async_tokenizer def _create_tokens_prompt( self, token_ids: list[int], max_length: Optional[int] = None, cache_salt: Optional[str] = None, + prompt: Optional[str] = None, ) -> EngineTokensPrompt: """Create validated EngineTokensPrompt.""" if max_length is not None and len(token_ids) > max_length: @@ -221,4 +390,6 @@ class CompletionRenderer(BaseRenderer): tokens_prompt = EngineTokensPrompt(prompt_token_ids=token_ids) if cache_salt is not None: tokens_prompt["cache_salt"] = cache_salt + if prompt is not None: + tokens_prompt["prompt"] = prompt return tokens_prompt diff --git a/vllm/entrypoints/tool.py b/vllm/entrypoints/tool.py index 758789a5e059d..f5f4d7d3b5565 100644 --- a/vllm/entrypoints/tool.py +++ b/vllm/entrypoints/tool.py @@ -4,6 +4,8 @@ import os from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any +from openai_harmony import Author, Message, Role, TextContent + from vllm.logger import init_logger if TYPE_CHECKING: @@ -99,6 +101,28 @@ class HarmonyPythonTool(Tool): return self.python_tool = PythonTool() + + async def validate(self): + if not self.enabled: + return + try: + message = Message( + author=Author(role=Role.ASSISTANT), + content=[TextContent(text="print('Hello, world!')")], + channel="analysis", + recipient="python", + content_type="code", + ) + msgs = [] + async for msg in self.python_tool.process(message): + msgs.append(msg) + assert msgs[0].content[0].text == "Hello, world!\n" + except Exception as e: + self.enabled = False + logger.warning_once( + "Code interpreter tool failed to initialize (%s), code " + "interpreter is disabled", e) + return logger.info_once("Code interpreter tool initialized") async def get_result(self, context: "ConversationContext") -> Any: diff --git a/vllm/entrypoints/tool_server.py b/vllm/entrypoints/tool_server.py index 2f28595f27c6a..056a571fb2fd1 100644 --- a/vllm/entrypoints/tool_server.py +++ b/vllm/entrypoints/tool_server.py @@ -86,7 +86,8 @@ class ToolServer(ABC): pass @abstractmethod - def new_session(self, tool_name: str) -> AbstractAsyncContextManager[Any]: + def new_session(self, tool_name: str, + session_id: str) -> AbstractAsyncContextManager[Any]: """ Create a session for the tool. """ @@ -124,7 +125,8 @@ class MCPToolServer(ToolServer): description=tool.description, parameters=tool.inputSchema) for tool in list_tools_response.tools - ]) + ], + ) self.harmony_tool_descriptions[tool_from_mcp.name] = tool_from_mcp if tool_from_mcp.name not in self.urls: self.urls[tool_from_mcp.name] = url @@ -142,14 +144,16 @@ class MCPToolServer(ToolServer): return self.harmony_tool_descriptions.get(tool_name) @asynccontextmanager - async def new_session(self, tool_name: str): + async def new_session(self, tool_name: str, session_id: str): from mcp import ClientSession from mcp.client.sse import sse_client url = self.urls.get(tool_name) + headers = {"x-session-id": session_id} if not url: raise KeyError(f"Tool '{tool_name}' is not supported") - async with sse_client(url=url) as streams, ClientSession( - *streams) as session: + async with sse_client(url=url, + headers=headers) as streams, ClientSession( + *streams) as session: await session.initialize() yield session @@ -158,10 +162,13 @@ class DemoToolServer(ToolServer): def __init__(self): self.tools: dict[str, Tool] = {} + + async def init_and_validate(self): browser_tool = HarmonyBrowserTool() + python_tool = HarmonyPythonTool() + await python_tool.validate() if browser_tool.enabled: self.tools["browser"] = browser_tool - python_tool = HarmonyPythonTool() if python_tool.enabled: self.tools["python"] = python_tool logger.info("DemoToolServer initialized with tools: %s", @@ -182,7 +189,7 @@ class DemoToolServer(ToolServer): raise ValueError(f"Unknown tool {tool_name}") @asynccontextmanager - async def new_session(self, tool_name: str): + async def new_session(self, tool_name: str, session_id: str): if tool_name not in self.tools: raise KeyError(f"Tool '{tool_name}' is not supported") yield self.tools[tool_name] diff --git a/vllm/envs.py b/vllm/envs.py index 50783eeb95a42..bb10c7cc2ac27 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -37,6 +37,7 @@ if TYPE_CHECKING: VLLM_CONFIGURE_LOGGING: int = 1 VLLM_LOGGING_LEVEL: str = "INFO" VLLM_LOGGING_PREFIX: str = "" + VLLM_LOGGING_STREAM: str = "ext://sys.stdout" VLLM_LOGGING_CONFIG_PATH: Optional[str] = None VLLM_LOGITS_PROCESSOR_THREADS: Optional[int] = None VLLM_LOG_STATS_INTERVAL: float = 10. @@ -162,13 +163,19 @@ if TYPE_CHECKING: VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False VLLM_ENABLE_RESPONSES_API_STORE: bool = False VLLM_USE_TRTLLM_ATTENTION: Optional[str] = None + VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION: bool = False VLLM_HAS_FLASHINFER_CUBIN: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False + VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: bool = False VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False + VLLM_GPT_OSS_USE_CONTAINER_TOOL: bool = False + VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False + VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True + VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME: str = "VLLM_OBJECT_STORAGE_SHM_BUFFER" def get_default_cache_root(): @@ -432,6 +439,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_LOGGING_LEVEL": lambda: os.getenv("VLLM_LOGGING_LEVEL", "INFO").upper(), + # this is used for configuring the default logging stream + "VLLM_LOGGING_STREAM": + lambda: os.getenv("VLLM_LOGGING_STREAM", "ext://sys.stdout"), + # if set, VLLM_LOGGING_PREFIX will be prepended to all log messages "VLLM_LOGGING_PREFIX": lambda: os.getenv("VLLM_LOGGING_PREFIX", ""), @@ -465,6 +476,8 @@ environment_variables: dict[str, Callable[[], Any]] = { # - "FLASHINFER": use flashinfer # - "FLASHMLA": use FlashMLA # - "FLASH_ATTN_MLA": use FlashAttention for MLA + # - "FLASHINFER_MLA": use FlashInfer for MLA + # - "CUTLASS_MLA": use CUTLASS for MLA "VLLM_ATTENTION_BACKEND": lambda: os.getenv("VLLM_ATTENTION_BACKEND", None), @@ -996,6 +1009,15 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8": lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "0"))), + # If set to 1, use the FlashInfer CUTLASS backend for + # MXFP8 (activation) x MXFP4 (weight) MoE. + # This is separate from the TRTLLMGEN path controlled by + # VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8. + "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS": + lambda: bool(int( + os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS", "0") + )), + # If set to 1, use the FlashInfer # BF16 (activation) x MXFP4 (weight) MoE backend. "VLLM_USE_FLASHINFER_MOE_MXFP4_BF16": @@ -1137,6 +1159,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_USE_TRTLLM_ATTENTION": lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None), + # If set to 1, when we use fp8 kv, we do not quantize Q to fp8 + "VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION": + lambda: bool(int(os.getenv("VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION", "0"))), + # If set, it means we pre-downloaded cubin files and flashinfer will # read the cubin files directly. "VLLM_HAS_FLASHINFER_CUBIN": @@ -1201,9 +1227,29 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_TUNED_CONFIG_FOLDER": lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None), + # Allows vllm use container tool + "VLLM_GPT_OSS_USE_CONTAINER_TOOL": + lambda: bool(int(os.getenv("VLLM_GPT_OSS_USE_CONTAINER_TOOL", "0"))), + + # Allows harmony instructions to be injected on system messages + "VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS": + lambda: bool( + int(os.getenv("VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", "0"))), + # Add optional custom scopes for profiling, disable to avoid overheads "VLLM_CUSTOM_SCOPES_FOR_PROFILING": lambda: bool(int(os.getenv("VLLM_CUSTOM_SCOPES_FOR_PROFILING", "0"))), + + # Represent block hashes in KV cache events as 64-bit integers instead of + # raw bytes. Defaults to True for backward compatibility. + "VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES": + lambda: bool(int(os.getenv("VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES", "1"))), + + # Name of the shared memory buffer used for object storage. + # Only effective when mm_config.mm_processor_cache_type == "shm". + "VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME": + lambda: os.getenv("VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME", + "VLLM_OBJECT_STORAGE_SHM_BUFFER"), } # --8<-- [end:env-vars-definition] @@ -1274,9 +1320,11 @@ def compute_hash() -> str: "VLLM_USE_FLASHINFER_MOE_FP8", "VLLM_USE_FLASHINFER_MOE_FP4", "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", + "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS", "VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "VLLM_USE_CUDNN_PREFILL", "VLLM_USE_TRTLLM_ATTENTION", + "VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION", "VLLM_ROCM_USE_AITER", "VLLM_ROCM_USE_AITER_PAGED_ATTN", "VLLM_ROCM_USE_AITER_LINEAR", diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 813232cd19281..a3c1d79a58b26 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -231,7 +231,7 @@ class ExecutorBase(ABC): def shutdown(self) -> None: """Shutdown the executor.""" - return + self.collective_rpc("shutdown") def __del__(self): self.shutdown() diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py index 37c3fe59c65dd..78d0ee6c1e3fc 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/executor/ray_distributed_executor.py @@ -117,10 +117,12 @@ class RayDistributedExecutor(DistributedExecutorBase): self.driver_worker.execute_method) def shutdown(self) -> None: - logger.info( - "Shutting down Ray distributed executor. If you see error log " - "from logging.cc regarding SIGTERM received, please ignore because " - "this is the expected termination process in Ray.") + if logger: + # Somehow logger can be None here. + logger.info( + "Shutting down Ray distributed executor. If you see error log " + "from logging.cc regarding SIGTERM received, please ignore " + "because this is the expected termination process in Ray.") if hasattr(self, "forward_dag") and self.forward_dag is not None: self.forward_dag.teardown() import ray diff --git a/vllm/executor/uniproc_executor.py b/vllm/executor/uniproc_executor.py index aabc9ed9b80a2..3b566e88a9ec2 100644 --- a/vllm/executor/uniproc_executor.py +++ b/vllm/executor/uniproc_executor.py @@ -1,7 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - import os +from concurrent.futures import Future, ThreadPoolExecutor +from functools import cached_property +from multiprocessing import Lock from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch @@ -10,9 +12,13 @@ import torch.distributed as dist import vllm.envs as envs from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.cache import worker_receiver_cache_from_config from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, run_method) from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType +from vllm.v1.executor.utils import get_and_update_mm_cache +from vllm.v1.outputs import AsyncModelRunnerOutput from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -27,15 +33,7 @@ class UniProcExecutor(ExecutorBase): """ self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, rpc_rank=0) - distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) - local_rank = 0 - # set local rank as the device index if specified - device_info = self.vllm_config.device_config.device.__str__().split( - ":") - if len(device_info) > 1: - local_rank = int(device_info[1]) - rank = 0 + distributed_init_method, rank, local_rank = self._distributed_args() is_driver_worker = True kwargs = dict( vllm_config=self.vllm_config, @@ -44,19 +42,58 @@ class UniProcExecutor(ExecutorBase): distributed_init_method=distributed_init_method, is_driver_worker=is_driver_worker, ) + self.mm_receiver_cache = worker_receiver_cache_from_config( + self.vllm_config, MULTIMODAL_REGISTRY, Lock()) + + self.async_output_thread: Optional[ThreadPoolExecutor] = None + if self.max_concurrent_batches > 1: + self.async_output_thread = ThreadPoolExecutor( + max_workers=1, thread_name_prefix="WorkerAsyncOutput") + self.collective_rpc("init_worker", args=([kwargs], )) self.collective_rpc("init_device") self.collective_rpc("load_model") + def _distributed_args(self) -> tuple[str, int, int]: + """Return (distributed_init_method, rank, local_rank).""" + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + # set local rank as the device index if specified + device_info = self.vllm_config.device_config.device.__str__().split( + ":") + local_rank = int(device_info[1]) if len(device_info) > 1 else 0 + return distributed_init_method, 0, local_rank + + @cached_property + def max_concurrent_batches(self) -> int: + return 2 if self.scheduler_config.async_scheduling else 1 + def collective_rpc(self, method: Union[str, Callable], timeout: Optional[float] = None, args: Tuple = (), - kwargs: Optional[Dict] = None) -> List[Any]: + kwargs: Optional[Dict] = None, + non_block: bool = False) -> List[Any]: if kwargs is None: kwargs = {} - answer = run_method(self.driver_worker, method, args, kwargs) - return [answer] + if self.mm_receiver_cache is not None and method == "execute_model": + get_and_update_mm_cache(self.mm_receiver_cache, args) + + if not non_block: + return [run_method(self.driver_worker, method, args, kwargs)] + + try: + result = run_method(self.driver_worker, method, args, kwargs) + if isinstance(result, AsyncModelRunnerOutput): + if (async_thread := self.async_output_thread) is not None: + return [async_thread.submit(result.get_output)] + result = result.get_output() + future = Future[Any]() + future.set_result(result) + except Exception as e: + future = Future[Any]() + future.set_exception(e) + return [future] def check_health(self) -> None: # UniProcExecutor will always be healthy as long as @@ -71,6 +108,10 @@ class UniProcExecutor(ExecutorBase): self.shutdown() return + def shutdown(self) -> None: + if worker := self.driver_worker: + worker.shutdown() + UniProcExecutorAsync = UniProcExecutor @@ -104,8 +145,9 @@ class ExecutorWithExternalLauncher(UniProcExecutor): assert not envs.VLLM_ENABLE_V1_MULTIPROCESSING, \ ("To get deterministic execution in V1, " "please set VLLM_ENABLE_V1_MULTIPROCESSING=0") - self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, - rpc_rank=0) + super()._init_executor() + + def _distributed_args(self) -> tuple[str, int, int]: # engines are launched in torchrun-compatible launchers # so we can use the env:// method. # required env vars: @@ -116,17 +158,7 @@ class ExecutorWithExternalLauncher(UniProcExecutor): distributed_init_method = "env://" rank = int(os.environ["RANK"]) local_rank = int(os.environ["LOCAL_RANK"]) - is_driver_worker = True - kwargs = dict( - vllm_config=self.vllm_config, - local_rank=local_rank, - rank=rank, - distributed_init_method=distributed_init_method, - is_driver_worker=is_driver_worker, - ) - self.collective_rpc("init_worker", args=([kwargs], )) - self.collective_rpc("init_device") - self.collective_rpc("load_model") + return distributed_init_method, rank, local_rank def determine_num_available_blocks(self) -> Tuple[int, int]: """ diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 065d0ab59291a..6a005aa634e85 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -52,6 +52,9 @@ class TokensPrompt(TypedDict): prompt_token_ids: list[int] """A list of token IDs to pass to the model.""" + prompt: NotRequired[str] + """The prompt text corresponding to the token IDs, if available.""" + token_type_ids: NotRequired[list[int]] """A list of token type IDs to pass to the cross encoder model.""" diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 094fcf021b619..22287aa6f41e0 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -258,8 +258,7 @@ class InputPreprocessor: tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, *, - mm_hash_overrides: Optional[Union[dict[str, list[str]], - MultiModalUUIDDict]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> MultiModalInputs: """ Apply the model's multi-modal processor to a multi-modal prompt, @@ -276,13 +275,23 @@ class InputPreprocessor: if mm_processor_kwargs is None: mm_processor_kwargs = {} - return mm_processor.apply( + mm_input = mm_processor.apply( prompt, mm_data, hf_processor_mm_kwargs=mm_processor_kwargs, tokenization_kwargs=tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) + mm_hashes = mm_input["mm_hashes"] + + # Validate that all mm items have a string as their hash + if not contains_only_strings(mm_hashes): + raise ValueError( + f"mm_hashes must contain only strings, got: {mm_hashes}. " + "This is likely due to an incorrect custom implementation of " + "MultiModalProcessor.apply method.") + + return mm_input async def _process_multimodal_async( self, @@ -292,8 +301,7 @@ class InputPreprocessor: tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, *, - mm_hash_overrides: Optional[Union[dict[str, list[str]], - MultiModalUUIDDict]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> MultiModalInputs: """ Async version of @@ -310,13 +318,23 @@ class InputPreprocessor: if mm_processor_kwargs is None: mm_processor_kwargs = {} - return mm_processor.apply( + mm_input = mm_processor.apply( prompt, mm_data, hf_processor_mm_kwargs=mm_processor_kwargs, tokenization_kwargs=tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) + mm_hashes = mm_input["mm_hashes"] + + # Validate that all mm items have a string as their hash + if not contains_only_strings(mm_hashes): + raise ValueError( + f"mm_hashes must contain only strings, got: {mm_hashes}. " + "This is likely due to an incorrect custom implementation of " + "MultiModalProcessor.apply method.") + + return mm_input def _process_embeds( self, @@ -370,8 +388,7 @@ class InputPreprocessor: tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, *, - mm_hash_overrides: Optional[Union[dict[str, list[str]], - MultiModalUUIDDict]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> Union[TokenInputs, MultiModalInputs]: prompt_token_ids = self._truncate_inputs( parsed_content["prompt_token_ids"], tokenization_kwargs) @@ -384,7 +401,7 @@ class InputPreprocessor: parsed_content.get("mm_processor_kwargs"), tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) else: inputs = token_inputs(prompt_token_ids=prompt_token_ids) @@ -400,8 +417,7 @@ class InputPreprocessor: tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, *, - mm_hash_overrides: Optional[Union[dict[str, list[str]], - MultiModalUUIDDict]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> Union[TokenInputs, MultiModalInputs]: prompt_token_ids = self._truncate_inputs( parsed_content["prompt_token_ids"], tokenization_kwargs) @@ -414,7 +430,7 @@ class InputPreprocessor: parsed_content.get("mm_processor_kwargs"), tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) else: inputs = token_inputs(prompt_token_ids=prompt_token_ids, ) @@ -430,8 +446,7 @@ class InputPreprocessor: tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, *, - mm_hash_overrides: Optional[Union[dict[str, list[str]], - MultiModalUUIDDict]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> Union[TokenInputs, MultiModalInputs]: prompt_text = parsed_content["prompt"] @@ -443,7 +458,7 @@ class InputPreprocessor: parsed_content.get("mm_processor_kwargs"), tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) else: prompt_token_ids = self._tokenize_prompt( @@ -467,8 +482,7 @@ class InputPreprocessor: tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, *, - mm_hash_overrides: Optional[Union[dict[str, list[str]], - MultiModalUUIDDict]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> Union[TokenInputs, MultiModalInputs]: prompt_text = parsed_content["prompt"] @@ -480,7 +494,7 @@ class InputPreprocessor: parsed_content.get("mm_processor_kwargs"), tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) else: prompt_token_ids = await self._tokenize_prompt_async( @@ -504,8 +518,7 @@ class InputPreprocessor: tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, *, - mm_hash_overrides: Optional[Union[dict[str, list[str]], - MultiModalUUIDDict]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> SingletonInputs: """ Extract the singleton inputs from a prompt. @@ -527,21 +540,21 @@ class InputPreprocessor: return self._process_tokens( parsed["content"], lora_request=lora_request, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) if parsed["type"] == "text": return self._process_text( parsed["content"], tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) if parsed["type"] == "str": return self._process_text( TextPrompt(prompt=parsed["content"]), tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) assert_never(parsed) @@ -552,8 +565,7 @@ class InputPreprocessor: tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, *, - mm_hash_overrides: Optional[Union[dict[str, list[str]], - MultiModalUUIDDict]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> SingletonInputs: """ Async version of @@ -567,21 +579,21 @@ class InputPreprocessor: return await self._process_tokens_async( parsed["content"], lora_request=lora_request, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) if parsed["type"] == "text": return await self._process_text_async( parsed["content"], tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) if parsed["type"] == "str": return await self._process_text_async( TextPrompt(prompt=parsed["content"]), tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) assert_never(parsed) @@ -692,8 +704,7 @@ class InputPreprocessor: prompt: PromptType, tokenization_kwargs: Optional[dict[str, Any]] = None, *, - mm_hash_overrides: Optional[Union[dict[str, list[str]], - MultiModalUUIDDict]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> EncoderDecoderInputs: """ For encoder/decoder models only: @@ -735,7 +746,7 @@ class InputPreprocessor: encoder_inputs = self._prompt_to_llm_inputs( prompt["encoder_prompt"], tokenization_kwargs=tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) if (decoder_input := prompt["decoder_prompt"]) is None: decoder_inputs = None @@ -751,7 +762,7 @@ class InputPreprocessor: inputs = self._prompt_to_llm_inputs( prompt, tokenization_kwargs=tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) if self.model_config.is_multimodal_model: # Encoder-Decoder Multimodal model @@ -768,8 +779,7 @@ class InputPreprocessor: prompt: PromptType, tokenization_kwargs: Optional[dict[str, Any]] = None, *, - mm_hash_overrides: Optional[Union[dict[str, list[str]], - MultiModalUUIDDict]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> EncoderDecoderInputs: """ Async version of @@ -782,7 +792,7 @@ class InputPreprocessor: encoder_task = self._prompt_to_llm_inputs_async( prompt["encoder_prompt"], tokenization_kwargs=tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) if (decoder_input := prompt["decoder_prompt"]) is None: @@ -792,7 +802,7 @@ class InputPreprocessor: decoder_task = self._prompt_to_llm_inputs_async( decoder_input, tokenization_kwargs=tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) encoder_inputs, decoder_inputs = await asyncio.gather( @@ -808,7 +818,7 @@ class InputPreprocessor: inputs = await self._prompt_to_llm_inputs_async( prompt, tokenization_kwargs=tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) if self.model_config.is_multimodal_model: # Encoder-Decoder Multimodal model @@ -836,8 +846,7 @@ class InputPreprocessor: tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, *, - mm_hash_overrides: Optional[Union[dict[str, list[str]], - MultiModalUUIDDict]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> DecoderOnlyInputs: """ For decoder-only models: @@ -858,7 +867,7 @@ class InputPreprocessor: prompt, tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) return self._build_decoder_only_llm_inputs(prompt_comps) @@ -869,8 +878,7 @@ class InputPreprocessor: tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, *, - mm_hash_overrides: Optional[Union[dict[str, list[str]], - MultiModalUUIDDict]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> DecoderOnlyInputs: """ Async version of @@ -880,7 +888,7 @@ class InputPreprocessor: prompt, tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) return self._build_decoder_only_llm_inputs(prompt_comps) @@ -891,8 +899,7 @@ class InputPreprocessor: tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, *, - mm_hash_overrides: Optional[Union[dict[str, list[str]], - MultiModalUUIDDict]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> ProcessorInputs: """Preprocess the input prompt.""" if self.model_config.is_encoder_decoder: @@ -901,7 +908,7 @@ class InputPreprocessor: return self._process_encoder_decoder_prompt( prompt, tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) if is_explicit_encoder_decoder_prompt(prompt): @@ -913,7 +920,7 @@ class InputPreprocessor: prompt, tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) async def preprocess_async( @@ -922,8 +929,7 @@ class InputPreprocessor: tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, *, - mm_hash_overrides: Optional[Union[dict[str, list[str]], - MultiModalUUIDDict]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> ProcessorInputs: """ Async version of @@ -935,7 +941,7 @@ class InputPreprocessor: return await self._process_encoder_decoder_prompt_async( prompt, tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) if is_explicit_encoder_decoder_prompt(prompt): @@ -947,9 +953,21 @@ class InputPreprocessor: prompt, tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) def clear_cache(self) -> None: if self.mm_processor_cache is not None: self.mm_processor_cache.clear_cache() + + +# Helper function to validate that a nested dictionary contains +# only strings or list of strings as the leaf values. +def contains_only_strings(obj: object): + if isinstance(obj, str): + return True + if isinstance(obj, list): + return all(isinstance(x, str) for x in obj) + if isinstance(obj, dict): + return all(contains_only_strings(v) for v in obj.values()) + return False diff --git a/vllm/logger.py b/vllm/logger.py index 8f06eb03c7f93..2861e0f1686c4 100644 --- a/vllm/logger.py +++ b/vllm/logger.py @@ -20,9 +20,10 @@ VLLM_CONFIGURE_LOGGING = envs.VLLM_CONFIGURE_LOGGING VLLM_LOGGING_CONFIG_PATH = envs.VLLM_LOGGING_CONFIG_PATH VLLM_LOGGING_LEVEL = envs.VLLM_LOGGING_LEVEL VLLM_LOGGING_PREFIX = envs.VLLM_LOGGING_PREFIX +VLLM_LOGGING_STREAM = envs.VLLM_LOGGING_STREAM _FORMAT = (f"{VLLM_LOGGING_PREFIX}%(levelname)s %(asctime)s " - "[%(filename)s:%(lineno)d] %(message)s") + "[%(fileinfo)s:%(lineno)d] %(message)s") _DATE_FORMAT = "%m-%d %H:%M:%S" DEFAULT_LOGGING_CONFIG = { @@ -38,7 +39,7 @@ DEFAULT_LOGGING_CONFIG = { "class": "logging.StreamHandler", "formatter": "vllm", "level": VLLM_LOGGING_LEVEL, - "stream": "ext://sys.stdout", + "stream": VLLM_LOGGING_STREAM, }, }, "loggers": { diff --git a/vllm/logging_utils/formatter.py b/vllm/logging_utils/formatter.py index 0affef10078dc..004b79f3ea6e2 100644 --- a/vllm/logging_utils/formatter.py +++ b/vllm/logging_utils/formatter.py @@ -2,16 +2,77 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import logging +from pathlib import Path + +from vllm import envs class NewLineFormatter(logging.Formatter): """Adds logging prefix to newlines to align multi-line messages.""" def __init__(self, fmt, datefmt=None, style="%"): - logging.Formatter.__init__(self, fmt, datefmt, style) + super().__init__(fmt, datefmt, style) + + self.use_relpath = envs.VLLM_LOGGING_LEVEL == "DEBUG" + if self.use_relpath: + self.root_dir = Path(__file__).resolve().parent.parent.parent def format(self, record): - msg = logging.Formatter.format(self, record) + + def shrink_path(relpath: Path) -> str: + """ + Shortens a file path for logging display: + - Removes leading 'vllm' folder if present. + - If path starts with 'v1', + keeps the first two and last two levels, + collapsing the middle as '...'. + - Otherwise, keeps the first and last two levels, + collapsing the middle as '...'. + - If the path is short, returns it as-is. + - Examples: + vllm/model_executor/layers/quantization/utils/fp8_utils.py -> + model_executor/.../quantization/utils/fp8_utils.py + vllm/model_executor/layers/quantization/awq.py -> + model_executor/layers/quantization/awq.py + vllm/v1/attention/backends/mla/common.py -> + v1/attention/backends/mla/common.py + + Args: + relpath (Path): The relative path to be shortened. + Returns: + str: The shortened path string for display. + """ + parts = list(relpath.parts) + new_parts = [] + if parts and parts[0] == "vllm": + parts = parts[1:] + if parts and parts[0] == "v1": + new_parts += parts[:2] + parts = parts[2:] + elif parts: + new_parts += parts[:1] + parts = parts[1:] + if len(parts) > 2: + new_parts += ["..."] + parts[-2:] + else: + new_parts += parts + return "/".join(new_parts) + + if self.use_relpath: + abs_path = getattr(record, "pathname", None) + if abs_path: + try: + relpath = Path(abs_path).resolve().relative_to( + self.root_dir) + except Exception: + relpath = Path(record.filename) + else: + relpath = Path(record.filename) + record.fileinfo = shrink_path(relpath) + else: + record.fileinfo = record.filename + + msg = super().format(record) if record.message != "": parts = msg.split(record.message) msg = msg.replace("\n", "\r\n" + parts[0]) diff --git a/vllm/lora/fully_sharded_layers.py b/vllm/lora/fully_sharded_layers.py deleted file mode 100644 index 7fc4cfe026aee..0000000000000 --- a/vllm/lora/fully_sharded_layers.py +++ /dev/null @@ -1,355 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# pylint: disable=unused-argument -from typing import TYPE_CHECKING, Optional, Union, cast - -import torch -import torch.nn as nn -from transformers import PretrainedConfig - -from vllm.config import LoRAConfig -from vllm.distributed.communication_op import ( - tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) -from vllm.distributed.parallel_state import get_tensor_model_parallel_rank -from vllm.lora.layers import (ColumnParallelLinearWithLoRA, - MergedColumnParallelLinearWithLoRA, - MergedQKVParallelLinearWithLoRA, - QKVParallelLinearWithLoRA, - RowParallelLinearWithLoRA) -from vllm.platforms import current_platform - -if TYPE_CHECKING: - pass - - -def _fully_sharded_can_replace(can_replace): - """ - decorator which adds the condition of fully sharded loras - intended to wrap can_replace_layer() - """ - - def dec(*args, **kwargs): - return (can_replace(*args, **kwargs) - and kwargs["lora_config"].fully_sharded_loras) - - return dec - - -def _mcp_apply(x, bias, layer: ColumnParallelLinearWithLoRA): - """ - For `ColumnParallelLinearWithLoRA` or classes that inherit from - `ColumnParallelLinearWithLoRA`, they share the same `apply` logic. - """ - assert (layer.n_slices == len(layer.lora_a_stacked) == len( - layer.lora_b_stacked) == len(layer.output_slices)) - if layer.lora_bias_stacked is not None: - assert layer.n_slices == len(layer.lora_bias_stacked) - - output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias) - - x = x.view(-1, x.shape[-1]) - output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape - - # Since communication is needed, the buffer is directly initialized as a - # tensor rather than a tuple of tensor. - buffers = torch.zeros( - (layer.n_slices, x.shape[0], layer.lora_a_stacked[0].shape[2]), - dtype=torch.float32, - device=x.device, - ) - - shrunk_buffers: Optional[torch.Tensor] = layer.punica_wrapper.add_shrink( - buffers, x, layer.lora_a_stacked, 1.0) - - if not current_platform.can_update_inplace(): - buffers = shrunk_buffers - - buffers = tensor_model_parallel_all_gather(buffers) - - lora_output: Optional[torch.Tensor] = layer.punica_wrapper.add_expand( - output, - buffers, - layer.lora_b_stacked, - layer.lora_bias_stacked, - layer.output_slices, - offset_start=0, - add_input=True) - - if not current_platform.can_update_inplace(): - output = lora_output - - output = output.view(*out_orig_shape) - # now have column partitioned and packed output - return output - - -# these layers are based on the tensor parallelism strategy given in -# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023, -# https://arxiv.org/abs/2311.03285. - - -class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA): - """ - Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also. - - Based on S-LoRA, slicing happens along the rank dim. - """ - - # For all LoRA layers where the `base_layer` is `ColumnParallelLinear`, - # their `lora_a` and `lora_b` have different sharding patterns. After - # completing the `lora_a` GEMM , a gather operation is performed. - # Therefore, the sharding of `lora_a` only needs to correspond with the - # gather operation. - def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: - tp_rank = get_tensor_model_parallel_rank() - shard_size = self.lora_a_stacked[0].shape[2] - start_idx = tp_rank * shard_size - lora_a = lora_a[:, start_idx:start_idx + shard_size] - return lora_a - - def apply(self, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return _mcp_apply(x, bias, self) - - @classmethod - @_fully_sharded_can_replace - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - # specifying kwargs so they can be easily accessed in decorator - return super().can_replace_layer( - source_layer=source_layer, - lora_config=lora_config, - packed_modules_list=packed_modules_list, - model_config=model_config, - decorate=False, - ) - - -class MergedColumnParallelLinearWithShardedLoRA( - MergedColumnParallelLinearWithLoRA): - """ - Differs from MergedColumnParallelLinearWithLoRA by slicing the - LoRA A's also. - - Based on S-LoRA, slicing happens along the rank dim. - """ - - def slice_lora_a( - self, lora_a: list[Union[torch.Tensor, None]] - ) -> list[Union[torch.Tensor, None]]: - #NOTE: lora_a contains 2 subloras, and each sublora could be None. - output_shard_size = self.lora_a_stacked[0].shape[2] - output_start_idx = self.tp_rank * output_shard_size - lora_a = [ - lora_a[0][:, output_start_idx:output_start_idx + - output_shard_size] if lora_a[0] is not None else None, - lora_a[1][:, output_start_idx:output_start_idx + - output_shard_size] if lora_a[1] is not None else None, - ] - return lora_a - - def apply(self, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return _mcp_apply(x, bias, self) - - @classmethod - @_fully_sharded_can_replace - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - # specifying kwargs so they can be easily accessed in decorator - return super().can_replace_layer( - source_layer=source_layer, - lora_config=lora_config, - packed_modules_list=packed_modules_list, - model_config=model_config, - decorate=False, - ) - - -class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA): - """ - Differs from QKVParallelLinearWithLoRA by slicing the - LoRA A's also. - - Based on S-LoRA, slicing happens along the rank dim. - """ - - def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: - tp_rank = get_tensor_model_parallel_rank() - shard_size = self.lora_a_stacked[0].shape[2] - start_idx = tp_rank * shard_size - lora_a = lora_a[:, start_idx:start_idx + shard_size] - return lora_a - - def apply(self, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return _mcp_apply(x, bias, self) - - @classmethod - @_fully_sharded_can_replace - def can_replace_layer(cls, source_layer: nn.Module, - lora_config: LoRAConfig, packed_modules_list: list, - model_config: Optional[PretrainedConfig]) -> bool: - # specifying kwargs so they can be easily accessed in decorator - return super().can_replace_layer( - source_layer=source_layer, - lora_config=lora_config, - packed_modules_list=packed_modules_list, - model_config=model_config, - decorate=False, - ) - - -class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA): - """ - Differs from MergedQKVParallelLinearWithLoRA by slicing the - LoRA A's also. - - Based on S-LoRA, slicing happens along the rank dim. - """ - - def slice_lora_a( - self, lora_a: list[Union[torch.Tensor, None]] - ) -> list[Union[torch.Tensor, None]]: - # NOTE: lora_a contains 3 subloras, and each sublora could be None. - shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)] - start_idx = [self.tp_rank * shard_size[i] for i in range(3)] - lora_a = [ - lora_a[0][:, start_idx[0]:start_idx[0] + - shard_size[0]] if lora_a[0] is not None else None, - lora_a[1][:, start_idx[1]:start_idx[1] + - shard_size[1]] if lora_a[1] is not None else None, - lora_a[2][:, start_idx[2]:start_idx[2] + - shard_size[2]] if lora_a[2] is not None else None, - ] - return lora_a - - def apply(self, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return _mcp_apply(x, bias, self) - - @classmethod - @_fully_sharded_can_replace - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - # specifying kwargs so they can be easily accessed in decorator - return super().can_replace_layer( - source_layer=source_layer, - lora_config=lora_config, - packed_modules_list=packed_modules_list, - model_config=model_config, - decorate=False, - ) - - -class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): - """ - Differs from RowParallelLinearWithLoRA by slicing the - LoRA B's also. - - Based on S-LoRA, slicing happens along the output dim. - This yields a combined partial sum from the row parallel base - layer and column partitioned output from the LoRA. - """ - - def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: - shard_size = self.lora_b_stacked[0].shape[2] - start_idx = self.tp_rank * shard_size - end_idx = (self.tp_rank + 1) * shard_size - lora_b = lora_b[:, start_idx:end_idx] - return lora_b - - def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: - if bias is None: - return bias - self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], - self.lora_bias_stacked) - shard_size = self.lora_bias_stacked[0].shape[2] - start_idx = self.tp_rank * shard_size - end_idx = (self.tp_rank + 1) * shard_size - bias = bias[start_idx:end_idx] - return bias - - def apply(self, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - output = self.base_layer.quant_method.apply(self.base_layer, x) - - x = x.view(-1, x.shape[-1]) - output, out_orig_shape = output.view(-1, - output.shape[-1]), output.shape - buffer = torch.zeros( - (self.n_slices, x.shape[0], self.lora_a_stacked[0].shape[2]), - dtype=torch.float32, - device=x.device, - ) - - shrunk_buffer: Optional[torch.Tensor] = self.punica_wrapper.add_shrink( - buffer, x, self.lora_a_stacked, 1.0) - if not current_platform.can_update_inplace(): - buffer = shrunk_buffer - - buffer = tensor_model_parallel_all_reduce(buffer) - - # following S-LoRA, allows the fusing of all_gather and all_reduce - # by adding the column partitioned lora output to a slice of output - # tensor, which is a partial sum due to row parallel. All that - # remains is a standard all_reduce. User should be aware though that - # the output is not the same as a normal row_parallel, it should be - # reduced before being used - # NOTE offset are based on the rank. - shard_size = self.lora_b_stacked[0].shape[2] - offset_start = self.tp_rank * shard_size - lora_output: Optional[torch.Tensor] = self.punica_wrapper.add_expand( - output, - buffer, - self.lora_b_stacked, - self.lora_bias_stacked, - self.output_slices, - offset_start=offset_start, - add_input=True, - ) - - if not current_platform.can_update_inplace(): - output = lora_output - - output = output.view(*out_orig_shape) - return output - - @classmethod - @_fully_sharded_can_replace - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - # specifying kwargs so they can be easily accessed in decorator - return super().can_replace_layer( - source_layer=source_layer, - lora_config=lora_config, - packed_modules_list=packed_modules_list, - model_config=model_config, - decorate=False, - ) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py deleted file mode 100644 index 6e4b69c303254..0000000000000 --- a/vllm/lora/layers.py +++ /dev/null @@ -1,1192 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# pylint: disable=unused-argument -import math -from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional, Union, cast - -import torch -import torch.nn as nn -import torch.nn.functional as F -from transformers import PretrainedConfig - -from vllm.adapter_commons.layers import AdapterMapping -from vllm.config import LoRAConfig -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - split_tensor_along_last_dim, - tensor_model_parallel_all_gather, - tensor_model_parallel_all_reduce) -from vllm.distributed.utils import divide -# yapf: disable -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearBase, - MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) -# yapf: enable -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) -from vllm.platforms import current_platform - -if TYPE_CHECKING: - from vllm.lora.punica_wrapper import PunicaWrapperBase - - -def _get_lora_device(base_layer: nn.Module) -> torch.device: - # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34 - """Returns the device for where to place the LoRA tensors.""" - # unquantizedLinear - if hasattr(base_layer, "weight"): - return base_layer.weight.device - # Compressed Tensor - elif hasattr(base_layer, "weight_packed"): - return base_layer.weight_packed.device - # GPTQ/AWQ - elif hasattr(base_layer, "qweight"): - return base_layer.qweight.device - # HQQ marlin - elif hasattr(base_layer, "W_q"): - return base_layer.W_q.device - else: - raise ValueError(f"Unsupported base layer: {base_layer}") - - -def _not_fully_sharded_can_replace(can_replace): - """ - decorator which adds the condition of not using fully sharded loras - intended to wrap can_replace_layer() - """ - - def dec(*args, **kwargs): - decorate = kwargs.pop("decorate") if "decorate" in kwargs else True - condition = (not kwargs["lora_config"].fully_sharded_loras - if decorate else True) - return can_replace(*args, **kwargs) and condition - - return dec - - -@dataclass -class LoRAMapping(AdapterMapping): - is_prefill: bool = False - - -class BaseLayerWithLoRA(nn.Module): - - def slice_lora_a( - self, lora_a: Union[torch.Tensor, list[Union[torch.Tensor, None]]] - ) -> Union[torch.Tensor, list[Union[torch.Tensor, None]]]: - """Slice lora a if splitting for tensor parallelism.""" - ... - - def slice_lora_b( - self, lora_b: Union[torch.Tensor, list[Union[torch.Tensor, None]]] - ) -> Union[torch.Tensor, list[Union[torch.Tensor, None]]]: - """Slice lora b if splitting with tensor parallelism.""" - ... - - def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None, - ) -> None: - """Initializes lora matrices.""" - ... - - def reset_lora(self, index: int): - """Resets the lora weights at index back to 0.""" - ... - - def set_lora( - self, - index: int, - lora_a: torch.Tensor, - lora_b: torch.Tensor, - embeddings_tensor: Optional[torch.Tensor], - bias: Optional[torch.Tensor] = None, - ): - """Overwrites lora tensors at index.""" - ... - - def set_mapping( - self, - punica_wrapper, - ): - self.punica_wrapper: PunicaWrapperBase = punica_wrapper - - @classmethod - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - """Returns True if the layer can be replaced by this LoRA layer.""" - raise NotImplementedError - - -class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): - - def __init__(self, base_layer: VocabParallelEmbedding) -> None: - super().__init__() - self.base_layer = base_layer - self.embeddings_slice: Optional[tuple[int, int]] - self.embeddings_weights: Optional[torch.Tensor] - - def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None) -> None: - - if self.base_layer.num_added_embeddings_per_partition > 0: - # We can start adding lora weights - self.embeddings_weights = self.base_layer.weight.data[ - self.base_layer.num_org_embeddings_per_partition:self. - base_layer.num_org_embeddings_per_partition + - self.base_layer.num_added_embeddings_per_partition] - self.embeddings_slice = ( - self.base_layer.shard_indices.added_vocab_start_index - - self.base_layer.org_vocab_size, - self.base_layer.shard_indices.added_vocab_end_index - - self.base_layer.org_vocab_size) - self.base_layer.weight.data[ - self.base_layer.num_org_embeddings_per_partition:].fill_(0) - else: - self.embeddings_slice = None - self.embeddings_weights = None - - self.embeddings_tensors = torch.zeros( - ( - max_loras, - lora_config.lora_extra_vocab_size, - self.base_layer.embedding_dim, - ), - dtype=self.base_layer.weight.dtype, - device=self.base_layer.weight.device, - ) - self.lora_a_stacked = torch.zeros( - ( - max_loras, - self.base_layer.org_vocab_size + - lora_config.lora_extra_vocab_size, - lora_config.max_lora_rank, - ), - dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, - ) - self.lora_b_stacked = torch.zeros( - ( - max_loras, - 1, - self.base_layer.embedding_dim, - lora_config.max_lora_rank, - ), - dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, - ) - self.lora_a_stacked_2d = self.lora_a_stacked.view( - self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1], - self.lora_a_stacked.shape[2], - ) - - def reset_lora(self, index: int): - self.lora_a_stacked[index] = 0 - self.lora_b_stacked[index] = 0 - self.embeddings_tensors[index] = 0 - - def set_lora( - self, - index: int, - lora_a: torch.Tensor, - lora_b: torch.Tensor, - embeddings_tensor: Optional[torch.Tensor], - bias: Optional[torch.Tensor] = None, - ): - self.reset_lora(index) - self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_( - lora_a, non_blocking=True) - self.lora_b_stacked[index, - 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( - lora_b.T, non_blocking=True) - if embeddings_tensor is not None: - self.embeddings_tensors[ - index, - :embeddings_tensor.shape[0], - :embeddings_tensor.shape[1], - ].copy_(embeddings_tensor, non_blocking=True) - if self.embeddings_slice is not None: - # TODO(yard1): Optimize this copy, we don't need to copy - # everything, just the modified part - embeddings = self.embeddings_tensors.view( - self.embeddings_tensors.shape[0] * - self.embeddings_tensors.shape[1], - self.embeddings_tensors.shape[2], - )[self.embeddings_slice[0]:self.embeddings_slice[1]] - assert self.embeddings_weights is not None - self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1, - 1, 0) - - # NB: Don't use torch.narrow here. torch.narrow triggers some - # Dynamic Shape specialization in torch.compile - num_tokens = x.shape[0] - indices_1 = self.punica_wrapper._embeddings_indices[1][:num_tokens] - indices_0 = self.punica_wrapper._embeddings_indices[0][:num_tokens] - - full_lora_a_embeddings = F.embedding( - x + indices_1, - self.lora_a_stacked_2d, - ) - full_output = self.base_layer.forward(x + - (indices_0 * added_tokens_mask)) - - full_output_org = full_output - if full_output.ndim == 3: - full_output = full_output.view( - full_output.shape[0] * full_output.shape[1], -1) - if full_lora_a_embeddings.ndim == 3: - full_lora_a_embeddings = full_lora_a_embeddings.view( - full_lora_a_embeddings.shape[0] * - full_lora_a_embeddings.shape[1], - -1, - ) - - lora_output: Optional[ - torch.Tensor] = self.punica_wrapper.add_lora_embedding( - full_output, - full_lora_a_embeddings, - self.lora_b_stacked, - add_input=True) - - if not current_platform.can_update_inplace(): - full_output = lora_output - - return full_output.view_as(full_output_org) - - @classmethod - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - return type(source_layer) is VocabParallelEmbedding - - @property - def weight(self): - return self.base_layer.weight - - -class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): - - def __init__(self, base_layer: LinearBase): - super().__init__() - self.base_layer = base_layer - self.input_size = self.base_layer.input_size - self.device = _get_lora_device(self.base_layer) - self.lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] = None - - self.output_slices: tuple[int, ...] - self.tp_size: int - self.output_size: int - self.n_slices: int - - def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None, - ) -> None: - self.lora_config = lora_config - # - if isinstance(self.base_layer, ReplicatedLinear): - lora_a_out_size = lora_config.max_lora_rank - lora_b_out_size = self.output_size - - elif isinstance(self.base_layer, ColumnParallelLinear): - lora_a_out_size = (lora_config.max_lora_rank if - not lora_config.fully_sharded_loras else divide( - lora_config.max_lora_rank, self.tp_size)) - lora_b_out_size = self.output_size - - elif isinstance(self.base_layer, RowParallelLinear): - lora_a_out_size = lora_config.max_lora_rank - lora_b_out_size = (self.output_size if - not lora_config.fully_sharded_loras else divide( - self.output_size, self.tp_size)) - else: - raise NotImplementedError - - self.lora_a_stacked = tuple( - torch.zeros( - max_loras, - 1, - lora_a_out_size, - self.input_size, - dtype=lora_config.lora_dtype, - device=self.device, - ) for _ in range(self.n_slices)) - self.lora_b_stacked = tuple( - torch.zeros( - max_loras, - 1, - lora_b_out_size, - lora_config.max_lora_rank, - dtype=lora_config.lora_dtype, - device=self.device, - ) for _ in range(self.n_slices)) - if lora_config.bias_enabled: - lora_bias_out_size = lora_b_out_size - self.lora_bias_stacked = tuple( - torch.zeros( - max_loras, - 1, - lora_bias_out_size, - dtype=lora_config.lora_dtype, - device=self.device, - ) for _ in range(self.n_slices)) - self.output_slices = (self.lora_b_stacked[0].shape[2], ) - - def reset_lora(self, index: int): - for s_index in range(self.n_slices): - self.lora_a_stacked[s_index][index] = 0 - self.lora_b_stacked[s_index][index] = 0 - if self.lora_config.bias_enabled: - # Make mypy happy - self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], - self.lora_bias_stacked) - self.lora_bias_stacked[s_index][index] = 0 - - def set_lora( - self, - index: int, - lora_a: torch.Tensor, - lora_b: torch.Tensor, - embeddings_tensor: Optional[torch.Tensor], - lora_bias: Optional[torch.Tensor] = None, - ): - # Except for QKVParallelLinearWithLoRA and - # MergedColumnParallelLinearWithLoRA, all other linear LoRA layers - # store weights in a tuple of size 1. These two layers will - # override this function. - assert (len(self.lora_a_stacked) == len(self.lora_b_stacked) == - self.n_slices == 1) - - self.reset_lora(index) - if self.tp_size > 1: - lora_a = self.slice_lora_a(lora_a) - lora_b = self.slice_lora_b(lora_b) - if lora_bias is not None: - lora_bias = self.slice_bias(lora_bias) - - self.lora_a_stacked[0][index, - 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( - lora_a.T, non_blocking=True) - self.lora_b_stacked[0][index, - 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( - lora_b.T, non_blocking=True) - if lora_bias is not None: - - self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], - self.lora_bias_stacked) - assert len(self.lora_bias_stacked) - self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_( - lora_bias.T, non_blocking=True) - - def apply(self, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - output = self.base_layer.quant_method.apply(self.base_layer, x, bias) - - # In transformers backend, x and output have extra batch dimension like - # (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim), - # therefore we need to flatten the batch dimensions. - if x.ndim == 3 and output.ndim == 3: - output = output.flatten(0, 1) - x = x.flatten(0, 1) - - lora_output: Optional[ - torch.Tensor] = self.punica_wrapper.add_lora_linear( - output, x, self.lora_a_stacked, self.lora_b_stacked, - self.lora_bias_stacked, 1.0, self.output_slices) - if not current_platform.can_update_inplace(): - output = lora_output - - return output - - @property - def weight(self) -> torch.Tensor: - - # unquantizedLinear - if hasattr(self.base_layer, "weight"): - return self.base_layer.weight - # Compressed Tensor - elif hasattr(self.base_layer, "weight_packed"): - return self.base_layer.weight_packed - # GPTQ/AWQ - elif hasattr(self.base_layer, "qweight"): - return self.base_layer.qweight - # marlin - elif hasattr(self.base_layer, "B"): - return self.base_layer.B - # HQQ marlin - elif hasattr(self.base_layer, "W_q"): - return self.base_layer.W_q - else: - raise ValueError(f"Unsupported base layer: {self.base_layer}") - - @property - def bias(self) -> Optional[torch.Tensor]: - if hasattr(self.base_layer, "bias"): - return self.base_layer.bias - else: - return None - - -class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA): - - def __init__(self, base_layer: ReplicatedLinear) -> None: - super().__init__(base_layer, ) - # To ensure interface compatibility, set to 1 always. - self.tp_size = 1 - self.output_size = self.base_layer.output_size - self.n_slices = 1 - - def forward( - self, input_: torch.Tensor - ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: - """Forward of ReplicatedLinearWithLoRA - - Args: - input_: Tensor whose last dimension is `input_size`. - - Returns: - - output - - bias - """ - bias = (self.base_layer.bias - if not self.base_layer.skip_bias_add else None) - - # Matrix multiply. - output = self.apply(input_, bias) - - output_bias = (self.base_layer.bias - if self.base_layer.skip_bias_add else None) - - if not self.base_layer.return_bias: - return output - - return output, output_bias - - # ReplicatedLinear should always be replaced, regardless of the fully - # sharded LoRAs setting, because it is, by definition, copied per GPU. - @classmethod - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - return type(source_layer) is ReplicatedLinear - - -class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): - """ - LoRA on top of ColumnParallelLinear layer. - LoRA B is sliced for tensor parallelism. - There are two types for the `base_layer`: - 1. ColumnParallelLinear, e.g.`dense_h_to_4h` in `FalconForCausalLM`. - 2. MergedColumnParallelLinear, e.g.`gate_up_proj` in `Phi3ForCausalLM`. - """ - - def __init__(self, base_layer: ColumnParallelLinear) -> None: - super().__init__(base_layer) - # The base_layer type is ColumnParallelLinear or - # MergedColumnParallelLinear, their weight sharding logic is - # inconsistent when TP is greater than 1. - self.is_merged_col_linear = type( - base_layer) is MergedColumnParallelLinear - self.tp_size = get_tensor_model_parallel_world_size() - self.output_size = self.base_layer.output_size_per_partition - # There is only one LoRA layer - self.n_slices = 1 - - def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: - return lora_a - - def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: - # Applicable to cases where the base_layer is - # MergedColumnParallelLinear. - if self.is_merged_col_linear: - tp_rank = get_tensor_model_parallel_rank() - shard_size = self.output_size // 2 - offset = lora_b.shape[-1] // 2 - - left_weight = lora_b[:, tp_rank * shard_size:(tp_rank + 1) * - shard_size] - right_weight = lora_b[:, offset + tp_rank * shard_size:offset + - (tp_rank + 1) * shard_size] - lora_b = torch.cat([left_weight, right_weight], dim=1) - # Applicable to cases where the base_layer is - # ColumnParallelLinear. - else: - tensor_model_parallel_rank = get_tensor_model_parallel_rank() - shard_size = self.output_size - start_idx = tensor_model_parallel_rank * shard_size - end_idx = (tensor_model_parallel_rank + 1) * shard_size - lora_b = lora_b[:, start_idx:end_idx] - return lora_b - - def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: - # TODO: Fix the slicing logic of bias. - if bias is None: - return bias - tensor_model_parallel_rank = get_tensor_model_parallel_rank() - shard_size = self.output_size - start_idx = tensor_model_parallel_rank * shard_size - end_idx = (tensor_model_parallel_rank + 1) * shard_size - bias = bias[start_idx:end_idx] - return bias - - def forward( - self, input_: torch.Tensor - ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: - """Forward of ColumnParallelLinear - - Args: - input_: Tensor whose last dimension is `input_size`. - - Returns: - - output - - bias - """ - bias = (self.base_layer.bias - if not self.base_layer.skip_bias_add else None) - - # Matrix multiply. - output_parallel = self.apply(input_, bias) - if self.base_layer.gather_output: - # All-gather across the partitions. - output = tensor_model_parallel_all_gather(output_parallel) - else: - output = output_parallel - - if not self.base_layer.return_bias: - return output - - output_bias = (self.base_layer.bias - if self.base_layer.skip_bias_add else None) - return output, output_bias - - @classmethod - @_not_fully_sharded_can_replace - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - return type(source_layer) is ColumnParallelLinear or ( - type(source_layer) is MergedColumnParallelLinear - and len(packed_modules_list) == 1) - - -class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): - """ColumnParallelLinear layer that is composed of 2 sublayers (slices) - packed together (e.g. gate_proj + up_proj -> gate_up_proj). - - This means we have 2 LoRAs, each applied to one half of the layer. - - Both slices must have the same size. - """ - - def __init__( - self, base_layer: Union[MergedColumnParallelLinear, - QKVParallelLinear]) -> None: - super().__init__(base_layer) - # There are two LoRA layers - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - # the output_sizes in MergedColumnParallelLinear is not sharded by tp - # we need to divide it by the tp_size to get correct slices size - output_sizes = self.base_layer.output_sizes - self.output_slices = tuple( - divide(output_size, self.tp_size) for output_size in output_sizes) - self.n_slices = len(self.output_slices) - self.output_ids = (self.tp_rank, ) * self.n_slices - - def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None, - ) -> None: - """ - The main reason for overriding this function is to enhance code - maintainability. - """ - self.lora_config = lora_config - - lora_a_output_size_per_partition = ( - lora_config.max_lora_rank if not lora_config.fully_sharded_loras - else divide(lora_config.max_lora_rank, self.tp_size)) - - self.lora_a_stacked = tuple( - torch.zeros( - max_loras, - 1, - lora_a_output_size_per_partition, - self.input_size, - dtype=lora_config.lora_dtype, - device=self.device, - ) for _ in range(self.n_slices)) - self.lora_b_stacked = tuple( - torch.zeros( - max_loras, - 1, - output_size, - lora_config.max_lora_rank, - dtype=lora_config.lora_dtype, - device=self.device, - ) for output_size in self.output_slices) - if lora_config.bias_enabled: - self.lora_bias_stacked = tuple( - torch.zeros( - max_loras, - 1, - output_size, - dtype=lora_config.lora_dtype, - device=self.device, - ) for output_size in self.output_slices) - - def slice_lora_a( - self, lora_a: list[Union[torch.Tensor, None]] - ) -> list[Union[torch.Tensor, None]]: - return lora_a - - def slice_lora_b( - self, lora_b: list[Union[torch.Tensor, None]] - ) -> list[Union[torch.Tensor, None]]: - sliced_lora_b = [None] * self.n_slices - for i, (shard_id, shard_size) in enumerate( - zip(self.output_ids, self.output_slices)): - if (lora_b_i := lora_b[i]) is not None: - sliced_lora_b[i] = lora_b_i[:, - shard_size * shard_id:shard_size * - (shard_id + 1)] - return sliced_lora_b - - def slice_bias( - self, bias: list[Union[torch.Tensor, - None]]) -> list[Union[torch.Tensor, None]]: - for i, (shard_id, shard_size) in enumerate( - zip(self.output_ids, self.output_slices)): - if (bias_i := bias[i]) is not None: - bias[i] = bias_i[shard_size * shard_id:shard_size * - (shard_id + 1)] - return bias - - def set_lora( - self, - index: int, - lora_a: torch.Tensor, - lora_b: torch.Tensor, - embeddings_tensor: Optional[torch.Tensor], - lora_bias: Optional[torch.Tensor] = None, - ): - self.reset_lora(index) - - if self.tp_size > 1: - lora_a = self.slice_lora_a(lora_a) - lora_b = self.slice_lora_b(lora_b) - if lora_bias is not None: - lora_bias = self.slice_bias(lora_bias) - - for i in range(self.n_slices): - if (lora_a_i := lora_a[i]) is not None: - self.lora_a_stacked[i][ - index, 0, :lora_a_i.shape[1], :lora_a_i.shape[0]].copy_( - lora_a_i.T, non_blocking=True) - if (lora_b_i := lora_b[i]) is not None: - self.lora_b_stacked[i][ - index, 0, :lora_b_i.shape[1], :lora_b_i.shape[0]].copy_( - lora_b_i.T, non_blocking=True) - - if lora_bias is not None: - self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], - self.lora_bias_stacked) - for i in range(self.n_slices): - if (lora_bias_i := lora_bias[i]) is not None: - self.lora_bias_stacked[i][index, - 0, :lora_bias_i.shape[0]].copy_( - lora_bias_i.T, - non_blocking=True) - - @classmethod - @_not_fully_sharded_can_replace - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - return (type(source_layer) is MergedColumnParallelLinear - and len(packed_modules_list) == 2) - - -class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): - """ - ColumnParallelLinear layer that is specifically designed for - qkv_proj. Certain models, such as chatglm3 and baichuan-7b, - only contains a single LoRA within their qkv_proj layer. - - During inference with Tensor Parallel, the weights of lora_b - must be accurately partitioned according to the respective ranks. - - Q slice may have different shape than K and V slices (which both have - the same shape). - """ - - def __init__(self, base_layer: QKVParallelLinear) -> None: - super().__init__(base_layer) - self.q_proj_total_size = (self.base_layer.total_num_heads * - self.base_layer.head_size) - self.q_proj_shard_size = (self.base_layer.num_heads * - self.base_layer.head_size) - self.kv_proj_shard_size = (self.base_layer.num_kv_heads * - self.base_layer.head_size) - self.kv_proj_total_size = (self.base_layer.total_num_kv_heads * - self.base_layer.head_size) - # There is only one LoRA layer - self.n_slices = 1 - - def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: - tp_rank = get_tensor_model_parallel_rank() - self.q_shard_id = tp_rank - self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas - lora_b_q = lora_b[:, self.q_proj_shard_size * - self.q_shard_id:self.q_proj_shard_size * - (self.q_shard_id + 1)] - k_offset = self.q_proj_total_size - lora_b_k = lora_b[:, k_offset + - self.kv_proj_shard_size * self.kv_shard_id:k_offset + - self.kv_proj_shard_size * (self.kv_shard_id + 1)] - v_offset = k_offset + self.kv_proj_total_size - lora_b_v = lora_b[:, v_offset + - self.kv_proj_shard_size * self.kv_shard_id:v_offset + - self.kv_proj_shard_size * (self.kv_shard_id + 1)] - lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1) - return lora_b - - def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: - bias_q = bias[self.q_proj_shard_size * - self.q_shard_id:self.q_proj_shard_size * - (self.q_shard_id + 1)] - k_offset = self.q_proj_total_size - bias_k = bias[k_offset + - self.kv_proj_shard_size * self.kv_shard_id:k_offset + - self.kv_proj_shard_size * (self.kv_shard_id + 1)] - v_offset = k_offset + self.kv_proj_total_size - bias_v = bias[v_offset + - self.kv_proj_shard_size * self.kv_shard_id:v_offset + - self.kv_proj_shard_size * (self.kv_shard_id + 1)] - bias = torch.cat([bias_q, bias_k, bias_v], dim=1) - return bias - - @classmethod - @_not_fully_sharded_can_replace - def can_replace_layer(cls, source_layer: nn.Module, - lora_config: LoRAConfig, packed_modules_list: list, - model_config: Optional[PretrainedConfig]) -> bool: - return type(source_layer) is QKVParallelLinear and len( - packed_modules_list) == 1 - - -class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA): - """MergedColumnParallelLinear layer that is composed of 3 sublayers (slices) - packed together in qkv proj fashion - (q_proj + k_proj + v_proj -> qkv_proj). - - This means we have 3 LoRAs, each applied to one slice of the layer. - - Q slice may have different shape than K and V slices (which both have - the same shape). - """ - - def __init__(self, base_layer: QKVParallelLinear) -> None: - super().__init__(base_layer) - # There are three LoRA layer. - self.n_slices = len(self.base_layer.output_sizes) - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - - self.q_proj_shard_size = (self.base_layer.num_heads * - self.base_layer.head_size) - self.kv_proj_shard_size = (self.base_layer.num_kv_heads * - self.base_layer.head_size) - self.q_shard_id = self.tp_rank - self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas - - self.output_slices = ( - self.q_proj_shard_size, - self.kv_proj_shard_size, - self.kv_proj_shard_size, - ) - self.output_ids = ( - self.q_shard_id, - self.kv_shard_id, - self.kv_shard_id, - ) - - def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None, - ) -> None: - """ - The main reason for overloading this function is to handle inconsistent - weight dimensions in qkv lora. - """ - super().create_lora_weights(max_loras, lora_config, model_config) - - @classmethod - @_not_fully_sharded_can_replace - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - return (type(source_layer) is QKVParallelLinear - and len(packed_modules_list) == 3) - - -#TODO: Implement this -class QKVCrossParallelLinearWithLoRA(BaseLayerWithLoRA): - pass - - -class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): - - def __init__(self, base_layer: RowParallelLinear) -> None: - super().__init__(base_layer) - - self.tp_size = get_tensor_model_parallel_world_size() - # reset input_size - self.input_size = self.base_layer.input_size_per_partition - self.output_size = self.base_layer.output_size - - self.tp_rank = get_tensor_model_parallel_rank() - # There is only one LoRA layer. - self.n_slices = 1 - - def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: - - shard_size = self.input_size - start_idx = self.tp_rank * shard_size - end_idx = (self.tp_rank + 1) * shard_size - lora_a = lora_a[start_idx:end_idx, :] - return lora_a - - def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: - return lora_b - - def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: - return bias - - def forward( - self, input_: torch.Tensor - ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: - """Forward of RowParallelLinear - - Args: - input_: tensor whose last dimension is `input_size`. If - `input_is_parallel` is set, then the last dimension - is `input_size // tp_size`. - - Returns: - - output - - bias - """ - # set up backprop all-reduce. - if self.base_layer.input_is_parallel: - input_parallel = input_ - else: - # TODO: simplify code below - splitted_input = split_tensor_along_last_dim( - input_, num_partitions=self.base_layer.tp_size) - input_parallel = splitted_input[self.tp_rank].contiguous() - - # Matrix multiply. - output_parallel = self.apply(input_parallel) - if self.base_layer.reduce_results and self.base_layer.tp_size > 1: - output_ = tensor_model_parallel_all_reduce(output_parallel) - else: - output_ = output_parallel - - if not self.base_layer.skip_bias_add: - output = (output_ + self.base_layer.bias - if self.base_layer.bias is not None else output_) - output_bias = None - else: - output = output_ - output_bias = self.base_layer.bias - - if not self.base_layer.return_bias: - return output - - return output, output_bias - - @classmethod - @_not_fully_sharded_can_replace - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - return type(source_layer) is RowParallelLinear - - -class LogitsProcessorWithLoRA(BaseLayerWithLoRA): - """ - LoRA wrapper for LogitsProcessor, with extra logic to handle the - application of the LoRA adapter and added LoRA vocabulary. - - Args: - base_layer: LogitsProcessor layer - hidden_size: hidden size of the model - dtype: data type of the model - device: device of the model - sharded_to_full_mapping: index mapping from sharded vocab to full vocab - received from base_layer.get_sharded_to_full_mapping(). If None, - no reindexing will be done. - """ - - def __init__(self, base_layer: LogitsProcessor, hidden_size: int, - dtype: torch.dtype, device: torch.device, - sharded_to_full_mapping: Optional[list[int]]) -> None: - super().__init__() - self.base_layer = base_layer - self.hidden_size = hidden_size - self.dtype = dtype - self.device = device - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - self.sharded_to_full_mapping = sharded_to_full_mapping - - @property - def logits_as_input(self): - return self.base_layer.logits_as_input - - @property - def vocab_size(self): - return self.base_layer.vocab_size - - @property - def scale(self): - return self.base_layer.scale - - @property - def soft_cap(self): - return self.base_layer.soft_cap - - @property - def use_all_gather(self): - return self.base_layer.use_all_gather - - @property - def org_vocab_size(self): - return self.base_layer.org_vocab_size - - @property - def include_gpu_probs_tensor(self): - return self.base_layer.include_gpu_probs_tensor - - @property - def should_modify_greedy_probs_inplace(self): - return self.base_layer.should_modify_greedy_probs_inplace - - def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None, - ) -> None: - # TODO: Verify if this condition can be further relaxed - if 32000 < self.base_layer.vocab_size > 257024: - raise ValueError("When using LoRA, vocab size must be " - "32000 >= vocab_size <= 257024") - self.lora_a_stacked = torch.zeros( - ( - max_loras, - 1, - lora_config.max_lora_rank, - self.hidden_size, - ), - dtype=lora_config.lora_dtype, - device=self.device, - ) - self.lora_b_stacked = torch.zeros( - ( - max_loras, - 1, - # Pad for kernel compatibility - math.ceil(self.base_layer.vocab_size / - lora_config.lora_vocab_padding_size) * - lora_config.lora_vocab_padding_size, - lora_config.max_lora_rank, - ), - dtype=lora_config.lora_dtype, - device=self.device, - ) - self.embeddings_tensors = torch.full( - (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size), - fill_value=float("-inf"), - dtype=self.dtype, - device=self.device, - ) - if self.sharded_to_full_mapping is not None: - self.sharded_to_full_mapping_gpu = torch.tensor( - self.sharded_to_full_mapping, - device=self.device, - dtype=torch.long) - else: - self.sharded_to_full_mapping_gpu = None - - def reset_lora(self, index: int): - self.lora_a_stacked[index] = 0 - self.lora_b_stacked[index] = 0 - self.embeddings_tensors[index] = float("-inf") - - def set_lora( - self, - index: int, - lora_a: torch.Tensor, - lora_b: torch.Tensor, - embeddings_tensor: Optional[torch.Tensor], - bias: Optional[torch.Tensor] = None, - ): - self.reset_lora(index) - self.lora_a_stacked[index, - 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( - lora_a.T, non_blocking=True) - self.lora_b_stacked[index, - 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( - lora_b.T, non_blocking=True) - if embeddings_tensor is not None: - self.embeddings_tensors[ - index, - :embeddings_tensor.shape[0], - :embeddings_tensor.shape[1], - ] = embeddings_tensor - - def _get_logits( - self, - hidden_states: torch.Tensor, - lm_head: VocabParallelEmbedding, - embedding_bias: Optional[torch.Tensor] = None, - ) -> Optional[torch.Tensor]: - # Get the logits for the next tokens. - logits = lm_head.quant_method.apply(lm_head, hidden_states) - if embedding_bias is not None: - logits += embedding_bias - - # Gather logits for TP - logits = self.base_layer._gather_logits(logits) - - if logits is None: - return None - - if self.sharded_to_full_mapping_gpu is not None: - # Reindex full logits tensor to ensure 1:1 mapping between - # index and token_id - # Example for: - # org_vocab_size = 4 - # added_vocab_size = 2 - # pad_to_size = 8 - # tp_size = 2 - - # indices: [0, 1, 2, 3, 4, 5, 6, 7] - # token_id: [0, 1, 4, -1, 2, 3, 5, -1] - - # Therefore, the mapping is expected to be: - # [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex, - # we get: - # indices: [0, 1, 2, 3, 4, 5, 6, 7] - # token_id: [0, 1, 2, 3, 4, 5, -1, -1] - logits = logits[:, self.sharded_to_full_mapping_gpu] - - lora_logits = torch.empty( - self.embeddings_tensors.shape[0] + 1, - self.embeddings_tensors.shape[1], - hidden_states.shape[0], - dtype=self.embeddings_tensors.dtype, - device=self.embeddings_tensors.device, - ) - torch.matmul(self.embeddings_tensors, - hidden_states.T, - out=lora_logits[:-1]) - - neg_inf, pos_inf = current_platform.get_infinity_values( - lora_logits.dtype) - - lora_logits[-1] = neg_inf - lora_logits = lora_logits.mT - indices_padded = self.punica_wrapper.sampler_indices_padded - - if current_platform.is_tpu() or current_platform.is_xpu(): - indices_padded = indices_padded[:logits.size(0)] - - lora_logits = (lora_logits.reshape( - lora_logits.shape[0] * lora_logits.shape[1], - lora_logits.shape[2], - ).index_select(0, indices_padded).nan_to_num_(nan=neg_inf, - posinf=pos_inf, - neginf=neg_inf)) - - logits[:, - self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + - lora_logits.shape[1]] = lora_logits - - lora_output: Optional[ - torch.Tensor] = self.punica_wrapper.add_lora_logits( - logits, hidden_states, self.lora_a_stacked, - self.lora_b_stacked, 1.0) - - if not current_platform.can_update_inplace(): - logits = lora_output - - # Remove paddings in vocab (if any). - logits = logits[:, :self.base_layer.vocab_size] - return logits - - def forward(self, *args, **kwargs): - return type(self.base_layer).forward(self, *args, **kwargs) - - @classmethod - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - # Special handling for the LogitsProcessor. - return False diff --git a/vllm/lora/layers/__init__.py b/vllm/lora/layers/__init__.py new file mode 100644 index 0000000000000..d3bb145dc7bf8 --- /dev/null +++ b/vllm/lora/layers/__init__.py @@ -0,0 +1,34 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.lora.layers.base import BaseLayerWithLoRA +from vllm.lora.layers.column_parallel_linear import ( + ColumnParallelLinearWithLoRA, ColumnParallelLinearWithShardedLoRA, + MergedColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithShardedLoRA, MergedQKVParallelLinearWithLoRA, + MergedQKVParallelLinearWithShardedLoRA, QKVParallelLinearWithLoRA, + QKVParallelLinearWithShardedLoRA) +from vllm.lora.layers.logits_processor import LogitsProcessorWithLoRA +from vllm.lora.layers.replicated_linear import ReplicatedLinearWithLoRA +from vllm.lora.layers.row_parallel_linear import ( + RowParallelLinearWithLoRA, RowParallelLinearWithShardedLoRA) +from vllm.lora.layers.utils import LoRAMapping +from vllm.lora.layers.vocal_parallel_embedding import ( + VocabParallelEmbeddingWithLoRA) + +__all__ = [ + "BaseLayerWithLoRA", + "VocabParallelEmbeddingWithLoRA", + "LogitsProcessorWithLoRA", + "ColumnParallelLinearWithLoRA", + "ColumnParallelLinearWithShardedLoRA", + "MergedColumnParallelLinearWithLoRA", + "MergedColumnParallelLinearWithShardedLoRA", + "MergedQKVParallelLinearWithLoRA", + "MergedQKVParallelLinearWithShardedLoRA", + "QKVParallelLinearWithLoRA", + "QKVParallelLinearWithShardedLoRA", + "RowParallelLinearWithLoRA", + "RowParallelLinearWithShardedLoRA", + "ReplicatedLinearWithLoRA", + "LoRAMapping", +] diff --git a/vllm/lora/layers/base.py b/vllm/lora/layers/base.py new file mode 100644 index 0000000000000..a80a033e39b40 --- /dev/null +++ b/vllm/lora/layers/base.py @@ -0,0 +1,69 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import TYPE_CHECKING, Optional, Union + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config.lora import LoRAConfig + +if TYPE_CHECKING: + from vllm.lora.punica_wrapper import PunicaWrapperBase + + +class BaseLayerWithLoRA(nn.Module): + + def slice_lora_a( + self, lora_a: Union[torch.Tensor, list[Union[torch.Tensor, None]]] + ) -> Union[torch.Tensor, list[Union[torch.Tensor, None]]]: + """Slice lora a if splitting for tensor parallelism.""" + ... + + def slice_lora_b( + self, lora_b: Union[torch.Tensor, list[Union[torch.Tensor, None]]] + ) -> Union[torch.Tensor, list[Union[torch.Tensor, None]]]: + """Slice lora b if splitting with tensor parallelism.""" + ... + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + """Initializes lora matrices.""" + ... + + def reset_lora(self, index: int): + """Resets the lora weights at index back to 0.""" + ... + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None, + ): + """Overwrites lora tensors at index.""" + ... + + def set_mapping( + self, + punica_wrapper, + ): + self.punica_wrapper: PunicaWrapperBase = punica_wrapper + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + """Returns True if the layer can be replaced by this LoRA layer.""" + raise NotImplementedError diff --git a/vllm/lora/layers/base_linear.py b/vllm/lora/layers/base_linear.py new file mode 100644 index 0000000000000..85a1f86ce6bf2 --- /dev/null +++ b/vllm/lora/layers/base_linear.py @@ -0,0 +1,184 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional, cast + +import torch +from transformers import PretrainedConfig + +from vllm.config.lora import LoRAConfig +from vllm.distributed.utils import divide +# yapf: disable +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearBase, ReplicatedLinear, + RowParallelLinear) +from vllm.platforms import current_platform + +from .base import BaseLayerWithLoRA +from .utils import _get_lora_device + + +class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): + + def __init__(self, base_layer: LinearBase): + super().__init__() + self.base_layer = base_layer + self.input_size = self.base_layer.input_size + self.device = _get_lora_device(self.base_layer) + self.lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] = None + + self.output_slices: tuple[int, ...] + self.tp_size: int + self.output_size: int + self.n_slices: int + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + self.lora_config = lora_config + # + if isinstance(self.base_layer, ReplicatedLinear): + lora_a_out_size = lora_config.max_lora_rank + lora_b_out_size = self.output_size + + elif isinstance(self.base_layer, ColumnParallelLinear): + lora_a_out_size = (lora_config.max_lora_rank if + not lora_config.fully_sharded_loras else divide( + lora_config.max_lora_rank, self.tp_size)) + lora_b_out_size = self.output_size + + elif isinstance(self.base_layer, RowParallelLinear): + lora_a_out_size = lora_config.max_lora_rank + lora_b_out_size = (self.output_size if + not lora_config.fully_sharded_loras else divide( + self.output_size, self.tp_size)) + else: + raise NotImplementedError + + self.lora_a_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_a_out_size, + self.input_size, + dtype=lora_config.lora_dtype, + device=self.device, + ) for _ in range(self.n_slices)) + self.lora_b_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_b_out_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.device, + ) for _ in range(self.n_slices)) + if lora_config.bias_enabled: + lora_bias_out_size = lora_b_out_size + self.lora_bias_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_bias_out_size, + dtype=lora_config.lora_dtype, + device=self.device, + ) for _ in range(self.n_slices)) + self.output_slices = (self.lora_b_stacked[0].shape[2], ) + + def reset_lora(self, index: int): + for s_index in range(self.n_slices): + self.lora_a_stacked[s_index][index] = 0 + self.lora_b_stacked[s_index][index] = 0 + if self.lora_config.bias_enabled: + # Make mypy happy + self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], + self.lora_bias_stacked) + self.lora_bias_stacked[s_index][index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + lora_bias: Optional[torch.Tensor] = None, + ): + # Except for QKVParallelLinearWithLoRA and + # MergedColumnParallelLinearWithLoRA, all other linear LoRA layers + # store weights in a tuple of size 1. These two layers will + # override this function. + assert (len(self.lora_a_stacked) == len(self.lora_b_stacked) == + self.n_slices == 1) + + self.reset_lora(index) + if self.tp_size > 1: + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) + if lora_bias is not None: + lora_bias = self.slice_bias(lora_bias) + + self.lora_a_stacked[0][index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[0][index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + if lora_bias is not None: + + self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], + self.lora_bias_stacked) + assert len(self.lora_bias_stacked) + self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_( + lora_bias.T, non_blocking=True) + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x, bias) + + # In transformers backend, x and output have extra batch dimension like + # (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim), + # therefore we need to flatten the batch dimensions. + if x.ndim == 3 and output.ndim == 3: + output = output.flatten(0, 1) + x = x.flatten(0, 1) + + lora_output: Optional[ + torch.Tensor] = self.punica_wrapper.add_lora_linear( + output, x, self.lora_a_stacked, self.lora_b_stacked, + self.lora_bias_stacked, 1.0, self.output_slices) + if not current_platform.can_update_inplace(): + output = lora_output + + return output + + @property + def weight(self) -> torch.Tensor: + + # unquantizedLinear + if hasattr(self.base_layer, "weight"): + return self.base_layer.weight + # Compressed Tensor + elif hasattr(self.base_layer, "weight_packed"): + return self.base_layer.weight_packed + # GPTQ/AWQ + elif hasattr(self.base_layer, "qweight"): + return self.base_layer.qweight + # marlin + elif hasattr(self.base_layer, "B"): + return self.base_layer.B + # HQQ marlin + elif hasattr(self.base_layer, "W_q"): + return self.base_layer.W_q + else: + raise ValueError(f"Unsupported base layer: {self.base_layer}") + + @property + def bias(self) -> Optional[torch.Tensor]: + if hasattr(self.base_layer, "bias"): + return self.base_layer.bias + else: + return None diff --git a/vllm/lora/layers/column_parallel_linear.py b/vllm/lora/layers/column_parallel_linear.py new file mode 100644 index 0000000000000..658fd23165da0 --- /dev/null +++ b/vllm/lora/layers/column_parallel_linear.py @@ -0,0 +1,622 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional, Union, cast + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config.lora import LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather) +from vllm.distributed.utils import divide +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear) +from vllm.platforms import current_platform + +from .base_linear import BaseLinearLayerWithLoRA +from .utils import _fully_sharded_can_replace, _not_fully_sharded_can_replace + + +def _mcp_apply(x, bias, layer: "ColumnParallelLinearWithLoRA"): + """ + For `ColumnParallelLinearWithLoRA` or classes that inherit from + `ColumnParallelLinearWithLoRA`, they share the same `apply` logic. + """ + assert (layer.n_slices == len(layer.lora_a_stacked) == len( + layer.lora_b_stacked) == len(layer.output_slices)) + if layer.lora_bias_stacked is not None: + assert layer.n_slices == len(layer.lora_bias_stacked) + + output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias) + + x = x.view(-1, x.shape[-1]) + output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape + + # Since communication is needed, the buffer is directly initialized as a + # tensor rather than a tuple of tensor. + buffers = torch.zeros( + (layer.n_slices, x.shape[0], layer.lora_a_stacked[0].shape[2]), + dtype=torch.float32, + device=x.device, + ) + + shrunk_buffers: Optional[torch.Tensor] = layer.punica_wrapper.add_shrink( + buffers, x, layer.lora_a_stacked, 1.0) + + if not current_platform.can_update_inplace(): + buffers = shrunk_buffers + + buffers = tensor_model_parallel_all_gather(buffers) + + lora_output: Optional[torch.Tensor] = layer.punica_wrapper.add_expand( + output, + buffers, + layer.lora_b_stacked, + layer.lora_bias_stacked, + layer.output_slices, + offset_start=0, + add_input=True) + + if not current_platform.can_update_inplace(): + output = lora_output + + output = output.view(*out_orig_shape) + # now have column partitioned and packed output + return output + + +class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): + """ + LoRA on top of ColumnParallelLinear layer. + LoRA B is sliced for tensor parallelism. + There are two types for the `base_layer`: + 1. ColumnParallelLinear, e.g.`dense_h_to_4h` in `FalconForCausalLM`. + 2. MergedColumnParallelLinear, e.g.`gate_up_proj` in `Phi3ForCausalLM`. + """ + + def __init__(self, base_layer: ColumnParallelLinear) -> None: + super().__init__(base_layer) + # The base_layer type is ColumnParallelLinear or + # MergedColumnParallelLinear, their weight sharding logic is + # inconsistent when TP is greater than 1. + self.is_merged_col_linear = type( + base_layer) is MergedColumnParallelLinear + self.tp_size = get_tensor_model_parallel_world_size() + self.output_size = self.base_layer.output_size_per_partition + # There is only one LoRA layer + self.n_slices = 1 + + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + return lora_a + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + # Applicable to cases where the base_layer is + # MergedColumnParallelLinear. + if self.is_merged_col_linear: + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.output_size // 2 + offset = lora_b.shape[-1] // 2 + + left_weight = lora_b[:, tp_rank * shard_size:(tp_rank + 1) * + shard_size] + right_weight = lora_b[:, offset + tp_rank * shard_size:offset + + (tp_rank + 1) * shard_size] + lora_b = torch.cat([left_weight, right_weight], dim=1) + # Applicable to cases where the base_layer is + # ColumnParallelLinear. + else: + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.output_size + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + lora_b = lora_b[:, start_idx:end_idx] + return lora_b + + def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: + # TODO: Fix the slicing logic of bias. + if bias is None: + return bias + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.output_size + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + bias = bias[start_idx:end_idx] + return bias + + def forward( + self, input_: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: + """Forward of ColumnParallelLinear + + Args: + input_: Tensor whose last dimension is `input_size`. + + Returns: + - output + - bias + """ + bias = (self.base_layer.bias + if not self.base_layer.skip_bias_add else None) + + # Matrix multiply. + output_parallel = self.apply(input_, bias) + if self.base_layer.gather_output: + # All-gather across the partitions. + output = tensor_model_parallel_all_gather(output_parallel) + else: + output = output_parallel + + if not self.base_layer.return_bias: + return output + + output_bias = (self.base_layer.bias + if self.base_layer.skip_bias_add else None) + return output, output_bias + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is ColumnParallelLinear or ( + type(source_layer) is MergedColumnParallelLinear + and len(packed_modules_list) == 1) + + +class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): + """ColumnParallelLinear layer that is composed of 2 sublayers (slices) + packed together (e.g. gate_proj + up_proj -> gate_up_proj). + + This means we have 2 LoRAs, each applied to one half of the layer. + + Both slices must have the same size. + """ + + def __init__( + self, base_layer: Union[MergedColumnParallelLinear, + QKVParallelLinear]) -> None: + super().__init__(base_layer) + # There are two LoRA layers + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + # the output_sizes in MergedColumnParallelLinear is not sharded by tp + # we need to divide it by the tp_size to get correct slices size + output_sizes = self.base_layer.output_sizes + self.output_slices = tuple( + divide(output_size, self.tp_size) for output_size in output_sizes) + self.n_slices = len(self.output_slices) + self.output_ids = (self.tp_rank, ) * self.n_slices + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + """ + The main reason for overriding this function is to enhance code + maintainability. + """ + self.lora_config = lora_config + + lora_a_output_size_per_partition = ( + lora_config.max_lora_rank if not lora_config.fully_sharded_loras + else divide(lora_config.max_lora_rank, self.tp_size)) + + self.lora_a_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_a_output_size_per_partition, + self.input_size, + dtype=lora_config.lora_dtype, + device=self.device, + ) for _ in range(self.n_slices)) + self.lora_b_stacked = tuple( + torch.zeros( + max_loras, + 1, + output_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.device, + ) for output_size in self.output_slices) + if lora_config.bias_enabled: + self.lora_bias_stacked = tuple( + torch.zeros( + max_loras, + 1, + output_size, + dtype=lora_config.lora_dtype, + device=self.device, + ) for output_size in self.output_slices) + + def slice_lora_a( + self, lora_a: list[Union[torch.Tensor, None]] + ) -> list[Union[torch.Tensor, None]]: + return lora_a + + def slice_lora_b( + self, lora_b: list[Union[torch.Tensor, None]] + ) -> list[Union[torch.Tensor, None]]: + sliced_lora_b = [None] * self.n_slices + for i, (shard_id, shard_size) in enumerate( + zip(self.output_ids, self.output_slices)): + if (lora_b_i := lora_b[i]) is not None: + sliced_lora_b[i] = lora_b_i[:, + shard_size * shard_id:shard_size * + (shard_id + 1)] + return sliced_lora_b + + def slice_bias( + self, bias: list[Union[torch.Tensor, + None]]) -> list[Union[torch.Tensor, None]]: + for i, (shard_id, shard_size) in enumerate( + zip(self.output_ids, self.output_slices)): + if (bias_i := bias[i]) is not None: + bias[i] = bias_i[shard_size * shard_id:shard_size * + (shard_id + 1)] + return bias + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + lora_bias: Optional[torch.Tensor] = None, + ): + self.reset_lora(index) + + if self.tp_size > 1: + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) + if lora_bias is not None: + lora_bias = self.slice_bias(lora_bias) + + for i in range(self.n_slices): + if (lora_a_i := lora_a[i]) is not None: + self.lora_a_stacked[i][ + index, 0, :lora_a_i.shape[1], :lora_a_i.shape[0]].copy_( + lora_a_i.T, non_blocking=True) + if (lora_b_i := lora_b[i]) is not None: + self.lora_b_stacked[i][ + index, 0, :lora_b_i.shape[1], :lora_b_i.shape[0]].copy_( + lora_b_i.T, non_blocking=True) + + if lora_bias is not None: + self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], + self.lora_bias_stacked) + for i in range(self.n_slices): + if (lora_bias_i := lora_bias[i]) is not None: + self.lora_bias_stacked[i][index, + 0, :lora_bias_i.shape[0]].copy_( + lora_bias_i.T, + non_blocking=True) + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return (type(source_layer) is MergedColumnParallelLinear + and len(packed_modules_list) == 2) + + +class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): + """ + ColumnParallelLinear layer that is specifically designed for + qkv_proj. Certain models, such as chatglm3 and baichuan-7b, + only contains a single LoRA within their qkv_proj layer. + + During inference with Tensor Parallel, the weights of lora_b + must be accurately partitioned according to the respective ranks. + + Q slice may have different shape than K and V slices (which both have + the same shape). + """ + + def __init__(self, base_layer: QKVParallelLinear) -> None: + super().__init__(base_layer) + self.q_proj_total_size = (self.base_layer.total_num_heads * + self.base_layer.head_size) + self.q_proj_shard_size = (self.base_layer.num_heads * + self.base_layer.head_size) + self.kv_proj_shard_size = (self.base_layer.num_kv_heads * + self.base_layer.head_size) + self.kv_proj_total_size = (self.base_layer.total_num_kv_heads * + self.base_layer.head_size) + # There is only one LoRA layer + self.n_slices = 1 + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + tp_rank = get_tensor_model_parallel_rank() + self.q_shard_id = tp_rank + self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas + lora_b_q = lora_b[:, self.q_proj_shard_size * + self.q_shard_id:self.q_proj_shard_size * + (self.q_shard_id + 1)] + k_offset = self.q_proj_total_size + lora_b_k = lora_b[:, k_offset + + self.kv_proj_shard_size * self.kv_shard_id:k_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1)] + v_offset = k_offset + self.kv_proj_total_size + lora_b_v = lora_b[:, v_offset + + self.kv_proj_shard_size * self.kv_shard_id:v_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1)] + lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1) + return lora_b + + def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: + bias_q = bias[self.q_proj_shard_size * + self.q_shard_id:self.q_proj_shard_size * + (self.q_shard_id + 1)] + k_offset = self.q_proj_total_size + bias_k = bias[k_offset + + self.kv_proj_shard_size * self.kv_shard_id:k_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1)] + v_offset = k_offset + self.kv_proj_total_size + bias_v = bias[v_offset + + self.kv_proj_shard_size * self.kv_shard_id:v_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1)] + bias = torch.cat([bias_q, bias_k, bias_v], dim=1) + return bias + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: list, + model_config: Optional[PretrainedConfig]) -> bool: + return type(source_layer) is QKVParallelLinear and len( + packed_modules_list) == 1 + + +class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA): + """MergedColumnParallelLinear layer that is composed of 3 sublayers (slices) + packed together in qkv proj fashion + (q_proj + k_proj + v_proj -> qkv_proj). + + This means we have 3 LoRAs, each applied to one slice of the layer. + + Q slice may have different shape than K and V slices (which both have + the same shape). + """ + + def __init__(self, base_layer: QKVParallelLinear) -> None: + super().__init__(base_layer) + # There are three LoRA layer. + self.n_slices = len(self.base_layer.output_sizes) + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + self.q_proj_shard_size = (self.base_layer.num_heads * + self.base_layer.head_size) + self.kv_proj_shard_size = (self.base_layer.num_kv_heads * + self.base_layer.head_size) + self.q_shard_id = self.tp_rank + self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas + + self.output_slices = ( + self.q_proj_shard_size, + self.kv_proj_shard_size, + self.kv_proj_shard_size, + ) + self.output_ids = ( + self.q_shard_id, + self.kv_shard_id, + self.kv_shard_id, + ) + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + """ + The main reason for overloading this function is to handle inconsistent + weight dimensions in qkv lora. + """ + super().create_lora_weights(max_loras, lora_config, model_config) + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return (type(source_layer) is QKVParallelLinear + and len(packed_modules_list) == 3) + + +# These following layers are based on the tensor parallelism strategy given in +# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023, +# https://arxiv.org/abs/2311.03285. + + +class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA): + """ + Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + # For all LoRA layers where the `base_layer` is `ColumnParallelLinear`, + # their `lora_a` and `lora_b` have different sharding patterns. After + # completing the `lora_a` GEMM , a gather operation is performed. + # Therefore, the sharding of `lora_a` only needs to correspond with the + # gather operation. + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.lora_a_stacked[0].shape[2] + start_idx = tp_rank * shard_size + lora_a = lora_a[:, start_idx:start_idx + shard_size] + return lora_a + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return _mcp_apply(x, bias, self) + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +class MergedColumnParallelLinearWithShardedLoRA( + MergedColumnParallelLinearWithLoRA): + """ + Differs from MergedColumnParallelLinearWithLoRA by slicing the + LoRA A's also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + def slice_lora_a( + self, lora_a: list[Union[torch.Tensor, None]] + ) -> list[Union[torch.Tensor, None]]: + #NOTE: lora_a contains 2 subloras, and each sublora could be None. + output_shard_size = self.lora_a_stacked[0].shape[2] + output_start_idx = self.tp_rank * output_shard_size + lora_a = [ + lora_a[0][:, output_start_idx:output_start_idx + + output_shard_size] if lora_a[0] is not None else None, + lora_a[1][:, output_start_idx:output_start_idx + + output_shard_size] if lora_a[1] is not None else None, + ] + return lora_a + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return _mcp_apply(x, bias, self) + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA): + """ + Differs from QKVParallelLinearWithLoRA by slicing the + LoRA A's also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.lora_a_stacked[0].shape[2] + start_idx = tp_rank * shard_size + lora_a = lora_a[:, start_idx:start_idx + shard_size] + return lora_a + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return _mcp_apply(x, bias, self) + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: list, + model_config: Optional[PretrainedConfig]) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA): + """ + Differs from MergedQKVParallelLinearWithLoRA by slicing the + LoRA A's also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + def slice_lora_a( + self, lora_a: list[Union[torch.Tensor, None]] + ) -> list[Union[torch.Tensor, None]]: + # NOTE: lora_a contains 3 subloras, and each sublora could be None. + shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)] + start_idx = [self.tp_rank * shard_size[i] for i in range(3)] + lora_a = [ + lora_a[0][:, start_idx[0]:start_idx[0] + + shard_size[0]] if lora_a[0] is not None else None, + lora_a[1][:, start_idx[1]:start_idx[1] + + shard_size[1]] if lora_a[1] is not None else None, + lora_a[2][:, start_idx[2]:start_idx[2] + + shard_size[2]] if lora_a[2] is not None else None, + ] + return lora_a + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return _mcp_apply(x, bias, self) + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) diff --git a/vllm/lora/layers/logits_processor.py b/vllm/lora/layers/logits_processor.py new file mode 100644 index 0000000000000..a50dcfa748f2f --- /dev/null +++ b/vllm/lora/layers/logits_processor.py @@ -0,0 +1,247 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math +from typing import Optional + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config.lora import LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.platforms import current_platform + +from .base import BaseLayerWithLoRA + + +class LogitsProcessorWithLoRA(BaseLayerWithLoRA): + """ + LoRA wrapper for LogitsProcessor, with extra logic to handle the + application of the LoRA adapter and added LoRA vocabulary. + + Args: + base_layer: LogitsProcessor layer + hidden_size: hidden size of the model + dtype: data type of the model + device: device of the model + sharded_to_full_mapping: index mapping from sharded vocab to full vocab + received from base_layer.get_sharded_to_full_mapping(). If None, + no reindexing will be done. + """ + + def __init__(self, base_layer: LogitsProcessor, hidden_size: int, + dtype: torch.dtype, device: torch.device, + sharded_to_full_mapping: Optional[list[int]]) -> None: + super().__init__() + self.base_layer = base_layer + self.hidden_size = hidden_size + self.dtype = dtype + self.device = device + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.sharded_to_full_mapping = sharded_to_full_mapping + + @property + def logits_as_input(self): + return self.base_layer.logits_as_input + + @property + def vocab_size(self): + return self.base_layer.vocab_size + + @property + def scale(self): + return self.base_layer.scale + + @property + def soft_cap(self): + return self.base_layer.soft_cap + + @property + def use_all_gather(self): + return self.base_layer.use_all_gather + + @property + def org_vocab_size(self): + return self.base_layer.org_vocab_size + + @property + def include_gpu_probs_tensor(self): + return self.base_layer.include_gpu_probs_tensor + + @property + def should_modify_greedy_probs_inplace(self): + return self.base_layer.should_modify_greedy_probs_inplace + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + # TODO: Verify if this condition can be further relaxed + if 32000 < self.base_layer.vocab_size > 257024: + raise ValueError("When using LoRA, vocab size must be " + "32000 >= vocab_size <= 257024") + self.lora_a_stacked = torch.zeros( + ( + max_loras, + 1, + lora_config.max_lora_rank, + self.hidden_size, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + self.lora_b_stacked = torch.zeros( + ( + max_loras, + 1, + # Pad for kernel compatibility + math.ceil(self.base_layer.vocab_size / + lora_config.lora_vocab_padding_size) * + lora_config.lora_vocab_padding_size, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + self.embeddings_tensors = torch.full( + (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size), + fill_value=float("-inf"), + dtype=self.dtype, + device=self.device, + ) + if self.sharded_to_full_mapping is not None: + self.sharded_to_full_mapping_gpu = torch.tensor( + self.sharded_to_full_mapping, + device=self.device, + dtype=torch.long) + else: + self.sharded_to_full_mapping_gpu = None + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + self.embeddings_tensors[index] = float("-inf") + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None, + ): + self.reset_lora(index) + self.lora_a_stacked[index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + if embeddings_tensor is not None: + self.embeddings_tensors[ + index, + :embeddings_tensor.shape[0], + :embeddings_tensor.shape[1], + ] = embeddings_tensor + + def _get_logits( + self, + hidden_states: torch.Tensor, + lm_head: VocabParallelEmbedding, + embedding_bias: Optional[torch.Tensor] = None, + ) -> Optional[torch.Tensor]: + # Get the logits for the next tokens. + logits = lm_head.quant_method.apply(lm_head, hidden_states) + if embedding_bias is not None: + logits += embedding_bias + + # Gather logits for TP + logits = self.base_layer._gather_logits(logits) + + if logits is None: + return None + + if self.sharded_to_full_mapping_gpu is not None: + # Reindex full logits tensor to ensure 1:1 mapping between + # index and token_id + # Example for: + # org_vocab_size = 4 + # added_vocab_size = 2 + # pad_to_size = 8 + # tp_size = 2 + + # indices: [0, 1, 2, 3, 4, 5, 6, 7] + # token_id: [0, 1, 4, -1, 2, 3, 5, -1] + + # Therefore, the mapping is expected to be: + # [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex, + # we get: + # indices: [0, 1, 2, 3, 4, 5, 6, 7] + # token_id: [0, 1, 2, 3, 4, 5, -1, -1] + logits = logits[:, self.sharded_to_full_mapping_gpu] + + lora_logits = torch.empty( + self.embeddings_tensors.shape[0] + 1, + self.embeddings_tensors.shape[1], + hidden_states.shape[0], + dtype=self.embeddings_tensors.dtype, + device=self.embeddings_tensors.device, + ) + torch.matmul(self.embeddings_tensors, + hidden_states.T, + out=lora_logits[:-1]) + + neg_inf, pos_inf = current_platform.get_infinity_values( + lora_logits.dtype) + + lora_logits[-1] = neg_inf + lora_logits = lora_logits.mT + indices_padded = self.punica_wrapper.sampler_indices_padded + + if current_platform.is_tpu() or current_platform.is_xpu(): + indices_padded = indices_padded[:logits.size(0)] + + lora_logits = (lora_logits.reshape( + lora_logits.shape[0] * lora_logits.shape[1], + lora_logits.shape[2], + ).index_select(0, indices_padded).nan_to_num_(nan=neg_inf, + posinf=pos_inf, + neginf=neg_inf)) + + logits[:, + self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + + lora_logits.shape[1]] = lora_logits + + lora_output: Optional[ + torch.Tensor] = self.punica_wrapper.add_lora_logits( + logits, hidden_states, self.lora_a_stacked, + self.lora_b_stacked, 1.0) + + if not current_platform.can_update_inplace(): + logits = lora_output + + # Remove paddings in vocab (if any). + logits = logits[:, :self.base_layer.vocab_size] + return logits + + def forward(self, *args, **kwargs): + return type(self.base_layer).forward(self, *args, **kwargs) + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + # Special handling for the LogitsProcessor. + return False diff --git a/vllm/lora/layers/qkv_x_parallel_linear.py b/vllm/lora/layers/qkv_x_parallel_linear.py new file mode 100644 index 0000000000000..367482d0ee078 --- /dev/null +++ b/vllm/lora/layers/qkv_x_parallel_linear.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from .base import BaseLayerWithLoRA + + +#TODO: Implement this +class QKVCrossParallelLinearWithLoRA(BaseLayerWithLoRA): + pass diff --git a/vllm/lora/layers/replicated_linear.py b/vllm/lora/layers/replicated_linear.py new file mode 100644 index 0000000000000..3356297c1537a --- /dev/null +++ b/vllm/lora/layers/replicated_linear.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional, Union + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config.lora import LoRAConfig +from vllm.model_executor.layers.linear import ReplicatedLinear + +from .base_linear import BaseLinearLayerWithLoRA + + +class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA): + + def __init__(self, base_layer: ReplicatedLinear) -> None: + super().__init__(base_layer, ) + # To ensure interface compatibility, set to 1 always. + self.tp_size = 1 + self.output_size = self.base_layer.output_size + self.n_slices = 1 + + def forward( + self, input_: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: + """Forward of ReplicatedLinearWithLoRA + + Args: + input_: Tensor whose last dimension is `input_size`. + + Returns: + - output + - bias + """ + bias = (self.base_layer.bias + if not self.base_layer.skip_bias_add else None) + + # Matrix multiply. + output = self.apply(input_, bias) + + output_bias = (self.base_layer.bias + if self.base_layer.skip_bias_add else None) + + if not self.base_layer.return_bias: + return output + + return output, output_bias + + # ReplicatedLinear should always be replaced, regardless of the fully + # sharded LoRAs setting, because it is, by definition, copied per GPU. + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is ReplicatedLinear diff --git a/vllm/lora/layers/row_parallel_linear.py b/vllm/lora/layers/row_parallel_linear.py new file mode 100644 index 0000000000000..18ef6fd1ddd78 --- /dev/null +++ b/vllm/lora/layers/row_parallel_linear.py @@ -0,0 +1,201 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional, Union, cast + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config.lora import LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_reduce) +# yapf: disable +from vllm.model_executor.layers.linear import RowParallelLinear +from vllm.platforms import current_platform + +from .base_linear import BaseLinearLayerWithLoRA +from .utils import _fully_sharded_can_replace, _not_fully_sharded_can_replace + + +class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): + + def __init__(self, base_layer: RowParallelLinear) -> None: + super().__init__(base_layer) + + self.tp_size = get_tensor_model_parallel_world_size() + # reset input_size + self.input_size = self.base_layer.input_size_per_partition + self.output_size = self.base_layer.output_size + + self.tp_rank = get_tensor_model_parallel_rank() + # There is only one LoRA layer. + self.n_slices = 1 + + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + + shard_size = self.input_size + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size + lora_a = lora_a[start_idx:end_idx, :] + return lora_a + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + return lora_b + + def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: + return bias + + def forward( + self, input_: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: + """Forward of RowParallelLinear + + Args: + input_: tensor whose last dimension is `input_size`. If + `input_is_parallel` is set, then the last dimension + is `input_size // tp_size`. + + Returns: + - output + - bias + """ + # set up backprop all-reduce. + if self.base_layer.input_is_parallel: + input_parallel = input_ + else: + # TODO: simplify code below + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.base_layer.tp_size) + input_parallel = splitted_input[self.tp_rank].contiguous() + + # Matrix multiply. + output_parallel = self.apply(input_parallel) + if self.base_layer.reduce_results and self.base_layer.tp_size > 1: + output_ = tensor_model_parallel_all_reduce(output_parallel) + else: + output_ = output_parallel + + if not self.base_layer.skip_bias_add: + output = (output_ + self.base_layer.bias + if self.base_layer.bias is not None else output_) + output_bias = None + else: + output = output_ + output_bias = self.base_layer.bias + + if not self.base_layer.return_bias: + return output + + return output, output_bias + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is RowParallelLinear + + + +# The following layer is based on the tensor parallelism strategy given in +# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023, +# https://arxiv.org/abs/2311.03285. + +class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): + """ + Differs from RowParallelLinearWithLoRA by slicing the + LoRA B's also. + + Based on S-LoRA, slicing happens along the output dim. + This yields a combined partial sum from the row parallel base + layer and column partitioned output from the LoRA. + """ + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + shard_size = self.lora_b_stacked[0].shape[2] + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size + lora_b = lora_b[:, start_idx:end_idx] + return lora_b + + def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: + if bias is None: + return bias + self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], + self.lora_bias_stacked) + shard_size = self.lora_bias_stacked[0].shape[2] + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size + bias = bias[start_idx:end_idx] + return bias + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x) + + x = x.view(-1, x.shape[-1]) + output, out_orig_shape = output.view(-1, + output.shape[-1]), output.shape + buffer = torch.zeros( + (self.n_slices, x.shape[0], self.lora_a_stacked[0].shape[2]), + dtype=torch.float32, + device=x.device, + ) + + shrunk_buffer: Optional[torch.Tensor] = self.punica_wrapper.add_shrink( + buffer, x, self.lora_a_stacked, 1.0) + if not current_platform.can_update_inplace(): + buffer = shrunk_buffer + + buffer = tensor_model_parallel_all_reduce(buffer) + + # following S-LoRA, allows the fusing of all_gather and all_reduce + # by adding the column partitioned lora output to a slice of output + # tensor, which is a partial sum due to row parallel. All that + # remains is a standard all_reduce. User should be aware though that + # the output is not the same as a normal row_parallel, it should be + # reduced before being used + # NOTE offset are based on the rank. + shard_size = self.lora_b_stacked[0].shape[2] + offset_start = self.tp_rank * shard_size + lora_output: Optional[torch.Tensor] = self.punica_wrapper.add_expand( + output, + buffer, + self.lora_b_stacked, + self.lora_bias_stacked, + self.output_slices, + offset_start=offset_start, + add_input=True, + ) + + if not current_platform.can_update_inplace(): + output = lora_output + + output = output.view(*out_orig_shape) + return output + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) diff --git a/vllm/lora/layers/utils.py b/vllm/lora/layers/utils.py new file mode 100644 index 0000000000000..27dcd720fbdea --- /dev/null +++ b/vllm/lora/layers/utils.py @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass + +import torch +import torch.nn as nn + +from vllm.adapter_commons.layers import AdapterMapping + + +@dataclass +class LoRAMapping(AdapterMapping): + is_prefill: bool = False + + +def _get_lora_device(base_layer: nn.Module) -> torch.device: + # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34 + """Returns the device for where to place the LoRA tensors.""" + # unquantizedLinear + if hasattr(base_layer, "weight"): + return base_layer.weight.device + # Compressed Tensor + elif hasattr(base_layer, "weight_packed"): + return base_layer.weight_packed.device + # GPTQ/AWQ + elif hasattr(base_layer, "qweight"): + return base_layer.qweight.device + # HQQ marlin + elif hasattr(base_layer, "W_q"): + return base_layer.W_q.device + else: + raise ValueError(f"Unsupported base layer: {base_layer}") + + +def _not_fully_sharded_can_replace(can_replace): + """ + decorator which adds the condition of not using fully sharded loras + intended to wrap can_replace_layer() + """ + + def dec(*args, **kwargs): + decorate = kwargs.pop("decorate") if "decorate" in kwargs else True + condition = (not kwargs["lora_config"].fully_sharded_loras + if decorate else True) + return can_replace(*args, **kwargs) and condition + + return dec + + +def _fully_sharded_can_replace(can_replace): + """ + decorator which adds the condition of fully sharded loras + intended to wrap can_replace_layer() + """ + + def dec(*args, **kwargs): + return (can_replace(*args, **kwargs) + and kwargs["lora_config"].fully_sharded_loras) + + return dec diff --git a/vllm/lora/layers/vocal_parallel_embedding.py b/vllm/lora/layers/vocal_parallel_embedding.py new file mode 100644 index 0000000000000..4d6218d970977 --- /dev/null +++ b/vllm/lora/layers/vocal_parallel_embedding.py @@ -0,0 +1,172 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import PretrainedConfig + +from vllm.config.lora import LoRAConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.platforms import current_platform + +from .base import BaseLayerWithLoRA + + +class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): + + def __init__(self, base_layer: VocabParallelEmbedding) -> None: + super().__init__() + self.base_layer = base_layer + self.embeddings_slice: Optional[tuple[int, int]] + self.embeddings_weights: Optional[torch.Tensor] + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> None: + + if self.base_layer.num_added_embeddings_per_partition > 0: + # We can start adding lora weights + self.embeddings_weights = self.base_layer.weight.data[ + self.base_layer.num_org_embeddings_per_partition:self. + base_layer.num_org_embeddings_per_partition + + self.base_layer.num_added_embeddings_per_partition] + self.embeddings_slice = ( + self.base_layer.shard_indices.added_vocab_start_index - + self.base_layer.org_vocab_size, + self.base_layer.shard_indices.added_vocab_end_index - + self.base_layer.org_vocab_size) + self.base_layer.weight.data[ + self.base_layer.num_org_embeddings_per_partition:].fill_(0) + else: + self.embeddings_slice = None + self.embeddings_weights = None + + self.embeddings_tensors = torch.zeros( + ( + max_loras, + lora_config.lora_extra_vocab_size, + self.base_layer.embedding_dim, + ), + dtype=self.base_layer.weight.dtype, + device=self.base_layer.weight.device, + ) + self.lora_a_stacked = torch.zeros( + ( + max_loras, + self.base_layer.org_vocab_size + + lora_config.lora_extra_vocab_size, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.lora_b_stacked = torch.zeros( + ( + max_loras, + 1, + self.base_layer.embedding_dim, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.lora_a_stacked_2d = self.lora_a_stacked.view( + self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1], + self.lora_a_stacked.shape[2], + ) + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + self.embeddings_tensors[index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None, + ): + self.reset_lora(index) + self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_( + lora_a, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + if embeddings_tensor is not None: + self.embeddings_tensors[ + index, + :embeddings_tensor.shape[0], + :embeddings_tensor.shape[1], + ].copy_(embeddings_tensor, non_blocking=True) + if self.embeddings_slice is not None: + # TODO(yard1): Optimize this copy, we don't need to copy + # everything, just the modified part + embeddings = self.embeddings_tensors.view( + self.embeddings_tensors.shape[0] * + self.embeddings_tensors.shape[1], + self.embeddings_tensors.shape[2], + )[self.embeddings_slice[0]:self.embeddings_slice[1]] + assert self.embeddings_weights is not None + self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1, + 1, 0) + + # NB: Don't use torch.narrow here. torch.narrow triggers some + # Dynamic Shape specialization in torch.compile + num_tokens = x.shape[0] + indices_1 = self.punica_wrapper._embeddings_indices[1][:num_tokens] + indices_0 = self.punica_wrapper._embeddings_indices[0][:num_tokens] + + full_lora_a_embeddings = F.embedding( + x + indices_1, + self.lora_a_stacked_2d, + ) + full_output = self.base_layer.forward(x + + (indices_0 * added_tokens_mask)) + + full_output_org = full_output + if full_output.ndim == 3: + full_output = full_output.view( + full_output.shape[0] * full_output.shape[1], -1) + if full_lora_a_embeddings.ndim == 3: + full_lora_a_embeddings = full_lora_a_embeddings.view( + full_lora_a_embeddings.shape[0] * + full_lora_a_embeddings.shape[1], + -1, + ) + + lora_output: Optional[ + torch.Tensor] = self.punica_wrapper.add_lora_embedding( + full_output, + full_lora_a_embeddings, + self.lora_b_stacked, + add_input=True) + + if not current_platform.can_update_inplace(): + full_output = lora_output + + return full_output.view_as(full_output_org) + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is VocabParallelEmbedding + + @property + def weight(self): + return self.base_layer.weight diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 3072047a2606c..7712438054914 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -16,7 +16,7 @@ from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel, from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter, get_adapter, list_adapters, remove_adapter, set_adapter_mapping) -from vllm.config import LoRAConfig +from vllm.config.lora import LoRAConfig from vllm.logger import init_logger from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights diff --git a/vllm/lora/peft_helper.py b/vllm/lora/peft_helper.py index 8b8e5cb7d5fae..dc7249c386021 100644 --- a/vllm/lora/peft_helper.py +++ b/vllm/lora/peft_helper.py @@ -9,7 +9,7 @@ import os from dataclasses import MISSING, dataclass, field, fields from typing import Literal, Optional, Union -from vllm.config import LoRAConfig +from vllm.config.lora import LoRAConfig from vllm.logger import init_logger from vllm.model_executor.model_loader.tensorizer import TensorizerConfig diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 1fc214c12b5d1..10ba390bffd9e 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -11,23 +11,23 @@ from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError, from torch import nn from transformers import PretrainedConfig -from vllm.config import LoRAConfig +from vllm.config.lora import LoRAConfig from vllm.logger import init_logger -from vllm.lora.fully_sharded_layers import ( - ColumnParallelLinearWithShardedLoRA, - MergedColumnParallelLinearWithShardedLoRA, - MergedQKVParallelLinearWithShardedLoRA, QKVParallelLinearWithShardedLoRA, - RowParallelLinearWithShardedLoRA) # being imported for _all_lora_classes below # yapf conflicts with isort for this block # yapf: disable from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, + ColumnParallelLinearWithShardedLoRA, LogitsProcessorWithLoRA, MergedColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithShardedLoRA, MergedQKVParallelLinearWithLoRA, + MergedQKVParallelLinearWithShardedLoRA, QKVParallelLinearWithLoRA, + QKVParallelLinearWithShardedLoRA, ReplicatedLinearWithLoRA, RowParallelLinearWithLoRA, + RowParallelLinearWithShardedLoRA, VocabParallelEmbeddingWithLoRA) from vllm.model_executor.layers.linear import LinearBase diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 248d2954f1ef4..3a807b1e161d2 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -11,7 +11,7 @@ from vllm.adapter_commons.utils import (add_adapter_worker, list_adapters_worker, set_active_adapters_worker) from vllm.adapter_commons.worker_manager import AbstractWorkerManager -from vllm.config import LoRAConfig +from vllm.config.lora import LoRAConfig from vllm.logger import init_logger from vllm.lora.models import (LoRAModel, LoRAModelManager, LRUCacheLoRAModelManager, create_lora_manager) diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 319fa938d400e..235df1a77c5ce 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -454,7 +454,7 @@ class XIELU(CustomOp): ) return result.view(original_shape) - def forward(self, input: torch.Tensor) -> torch.Tensor: + def forward_native(self, input: torch.Tensor) -> torch.Tensor: if self._xielu_cuda_obj is not None and input.is_cuda: if not torch._dynamo.is_compiling(): return self._xielu_cuda_fn(input) @@ -464,6 +464,9 @@ class XIELU(CustomOp): ) return self._xielu_python(input) + def forward_cuda(self, input: torch.Tensor) -> torch.Tensor: + return self.forward_native(input) + class ScaledActivation(nn.Module): """An activation function with post-scale parameters. diff --git a/vllm/model_executor/layers/fla/__init__.py b/vllm/model_executor/layers/fla/__init__.py new file mode 100644 index 0000000000000..0e89cf9f79439 --- /dev/null +++ b/vllm/model_executor/layers/fla/__init__.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang diff --git a/vllm/model_executor/layers/fla/ops/__init__.py b/vllm/model_executor/layers/fla/ops/__init__.py new file mode 100644 index 0000000000000..c19cc14ba6928 --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/__init__.py @@ -0,0 +1,17 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +from .chunk import chunk_gated_delta_rule +from .fused_recurrent import fused_recurrent_gated_delta_rule +from .layernorm_guard import RMSNormGated + +__all__ = [ + "RMSNormGated", + "chunk_gated_delta_rule", + "fused_recurrent_gated_delta_rule", +] diff --git a/vllm/model_executor/layers/fla/ops/chunk.py b/vllm/model_executor/layers/fla/ops/chunk.py new file mode 100644 index 0000000000000..e7d295aff2392 --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/chunk.py @@ -0,0 +1,225 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +import warnings +from typing import Optional + +import torch +from einops import rearrange + +from .chunk_delta_h import chunk_gated_delta_rule_fwd_h +from .chunk_o import chunk_fwd_o +from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd +from .cumsum import chunk_local_cumsum +from .l2norm import l2norm_fwd +from .solve_tril import solve_tril +from .utils import SUPPRESS_LEVEL, input_guard +from .wy_fast import recompute_w_u_fwd + + +def chunk_gated_delta_rule_fwd(q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None): + g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens) + # obtain WY representation. u is actually the new v. + A = chunk_scaled_dot_kkt_fwd(k=k, + beta=beta, + g_cumsum=g, + cu_seqlens=cu_seqlens, + output_dtype=torch.float32) + A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) + w, u = recompute_w_u_fwd( + k=k, + v=v, + beta=beta, + A=A, + g_cumsum=g, + cu_seqlens=cu_seqlens, + ) + h, v_new, final_state = chunk_gated_delta_rule_fwd_h( + k=k, + w=w, + u=u, + g=g, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + ) + o = chunk_fwd_o( + q=q, + k=k, + v=v_new, + h=h, + g=g, + scale=scale, + cu_seqlens=cu_seqlens, + ) + if SUPPRESS_LEVEL < 3: + return g, o, A, final_state, None, None, None + elif SUPPRESS_LEVEL >= 3: + return g, o, A, final_state, w, h, v_new + + +class ChunkGatedDeltaRuleFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @torch.amp.custom_fwd(device_type='cuda') + def forward(ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, + use_qk_l2norm_in_kernel: bool = False): + if use_qk_l2norm_in_kernel: + q = l2norm_fwd(q) + k = l2norm_fwd(k) + + g, o, A, final_state, w, h, v_new = chunk_gated_delta_rule_fwd( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + ) + ctx.scale = scale + ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel + return o.to(q.dtype), final_state + + +@torch.compiler.disable +def chunk_gated_delta_rule(q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, + use_qk_l2norm_in_kernel: bool = False): + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + g (torch.Tensor): + (forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. + beta (torch.Tensor): + betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule + # inputs with equal lengths + >>> B, T, H, K, V = 4, 2048, 4, 512, 512 + >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda') + >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1) + >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda') + >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid() + >>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda')) + >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda') + >>> o, ht = chunk_gated_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = chunk_gated_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens + ) + """ + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16." + assert len( + beta.shape + ) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise." + + if head_first: + raise DeprecationWarning( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead.", + stacklevel=2) + q, k, v, beta, g = map( + lambda x: rearrange(x, 'b h t ... -> b t h ...'), + (q, k, v, beta, g)) + if not head_first and q.shape[1] < q.shape[2]: + warnings.warn( + f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...].", + stacklevel=2) + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing.") + if initial_state is not None and initial_state.shape[0] != len( + cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) + if scale is None: + scale = k.shape[-1]**-0.5 + o, final_state = ChunkGatedDeltaRuleFunction.apply( + q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens, + use_qk_l2norm_in_kernel) + if head_first: + o = rearrange(o, 'b t h ... -> b h t ...') + return o, final_state diff --git a/vllm/model_executor/layers/fla/ops/chunk_delta_h.py b/vllm/model_executor/layers/fla/ops/chunk_delta_h.py new file mode 100644 index 0000000000000..34006f87f457b --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/chunk_delta_h.py @@ -0,0 +1,290 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +from typing import Optional + +import torch + +from vllm.triton_utils import tl, triton + +from .index import prepare_chunk_indices, prepare_chunk_offsets +from .op import exp +from .utils import is_nvidia_hopper, use_cuda_graph + +NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16] + + +@triton.heuristics({ + 'USE_G': lambda args: args['g'] is not None, + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'SAVE_NEW_VALUE': lambda args: args['v_new'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4] for num_stages in [2, 3, 4] for BV in [32, 64] + ], + key=['H', 'K', 'V', 'BT', 'USE_G'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( + k, + v, + w, + v_new, + g, + h, + h0, + ht, + cu_seqlens, + chunk_offsets, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + SAVE_NEW_VALUE: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to( + tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_h1 = tl.zeros([64, BV], dtype=tl.float32) + if K > 64: + b_h2 = tl.zeros([64, BV], dtype=tl.float32) + if K > 128: + b_h3 = tl.zeros([64, BV], dtype=tl.float32) + if K > 192: + b_h4 = tl.zeros([64, BV], dtype=tl.float32) + + # calculate offset + h += (boh * H + i_h) * K * V + v += (bos * H + i_h) * V + k += (bos * Hg + i_h // (H // Hg)) * K + w += (bos * H + i_h) * K + if SAVE_NEW_VALUE: + v_new += (bos * H + i_h) * V + stride_v = H * V + stride_h = H * K * V + stride_k = Hg * K + stride_w = H * K + if USE_INITIAL_STATE: + h0 = h0 + i_nh * K * V + if STORE_FINAL_STATE: + ht = ht + i_nh * K * V + + # load initial state + if USE_INITIAL_STATE: + p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), + (1, 0)) + b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32) + if K > 64: + p_h0_2 = tl.make_block_ptr(h0, (K, V), (V, 1), (64, i_v * BV), + (64, BV), (1, 0)) + b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32) + if K > 128: + p_h0_3 = tl.make_block_ptr(h0, (K, V), (V, 1), (128, i_v * BV), + (64, BV), (1, 0)) + b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32) + if K > 192: + p_h0_4 = tl.make_block_ptr(h0, (K, V), (V, 1), (192, i_v * BV), + (64, BV), (1, 0)) + b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32) + + # main recurrence + for i_t in range(NT): + p_h1 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), + (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_h2 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), + (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h2, + b_h2.to(p_h2.dtype.element_ty), + boundary_check=(0, 1)) + if K > 128: + p_h3 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), + (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h3, + b_h3.to(p_h3.dtype.element_ty), + boundary_check=(0, 1)) + if K > 192: + p_h4 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), + (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h4, + b_h4.to(p_h4.dtype.element_ty), + boundary_check=(0, 1)) + + p_v = tl.make_block_ptr(v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), + (BT, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new, (T, V), (stride_v, 1), + (i_t * BT, i_v * BV), (BT, BV), + (1, 0)) if SAVE_NEW_VALUE else None + b_v_new = tl.zeros([BT, BV], dtype=tl.float32) + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 0), + (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v_new += tl.dot(b_w, b_h1.to(b_w.dtype)) + if K > 64: + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 64), + (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v_new += tl.dot(b_w, b_h2.to(b_w.dtype)) + if K > 128: + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 128), + (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v_new += tl.dot(b_w, b_h3.to(b_w.dtype)) + if K > 192: + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 192), + (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v_new += tl.dot(b_w, b_h4.to(b_w.dtype)) + b_v_new = -b_v_new + tl.load(p_v, boundary_check=(0, 1)) + + if SAVE_NEW_VALUE: + p_v_new = tl.make_block_ptr(v_new, (T, V), (stride_v, 1), + (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_v_new, + b_v_new.to(p_v_new.dtype.element_ty), + boundary_check=(0, 1)) + + if USE_G: + m_t = (i_t * BT + tl.arange(0, BT)) < T + last_idx = min((i_t + 1) * BT, T) - 1 + b_g_last = tl.load(g + bos * H + last_idx * H + i_h) + p_g = tl.make_block_ptr(g + bos * H + i_h, (T, ), (H, ), + (i_t * BT, ), (BT, ), (0, )) + b_g = tl.load(p_g, boundary_check=(0, )) + b_v_new = b_v_new * tl.where(m_t, exp(b_g_last - b_g), 0)[:, None] + b_g_last = exp(b_g_last) + b_h1 = b_h1 * b_g_last + if K > 64: + b_h2 = b_h2 * b_g_last + if K > 128: + b_h3 = b_h3 * b_g_last + if K > 192: + b_h4 = b_h4 * b_g_last + b_v_new = b_v_new.to(k.dtype.element_ty) + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (0, i_t * BT), + (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h1 += tl.dot(b_k, b_v_new) + if K > 64: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (64, i_t * BT), + (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h2 += tl.dot(b_k, b_v_new) + if K > 128: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (128, i_t * BT), + (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h3 += tl.dot(b_k, b_v_new) + if K > 192: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (192, i_t * BT), + (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h4 += tl.dot(b_k, b_v_new) + + # epilogue + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), + (1, 0)) + tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (64, i_v * BV), + (64, BV), (1, 0)) + tl.store(p_ht, + b_h2.to(p_ht.dtype.element_ty), + boundary_check=(0, 1)) + if K > 128: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (128, i_v * BV), + (64, BV), (1, 0)) + tl.store(p_ht, + b_h3.to(p_ht.dtype.element_ty), + boundary_check=(0, 1)) + if K > 192: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (192, i_v * BV), + (64, BV), (1, 0)) + tl.store(p_ht, + b_h4.to(p_ht.dtype.element_ty), + boundary_check=(0, 1)) + + +def chunk_gated_delta_rule_fwd_h( + k: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + g: Optional[torch.Tensor] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + chunk_size: int = 64, # SY: remove this argument and force chunk size 64? + save_new_value: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + B, T, Hg, K, V = *k.shape, u.shape[-1] + H = u.shape[-2] + BT = chunk_size + + chunk_indices = prepare_chunk_indices( + cu_seqlens, chunk_size) if cu_seqlens is not None else None + # N: the actual number of sequences in the batch with either equal or variable lengths + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = len(cu_seqlens) - 1, len( + chunk_indices), prepare_chunk_offsets(cu_seqlens, BT) + assert K <= 256, "current kernel does not support head dimension larger than 256." + + h = k.new_empty(B, NT, H, K, V) + final_state = k.new_empty( + N, H, K, V, dtype=torch.float32) if output_final_state else None + + v_new = torch.empty_like(u) if save_new_value else None + + def grid(meta): + return (triton.cdiv(V, meta['BV']), N * H) + + chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid]( + k=k, + v=u, + w=w, + v_new=v_new, + g=g, + h=h, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT) + return h, v_new, final_state diff --git a/vllm/model_executor/layers/fla/ops/chunk_o.py b/vllm/model_executor/layers/fla/ops/chunk_o.py new file mode 100644 index 0000000000000..332751a1860a9 --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/chunk_o.py @@ -0,0 +1,177 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# ruff: noqa: E501 + +from typing import Optional + +import torch + +from vllm.triton_utils import tl, triton + +from .index import prepare_chunk_indices +from .op import exp +from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper + +BKV_LIST = [64, 128] if check_shared_mem() else [32, 64] +NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8] + + +@triton.heuristics({ + 'USE_G': lambda args: args['g'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({ + 'BK': BK, + 'BV': BV + }, + num_warps=num_warps, + num_stages=num_stages) for BK in BKV_LIST + for BV in BKV_LIST for num_warps in NUM_WARPS + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'V', 'BT'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_fwd_kernel_o( + q, + k, + v, + h, + g, + o, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to( + tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to( + tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + # offset calculation + q += (bos * Hg + i_h // (H // Hg)) * K + k += (bos * Hg + i_h // (H // Hg)) * K + v += (bos * H + i_h) * V + o += (bos * H + i_h) * V + h += (i_tg * H + i_h).to(tl.int64) * K * V + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_A = tl.zeros([BT, BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), + (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), + (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), + (BK, BV), (1, 0)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + + # [BT, BK] @ [BK, BV] -> [BT, BV] + b_o += tl.dot(b_q, b_h) + # [BT, BK] @ [BK, BT] -> [BT, BT] + b_A += tl.dot(b_q, b_k) + + if USE_G: + g += bos * H + i_h + p_g = tl.make_block_ptr(g, (T, ), (H, ), (i_t * BT, ), (BT, ), (0, )) + b_g = tl.load(p_g, boundary_check=(0, )) + b_o = b_o * exp(b_g)[:, None] + b_A = b_A * exp(b_g[:, None] - b_g[None, :]) + + o_t = i_t * BT + tl.arange(0, BT) + m_t = o_t < T + m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t) + b_A = tl.where(m_A, b_A, 0) + + p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), + (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o, (T, V), (H * V, 1), (i_t * BT, i_v * BV), + (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + + # to fix mma -> mma layout conversion + # already solved by triton v3.2 or higher + b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_fwd_o( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + h: torch.Tensor, + g: Optional[torch.Tensor] = None, # cumsum of log decay + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64) -> torch.Tensor: + B, T, Hg, K, V = *q.shape, v.shape[-1] + H = v.shape[-2] + if FLA_GDN_FIX_BT: + BT = 64 + else: + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + chunk_indices = prepare_chunk_indices( + cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + if scale is None: + scale = k.shape[-1]**-0.5 + + o = torch.empty_like(v) + + def grid(meta): + return (triton.cdiv(V, meta['BV']), NT, B * H) + + chunk_fwd_kernel_o[grid]( + q, + k, + v, + h, + g, + o, + cu_seqlens, + chunk_indices, + scale, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + ) + return o diff --git a/vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py b/vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py new file mode 100644 index 0000000000000..d1adc6978f245 --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py @@ -0,0 +1,140 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +from typing import Optional + +import torch + +from vllm.triton_utils import tl, triton + +from .index import prepare_chunk_indices +from .op import exp + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, + 'USE_G': lambda args: args['g_cumsum'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64, 128] for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'BT', 'IS_VARLEN'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_scaled_dot_kkt_fwd_kernel( + k, + beta, + g_cumsum, + A, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_G: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to( + tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to( + tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + o_t = i_t * BT + tl.arange(0, BT) + m_t = o_t < T + + p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T, ), (H, ), + (i_t * BT, ), (BT, ), (0, )) + b_beta = tl.load(p_beta, boundary_check=(0, )) + + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k + (bos * Hg + i_h // (H // Hg)) * K, (T, K), + (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), + (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = b_k * b_beta[:, None] + b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k)) + + if USE_G: + p_g = tl.make_block_ptr(g_cumsum + bos * H + i_h, (T, ), (H, ), + (i_t * BT, ), (BT, ), (0, )) + b_g = tl.load(p_g, boundary_check=(0, )) + b_g_diff = b_g[:, None] - b_g[None, :] + b_A = b_A * exp(b_g_diff) + + m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t) + b_A = tl.where(m_A, b_A, 0) + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), + (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_scaled_dot_kkt_fwd( + k: torch.Tensor, + beta: torch.Tensor, + g_cumsum: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, + output_dtype: torch.dtype = torch.float32) -> torch.Tensor: + r""" + Compute beta * K * K^T. + + Args: + k (torch.Tensor): + The key tensor of shape `[B, T, H, K]`. + beta (torch.Tensor): + The beta tensor of shape `[B, T, H]`. + g_cumsum (torch.Tensor): + The cumulative sum of the gate tensor of shape `[B, T, H]`. + Default: None + cu_seqlens (torch.LongTensor): + The cumulative sequence lengths of the input tensor. + Default: None + chunk_size (int): + The chunk size. Default: 64. + output_dtype (torch.dtype): + The dtype of the output tensor. Default: `torch.float32` + + Returns: + beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size. + """ + + B, T, Hg, K = k.shape + + H = beta.shape[-1] + BT = chunk_size + chunk_indices = prepare_chunk_indices( + cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype) + chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)]( + k=k, + beta=beta, + g_cumsum=g_cumsum, + A=A, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + Hg=Hg, + K=K, + BT=BT, + ) + return A diff --git a/vllm/model_executor/layers/fla/ops/cumsum.py b/vllm/model_executor/layers/fla/ops/cumsum.py new file mode 100644 index 0000000000000..370a45fe16358 --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/cumsum.py @@ -0,0 +1,226 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +import warnings +from typing import Optional + +import torch + +from vllm.triton_utils import tl, triton + +from .index import prepare_chunk_indices +from .utils import check_shared_mem, input_guard + +BS_LIST = [32, 64] if check_shared_mem() else [16, 32] + + +@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None}) +@triton.autotune(configs=[ + triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8] +], + key=['B', 'H', 'BT', 'IS_VARLEN', 'REVERSE']) +@triton.jit(do_not_specialize=['T']) +def chunk_local_cumsum_scalar_kernel( + s, + o, + cu_seqlens, + chunk_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + BT: tl.constexpr, + REVERSE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to( + tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to( + tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + p_s = tl.make_block_ptr(s + bos * H + i_h * T, (T, ), (1, ), + (i_t * BT, ), (BT, ), (0, )) + p_o = tl.make_block_ptr(o + bos * H + i_h * T, (T, ), (1, ), + (i_t * BT, ), (BT, ), (0, )) + else: + p_s = tl.make_block_ptr(s + bos * H + i_h, (T, ), (H, ), (i_t * BT, ), + (BT, ), (0, )) + p_o = tl.make_block_ptr(o + bos * H + i_h, (T, ), (H, ), (i_t * BT, ), + (BT, ), (0, )) + # [BT] + b_s = tl.load(p_s, boundary_check=(0, )).to(tl.float32) + b_o = tl.cumsum(b_s, axis=0) + if REVERSE: + b_z = tl.sum(b_s, axis=0) + b_o = -b_o + b_z[None] + b_s + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, )) + + +@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None}) +@triton.autotune(configs=[ + triton.Config({'BS': BS}, num_warps=num_warps) for BS in BS_LIST + for num_warps in [2, 4, 8] +], + key=['B', 'H', 'S', 'BT', 'IS_VARLEN', 'REVERSE']) +@triton.jit(do_not_specialize=['T']) +def chunk_local_cumsum_vector_kernel( + s, + o, + cu_seqlens, + chunk_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + REVERSE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to( + tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to( + tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + o_i = tl.arange(0, BT) + if REVERSE: + m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.) + else: + m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.) + + if HEAD_FIRST: + p_s = tl.make_block_ptr(s + (bos * H + i_h * T) * S, (T, S), (S, 1), + (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h * T) * S, (T, S), (S, 1), + (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + else: + p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H * S, 1), + (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h) * S, (T, S), (H * S, 1), + (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + b_o = tl.dot(m_s, b_s, allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_local_cumsum_scalar( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + cu_seqlens: Optional[torch.Tensor] = None, + head_first: bool = False, + output_dtype: Optional[torch.dtype] = torch.float) -> torch.Tensor: + if head_first: + B, H, T = g.shape + else: + B, T, H = g.shape + assert chunk_size == 2**(chunk_size.bit_length() - + 1), "chunk_size must be a power of 2" + BT = chunk_size + chunk_indices = prepare_chunk_indices( + cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) + grid = (NT, B * H) + chunk_local_cumsum_scalar_kernel[grid](g_org, + g, + cu_seqlens, + chunk_indices, + T=T, + B=B, + H=H, + BT=BT, + HEAD_FIRST=head_first, + REVERSE=reverse) + return g + + +def chunk_local_cumsum_vector( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + cu_seqlens: Optional[torch.Tensor] = None, + head_first: bool = False, + output_dtype: Optional[torch.dtype] = torch.float) -> torch.Tensor: + if head_first: + B, H, T, S = g.shape + else: + B, T, H, S = g.shape + BT = chunk_size + chunk_indices = prepare_chunk_indices( + cu_seqlens, chunk_size) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + assert chunk_size == 2**(chunk_size.bit_length() - + 1), "chunk_size must be a power of 2" + + g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) + + def grid(meta): + return (triton.cdiv(meta['S'], meta['BS']), NT, B * H) + + # keep cumulative normalizer in fp32 + # this kernel is equivalent to + # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1) + chunk_local_cumsum_vector_kernel[grid](g_org, + g, + cu_seqlens, + chunk_indices, + T=T, + B=B, + H=H, + S=S, + BT=BT, + HEAD_FIRST=head_first, + REVERSE=reverse) + return g + + +@input_guard +def chunk_local_cumsum(g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + cu_seqlens: Optional[torch.Tensor] = None, + head_first: bool = False, + output_dtype: Optional[torch.dtype] = torch.float, + **kwargs) -> torch.Tensor: + if not head_first and g.shape[1] < g.shape[2]: + warnings.warn( + f"Input tensor shape suggests potential format mismatch: seq_len ({g.shape[1]}) < num_heads ({g.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...].", + stacklevel=2) + if cu_seqlens is not None: + assert g.shape[ + 0] == 1, "Only batch size 1 is supported when cu_seqlens are provided" + if len(g.shape) == 3: + return chunk_local_cumsum_scalar(g, chunk_size, reverse, cu_seqlens, + head_first, output_dtype) + elif len(g.shape) == 4: + return chunk_local_cumsum_vector(g, chunk_size, reverse, cu_seqlens, + head_first, output_dtype) + else: + raise ValueError(f"Unsupported input shape {g.shape}. " + f"which should be (B, T, H, D) if `head_first=False` " + f"or (B, H, T, D) otherwise") diff --git a/vllm/model_executor/layers/fla/ops/fused_recurrent.py b/vllm/model_executor/layers/fla/ops/fused_recurrent.py new file mode 100644 index 0000000000000..b278e37415748 --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/fused_recurrent.py @@ -0,0 +1,366 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +from typing import Optional + +import torch + +from vllm.triton_utils import tl, triton + +from .op import exp + + +@triton.heuristics({ + 'USE_INITIAL_STATE': + lambda args: args['h0'] is not None, + 'IS_VARLEN': + lambda args: args['cu_seqlens'] is not None, + "IS_CONTINUOUS_BATCHING": + lambda args: args['ssm_state_indices'] is not None, + "IS_SPEC_DECODING": + lambda args: args['num_accepted_tokens'] is not None, +}) +@triton.jit(do_not_specialize=['N', 'T']) +def fused_recurrent_gated_delta_rule_fwd_kernel( + q, + k, + v, + g, + beta, + o, + h0, + ht, + cu_seqlens, + ssm_state_indices, + num_accepted_tokens, + scale, + N: tl.constexpr, # num of sequences + T: tl.constexpr, # num of tokens + B: tl.constexpr, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + stride_init_state_token: tl.constexpr, + stride_final_state_token: tl.constexpr, + stride_indices_seq: tl.constexpr, + stride_indices_tok: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace + IS_BETA_HEADWISE: tl. + constexpr, # whether beta is headwise vector or scalar, + USE_QK_L2NORM_IN_KERNEL: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_hv = i_nh // HV, i_nh % HV + i_h = i_hv // (HV // H) + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to( + tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + + if T == 0: + # no tokens to process for this sequence + return + + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + + p_q = q + (bos * H + i_h) * K + o_k + p_k = k + (bos * H + i_h) * K + o_k + p_v = v + (bos * HV + i_hv) * V + o_v + if IS_BETA_HEADWISE: + p_beta = beta + (bos * HV + i_hv) * V + o_v + else: + p_beta = beta + bos * HV + i_hv + p_g = g + bos * HV + i_hv + p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_k[:, None] & mask_v[None, :] + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + if IS_CONTINUOUS_BATCHING: + if IS_SPEC_DECODING: + i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1 + else: + i_t = 0 + p_h0 = h0 + tl.load(ssm_state_indices + i_n * stride_indices_seq + + i_t).to(tl.int64) * stride_init_state_token + else: + p_h0 = h0 + bos * HV * K * V + p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :] + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for i_t in range(0, T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_g = tl.load(p_g).to(tl.float32) + + if USE_QK_L2NORM_IN_KERNEL: + b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6) + b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6) + b_q = b_q * scale + # [BK, BV] + b_h *= exp(b_g) + # [BV] + b_v -= tl.sum(b_h * b_k[:, None], 0) + if IS_BETA_HEADWISE: + b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + b_v *= b_beta + # [BK, BV] + b_h += b_k[:, None] * b_v[None, :] + # [BV] + b_o = tl.sum(b_h * b_q[:, None], 0) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + + # keep the states for multi-query tokens + if INPLACE_FINAL_STATE: + p_ht = ht + tl.load(ssm_state_indices + i_n * stride_indices_seq + + i_t).to(tl.int64) * stride_final_state_token + else: + p_ht = ht + (bos + i_t) * stride_final_state_token + p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + p_q += H * K + p_k += H * K + p_o += HV * V + p_v += HV * V + p_g += HV + p_beta += HV * (V if IS_BETA_HEADWISE else 1) + + +def fused_recurrent_gated_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + inplace_final_state: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, + ssm_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + use_qk_l2norm_in_kernel: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + HV = v.shape[2] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, "NK > 1 is not supported yet" + num_stages = 3 + num_warps = 1 + + o = q.new_empty(NK, *v.shape) + if inplace_final_state: + final_state = initial_state + else: + final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype) + + stride_init_state_token = initial_state.stride(0) + stride_final_state_token = final_state.stride(0) + + if ssm_state_indices is None: + stride_indices_seq, stride_indices_tok = 1, 1 + elif ssm_state_indices.ndim == 1: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1 + else: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride() + + grid = (NK, NV, N * HV) + fused_recurrent_gated_delta_rule_fwd_kernel[grid]( + q=q, + k=k, + v=v, + g=g, + beta=beta, + o=o, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + ssm_state_indices=ssm_state_indices, + num_accepted_tokens=num_accepted_tokens, + scale=scale, + N=N, + T=T, + B=B, + H=H, + HV=HV, + K=K, + V=V, + BK=BK, + BV=BV, + stride_init_state_token=stride_init_state_token, + stride_final_state_token=stride_final_state_token, + stride_indices_seq=stride_indices_seq, + stride_indices_tok=stride_indices_tok, + IS_BETA_HEADWISE=beta.ndim == v.ndim, + USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, + INPLACE_FINAL_STATE=inplace_final_state, + num_warps=num_warps, + num_stages=num_stages, + ) + o = o.squeeze(0) + return o, final_state + + +class FusedRecurrentFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + inplace_final_state: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, + ssm_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + use_qk_l2norm_in_kernel: bool = False): + o, final_state = fused_recurrent_gated_delta_rule_fwd( + q=q.contiguous(), + k=k.contiguous(), + v=v.contiguous(), + g=g.contiguous(), + beta=beta.contiguous(), + scale=scale, + initial_state=initial_state, + inplace_final_state=inplace_final_state, + cu_seqlens=cu_seqlens, + ssm_state_indices=ssm_state_indices, + num_accepted_tokens=num_accepted_tokens, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + ) + + return o, final_state + + +def fused_recurrent_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor = None, + scale: float = None, + initial_state: torch.Tensor = None, + inplace_final_state: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, + ssm_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + use_qk_l2norm_in_kernel: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, T, HV, V]`. + GVA is applied if `HV > H`. + g (torch.Tensor): + g (decays) of shape `[B, T, HV]`. + beta (torch.Tensor): + betas of shape `[B, T, HV]`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, HV, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + inplace_final_state: bool: + Whether to store the final state in-place to save memory. + Default: `True`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + ssm_state_indices (Optional[torch.Tensor]): + Indices to map the input sequences to the initial/final states. + num_accepted_tokens (Optional[torch.Tensor]): + Number of accepted tokens for each sequence during decoding. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HV, V]`. + final_state (torch.Tensor): + Final state of shape `[N, HV, K, V]`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule + # inputs with equal lengths + >>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512 + >>> q = torch.randn(B, T, H, K, device='cuda') + >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1) + >>> v = torch.randn(B, T, HV, V, device='cuda') + >>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda')) + >>> beta = torch.rand(B, T, HV, device='cuda').sigmoid() + >>> h0 = torch.randn(B, HV, K, V, device='cuda') + >>> o, ht = fused_gated_recurrent_delta_rule( + q, k, v, g, beta, + initial_state=h0, + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = fused_gated_recurrent_delta_rule( + q, k, v, g, beta, + initial_state=h0, + cu_seqlens=cu_seqlens + ) + """ + if cu_seqlens is not None and q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing.") + if scale is None: + scale = k.shape[-1]**-0.5 + else: + assert scale > 0, "scale must be positive" + if beta is None: + beta = torch.ones_like(q[..., 0]) + o, final_state = FusedRecurrentFunction.apply( + q, + k, + v, + g, + beta, + scale, + initial_state, + inplace_final_state, + cu_seqlens, + ssm_state_indices, + num_accepted_tokens, + use_qk_l2norm_in_kernel, + ) + return o, final_state diff --git a/vllm/model_executor/layers/fla/ops/index.py b/vllm/model_executor/layers/fla/ops/index.py new file mode 100644 index 0000000000000..9eca32bc31a04 --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/index.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +import torch + +from vllm.triton_utils import triton + +from .utils import tensor_cache + + +@tensor_cache +def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + return cu_seqlens[1:] - cu_seqlens[:-1] + + +@tensor_cache +def prepare_chunk_indices(cu_seqlens: torch.LongTensor, + chunk_size: int) -> torch.LongTensor: + indices = torch.cat([ + torch.arange(n) + for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist() + ]) + return torch.stack([indices.eq(0).cumsum(0) - 1, indices], + 1).to(cu_seqlens) + + +@tensor_cache +def prepare_chunk_offsets(cu_seqlens: torch.LongTensor, + chunk_size: int) -> torch.LongTensor: + return torch.cat([ + cu_seqlens.new_tensor([0]), + triton.cdiv(prepare_lens(cu_seqlens), chunk_size) + ]).cumsum(-1) diff --git a/vllm/model_executor/layers/fla/ops/l2norm.py b/vllm/model_executor/layers/fla/ops/l2norm.py new file mode 100644 index 0000000000000..ef9788ceaf20e --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/l2norm.py @@ -0,0 +1,143 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import os +from typing import Optional + +import torch + +from vllm.triton_utils import tl, triton + +BT_LIST = [8, 16, 32, 64, 128] + +USE_DEFAULT_FLA_NORM = int(os.getenv("USE_DEFAULT_FLA_NORM", "0")) + + +@triton.autotune(configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8, 16, 32] +], + key=['D']) +@triton.jit +def l2norm_fwd_kernel1( + x, + y, + D, + BD: tl.constexpr, + eps, +): + i_t = tl.program_id(0) + x += i_t * D + y += i_t * D + # Compute mean and variance + cols = tl.arange(0, BD) + mask = cols < D + b_x = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32) + b_var = tl.sum(b_x * b_x, axis=0) + b_rstd = 1 / tl.sqrt(b_var + eps) + # tl.store(Rstd + i_t, rstd) + # Normalize and apply linear transformation + b_y = b_x * b_rstd + tl.store(y + cols, b_y, mask=mask) + + +@triton.autotune(configs=[ + triton.Config({'BT': BT}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST +], + key=['D']) +@triton.jit(do_not_specialize=["NB"]) +def l2norm_fwd_kernel( + x, + y, + eps, + NB, + T, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr, +): + i_t = tl.program_id(0) + p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32) + b_var = tl.sum(b_x * b_x, axis=1) + b_y = b_x / tl.sqrt(b_var + eps)[:, None] + p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr): + xoffset = tl.program_id(0) * MBLOCK + row_idx = xoffset + tl.arange(0, MBLOCK)[:, None] + xmask = row_idx < M + rindex = tl.arange(0, N)[None, :] + xs = tl.load(X + (rindex + N * row_idx), xmask).to(tl.float32) + square = tl.broadcast_to(xs * xs, [MBLOCK, N]) + square_sum = tl.sum(tl.where(xmask, square, 0), 1)[:, None] + rsqrt = tl.rsqrt(square_sum + eps) + tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask) + + +def l2norm_fwd(x: torch.Tensor, + eps: float = 1e-6, + output_dtype: Optional[torch.dtype] = None): + x_shape_og = x.shape + x = x.view(-1, x.shape[-1]) + # allocate output + if output_dtype is None: + y = torch.empty_like(x) + else: + y = torch.empty_like(x, dtype=output_dtype) + assert y.stride(-1) == 1 + T, D = x.shape[0], x.shape[-1] + # rstd = torch.empty((T,), dtype=torch.float32, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) + if D > BD: + raise RuntimeError("This layer doesn't support feature dim >= 64KB.") + + if not USE_DEFAULT_FLA_NORM: + MBLOCK = 32 + # M, N = x.shape + l2norm_fwd_kernel2[(triton.cdiv(T, MBLOCK), )]( + x, + y, + eps, + T, + D, + MBLOCK, + ) + else: + if D <= 512: + NB = triton.cdiv(T, 2048) + + def grid(meta): + return (triton.cdiv(T, meta['BT']), ) + + l2norm_fwd_kernel[grid]( + x, + y, + eps, + NB=NB, + T=T, + D=D, + BD=BD, + ) + else: + l2norm_fwd_kernel1[(T, )]( + x, + y, + eps=eps, + D=D, + BD=BD, + ) + + return y.view(x_shape_og) diff --git a/vllm/model_executor/layers/fla/ops/layernorm_guard.py b/vllm/model_executor/layers/fla/ops/layernorm_guard.py new file mode 100644 index 0000000000000..a733c6c81e369 --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/layernorm_guard.py @@ -0,0 +1,337 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Tri Dao +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2024, Tri Dao. + +# ruff: noqa: E501 +# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html +# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. +# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling. +# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. + +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from vllm.triton_utils import tl, triton + +from .utils import input_guard + + +def rms_norm_ref(x, + weight, + bias, + z=None, + eps=1e-6, + group_size=None, + norm_before_gate=True, + upcast=True): + dtype = x.dtype + weight = weight.float() + bias = bias.float() if bias is not None else None + if upcast: + x = x.float() + z = z.float() if z is not None else z + if z is not None and not norm_before_gate: + x = x * F.silu(z) + if group_size is None: + rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) + out = (x * rstd * weight) + bias if bias is not None else (x * rstd * + weight) + else: + x_group = rearrange(x, "... (g d) -> ... g d", d=group_size) + rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + + eps) + out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight + if bias is not None: + out = out + bias + if z is not None and norm_before_gate: + out *= F.silu(z) + return out.to(dtype) + + +@triton.heuristics({ + "HAS_BIAS": lambda args: args["B"] is not None, + "HAS_Z": lambda args: args["Z"] is not None, +}) +@triton.jit +def layer_norm_fwd_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Z, # pointer to the other branch + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_z_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_N: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_Z: tl.constexpr, + NORM_BEFORE_GATE: tl.constexpr, + IS_RMS_NORM: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + group = tl.program_id(1) + X += row * stride_x_row + group * N + Y += row * stride_y_row + group * N + if HAS_Z: + Z += row * stride_z_row + group * N + if not IS_RMS_NORM: + Mean += group * M + Rstd += group * M + W += group * N + if HAS_BIAS: + B += group * N + # Compute mean and variance + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + if HAS_Z and not NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=cols < N).to(tl.float32) + x *= z * tl.sigmoid(z) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + tl.store(Mean + row, mean) + xbar = tl.where(cols < N, x - mean, 0.) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + w = tl.load(W + cols, mask=mask).to(tl.float32) + if HAS_BIAS: + b = tl.load(B + cols, mask=mask).to(tl.float32) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + y = x_hat * w + b if HAS_BIAS else x_hat * w + if HAS_Z and NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=mask).to(tl.float32) + y *= z * tl.sigmoid(z) + # Write output + tl.store(Y + cols, y, mask=mask) + + +def layer_norm_fwd( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + z: torch.Tensor = None, + out: torch.Tensor = None, + group_size: int = None, + norm_before_gate: bool = True, + is_rms_norm: bool = False, +): + M, N = x.shape + if group_size is None: + group_size = N + assert N % group_size == 0 + ngroups = N // group_size + assert x.stride(-1) == 1 + if z is not None: + assert z.stride(-1) == 1 + assert z.shape == (M, N) + assert weight.shape == (N, ) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N, ) + # allocate output + if out is not None: + assert out.shape == x.shape + else: + out = torch.empty_like(x) + assert out.stride(-1) == 1 + mean = torch.empty((ngroups * M, ), dtype=torch.float32, + device=x.device) if not is_rms_norm else None + rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) + if group_size > BLOCK_N: + raise RuntimeError( + "This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_N // 256, 1), 8) + grid = (M, ngroups) + layer_norm_fwd_kernel[grid](x, + out, + weight, + bias, + z, + mean, + rstd, + x.stride(0), + out.stride(0), + z.stride(0) if z is not None else 0, + M, + group_size, + eps, + BLOCK_N=BLOCK_N, + NORM_BEFORE_GATE=norm_before_gate, + IS_RMS_NORM=is_rms_norm, + num_warps=num_warps) + return out, mean, rstd + + +class LayerNormFn(torch.autograd.Function): + + @input_guard + @staticmethod + def forward(ctx, + x, + weight, + bias, + z=None, + eps=1e-6, + group_size=None, + norm_before_gate=True, + is_rms_norm=False): + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z)) + """ + + x_shape_og = x.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if z is not None: + assert z.shape == x_shape_og + z = z.reshape(-1, z.shape[-1]) + if z.stride(-1) != 1: + z = z.contiguous() + weight = weight.contiguous() + if bias is not None: + bias = bias.contiguous() + y, mean, rstd = layer_norm_fwd( + x, + weight, + bias, + eps, + z=z, + group_size=group_size, + norm_before_gate=norm_before_gate, + is_rms_norm=is_rms_norm, + ) + ctx.save_for_backward(x, weight, bias, mean, rstd, z) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.group_size = group_size + ctx.norm_before_gate = norm_before_gate + ctx.is_rms_norm = is_rms_norm + return y.reshape(x_shape_og) + + +def layernorm_fn(x, + weight, + bias, + z=None, + eps=1e-6, + group_size=None, + norm_before_gate=True, + is_rms_norm=False): + return LayerNormFn.apply(x, weight, bias, z, eps, group_size, + norm_before_gate, is_rms_norm) + + +def rmsnorm_fn(x, + weight, + bias, + z=None, + eps=1e-6, + group_size=None, + norm_before_gate=True): + return LayerNormFn.apply(x, weight, bias, z, eps, group_size, + norm_before_gate, True) + + +class LayerNormGated(nn.Module): + + def __init__( + self, + hidden_size, + eps: float = 1e-5, + group_size: Optional[int] = None, + norm_before_gate: bool = True, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + """If group_size is not None, we do GroupNorm with each group having group_size elements. + group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group). + """ + + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.bias = nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.group_size = group_size + self.norm_before_gate = norm_before_gate + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) + torch.nn.init.zeros_(self.bias) + + def forward(self, x, z=None): + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z)) + """ + return layernorm_fn(x, + self.weight, + self.bias, + z=z, + group_size=self.group_size, + eps=self.eps, + norm_before_gate=self.norm_before_gate) + + +class RMSNormGated(nn.Module): + + def __init__( + self, + hidden_size, + eps: float = 1e-5, + group_size: Optional[int] = None, + norm_before_gate: bool = False, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + """If group_size is not None, we do GroupNorm with each group having group_size elements. + group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group). + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.register_parameter("bias", None) + self.group_size = group_size + self.norm_before_gate = norm_before_gate + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) + + def forward(self, x, z=None): + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z)) + """ + return rmsnorm_fn(x, + self.weight, + self.bias, + z=z, + eps=self.eps, + group_size=self.group_size, + norm_before_gate=self.norm_before_gate) diff --git a/vllm/model_executor/layers/fla/ops/op.py b/vllm/model_executor/layers/fla/ops/op.py new file mode 100644 index 0000000000000..8c29434ca106a --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/op.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import os + +from vllm.triton_utils import tl, tldevice, triton + +if os.environ.get('FLA_USE_FAST_OPS', '0') == '1': + div = tldevice.fast_dividef + exp = tldevice.fast_expf + log = tldevice.fast_logf + log2 = tldevice.fast_log2f +else: + + @triton.jit + def div_normal(x, y): + return x / y + + div = div_normal + exp = tl.exp + log = tl.log + log2 = tl.log2 + + +if not hasattr(tl, 'gather'): + + @triton.jit + def gather(src, index, axis, _builder=None): + # This is a fallback implementation when tl.gather is not supported + # In order to pass triton compiler, there is no actual gather operation + return src +else: + gather = tl.gather diff --git a/vllm/model_executor/layers/fla/ops/solve_tril.py b/vllm/model_executor/layers/fla/ops/solve_tril.py new file mode 100644 index 0000000000000..97cb0d800d411 --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/solve_tril.py @@ -0,0 +1,365 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +from typing import Optional + +import torch + +from vllm.triton_utils import tl, triton + +from .index import prepare_chunk_indices +from .utils import input_guard + + +@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4, 8] for num_stages in [2, 3, 4, 5] + ], + key=['BT'], +) +@triton.jit(do_not_specialize=['T']) +def solve_tril_16x16_kernel( + A, + Ad, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to( + tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to( + tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + A = A + (bos * H + i_h) * BT + Ad = Ad + (bos * H + i_h) * 16 + + offset = (i_t * 16) % BT + p_A = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * 16, offset), + (16, 16), (1, 0)) + p_Ai = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 16, 0), (16, 16), + (1, 0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32) + b_A = -tl.where( + tl.arange(0, 16)[:, None] > tl.arange(0, 16)[None, :], b_A, 0) + + o_i = tl.arange(0, 16) + for i in range(1, min(16, T - i_t * 16)): + b_a = -tl.load(A + (i_t * 16 + i) * H * BT + o_i + offset) + b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) + mask = o_i == i + b_A = tl.where(mask[:, None], b_a, b_A) + b_A += o_i[:, None] == o_i[None, :] + tl.store(p_Ai, + b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1)) + + +@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4, 8] for num_stages in [2, 3, 4, 5] + ], + key=['H', 'BT', 'IS_VARLEN'], +) +@triton.jit(do_not_specialize=['T']) +def merge_16x16_to_32x32_inverse_kernel(A, Ad, Ai, cu_seqlens, chunk_indices, + T, H: tl.constexpr, BT: tl.constexpr, + IS_VARLEN: tl.constexpr): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to( + tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to( + tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + A += (bos * H + i_h) * 32 + Ad += (bos * H + i_h) * 16 + Ai += (bos * H + i_h) * 32 + + p_A_21 = tl.make_block_ptr(A, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), + (16, 16), (1, 0)) + p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32, 0), + (16, 16), (1, 0)) + p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32 + 16, 0), + (16, 16), (1, 0)) + p_Ai_11 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32, 0), + (16, 16), (1, 0)) + p_Ai_22 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 16), + (16, 16), (1, 0)) + p_Ai_21 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), + (16, 16), (1, 0)) + + A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) + Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32) + Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32) + Ai_21 = -tl.dot(tl.dot(Ai_22, A_21, input_precision='ieee'), + Ai_11, + input_precision='ieee') + tl.store(p_Ai_11, + Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1)) + tl.store(p_Ai_22, + Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1)) + tl.store(p_Ai_21, + Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1)) + + +@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] for num_stages in [2, 3, 4, 5] + ], + key=['H', 'BT', 'IS_VARLEN'], +) +@triton.jit(do_not_specialize=['T']) +def merge_16x16_to_64x64_inverse_kernel(A, Ad, Ai, cu_seqlens, chunk_indices, + T, H: tl.constexpr, BT: tl.constexpr, + IS_VARLEN: tl.constexpr): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to( + tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to( + tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + A += (bos * H + i_h) * 64 + Ad += (bos * H + i_h) * 16 + Ai += (bos * H + i_h) * 64 + + p_A_21 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 16, 0), + (16, 16), (1, 0)) + p_A_32 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 32, 16), + (16, 16), (1, 0)) + p_A_31 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 32, 0), + (16, 16), (1, 0)) + p_A_43 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 32), + (16, 16), (1, 0)) + p_A_42 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 16), + (16, 16), (1, 0)) + p_A_41 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 0), + (16, 16), (1, 0)) + p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64, 0), + (16, 16), (1, 0)) + p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64 + 16, 0), + (16, 16), (1, 0)) + p_Ad_33 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64 + 32, 0), + (16, 16), (1, 0)) + p_Ad_44 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64 + 48, 0), + (16, 16), (1, 0)) + + A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) + A_32 = tl.load(p_A_32, boundary_check=(0, 1)).to(tl.float32) + A_31 = tl.load(p_A_31, boundary_check=(0, 1)).to(tl.float32) + A_43 = tl.load(p_A_43, boundary_check=(0, 1)).to(tl.float32) + A_42 = tl.load(p_A_42, boundary_check=(0, 1)).to(tl.float32) + A_41 = tl.load(p_A_41, boundary_check=(0, 1)).to(tl.float32) + + Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32) + Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32) + Ai_33 = tl.load(p_Ad_33, boundary_check=(0, 1)).to(tl.float32) + Ai_44 = tl.load(p_Ad_44, boundary_check=(0, 1)).to(tl.float32) + + Ai_21 = -tl.dot(tl.dot(Ai_22, A_21, input_precision='ieee'), + Ai_11, + input_precision='ieee') + Ai_32 = -tl.dot(tl.dot(Ai_33, A_32, input_precision='ieee'), + Ai_22, + input_precision='ieee') + Ai_43 = -tl.dot(tl.dot(Ai_44, A_43, input_precision='ieee'), + Ai_33, + input_precision='ieee') + + Ai_31 = -tl.dot(Ai_33, + tl.dot(A_31, Ai_11, input_precision='ieee') + + tl.dot(A_32, Ai_21, input_precision='ieee'), + input_precision='ieee') + Ai_42 = -tl.dot(Ai_44, + tl.dot(A_42, Ai_22, input_precision='ieee') + + tl.dot(A_43, Ai_32, input_precision='ieee'), + input_precision='ieee') + Ai_41 = -tl.dot(Ai_44, + tl.dot(A_41, Ai_11, input_precision='ieee') + + tl.dot(A_42, Ai_21, input_precision='ieee') + + tl.dot(A_43, Ai_31, input_precision='ieee'), + input_precision='ieee') + + p_Ai_11 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 0), + (16, 16), (1, 0)) + p_Ai_22 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 16), + (16, 16), (1, 0)) + p_Ai_33 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 32), + (16, 16), (1, 0)) + p_Ai_44 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 48), + (16, 16), (1, 0)) + p_Ai_21 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 0), + (16, 16), (1, 0)) + p_Ai_31 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 0), + (16, 16), (1, 0)) + p_Ai_32 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 16), + (16, 16), (1, 0)) + p_Ai_41 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 0), + (16, 16), (1, 0)) + p_Ai_42 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 16), + (16, 16), (1, 0)) + p_Ai_43 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 32), + (16, 16), (1, 0)) + tl.store(p_Ai_11, + Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1)) + tl.store(p_Ai_22, + Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1)) + tl.store(p_Ai_33, + Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1)) + tl.store(p_Ai_44, + Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1)) + tl.store(p_Ai_21, + Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1)) + tl.store(p_Ai_31, + Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1)) + tl.store(p_Ai_32, + Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1)) + tl.store(p_Ai_41, + Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1)) + tl.store(p_Ai_42, + Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1)) + tl.store(p_Ai_43, + Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1)) + + fill_zeros = tl.zeros((16, 16), dtype=tl.float32) + p_Ai_12 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 16), + (16, 16), (1, 0)) + p_Ai_13 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 32), + (16, 16), (1, 0)) + p_Ai_14 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 48), + (16, 16), (1, 0)) + p_Ai_23 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 32), + (16, 16), (1, 0)) + p_Ai_24 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 48), + (16, 16), (1, 0)) + p_Ai_34 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 48), + (16, 16), (1, 0)) + tl.store(p_Ai_12, + fill_zeros.to(p_Ai_12.dtype.element_ty, + fp_downcast_rounding="rtne"), + boundary_check=(0, 1)) + tl.store(p_Ai_13, + fill_zeros.to(p_Ai_13.dtype.element_ty, + fp_downcast_rounding="rtne"), + boundary_check=(0, 1)) + tl.store(p_Ai_14, + fill_zeros.to(p_Ai_14.dtype.element_ty, + fp_downcast_rounding="rtne"), + boundary_check=(0, 1)) + tl.store(p_Ai_23, + fill_zeros.to(p_Ai_23.dtype.element_ty, + fp_downcast_rounding="rtne"), + boundary_check=(0, 1)) + tl.store(p_Ai_24, + fill_zeros.to(p_Ai_24.dtype.element_ty, + fp_downcast_rounding="rtne"), + boundary_check=(0, 1)) + tl.store(p_Ai_34, + fill_zeros.to(p_Ai_34.dtype.element_ty, + fp_downcast_rounding="rtne"), + boundary_check=(0, 1)) + + +@input_guard +def solve_tril(A: torch.Tensor, + cu_seqlens: Optional[torch.Tensor] = None, + output_dtype: torch.dtype = torch.float) -> torch.Tensor: + """ + Compute the inverse of the lower triangular matrix + A should be strictly lower triangular, i.e., A.triu() == 0. + + Args: + A (torch.Tensor): + [B, T, H, K] + cu_seqlens (torch.Tensor): + The cumulative sequence lengths of the input tensor. + Default: None. + output_dtype (torch.dtype): + The dtype of the output tensor. Default: `torch.float` + + Returns: + (I + A)^-1 with the same shape as A + """ + assert A.shape[-1] in [16, 32, 64] + + B, T, H, BT = A.shape + Ad = torch.empty(B, + T, + H, + 16, + device=A.device, + dtype=torch.float if BT != 16 else output_dtype) + + chunk_indices = prepare_chunk_indices( + cu_seqlens, 16) if cu_seqlens is not None else None + NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, 16) + solve_tril_16x16_kernel[NT, B * H]( + A=A, + Ad=Ad, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + BT=BT, + ) + if BT == 16: + return Ad + + Ai = torch.empty(B, T, H, BT, device=A.device, dtype=output_dtype) + merge_fn = merge_16x16_to_32x32_inverse_kernel if BT == 32 else merge_16x16_to_64x64_inverse_kernel + chunk_indices = prepare_chunk_indices( + cu_seqlens, BT) if cu_seqlens is not None else None + NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT) + merge_fn[NT, B * H]( + A=A, + Ad=Ad, + Ai=Ai, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + BT=BT, + ) + return Ai diff --git a/vllm/model_executor/layers/fla/ops/utils.py b/vllm/model_executor/layers/fla/ops/utils.py new file mode 100644 index 0000000000000..7fd90cee45d0e --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/utils.py @@ -0,0 +1,180 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +import contextlib +import functools +import logging +import os +from enum import Enum +from typing import Any, Callable, Literal, Optional + +import torch + +from vllm.triton_utils import triton + +logger = logging.getLogger(__name__) + +COMPILER_MODE = os.getenv("FLA_COMPILER_MODE") == "1" +FLA_CI_ENV = os.getenv("FLA_CI_ENV") == "1" +FLA_GDN_FIX_BT = os.getenv("FLA_GDN_FIX_BT", "0") == "1" + +SUPPRESS_LEVEL = int(os.getenv("GDN_RECOMPUTE_SUPPRESS_LEVEL", "0")) + + +def tensor_cache( + fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: + """ + A decorator that caches the most recent results of a function with tensor inputs. + + This decorator will store the output of the decorated function for the most recent set of input tensors. + The cache is limited to a fixed size (default is 4). When the cache is full, the oldest entry will be removed. + + Args: + fn (Callable[..., torch.Tensor]): + The function to be decorated. It should take tensor inputs and return tensor outputs. + + Returns: + Callable[..., torch.Tensor]: + A wrapped version of the input function with single-entry caching. + """ + + cache_entries: tuple[Optional[tuple], Optional[dict], Any] = [] + cache_size = 4 + + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + nonlocal cache_entries, cache_size + for i, entry in enumerate(cache_entries): + last_args, last_kwargs, last_result = entry + if len(args) == len(last_args) and len(kwargs) == len(last_kwargs) \ + and all(a is b for a, b in zip(args, last_args)) \ + and all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()): + cache_entries = cache_entries[:i] + cache_entries[i + 1:] + [ + (args, kwargs, last_result) + ] + return last_result + + result = fn(*args, **kwargs) + + if len(cache_entries) >= cache_size: + cache_entries = cache_entries[1:] + cache_entries.append((args, kwargs, result)) + return result + + return wrapper + + +def input_guard( + fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: + """ + A decorator to make sure all input tensors are contiguous and set the device based on input tensors. + """ + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + contiguous_args = (i if not isinstance(i, torch.Tensor) else + i.contiguous() for i in args) + contiguous_kwargs = { + k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) + for k, v in kwargs.items() + } + + tensor = None + for arg in args: + if isinstance(arg, torch.Tensor): + tensor = arg + break + if tensor is None: + for value in kwargs.values(): + if isinstance(value, torch.Tensor): + tensor = value + break + + if tensor is not None: + ctx = torch.cuda.device(tensor.device.index) + else: + ctx = contextlib.nullcontext() + + with ctx: + return fn(*contiguous_args, **contiguous_kwargs) + + return wrapper + + +@functools.cache +def get_available_device() -> str: + try: + return triton.runtime.driver.active.get_current_target().backend + except BaseException: + return 'cpu' + + +@functools.cache +def _check_platform() -> Literal['nvidia', 'amd', 'intel', 'musa']: + device = get_available_device() + mapping = { + "cuda": "nvidia", + "hip": "amd", + "xpu": "intel", + } + # return the mapped value, or the original if not found + return mapping.get(device, device) + + +# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'. +# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs. +# Therefore, we need to check the triton backend to determine the actual GPU vendor. +device = get_available_device() if get_available_device() != 'hip' else 'cuda' +device_torch_lib = getattr(torch, device) +device_platform = _check_platform() + +is_amd = (device_platform == 'amd') +is_intel = (device_platform == 'intel') +is_nvidia = (device_platform == 'nvidia') +is_intel_alchemist = (is_intel + and 'Intel(R) Arc(TM) A' in torch.xpu.get_device_name(0)) +is_nvidia_hopper = (is_nvidia + and ('NVIDIA H' in torch.cuda.get_device_name(0) + or torch.cuda.get_device_capability()[0] >= 9)) +use_cuda_graph = (is_nvidia + and os.environ.get('FLA_USE_CUDA_GRAPH', '0') == '1') + + +def get_all_max_shared_mem(): + try: + return [ + triton.runtime.driver.active.utils.get_device_properties(i) + ['max_shared_mem'] for i in range(device_torch_lib.device_count()) + ] + except BaseException: + return [-1] + + +class Backend(Enum): + ADA = 101376 # RTX 4090 + AMPERE = 166912 # A100 + HOPPER = 232448 # H100 + DEFAULT = 102400 # Default + + @classmethod + def get_shared_memory(cls, arch: str) -> int: + try: + return cls[arch.upper()].value + except KeyError: + return cls.DEFAULT.value + + +@functools.cache +def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool: + try: + device_shared_mem_list = get_all_max_shared_mem() + max_shared_memory = device_shared_mem_list[tensor_idx] + return max_shared_memory >= Backend.get_shared_memory(arch) + except Exception: + return False diff --git a/vllm/model_executor/layers/fla/ops/wy_fast.py b/vllm/model_executor/layers/fla/ops/wy_fast.py new file mode 100644 index 0000000000000..70374eb650642 --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/wy_fast.py @@ -0,0 +1,114 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# ruff: noqa: E501 +from typing import Optional + +import torch + +from vllm.triton_utils import tl, triton + +from .index import prepare_chunk_indices + + +@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'IS_VARLEN'], +) +@triton.jit(do_not_specialize=['T']) +def recompute_w_u_fwd_kernel(k, v, beta, w, u, A, g, cu_seqlens, chunk_indices, + T, H: tl.constexpr, Hg: tl.constexpr, + K: tl.constexpr, V: tl.constexpr, + BT: tl.constexpr, BK: tl.constexpr, + BV: tl.constexpr, IS_VARLEN: tl.constexpr): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to( + tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to( + tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T, ), (H, ), + (i_t * BT, ), (BT, ), (0, )) + p_g = tl.make_block_ptr(g + (bos * H + i_h), (T, ), (H, ), (i_t * BT, ), + (BT, ), (0, )) + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), + (i_t * BT, 0), (BT, BT), (1, 0)) + b_beta = tl.load(p_beta, boundary_check=(0, )) + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_g = tl.exp(tl.load(p_g, boundary_check=(0, ))) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H * V, 1), + (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_u = tl.make_block_ptr(u + (bos * H + i_h) * V, (T, V), (H * V, 1), + (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k + (bos * Hg + i_h // (H // Hg)) * K, (T, K), + (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), + (1, 0)) + p_w = tl.make_block_ptr(w + (bos * H + i_h) * K, (T, K), (H * K, 1), + (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None] * b_g[:, None]).to(b_k.dtype) + b_w = tl.dot(b_A, b_kb) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + +def recompute_w_u_fwd( + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + g_cumsum: torch.Tensor, + A: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor], +) -> tuple[torch.Tensor, torch.Tensor]: + B, T, Hg, K, V = *k.shape, v.shape[-1] + H = v.shape[-2] + BT = A.shape[-1] + + chunk_indices = prepare_chunk_indices( + cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + BK = 64 + BV = 64 + u = torch.empty_like(v) + w = k.new_empty(B, T, H, K) + recompute_w_u_fwd_kernel[(NT, B * H)]( + k=k, + v=v, + beta=beta, + w=w, + u=u, + A=A, + g=g_cumsum, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + return w, u diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index a5326dfe84f6d..0ab6355f41565 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from math import log2 from typing import Optional import torch @@ -10,6 +11,7 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate) from vllm.model_executor.layers.fused_moe.utils import _resize_cache +from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils.deep_gemm import (fp8_m_grouped_gemm_nt_masked, is_deep_gemm_e8m0_used) @@ -24,35 +26,28 @@ def _silu_mul_fp8_quant_deep_gemm( y_q_ptr, # fp8 quantized activations (E, T, H) y_s_ptr, # 16-bit scales (E, T, G) counts_ptr, # int32 num tokens per expert (E) - # Sizes --------------------------------------------------------------- H: tl.constexpr, # hidden dimension (per output) GROUP_SIZE: tl.constexpr, # elements per group (usually 128) - # Strides for input (elements) --------------------------------------- stride_i_e, stride_i_t, stride_i_h, - # Strides for y_q (elements) ----------------------------------------- stride_yq_e, stride_yq_t, stride_yq_h, - # Strides for y_s (elements) ----------------------------------------- stride_ys_e, stride_ys_t, stride_ys_g, - # Stride for counts (elements) stride_counts_e, - # Numeric params ------------------------------------------------------ eps: tl.constexpr, fp8_min: tl.constexpr, fp8_max: tl.constexpr, use_ue8m0: tl.constexpr, - # Meta --------------------------------------------------------------- BLOCK: tl.constexpr, NUM_STAGES: tl.constexpr, @@ -101,17 +96,15 @@ def _silu_mul_fp8_quant_deep_gemm( tl.store(y_s_ptr + base_ys_offset + t * stride_ys_t, y_s) -def silu_mul_fp8_quant_deep_gemm( +def silu_mul_fp8_quant_deep_gemm_cuda( y: torch.Tensor, # (E, T, 2*H) tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert + num_parallel_tokens=16, group_size: int = 128, - eps: float = 1e-10, ) -> tuple[torch.Tensor, torch.Tensor]: """Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales - - y has shape (E, T, 2*H). The first half of the last dimension is + y has shape (E, T, 2*H). The first half of the last dimension is silu-activated, multiplied by the second half, then quantized into FP8. - Returns `(y_q, y_s)` where * `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H] * `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T) @@ -120,22 +113,17 @@ def silu_mul_fp8_quant_deep_gemm( E, T, H2 = y.shape assert H2 % 2 == 0, "last dim of y must be even (2*H)" H = H2 // 2 - G = H // group_size - assert H % group_size == 0, "H must be divisible by group_size" - assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E, \ - "tokens_per_expert must be shape (E,)" + G = (H + group_size - 1) // group_size + assert H % 8 == 0, "H must be divisible by 8" + assert group_size == 128, "H must be divisible by 8" + assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E + tokens_per_expert = tokens_per_expert.to(device=y.device, dtype=torch.int32) - # allocate outputs fp8_dtype = torch.float8_e4m3fn y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device) - # strides (elements) - stride_i_e, stride_i_t, stride_i_h = y.stride() - stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride() - - # desired scale strides (elements): (T*G, 1, T) stride_ys_e = T * G stride_ys_t = 1 stride_ys_g = T @@ -144,47 +132,86 @@ def silu_mul_fp8_quant_deep_gemm( dtype=torch.float32, device=y.device) - stride_cnt_e = tokens_per_expert.stride()[0] + use_ue8m0 = is_deep_gemm_e8m0_used() - # Static grid over experts and H-groups. - # A loop inside the kernel handles the token dim - grid = (E * G, ) + if E <= 16: + max_empirical_parallelism = 64 + elif E <= 32: + max_empirical_parallelism = 16 + else: + max_empirical_parallelism = 4 - f_info = torch.finfo(fp8_dtype) - fp8_max = f_info.max - fp8_min = f_info.min + # We never want to launch more than Tx number of threads + # This computes the clip. + num_parallel_tokens = max( + 1, + min(max_empirical_parallelism, 2**int(log2(min(num_parallel_tokens, + T))))) + cuda_arch = current_platform.get_device_capability( + device_id=y.device.index).to_int() - _silu_mul_fp8_quant_deep_gemm[grid]( - y, - y_q, - y_s, - tokens_per_expert, - H, - group_size, - stride_i_e, - stride_i_t, - stride_i_h, - stride_yq_e, - stride_yq_t, - stride_yq_h, - stride_ys_e, - stride_ys_t, - stride_ys_g, - stride_cnt_e, - eps, - fp8_min, - fp8_max, - is_deep_gemm_e8m0_used(), - BLOCK=group_size, - NUM_STAGES=4, - num_warps=1, - ) + if cuda_arch >= 80: + torch.ops._C.silu_mul_fp8_quant_deep_gemm_cuda(y, tokens_per_expert, + y_q, y_s, group_size, + use_ue8m0, + num_parallel_tokens) + else: + # Default to triton if not on cuda or if arch is too old + y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device) + + stride_cnt_e = tokens_per_expert.stride()[0] + + # Static grid over experts and H-groups. + # A loop inside the kernel handles the token dim + grid = (E * G, ) + # strides (elements) + stride_i_e, stride_i_t, stride_i_h = y.stride() + stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride() + + # desired scale strides (elements): (T*G, 1, T) + stride_ys_e = T * G + stride_ys_t = 1 + stride_ys_g = T + y_s = torch.empty_strided( + (E, T, G), + (stride_ys_e, stride_ys_t, stride_ys_g), + dtype=torch.float32, + device=y.device, + ) + f_info = torch.finfo(fp8_dtype) + fp8_max = f_info.max + fp8_min = f_info.min + eps: float = 1e-10 + _silu_mul_fp8_quant_deep_gemm[grid]( + y, + y_q, + y_s, + tokens_per_expert, + H, + group_size, + stride_i_e, + stride_i_t, + stride_i_h, + stride_yq_e, + stride_yq_t, + stride_yq_h, + stride_ys_e, + stride_ys_t, + stride_ys_g, + stride_cnt_e, + eps, + fp8_min, + fp8_max, + is_deep_gemm_e8m0_used(), + BLOCK=group_size, + NUM_STAGES=4, + num_warps=1, + ) return y_q, y_s class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): - # The Deep Gemm kernels only support block size of 128 DEEPGEMM_BLOCK_SHAPE: list[int] = [128, 128] @@ -297,8 +324,8 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, w1_scale), workspace1, expert_num_tokens, expected_m) - a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1, - expert_num_tokens) + a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm_cuda( + workspace1, expert_num_tokens) fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, w2_scale), output, expert_num_tokens, expected_m) diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000000..cc853947c19f5 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=384,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=384,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..b4e736bec9b65 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=384,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=384,N=128,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=384,N=128,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..bb71005a72bc5 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=384,N=128,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..ac53df14ce846 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..f1ed617d6308f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=384,N=256,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=384,N=256,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..e72282dc5bcd9 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=384,N=256,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_B200.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_B200.json new file mode 100644 index 0000000000000..d104aa5167b22 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_B200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000000..94408e279b656 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_H20-3e.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_H20-3e.json new file mode 100644 index 0000000000000..9f4c3cbc9b8a9 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_H20-3e.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_H200.json new file mode 100644 index 0000000000000..20146f53a6eba --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_B200.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_B200.json new file mode 100644 index 0000000000000..d0140252594f5 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_B200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000000..cc1427c139e3e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_H20-3e.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_H20-3e.json new file mode 100644 index 0000000000000..68649395a23ed --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_H20-3e.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_H200.json new file mode 100644 index 0000000000000..2f0b45014e863 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_B200.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_B200.json new file mode 100644 index 0000000000000..5d69efe9ed5f9 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_B200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000000..564ff499d43c4 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_H20-3e.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_H20-3e.json new file mode 100644 index 0000000000000..a68c83147eeb3 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_H20-3e.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_H200.json new file mode 100644 index 0000000000000..e55df46b40269 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=64,device_name=NVIDIA_B200.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=64,device_name=NVIDIA_B200.json new file mode 100644 index 0000000000000..a0855a921f3f6 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=64,device_name=NVIDIA_B200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=64,device_name=NVIDIA_H20-3e.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=64,device_name=NVIDIA_H20-3e.json new file mode 100644 index 0000000000000..5dd1a8e19c2ce --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=64,device_name=NVIDIA_H20-3e.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=64,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=64,device_name=NVIDIA_H200.json new file mode 100644 index 0000000000000..d5b6d02123d71 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=64,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py index 2bbe523b4bf98..2a3ae478f3eab 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -49,14 +49,14 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): return torch.int64 def _get_dispatch_config(self) -> Optional[deep_ep.Config]: - if self.dp_size not in self.available_rank_configs: + if self.num_dispatchers_ not in self.available_rank_configs: return None - return deep_ep.Buffer.get_dispatch_config(self.dp_size) + return deep_ep.Buffer.get_dispatch_config(self.num_dispatchers_) def _get_combine_config(self) -> Optional[deep_ep.Config]: - if self.dp_size not in self.available_rank_configs: + if self.num_dispatchers_ not in self.available_rank_configs: return None - return deep_ep.Buffer.get_combine_config(self.dp_size) + return deep_ep.Buffer.get_combine_config(self.num_dispatchers_) def _do_dispatch( self, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 272ad39565375..a90a71159f721 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -35,7 +35,7 @@ from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum -from vllm.utils import (direct_register_custom_op, has_deep_ep, has_pplx, +from vllm.utils import (cdiv, direct_register_custom_op, has_deep_ep, has_pplx, round_up) if current_platform.is_cuda_alike(): @@ -755,7 +755,7 @@ class FusedMoE(CustomOp): intermediate_size: Intermediate size of the experts params_dtype: Data type for the parameters. reduce_results: Whether to all all_reduce on the output of the layer - renomalize: Whether to renormalize the logits in the fused_moe kernel + renormalize: Whether to renormalize the logits in the fused_moe kernel quant_config: Quantization configure. enable_eplb: Whether to enable expert parallelism load balancer. """ @@ -786,6 +786,7 @@ class FusedMoE(CustomOp): enable_eplb: bool = False, num_redundant_experts: int = 0, has_bias: bool = False, + is_sequence_parallel=False, ): super().__init__() if params_dtype is None: @@ -797,6 +798,10 @@ class FusedMoE(CustomOp): dp_size_ = (dp_size if dp_size is not None else get_dp_group().world_size) + self.is_sequence_parallel = is_sequence_parallel + if self.is_sequence_parallel: + self.sp_size = tp_size_ + vllm_config = get_current_vllm_config() self.moe_parallel_config: FusedMoEParallelConfig = ( FusedMoEParallelConfig.make( @@ -808,9 +813,16 @@ class FusedMoE(CustomOp): # we are padding globally so EP buffer allocation works if quant_config and quant_config.get_name() == "mxfp4": - from vllm.model_executor.layers.quantization.mxfp4 import ( # noqa: E501 - should_use_flashinfer_mxfp4) - if current_platform.is_rocm() or should_use_flashinfer_mxfp4(): + from vllm.model_executor.layers.quantization.mxfp4 import ( + Mxfp4Backend, get_mxfp4_backend) + current_mxfp4_backend = get_mxfp4_backend() + if (current_mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16 + or current_mxfp4_backend + == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS): + hidden_size = round_up(hidden_size, 128) + elif (current_platform.is_rocm() or current_mxfp4_backend + == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM or + current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16): hidden_size = round_up(hidden_size, 256) # For smuggling this layer into the fused moe custom op @@ -1588,7 +1600,7 @@ class FusedMoE(CustomOp): else: return tensor_model_parallel_all_reduce(final_hidden_states) - def forward( + def forward_native( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, @@ -1622,6 +1634,13 @@ class FusedMoE(CustomOp): return (shared_output[..., :og_hidden_states], fused_output[..., :og_hidden_states]) + def forward_cuda( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + return self.forward_native(hidden_states, router_logits) + def forward_impl_chunked( self, full_hidden_states: torch.Tensor, @@ -1699,14 +1718,22 @@ class FusedMoE(CustomOp): ctx = get_forward_context() # flashinfer_cutlass_kernels can handle: optional DP + TP/EP - max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu + max_tokens_across_dispatchers = ctx.dp_metadata.max_tokens_across_dp_cpu moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens + + # If the input to the MoE is sequence parallel then divide by sp_size + # to find the maximum number of tokens for any individual dispatcher. + if self.is_sequence_parallel: + max_tokens_across_dispatchers = cdiv(max_tokens_across_dispatchers, + self.sp_size) + num_tokens = full_hidden_states.size(0) for chunk_idx, chunk_start_ in enumerate( - range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank)): + range(0, max_tokens_across_dispatchers, + moe_dp_chunk_size_per_rank)): chunk_start = chunk_start_ chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank, - max_tokens_across_dp) + max_tokens_across_dispatchers) # clamp start and end chunk_start = min(chunk_start, num_tokens - 1) chunk_end = min(chunk_end, num_tokens) @@ -1742,9 +1769,6 @@ class FusedMoE(CustomOp): self.dp_size > 1 and not self.moe_parallel_config.use_deepep_ht_kernels and not self.moe_config.use_flashinfer_cutlass_kernels) - if do_naive_dispatch_combine: - hidden_states, router_logits = get_ep_group().dispatch( - hidden_states, router_logits) # If there are shared experts but we are not using a modular kernel, the # shared experts must be called here @@ -1755,6 +1779,10 @@ class FusedMoE(CustomOp): else: shared_output = None + if do_naive_dispatch_combine: + hidden_states, router_logits = get_ep_group().dispatch( + hidden_states, router_logits) + # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, @@ -1787,8 +1815,9 @@ class FusedMoE(CustomOp): final_hidden_states, ) - def reduce_output(states: torch.Tensor) -> torch.Tensor: - if do_naive_dispatch_combine: + def reduce_output(states: torch.Tensor, + do_combine: bool = True) -> torch.Tensor: + if do_naive_dispatch_combine and do_combine: states = get_ep_group().combine(states) if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): @@ -1797,10 +1826,11 @@ class FusedMoE(CustomOp): return states if self.shared_experts is None: + assert not isinstance(final_hidden_states, tuple) return reduce_output(final_hidden_states) else: return ( - reduce_output(final_hidden_states[0]), + reduce_output(final_hidden_states[0], do_combine=False), reduce_output(final_hidden_states[1]), ) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 7a8c6f8571deb..281563c3bfca2 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -302,7 +302,7 @@ class FusedMoEPrepareAndFinalize(ABC): def max_num_tokens_per_rank(self) -> Optional[int]: """ Some PrepareFinalize All2All implementations are batched. Meaning, - they can processes only as set of tokens at a time. This + they can process only as set of tokens at a time. This function returns the batch size i.e the maximum number of tokens the implementation can process at a time. Return None if there are no such restrictions. diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index f14f13e2ade9d..13c3ab4f06dd1 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -420,9 +420,8 @@ def shuffle_weights( Args: *tensors: Variable number of torch.Tensor objects. - layout: A pair of integers specifying the - block sizes used to divide the tensors during shuffling. - Default is (16, 16). + layout: A pair of integers specifying the block sizes used to divide + the tensors during shuffling. Default is (16, 16). Returns: A Tuple of shuffled tensors. diff --git a/vllm/model_executor/layers/fused_moe/routing_simulator.py b/vllm/model_executor/layers/fused_moe/routing_simulator.py index c8b107f13cd0d..8758a570b3c63 100644 --- a/vllm/model_executor/layers/fused_moe/routing_simulator.py +++ b/vllm/model_executor/layers/fused_moe/routing_simulator.py @@ -10,7 +10,7 @@ like uniform random routing. """ from abc import ABC, abstractmethod -from typing import Optional +from typing import Any, Optional import torch @@ -50,7 +50,9 @@ class DistributionBasedRouting(RoutingStrategy): distributions for testing different routing patterns. """ - def __init__(self, distribution: str = "uniform", **distribution_params): + def __init__(self, + distribution: str = "uniform", + **distribution_params: Any): """ Initialize distribution-based routing. @@ -244,7 +246,7 @@ class RoutingSimulator: cls._routing_strategies[name] = strategy @classmethod - def get_available_strategies(cls): + def get_available_strategies(cls) -> list[str]: """ Get list of available routing strategy names. diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index a5fc1db2dc10f..f875f712ba9c9 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -9,11 +9,11 @@ import torch.nn as nn import vllm.envs as envs from vllm.model_executor.custom_op import CustomOp from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op def is_rocm_aiter_rmsnorm_enabled() -> bool: - return current_platform.is_rocm() \ - and envs.VLLM_ROCM_USE_AITER_RMSNORM \ + return envs.VLLM_ROCM_USE_AITER_RMSNORM \ and envs.VLLM_ROCM_USE_AITER @@ -43,8 +43,22 @@ def fused_add_rms_norm( return x, residual -def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor, - variance_epsilon: float) -> torch.Tensor: +def poly_norm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, + variance_epsilon: float) -> torch.Tensor: + from vllm import _custom_ops as ops + out = torch.empty_like(x) + ops.poly_norm( + out, + x, + weight, + bias, + variance_epsilon, + ) + return out + + +def rocm_aiter_rms_norm_impl(x: torch.Tensor, weight: torch.Tensor, + variance_epsilon: float) -> torch.Tensor: import aiter as rocm_aiter if x.dim() > 2: x_original_shape = x.shape @@ -55,7 +69,7 @@ def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor, return rocm_aiter.rms_norm(x, weight, variance_epsilon) -def rocm_aiter_fused_add_rms_norm( +def rocm_aiter_rmsnorm2d_fwd_with_add_impl( x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]: @@ -74,14 +88,48 @@ def rocm_aiter_fused_add_rms_norm( return output, residual_out -def dispatch_cuda_rmsnorm_func(add_residual: bool): - if add_residual: - if is_rocm_aiter_rmsnorm_enabled(): - return rocm_aiter_fused_add_rms_norm - return fused_add_rms_norm +def rocm_aiter_rms_norm_fake(x: torch.Tensor, weight: torch.Tensor, + variance_epsilon: float) -> torch.Tensor: + return torch.empty_like(x) - if is_rocm_aiter_rmsnorm_enabled(): - return rocm_aiter_rms_norm + +def rocm_aiter_rmsnorm2d_fwd_with_add_fake( + x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, + variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]: + return torch.empty_like(x), torch.empty_like(residual) + + +if current_platform.is_rocm(): + direct_register_custom_op( + op_name="rocm_aiter_rms_norm", + op_func=rocm_aiter_rms_norm_impl, + mutates_args=[], + fake_impl=rocm_aiter_rms_norm_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_rmsnorm2d_fwd_with_add", + op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl, + mutates_args=[], + fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake, + dispatch_key=current_platform.dispatch_key, + ) + + +def dispatch_rocm_rmsnorm_func(with_fused_add: bool, dtype: torch.dtype): + use_aiter = is_rocm_aiter_rmsnorm_enabled() and dtype in [ + torch.float16, torch.bfloat16 + ] + + if use_aiter and with_fused_add: + return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add + if use_aiter: + return torch.ops.vllm.rocm_aiter_rms_norm + + # fall back to CUDA implementation + if with_fused_add: + return fused_add_rms_norm return rms_norm @@ -114,6 +162,13 @@ class RMSNorm(CustomOp): self.weight = torch.ones(hidden_size) if self.has_weight: self.weight = nn.Parameter(self.weight) + weight_dtype = self.weight.data.dtype + + if current_platform.is_rocm(): + self.rocm_norm_func = dispatch_rocm_rmsnorm_func( + with_fused_add=False, dtype=weight_dtype) + self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func( + with_fused_add=True, dtype=weight_dtype) def forward_native( self, @@ -162,13 +217,27 @@ class RMSNorm(CustomOp): return self.forward_native(x, residual) add_residual = residual is not None - norm_func = dispatch_cuda_rmsnorm_func(add_residual) - if add_residual: - return norm_func(x, residual, self.weight.data, - self.variance_epsilon) + return fused_add_rms_norm(x, residual, self.weight.data, + self.variance_epsilon) else: - return norm_func(x, self.weight.data, self.variance_epsilon) + return rms_norm(x, self.weight.data, self.variance_epsilon) + + def forward_hip( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + if self.variance_size_override is not None: + return self.forward_native(x, residual) + + add_residual = residual is not None + if add_residual: + return self.rocm_norm_func_with_add(x, residual, self.weight.data, + self.variance_epsilon) + else: + return self.rocm_norm_func(x, self.weight.data, + self.variance_epsilon) def forward_xpu( self, @@ -265,3 +334,48 @@ class GemmaRMSNorm(CustomOp): self.forward_static) self._is_compiled = True return self.forward_native(x, residual) + + +@CustomOp.register("poly_norm") +class PolyNorm(CustomOp): + """Polynomial normalization. + + Computes x -> w_0 * RMSNorm(x^3) + w_1 * RMSNorm(x^2) + w_2 * RMSNorm(x) + b + where w_n is the learned weight and b is the bias. + Refer to https://arxiv.org/html/2411.03884v1 + """ + + def __init__( + self, + eps: float = 1e-6, + ) -> None: + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(3) / 3) + self.bias = torch.nn.Parameter(torch.zeros(1)) + self.variance_epsilon = eps + + def _norm(self, x): + return x / torch.sqrt( + x.pow(2).mean(-1, keepdim=True) + self.variance_epsilon) + + def forward_native( + self, + x: torch.Tensor, + ) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward(). + + Refer to https://github.com/BryceZhuo/PolyCom?tab=readme-ov-file/README.md + """ + + orig_dtype = x.dtype + x_float = x.to(torch.float32) + output = (self.weight[0] * self._norm(x_float**3) + + self.weight[1] * self._norm(x_float**2) + + self.weight[2] * self._norm(x_float) + self.bias) + return output.to(orig_dtype) + + def forward_cuda( + self, + x: torch.Tensor, + ) -> torch.Tensor: + return poly_norm(x, self.weight, self.bias, self.variance_epsilon) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index fd88eac55cb51..773dfeae25d93 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -740,7 +740,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): """ Handle special case for models where MLP layers are already fused on disk. In this case, we have no shard id. This function - determmines the shard id by splitting these layers and then calls + determines the shard id by splitting these layers and then calls the weight loader using the shard id. An example of a model with these fused layers: @@ -914,7 +914,7 @@ class QKVParallelLinear(ColumnParallelLinear): """ Handle special case for models where QKV layers are already fused on disk. In this case, we have no shard id. This function - determmines the shard id by splitting these layers and then calls + determines the shard id by splitting these layers and then calls the weight loader using the shard id. An example of a model with these fused layers: diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index e93be9bfb1657..8a4ac214443eb 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -6,11 +6,11 @@ from concurrent.futures import ThreadPoolExecutor from typing import Optional import torch -import torch.nn as nn import vllm.envs as envs from vllm.distributed import (tensor_model_parallel_all_gather, tensor_model_parallel_gather) +from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -22,7 +22,8 @@ if envs.VLLM_LOGITS_PROCESSOR_THREADS is not None: envs.VLLM_LOGITS_PROCESSOR_THREADS) -class LogitsProcessor(nn.Module): +@CustomOp.register("logits_processor") +class LogitsProcessor(CustomOp): """Process logits and apply logits processors from sampling metadata. This layer does the following: diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index bb3fdd38dbef3..1623a2fd562c7 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -291,6 +291,7 @@ class MambaMixer2(MambaBase, CustomOp): output_size=self.conv_dim, bias=use_conv_bias, quant_config=None, + prefix=f"{prefix}.conv1d", ) # unsqueeze to fit conv1d weights shape into the linear weights shape. # Can't do this in `weight_loader` since it already exists in @@ -303,6 +304,7 @@ class MambaMixer2(MambaBase, CustomOp): output_size=intermediate_size + self.conv_dim + self.num_heads, bias=use_bias, quant_config=quant_config, + prefix=f"{prefix}.in_proj", ) # - because in_proj is a concatenation of 3 weights, we @@ -322,7 +324,7 @@ class MambaMixer2(MambaBase, CustomOp): # - the weight already has a "weight_loader" attribute # which set_weight_attrs will raise if we do not # delete before trying to override it - # - ditto for the otther two weights below + # - ditto for the other two weights below delattr(self.conv1d.bias, "weight_loader") set_weight_attrs( self.conv1d.bias, @@ -402,6 +404,7 @@ class MambaMixer2(MambaBase, CustomOp): bias=use_bias, input_is_parallel=True, quant_config=quant_config, + prefix=f"{prefix}.out_proj", ) self.norm = Mixer2RMSNormGated(intermediate_size, diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index 1dc46639640b0..a6c1af91de421 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -70,6 +70,15 @@ class MambaStateDtypeCalculator: model_dtype) return (conv_state_dtype, ) + @classmethod + def gated_delta_net_state_dtype( + cls, + model_dtype: Union[ModelDType, torch.dtype], + mamba_cache_dtype: MambaDType, + ) -> tuple[torch.dtype, torch.dtype]: + state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype) + return (state_dtype, state_dtype) + class MambaStateShapeCalculator: @@ -163,3 +172,31 @@ class MambaStateShapeCalculator: # for n_groups == 1, this is exactly tp_size - n_groups return tp_size - ngroups + + @classmethod + def gated_delta_net_state_shape( + cls, + tp_world_size: int, + num_k_heads: int, + num_v_heads: int, + head_k_dim: int, + head_v_dim: int, + conv_kernel_size: int, + num_spec: int = 0, + use_v1: bool = True, + ): + conv_dim = (head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads) + conv_state_shape = ( + divide(conv_dim, tp_world_size), + conv_kernel_size - 1 + num_spec, + ) + + # In V0, the conv_state shape was swapped during allocation in + # MambaCacheManager, but in V1 it needs to be determined here at the + # calculation level + if use_v1: + conv_state_shape = conv_state_shape[1], conv_state_shape[0] + + temporal_state_shape = (divide(num_v_heads, + tp_world_size), head_k_dim, head_v_dim) + return conv_state_shape, temporal_state_shape diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index b8d4bbc37105d..a0478a359f91b 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -464,7 +464,9 @@ def causal_conv1d_fn( # 3. mapping from sequence x[idx] to a cache line at index as specified via cache_indices[idx] # 4. computation can be skipped if cache_indices[idx] == pad_slot_id num_cache_lines = conv_states.size(0) - assert (num_cache_lines, dim, width - 1) == conv_states.shape + assert (num_cache_lines == conv_states.shape[0] + and dim == conv_states.shape[1] + and width - 1 <= conv_states.shape[2]) stride_istate_seq = conv_states.stride(0) stride_istate_dim = conv_states.stride(1) stride_istate_token = conv_states.stride(2) @@ -623,6 +625,7 @@ def _causal_conv1d_update_kernel( conv_state_ptr, cache_seqlens_ptr, # circular buffer conv_state_indices_ptr, + num_accepted_tokens_ptr, o_ptr, # (batch, dim, seqlen) # Matrix dimensions batch: int, @@ -639,6 +642,7 @@ def _causal_conv1d_update_kernel( stride_conv_state_seq: tl.constexpr, stride_conv_state_dim: tl.constexpr, stride_conv_state_tok: tl.constexpr, + stride_state_indices: tl.constexpr, stride_o_seq: tl.constexpr, stride_o_dim: tl.constexpr, stride_o_token: tl.constexpr, @@ -649,6 +653,7 @@ def _causal_conv1d_update_kernel( KERNEL_WIDTH: tl.constexpr, SILU_ACTIVATION: tl.constexpr, IS_CONTINUOUS_BATCHING: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, NP2_STATELEN: tl.constexpr, USE_PAD_SLOT: tl.constexpr, BLOCK_N: tl.constexpr, @@ -663,8 +668,9 @@ def _causal_conv1d_update_kernel( if IS_CONTINUOUS_BATCHING: # mask = idx_seq < batch - conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq).to( - tl.int64) + conv_state_batch_coord = tl.load(conv_state_indices_ptr + + idx_seq * stride_state_indices).to( + tl.int64) else: conv_state_batch_coord = idx_seq if USE_PAD_SLOT: # noqa @@ -672,13 +678,32 @@ def _causal_conv1d_update_kernel( # not processing as this is not the actual sequence return + if IS_SPEC_DECODING: + # The rolling of conv state: + # + # Before forward, the conv_state is: + # [history1, history2, ..., historyM]. + # + # After forward, the conv_state becomes: + # [history2, ..., historyM, draft1, draft2, ..., draftN]. + # + # After acceptance, it becomes: + # + # - accept 1 tokens: [history2, ..., historyM, draft1] + # - accept 2 tokens: [history3, ..., historyM, draft1, draft2] + # - and so on. + conv_state_token_offset = (tl.load(num_accepted_tokens_ptr + idx_seq) - + 1) + else: + conv_state_token_offset = 0 + # STEP 1: READ init_state data conv_states_base = (conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) + (idx_feats * stride_conv_state_dim)) mask_w = idx_feats < dim - prior_tokens = conv_states_base + prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok if KERNEL_WIDTH >= 2: conv_states_ptrs = prior_tokens # [BLOCK_N] col0 = tl.load(conv_states_ptrs, mask_w, 0.0) @@ -695,11 +720,15 @@ def _causal_conv1d_update_kernel( # STEP 2: assume state_len > seqlen idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + # With speculative decoding, the conv_state updates works in a sliding + # window manner, at each forward pass, the tokens are shift by 1, so we + # load since idx_tokens + 1. conv_state_ptrs_source = ( conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) + + conv_state_token_offset * stride_conv_state_tok + (idx_feats * stride_conv_state_dim)[None, :] + - ((idx_tokens + seqlen) * stride_conv_state_tok)[:, None] - ) # [BLOCK_M, BLOCK_N] + ((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) * + stride_conv_state_tok)[:, None]) # [BLOCK_M, BLOCK_N] mask = ((conv_state_batch_coord < num_cache_lines) & ((idx_tokens + seqlen) < state_len)[:, None] & (idx_feats < dim)[None, :]) @@ -820,6 +849,7 @@ def causal_conv1d_update( activation: Union[bool, str, None] = None, cache_seqlens: Optional[torch.Tensor] = None, conv_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, pad_slot_id: int = PAD_SLOT_ID, metadata=None, validate_data=False, @@ -890,10 +920,14 @@ def causal_conv1d_update( ) # X (batch, dim, seqlen) stride_o_seq, stride_o_dim, stride_o_token = out.stride() - stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride( ) - state_len = width - 1 + stride_state_indices = conv_state_indices.stride( + 0) if conv_state_indices is not None else 0 + if num_accepted_tokens is not None: + state_len = width - 1 + (seqlen - 1) # effective state_len needed + else: + state_len = width - 1 np2_statelen = triton.next_power_of_2(state_len) def grid(META): @@ -910,6 +944,7 @@ def causal_conv1d_update( conv_state, cache_seqlens, conv_state_indices, + num_accepted_tokens, out, # Matrix dimensions batch, @@ -926,6 +961,7 @@ def causal_conv1d_update( stride_istate_seq, stride_istate_dim, stride_istate_token, + stride_state_indices, stride_o_seq, stride_o_dim, stride_o_token, @@ -936,6 +972,7 @@ def causal_conv1d_update( KERNEL_WIDTH=width, SILU_ACTIVATION=activation in ["silu", "swish"], IS_CONTINUOUS_BATCHING=conv_state_indices is not None, + IS_SPEC_DECODING=num_accepted_tokens is not None, NP2_STATELEN=np2_statelen, USE_PAD_SLOT=pad_slot_id is not None, BLOCK_N=256, diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index 365139e237c66..fb8350e191c94 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -289,6 +289,9 @@ def _chunk_scan_fwd_kernel( # get the cs at the offset boundary # - c_off == 0 is a passthrough + # - We need dA_cs at the boundary, defined by c_off - no need + # to increase pointer by pid_m (it is a constant offset, + # i.e. the same for all blocks) dA_cs_m_boundary = tl.load( dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize, mask=(((c_off - 1) > -1) and ((c_off) < chunk_size)), diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index ad58a9918f03c..a7b3c814859ce 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -502,7 +502,7 @@ def _chunk_state_varlen_kernel( dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize # If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk - # If HAS_INITSTATES==True need to consider two possiblties + # If HAS_INITSTATES==True need to consider two possibilities # - if start_idx < pid_c * chunk_size, then we need to take the past_states_ptrs # - if state_idx >= pid * chunk_size, then we need to insert initstates if ((start_idx < pid_c * chunk_size) # first chunk diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index d0b3e9e5235bf..fcc5c905bf77f 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -106,21 +106,24 @@ def _mamba_chunk_scan_combined_fwd(x, # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries # (middle term of factorization of off-diag blocks; A terms) # - for handling chunked prefill, this requires i) initial_states - # ii) seq_idx and iii) is_cont_batched to be all specified. + # ii) seq_idx iii) is_cont_batched and (iv) chunk_offsets to be all specified. # - When a new seq_idx is detected, we will stop passing the prev_state # and switch accordingly to the init_state corresponding to the new seq_idx. + # - We will also make sure that the dA_cumsum is taken only from the start of the + # sequence (hence we need the full dA_cumsum tensor and not just the values at chunk boundaries) # - this will ensure that states will be updated with the rightmost flushed seq_idx # of the previous chunk. This implies that the first chunk of states is either 0 # or equal to init_states of the first example. states, final_states = _state_passing_fwd( rearrange(states, "... p n -> ... (p n)"), - dA_cumsum[:, :, :, -1], + dA_cumsum, initial_states=rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None, seq_idx=seq_idx, chunk_size=chunk_size, out_dtype=state_dtype if state_dtype is not None else C.dtype, - is_cont_batched=cu_seqlens is not None) + is_cont_batched=cu_seqlens is not None, + chunk_offsets=chunk_offsets) states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states]) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index a28fc9ffad71b..d61c3a8cdbe9c 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -31,6 +31,8 @@ def _state_passing_fwd_kernel( dA_cs_ptr, initstates_ptr, seq_idx_ptr, + chunk_offsets_ptr, + chunk_meta_num, # Matrix dimensions dim, nchunks, @@ -51,6 +53,7 @@ def _state_passing_fwd_kernel( stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, + stride_dA_cs_csize, stride_initstates_batch, stride_initstates_head, stride_initstates_dim, @@ -66,7 +69,8 @@ def _state_passing_fwd_kernel( pid_h = tl.program_id(axis=2) pid_m = tl.program_id(axis=0) states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head - dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + ( + chunk_size - 1) * stride_dA_cs_csize out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head if HAS_INITSTATES: @@ -95,35 +99,62 @@ def _state_passing_fwd_kernel( tl.store(out_ptrs, states, mask=offs_m < dim) out_ptrs += stride_out_chunk - seq_idx = 0 + prev_seq_idx_chunk_end = 0 + logical_chunk_idx = 0 for c in range(nchunks): new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) dA_cs = tl.load(dA_cs_ptr).to(tl.float32) - scale = tl.exp(dA_cs) + scale_mask = True if HAS_SEQ_IDX: # - the seq to pass forward is the one that is flushed to the right # boundary. - # - that is given by seq_idx_new below. - seq_idx_new = tl.load(seq_idx_ptr + - (min((c + 1) * chunk_size, seqlen) - 1) * - stride_seq_idx_seqlen) + # - that is given by seq_idx_chunk_end below: the sequence index at the end of the chunk. + seq_idx_chunk_end = tl.load(seq_idx_ptr + (min( + (c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen) if HAS_INITSTATES: - if IS_CONT_BATCHED and seq_idx != seq_idx_new: + if IS_CONT_BATCHED and prev_seq_idx_chunk_end != seq_idx_chunk_end: # this means in the current chunk the rightmost flushed seq # has changed. # - so we do not propagate the state from previous chunk # - but rather we load that sequence's init state - initstates_ptrs = initstates_ptr + seq_idx_new * stride_initstates_batch + initstates_ptrs = initstates_ptr + seq_idx_chunk_end * stride_initstates_batch # - update state with seq_idx_new's init state states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - else: - scale = tl.where(seq_idx_new == seq_idx, scale, 0.0) - seq_idx = seq_idx_new + # - we need to consider the cumsum only of the last sequence in the chunk + # - find its starting position (given by c_off of the logical chunk index) + # - and subtract the cumsum just before that position from the total cumsum + # - first, update the logical chunk index (add the number of sequences in the current physical chunk): + # sequence index at the start of the current chunk + seq_idx_chunk_start = tl.load(seq_idx_ptr + + min(c * chunk_size, seqlen) * + stride_seq_idx_seqlen) + logical_chunk_idx += seq_idx_chunk_end - seq_idx_chunk_start + # - load the chunk offset: + c_off = tl.load(chunk_offsets_ptr + logical_chunk_idx, + mask=logical_chunk_idx < chunk_meta_num, + other=0) + # - if offset is 0, then the sequence starts at the beginning of the chunk, and we don't need to subtract anything + if c_off > 0: + # - dA_cs_ptr currently points to the cumsum at the end of the chunk - subtract the chunk size and add the offset + dA_cs_boundary = tl.load( + dA_cs_ptr - (chunk_size - 1) * stride_dA_cs_csize + + (c_off - 1) * stride_dA_cs_csize, + mask=(c_off - 1) > -1 and c_off < chunk_size, + other=0.0) + dA_cs -= dA_cs_boundary + + # - increment logical chunk index for every physical chunk + logical_chunk_idx += 1 + else: + scale_mask = seq_idx_chunk_end == prev_seq_idx_chunk_end + prev_seq_idx_chunk_end = seq_idx_chunk_end + + scale = tl.where(scale_mask, tl.exp(dA_cs), 0.0) states = scale * states + new_states if c < nchunks - 1: tl.store(out_ptrs, states, mask=offs_m < dim) @@ -136,28 +167,36 @@ def _state_passing_fwd_kernel( def _state_passing_fwd( states, - dA_chunk_cumsum, + dA_cumsum, initial_states=None, seq_idx=None, chunk_size=None, out_dtype=None, is_cont_batched=False, + chunk_offsets=None, ): batch, nchunks, nheads, dim = states.shape - assert dA_chunk_cumsum.shape == (batch, nheads, nchunks) + if chunk_size is None: + chunk_size = dA_cumsum.shape[-1] + else: + assert chunk_size == dA_cumsum.shape[-1] + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) if initial_states is not None: if is_cont_batched: # - if cu_seqlens is provided, then the initial states # are used for continuous batching. In which case we # require seq_idx to be provided - assert seq_idx is not None, "" + assert seq_idx is not None, "seq_idx must be provided for continuous batching" + # - we also need chunk_offsets to be provided, to account + # for computation of dA_cumsum from the start of the + # sequence + assert chunk_offsets is not None, "chunk_offsets must be provided for continuous batching" else: # - this is the regular batching case, where initial # states are used are for each example of the batch. assert initial_states.shape == (batch, nheads, dim) if seq_idx is not None: - assert chunk_size is not None seqlen = seq_idx.shape[-1] assert seq_idx.shape == (batch, seqlen) out_dtype = states.dtype if out_dtype is None else out_dtype @@ -173,13 +212,15 @@ def _state_passing_fwd( states, out, final_states, - dA_chunk_cumsum, + dA_cumsum, initial_states, seq_idx, + chunk_offsets, + len(chunk_offsets) if chunk_offsets is not None else 0, dim, nchunks, seqlen if seq_idx is not None else 0, - chunk_size if seq_idx is not None else 0, + chunk_size, states.stride(0), states.stride(1), states.stride(2), @@ -191,9 +232,10 @@ def _state_passing_fwd( final_states.stride(0), final_states.stride(1), final_states.stride(2), - dA_chunk_cumsum.stride(0), - dA_chunk_cumsum.stride(2), - dA_chunk_cumsum.stride(1), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(3), *((initial_states.stride(0), initial_states.stride(1), initial_states.stride(2)) if initial_states is not None else (0, 0, 0)), diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index afe7ea7b83924..b571a8f866990 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -5,7 +5,7 @@ from collections.abc import Mapping, Set from dataclasses import dataclass from enum import IntEnum from itertools import groupby -from typing import Callable, Optional, TypeVar, Union, cast +from typing import Callable, Optional, TypeVar, Union import torch import torch.nn as nn @@ -362,14 +362,13 @@ class PoolerIdentity(PoolerActivation): class PoolerNormalize(PoolerActivation): def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: - x = F.normalize(pooled_data.float(), p=2, dim=-1) - return x.to(pooled_data.dtype) + return F.normalize(pooled_data, p=2, dim=-1) class PoolerMultiLabelClassify(PoolerActivation): def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: - return F.sigmoid(pooled_data.float()).to(pooled_data.dtype) + return F.sigmoid(pooled_data) class PoolerClassify(PoolerActivation): @@ -394,9 +393,9 @@ class PoolerClassify(PoolerActivation): pooled_data.shape[-1]) if num_labels < 2: - return F.sigmoid(pooled_data.float()).to(pooled_data.dtype) + return F.sigmoid(pooled_data) - return F.softmax(pooled_data.float(), dim=-1).to(pooled_data.dtype) + return F.softmax(pooled_data, dim=-1) class LambdaPoolerActivation(PoolerActivation): @@ -432,8 +431,9 @@ class EmbeddingPoolerHead(PoolerHead): from vllm.model_executor.models.adapters import _load_st_projector vllm_config = get_current_vllm_config() - self.projector = _load_st_projector( + self.projector: Optional[nn.Module] = _load_st_projector( vllm_config.model_config) if vllm_config else None + self.head_dtype = vllm_config.model_config.head_dtype def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], pooling_metadata: PoolingMetadata): @@ -442,16 +442,11 @@ class EmbeddingPoolerHead(PoolerHead): pooled_data = torch.stack(pooled_data) # pooled_data shape: [batchsize, hidden_dimension] + pooled_data = pooled_data.to(self.head_dtype) + # Apply ST projector if self.projector is not None: - projector = cast(nn.Module, self.projector) - - def _proj(x: torch.Tensor) -> torch.Tensor: - orig_dtype = x.dtype - y = projector(x.to(torch.float32)) - return y.to(orig_dtype) - - pooled_data = _proj(pooled_data) + pooled_data = self.projector(pooled_data) # pooled_data shape: [batchsize, embedding_dimension] pooling_params = get_pooling_params(pooling_metadata) @@ -494,8 +489,18 @@ class RewardPoolerHead(PoolerHead): def __init__(self) -> None: super().__init__(activation=PoolerClassify(static_num_labels=False)) + from vllm.config import get_current_vllm_config + vllm_config = get_current_vllm_config() + self.head_dtype = vllm_config.model_config.head_dtype + def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], pooling_metadata: PoolingMetadata): + + if isinstance(pooled_data, list): + pooled_data = [p.to(self.head_dtype) for p in pooled_data] + else: + pooled_data = pooled_data.to(self.head_dtype) + pooling_params = get_pooling_params(pooling_metadata) # for softmax @@ -641,6 +646,7 @@ class ClassifierPooler(Pooler): self.act_fn = act_fn or PoolerClassify() self.logit_bias: Optional[ float] = vllm_config.model_config.pooler_config.logit_bias + self.head_dtype = vllm_config.model_config.head_dtype def get_supported_tasks(self) -> Set[PoolingTask]: return {"classify", "score"} @@ -655,6 +661,8 @@ class ClassifierPooler(Pooler): pooled_data = torch.stack(pooled_data) # pooled_data shape: [batchsize, hidden_size] + pooled_data = pooled_data.to(self.head_dtype) + if self.classifier is not None: pooled_data = self.classifier(pooled_data) # pooled_data shape: [batchsize, num_labels] diff --git a/vllm/model_executor/layers/quantization/auto_round.py b/vllm/model_executor/layers/quantization/auto_round.py index fb285413ba9ef..1ca92273430dd 100644 --- a/vllm/model_executor/layers/quantization/auto_round.py +++ b/vllm/model_executor/layers/quantization/auto_round.py @@ -327,6 +327,8 @@ class AutoRoundConfig(QuantizationConfig): if isinstance(layer, FusedMoE): if use_marlin: + return GPTQMarlinMoEMethod(quant_args_marlin, layer.moe) + else: from vllm.model_executor.layers.quantization.moe_wna16 import ( MoeWNA16Config) @@ -339,7 +341,6 @@ class AutoRoundConfig(QuantizationConfig): } return MoeWNA16Config.from_config(config).get_quant_method( layer, prefix) - return GPTQMarlinMoEMethod(quant_args_marlin, layer.moe) if isinstance(layer, (LinearBase, ParallelLMHead)): if use_marlin: diff --git a/vllm/model_executor/layers/quantization/bitblas.py b/vllm/model_executor/layers/quantization/bitblas.py index 39bd34d351f61..d05c0c0d5473c 100644 --- a/vllm/model_executor/layers/quantization/bitblas.py +++ b/vllm/model_executor/layers/quantization/bitblas.py @@ -202,7 +202,7 @@ class BitBLASLinearMethod(LinearMethodBase): output_size: int, params_dtype: torch.dtype, **extra_weight_attrs, - ): + ) -> None: """Creates quantized weights for use in linear operations. The function initializes and returns a dictionary containing quantized @@ -211,7 +211,7 @@ class BitBLASLinearMethod(LinearMethodBase): Args: input_size_per_partition: The size of the input partition. - output_size_per_partition: The size of the output partition. + output_partition_sizes: List of output partition sizes. input_size: The total size of the input (unused). output_size: The total size of the output (unused). params_dtype: @@ -222,9 +222,9 @@ class BitBLASLinearMethod(LinearMethodBase): scales ('scales'), and zeros ('zeros'). Raises: - ValueError: If `params_dtype` is not `torch.float16` or if the - input size per partition is not divisible by the group size in - `quant_config`. + ValueError: If `params_dtype` is not `torch.float16` or if the input + size per partition is not divisible by the group size + in `quant_config`. """ del input_size, output_size # Unused arguments. weight_loader = extra_weight_attrs["weight_loader"] diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 65e0b70621532..3d94626e5d8c6 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -30,7 +30,8 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights, select_cutlass_fp8_gemm_impl, swap_w13_to_w31) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace) + get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace, + should_use_deepgemm_for_fp8_linear) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin, prepare_moe_fp8_layer_for_marlin) @@ -449,10 +450,10 @@ class Fp8LinearMethod(LinearMethodBase): # Activations not quantized for marlin. del layer.input_scale - # On B200, if E8M0 for DeepGemm is used, we need to + # On Blackwell or Hopper, if E8M0 for DeepGemm is used, we need to # requantize the weight and input to the specific scale # at the same time. - if is_deep_gemm_e8m0_used(): + if is_deep_gemm_e8m0_used() and self.block_quant: assert layer.weight_block_size is not None block_sz = tuple(layer.weight_block_size) requant_weight_ue8m0_inplace( @@ -462,6 +463,15 @@ class Fp8LinearMethod(LinearMethodBase): block_sz, ) + # SM90 Block FP8 CUTLASS requires row-major weight scales + if (self.block_quant and current_platform.is_device_capability(90) + and self.cutlass_block_fp8_supported + and not should_use_deepgemm_for_fp8_linear( + torch.bfloat16, layer.weight)): + layer.weight_scale_inv = Parameter( + layer.weight_scale_inv.data.T.contiguous(), + requires_grad=False) + def apply(self, layer: torch.nn.Module, x: torch.Tensor, @@ -757,10 +767,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) - # DeepGemm scales need to be transposed and aligned. We try to do + # DeepGemm scales need to be transposed and aligned. We try to do # it ahead of time for performance reasons. if self.allow_deep_gemm and not is_deep_gemm_e8m0_used(): - # Lazy import to avoid CUDA initialization problems. if _is_col_major(layer.w13_weight_scale_inv): layer.w13_weight_scale_inv = \ get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv).contiguous() @@ -896,7 +905,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): del layer.w13_input_scale del layer.w2_input_scale - if is_deep_gemm_e8m0_used(): + if is_deep_gemm_e8m0_used() and self.block_quant: assert layer.weight_block_size is not None # Re-quantise the expert weights so their scales are UE8M0. block_sz = tuple(layer.weight_block_size) diff --git a/vllm/model_executor/layers/quantization/gptq_bitblas.py b/vllm/model_executor/layers/quantization/gptq_bitblas.py index d03074f861848..6462292586482 100644 --- a/vllm/model_executor/layers/quantization/gptq_bitblas.py +++ b/vllm/model_executor/layers/quantization/gptq_bitblas.py @@ -265,9 +265,9 @@ class GPTQBitBLASLinearMethod(LinearMethodBase): scales ('scales'), and zeros ('zeros'). Raises: - ValueError: If `params_dtype` is not `torch.float16` or - if the input size per partition is not divisible by the - group size in `quant_config`. + ValueError: If `params_dtype` is not `torch.float16` or if the input + size per partition is not divisible by the group size + in `quant_config`. """ if params_dtype != torch.float16: raise ValueError("Parameter data type must be torch.float16, " diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py index 4bcfcd04b3d8b..f10d20999bee3 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py @@ -46,11 +46,11 @@ def choose_mp_linear_kernel( performance. Args: - config (MPLinearLayerConfig): Description of the linear layer to be - implemented. + config (MPLinearLayerConfig): Description of the linear layer to be + implemented. compute_capability (Optional[int], optional): The compute capability of - the target device, if None uses `current_platform` to get the compute - capability. Defaults to None. + the target device, if None uses `current_platform` to get + the compute capability. Defaults to None. Raises: ValueError: If no kernel can implement the given config. diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index e5604670fb4c1..4c6fcda893a03 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -88,6 +88,7 @@ class BaseKVCacheMethod(QuantizeMethodBase): "Setting it to k_scale. This only matters for " "the flash-attn backend.") layer._q_scale.copy_(k_scale) + layer._q_scale_float = k_scale # These are used in the final Attention.forward() layer._k_scale.copy_(k_scale) @@ -124,6 +125,7 @@ class BaseKVCacheMethod(QuantizeMethodBase): # These are used in the final Attention.forward() layer._q_scale.copy_(q_scale) + layer._q_scale_float = q_scale layer._prob_scale.copy_(prob_scale) if layer.kv_cache_dtype == "fp8" and (q_scale == 1.0 or prob_scale == 1.0): diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index e140807879177..9b99931e7b43f 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch from torch.nn import Module @@ -45,6 +45,9 @@ from vllm.utils import next_power_of_2 from vllm.utils.flashinfer import (flashinfer_scaled_fp4_mm, has_flashinfer, has_flashinfer_moe) +if TYPE_CHECKING: + from vllm.model_executor.models.utils import WeightsMapper + logger = init_logger(__name__) QUANT_ALGOS = ["FP8", "NVFP4"] @@ -63,7 +66,7 @@ class ModelOptFp8Config(QuantizationConfig): super().__init__() self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized self.kv_cache_quant_method = kv_cache_quant_method - self.exclude_modules = exclude_modules + self.exclude_modules = exclude_modules or [] if is_checkpoint_fp8_serialized: logger.warning("Detected ModelOpt fp8 checkpoint. Please note that" " the format is experimental and could change.") @@ -84,6 +87,11 @@ class ModelOptFp8Config(QuantizationConfig): def get_config_filenames(cls) -> list[str]: return ["hf_quant_config.json"] + def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): + if self.exclude_modules is not None: + self.exclude_modules = hf_to_vllm_mapper.apply_list( + self.exclude_modules) + @classmethod def override_quantization_method( cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: @@ -170,7 +178,9 @@ class ModelOptFp8Config(QuantizationConfig): prefix: str) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import if isinstance(layer, LinearBase): - if self.is_layer_excluded(prefix): + if (is_layer_skipped(prefix, self.exclude_modules, + self.packed_modules_mapping) + or self.is_layer_excluded(prefix)): return UnquantizedLinearMethod() return ModelOptFp8LinearMethod(self) elif isinstance(layer, Attention): @@ -615,6 +625,11 @@ class ModelOptNvFp4Config(QuantizationConfig): def get_config_filenames(cls) -> list[str]: return ["hf_quant_config.json"] + def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): + if self.exclude_modules is not None: + self.exclude_modules = hf_to_vllm_mapper.apply_list( + self.exclude_modules) + @classmethod def override_quantization_method( cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: @@ -763,7 +778,8 @@ class ModelOptNvFp4Config(QuantizationConfig): prefix: str) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import if isinstance(layer, LinearBase): - if (is_layer_skipped(prefix, self.exclude_modules) + if (is_layer_skipped(prefix, self.exclude_modules, + self.packed_modules_mapping) or self.is_layer_excluded(prefix, self.exclude_modules)): return UnquantizedLinearMethod() return ModelOptNvFp4LinearMethod(self) diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 889c15df3c878..f935bdd84124a 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from enum import Enum from typing import Callable, Optional, Union import torch @@ -33,33 +34,72 @@ from vllm.utils.flashinfer import has_flashinfer logger = init_logger(__name__) -def _should_use_flashinfer_mxfp4_bf16(): - """Determine if FlashInfer MXFP4 BF16 should be used.""" - # If explicitly set, respect the setting - if envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"): - return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16 +# enum for mxfp4 backend +class Mxfp4Backend(Enum): + NONE = 0 - # Enable by default on SM100 if MXFP8 is not explicitly enabled - if (current_platform.is_device_capability(100) and has_flashinfer() - and not envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8")): - logger.info_once( - "Enabling FlashInfer MXFP4 BF16 backend by default for Blackwell. " - "For faster performance, consider setting " - "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, " - "though this may impact accuracy.") - return True + # FlashInfer Backend + SM100_FI_MXFP4_MXFP8_TRTLLM = 1 + SM100_FI_MXFP4_MXFP8_CUTLASS = 2 + SM100_FI_MXFP4_BF16 = 3 + SM90_FI_MXFP4_BF16 = 4 - return False + # Marlin Backend + MARLIN = 5 + + # Triton Backend + TRITON = 6 -def _should_use_flashinfer_mxfp4_mxfp8(): - """Determine if FlashInfer MXFP4 MXFP8 should be used.""" - return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 +def get_mxfp4_backend(): + # Backend Selection + if current_platform.is_cuda(): + if (current_platform.is_device_capability(90) and has_flashinfer() + and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): + logger.info_once("Using FlashInfer MXFP4 BF16 backend for SM90") + return Mxfp4Backend.SM90_FI_MXFP4_BF16 + elif (current_platform.is_device_capability(100) and has_flashinfer() + and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS): + logger.info_once( + "Using FlashInfer MXFP4 MXFP8 CUTLASS backend for SM100") + return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS + elif (current_platform.is_device_capability(100) and has_flashinfer() + and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8): + logger.info_once( + "Using FlashInfer MXFP4 MXFP8 TRTLLM backend for SM100, " + "for high concurrency throughput workloads consider setting " + "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS=1 for better " + "performance") + return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM + elif current_platform.is_device_capability(100) and has_flashinfer(): + logger.info_once( + "Using FlashInfer MXFP4 BF16 backend for SM100, " + "For faster performance on SM100, consider setting " + "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, though this may impact " + "accuracy.") + return Mxfp4Backend.SM100_FI_MXFP4_BF16 + elif ((current_platform.is_device_capability(100) + or current_platform.is_device_capability(90)) + and not has_flashinfer()): + logger.warning_once( + "MXFP4 MoE is enabled on Hopper/Blackwell but FlashInfer " + "is not available. This may result in degraded performance. " + "Please `pip install vllm[flashinfer]` for best results.") + # If FlashInfer is not available, try either Marlin or Triton + if current_platform.get_device_capability( + )[0] < 9 or not has_triton_kernels() or not is_torch_equal_or_newer( + "2.8.0"): + logger.info_once("Using Marlin backend") + return Mxfp4Backend.MARLIN + else: + logger.info_once("Using Triton backend") + return Mxfp4Backend.TRITON + elif current_platform.is_rocm() and has_triton_kernels(): + logger.info_once("Using Triton backend") + return Mxfp4Backend.TRITON -def should_use_flashinfer_mxfp4(): - return (_should_use_flashinfer_mxfp4_mxfp8() - or _should_use_flashinfer_mxfp4_bf16()) + return Mxfp4Backend.NONE class Mxfp4Config(QuantizationConfig): @@ -113,29 +153,14 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): super().__init__(moe) self.topk_indices_dtype = None self.moe = moe - self.use_marlin = self._should_use_marlin() + self.mxfp4_backend = get_mxfp4_backend() self.max_capture_size = get_current_vllm_config( ).compilation_config.max_capture_size - if current_platform.is_device_capability(100) and not has_flashinfer(): - logger.warning_once( - "MXFP4 MoE is enabled on Blackwell but FlashInfer " - "is not available. This may result in degraded performance. " - "Please `pip install vllm[flashinfer]` for best results.") - - def _should_use_marlin(self): - if envs.VLLM_MXFP4_USE_MARLIN is not None: - return envs.VLLM_MXFP4_USE_MARLIN - if current_platform.is_cuda() and \ - not current_platform.is_device_capability(100): - if not current_platform.has_device_capability(90): - # marlin kernel has better performance on ampere - return True - if not has_triton_kernels(): - return True - if not is_torch_equal_or_newer("2.8.0"): - return True - return False + assert self.mxfp4_backend != Mxfp4Backend.NONE, ( + "No MXFP4 MoE backend (FlashInfer/Marlin/Triton) available." + "Please check your environment and try again.") + self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {} def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, @@ -156,7 +181,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): intermediate_size_per_partition_after_pad = \ intermediate_size_per_partition - if self.use_marlin: + if self.mxfp4_backend == Mxfp4Backend.MARLIN: # The moe marlin kernel requires that for each linear # n % 256 == 0 and k % 128 == 0. # In gate_up_proj: @@ -174,16 +199,20 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): layer.hidden_size = hidden_size layer.intermediate_size_per_partition = \ intermediate_size_per_partition_after_pad - elif should_use_flashinfer_mxfp4(): + elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM + or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16): # pad the intermediate size to be a multiple of 2 * mxfp4_block # for to hold non-uniform sharded tensor as well as swizzling # other padding to increase performance intermediate_size_per_partition_after_pad = round_up( intermediate_size_per_partition, 256) hidden_size = round_up(hidden_size, 256) - elif current_platform.is_rocm(): + elif current_platform.is_rocm() or ( + self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS + or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16): intermediate_size_per_partition_after_pad = round_up( intermediate_size_per_partition, 128) + hidden_size = round_up(hidden_size, 128) else: intermediate_size_per_partition_after_pad = round_up( intermediate_size_per_partition, 64) @@ -263,10 +292,14 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): set_weight_attrs(w2_bias, extra_weight_attrs) def process_weights_after_loading(self, layer): - if self.use_marlin: + if self.mxfp4_backend == Mxfp4Backend.MARLIN: prepare_moe_fp4_layer_for_marlin(layer) - elif should_use_flashinfer_mxfp4(): - from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a + elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM + or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16): + from flashinfer.fp4_quantization import ( + nvfp4_block_scale_interleave) + from flashinfer.fused_moe.core import ( + _maybe_get_cached_w2_permute_indices) layer.gemm1_alpha = Parameter(torch.tensor( [1.702] * self.num_experts, dtype=torch.float32).cuda(), requires_grad=False) @@ -343,25 +376,63 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): gemm2_bias_shuffled = [] epilogue_tile_m = 128 # FIXME: this depends on the kernel internals for i in range(self.num_experts): - gemm1_weights_mxfp4_shuffled.append( - shuffle_matrix_a(w13_weight[i].view(torch.uint8), - epilogue_tile_m)) + # w13 weight shuffling + permute_indices = _maybe_get_cached_w2_permute_indices( + self._cache_permute_indices, + w13_weight[i].view(torch.uint8), + epilogue_tile_m, + ) + gemm1_weights_mxfp4_shuffled.append(w13_weight[i].view( + torch.uint8)[permute_indices.to( + w13_weight.device)].contiguous()) + # w13 scale shuffling + permute_sf_indices = _maybe_get_cached_w2_permute_indices( + self._cache_permute_indices, + w13_weight_scale[i].view(torch.uint8), + epilogue_tile_m, + num_elts_per_sf=16, + ) gemm1_scales_mxfp4_shuffled.append( - shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8), - epilogue_tile_m)) - gemm1_bias_shuffled.append( - shuffle_matrix_a(w13_bias[i].clone().reshape(-1, 1), - epilogue_tile_m)) - - gemm2_weights_mxfp4_shuffled.append( - shuffle_matrix_a(w2_weight[i].view(torch.uint8), - epilogue_tile_m)) + nvfp4_block_scale_interleave(w13_weight_scale[i].view( + torch.uint8)[permute_sf_indices.to( + w13_weight_scale.device)].contiguous())) + # w13 bias shuffling + permute_bias_indices = _maybe_get_cached_w2_permute_indices( + self._cache_permute_indices, + w13_bias[i].clone().reshape(-1, 1), + epilogue_tile_m, + ) + gemm1_bias_shuffled.append(w13_bias[i].clone().reshape( + -1, + 1)[permute_bias_indices.to(w13_bias.device)].contiguous()) + # w2 weight shuffling + permute_indices = _maybe_get_cached_w2_permute_indices( + self._cache_permute_indices, + w2_weight[i].view(torch.uint8), + epilogue_tile_m, + ) + gemm2_weights_mxfp4_shuffled.append(w2_weight[i].view( + torch.uint8)[permute_indices.to( + w2_weight.device)].contiguous()) + # w2 scale shuffling + permute_sf_indices = _maybe_get_cached_w2_permute_indices( + self._cache_permute_indices, + w2_weight_scale[i].view(torch.uint8), + epilogue_tile_m, + num_elts_per_sf=16, + ) gemm2_scales_mxfp4_shuffled.append( - shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8), - epilogue_tile_m)) - gemm2_bias_shuffled.append( - shuffle_matrix_a(w2_bias[i].clone().reshape(-1, 1), - epilogue_tile_m)) + nvfp4_block_scale_interleave(w2_weight_scale[i].view( + torch.uint8)[permute_sf_indices.to( + w2_weight_scale.device)].contiguous())) + # w2 bias shuffling + permute_indices = _maybe_get_cached_w2_permute_indices( + self._cache_permute_indices, + w2_bias[i].clone().reshape(-1, 1), + epilogue_tile_m, + ) + gemm2_bias_shuffled.append(w2_bias[i].clone().reshape( + -1, 1)[permute_indices.to(w2_bias.device)].contiguous()) w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled) w13_weight_scale = torch.stack( @@ -387,7 +458,116 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): layer.w2_bias = Parameter(torch.stack(gemm2_bias_shuffled).reshape( self.num_experts, -1), requires_grad=False) - else: + elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS + or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16): + layer.gemm1_alpha = Parameter(torch.tensor( + [1.702] * self.num_experts, dtype=torch.float32).cuda(), + requires_grad=False) + layer.gemm1_beta = Parameter(torch.tensor( + [1.0] * self.num_experts, dtype=torch.float32).cuda(), + requires_grad=False) + layer.gemm1_clamp_limit = Parameter(torch.tensor( + [7.0] * self.num_experts, dtype=torch.float32).cuda(), + requires_grad=False) + + sf_block_size = 32 # mxfp4 block size + + # Common shape assertions + assert (layer.w13_weight.dim() == 3 + and layer.w13_weight.shape[0] == self.num_experts + and layer.w13_weight.shape[1] == self.intermediate_size * 2 + and layer.w13_weight.shape[2] == self.hidden_size // 2) + assert (layer.w13_weight_scale.dim() == 3 + and layer.w13_weight_scale.shape[0] == self.num_experts + and layer.w13_weight_scale.shape[1] + == self.intermediate_size * 2 + and layer.w13_weight_scale.shape[2] + == self.hidden_size // sf_block_size) + assert (layer.w2_weight.dim() == 3 + and layer.w2_weight.shape[0] == self.num_experts + and layer.w2_weight.shape[1] == self.hidden_size and + layer.w2_weight.shape[2] == self.intermediate_size // 2) + assert (layer.w2_weight_scale.dim() == 3 + and layer.w2_weight_scale.shape[1] == self.hidden_size + and layer.w2_weight_scale.shape[2] + == self.intermediate_size // sf_block_size) + assert (layer.w13_bias.dim() == 2 + and layer.w13_bias.shape[0] == self.num_experts + and layer.w13_bias.shape[1] == self.intermediate_size * 2) + assert (layer.w2_bias.dim() == 2 + and layer.w2_bias.shape[0] == self.num_experts + and layer.w2_bias.shape[1] == self.hidden_size) + + # De-interleave and swap for w13 weight, bias, and scales + w13_w = layer.w13_weight.data + gate_w, up_w = w13_w[:, ::2, :], w13_w[:, 1::2, :] + deinterleaved_w13_w = torch.cat([gate_w, up_w], dim=1) + w1_w, w3_w = torch.chunk(deinterleaved_w13_w, 2, dim=1) + w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1) + + w13_b = layer.w13_bias.data.to(torch.float32) + gate_b, up_b = w13_b[:, ::2], w13_b[:, 1::2] + deinterleaved_w13_b = torch.cat([gate_b, up_b], dim=1) + b1, b3 = torch.chunk(deinterleaved_w13_b, 2, dim=-1) + w13_bias_swapped = torch.cat([b3, b1], dim=-1).to(torch.bfloat16) + + w13_s = layer.w13_weight_scale.data + gate_s, up_s = w13_s[:, ::2, :], w13_s[:, 1::2, :] + deinterleaved_w13_s = torch.cat([gate_s, up_s], dim=1) + s1, s3 = torch.chunk(deinterleaved_w13_s, 2, dim=1) + w13_scale_swapped = torch.cat([s3, s1], dim=1) + + if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS: + from flashinfer import block_scale_interleave + + orig_shape = w13_scale_swapped.shape + w13_scale_interleaved = block_scale_interleave( + w13_scale_swapped.view(torch.uint8)).reshape(orig_shape) + + w2_s = layer.w2_weight_scale.data + orig_shape = w2_s.shape + w2_scale_interleaved = block_scale_interleave( + w2_s.view(torch.uint8)).reshape(orig_shape) + + layer.w13_weight = Parameter(w13_weight_swapped, + requires_grad=False) + layer.w13_weight_scale = Parameter(w13_scale_interleaved, + requires_grad=False) + layer.w13_bias = Parameter(w13_bias_swapped, + requires_grad=False) + layer.w2_weight_scale = Parameter(w2_scale_interleaved, + requires_grad=False) + elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16: + + def _interleave_mxfp4_cutlass_sm90(w): + w_shape = w.shape + w_interleaved = w.reshape(w_shape[0], w_shape[1], + (w_shape[2] // 4), 4) + w_interleaved = w_interleaved.permute(0, 2, 1, 3) + w_interleaved = w_interleaved.reshape( + w_shape[0], w_shape[2] // 4, w_shape[1] * 4) + return w_interleaved + + w31_scales = w13_scale_swapped.to(torch.uint8).view( + torch.uint8) + w31_scales_interleaved = _interleave_mxfp4_cutlass_sm90( + w31_scales) + + w2_weight_scale = layer.w2_weight_scale.data + w2_scales = w2_weight_scale.to(torch.uint8).view(torch.uint8) + w2_scales_interleaved = _interleave_mxfp4_cutlass_sm90( + w2_scales) + + layer.w13_weight = torch.nn.Parameter(torch.cat([w3_w, w1_w], + dim=1), + requires_grad=False) + layer.w13_bias = torch.nn.Parameter(w13_bias_swapped, + requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter( + w31_scales_interleaved, requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter( + w2_scales_interleaved, requires_grad=False) + elif self.mxfp4_backend == Mxfp4Backend.TRITON: from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig w13_bias = layer.w13_bias.to(torch.float32) @@ -422,6 +602,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): layer.w13_weight = None layer.w2_weight = None torch.cuda.empty_cache() + else: + raise ValueError(f"Unsupported backend: {self.mxfp4_backend}") def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int): # Number of tokens in the input tensor. @@ -458,7 +640,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): raise NotImplementedError( "Mxfp4 does not support batched experts format for EP") else: - if should_use_flashinfer_mxfp4(): + if (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM + or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16): # B200 code-path kwargs = { "gemm1_alpha": layer.gemm1_alpha, @@ -559,7 +742,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): if enable_eplb: raise NotImplementedError("EPLB is not supported for mxfp4") - if self.use_marlin: + if self.mxfp4_backend == Mxfp4Backend.MARLIN: topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, @@ -623,16 +806,19 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): logical_replica_count), ( "MXFP4 are not supported with this configuration.") - if should_use_flashinfer_mxfp4(): - from flashinfer import mxfp8_quantize, trtllm_fp4_block_scale_moe - if _should_use_flashinfer_mxfp4_bf16(): + if (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM + or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16): + from flashinfer import trtllm_fp4_block_scale_moe + if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16: assert x.dtype == torch.bfloat16 x_quant = x x_scale = None - else: + elif self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM: + from flashinfer import mxfp8_quantize x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8 x_scale = x_scale.view(torch.float8_e4m3fn).reshape( *x.shape[:-1], -1) + trtllm_gen_output = trtllm_fp4_block_scale_moe( router_logits.to(torch.bfloat16), None, # routing_bias @@ -664,7 +850,86 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): tune_max_num_tokens=self.max_capture_size, )[0] return trtllm_gen_output - else: + elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS + or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16): + from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) + + # Backend-specific preparation + if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS: + + from flashinfer import mxfp8_quantize + + x_quant, x_scale = mxfp8_quantize(x, True, 32) + + fake_input_scale = torch.ones(self.num_experts, + device=x.device) + quant_scales = [ + layer.w13_weight_scale.contiguous().view(torch.int32), + fake_input_scale, + layer.w2_weight_scale.contiguous().view(torch.int32), + fake_input_scale, + ] + + fi_input = x_quant + extra_kwargs = dict( + use_mxfp8_act_scaling=True, + input_sf=x_scale, + fc1_expert_weights=layer.w13_weight.contiguous().view( + torch.long), + fc2_expert_weights=layer.w2_weight.contiguous().view( + torch.long), + ) + elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16: + assert x.dtype == torch.bfloat16 + + quant_scales = [ + layer.w13_weight_scale, + layer.w2_weight_scale, + ] + + fi_input = x + extra_kwargs = dict( + use_w4_group_scaling=True, + fc1_expert_weights=layer.w13_weight, + fc2_expert_weights=layer.w2_weight, + ) + + output = torch.empty_like(x, dtype=torch.bfloat16) + _ = flashinfer_cutlass_fused_moe( + input=fi_input, + token_selected_experts=topk_ids.to(torch.int).contiguous(), + token_final_scales=topk_weights, + output_dtype=torch.bfloat16, + output=output, + quant_scales=quant_scales, + fc1_expert_biases=layer.w13_bias, + fc2_expert_biases=layer.w2_bias, + swiglu_alpha=layer.gemm1_alpha, + swiglu_beta=layer.gemm1_beta, + swiglu_limit=layer.gemm1_clamp_limit, + tp_size=self.moe.tp_size, + tp_rank=self.moe.tp_rank, + ep_size=self.moe.ep_size, + ep_rank=self.moe.ep_rank, + tune_max_num_tokens=self.max_capture_size, + **extra_kwargs, + ) + + return output + elif self.mxfp4_backend == Mxfp4Backend.TRITON: from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501 triton_kernel_moe_forward) return triton_kernel_moe_forward( @@ -682,3 +947,5 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): w2_precision=self.w2_precision_config, apply_router_weight_on_input=apply_router_weight_on_input, ) + else: + raise ValueError(f"Unsupported backend: {self.mxfp4_backend}") diff --git a/vllm/model_executor/layers/quantization/torchao.py b/vllm/model_executor/layers/quantization/torchao.py index 63b2ab6bab063..2efb605f203fd 100644 --- a/vllm/model_executor/layers/quantization/torchao.py +++ b/vllm/model_executor/layers/quantization/torchao.py @@ -144,34 +144,36 @@ def torchao_quantize_param_data(param: torch.Tensor, """Quantize a Tensor with torchao quantization specified by torchao_config Args: - `param`: weight parameter of the linear module - `torchao_config`: type of quantization and their arguments we want to - use to quantize the Tensor + param: weight parameter of the linear module + torchao_config: type of quantization and their arguments we want to + use to quantize the Tensor """ from torchao.core.config import AOBaseConfig from torchao.quantization import quantize_ assert isinstance(torchao_config, AOBaseConfig), f"{torchao_config}" - """ - Avoid real weight allocation for faster load, since we will + """ + Avoid real weight allocation for faster load, since we will end up setting it to param. """ with torch.device("meta"): - dummy_linear = torch.nn.Linear(param.shape[1], - param.shape[0], - bias=False) + # linear can't be top level module since quantize_ is inplace + # while some of our configs need to do module swap, and only non-top + # level modules support module swap + dummy_linear = torch.nn.Sequential( + torch.nn.Linear(param.shape[1], param.shape[0], bias=False)) - dummy_linear.weight = param + dummy_linear[0].weight = param quantize_(dummy_linear, torchao_config) - return dummy_linear.weight + return dummy_linear[0].weight class TorchAOLinearMethod(LinearMethodBase): """Linear method for torchao. Args: - torchao_config: The torchao quantization config, a string - that encodes the type of quantization and all relevant arguments. + quant_config: The torchao quantization config, a string that encodes + the type of quantization and all relevant arguments. """ def __init__(self, quant_config: TorchAOConfig): diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 7b324dce3c367..e3e9635132d68 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -40,11 +40,14 @@ def cutlass_scaled_mm( block_size: list[int], output_dtype: torch.dtype = torch.float16, ) -> torch.Tensor: - return ops.cutlass_scaled_mm(A, - B.T, - out_dtype=output_dtype, - scale_a=As, - scale_b=Bs.T) + return ops.cutlass_scaled_mm( + A, + B.T, + out_dtype=output_dtype, + scale_a=As, + # SM90 block FP8 requires row-major scale_b, which we do ahead of time + scale_b=Bs if block_size is not None + and current_platform.is_device_capability(90) else Bs.T) def rocm_aiter_gemm_w8a8_blockscale_impl( @@ -152,35 +155,32 @@ def apply_w8a8_block_fp8_linear( output += bias return output.to(dtype=output_dtype).view(*output_shape) - if current_platform.is_cuda(): - if current_platform.has_device_capability(100): - - use_cutlass = cutlass_block_fp8_supported and ( - cdiv(weight.shape[0], 128) == weight_scale.shape[0] - and cdiv(weight.shape[1], 128) == weight_scale.shape[1]) - else: - # TODO: update this after switching to public sm90 block scale gemm - # as it also supports weight.shape % 128 != 0 - use_cutlass = cutlass_block_fp8_supported and ( - weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0) - else: - use_cutlass = False - w8a8_blockscale_func = dispatch_w8a8_blockscale_func( - use_cutlass, use_aiter_and_is_supported) - if use_cutlass: - q_input, x_scale = per_token_group_quant_fp8( - input_2d, block_size[1], column_major_scales=use_cutlass) + cutlass_block_fp8_supported, use_aiter_and_is_supported) + if cutlass_block_fp8_supported: + num_pad = 0 + if current_platform.is_device_capability(90): + # pad first dimension to be divisible by 4 due to + # cutlass blockwise gemm limitation for hopper + num_pad = 4 - (input_2d.shape[0] % 4) + if num_pad > 0: + input_2d = torch.nn.functional.pad(input_2d, + (0, 0, 0, num_pad), + "constant", 0) + q_input, x_scale = per_token_group_quant_fp8(input_2d, + block_size[1], + column_major_scales=True) output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale, block_size, input.dtype) - + if num_pad > 0: + output = output[:-num_pad] else: if use_aiter_and_is_supported: q_input, x_scale = aiter_per1x128_quant( input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8) else: q_input, x_scale = per_token_group_quant_fp8( - input_2d, block_size[1], column_major_scales=use_cutlass) + input_2d, block_size[1], column_major_scales=False) output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale, block_size, input.dtype) diff --git a/vllm/model_executor/layers/quantization/utils/int8_utils.py b/vllm/model_executor/layers/quantization/utils/int8_utils.py index 6840cabbf1ae3..62e458ec3c93e 100644 --- a/vllm/model_executor/layers/quantization/utils/int8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/int8_utils.py @@ -423,7 +423,7 @@ def w8a8_block_int8_matmul( Bs: The per-block quantization scale for `B`. block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128]. - output_dytpe: The dtype of the returned tensor. + output_dtype: The dtype of the returned tensor. Returns: torch.Tensor: The result of matmul. diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 02057b476c6e2..317ad079b392d 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -201,7 +201,7 @@ def marlin_make_workspace(output_size_per_partition: int, def marlin_make_workspace_new(device: torch.device, max_blocks_per_sm: int = 1) -> torch.Tensor: # In the new marlin kernel, we use the num of threadblocks as workspace - # size. The num of threadblocks is is sms_count * max_blocks_per_sm. + # size. The num of threadblocks is sms_count * max_blocks_per_sm. sms = torch.cuda.get_device_properties(device).multi_processor_count return torch.zeros(sms * max_blocks_per_sm, dtype=torch.int, diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 8f6b7f83d47f8..e89a5e643b0e5 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -171,10 +171,12 @@ def flashinfer_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor, bias=bias) -def rocm_per_tensor_w8a8_scaled_mm_impl( - qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype, - scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, - input_2d: torch.Tensor) -> torch.Tensor: +def rocm_per_tensor_w8a8_scaled_mm_impl(qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + bias: torch.Tensor) -> torch.Tensor: from vllm.platforms.rocm import on_mi3xx if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx( ) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0: @@ -190,10 +192,12 @@ def rocm_per_tensor_w8a8_scaled_mm_impl( return output -def rocm_per_tensor_w8a8_scaled_mm_fake( - qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype, - scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, - input_2d: torch.Tensor) -> torch.Tensor: +def rocm_per_tensor_w8a8_scaled_mm_fake(qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + bias: torch.Tensor) -> torch.Tensor: return qinput.new_empty((*qinput.shape[:-1], weight.shape[1]), dtype=out_dtype) @@ -203,11 +207,10 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, out_dtype: torch.dtype, scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, - input_2d: torch.Tensor, output_shape: list) -> torch.Tensor: output = torch.ops.vllm.rocm_per_tensor_w8a8_scaled_mm_impl( - qinput, weight, out_dtype, scale_a, scale_b, bias, input_2d) - return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) + qinput, weight, out_dtype, scale_a, scale_b, bias) + return torch.narrow(output, 0, 0, qinput.shape[0]).view(*output_shape) direct_register_custom_op( @@ -224,7 +227,6 @@ def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, out_dtype: torch.dtype, scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, - input_2d: torch.Tensor, output_shape: list) -> torch.Tensor: output = torch._scaled_mm(qinput, weight, @@ -237,7 +239,7 @@ def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, if type(output) is tuple and len(output) == 2: output = output[0] - return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) + return torch.narrow(output, 0, 0, qinput.shape[0]).view(*output_shape) def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor, @@ -245,7 +247,7 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor, out_dtype: torch.dtype, scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, - input_2d: torch.Tensor, output_shape: list, + output_shape: list, **kwargs) -> torch.Tensor: # Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM # when using it. @@ -265,7 +267,7 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor, scale_b=scale_b.t(), bias=bias) - output = torch.narrow(output, 0, 0, input_2d.shape[0]) + output = torch.narrow(output, 0, 0, qinput.shape[0]) output = output.view(*output_shape) return output @@ -275,7 +277,6 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor, out_dtype: torch.dtype, scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, - input_2d: torch.Tensor, output_shape: list, **kwargs) -> torch.Tensor: # Use unfused DQ due to limitations with scaled_mm @@ -305,8 +306,8 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor, if type(output) is tuple and len(output) == 2: output = output[0] # Unpad (undo num_token_padding) - output = torch.narrow(output, 0, 0, input_2d.shape[0]) - x_scale = torch.narrow(scale_a, 0, 0, input_2d.shape[0]) + output = torch.narrow(output, 0, 0, qinput.shape[0]) + x_scale = torch.narrow(scale_a, 0, 0, qinput.shape[0]) # DQ # C = sw * sx * (X * W) + bias @@ -430,7 +431,6 @@ class Fp8LinearOp: scale_a=x_scale, scale_b=weight_scale, bias=bias, - input_2d=input_2d, output_shape=output_shape) diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index be25e90abf821..db50eb08db3ff 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -62,11 +62,8 @@ class RotaryEmbedding(CustomOp): positions: torch.Tensor, query: torch.Tensor, key: Optional[torch.Tensor] = None, - offsets: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """A PyTorch-native implementation of forward().""" - if offsets is not None: - positions = positions + offsets positions = positions.flatten() num_tokens = positions.shape[0] cos_sin = self.cos_sin_cache.index_select(0, positions) @@ -96,7 +93,6 @@ class RotaryEmbedding(CustomOp): positions: torch.Tensor, query: torch.Tensor, key: Optional[torch.Tensor] = None, - offsets: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: from vllm import _custom_ops as ops @@ -107,16 +103,10 @@ class RotaryEmbedding(CustomOp): self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) - # ops.rotary_embedding()/batched_rotary_embedding() - # are in-place operations that update the query and key tensors. - if offsets is not None: - ops.batched_rotary_embedding(positions, query, key, self.head_size, - self.cos_sin_cache, - self.is_neox_style, self.rotary_dim, - offsets) - else: - ops.rotary_embedding(positions, query, key, self.head_size, - self.cos_sin_cache, self.is_neox_style) + # ops.rotary_embedding() is an in-place operation + # that updates the query and key tensors. + ops.rotary_embedding(positions, query, key, self.head_size, + self.cos_sin_cache, self.is_neox_style) return query, key def forward_xpu( @@ -124,29 +114,21 @@ class RotaryEmbedding(CustomOp): positions: torch.Tensor, query: torch.Tensor, key: Optional[torch.Tensor] = None, - offsets: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: from vllm._ipex_ops import ipex_ops as ops self.cos_sin_cache = self.cos_sin_cache.to(positions.device, dtype=query.dtype) - # ops.rotary_embedding()/batched_rotary_embedding() - # are in-place operations that update the query and key tensors. + # ops.rotary_embedding() is an in-place operation + # that updates the query and key tensors. if key is None: # XPU kernel doesn't support key=None so fall back to native impl # TODO(sarckk): add support for optional key in # ipex.llm.functional.rotary_embedding_batched - return self.forward_native(positions, query, key, offsets) + return self.forward_native(positions, query, key) else: - if offsets is not None: - ops.batched_rotary_embedding(positions, query, key, - self.head_size, - self.cos_sin_cache, - self.is_neox_style, - self.rotary_dim, offsets) - else: - ops.rotary_embedding(positions, query, key, self.head_size, - self.cos_sin_cache, self.is_neox_style) + ops.rotary_embedding(positions, query, key, self.head_size, + self.cos_sin_cache, self.is_neox_style) return query, key def extra_repr(self) -> str: diff --git a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py index cd888b733426b..7ac2e4bb6c34f 100644 --- a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py @@ -88,7 +88,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): cache = torch.cat((cos, sin), dim=-1) return cache - def forward( + def forward_native( self, positions: torch.Tensor, query: torch.Tensor, @@ -129,3 +129,12 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): query = query_rot key = key_rot return query, key + + def forward_cuda( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + offsets: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + return self.forward_native(positions, query, key, offsets) diff --git a/vllm/model_executor/layers/rotary_embedding/dual_chunk_rope.py b/vllm/model_executor/layers/rotary_embedding/dual_chunk_rope.py index 3d8da0fa9d8f5..27e41dd0fa97e 100644 --- a/vllm/model_executor/layers/rotary_embedding/dual_chunk_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/dual_chunk_rope.py @@ -111,7 +111,7 @@ class DualChunkRotaryEmbedding(CustomOp): device=self.device) return q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache - def forward( + def forward_native( self, positions: torch.Tensor, query: torch.Tensor, @@ -161,6 +161,15 @@ class DualChunkRotaryEmbedding(CustomOp): dim=-1) return query, key + def forward_cuda( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + return self.forward_native(positions, query, key, offsets) + def _apply_rotary_embedding(self, cos_sin, hidden_rot, hidden_pass): cos, sin = cos_sin.chunk(2, dim=-1) if self.is_neox_style: diff --git a/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py b/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py index 05322e56f2620..4960c20f4060a 100644 --- a/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py @@ -12,7 +12,7 @@ from .mrope import MRotaryEmbedding class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding): """3D rotary positional embedding. 3D is t:time h:height w:width""" - def forward( + def forward_native( # type: ignore[override] self, positions: torch.Tensor, query: torch.Tensor, @@ -70,3 +70,11 @@ class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding): self.is_neox_style) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key + + def forward_cuda( # type: ignore[override] + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + return self.forward_native(positions, query, key) \ No newline at end of file diff --git a/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py b/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py index 415a85ab698bc..37ead43e22bc4 100644 --- a/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py @@ -53,7 +53,7 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding): torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)) return cache - def forward( + def forward_native( # type: ignore[override] self, query: torch.Tensor, key: Optional[torch.Tensor] = None, @@ -72,3 +72,10 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding): query_out = torch.view_as_real(query_ * freqs_ci).flatten(3) key_out = torch.view_as_real(key_ * freqs_ci).flatten(3) return query_out.type_as(query), key_out.type_as(key) + + def forward_cuda( # type: ignore[override] + self, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + return self.forward_native(query, key) diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py index 0ab4bc5375daf..69849fdac0277 100644 --- a/vllm/model_executor/layers/rotary_embedding/mrope.py +++ b/vllm/model_executor/layers/rotary_embedding/mrope.py @@ -8,7 +8,6 @@ import numpy as np import torch from transformers import PretrainedConfig -from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from .base import RotaryEmbedding @@ -136,8 +135,8 @@ def triton_mrope( """Qwen2VL mrope kernel. Args: - query: [num_tokens, num_heads * head_size] - key: [num_tokens, num_kv_heads * head_size] + q: [num_tokens, num_heads * head_size] + k: [num_tokens, num_kv_heads * head_size] cos: [3, num_tokens, head_size //2 ] (T/H/W positions with multimodal inputs) sin: [3, num_tokens, head_size //2 ] @@ -202,28 +201,6 @@ class MRotaryEmbedding(RotaryEmbedding): if self.mrope_section: assert sum(self.mrope_section) == rotary_dim // 2 - self.use_triton = current_platform.is_cuda_alike() - - def forward( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - """MRope forward. - - Args: - positions: - [num_tokens,] (text only) or - [3, num_tokens] (T/H/W positions with multimodal inputs) - query: [num_tokens, num_heads * head_size] - key: [num_tokens, num_kv_heads * head_size] - """ - if self.use_triton: - return self.forward_cuda(positions, query, key) - else: - return self.forward_native(positions, query, key) - def forward_native( self, positions: torch.Tensor, @@ -323,6 +300,24 @@ class MRotaryEmbedding(RotaryEmbedding): key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key + def forward_xpu( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + offsets: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + return self.forward_native(positions, query, key, offsets) + + def forward_cpu( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + offsets: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + return self.forward_native(positions, query, key, offsets) + @classmethod def get_input_positions( cls, diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index c92a7978195bc..aa64d4e09ae18 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -258,7 +258,7 @@ class VocabParallelEmbedding(CustomOp): if params_dtype is None: params_dtype = torch.get_default_dtype() - # Divide the weight matrix along the vocaburaly dimension. + # Divide the weight matrix along the vocabulary dimension. self.num_added_embeddings = self.num_embeddings - self.org_vocab_size self.num_embeddings_per_partition = divide(self.num_embeddings_padded, self.tp_size) @@ -399,7 +399,7 @@ class VocabParallelEmbedding(CustomOp): param[:loaded_weight.shape[0]].data.copy_(loaded_weight) param[loaded_weight.shape[0]:].data.fill_(0) - def forward(self, input_): + def forward_native(self, input_): if self.tp_size > 1: # Build the mask. masked_input, input_mask = get_masked_input_and_mask( @@ -420,6 +420,9 @@ class VocabParallelEmbedding(CustomOp): output = tensor_model_parallel_all_reduce(output_parallel) return output + def forward_cuda(self, input_): + return self.forward_native(input_) + def extra_repr(self) -> str: s = f"num_embeddings={self.num_embeddings_per_partition}" s += f", embedding_dim={self.embedding_dim}" @@ -429,6 +432,7 @@ class VocabParallelEmbedding(CustomOp): return s +@CustomOp.register("parallel_lm_head") class ParallelLMHead(VocabParallelEmbedding): """Parallelized LM head. diff --git a/vllm/model_executor/model_loader/__init__.py b/vllm/model_executor/model_loader/__init__.py index 2dada794a8f3e..138a2ff30b622 100644 --- a/vllm/model_executor/model_loader/__init__.py +++ b/vllm/model_executor/model_loader/__init__.py @@ -5,7 +5,8 @@ from typing import Literal, Optional from torch import nn -from vllm.config import LoadConfig, ModelConfig, VllmConfig +from vllm.config import ModelConfig, VllmConfig +from vllm.config.load import LoadConfig from vllm.logger import init_logger from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.bitsandbytes_loader import ( @@ -67,7 +68,7 @@ def register_model_loader(load_format: str): load_format (str): The model loader format name. Examples: - >>> from vllm.config import LoadConfig + >>> from vllm.config.load import LoadConfig >>> from vllm.model_executor.model_loader import get_model_loader, register_model_loader >>> from vllm.model_executor.model_loader.base_loader import BaseModelLoader >>> diff --git a/vllm/model_executor/model_loader/base_loader.py b/vllm/model_executor/model_loader/base_loader.py index 4cf6c7988960d..ab538a3c95620 100644 --- a/vllm/model_executor/model_loader/base_loader.py +++ b/vllm/model_executor/model_loader/base_loader.py @@ -5,7 +5,8 @@ from abc import ABC, abstractmethod import torch import torch.nn as nn -from vllm.config import LoadConfig, ModelConfig, VllmConfig +from vllm.config import ModelConfig, VllmConfig +from vllm.config.load import LoadConfig from vllm.logger import init_logger from vllm.model_executor.model_loader.utils import ( initialize_model, process_weights_after_loading, set_default_torch_dtype) diff --git a/vllm/model_executor/model_loader/bitsandbytes_loader.py b/vllm/model_executor/model_loader/bitsandbytes_loader.py index c8dd1ec0ec3c6..4edf193b54ac5 100644 --- a/vllm/model_executor/model_loader/bitsandbytes_loader.py +++ b/vllm/model_executor/model_loader/bitsandbytes_loader.py @@ -16,7 +16,8 @@ from packaging import version from torch import nn from transformers.utils import SAFE_WEIGHTS_INDEX_NAME -from vllm.config import LoadConfig, ModelConfig +from vllm.config import ModelConfig +from vllm.config.load import LoadConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) # yapf: enable @@ -325,7 +326,8 @@ class BitsAndBytesModelLoader(BaseModelLoader): global_tp_size = get_tensor_model_parallel_world_size() global_tp_rank = get_tensor_model_parallel_rank() - + check_match = (lambda weight_name, module_name: weight_name. + removesuffix(".weight") == module_name) for ( org_weight_name, mapped_weight_name, @@ -346,12 +348,12 @@ class BitsAndBytesModelLoader(BaseModelLoader): ) and mapped_weight_name.endswith(".weight"): # Without sharding if any( - mapped_weight_name.startswith(module) + check_match(mapped_weight_name, module) for module in self.unsharded_weights_modules): weight_sub_tensor = weight_tensor # Shard by column elif any( - mapped_weight_name.startswith(module) + check_match(mapped_weight_name, module) for module in self.column_sharded_weights_modules): total_size = weight_tensor.size(-1) start_index = total_size // tp_size * tp_rank @@ -361,14 +363,14 @@ class BitsAndBytesModelLoader(BaseModelLoader): # Weights have fused on disk. In this case, we assume that the # weight and module use same name. elif any( - mapped_weight_name.startswith(module) + check_match(mapped_weight_name, module) for module in self.maybe_fused_weights_modules): # special case for fused weights # get the size of each shard weight tensor total_shard_sizes = next( (sizes for module, sizes in self.maybe_fused_weights_modules.items() - if mapped_weight_name.startswith(module))) + if check_match(mapped_weight_name, module))) total_size = weight_tensor.size(0) assert total_size == sum(total_shard_sizes) # get the start/end index of each shard weight tensor diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index 1e5aa9e571edb..d1bdec21fd974 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -7,19 +7,20 @@ import time from collections.abc import Generator, Iterable from typing import Optional, cast -import huggingface_hub import torch from torch import nn from transformers.utils import SAFE_WEIGHTS_INDEX_NAME -from vllm import envs -from vllm.config import LoadConfig, ModelConfig +from vllm.config import ModelConfig +from vllm.config.load import LoadConfig from vllm.logger import init_logger from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.weight_utils import ( download_safetensors_index_file_from_hf, download_weights_from_hf, fastsafetensors_weights_iterator, filter_duplicate_safetensors_files, - filter_files_not_needed_for_inference, get_lock, np_cache_weights_iterator, + filter_files_not_needed_for_inference, maybe_download_from_modelscope, + multi_thread_pt_weights_iterator, + multi_thread_safetensors_weights_iterator, np_cache_weights_iterator, pt_weights_iterator, safetensors_weights_iterator) from vllm.platforms import current_platform @@ -29,6 +30,9 @@ logger = init_logger(__name__) class DefaultModelLoader(BaseModelLoader): """Model loader that can load different file types from disk.""" + # default number of thread when enable multithread weight loading + DEFAULT_NUM_THREADS = 8 + @dataclasses.dataclass class Source: """A source for weights.""" @@ -53,38 +57,15 @@ class DefaultModelLoader(BaseModelLoader): def __init__(self, load_config: LoadConfig): super().__init__(load_config) - if load_config.model_loader_extra_config: - raise ValueError(f"Model loader extra config is not supported for " - f"load format {load_config.load_format}") - def _maybe_download_from_modelscope( - self, model: str, revision: Optional[str]) -> Optional[str]: - """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True. + extra_config = load_config.model_loader_extra_config + allowed_keys = {"enable_multithread_load", "num_threads"} + unexpected_keys = set(extra_config.keys()) - allowed_keys - Returns the path to the downloaded model, or None if the model is not - downloaded from ModelScope.""" - if envs.VLLM_USE_MODELSCOPE: - # download model from ModelScope hub, - # lazy import so that modelscope is not required for normal use. - # pylint: disable=C. - from modelscope.hub.snapshot_download import snapshot_download - - # Use file lock to prevent multiple processes from - # downloading the same model weights at the same time. - with get_lock(model, self.load_config.download_dir): - if not os.path.exists(model): - model_path = snapshot_download( - model_id=model, - cache_dir=self.load_config.download_dir, - local_files_only=huggingface_hub.constants. - HF_HUB_OFFLINE, - revision=revision, - ignore_file_pattern=self.load_config.ignore_patterns, - ) - else: - model_path = model - return model_path - return None + if unexpected_keys: + raise ValueError(f"Unexpected extra config keys for load format " + f"{load_config.load_format}: " + f"{unexpected_keys}") def _prepare_weights( self, @@ -96,7 +77,7 @@ class DefaultModelLoader(BaseModelLoader): """Prepare weights for the model. If the model is not local, it will be downloaded.""" - model_name_or_path = (self._maybe_download_from_modelscope( + model_name_or_path = (maybe_download_from_modelscope( model_name_or_path, revision) or model_name_or_path) is_local = os.path.isdir(model_name_or_path) @@ -175,6 +156,7 @@ class DefaultModelLoader(BaseModelLoader): self, source: "Source" ) -> Generator[tuple[str, torch.Tensor], None, None]: """Get an iterator for the model weights based on the load format.""" + extra_config = self.load_config.model_loader_extra_config hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( source.model_or_path, source.revision, source.fall_back_to_pt, source.allow_patterns_overrides) @@ -195,16 +177,35 @@ class DefaultModelLoader(BaseModelLoader): self.load_config.use_tqdm_on_load, ) else: - weights_iterator = safetensors_weights_iterator( + if extra_config.get("enable_multithread_load"): + weights_iterator = ( + multi_thread_safetensors_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + max_workers=extra_config.get( + "num_threads", self.DEFAULT_NUM_THREADS), + )) + else: + weights_iterator = safetensors_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + self.load_config.safetensors_load_strategy, + ) + else: + if extra_config.get("enable_multithread_load"): + weights_iterator = multi_thread_pt_weights_iterator( hf_weights_files, self.load_config.use_tqdm_on_load, + self.load_config.pt_load_map_location, + max_workers=extra_config.get("num_threads", + self.DEFAULT_NUM_THREADS), + ) + else: + weights_iterator = pt_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + self.load_config.pt_load_map_location, ) - else: - weights_iterator = pt_weights_iterator( - hf_weights_files, - self.load_config.use_tqdm_on_load, - self.load_config.pt_load_map_location, - ) if current_platform.is_tpu(): from vllm.platforms.tpu import USE_TPU_COMMONS diff --git a/vllm/model_executor/model_loader/dummy_loader.py b/vllm/model_executor/model_loader/dummy_loader.py index f4a7da5744e04..5b8c6268f64ef 100644 --- a/vllm/model_executor/model_loader/dummy_loader.py +++ b/vllm/model_executor/model_loader/dummy_loader.py @@ -2,7 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch.nn as nn -from vllm.config import LoadConfig, ModelConfig +from vllm.config import ModelConfig +from vllm.config.load import LoadConfig from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.weight_utils import ( initialize_dummy_weights) diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py index 9877cb3b7c06e..aaee8f3f76353 100644 --- a/vllm/model_executor/model_loader/gguf_loader.py +++ b/vllm/model_executor/model_loader/gguf_loader.py @@ -9,7 +9,8 @@ import torch.nn as nn from huggingface_hub import hf_hub_download from transformers import AutoModelForCausalLM -from vllm.config import LoadConfig, ModelConfig, VllmConfig +from vllm.config import ModelConfig, VllmConfig +from vllm.config.load import LoadConfig from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.utils import ( initialize_model, process_weights_after_loading, set_default_torch_dtype) diff --git a/vllm/model_executor/model_loader/runai_streamer_loader.py b/vllm/model_executor/model_loader/runai_streamer_loader.py index 83e0f386c1082..dc941401a04e0 100644 --- a/vllm/model_executor/model_loader/runai_streamer_loader.py +++ b/vllm/model_executor/model_loader/runai_streamer_loader.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # ruff: noqa: SIM117 -import glob import os from collections.abc import Generator from typing import Optional @@ -10,13 +9,14 @@ import torch from torch import nn from transformers.utils import SAFE_WEIGHTS_INDEX_NAME -from vllm.config import LoadConfig, ModelConfig +from vllm.config import ModelConfig +from vllm.config.load import LoadConfig from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.weight_utils import ( download_safetensors_index_file_from_hf, download_weights_from_hf, runai_safetensors_weights_iterator) -from vllm.transformers_utils.s3_utils import glob as s3_glob -from vllm.transformers_utils.utils import is_s3 +from vllm.transformers_utils.runai_utils import (is_runai_obj_uri, + list_safetensors) class RunaiModelStreamerLoader(BaseModelLoader): @@ -53,27 +53,22 @@ class RunaiModelStreamerLoader(BaseModelLoader): If the model is not local, it will be downloaded.""" - is_s3_path = is_s3(model_name_or_path) + is_object_storage_path = is_runai_obj_uri(model_name_or_path) is_local = os.path.isdir(model_name_or_path) safetensors_pattern = "*.safetensors" index_file = SAFE_WEIGHTS_INDEX_NAME - hf_folder = (model_name_or_path if - (is_local or is_s3_path) else download_weights_from_hf( + hf_folder = (model_name_or_path if (is_local or is_object_storage_path) + else download_weights_from_hf( model_name_or_path, self.load_config.download_dir, [safetensors_pattern], revision, ignore_patterns=self.load_config.ignore_patterns, )) - if is_s3_path: - hf_weights_files = s3_glob(path=hf_folder, - allow_pattern=[safetensors_pattern]) - else: - hf_weights_files = glob.glob( - os.path.join(hf_folder, safetensors_pattern)) + hf_weights_files = list_safetensors(path=hf_folder) - if not is_local and not is_s3_path: + if not is_local and not is_object_storage_path: download_safetensors_index_file_from_hf( model_name_or_path, index_file, self.load_config.download_dir, revision) diff --git a/vllm/model_executor/model_loader/sharded_state_loader.py b/vllm/model_executor/model_loader/sharded_state_loader.py index 3edd4ec4007e8..a85ca065d1d27 100644 --- a/vllm/model_executor/model_loader/sharded_state_loader.py +++ b/vllm/model_executor/model_loader/sharded_state_loader.py @@ -10,7 +10,8 @@ from typing import Any, Optional import torch from torch import nn -from vllm.config import LoadConfig, ModelConfig +from vllm.config import ModelConfig +from vllm.config.load import LoadConfig from vllm.logger import init_logger from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.weight_utils import ( diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index 3d491be3156b6..58296131fadb9 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -171,51 +171,52 @@ class TensorizerConfig(MutableMapping): _is_sharded: bool = field(init=False, default=False) _fields: ClassVar[tuple[str, ...]] _keys: ClassVar[frozenset[str]] - """ - Args for the TensorizerConfig class. These are used to configure the - behavior of model serialization and deserialization using Tensorizer. + """Configuration class for Tensorizer settings. - Args: - tensorizer_uri: Path to serialized model tensors. Can be a local file - path or a S3 URI. This is a required field unless lora_dir is - provided and the config is meant to be used for the - `tensorize_lora_adapter` function. Unless a `tensorizer_dir` or - `lora_dir` is passed to this object's initializer, this is a required - argument. - tensorizer_dir: Path to a directory containing serialized model tensors, - and all other potential model artifacts to load the model, such as - configs and tokenizer files. Can be passed instead of `tensorizer_uri` - where the `model.tensors` file will be assumed to be in this - directory. - vllm_tensorized: If True, indicates that the serialized model is a - vLLM model. This is used to determine the behavior of the - TensorDeserializer when loading tensors from a serialized model. - It is far faster to deserialize a vLLM model as it utilizes - tensorizer's optimized GPU loading. Note that this is now - deprecated, as serialized vLLM models are now automatically - inferred as vLLM models. - verify_hash: If True, the hashes of each tensor will be verified against - the hashes stored in the metadata. A `HashMismatchError` will be - raised if any of the hashes do not match. - num_readers: Controls how many threads are allowed to read concurrently - from the source file. Default is `None`, which will dynamically set - the number of readers based on the number of available - resources and model size. This greatly increases performance. - encryption_keyfile: File path to a binary file containing a - binary key to use for decryption. `None` (the default) means - no decryption. See the example script in - examples/others/tensorize_vllm_model.py. - s3_access_key_id: The access key for the S3 bucket. Can also be set via - the S3_ACCESS_KEY_ID environment variable. - s3_secret_access_key: The secret access key for the S3 bucket. Can also - be set via the S3_SECRET_ACCESS_KEY environment variable. - s3_endpoint: The endpoint for the S3 bucket. Can also be set via the - S3_ENDPOINT_URL environment variable. - lora_dir: Path to a directory containing LoRA adapter artifacts for - serialization or deserialization. When serializing LoRA adapters - this is the only necessary parameter to pass to this object's - initializer. - """ + These settings configure the behavior of model serialization and + deserialization using Tensorizer. + + Attributes: + tensorizer_uri: Path to serialized model tensors. Can be a local file + path or a S3 URI. This is a required field unless lora_dir is + provided and the config is meant to be used for the + `tensorize_lora_adapter` function. Unless a `tensorizer_dir` or + `lora_dir` is passed to this object's initializer, this is + a required argument. + tensorizer_dir: Path to a directory containing serialized model tensors, + and all other potential model artifacts to load the model, such as + configs and tokenizer files. Can be passed instead of + `tensorizer_uri` where the `model.tensors` file will be assumed + to be in this directory. + vllm_tensorized: If True, indicates that the serialized model is a + vLLM model. This is used to determine the behavior of the + TensorDeserializer when loading tensors from a serialized model. + It is far faster to deserialize a vLLM model as it utilizes + tensorizer's optimized GPU loading. Note that this is now + deprecated, as serialized vLLM models are now automatically + inferred as vLLM models. + verify_hash: If True, the hashes of each tensor will be verified + against the hashes stored in the metadata. A `HashMismatchError` + will be raised if any of the hashes do not match. + num_readers: Controls how many threads are allowed to read concurrently + from the source file. Default is `None`, which will dynamically set + the number of readers based on the number of available + resources and model size. This greatly increases performance. + encryption_keyfile: File path to a binary file containing a + binary key to use for decryption. `None` (the default) means + no decryption. See the example script in + examples/others/tensorize_vllm_model.py. + s3_access_key_id: The access key for the S3 bucket. Can also be set via + the S3_ACCESS_KEY_ID environment variable. + s3_secret_access_key: The secret access key for the S3 bucket. Can also + be set via the S3_SECRET_ACCESS_KEY environment variable. + s3_endpoint: The endpoint for the S3 bucket. Can also be set via the + S3_ENDPOINT_URL environment variable. + lora_dir: Path to a directory containing LoRA adapter artifacts for + serialization or deserialization. When serializing LoRA adapters + this is the only necessary parameter to pass to this object's + initializer. + """ def __post_init__(self): # check if the configuration is for a sharded vLLM model diff --git a/vllm/model_executor/model_loader/tensorizer_loader.py b/vllm/model_executor/model_loader/tensorizer_loader.py index fa01758ab4cee..65ea49c642944 100644 --- a/vllm/model_executor/model_loader/tensorizer_loader.py +++ b/vllm/model_executor/model_loader/tensorizer_loader.py @@ -8,7 +8,8 @@ from typing import Union import torch from torch import nn -from vllm.config import LoadConfig, ModelConfig, ParallelConfig, VllmConfig +from vllm.config import ModelConfig, ParallelConfig, VllmConfig +from vllm.config.load import LoadConfig from vllm.logger import init_logger from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.tensorizer import ( diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index f57ebdb1abcbc..0c2441a6db44d 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -19,10 +19,11 @@ from vllm.logger import init_logger from vllm.model_executor.layers.linear import QKVCrossParallelLinear from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) -from vllm.model_executor.models.adapters import (as_embedding_model, - as_reward_model, - as_seq_cls_model) -from vllm.model_executor.models.interfaces import SupportsQuant +from vllm.model_executor.models.adapters import ( + as_embedding_model, as_reward_model, as_seq_cls_model, + try_create_mm_pooling_model_cls) +from vllm.model_executor.models.interfaces import (SupportsQuant, + supports_multimodal) from vllm.utils import is_pin_memory_available logger = init_logger(__name__) @@ -169,22 +170,6 @@ def get_model_architecture( model_config: ModelConfig) -> tuple[type[nn.Module], str]: architectures = getattr(model_config.hf_config, "architectures", []) - # Special handling for quantized Mixtral. - # FIXME(woosuk): This is a temporary hack. - mixtral_supported = [ - "fp8", - "compressed-tensors", - "gptq_marlin", - "awq_marlin", - "quark", - "bitsandbytes", - ] - - if (model_config.quantization is not None - and model_config.quantization not in mixtral_supported - and "MixtralForCausalLM" in architectures): - architectures = ["QuantMixtralForCausalLM"] - model_cls, arch = model_config.registry.resolve_model_cls( architectures, model_config=model_config, @@ -199,6 +184,15 @@ def get_model_architecture( "performance may not be optimal.", arch) convert_type = model_config.convert_type + if convert_type != "none" and supports_multimodal(model_cls): + logger.debug_once("Detected conversion of Multi Modal model.") + converted = try_create_mm_pooling_model_cls(model_cls) + if converted is not None: + logger.debug_once("Creating wrapper class to forward pooler.") + return converted, arch + else: + logger.debug_once("Attempting direct conversion.") + if convert_type == "none": pass elif convert_type == "embed": diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index f87eeaa4563ff..f2c66763d0816 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Utilities for downloading and initializing model weights.""" +import concurrent.futures import fnmatch import glob import hashlib @@ -18,10 +19,12 @@ import huggingface_hub.constants import numpy as np import torch from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download -from safetensors.torch import load_file, safe_open, save_file +from safetensors.torch import load, load_file, safe_open, save_file from tqdm.auto import tqdm -from vllm.config import LoadConfig, ModelConfig +from vllm import envs +from vllm.config import ModelConfig +from vllm.config.load import LoadConfig from vllm.distributed import get_tensor_model_parallel_rank from vllm.logger import init_logger from vllm.model_executor.layers.quantization import (QuantizationConfig, @@ -95,6 +98,41 @@ def get_lock(model_name_or_path: Union[str, Path], return lock +def maybe_download_from_modelscope( + model: str, + revision: Optional[str] = None, + download_dir: Optional[str] = None, + ignore_patterns: Optional[Union[str, list[str]]] = None, + allow_patterns: Optional[Union[list[str], + str]] = None) -> Optional[str]: + """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True. + + Returns the path to the downloaded model, or None if the model is not + downloaded from ModelScope.""" + if envs.VLLM_USE_MODELSCOPE: + # download model from ModelScope hub, + # lazy import so that modelscope is not required for normal use. + # pylint: disable=C. + from modelscope.hub.snapshot_download import snapshot_download + + # Use file lock to prevent multiple processes from + # downloading the same model weights at the same time. + with get_lock(model, download_dir): + if not os.path.exists(model): + model_path = snapshot_download( + model_id=model, + cache_dir=download_dir, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + revision=revision, + ignore_file_pattern=ignore_patterns, + allow_patterns=allow_patterns, + ) + else: + model_path = model + return model_path + return None + + def _shared_pointers(tensors): ptrs = defaultdict(list) for k, v in tensors.items(): @@ -169,7 +207,13 @@ def get_quant_config(model_config: ModelConfig, # Inflight BNB quantization if model_config.quantization == "bitsandbytes": return quant_cls.from_config({}) - is_local = os.path.isdir(model_config.model) + model_name_or_path = maybe_download_from_modelscope( + model_config.model, + revision=model_config.revision, + download_dir=load_config.download_dir, + allow_patterns=["*.json"], + ) or model_config.model + is_local = os.path.isdir(model_name_or_path) if not is_local: # Download the config files. with get_lock(model_config.model, load_config.download_dir): @@ -182,7 +226,7 @@ def get_quant_config(model_config: ModelConfig, tqdm_class=DisabledTqdm, ) else: - hf_folder = model_config.model + hf_folder = model_name_or_path possible_config_filenames = quant_cls.get_config_filenames() @@ -475,18 +519,58 @@ def np_cache_weights_iterator( def safetensors_weights_iterator( hf_weights_files: list[str], use_tqdm_on_load: bool, + safetensors_load_strategy: str = "lazy", ) -> Generator[tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model safetensor files.""" + loading_desc = "Loading safetensors checkpoint shards" + if safetensors_load_strategy == "eager": + loading_desc += " (eager)" + for st_file in tqdm( hf_weights_files, - desc="Loading safetensors checkpoint shards", + desc=loading_desc, disable=not enable_tqdm(use_tqdm_on_load), bar_format=_BAR_FORMAT, ): - with safe_open(st_file, framework="pt") as f: - for name in f.keys(): # noqa: SIM118 - param = f.get_tensor(name) - yield name, param + if safetensors_load_strategy == "eager": + with open(st_file, "rb") as f: + state_dict = load(f.read()) + yield from state_dict.items() + else: + with safe_open(st_file, framework="pt") as f: + for name in f.keys(): # noqa: SIM118 + param = f.get_tensor(name) + yield name, param + + +def multi_thread_safetensors_weights_iterator( + hf_weights_files: list[str], + use_tqdm_on_load: bool, + max_workers: int = 4, +) -> Generator[tuple[str, torch.Tensor], None, None]: + """Multi-Thread iterate over the weights in the model safetensor files.""" + + def _load_file(st_file: str): + result = load_file(st_file, device="cpu") + return result + + with concurrent.futures.ThreadPoolExecutor( + max_workers=max_workers) as executor: + futures = [ + executor.submit(_load_file, st_file) + for st_file in hf_weights_files + ] + futures_iter = tqdm( + concurrent.futures.as_completed(futures), + total=len(hf_weights_files), + desc="Multi-thread loading shards", + disable=not enable_tqdm(use_tqdm_on_load), + bar_format=_BAR_FORMAT, + ) + + for future in futures_iter: + state_dict = future.result() + yield from state_dict.items() def runai_safetensors_weights_iterator( @@ -569,6 +653,39 @@ def pt_weights_iterator( del state +def multi_thread_pt_weights_iterator( + hf_weights_files: list[str], + use_tqdm_on_load: bool, + pt_load_map_location: Union[str, dict[str, str]] = "cpu", + max_workers: int = 4, +) -> Generator[tuple[str, torch.Tensor], None, None]: + """Multi-Thread iterate over the weights in the model bin/pt files.""" + + def _load_file(bin_file: str): + return torch.load(bin_file, + map_location=pt_load_map_location, + weights_only=True) + + with concurrent.futures.ThreadPoolExecutor( + max_workers=max_workers) as executor: + futures = [ + executor.submit(_load_file, bin_file) + for bin_file in hf_weights_files + ] + futures_iter = tqdm( + concurrent.futures.as_completed(futures), + total=len(hf_weights_files), + desc="Multi-thread loading pt checkpoint shards", + disable=not enable_tqdm(use_tqdm_on_load), + bar_format=_BAR_FORMAT, + ) + + for future in futures_iter: + state = future.result() + yield from state.items() + del state + + def get_gguf_extra_tensor_names( gguf_file: str, gguf_to_hf_name_map: dict[str, str]) -> list[str]: reader = gguf.GGUFReader(gguf_file) diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index bb96bc559200c..c4328a176a5de 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -1,12 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import ast +import inspect from collections.abc import Iterable from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast import torch import torch.nn as nn +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.models.config import VerifyAndUpdateConfig @@ -62,7 +65,7 @@ def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]: linear = nn.Linear(layer_config.get("in_features", 768), layer_config.get("out_features", 768), bias=layer_config.get("bias", True), - dtype=torch.float32) + dtype=model_config.head_dtype) if not _load_dense_weights(linear, folder, model_config): continue @@ -70,7 +73,7 @@ def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]: layers.append(linear) if act_name := layer_config.get("activation_function"): layers.append(get_act_fn(act_name)) - return nn.Sequential(*layers).to(dtype=torch.float32) + return nn.Sequential(*layers).to(dtype=model_config.head_dtype) except Exception: logger.exception("ST projector loading failed") @@ -105,15 +108,13 @@ def _load_dense_weights(linear: nn.Linear, folder: str, if weight_key in state_dict: weight_loader = getattr(linear.weight, "weight_loader", default_weight_loader) - weight_loader(linear.weight, - state_dict[weight_key].to(torch.float32)) + weight_loader(linear.weight, state_dict[weight_key]) bias_key = weight_key.replace("weight", "bias") if linear.bias is not None and bias_key in state_dict: bias_loader = getattr(linear.bias, "weight_loader", default_weight_loader) - bias_loader(linear.bias, - state_dict[bias_key].to(torch.float32)) + bias_loader(linear.bias, state_dict[bias_key]) return True except Exception: logger.exception("Failed to load %s", filename) @@ -131,6 +132,41 @@ def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str: return model_name + pooling_suffix +def try_create_mm_pooling_model_cls(orig_cls: _T) -> _T: + + class CallVisitor(ast.NodeVisitor): + + def __init__(self): + self.calls = [] + + def visit_Call(self, node): + if isinstance(node.func, ast.Name): + self.calls.append(node.func.id) + self.generic_visit(node) + + visitor = CallVisitor() + visitor.visit(ast.parse(inspect.getsource(orig_cls))) + if "init_vllm_registered_model" not in visitor.calls: + return None + + class ModelForPooling(orig_cls, VllmModelForPooling): + + is_pooling_model = True + + def __init__( + self, + *, + vllm_config: "VllmConfig", + prefix: str = "", + **kwargs: Any, + ) -> None: + super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) + + self.pooler = self.get_language_model().pooler + + return ModelForPooling # type: ignore + + def _create_pooling_model_cls(orig_cls: _T) -> _T: # Lazy import from .utils import AutoWeightsLoader, WeightsMapper @@ -257,7 +293,7 @@ def as_seq_cls_model(cls: _T) -> _T: from vllm.model_executor.models.interfaces import SupportsCrossEncoding from vllm.sequence import IntermediateTensors - from .utils import maybe_prefix + from .utils import get_model_hidden_size, maybe_prefix class ModelForSequenceClassification(_create_pooling_model_cls(cls), SupportsCrossEncoding): @@ -265,9 +301,10 @@ def as_seq_cls_model(cls: _T) -> _T: def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config + hidden_size = get_model_hidden_size(config) self.score = ReplicatedLinear( - config.hidden_size, + hidden_size, config.num_labels, bias=False, params_dtype=torch.float32, @@ -400,6 +437,7 @@ def load_weights_using_from_2_way_softmax( from vllm.model_executor.models.utils import AutoWeightsLoader model_config = model.vllm_config.model_config + tokens = getattr(model.config, "classifier_from_token", []) tokens = cast(list[int], tokens) assert len(tokens) == 2 @@ -407,9 +445,10 @@ def load_weights_using_from_2_way_softmax( if model.config.tie_word_embeddings: model.lm_head = model.model.embed_tokens else: + quant_config = model.vllm_config.quant_config model.lm_head = ParallelLMHead(model.config.vocab_size, model.config.hidden_size, - quant_config=model.quant_config) + quant_config=quant_config) loader = AutoWeightsLoader(model) loaded_weights = loader.load_weights(weights) @@ -453,9 +492,10 @@ def load_weights_no_post_processing(model, if model.config.tie_word_embeddings: model.lm_head = model.model.embed_tokens else: + quant_config = model.vllm_config.quant_config model.lm_head = ParallelLMHead(model.config.vocab_size, model.config.hidden_size, - quant_config=model.quant_config) + quant_config=quant_config) loader = AutoWeightsLoader(model) loaded_weights = loader.load_weights(weights) diff --git a/vllm/model_executor/models/arcee.py b/vllm/model_executor/models/arcee.py index 13ed4da0602ad..be82c2fd59644 100644 --- a/vllm/model_executor/models/arcee.py +++ b/vllm/model_executor/models/arcee.py @@ -342,7 +342,7 @@ class ArceeModel(nn.Module): class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): """Arcee Model for causal language modeling, integrated with vLLM runtime.""" - # Map fused module names to their sub-module components + # Map fused module names to their submodule components # (for quantization and LoRA) packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 1c7960fa3e0a5..db262447d7fa8 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -143,16 +143,8 @@ class AriaProjector(nn.Module): projects ViT's outputs into MoE's inputs. Args: - patch_to_query_dict (dict): Maps patch numbers to their corresponding - query numbers, - e.g., {1225: 128, 4900: 256}. This allows for different query sizes - based on image resolution. - embed_dim (int): Embedding dimension. - num_heads (int): Number of attention heads. - kv_dim (int): Dimension of key and value. - ff_dim (int): Hidden dimension of the feed-forward network. - output_dim (int): Output dimension. - norm_layer (nn.Module): Normalization layer. Default is nn.LayerNorm. + config: [AriaConfig](https://huggingface.co/docs/transformers/main/model_doc/aria#transformers.AriaConfig) + containing projector configuration parameters. Outputs: A tensor with the shape of (batch_size, query_number, output_dim) @@ -282,8 +274,8 @@ class AriaTextMoELayer(nn.Module): Forward pass of the MoE Layer. Args: - hidden_states (torch.Tensor): Input tensor of shape (batch_size, - sequence_length, hidden_size). + hidden_states: Input tensor of shape + (batch_size, sequence_length, hidden_size). Returns: torch.Tensor: Output tensor after passing through the MoE layer. diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 32551d8102f32..242530817c642 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -29,7 +29,8 @@ from transformers import BartConfig from transformers.utils import logging from vllm.attention import Attention, AttentionType -from vllm.config import CacheConfig, LoRAConfig, VllmConfig +from vllm.config import CacheConfig, VllmConfig +from vllm.config.lora import LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -400,8 +401,7 @@ class BartEncoderLayer(nn.Module): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: r""" Args: - hidden_states - torch.Tensor of *encoder* input embeddings. + hidden_states: torch.Tensor of *encoder* input embeddings. Returns: Encoder layer output torch.Tensor """ @@ -489,10 +489,8 @@ class BartDecoderLayer(nn.Module): ) -> torch.Tensor: r""" Args: - decoder_hidden_states - torch.Tensor of *decoder* input embeddings. - encoder_hidden_states - torch.Tensor of *encoder* input embeddings. + decoder_hidden_states: torch.Tensor of *decoder* input embeddings. + encoder_hidden_states: torch.Tensor of *encoder* input embeddings. Returns: Decoder layer output torch.Tensor """ @@ -583,12 +581,10 @@ class BartEncoder(nn.Module): ) -> torch.Tensor: r""" Args: - input_ids - Indices of *encoder* input sequence tokens in the vocabulary. - Padding will be ignored by default should you - provide it. - positions - Positions of *encoder* input sequence tokens. + input_ids: Indices of *encoder* input sequence tokens in the + vocabulary. + Padding will be ignored by default should you provide it. + positions: Positions of *encoder* input sequence tokens. Returns: Decoder output torch.Tensor """ @@ -662,14 +658,11 @@ class BartDecoder(nn.Module): ) -> torch.Tensor: r""" Args: - decoder_input_ids - Indices of *decoder* input sequence tokens in the vocabulary. - Padding will be ignored by default should you - provide it. - decoder_positions - Positions of *decoder* input sequence tokens. - encoder_hidden_states: - Tensor of encoder output embeddings + decoder_input_ids: Indices of *decoder* input sequence tokens + in the vocabulary. + Padding will be ignored by default should you provide it. + decoder_positions: Positions of *decoder* input sequence tokens. + encoder_hidden_states: Tensor of encoder output embeddings. Returns: Decoder output torch.Tensor """ @@ -731,16 +724,13 @@ class BartModel(nn.Module, SupportsQuant): encoder_positions: torch.Tensor) -> torch.Tensor: r""" Args: - input_ids - Indices of *decoder* input sequence tokens in the vocabulary. - Padding will be ignored by default should you - provide it. - positions - Positions of *decoder* input sequence tokens. - encoder_input_ids - Indices of *encoder* input sequence tokens in the vocabulary. - encoder_positions: - Positions of *encoder* input sequence tokens. + input_ids: Indices of *decoder* input sequence tokens + in the vocabulary. + Padding will be ignored by default should you provide it. + positions: Positions of *decoder* input sequence tokens. + encoder_input_ids: Indices of *encoder* input sequence tokens + in the vocabulary. + encoder_positions: Positions of *encoder* input sequence tokens. Returns: Model output torch.Tensor """ @@ -847,14 +837,10 @@ class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant): ) -> torch.Tensor: r""" Args: - input_ids - torch.Tensor of *decoder* input token ids. - positions - torch.Tensor of *decoder* position indices. - encoder_input_ids - torch.Tensor of *encoder* input token ids. - encoder_positions - torch.Tensor of *encoder* position indices + input_ids: torch.Tensor of *decoder* input token ids. + positions: torch.Tensor of *decoder* position indices. + encoder_input_ids: torch.Tensor of *encoder* input token ids. + encoder_positions: torch.Tensor of *encoder* position indices. Returns: Output torch.Tensor """ @@ -911,8 +897,7 @@ class MBartEncoderLayer(BartEncoderLayer): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: r""" Args: - hidden_states - torch.Tensor of *encoder* input embeddings. + hidden_states: torch.Tensor of *encoder* input embeddings. Returns: Encoder layer output torch.Tensor """ @@ -1034,12 +1019,10 @@ class MBartEncoder(nn.Module): ) -> torch.Tensor: r""" Args: - input_ids - Indices of *encoder* input sequence tokens in the vocabulary. - Padding will be ignored by default should you - provide it. - positions - Positions of *encoder* input sequence tokens. + input_ids: Indices of *encoder* input sequence tokens in the + vocabulary. + Padding will be ignored by default should you provide it. + positions: Positions of *encoder* input sequence tokens. Returns: Decoder output torch.Tensor """ @@ -1115,14 +1098,11 @@ class MBartDecoder(nn.Module): ) -> torch.Tensor: r""" Args: - decoder_input_ids - Indices of *decoder* input sequence tokens in the vocabulary. - Padding will be ignored by default should you - provide it. - decoder_positions - Positions of *decoder* input sequence tokens. - encoder_hidden_states: - Tensor of encoder output embeddings + decoder_input_ids: Indices of *decoder* input sequence tokens + in the vocabulary. + Padding will be ignored by default should you provide it. + decoder_positions: Positions of *decoder* input sequence tokens. + encoder_hidden_states: Tensor of encoder output embeddings. Returns: Decoder output torch.Tensor """ @@ -1184,16 +1164,13 @@ class MBartModel(nn.Module, SupportsQuant): encoder_positions: torch.Tensor) -> torch.Tensor: r""" Args: - input_ids - Indices of *decoder* input sequence tokens in the vocabulary. - Padding will be ignored by default should you - provide it. - positions - Positions of *decoder* input sequence tokens. - encoder_input_ids - Indices of *encoder* input sequence tokens in the vocabulary. - encoder_positions: - Positions of *encoder* input sequence tokens. + input_ids: Indices of *decoder* input sequence tokens + in the vocabulary. + Padding will be ignored by default should you provide it. + positions: Positions of *decoder* input sequence tokens. + encoder_input_ids: Indices of *encoder* input sequence tokens + in the vocabulary. + encoder_positions: Positions of *encoder* input sequence tokens. Returns: Model output torch.Tensor """ diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 8f23439655ed7..c07e5364814ac 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -562,7 +562,9 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, self.bert = BertPoolingModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "bert"), embedding_class=BertEmbedding) - self.classifier = nn.Linear(config.hidden_size, config.num_labels) + self.classifier = nn.Linear(config.hidden_size, + config.num_labels, + dtype=vllm_config.model_config.head_dtype) pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py index 3be7e11d947d5..b758cbf28d893 100644 --- a/vllm/model_executor/models/bert_with_rope.py +++ b/vllm/model_executor/models/bert_with_rope.py @@ -637,14 +637,14 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding): self.new = GteNewModel(vllm_config=vllm_config, prefix=prefix, add_pooling_layer=True) - self.classifier = RowParallelLinear(config.hidden_size, - config.num_labels, - input_is_parallel=False, - bias=True, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "classifier"), - return_bias=False) + self.classifier = ReplicatedLinear( + config.hidden_size, + config.num_labels, + bias=True, + quant_config=quant_config, + params_dtype=vllm_config.model_config.head_dtype, + prefix=maybe_prefix(prefix, "classifier"), + return_bias=False) pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index ed98a3008c567..c1e7a7d498b11 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -678,7 +678,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, Args: input_ids: Flattened (concatenated) input_ids corresponding to a batch. - pixel_values: The pixels in each input image. Info: [Blip2ImageInputs][] diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 13ecda0122be6..f8ed92314c3d2 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -257,7 +257,7 @@ class BloomModel(nn.Module): config.hidden_size)) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.word_embeddings_layernorm(self.word_embeddings(input_ids)) + return self.word_embeddings(input_ids) def forward( self, @@ -271,6 +271,7 @@ class BloomModel(nn.Module): hidden_states = inputs_embeds else: hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.word_embeddings_layernorm(hidden_states) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index f38e7fc202209..687af7a189cea 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -312,7 +312,8 @@ class MambaModelConfig(VerifyAndUpdateConfig): # TODO(tdoublep): remove as full cuda graph support is added FCG_NOT_SUPPORTED_MODELS = [ - "Lfm2ForCausalLM", "MiniMaxText01ForCausalLM" + "Lfm2ForCausalLM", + "MiniMaxText01ForCausalLM", ] if (model_config.architecture not in FCG_NOT_SUPPORTED_MODELS diff --git a/vllm/model_executor/models/deepseek_eagle.py b/vllm/model_executor/models/deepseek_eagle.py index 0c9c83cf61000..5e8447a7f48f9 100644 --- a/vllm/model_executor/models/deepseek_eagle.py +++ b/vllm/model_executor/models/deepseek_eagle.py @@ -37,8 +37,6 @@ class DeepseekV2Model(nn.Module): super().__init__() self.config = vllm_config. \ speculative_config.draft_model_config.hf_config - model_config = vllm_config.model_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config self.vocab_size = self.config.vocab_size @@ -51,11 +49,8 @@ class DeepseekV2Model(nn.Module): self.layers = nn.ModuleList([ DeepseekV2DecoderLayer( - self.config, + vllm_config, prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, ) for i in range(self.config.num_hidden_layers) ]) diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 0ad001be71c19..8fbf16d206a86 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -7,7 +7,7 @@ import torch import torch.nn as nn from transformers import PretrainedConfig -from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.config import VllmConfig from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -43,23 +43,19 @@ class SharedHead(nn.Module): class DeepSeekMultiTokenPredictorLayer(nn.Module): - def __init__( - self, - config: PretrainedConfig, - prefix: str, - model_config: ModelConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: + def __init__(self, vllm_config: VllmConfig, prefix: str) -> None: super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False) self.shared_head = SharedHead(config=config, quant_config=quant_config) - self.mtp_block = DeepseekV2DecoderLayer(config, prefix, model_config, - cache_config, quant_config) + self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix) def forward( self, @@ -95,13 +91,8 @@ class DeepSeekMultiTokenPredictor(nn.Module): # to map the exact layer index from weights self.layers = torch.nn.ModuleDict({ str(idx): - DeepSeekMultiTokenPredictorLayer( - config, - f"{prefix}.layers.{idx}", - model_config=vllm_config.model_config, - cache_config=vllm_config.cache_config, - quant_config=vllm_config.quant_config, - ) + DeepSeekMultiTokenPredictorLayer(vllm_config, + f"{prefix}.layers.{idx}") for idx in range(self.mtp_start_layer_idx, self.mtp_start_layer_idx + self.num_mtp_layers) }) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index d65dcfebaeff8..e4a21febc5bde 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -32,12 +32,14 @@ import torch from torch import nn from transformers import DeepseekV2Config, DeepseekV3Config +import vllm.envs as envs from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import (CacheConfig, ModelConfig, VllmConfig, - get_current_vllm_config) +from vllm.config import CacheConfig, ParallelConfig, VllmConfig from vllm.distributed import (get_ep_group, get_pp_group, - get_tensor_model_parallel_world_size) + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm @@ -55,7 +57,9 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors +from vllm.utils import cdiv, direct_register_custom_op from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP from .utils import (PPMissingLayer, is_pp_missing_parameter, @@ -72,19 +76,27 @@ class DeepseekV2MLP(nn.Module): hidden_act: str, quant_config: Optional[QuantizationConfig] = None, reduce_results: bool = True, + is_sequence_parallel=False, prefix: str = "", ) -> None: super().__init__() + + # If is_sequence_parallel, the input and output tensors are sharded + # across the ranks within the tp_group. In this case the weights are + # replicated and no collective ops are needed. + # Otherwise we use standard TP with an allreduce at the end. self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config, + disable_tp=is_sequence_parallel, prefix=f"{prefix}.gate_up_proj") self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, quant_config=quant_config, reduce_results=reduce_results, + disable_tp=is_sequence_parallel, prefix=f"{prefix}.down_proj") if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " @@ -98,17 +110,58 @@ class DeepseekV2MLP(nn.Module): return x +# Chunk x along the num_tokens axis for sequence parallelism +# NOTE: This is wrapped in a torch custom op to work around the following issue: +# The output tensor can have a sequence length 0 at small input sequence lengths +# even though we explicitly pad to avoid this. +def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + + # all_gather needs the sequence length to be divisible by tp_size + seq_len = x.size(0) + remainder = seq_len % tp_size + if remainder != 0: + pad_len = tp_size - remainder + x = nn.functional.pad(x, (0, 0, 0, pad_len)) + + chunk = x.shape[0] // tp_size + start = tp_rank * chunk + return torch.narrow(x, 0, start, chunk) + + +def sequence_parallel_chunk_fake(x: torch.Tensor) -> torch.Tensor: + tp_size = get_tensor_model_parallel_world_size() + seq_len = cdiv(x.size(0), tp_size) + shape = list(x.shape) + shape[0] = seq_len + out = torch.empty(shape, dtype=x.dtype, device=x.device) + return out + + +direct_register_custom_op( + op_name="sequence_parallel_chunk", + op_func=sequence_parallel_chunk, + mutates_args=[], + fake_impl=sequence_parallel_chunk_fake, + dispatch_key=current_platform.dispatch_key, + tags=(torch.Tag.needs_fixed_stride_order, ), +) + + class DeepseekV2MoE(nn.Module): def __init__( self, config: Union[DeepseekV2Config, DeepseekV3Config], + parallel_config: ParallelConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", - enable_eplb: bool = False, ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.routed_scaling_factor = config.routed_scaling_factor self.ep_group = get_ep_group().device_group @@ -117,6 +170,21 @@ class DeepseekV2MoE(nn.Module): self.n_routed_experts: int = config.n_routed_experts self.n_shared_experts: int = config.n_shared_experts + # The all_reduce at the end of attention (during o_proj) means that + # inputs are replicated across each rank of the tensor parallel group. + # If using expert-parallelism with DeepEP All2All ops, replicated + # tokens results in useless duplicate computation and communication. + # + # In this case, ensure the input to the experts is sequence parallel + # to avoid the excess work. + # + # Not needed for pplx-kernels as it can handle duplicate input tokens. + self.is_sequence_parallel = (envs.VLLM_ALL2ALL_BACKEND + in ("deepep_high_throughput", + "deepep_low_latency") + and parallel_config.enable_expert_parallel + and self.tp_size > 1) + if config.hidden_act != "silu": raise ValueError(f"Unsupported activation: {config.hidden_act}. " "Only silu is supported for now.") @@ -133,9 +201,8 @@ class DeepseekV2MoE(nn.Module): self.gate.e_score_correction_bias = None # Load balancing settings. - vllm_config = get_current_vllm_config() - eplb_config = vllm_config.parallel_config.eplb_config - self.enable_eplb = enable_eplb + eplb_config = parallel_config.eplb_config + self.enable_eplb = parallel_config.enable_eplb self.n_redundant_experts = eplb_config.num_redundant_experts self.n_logical_experts = self.n_routed_experts @@ -166,7 +233,9 @@ class DeepseekV2MoE(nn.Module): routed_scaling_factor=1.0, e_score_correction_bias=self.gate.e_score_correction_bias, enable_eplb=self.enable_eplb, - num_redundant_experts=self.n_redundant_experts) + num_redundant_experts=self.n_redundant_experts, + is_sequence_parallel=self.is_sequence_parallel, + ) self.shared_experts = None else: intermediate_size = (config.moe_intermediate_size * @@ -177,6 +246,7 @@ class DeepseekV2MoE(nn.Module): intermediate_size=intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, + is_sequence_parallel=self.is_sequence_parallel, reduce_results=False, prefix=f"{prefix}.shared_experts", ) @@ -199,11 +269,22 @@ class DeepseekV2MoE(nn.Module): routed_scaling_factor=1.0, e_score_correction_bias=self.gate.e_score_correction_bias, enable_eplb=self.enable_eplb, - num_redundant_experts=self.n_redundant_experts) + num_redundant_experts=self.n_redundant_experts, + is_sequence_parallel=self.is_sequence_parallel, + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) + + # Chunk the hidden states so they aren't replicated across TP ranks. + # This avoids duplicate computation in self.experts. + # TODO: We can replace the all_reduce at the end of attn with a + # reduce_scatter instead of chunking here. + if self.is_sequence_parallel: + hidden_states = torch.ops.vllm.sequence_parallel_chunk( + hidden_states) + # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) @@ -228,7 +309,11 @@ class DeepseekV2MoE(nn.Module): assert shared_output is not None final_hidden_states += shared_output - if self.tp_size > 1: + if self.is_sequence_parallel: + final_hidden_states = tensor_model_parallel_all_gather( + final_hidden_states, 0) + final_hidden_states = final_hidden_states[:num_tokens] + elif self.tp_size > 1: final_hidden_states = ( self.experts.maybe_all_reduce_tensor_model_parallel( final_hidden_states)) @@ -532,16 +617,15 @@ class DeepseekV2MLAAttention(nn.Module): class DeepseekV2DecoderLayer(nn.Module): - def __init__( - self, - config: Union[DeepseekV2Config, DeepseekV3Config], - prefix: str, - model_config: ModelConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - enable_eplb: bool = False, - ) -> None: + def __init__(self, vllm_config: VllmConfig, prefix: str) -> None: super().__init__() + + config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + parallel_config = vllm_config.parallel_config + self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) @@ -578,9 +662,9 @@ class DeepseekV2DecoderLayer(nn.Module): and layer_idx % config.moe_layer_freq == 0): self.mlp = DeepseekV2MoE( config=config, + parallel_config=parallel_config, quant_config=quant_config, prefix=f"{prefix}.mlp", - enable_eplb=enable_eplb, ) else: self.mlp = DeepseekV2MLP( @@ -650,10 +734,7 @@ class DeepseekV2Model(nn.Module): super().__init__() config = vllm_config.model_config.hf_config - model_config = vllm_config.model_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - enable_eplb = vllm_config.parallel_config.enable_eplb self.config = config self.vocab_size = config.vocab_size @@ -669,14 +750,7 @@ class DeepseekV2Model(nn.Module): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: DeepseekV2DecoderLayer( - config, - prefix, - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - enable_eplb=enable_eplb, - ), + lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix), prefix=f"{prefix}.layers") if get_pp_group().is_last_rank: diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index 5eab02b17151c..d7ae8206baca5 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -21,7 +21,8 @@ from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.models.transformers import replace_linear_class from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, NestedTensors) + MultiModalKwargsItems, MultiModalUUIDDict, + NestedTensors) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -290,7 +291,7 @@ class DeepseekVL2MultiModalProcessor( mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], - mm_hash_overrides: Optional[dict[str, list[str]]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> tuple[list[int], MultiModalProcessingInfo, bool]: # The processor logic is different for len(images) <= 2 vs > 2 # Since the processing cache assumes that the processor output is @@ -302,7 +303,7 @@ class DeepseekVL2MultiModalProcessor( mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) return super()._cached_apply_hf_processor( @@ -310,7 +311,7 @@ class DeepseekVL2MultiModalProcessor( mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) diff --git a/vllm/model_executor/models/donut.py b/vllm/model_executor/models/donut.py index c00db52371b68..23f4c6a4f93fc 100644 --- a/vllm/model_executor/models/donut.py +++ b/vllm/model_executor/models/donut.py @@ -79,10 +79,8 @@ class DonutLanguageForConditionalGeneration(nn.Module, SupportsV0Only): ) -> torch.Tensor: r""" Args: - input_ids - torch.Tensor of *decoder* input token ids. - positions - torch.Tensor of *decoder* position indices. + input_ids: torch.Tensor of *decoder* input token ids. + positions: torch.Tensor of *decoder* position indices. Returns: Output torch.Tensor """ @@ -351,14 +349,10 @@ class DonutForConditionalGeneration(nn.Module, SupportsMultiModal, ) -> torch.Tensor: r""" Args: - input_ids - torch.Tensor of *decoder* input token ids. - positions - torch.Tensor of *decoder* position indices. - encoder_input_ids - torch.Tensor of *encoder* input token ids. - encoder_positions - torch.Tensor of *encoder* position indices + input_ids: torch.Tensor of *decoder* input token ids. + positions: torch.Tensor of *decoder* position indices. + encoder_input_ids: torch.Tensor of *encoder* input token ids. + encoder_positions: torch.Tensor of *encoder* position indices Returns: Output torch.Tensor """ diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index d880fc434e20f..3396c67f42b7b 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -34,6 +34,7 @@ import torch.nn.functional as F from einops import rearrange, repeat from transformers import BatchFeature +from vllm.attention.layer import check_upstream_fa_availability from vllm.config import VllmConfig from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils @@ -66,8 +67,6 @@ from .vision import get_vit_attn_backend logger = init_logger(__name__) -_MAX_FRAMES_PER_VIDEO = 16 - # === Vision Transformer === # @@ -172,7 +171,16 @@ class Ernie4_5_VisionAttention(nn.Module): prefix=f"{prefix}.proj") # Detect attention implementation. - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend = get_vit_attn_backend( + head_size=self.hidden_size_per_attention_head, + dtype=torch.get_default_dtype()) + + self.use_upstream_fa = False + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability(torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN + self.use_upstream_fa = True + if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.ROCM_AITER_FA @@ -235,7 +243,10 @@ class Ernie4_5_VisionAttention(nn.Module): if self.attn_backend == _Backend.ROCM_AITER_FA: from aiter import flash_attn_varlen_func else: - from flash_attn import flash_attn_varlen_func + if self.use_upstream_fa: + from flash_attn import flash_attn_varlen_func + else: + from vllm.vllm_flash_attn import flash_attn_varlen_func q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) @@ -459,7 +470,11 @@ class Ernie4_5_VisionTransformer(nn.Module): ), "vit's config.hidden must be equal to config.embed_dim" self.ln = nn.LayerNorm(hidden_size, eps=1e-6) - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend = get_vit_attn_backend( + head_size=head_dim, dtype=torch.get_default_dtype()) + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability(torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN @property def dtype(self) -> torch.dtype: @@ -839,6 +854,15 @@ class Ernie4_5_VLProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None, "video": None} + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + max_image_tokens = self.get_max_image_tokens() + max_video_tokens = self.get_max_video_tokens(seq_len, mm_counts) + return {"image": max_image_tokens, "video": max_video_tokens} + def _get_vision_info( self, *, @@ -964,8 +988,7 @@ class Ernie4_5_VLProcessingInfo(BaseProcessingInfo): max_image_tokens = self.get_max_image_tokens() * max_images max_total_frames = self._get_max_video_frames(seq_len - max_image_tokens) - max_frames_per_video = min(max_total_frames // max(max_videos, 1), - _MAX_FRAMES_PER_VIDEO) + max_frames_per_video = max_total_frames // max(max_videos, 1) return max(max_frames_per_video, 2) @@ -1315,7 +1338,7 @@ class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal, raise ValueError(f"{name} should be 2D or batched 3D tensor. " f"Got ndim: {mm_input.ndim} " f"(shape={mm_input.shape})") - return torch.concat(list(mm_input)) + return mm_input.reshape(-1, mm_input.shape[-1]) else: return torch.concat(mm_input) @@ -1423,7 +1446,7 @@ class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal, return None # The result multimodal_embeddings is tuple of tensors, with each - # tensor correspoending to a multimodal data item (image or video). + # tensor corresponding to a multimodal data item (image or video). multimodal_embeddings: tuple[torch.Tensor, ...] = () # NOTE: It is important to iterate over the keys in this dictionary diff --git a/vllm/model_executor/models/ernie45_vl_moe.py b/vllm/model_executor/models/ernie45_vl_moe.py index 780974c3b758e..6034505fa7d68 100644 --- a/vllm/model_executor/models/ernie45_vl_moe.py +++ b/vllm/model_executor/models/ernie45_vl_moe.py @@ -287,8 +287,13 @@ class Ernie4_5_VLMoeMoE(nn.Module): if self.has_shared_experts: shared_output = self.shared_experts(hidden_states) - if visual_token_mask is not None and visual_token_mask.any(): - # assert visual_token_mask.shape[0] != hidden_states.shape[0] + if visual_token_mask is not None and visual_token_mask.all(): + # only vision modal input + router_logits, _ = self.vision_experts_gate(hidden_states) + final_hidden_states = self.vision_experts( + hidden_states=hidden_states, router_logits=router_logits) + elif visual_token_mask is not None and visual_token_mask.any(): + # text and vision modals input visual_token_mask = visual_token_mask.repeat( 1, self.hidden_size).bool() text_token_mask = ~visual_token_mask @@ -310,7 +315,7 @@ class Ernie4_5_VLMoeMoE(nn.Module): hidden_states=vision_hidden_states, router_logits=vision_router_logits).flatten() else: - # text modal input processing directly + # only text modal input text_router_logits, _ = self.text_experts_gate(hidden_states) final_hidden_states = self.text_experts( diff --git a/vllm/model_executor/models/florence2.py b/vllm/model_executor/models/florence2.py index d0881231fb1e7..5e05e0c60f41c 100644 --- a/vllm/model_executor/models/florence2.py +++ b/vllm/model_executor/models/florence2.py @@ -631,16 +631,14 @@ class Florence2LanguageModel(nn.Module): ) -> torch.Tensor: r""" Args: - input_ids - Indices of *decoder* input sequence tokens in the vocabulary. + input_ids: Indices of *decoder* input sequence tokens + in the vocabulary. Padding will be ignored by default should you provide it. - positions - Positions of *decoder* input sequence tokens. - encoder_input_ids - Indices of *encoder* input sequence tokens in the vocabulary. - encoder_positions: - Positions of *encoder* input sequence tokens. + positions: Positions of *decoder* input sequence tokens. + encoder_input_ids: Indices of *encoder* input sequence tokens + in the vocabulary. + encoder_positions: Positions of *encoder* input sequence tokens. Returns: Model output torch.Tensor """ @@ -699,14 +697,10 @@ class Florence2LanguageForConditionalGeneration(nn.Module, SupportsV0Only): ) -> torch.Tensor: r""" Args: - input_ids - torch.Tensor of *decoder* input token ids. - positions - torch.Tensor of *decoder* position indices. - encoder_input_ids - torch.Tensor of *encoder* input token ids. - encoder_positions - torch.Tensor of *encoder* position indices + input_ids: torch.Tensor of *decoder* input token ids. + positions: torch.Tensor of *decoder* position indices. + encoder_input_ids: torch.Tensor of *encoder* input token ids. + encoder_positions: torch.Tensor of *encoder* position indices Returns: Output torch.Tensor """ @@ -1068,14 +1062,10 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal, ) -> torch.Tensor: r""" Args: - input_ids - torch.Tensor of *decoder* input token ids. - positions - torch.Tensor of *decoder* position indices. - encoder_input_ids - torch.Tensor of *encoder* input token ids. - encoder_positions - torch.Tensor of *encoder* position indices + input_ids: torch.Tensor of *decoder* input token ids. + positions: torch.Tensor of *decoder* position indices. + encoder_input_ids: torch.Tensor of *encoder* input token ids. + encoder_positions: torch.Tensor of *encoder* position indices Returns: Output torch.Tensor """ diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index f3dc7dde46bdf..e652ba2f1c7fe 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -512,7 +512,11 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, architectures=["Gemma3ForCausalLM"], ) logit_scale = getattr(config, "logit_scale", 1.0) - self.language_model.logits_processor.scale *= logit_scale + + if hasattr(self.language_model, "logits_processor"): + # The logits processor can be unset if we're using + # automatic conversion to pooling model. + self.language_model.logits_processor.scale *= logit_scale self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/gemma3n_mm.py b/vllm/model_executor/models/gemma3n_mm.py index 3074451e40a4d..663d4da7cec23 100644 --- a/vllm/model_executor/models/gemma3n_mm.py +++ b/vllm/model_executor/models/gemma3n_mm.py @@ -586,10 +586,10 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal, # ruff: noqa # The Gemma3nProcessor expects all audio will be 30s in length and inserts 188 audio soft tokens into the - # text to account for this. However, the audio preprocessing and encoder do not gurarantee they will + # text to account for this. However, the audio preprocessing and encoder do not guarantee they will # produce 188 soft tokens; they will produce at most that many tokens, but they may produce fewer tokens # depending on the length of the longest audio input in the batch. When we encounter this situation, we pad - # the audio feature out to 188 soft tokens with the emebedding of the last token in the embed_audio vocab. + # the audio feature out to 188 soft tokens with the embedding of the last token in the embed_audio vocab. # TODO precompute and cache padding audio_padding_toks = torch.tensor([[self.vocab_size - 1]], dtype=torch.long, diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index fd5fecac67d67..22386a5e819ab 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -44,6 +44,7 @@ from transformers.models.glm4v.video_processing_glm4v import ( Glm4vVideoProcessor) from transformers.video_utils import VideoMetadata +from vllm.attention.layer import check_upstream_fa_availability from vllm.config import VllmConfig from vllm.distributed import (get_tensor_model_parallel_world_size, parallel_state) @@ -260,7 +261,15 @@ class Glm4vVisionAttention(nn.Module): ) # Detect attention implementation. - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend = get_vit_attn_backend( + head_size=self.hidden_size_per_attention_head, + dtype=torch.get_default_dtype()) + self.use_upstream_fa = False + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability(torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN + self.use_upstream_fa = True + if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, @@ -272,23 +281,10 @@ class Glm4vVisionAttention(nn.Module): def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: # [s, b, 3 * head * head_dim] seq_len, bs, _ = qkv.shape - if self.tp_size > 1: - qkv = all_gather_interleave(qkv, self.qkv.hidden_size, - self.tp_size) # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim] q, k, v = qkv.chunk(3, dim=2) - # 3 * [s, b, head * head_dim] - if self.tp_size > 1: - splitter = partial( - dist_utils.split_tensor_along_last_dim, - num_partitions=self.tp_size, - ) - q = splitter(q)[self.tp_rank] - k = splitter(k)[self.tp_rank] - v = splitter(v)[self.tp_rank] - # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim] new_shape = ( seq_len, @@ -323,7 +319,10 @@ class Glm4vVisionAttention(nn.Module): if self.attn_backend == _Backend.FLASH_ATTN: # from vllm_flash_attn.flash_attn_interface import ( # flash_attn_varlen_func) - from flash_attn import flash_attn_varlen_func + if self.use_upstream_fa: + from flash_attn import flash_attn_varlen_func + else: + from vllm.vllm_flash_attn import flash_attn_varlen_func q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) @@ -420,15 +419,16 @@ class Glm4vVisionBlock(nn.Module): max_seqlen: Optional[int] = None, # Only used for Flash Attention seqlens: Optional[list[int]] = None, # Only used for xFormers ) -> torch.Tensor: - x = x + self.attn( + x_attn = self.attn( self.norm1(x), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, max_seqlen=max_seqlen, seqlens=seqlens, ) + x_fused_norm, residual = self.norm2(x, residual=x_attn) + x = residual + self.mlp(x_fused_norm) - x = x + self.mlp(self.norm2(x)) return x @@ -728,7 +728,11 @@ class Glm4vVisionTransformer(nn.Module): self.post_layernorm = RMSNorm(vision_config.hidden_size, eps=vision_config.rms_norm_eps) - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend = get_vit_attn_backend( + head_size=head_dim, dtype=torch.get_default_dtype()) + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability(torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN @property def dtype(self) -> torch.dtype: @@ -1036,6 +1040,43 @@ class Glm4vProcessingInfo(BaseProcessingInfo): selected_timestamps.append(timestamps_list[idx]) return selected_timestamps + def _construct_video_placeholder( + self, + video_array: np.ndarray, + metadata: dict[str, Any], + grid_thw: torch.Tensor, + ) -> str: + hf_processor = self.get_hf_processor() + tokenizer = self.get_tokenizer() + image_processor = hf_processor.image_processor + + hf_config = self.get_hf_config() + boi_token_id = hf_config.image_start_token_id + eoi_token_id = hf_config.image_end_token_id + bov_token_id = hf_config.video_start_token_id + eov_token_id = hf_config.video_end_token_id + merge_length = image_processor.merge_size**2 + + assert isinstance(grid_thw, torch.Tensor) + timestamps = self._get_video_second_idx(metadata, len(video_array)) + frames_idx_token = [ + tokenizer.encode(str(i), add_special_tokens=False) + for i in timestamps + ] + T, H, W = grid_thw + num_tokens_per_frame = int(H * W) // merge_length + placeholder = [] + placeholder.append(bov_token_id) + for frame_idx in frames_idx_token: + placeholder.append(boi_token_id) + placeholder.extend([hf_processor.video_token_id] * + num_tokens_per_frame) + placeholder.append(eoi_token_id) + placeholder.extend(frame_idx) + placeholder.append(eov_token_id) + + return placeholder + class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]): @@ -1131,17 +1172,10 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]): for item in mm_data.pop("videos", []): video_array, metadata = item - # FIXME(Isotr0py): Activate the below logic after we can disable - # resampling from video loader backend. - # assert metadata["total_num_frames"] == len(video_array), ( - # f"Total frames {metadata['total_num_frames']} does not " - # f"match the length of video array {len(video_array)}.") + if metadata["video_backend"] == "opencv_dynamic": + mm_kwargs["do_sample_frames"] = False - # NOTE: Temporary workaround for resampled videos. - # this can cause a divergence with HF implementation if - # the input video is resampled in advance. - - if metadata["total_num_frames"] != len(video_array): + elif metadata["total_num_frames"] != len(video_array): logger.warning( "Total frames in metadata " "(%s) does not match the length of " @@ -1153,11 +1187,10 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]): len(video_array), ) metadata["total_num_frames"] = len(video_array) - metadata = VideoMetadata(**metadata) video_mm_data = dict() video_mm_data["videos"] = [[video_array]] - video_mm_data["video_metadata"] = [[metadata]] + video_mm_data["video_metadata"] = [[VideoMetadata(**metadata)]] video_outputs = super()._call_hf_processor( prompt="<|begin_of_video|><|video|><|end_of_video|>", @@ -1165,11 +1198,23 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]): mm_kwargs=mm_kwargs, tok_kwargs=tok_kwargs, ) - input_ids = video_outputs.pop("input_ids") - input_ids[input_ids == processor.image_token_id] = ( - processor.video_token_id) - video_placeholder = processor.tokenizer.batch_decode( - input_ids)[0] + if "do_sample_frames" in mm_kwargs and not mm_kwargs[ + "do_sample_frames"]: + # Transformers v4.55 has incorrect timestamps issue for + # skip sampling. We construct the placeholder manually to + # get placeholders with correct timestamps. + placeholder = self.info._construct_video_placeholder( + video_array, + metadata, + video_outputs["video_grid_thw"].squeeze(0), + ) + video_placeholder = processor.tokenizer.decode(placeholder) + else: + input_ids = video_outputs.pop("input_ids") + input_ids[input_ids == processor.image_token_id] = ( + processor.video_token_id) + video_placeholder = processor.tokenizer.batch_decode( + input_ids)[0] prompt = prompt.replace( "<|begin_of_video|><|video|><|end_of_video|>", video_placeholder, @@ -1215,14 +1260,6 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]): hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_processor = self.info.get_image_processor( **hf_processor_mm_kwargs) - tokenizer = self.info.get_tokenizer() - hf_config = self.info.get_hf_config() - - boi_token_id = hf_config.image_start_token_id - eoi_token_id = hf_config.image_end_token_id - - bov_token_id = hf_config.video_start_token_id - eov_token_id = hf_config.video_end_token_id merge_length = image_processor.merge_size**2 @@ -1240,21 +1277,8 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]): assert isinstance(grid_thw, torch.Tensor) video, metadata = mm_items["video"][item_idx] - timestamps = self.info._get_video_second_idx(metadata, len(video)) - frames_idx_token = [ - tokenizer.encode(str(i), add_special_tokens=False) - for i in timestamps - ] - num_tokens_per_frame = int(grid_thw[1:].prod()) // merge_length - placeholder = [] - placeholder.append(bov_token_id) - for frame_idx in frames_idx_token: - placeholder.append(boi_token_id) - placeholder.extend([hf_processor.video_token_id] * - num_tokens_per_frame) - placeholder.append(eoi_token_id) - placeholder.extend(frame_idx) - placeholder.append(eov_token_id) + placeholder = self.info._construct_video_placeholder( + video, metadata, grid_thw) return PromptUpdateDetails.select_token_id( placeholder, embed_token_id=hf_processor.video_token_id, @@ -1355,7 +1379,7 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, raise ValueError(f"{name} should be 2D or batched 3D tensor. " f"Got ndim: {mm_input.ndim} " f"(shape={mm_input.shape})") - return torch.concat(list(mm_input)) + return mm_input.reshape(-1, mm_input.shape[-1]) else: return torch.concat(mm_input) @@ -1429,6 +1453,7 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, self, image_input: Glm4vImageInputs) -> tuple[torch.Tensor, ...]: grid_thw = image_input["image_grid_thw"] assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() if image_input["type"] == "image_embeds": image_embeds = image_input["image_embeds"].type(self.visual.dtype) @@ -1443,13 +1468,15 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, image_embeds = self.visual(pixel_values, grid_thw=grid_thw.tolist()) merge_size = self.visual.spatial_merge_size - sizes = grid_thw.prod(-1) // merge_size // merge_size - return image_embeds.split(sizes.tolist()) + sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) // + (merge_size * merge_size)).tolist() + return image_embeds.split(sizes) def _process_video_input( self, video_input: Glm4vVideoInputs) -> tuple[torch.Tensor, ...]: grid_thw = video_input["video_grid_thw"] assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() if video_input["type"] == "video_embeds": video_embeds = video_input["video_embeds"].type(self.visual.dtype) @@ -1466,8 +1493,9 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, grid_thw=grid_thw.tolist()) # Split concatenated embeddings for each video item. merge_size = self.visual.spatial_merge_size - sizes = grid_thw.prod(-1) // merge_size // merge_size - return video_embeds.split(sizes.tolist()) + sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) // + (merge_size * merge_size)).tolist() + return video_embeds.split(sizes) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: mm_input_by_modality = {} @@ -1572,17 +1600,10 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, **NOTE**: If mrope is enabled (default setting for GLM-4V opensource models), the shape will be `(3, seq_len)`, otherwise it will be `(seq_len,). - pixel_values: Pixel values to be fed to a model. - `None` if no images are passed. - image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM. - `None` if no images are passed. - pixel_values_videos: Pixel values of videos to be fed to a model. - `None` if no videos are passed. - video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM. - `None` if no videos are passed. - second_per_grid_ts: Tensor `(num_videos)` of video time interval ( - in seconds) for each grid along the temporal dimension in the - 3D position IDs. `None` if no videos are passed. + intermediate_tensors: Optional intermediate tensors for pipeline + parallelism. + inputs_embeds: Optional pre-computed input embeddings. + **kwargs: Additional keyword arguments. """ if intermediate_tensors is not None: inputs_embeds = None diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 4446b5ab181c1..0f6521e44e6be 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -339,7 +339,10 @@ class GPT2ForSequenceClassification(nn.Module): config = vllm_config.model_config.hf_config self.transformer = GPT2Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "gpt2")) - self.score = nn.Linear(config.n_embd, config.num_labels, bias=False) + self.score = nn.Linear(config.n_embd, + config.num_labels, + bias=False, + dtype=vllm_config.model_config.head_dtype) pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None @@ -348,7 +351,7 @@ class GPT2ForSequenceClassification(nn.Module): "encode": Pooler.for_encode(pooler_config), "classify": - Pooler.for_classify(pooler_config, classifier=None), + Pooler.for_classify(pooler_config, classifier=self.score), }) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): @@ -367,8 +370,7 @@ class GPT2ForSequenceClassification(nn.Module): position_ids=positions, inputs_embeds=inputs_embeds, intermediate_tensors=intermediate_tensors) - logits = self.score(hidden_states) - return logits + return hidden_states def _add_transformer_prefix( diff --git a/vllm/model_executor/models/h2ovl.py b/vllm/model_executor/models/h2ovl.py index 306775af68065..b42df3ad86508 100644 --- a/vllm/model_executor/models/h2ovl.py +++ b/vllm/model_executor/models/h2ovl.py @@ -17,7 +17,7 @@ from transformers import PretrainedConfig from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalKwargsItems +from vllm.multimodal.inputs import MultiModalKwargsItems, MultiModalUUIDDict from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, MultiModalDataItems) from vllm.multimodal.processing import (MultiModalProcessingInfo, @@ -479,7 +479,7 @@ class H2OVLMultiModalProcessor( mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], - mm_hash_overrides: Optional[dict[str, list[str]]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> tuple[list[int], MultiModalProcessingInfo, bool]: # The processor logic is different for len(images) <= 1 vs > 1 # Since the processing cache assumes that the processor output is @@ -491,7 +491,7 @@ class H2OVLMultiModalProcessor( mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) return super()._cached_apply_hf_processor( @@ -499,7 +499,7 @@ class H2OVLMultiModalProcessor( mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) diff --git a/vllm/model_executor/models/idefics2_vision_model.py b/vllm/model_executor/models/idefics2_vision_model.py index 0ca2e9e4bb688..76737a4428232 100644 --- a/vllm/model_executor/models/idefics2_vision_model.py +++ b/vllm/model_executor/models/idefics2_vision_model.py @@ -31,7 +31,6 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, - ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -139,37 +138,23 @@ class Idefics2VisionAttention(nn.Module): assert self.num_heads % tp_size == 0 self.num_heads_per_partition = self.num_heads // tp_size - if use_data_parallel: - self.q_size = self.num_heads * self.head_dim - self.qkv_proj = ReplicatedLinear( - self.embed_dim, - 3 * self.q_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", - ) - self.out_proj = ReplicatedLinear( - self.embed_dim, - self.embed_dim, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.out_proj", - ) - else: - self.qkv_proj = QKVParallelLinear( - self.embed_dim, - self.head_dim, - self.num_heads, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", - ) - self.out_proj = RowParallelLinear( - self.embed_dim, - self.embed_dim, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.out_proj", - ) + self.qkv_proj = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.num_heads, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + disable_tp=use_data_parallel, + ) + self.out_proj = RowParallelLinear( + self.embed_dim, + self.embed_dim, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + disable_tp=use_data_parallel, + ) + # Use unified MultiHeadAttention with Flash Attention support self.attn = MultiHeadAttention(self.num_heads_per_partition, self.head_dim, self.scale) @@ -181,6 +166,8 @@ class Idefics2VisionAttention(nn.Module): hidden_states ) # batch_size, q_len, 3 * num_heads_per_partition * head_dim query_states, key_states, value_states = qkv.chunk(3, dim=-1) + + # Use unified MultiHeadAttention implementation out = self.attn(query_states, key_states, value_states) attn_output, _ = self.out_proj(out) return attn_output @@ -198,23 +185,21 @@ class Idefics2VisionMLP(nn.Module): super().__init__() self.config = config self.activation_fn = get_act_fn(config.hidden_act) - cls_fc1 = (ReplicatedLinear - if use_data_parallel else ColumnParallelLinear) - self.fc1 = cls_fc1( + self.fc1 = ColumnParallelLinear( config.hidden_size, config.intermediate_size, bias=True, quant_config=quant_config, prefix=f"{prefix}.fc1", + disable_tp=use_data_parallel, ) - cls_fc2 = (ReplicatedLinear - if use_data_parallel else RowParallelLinear) - self.fc2 = cls_fc2( + self.fc2 = RowParallelLinear( config.intermediate_size, config.hidden_size, bias=True, quant_config=quant_config, prefix=f"{prefix}.fc2", + disable_tp=use_data_parallel, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -386,30 +371,6 @@ class Idefics2VisionTransformer(nn.Module): last_hidden_state = self.post_layernorm(encoder_outputs) return last_hidden_state - def _consolidate_qkv_weights( - self, weights: Iterable[tuple[str, torch.Tensor]] - ) -> Iterable[tuple[str, torch.Tensor]]: - qkv_idx_mappings = { - ".self_attn.q_proj": 0, - ".self_attn.k_proj": 1, - ".self_attn.v_proj": 2, - } - qkv_weights = {} - for name, loaded_weight in weights: - for weight_name, idx in qkv_idx_mappings.items(): - if weight_name not in name: - continue - new_name = name.replace(weight_name, ".self_attn.qkv_proj") - if new_name not in qkv_weights: - qkv_weights[new_name] = [None] * 3 - qkv_weights[new_name][idx] = loaded_weight - break - else: - yield name, loaded_weight - for key, weight in qkv_weights.items(): - qkv_weight = torch.cat(weight, dim=0) - yield key, qkv_weight - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ @@ -422,9 +383,6 @@ class Idefics2VisionTransformer(nn.Module): loaded_params: set[str] = set() layer_count = len(self.encoder.layers) - if self.use_data_parallel: - weights = self._consolidate_qkv_weights(weights) - for name, loaded_weight in weights: # skip pooling header if name.startswith("head."): diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index d5b71b057831b..8f8e300c84d71 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -823,7 +823,7 @@ class SupportsEagle3(Protocol): Args: layers: Tuple of layer indices that should output auxiliary - hidden states. + hidden states. """ ... diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py index 58e8163e0b26e..8e9ab9649bd44 100644 --- a/vllm/model_executor/models/intern_vit.py +++ b/vllm/model_executor/models/intern_vit.py @@ -255,6 +255,10 @@ class InternSdpaAttention(nn.Module): self.proj = nn.Linear(self.dummy_dim, self.embed_dim) + # Use unified MultiHeadAttention with automatic backend selection + self.attn = MultiHeadAttention(self.num_heads, self.head_dim, + self.scale) + def forward(self, x: torch.Tensor) -> torch.Tensor: B, N, C = x.shape qkv = self.qkv(x) @@ -268,12 +272,9 @@ class InternSdpaAttention(nn.Module): B_, N_, H_, D_ = q.shape q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_) k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_) - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - x = F.scaled_dot_product_attention(q, k, v, scale=self.scale) - x = x.transpose(1, 2).reshape(B, N, -1) + # Use unified MultiHeadAttention with automatic backend selection + x = self.attn(q, k, v) x = self.proj(x) return x diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 320e8d9d480c3..ce94328797ed6 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -423,13 +423,15 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM): delattr(self, attr) config = vllm_config.model_config.hf_config - self.v_head = RowParallelLinear( - config.hidden_size, - 1, - bias=False, - input_is_parallel=False, - prefix=maybe_prefix(prefix, "v_head"), - ) + self.head_dtype = vllm_config.model_config.head_dtype + + self.v_head = RowParallelLinear(config.hidden_size, + 1, + bias=False, + input_is_parallel=False, + params_dtype=self.head_dtype, + prefix=maybe_prefix(prefix, "v_head"), + return_bias=False) pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None @@ -446,5 +448,6 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM): ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) - logits, _ = self.v_head(hidden_states) + hidden_states = hidden_states.to(self.head_dtype) + logits = self.v_head(hidden_states) return logits diff --git a/vllm/model_executor/models/interns1_vit.py b/vllm/model_executor/models/interns1_vit.py index 300ed17ecaabc..eb6b685d03dc5 100644 --- a/vllm/model_executor/models/interns1_vit.py +++ b/vllm/model_executor/models/interns1_vit.py @@ -12,10 +12,10 @@ from typing import Optional import torch import torch.nn as nn -import torch.nn.functional as F from transformers import PretrainedConfig from transformers.utils import torch_int +from vllm.attention.layer import MultiHeadAttention from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -206,6 +206,10 @@ class InternSdpaAttention(nn.Module): self.projection_layer = nn.Linear(self.dummy_dim, self.embed_dim) + # Use unified MultiHeadAttention with automatic backend selection + self.attn = MultiHeadAttention(self.num_heads, self.head_dim, + self.scale) + def forward(self, x: torch.Tensor) -> torch.Tensor: B, N, C = x.shape @@ -213,20 +217,13 @@ class InternSdpaAttention(nn.Module): k = self.k_proj(x) v = self.v_proj(x) - q = q.view(B, N, self.num_heads, self.head_dim) - k = k.view(B, N, self.num_heads, self.head_dim) - v = v.view(B, N, self.num_heads, self.head_dim) - if self.qk_normalization: B_, N_, H_, D_ = q.shape q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_) k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_) - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - x = F.scaled_dot_product_attention(q, k, v, scale=self.scale) - x = x.transpose(1, 2).reshape(B, N, -1) + # Use unified MultiHeadAttention with automatic backend selection + x = self.attn(q, k, v) x = self.projection_layer(x) return x diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index b09ed7bbe72a3..9565628b198e2 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -7,6 +7,7 @@ # Copyright (c) 2023 OpenGVLab # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- +import os from abc import ABC, abstractmethod from collections.abc import Iterable, Mapping, Sequence from typing import Annotated, Any, Literal, Optional, TypeVar, Union @@ -37,6 +38,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import set_default_torch_num_threads from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, @@ -115,13 +117,26 @@ InternVLVideoInputs = Union[InternVLVideoPixelInputs, # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B def build_transform(input_size: int): MEAN, STD = IMAGENET_MEAN, IMAGENET_STD - return T.Compose([ + transform = T.Compose([ T.Lambda(lambda img: convert_image_mode(img, 'RGB')), T.Resize((input_size, input_size), interpolation=T.InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=MEAN, std=STD) ]) + # Image transformation operations (which include tensor computations + # on the CPU) can occupy a substantial number of CPU cores, introducing + # overhead due to CPU contention. This issue becomes particularly + # noticeable when deploying multiple vLLM instances on a single machine. + # Therefore, it is necessary to limit the number of threads allocated to + # image transformation tasks. + num_threads = int(os.environ.get("OMP_NUM_THREADS", "1")) + + def apply(img): + with set_default_torch_num_threads(num_threads): + return transform(img) + + return apply # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index aebd2cbe2e999..550fde17b6c53 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -613,7 +613,7 @@ class JambaForSequenceClassification(JambaForCausalLM): config.hidden_size, num_labels, bias=score_bias, - dtype=torch.float32, + dtype=vllm_config.model_config.head_dtype, ) pooler_config = vllm_config.model_config.pooler_config diff --git a/vllm/model_executor/models/jina_vl.py b/vllm/model_executor/models/jina_vl.py index 140b0d1674728..f8c2a1e507a74 100644 --- a/vllm/model_executor/models/jina_vl.py +++ b/vllm/model_executor/models/jina_vl.py @@ -5,9 +5,9 @@ from typing import Optional import torch import torch.nn as nn -from transformers import BatchFeature, PretrainedConfig +from transformers import BatchFeature -from vllm.config import VllmConfig +from vllm.config import ModelConfig, VllmConfig from vllm.inputs import TokensPrompt from vllm.logger import init_logger from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -28,13 +28,17 @@ logger = init_logger(__name__) class JinaVLScorer(nn.Module): - def __init__(self, config: PretrainedConfig): + def __init__(self, model_config: "ModelConfig"): super().__init__() + config = model_config.hf_config + head_dtype = model_config.head_dtype self.dense = ColumnParallelLinear(config.hidden_size, config.hidden_size, + params_dtype=head_dtype, bias=True) self.out_proj = RowParallelLinear(config.hidden_size, config.num_labels, + params_dtype=head_dtype, bias=True) def forward(self, x, **kwargs): @@ -88,11 +92,10 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration, def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "qwen2_vl")) - config = vllm_config.model_config.hf_config pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.score = JinaVLScorer(config) + self.score = JinaVLScorer(vllm_config.model_config) self.pooler = DispatchPooler({ "encode": Pooler.for_encode(pooler_config), diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index 710b805acb3ea..afe33b4d4ad26 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -17,6 +17,7 @@ from transformers.modeling_outputs import (BaseModelOutput, BaseModelOutputWithPooling) from transformers.utils import torch_int +from vllm.attention.layer import check_upstream_fa_availability from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger @@ -374,7 +375,16 @@ class KeyeSiglipAttention(nn.Module): ) # Detect attention implementation. - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend = get_vit_attn_backend( + head_size=self.head_dim, dtype=torch.get_default_dtype()) + + self.use_upstream_fa = False + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability( + torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN + self.use_upstream_fa = True + if self.attn_backend not in {_Backend.FLASH_ATTN, _Backend.XFORMERS}: raise RuntimeError( f"Keye-VL does not support {self.attn_backend} backend now.") @@ -428,7 +438,10 @@ class KeyeSiglipAttention(nn.Module): ) if self.attn_backend == _Backend.FLASH_ATTN: - from flash_attn import flash_attn_varlen_func + if self.use_upstream_fa: + from flash_attn import flash_attn_varlen_func + else: + from vllm.vllm_flash_attn import flash_attn_varlen_func q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) @@ -1507,15 +1520,9 @@ class BaseKeyeModule(nn.Module): batch. **NOTE**: If mrope is enabled (default setting for Qwen2-VL opensource models), the shape will be `(3, seq_len)`, - otherwise it will be `(seq_len,). - pixel_values: Pixel values to be fed to a model. - `None` if no images are passed. - image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM. - `None` if no images are passed. - pixel_values_videos: Pixel values of videos to be fed to a model. - `None` if no videos are passed. - video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM. - `None` if no videos are passed. + otherwise it will be `(seq_len,)`. + intermediate_tensors: Intermediate tensors from prior forward pass. + inputs_embeds: Optional tensor of input embeddings. """ if intermediate_tensors is not None: inputs_embeds = None @@ -1598,12 +1605,12 @@ class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal, raise ValueError(f"{name} should be 2D or batched 3D tensor. " f"Got ndim: {mm_input.ndim} " f"(shape={mm_input.shape})") - return torch.concat(list(mm_input)) + return mm_input.reshape(-1, mm_input.shape[-1]) elif is_list_of(mm_input, torch.Tensor): if all(p.dim() == 4 for p in mm_input) or all(p.dim() == 2 for p in mm_input): return mm_input - return torch.concat(list(mm_input)) + return torch.concat(mm_input) def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[KeyeImageInputs]: diff --git a/vllm/model_executor/models/keye_vl1_5.py b/vllm/model_executor/models/keye_vl1_5.py index 605c6d3eaf643..93a3bf5f98f7b 100644 --- a/vllm/model_executor/models/keye_vl1_5.py +++ b/vllm/model_executor/models/keye_vl1_5.py @@ -58,17 +58,18 @@ def split_thw(grid_thw: torch.Tensor) -> torch.Tensor: return torch.cat([ones, h_w], dim=1).repeat_interleave(t, dim=0) -def get_num_patches(grid_thw: torch.Tensor, num_frames: Union[list[int], - torch.Tensor]): +def get_num_patches(grid_thw: torch.Tensor, + num_frames: Union[list[int], torch.Tensor]) -> list[int]: """ Return num_patches per video. Args: - t: tensor with shape [N, ...] where each item is a list/tensor - cu_seqlens: list indicating the boundaries of groups + grid_thw: Tensor with shape [N, 3] containing temporal, height, width + dimensions + num_frames: List or tensor indicating the number of frames per video Returns: - list of ints representing the sum of products for each group + List of ints representing the number of patches for each video Examples: >>> # Suppose there are 2 videos with a total of 3 grids @@ -491,14 +492,14 @@ class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal, if mm_input.ndim == expected_dim: return mm_input elif mm_input.ndim == expected_dim + 1: - return torch.concat(list(mm_input)) + return mm_input.reshape(-1, *mm_input.shape[2:]) else: raise ValueError( f"{name} should be {expected_dim}D or " f"batched {expected_dim}D tensor." f"Got ndim: {mm_input.ndim} (shape={mm_input.shape})") else: - return torch.concat(list(mm_input)) + return torch.concat(mm_input) def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[KeyeVL1_5ImageInputs]: diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index a22bde194f5de..f8ea2111fed57 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -171,7 +171,22 @@ class LlamaAttention(nn.Module): sliding_window = None if layer_types := getattr(config, "layer_types", None): - is_sliding = layer_types[layer_idx] == "sliding_attention" + # Fix for Eagle3 compatibility: + # for draft models, subtract target layer count + # to get draft-relative layer index starting from 0 + if hasattr(config, 'target_layer_count'): + # This is a draft model, + # adjust layer_idx to be relative to draft layers + effective_layer_idx = layer_idx - config.target_layer_count + else: + # This is a target model, use layer_idx directly + effective_layer_idx = layer_idx + assert effective_layer_idx < len(layer_types), \ + f"effective_layer_idx: {effective_layer_idx} \ + is out of bounds for layer_types: {layer_types}" + + is_sliding = layer_types[ + effective_layer_idx] == "sliding_attention" if is_sliding: sliding_window = config.sliding_window @@ -611,9 +626,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): loaded_weight: torch.Tensor, ) -> tuple[str, torch.Tensor]: - def permute(w: torch.Tensor, n_heads: int): + def permute(w: torch.Tensor, n_heads: int, attn_out: int): attn_in = self.config.head_dim * n_heads - attn_out = self.config.hidden_size return w.view(n_heads, attn_in // n_heads // 2, 2, attn_out).transpose(1, 2).reshape(attn_in, attn_out) @@ -622,12 +636,24 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): modules = name.split(".") # rotary embeds should be sliced + # If using quantized model in mistral format, + # quantization scales (qscale_weight) also need to be sliced if "wk" in modules and modules[-1] == "weight": loaded_weight = permute(loaded_weight, - self.config.num_key_value_heads) + self.config.num_key_value_heads, + self.config.hidden_size) + elif "wk" in modules and modules[ + -1] == "qscale_weight" and loaded_weight.numel() > 1: + loaded_weight = permute(loaded_weight, + self.config.num_key_value_heads, 1) elif "wq" in modules and modules[-1] == "weight": loaded_weight = permute(loaded_weight, - self.config.num_attention_heads) + self.config.num_attention_heads, + self.config.hidden_size) + elif "wq" in modules and modules[ + -1] == "qscale_weight" and loaded_weight.numel() > 1: + loaded_weight = permute(loaded_weight, + self.config.num_attention_heads, 1) num_modules = len(modules) for i in range(num_modules): diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 572930c39a846..bceb6cc42768e 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -199,6 +199,10 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): speculative_config.draft_model_config.hf_config target_layer_num = vllm_config.model_config.get_num_layers( vllm_config.parallel_config) + + # Store target layer count in draft config for + # proper layer_types indexing in draft models + self.config.target_layer_count = target_layer_num self.model = LlamaModel(vllm_config=vllm_config, prefix="model", start_layer_id=target_layer_num) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 8a847a6180f3a..9591deea06ce9 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -24,7 +24,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputs, MultiModalKwargsItems) + MultiModalInputs, MultiModalKwargsItems, + MultiModalUUIDDict) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -731,7 +732,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): Args: input_ids: Flattened (concatenated) input_ids corresponding to a batch. - pixel_values: The pixels in each input image. + positions: Position indices for the input tokens. + intermediate_tensors: Intermediate tensors from prior forward pass. + inputs_embeds: Optional tensor of input embeddings. Info: [LlavaImageInputs][] @@ -795,7 +798,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, - mm_hash_overrides: Optional[dict[str, list[str]]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> MultiModalInputs: hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index @@ -810,7 +813,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): mm_data, hf_processor_mm_kwargs, tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides) + mm_uuids=mm_uuids) mm_items = self._to_mm_items(mm_data) mm_item_counts = mm_items.get_all_counts() diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index a63c18493df5e..5e82f9799e0fe 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -535,8 +535,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, Args: input_ids: Flattened (concatenated) input_ids corresponding to a batch. - pixel_values: The pixels in each grid patch for each input image. - image_sizes: The original `(height, width)` for each input image. + positions: Position indices for the input tokens. + intermediate_tensors: Intermediate tensors from prior forward pass. + inputs_embeds: Optional tensor of input embeddings. Info: [LlavaNextImageInputs][] diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index bc340a9e2d8f8..46d54452a52d8 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -835,7 +835,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, return None # The result multimodal_embeddings is tuple of tensors, with each - # tensor correspoending to a multimodal data item (image or video). + # tensor corresponding to a multimodal data item (image or video). multimodal_embeddings: tuple[torch.Tensor, ...] = () # NOTE: It is important to iterate over the keys in this dictionary diff --git a/vllm/model_executor/models/midashenglm.py b/vllm/model_executor/models/midashenglm.py index 858d4e7e34cf1..e314ae357ecd4 100644 --- a/vllm/model_executor/models/midashenglm.py +++ b/vllm/model_executor/models/midashenglm.py @@ -669,7 +669,7 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP): raise ValueError(f"Incorrect type of {name}. " f"Got type: {type(mm_input)}") if isinstance(mm_input, torch.Tensor): - return torch.concat(list(mm_input)) + return mm_input.reshape(-1, *mm_input.shape[2:]) else: return torch.concat(mm_input) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 04176c5589ed6..9b2d84e32151a 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -1117,7 +1117,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): def _process_multimodal_inputs(self, modalities: dict): # The result multimodal_embeddings is tuple of tensors, with each - # tensor correspoending to a multimodal data item (image or video). + # tensor corresponding to a multimodal data item (image or video). multimodal_embeddings: tuple[torch.Tensor, ...] = () # NOTE: It is important to iterate over the keys in this dictionary diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index 08948960b275c..09479012a03ad 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -578,7 +578,9 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA, Args: input_ids: Flattened (concatenated) input_ids corresponding to a batch. - pixel_values: The pixels in each input image. + positions: Position indices for the input tokens. + intermediate_tensors: Intermediate tensors from prior forward pass. + inputs_embeds: Optional tensor of input embeddings. Info: [Mistral3ImagePixelInputs][] diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py deleted file mode 100644 index 692267b4d7271..0000000000000 --- a/vllm/model_executor/models/mixtral_quant.py +++ /dev/null @@ -1,454 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# Adapted from -# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py -# Copyright 2023 The vLLM team. -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Inference-only Mixtral model.""" -from collections.abc import Iterable -from itertools import islice -from typing import Optional, Union - -import numpy as np -import torch -import torch.nn.functional as F -from torch import nn -from transformers import MixtralConfig - -from vllm.attention import Attention -from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors - -from .interfaces import SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) - - -class MixtralMLP(nn.Module): - - def __init__( - self, - num_experts: int, - hidden_size: int, - intermediate_size: int, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: - super().__init__() - self.num_experts = num_experts - self.ffn_dim = intermediate_size - self.hidden_dim = hidden_size - - self.w1 = ReplicatedLinear(self.hidden_dim, - self.ffn_dim, - bias=False, - quant_config=quant_config) - self.w2 = ReplicatedLinear(self.ffn_dim, - self.hidden_dim, - bias=False, - quant_config=quant_config) - self.w3 = ReplicatedLinear(self.hidden_dim, - self.ffn_dim, - bias=False, - quant_config=quant_config) - - # TODO: Use vllm's SiluAndMul - self.act_fn = nn.SiLU() - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - w1_out, _ = self.w1(hidden_states) - w1_out = self.act_fn(w1_out) - w3_out, _ = self.w3(hidden_states) - current_hidden_states = w1_out * w3_out - current_hidden_states, _ = self.w2(current_hidden_states) - return current_hidden_states - - -class MixtralMoE(nn.Module): - - def __init__( - self, - config: MixtralConfig, - quant_config: Optional[QuantizationConfig] = None, - ): - super().__init__() - self.config = config - self.rank = get_tensor_model_parallel_rank() - self.tp_size = get_tensor_model_parallel_world_size() - self.num_total_experts = config.num_local_experts - self.top_k = config.num_experts_per_tok - if self.tp_size > self.num_total_experts: - raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {self.num_total_experts}.") - # Split experts equally between ranks - self.expert_indices = np.array_split(range(self.num_total_experts), - self.tp_size)[self.rank].tolist() - if not self.expert_indices: - raise ValueError( - f"Rank {self.rank} has no experts assigned to it.") - - self.experts = nn.ModuleList([ - MixtralMLP(self.num_total_experts, - config.hidden_size, - config.intermediate_size, - quant_config=quant_config) - if idx in self.expert_indices else None - for idx in range(self.num_total_experts) - ]) - self.gate = ReplicatedLinear(config.hidden_size, - self.num_total_experts, - bias=False, - quant_config=None) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - num_tokens, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) - # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) - - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, - self.top_k, - dim=-1) - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - - final_hidden_states = None - for expert_idx in self.expert_indices: - expert_layer = self.experts[expert_idx] - expert_mask = (selected_experts == expert_idx) - expert_weights = (routing_weights * expert_mask).sum(dim=-1, - keepdim=True) - - current_hidden_states = expert_layer(hidden_states).mul_( - expert_weights) - if final_hidden_states is None: - final_hidden_states = current_hidden_states - else: - final_hidden_states.add_(current_hidden_states) - - return tensor_model_parallel_all_reduce(final_hidden_states).view( - num_tokens, hidden_dim) - - -class MixtralAttention(nn.Module): - - def __init__( - self, - config: MixtralConfig, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - max_position: int = 4096 * 32, - rope_theta: float = 10000, - quant_config: Optional[QuantizationConfig] = None, - cache_config: Optional[CacheConfig] = None, - prefix: str = "", - ) -> None: - super().__init__() - self.hidden_size = hidden_size - tp_size = get_tensor_model_parallel_world_size() - self.total_num_heads = num_heads - assert self.total_num_heads % tp_size == 0 - self.num_heads = self.total_num_heads // tp_size - self.total_num_kv_heads = num_kv_heads - if self.total_num_kv_heads >= tp_size: - # Number of KV heads is greater than TP size, so we partition - # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_size == 0 - else: - # Number of KV heads is less than TP size, so we replicate - # the KV heads across multiple tensor parallel GPUs. - assert tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - # MixtralConfig has an optional head_dim argument - self.head_dim = getattr(config, "head_dim", None) - if self.head_dim is None: - self.head_dim = self.hidden_size // self.total_num_heads - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = self.head_dim**-0.5 - self.rope_theta = rope_theta - - self.qkv_proj = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, - bias=False, - quant_config=quant_config, - ) - self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, - bias=False, - quant_config=quant_config, - ) - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=max_position, - base=int(self.rope_theta), - is_neox_style=True, - ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v) - output, _ = self.o_proj(attn_output) - return output - - -class MixtralDecoderLayer(nn.Module): - - def __init__( - self, - config: MixtralConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__() - self.hidden_size = config.hidden_size - # Requires transformers > 4.32.0 - rope_theta = getattr(config, "rope_theta", 10000) - self.self_attn = MixtralAttention( - config=config, - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - max_position=config.max_position_embeddings, - num_kv_heads=config.num_key_value_heads, - rope_theta=rope_theta, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn", - ) - self.block_sparse_moe = MixtralMoE(config=config, - quant_config=quant_config) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], - ) -> torch.Tensor: - # Self Attention - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - ) - - # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) - hidden_states = self.block_sparse_moe(hidden_states) - return hidden_states, residual - - -class MixtralModel(nn.Module): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - - self.vocab_size = config.vocab_size - - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - ) - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, - lambda prefix: MixtralDecoderLayer( - config, cache_config, quant_config=quant_config, prefix=prefix - ), - prefix=f"{prefix}.layers") - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.get_input_embeddings(input_ids) - residual = None - else: - assert intermediate_tensors is not None - hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] - for layer in islice(self.layers, self.start_layer, self.end_layer): - hidden_states, residual = layer(positions, hidden_states, residual) - if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ] - - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if name.endswith("scale"): - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Skip experts that are not assigned to this worker. - if ("block_sparse_moe.experts." in name - and name not in params_dict): - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - - -class MixtralForCausalLM(nn.Module, SupportsPP): - fall_back_to_pt_during_load = False - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - self.config = config - self.quant_config = quant_config - self.model = MixtralModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) - if self.config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight - self.logits_processor = LogitsProcessor(config.vocab_size) - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) - return hidden_states - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self) - return loader.load_weights(weights) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index f441287a4d089..048894085b360 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -35,6 +35,7 @@ from transformers.models.mllama.processing_mllama import ( import vllm.distributed.parallel_state as ps from vllm.attention import Attention, AttentionMetadata, AttentionType +from vllm.attention.layer import MultiHeadAttention from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.selector import _Backend from vllm.config import VllmConfig @@ -57,7 +58,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs, MultiModalFieldConfig, - MultiModalKwargsItems) + MultiModalKwargsItems, MultiModalUUIDDict) from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseProcessingInfo, @@ -184,13 +185,13 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo] mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, - mm_hash_overrides: Optional[dict[str, list[str]]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> MultiModalEncDecInputs: mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs, tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides) + mm_uuids=mm_uuids) image_token_id = self.info.get_hf_config().image_token_index # Check that the number of image tokens in the decoder prompt matches @@ -517,6 +518,10 @@ class MllamaVisionSdpaAttention(nn.Module): prefix=f"{prefix}.o_proj", ) + # Use unified MultiHeadAttention with automatic backend selection + self.attn = MultiHeadAttention(self.num_local_heads, self.head_dim, + 1.0 / math.sqrt(self.head_dim)) + def forward( self, hidden_state: torch.Tensor, @@ -524,21 +529,10 @@ class MllamaVisionSdpaAttention(nn.Module): ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_state) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q = q.view(q.shape[0], q.shape[1], self.num_local_heads, - self.head_dim).transpose(1, 2) - k = k.view(k.shape[0], k.shape[1], self.num_local_heads, - self.head_dim).transpose(1, 2) - v = v.view(v.shape[0], v.shape[1], self.num_local_heads, - self.head_dim).transpose(1, 2) - # TODO: remove padding in image encoder - attn_output = F.scaled_dot_product_attention(q, - k, - v, - attn_mask=attention_mask, - dropout_p=0.0) + # Use unified MultiHeadAttention with automatic backend selection + attn_output = self.attn(q, k, v) - attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(attn_output.shape[0], attn_output.shape[1], -1) output, _ = self.o_proj(attn_output) diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index ecbbb5f57bec8..2f0e8a2a5e575 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -106,22 +106,21 @@ class Llama4VisionMLP(nn.Module): use_data_parallel: bool = False, ): super().__init__() - cls_fc1 = (ReplicatedLinear - if use_data_parallel else ColumnParallelLinear) - self.fc1 = cls_fc1( + self.fc1 = ColumnParallelLinear( input_size=input_size, output_size=intermediate_size, bias=bias, quant_config=quant_config, prefix=f"{prefix}.fc1", + disable_tp=use_data_parallel, ) - cls_fc2 = ReplicatedLinear if use_data_parallel else RowParallelLinear - self.fc2 = cls_fc2( + self.fc2 = RowParallelLinear( input_size=intermediate_size, output_size=output_size, bias=bias, quant_config=quant_config, prefix=f"{prefix}.fc2", + disable_tp=use_data_parallel, ) self.activation_fn = nn.GELU() self.output_activation = output_activation @@ -388,11 +387,10 @@ class Llama4VisionEncoder(nn.Module): ) -> torch.Tensor: r""" Args: - inputs_embeds (`torch.FloatTensor` of shape - `(batch_size, sequence_length, hidden_size)`): - Optionally, instead of passing `input_ids` you can choose to - directly pass an embedded representation. This is useful if you - want more control over how to convert `input_ids` indices into + hidden_states: Input tensor of shape + (batch_size, sequence_length, hidden_size). + Hidden states from the model embeddings, representing + the input tokens. associated vectors than the model's internal embedding lookup matrix. """ @@ -419,20 +417,15 @@ class Llama4UnfoldConvolution(nn.Module): kernel_size = (kernel_size, kernel_size) self.unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=config.patch_size) - params = { - "input_size": - config.num_channels * kernel_size[0] * kernel_size[1], - "output_size": config.hidden_size, - "bias": False, - "quant_config": quant_config, - "prefix": f"{prefix}.linear", - } - if use_data_parallel: - cls = ReplicatedLinear - else: - cls = ColumnParallelLinear - params["gather_output"] = True - self.linear = cls(**params) + self.linear = ColumnParallelLinear( + input_size=config.num_channels * kernel_size[0] * kernel_size[1], + output_size=config.hidden_size, + bias=False, + gather_output=True, + quant_config=quant_config, + prefix=f"{prefix}.linear", + disable_tp=use_data_parallel, + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.unfold(hidden_states) diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index 776287589808a..1d5da3139de92 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -306,7 +306,9 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding): self.config = config self.model = ModernBertModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "modernbert")) - self.classifier = nn.Linear(config.hidden_size, config.num_labels) + self.classifier = nn.Linear(config.hidden_size, + config.num_labels, + dtype=vllm_config.model_config.head_dtype) self.pooling = ModernBertPooler(config) pooler_config = vllm_config.model_config.pooler_config diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index b2fc7be1af224..5d999a02b4e65 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -76,20 +76,22 @@ class MolmoImageInputs(TensorSchema): """ Dimensions: - bn: Batch size * number of images - - nc: Number of crops + - nc: Number of crops (dynamic) - np: Number of patches + - tp: Token sequence positions - pd: Patch dimension """ images: Annotated[Union[torch.Tensor, list[torch.Tensor]], - TensorShape("bn", "nc", "np", "pd")] + TensorShape("bn", "nc", "np", "pd", dynamic_dims={"nc"})] + # Number of crops may vary per batch and image, so pass it as a list. image_masks: Annotated[Optional[Union[torch.Tensor, list[torch.Tensor]]], - TensorShape("bn", "nc", "np")] + TensorShape("bn", "nc", "np", dynamic_dims={"nc"})] - feat_is_patch: Annotated[Union[torch.Tensor, list[torch.Tensor]], - TensorShape("bn", "nc", "np")] + feat_is_patch: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "nc", "tp", dynamic_dims={"nc"})] # A boolean mask indicating which image features correspond to patch tokens. - num_crops: Annotated[torch.Tensor, TensorShape("bn")] diff --git a/vllm/model_executor/models/moonvit.py b/vllm/model_executor/models/moonvit.py index 41a2c836b09f3..caa00763fc3d4 100644 --- a/vllm/model_executor/models/moonvit.py +++ b/vllm/model_executor/models/moonvit.py @@ -70,11 +70,15 @@ def multihead_attention( v: torch.Tensor, q_cu_seqlens: Optional[torch.Tensor] = None, k_cu_seqlens: Optional[torch.Tensor] = None, -): +) -> torch.Tensor: """Multi-head attention using flash attention 2. Args: - q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim), + q: Query tensor of shape (batch_size, seqlen, num_heads, head_dim), + or (tot_seqlens, num_heads, head_dim) if packing. + k: Key tensor of shape (batch_size, seqlen, num_heads, head_dim), + or (tot_seqlens, num_heads, head_dim) if packing. + v: Value tensor of shape (batch_size, seqlen, num_heads, head_dim), or (tot_seqlens, num_heads, head_dim) if packing. q_cu_seqlens (torch.Tensor): cumulative sequence lengths of q. The first element should be 0 and the last element should be q.shape[0]. @@ -123,8 +127,14 @@ def sdpa_attention( """SDPA attention. Args: - q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim), + q: Query tensor of shape (batch_size, seqlen, num_heads, head_dim), or (tot_seqlens, num_heads, head_dim) if packing. + k: Key tensor of shape (batch_size, seqlen, num_heads, head_dim), + or (tot_seqlens, num_heads, head_dim) if packing. + v: Value tensor of shape (batch_size, seqlen, num_heads, head_dim), + or (tot_seqlens, num_heads, head_dim) if packing. + q_cu_seqlens: Optional cumulative sequence lengths of q. + k_cu_seqlens: Optional cumulative sequence lengths of k. """ seq_length = q.shape[0] attention_mask = torch.zeros([1, seq_length, seq_length], @@ -387,7 +397,7 @@ class MLP2(nn.Module): def __init__(self, dims: list[int], activation, - bias=True, + bias: bool = True, prefix: str = "", use_data_parallel: bool = False): super().__init__() diff --git a/vllm/model_executor/models/motif.py b/vllm/model_executor/models/motif.py new file mode 100644 index 0000000000000..153f36dcf1f55 --- /dev/null +++ b/vllm/model_executor/models/motif.py @@ -0,0 +1,345 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://huggingface.co/Motif-Technologies/Motif-2.6B/blob/main/modeling_motif.py +# Copyright (c) Alibaba Cloud. +# LICENSE: https://huggingface.co/Motif-Technologies/Motif-2.6B/blob/main/LICENSE +"""Inference-only Motif model compatible with HuggingFace weights.""" +import math +from typing import Any, Optional + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention, AttentionType +from vllm.attention.selector import _Backend +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.layernorm import PolyNorm, RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.models.llama import LlamaForCausalLM + +from .adapters import as_seq_cls_model +from .interfaces import SupportsV0Only +from .utils import extract_layer_index + + +class MotifMLP(nn.Module): + """MLP for the language component of the Motif model, which contains a + MergedColumnParallelLinear merging 2 outputs via PolyNorm activation.""" + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str = "poly_norm", + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + prefix: str = "", + reduce_results: bool = True, + ): + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "poly_norm": + raise NotImplementedError(f"Unsupported activation: {hidden_act}. " + "Only poly_norm is supported for now.") + self.act_fn = PolyNorm() + self.intermediate_size = intermediate_size + tp_size = get_tensor_model_parallel_world_size() + if hidden_act == "poly_norm" and tp_size > 1: + raise NotImplementedError( + "Tensor parallelism for poly_norm is not supported yet. " + "Support will be added in the future.") + + def forward(self, x): + x, _ = self.gate_up_proj(x) + x = self.act_fn( + x[..., :self.intermediate_size]) * x[..., self.intermediate_size:] + x, _ = self.down_proj(x) + return x + + +class MotifAttention(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + bias_o_proj: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + ) -> None: + super().__init__() + layer_idx = extract_layer_index(prefix) + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + head_dim = getattr(config, "head_dim", None) + if head_dim is None: + head_dim = self.hidden_size // self.total_num_heads + self.head_dim = head_dim + # Phi models introduced a partial_rotary_factor parameter in the config + self.partial_rotary_factor = getattr(config, "partial_rotary_factor", + 1) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + assert self.num_heads % 2 == 0, 'num_heads should be even' + assert self.num_kv_heads % 2 == 0, 'num_heads should be even' + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias_o_proj, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self._init_rotary_emb(config, + rope_scaling=rope_scaling, + quant_config=quant_config) + sliding_window = None + + self.lambda_init = self.lambda_init_fn(layer_idx) + self.lambda_q1 = nn.Parameter( + torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, + std=0.1)) + self.lambda_k1 = nn.Parameter( + torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, + std=0.1)) + self.lambda_q2 = nn.Parameter( + torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, + std=0.1)) + self.lambda_k2 = nn.Parameter( + torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, + std=0.1)) + self.subln = RMSNorm(2 * self.head_dim, eps=config.attn_rms_norm_eps) + + params = { + 'differential_flash_attention_config': { + 'lambda_init': self.lambda_init, + 'lambda_q1': self.lambda_q1, + 'lambda_k1': self.lambda_k1, + 'lambda_q2': self.lambda_q2, + 'lambda_k2': self.lambda_k2, + "subln": self.subln, + } + } + + diff_attn_err_msg = ( + 'Set VLLM_ATTENTION_BACKEND="DIFFERENTIAL_FLASH_ATTN" ' + 'to enable Differential Flash Attention.') + try: + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + per_layer_sliding_window=sliding_window, + attn_type=attn_type, + prefix=f"{prefix}.attn", + **params, + ) + except TypeError as e: + raise ValueError(diff_attn_err_msg) from e + assert (self.attn.backend == _Backend.DIFFERENTIAL_FLASH_ATTN + ), diff_attn_err_msg + + def lambda_init_fn(self, depth): + return 0.8 - 0.6 * math.exp(-0.3 * (depth - 1)) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + def _init_rotary_emb(self, config: PretrainedConfig, + rope_scaling: Optional[dict[str, Any]], + quant_config: Optional[QuantizationConfig]) -> None: + is_neox_style = True + is_gguf = quant_config and quant_config.get_name() == "gguf" + if is_gguf and config.model_type == "llama": + is_neox_style = False + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + rope_scaling=rope_scaling, + is_neox_style=is_neox_style, + partial_rotary_factor=self.partial_rotary_factor, + ) + + +class MotifDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "use_bias", False) + bias_o_proj = attention_bias + if hasattr(config, 'qkv_bias'): + attention_bias = config.qkv_bias + + # By default, Motif uses causal attention as it is a decoder-only model. + # You can override the HF config with `is_causal=False` to enable + # bidirectional attention, which is used in some embedding models + # (e.g. parasail-ai/GritLM-7B-vllm) + if getattr(config, "is_causal", True): + attn_type = AttentionType.DECODER + else: + attn_type = AttentionType.ENCODER_ONLY + + self.self_attn = MotifAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + bias_o_proj=bias_o_proj, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + attn_type=attn_type, + ) + self.mlp = MotifMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "use_bias", False), + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn(positions=positions, + hidden_states=hidden_states) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +# Motif model uses differential attention +# Only supported in v0 (no chunked prefill support) +class MotifForCausalLM(LlamaForCausalLM, SupportsV0Only): + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = MotifDecoderLayer): + + # Prefix caching and chunked prefill is not supported for this model. + assert not vllm_config.cache_config.enable_prefix_caching, \ + "Motif currently does not support prefix caching" + assert not vllm_config.scheduler_config.chunked_prefill_enabled, \ + "Motif currently does not support chunked prefill" + + super().__init__(vllm_config=vllm_config, + prefix=prefix, + layer_type=layer_type) + + +MotifForSequenceClassification = as_seq_cls_model(MotifForCausalLM) diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py new file mode 100644 index 0000000000000..21765a483b8e0 --- /dev/null +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -0,0 +1,1395 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# -------------------------------------------------------- +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/internvl.py +# under Apache-2.0 License +# LICENSE is in root directory. +# -------------------------------------------------------- + +import copy +import warnings +from abc import ABC, abstractmethod +from collections.abc import Iterable, Mapping, Sequence +from typing import Annotated, Any, Literal, Optional, TypedDict, TypeVar, Union + +import numpy.typing as npt +import torch +import torch.nn as nn +import torchvision.transforms as T +from PIL import Image +from transformers import (AutoModel, BatchEncoding, BatchFeature, + PretrainedConfig, TensorType) + +from vllm.config import VllmConfig +from vllm.model_executor.layers.activation import ReLUSquaredActivation +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, + MultiModalEmbeddings, + SupportsMultiModal) +from vllm.model_executor.models.internvl import (calculate_internvl_targets, + get_internvl_target_ratios) +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM +from vllm.model_executor.models.utils import (flatten_bn, + init_vllm_registered_model, + maybe_prefix, + merge_multimodal_embeddings) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs, MultiModalKwargsItems, + NestedTensors) +from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, + ImageSize, MultiModalDataItems) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils.tensor_schema import TensorSchema, TensorShape + +# Configure PIL to handle large images without warnings +# This prevents DecompressionBombWarning for legitimate large images +Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely +# Alternative: Set a specific higher limit +# Image.MAX_IMAGE_PIXELS = 300000000 # ~300M pixels + +IMG_START = "<img>" +IMG_END = "</img>" +IMG_CONTEXT = "<image>" + +# Profiling +MAX_FRAMES = 16 + + +class NanoNemotronVLImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + pixel_values_flat: torch.Tensor + """ + Shape: + `(batch_size * num_images * (1 + num_patches), num_channels, height, width)` + """ + + num_patches: torch.Tensor + """Shape: `(batch_size * num_images)`""" + + +class NanoNemotronVLImageEmbeddinInputs(TypedDict): + type: Literal["image_embeds"] + data: Union[torch.Tensor, list[torch.Tensor]] + """ + A tensor of shape `(num_images, total_image_feature_size, hidden_size)` + or a list of tensors of shape `(total_image_feature_size, hidden_size)` + + `hidden_size` must match the hidden size of language model backbone. + """ + + +NanoNemotronVLImageInputs = Union[NanoNemotronVLImagePixelInputs, + NanoNemotronVLImageEmbeddinInputs] + + +class NanoNemotronVLVideoPixelInputs(TensorSchema): + """ + Dimensions: + - bvf: Batch size * number of videos * num_frames + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height of each video frame + - w: Width of each video frame + """ + type: Literal["pixel_values_videos"] + pixel_values_flat: Annotated[torch.Tensor, TensorShape("bvf", 3, "h", "w")] + num_patches: Annotated[torch.Tensor, TensorShape("bn")] + + +class NanoNemotronVLVideoEmbeddingInputs(TensorSchema): + """ + Dimensions: + - n: Number of videos + - f: Total video feature size + - h: Hidden size (must match the hidden size of language model backbone) + """ + type: Literal["video_embeds"] + data: Annotated[Union[torch.Tensor, list[torch.Tensor]], + TensorShape("n", "f", "h")] + + +NanoNemotronVLVideoInputs = Union[NanoNemotronVLVideoPixelInputs, + NanoNemotronVLVideoEmbeddingInputs] + + +def input_conditioner(x, norm_mean, norm_std): + y = (x - norm_mean) / norm_std + return y + + +def dynamic_preprocess(image, + *, + image_size=512, + max_num_tiles=12, + use_thumbnail=True, + idx=0): + orig_width, orig_height = image.size + + target_ratios = get_internvl_target_ratios(1, max_num_tiles) + + blocks, target_width, target_height = calculate_internvl_targets( + orig_width=orig_width, + orig_height=orig_height, + target_ratios=target_ratios, + image_size=image_size, + use_thumbnail=False) + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size, + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + + processed_images = [ + img.convert("RGB") if img.mode != "RGB" else img + for img in processed_images + ] + processed_images = [ + T.Resize((image_size, image_size), + interpolation=T.InterpolationMode.BICUBIC)(img) + for img in processed_images + ] + processed_images = [T.ToTensor()(img) for img in processed_images] + return processed_images + + +def image_to_pixel_values( + image: Image.Image, + *, + input_size: int, + max_num: int, + use_thumbnail: bool, + idx: int, +) -> torch.Tensor: + images = dynamic_preprocess( + image, + image_size=input_size, + max_num_tiles=max_num, + use_thumbnail=use_thumbnail, + idx=idx, + ) + + pixel_values = torch.stack(images) + return pixel_values + + +def video_to_pixel_values( + video: npt.NDArray, + *, + input_size: int, + max_num_tiles: int = 1, + use_thumbnail: bool, +) -> torch.Tensor: + # Convert each frame to a single resized tile tensor consistent + # with image path + frames_tensors: list[torch.Tensor] = [] + for frame in video: + pil_frame = dynamic_preprocess( + Image.fromarray(frame, mode="RGB"), + image_size=input_size, + max_num_tiles=max_num_tiles, + use_thumbnail=use_thumbnail, + idx=0, + ) + # dynamic_preprocess returns tensors already; take the single tile + assert len(pil_frame) >= 1 + frames_tensors.append(pil_frame[0]) + + return torch.stack(frames_tensors) + + +class BaseNanoNemotronVLProcessor(ABC): + """ + This model doesn't define its own HF processor, + so we implement our own one here. + + The code to insert image tokens is based on: + https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py#L252 + """ + + def __init__(self, config: PretrainedConfig, tokenizer: AnyTokenizer, + *args, **kwargs) -> None: + super().__init__() + + self.config = config + self.tokenizer = tokenizer + + image_size: int = config.force_image_size + patch_size: int = config.patch_size + + self.num_image_token = int( + (image_size // patch_size)**2 * (config.downsample_ratio**2)) + self.image_size = image_size + self.use_thumbnail: bool = config.use_thumbnail + self.norm_mean = torch.Tensor(config.norm_mean).reshape(1, 3, 1, 1) + self.norm_std = torch.Tensor(config.norm_std).reshape(1, 3, 1, 1) + + @property + @abstractmethod + def image_token_id(self) -> int: + raise NotImplementedError + + @abstractmethod + def get_image_repl( + self, + feature_size: int, + num_patches: Optional[int], + ) -> PromptUpdateDetails[str]: + raise NotImplementedError + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + max_num_tiles: int, + ) -> int: + target_ratios = get_internvl_target_ratios(1, max_num_tiles) + + num_patches, _, _ = calculate_internvl_targets( + orig_width=image_width, + orig_height=image_height, + target_ratios=target_ratios, + image_size=self.image_size, + use_thumbnail=self.use_thumbnail, + ) + + return num_patches * self.num_image_token + + def _images_to_pixel_values_lst( + self, + images: list[Image.Image], + max_num_tiles: int, + ) -> list[torch.Tensor]: + return [ + image_to_pixel_values( + image, + input_size=self.image_size, + max_num=max_num_tiles, + use_thumbnail=self.use_thumbnail, + idx=idx, + ) for idx, image in enumerate(images) + ] + + def _preprocess_image( + self, + text: list[str], + images: list[Image.Image], + max_num_tiles: int, + ) -> tuple[list[str], dict[str, torch.Tensor]]: + if len(images) == 0: + image_inputs = {} + else: + pixel_values_lst = self._images_to_pixel_values_lst( + images, max_num_tiles) + image_inputs: dict[str, NestedTensors] = { + "pixel_values_flat": + input_conditioner(torch.cat(pixel_values_lst), self.norm_mean, + self.norm_std), + "image_num_patches": + torch.tensor([len(item) for item in pixel_values_lst]), + } + + for pixel_values in pixel_values_lst: + num_patches = pixel_values.shape[0] + feature_size = num_patches * self.num_image_token + image_repl = self.get_image_repl(feature_size, num_patches) + text = [t.replace('<image>', image_repl.full, 1) for t in text] + return text, image_inputs + + def _make_batch_input(self, + input_item: Optional[Union[Any, list[Any]]] = None): + if input_item is None: + input_item = [] + if not isinstance(input_item, list): + input_item = [input_item] + return input_item + + def __call__( + self, + text: Optional[Union[str, list[str]]] = None, + images: Optional[Union[Image.Image, list[Image.Image]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + max_num_tiles: Optional[int] = None, + ) -> Mapping[str, NestedTensors]: + # Use default if not provided + if max_num_tiles is None: + max_num_tiles = 12 + + text, images = [self._make_batch_input(x) for x in (text, images)] + + text, image_inputs = self._preprocess_image( + text=text, + images=images, + max_num_tiles=max_num_tiles, + ) + + text_inputs = self.tokenizer(text, add_special_tokens=False) + + return { + **BatchEncoding(text_inputs, tensor_type=return_tensors), + **image_inputs, + } + + +class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): + """ + HF Processor with extended video processing logic. + Code for video processing is adapted from video example: + https://huggingface.co/OpenGVLab/InternVL3-1B#inference-with-transformers + """ + + def __init__( + self, + config: PretrainedConfig, + tokenizer: AnyTokenizer, + *, + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + video_token: Optional[str] = None, + ) -> None: + super().__init__( + config=config, + tokenizer=tokenizer, + min_dynamic_patch=min_dynamic_patch, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + ) + # add extra video token for video processing + self.video_token = video_token + + @property + def supports_video(self) -> bool: + return self.video_token_id is not None + + @property + def video_token_id(self) -> Optional[int]: + if self.video_token is None: + return None + return self.tokenizer.get_vocab().get(self.video_token, None) + + @property + def image_token_id(self) -> int: + return self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT) + + def _videos_to_pixel_values_lst( + self, + videos: list[npt.NDArray], + max_num_tiles: int, + dynamic_image_size: Optional[bool] = None, + ) -> list[torch.Tensor]: + + return [ + video_to_pixel_values( + video, + input_size=self.image_size, + max_num_tiles=max_num_tiles, + use_thumbnail=self.use_thumbnail, + ) for video in videos + ] + + def _preprocess_video( + self, + text: list[str], + videos: list[npt.NDArray], + max_num_tiles: int, + dynamic_image_size: Optional[bool] = None, + ): + if len(videos) == 0 or not self.supports_video: + video_inputs = {} + else: + pixel_values_lst_video = self._videos_to_pixel_values_lst( + videos, + max_num_tiles=max_num_tiles, + dynamic_image_size=dynamic_image_size, + ) + + video_inputs: dict[str, NestedTensors] = { + "pixel_values_flat_video": + input_conditioner(torch.cat(pixel_values_lst_video), + self.norm_mean, self.norm_std), + "video_num_patches": + torch.tensor([len(item) for item in pixel_values_lst_video]), + } + + for pixel_values in pixel_values_lst_video: + num_patches = pixel_values.shape[0] + + video_repl = self.get_video_repl(self.num_image_token, + num_patches, self.video_token) + text = [t.replace('<video>', video_repl.full, 1) for t in text] + return text, video_inputs + + def __call__( + self, + text: Optional[Union[str, list[str]]] = None, + images: Optional[Union[Image.Image, list[Image.Image]]] = None, + videos: Optional[Union[npt.NDArray, list[npt.NDArray]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + max_num_tiles: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + ) -> Mapping[str, NestedTensors]: + # Use default if not provided + if max_num_tiles is None: + max_num_tiles = 12 + + text, images, videos = [ + self._make_batch_input(x) for x in (text, images, videos) + ] + + text, image_inputs = self._preprocess_image( + text=text, + images=images, + max_num_tiles=max_num_tiles, + ) + + text, video_inputs = self._preprocess_video( + text=text, + videos=videos, + max_num_tiles=max_num_tiles, + dynamic_image_size=dynamic_image_size, + ) + + text_inputs = self.tokenizer(text, add_special_tokens=False) + + return BatchFeature({ + **BatchEncoding(text_inputs, tensor_type=return_tensors), + **image_inputs, + **video_inputs, + }) + + def get_image_repl( + self, + feature_size: int, + num_patches: Optional[int], + ) -> PromptUpdateDetails[str]: + repl_features = IMG_CONTEXT * feature_size + repl_full = IMG_START + repl_features + IMG_END + + return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT) + + def get_video_repl( + self, + feature_size: int, + num_patches: Optional[int] = None, + video_context_token: str = IMG_CONTEXT, + ) -> PromptUpdateDetails[str]: + repl_features = video_context_token * self.num_image_token + repl_features_with_sep = IMG_START + repl_features + IMG_END + # num_patches is equal to num_frames + repl_full = ''.join([ + f'Frame{i+1}: {repl_features_with_sep}' for i in range(num_patches) + ]) + + return PromptUpdateDetails.select_text(repl_full, video_context_token) + + +class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo): + """Basic image-only ProcessingInfo for InternVL-style models.""" + + @abstractmethod + def get_hf_processor( + self, + **kwargs: object, + ) -> BaseNanoNemotronVLProcessor: + raise NotImplementedError + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + max_num_tiles: int, + processor: Optional[BaseNanoNemotronVLProcessor], + ) -> int: + if processor is None: + processor = self.get_hf_processor() + + return processor.get_num_image_tokens( + image_width=image_width, + image_height=image_height, + max_num_tiles=max_num_tiles, + ) + + def get_image_size_with_most_features(self, + max_num_tiles: int) -> ImageSize: + processor = self.get_hf_processor() + + base_size = processor.image_size + target_ratios = get_internvl_target_ratios(1, max_num_tiles) + + largest_feature_size, largest_feature_pinpoint = 0, None + for wr, hr in target_ratios: + width, height = base_size * wr, base_size * hr + + feat_size = self.get_num_image_tokens( + image_width=width, + image_height=height, + max_num_tiles=max_num_tiles, + processor=processor, + ) + if feat_size > largest_feature_size: + largest_feature_size = feat_size + largest_feature_pinpoint = ImageSize(width=width, + height=height) + + if largest_feature_size == 0 or largest_feature_pinpoint is None: + raise ValueError("Cannot have a largest feature size of 0!") + + return largest_feature_pinpoint + + def get_max_image_tokens(self) -> int: + processor = self.get_hf_processor() + # Use default max_num_tiles for max tokens calculation + max_num_tiles = 12 + target_width, target_height = self.get_image_size_with_most_features( + max_num_tiles) + + return self.get_num_image_tokens( + image_width=target_width, + image_height=target_height, + max_num_tiles=max_num_tiles, + processor=processor, + ) + + +_I = TypeVar("_I", bound=BaseNanoNemotronVLProcessingInfo) + + +class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo): + """ ProcessingInfo extended for video processing""" + + @property + def supports_video(self): + return self.get_hf_processor().supports_video + + def get_supported_mm_limits(self): + video_limit = {"video": None} if self.supports_video else {} + return {**super().get_supported_mm_limits(), **video_limit} + + def get_video_token(self) -> Optional[str]: + return IMG_CONTEXT + + def get_num_frames_with_most_features( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + max_images = mm_counts.get("image", 0) + max_videos = mm_counts.get("video", 0) + + processor = self.get_hf_processor() # we get the CustomProcessor here + + max_image_tokens = self.get_max_image_tokens() * max_images + max_total_frames = (seq_len - + max_image_tokens) // processor.num_image_token + max_frames_per_video = max_total_frames // max(max_videos, 1) + + max_frames_per_video = min(max_frames_per_video, MAX_FRAMES) + return max(max_frames_per_video, 1) + + def get_hf_processor(self, **kwargs: object) -> NanoNemotronVLProcessor: + return self.ctx.init_processor( + NanoNemotronVLProcessor, + config=self.get_hf_config(), + tokenizer=self.get_tokenizer(), + video_token=self.get_video_token(), + **kwargs, + ) + + +class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]): + """Basic image-only MultiModalProcessor for InternVL-style models.""" + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> Mapping[str, NestedTensors]: + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, + ) + + hf_processor = self.info.get_hf_processor(**mm_kwargs) + image_token_id = hf_processor.image_token_id + + # Since there may be extra tokens in the feature placeholders, + # we need to pass the image token ID to the model to select the + # tokens to merge from the vision encoder outputs + processed_outputs["image_token_id"] = torch.tensor(image_token_id) + + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: Mapping[str, NestedTensors], + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0)) + num_images = len(image_num_patches) + + return dict( + pixel_values_flat=MultiModalFieldConfig.flat_from_sizes( + "image", image_num_patches), + image_num_patches=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), + image_token_id=MultiModalFieldConfig.shared("image", num_images), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + + if "image_num_patches" in out_mm_kwargs: + image_num_patches = out_mm_kwargs["image_num_patches"] + assert isinstance(image_num_patches, torch.Tensor) + image_num_patches = image_num_patches.tolist() + elif "image_embeds" in out_mm_kwargs: + # to compute num_patches (similar to Qwen2-VL) + image_num_patches = [None] * len(out_mm_kwargs["image_embeds"]) + else: + image_num_patches = [] + + def get_replacement_custom(item_idx: int): + images = mm_items.get_items( + "image", (ImageEmbeddingItems, ImageProcessorItems)) + + if isinstance(images, ImageEmbeddingItems): + feature_size = images.get_feature_size(item_idx) + else: + image_size = images.get_image_size(item_idx) + # Extract max_num_tiles from kwargs, default to 12 + max_num_tiles = hf_processor_mm_kwargs.get("max_num_tiles", 12) + feature_size = self.info.get_num_image_tokens( + image_width=image_size.width, + image_height=image_size.height, + max_num_tiles=max_num_tiles, + processor=hf_processor, + ) + + num_patches = None + local_image_num_patches = image_num_patches + if isinstance(local_image_num_patches, torch.Tensor): + local_image_num_patches = local_image_num_patches.tolist() + if isinstance( + local_image_num_patches, + (list, tuple)) and item_idx < len(local_image_num_patches): + num_patches = int(local_image_num_patches[item_idx]) + + return hf_processor.get_image_repl(feature_size, num_patches) + + return [ + PromptReplacement( + modality="image", + target="<image>", + replacement=get_replacement_custom, + ) + ] + + +class NanoNemotronVLMultiModalProcessor( + NanoNemotronBaseVLMultiModalProcessor[NanoNemotronVLProcessingInfo]): + """MultiModalProcessor extended for video support""" + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> Mapping[str, NestedTensors]: + processed_outputs = super()._call_hf_processor(prompt, mm_data, + mm_kwargs, tok_kwargs) + + hf_processor = self.info.get_hf_processor(**mm_kwargs) + if self.info.supports_video and ( + video_token_id := hf_processor.video_token_id) is not None: + processed_outputs["video_token_id"] = torch.tensor(video_token_id) + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: Mapping[str, NestedTensors], + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + image_fields = super()._get_mm_fields_config(hf_inputs, + hf_processor_mm_kwargs) + if self.info.supports_video: + video_num_patches = hf_inputs.get("video_num_patches", + torch.empty(0)) + num_videos = len(video_num_patches) + video_fields = dict( + pixel_values_flat_video=MultiModalFieldConfig.flat_from_sizes( + "video", video_num_patches), + video_num_patches=MultiModalFieldConfig.batched("video"), + video_token_id=MultiModalFieldConfig.shared( + "video", num_videos)) + else: + video_fields = {} + + return image_fields | video_fields + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + prompt_repl = super()._get_prompt_updates( + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + out_mm_kwargs=out_mm_kwargs, + ) + + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + + out_mm_data = out_mm_kwargs.get_data() + if "video_num_patches" in out_mm_data: + video_num_patches = out_mm_data["video_num_patches"] + assert isinstance(video_num_patches, torch.Tensor) + video_num_patches = video_num_patches.tolist() + else: + video_num_patches = [] + + def get_video_replacement_internvl(item_idx: int): + feature_size = hf_processor.num_image_token + num_patches = video_num_patches[item_idx] + if num_patches is not None: + assert isinstance(num_patches, int) + + return hf_processor.get_video_repl( + feature_size, + num_patches, + video_context_token=hf_processor.video_token) + + if self.info.supports_video: + prompt_repl = [ + *prompt_repl, + PromptReplacement( + modality="video", + target="<video>", + replacement=get_video_replacement_internvl, + ) + ] + + return prompt_repl + + +class NanoNemotronVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]): + """Basic image-only DummyInputsBuilder for InternVL-style models.""" + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + + return "<image>" * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + # Use default max_num_tiles for dummy data generation + max_num_tiles = 12 + target_width, target_height = ( + self.info.get_image_size_with_most_features(max_num_tiles)) + num_images = mm_counts.get("image", 0) + + return { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + +class NanoNemotronVLDummyInputsBuilder( + NanoNemotronVLDummyInputsBuilder[NanoNemotronVLProcessingInfo]): + """DummyInputsBuilder extended for video support""" + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_videos = mm_counts.get("video", 0) + + return super().get_dummy_text(mm_counts) + "<video>" * num_videos + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + dummy_image = super().get_dummy_mm_data(seq_len=seq_len, + mm_counts=mm_counts) + if self.info.supports_video: + config = self.info.get_hf_config() + image_size: int = config.force_image_size + target_num_frames = \ + self.info.get_num_frames_with_most_features(seq_len, mm_counts) + num_videos = mm_counts.get("video", 0) + dummy_video = { + "video": + self._get_dummy_videos(width=image_size, + height=image_size, + num_frames=target_num_frames, + num_videos=num_videos) + } + else: + dummy_video = {} + return {**dummy_image, **dummy_video} + + +@MULTIMODAL_REGISTRY.register_processor( + NanoNemotronVLMultiModalProcessor, + info=NanoNemotronVLProcessingInfo, + dummy_inputs=NanoNemotronVLDummyInputsBuilder, +) +class NemotronH_Nano_VL(nn.Module, HasInnerState, IsHybrid, + SupportsMultiModal): + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("image"): + return "<image>" + if modality.startswith("video"): + return "<video>" + return None + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + + image_size = config.force_image_size + patch_size = config.patch_size + self.patch_size = patch_size + self.template = config.template + self.num_image_token = int( + (image_size // patch_size)**2 * (config.downsample_ratio**2)) + self.downsample_ratio = config.downsample_ratio + self.ps_version = config.ps_version + self.image_tag_type = config.image_tag_type + + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + self.vision_model = AutoModel.from_config(config.vision_config, + trust_remote_code=True) + self.vision_model.model._initialize_weights = ( + self.vision_model.model._init_weights) + # Move input normalization to processor to mirror original HF + # implementation where normalization is done in fp32 + self.vision_model.radio_model.make_preprocessor_external() + self.vision_model = self.vision_model.to( + self.language_model.config.torch_dtype) + + self.drop_vision_class_token = True + + # Construct the vision projection. + vit_hidden_size = config.vit_hidden_size + vision_projection_hidden_size = config.projector_hidden_size + llm_hidden_size = config.text_config.hidden_size + + self.mlp1 = nn.Sequential( + RMSNorm(hidden_size=vit_hidden_size * + int(1 / self.downsample_ratio)**2, + eps=1e-5), + nn.Linear( + vit_hidden_size * int(1 / self.downsample_ratio)**2, + vision_projection_hidden_size, + bias=False, + ), + ReLUSquaredActivation(), + nn.Linear(vision_projection_hidden_size, + llm_hidden_size, + bias=False), + ) + self.mlp1 = self.mlp1.to(self.language_model.config.torch_dtype) + + self.img_context_token_id = None + self.video_context_token_id = None + self.config = config + + def pixel_shuffle(self, x, scale_factor=0.5): + n, w, h, c = x.size() + # N, W, H, C --> N, W, H * scale, C // scale + x = x.view( + n, + w, + int(h * scale_factor), + int(c / scale_factor), + ) + # N, W, H * scale, C // scale --> N, H * scale, W, C // scale + x = x.permute(0, 2, 1, 3).contiguous() + # N, H * scale, W, C // scale --> + # N, H * scale, W * scale, C // (scale ** 2) + x = x.view( + n, + int(h * scale_factor), + int(w * scale_factor), + int(c / (scale_factor * scale_factor)), + ) + if self.ps_version == "v1": + warnings.warn( + "In ps_version 'v1', the height and width have not " + "been swapped back, which results in a transposed image.", + stacklevel=2, + ) + else: + x = x.permute(0, 2, 1, 3).contiguous() + return x + + def extract_feature(self, pixel_values): + vit_embeds = self.vision_model(pixel_values).features + vit_embeds = vit_embeds.to(dtype=torch.bfloat16) + h = w = int(vit_embeds.shape[1]**0.5) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) + vit_embeds = self.pixel_shuffle(vit_embeds, + scale_factor=self.downsample_ratio) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, + vit_embeds.shape[-1]) + vit_embeds = self.mlp1(vit_embeds) + return vit_embeds + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[NanoNemotronVLImageInputs]: + pixel_values_flat = kwargs.pop("pixel_values_flat", None) + image_num_patches = kwargs.pop("image_num_patches", None) + image_embeds = kwargs.pop("image_embeds", None) + + if pixel_values_flat is None and image_embeds is None: + return None + + if image_embeds is not None: + if not isinstance(image_embeds, (torch.Tensor, list)): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + + return NanoNemotronVLImageEmbeddinInputs( + type="image_embeds", + data=flatten_bn(image_embeds), + ) + + image_token_id = kwargs["image_token_id"] + assert isinstance(image_token_id, torch.Tensor) + self.img_context_token_id = image_token_id.flatten().unique().item() + + if pixel_values_flat is not None: + if not isinstance(pixel_values_flat, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values_flat)}") + + if not isinstance(image_num_patches, (torch.Tensor, list)): + raise ValueError("Incorrect type of image_num_patches. " + f"Got type: {type(image_num_patches)}") + + pixel_values_flat = flatten_bn(pixel_values_flat, concat=True) + image_num_patches = flatten_bn(image_num_patches, concat=True) + + return NanoNemotronVLImagePixelInputs( + type="pixel_values", + pixel_values_flat=pixel_values_flat, + num_patches=image_num_patches, + ) + + raise AssertionError("This line should be unreachable.") + + def _process_image_input( + self, image_input: NanoNemotronVLImageInputs) -> torch.Tensor: + if image_input["type"] == "image_embeds": + return image_input["data"] + + assert self.vision_model is not None + + image_embeds = self.extract_feature(image_input["pixel_values_flat"]) + num_patches = image_input["num_patches"] + + # Only one image in the current batch + if len(num_patches) == 1: + return (image_embeds.view(-1, + self.config.text_config.hidden_size), ) + + # NOTE: Image embeddings are split into separate tensors for each image + # by the size of each embedding. + feature_size = image_embeds.shape[1] + image_embeds = image_embeds.view(-1, + self.config.text_config.hidden_size) + image_feature_sizes = [ + num_patches * feature_size for num_patches in num_patches + ] + return image_embeds.split(image_feature_sizes) + + def _parse_and_validate_video_input( + self, + **kwargs: object) -> Optional[NanoNemotronVLVideoPixelInputs]: + pixel_values_flat_video = kwargs.pop("pixel_values_flat_video", None) + video_num_patches = kwargs.pop("video_num_patches", None) + video_embeds = kwargs.pop("video_embeds", None) + + if pixel_values_flat_video is None and video_embeds is None: + return None + + if video_embeds is not None: + return NanoNemotronVLVideoEmbeddingInputs( + type="video_embeds", + data=flatten_bn(video_embeds), + ) + + video_token_id = kwargs["video_token_id"] + assert isinstance(video_token_id, torch.Tensor) + self.video_context_token_id = video_token_id.flatten().unique().item() + + if pixel_values_flat_video is not None: + if not isinstance(pixel_values_flat_video, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values_flat_video)}") + + if not isinstance(video_num_patches, (torch.Tensor, list)): + raise ValueError("Incorrect type of image_num_patches. " + f"Got type: {type(video_num_patches)}") + + pixel_values_flat_video = flatten_bn(pixel_values_flat_video, + concat=True) + video_num_patches = flatten_bn(video_num_patches, concat=True) + expected_h = expected_w = self.config.force_image_size + resolve_bindings = {"h": expected_h, "w": expected_w} + + return NanoNemotronVLVideoPixelInputs( + type="pixel_values_videos", + pixel_values_flat=pixel_values_flat_video, + num_patches=video_num_patches, + resolve_bindings=resolve_bindings, + ) + + raise AssertionError("This line should be unreachable.") + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + modalities = {} + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if input_key in ("pixel_values_flat", + "image_embeds") and "images" not in modalities: + modalities["images"] = self._parse_and_validate_image_input( + **kwargs) + if input_key in ("pixel_values_flat_video", + ) and "videos" not in modalities: + modalities["videos"] = self._parse_and_validate_video_input( + **kwargs) + + return modalities + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + # Validate the multimodal input keyword arguments + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if modalities is None: + return [] + + # # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in modalities: + if modality == "images": + image_input = modalities["images"] + vision_embeddings = self._process_image_input(image_input) + multimodal_embeddings += vision_embeddings + if modality == "videos": + video_input = modalities["videos"] + video_embeddings = self._process_image_input(video_input) + multimodal_embeddings += video_embeddings + + return multimodal_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + + if (multimodal_embeddings is not None + and len(multimodal_embeddings) != 0): + context_token_ids = [ + token_id for token_id in (self.img_context_token_id, + self.video_context_token_id) + if token_id is not None + ] + assert len(context_token_ids) >= 1 + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + multimodal_embeddings, + context_token_ids, + ) + + return inputs_embeds + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + input_ids = None + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + hidden_states = self.language_model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + return hidden_states + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="mlp1", + tower_model="vision_model", + ) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + + def is_vision_model_weights(weight: tuple[str, torch.Tensor]): + return weight[0].startswith("vision_model") + + def is_adapter_weights(weight: tuple[str, torch.Tensor]): + return weight[0].startswith("mlp1") + + # Get references to parameters for direct loading + vision_model_dict = dict(self.vision_model.named_parameters()) + vision_model_buffers = dict(self.vision_model.named_buffers()) + adapter_dict = dict(self.mlp1.named_parameters()) + + def llm_weights_generator(): + # Single pass over weights + for name, w in weights: + if is_vision_model_weights((name, w)): + # Load vision encoder weights directly + trimmed_name = ".".join(name.split(".")[1:]) + if "input_conditioner" in trimmed_name: + continue + if trimmed_name in vision_model_buffers: + param = vision_model_buffers[trimmed_name] + else: + param = vision_model_dict[trimmed_name] + with torch.no_grad(): + default_weight_loader(param, w) + elif is_adapter_weights((name, w)): + # Load vision-language adapter weights directly + trimmed_name = ".".join(name.split(".")[1:]) + param = adapter_dict[trimmed_name] + with torch.no_grad(): + default_weight_loader(param, w) + else: + # LLM weights: yield them to be loaded + # by language_model.load_weights + assert name.startswith("language_model") + trimmed_name = ".".join(name.split(".")[1:]) + yield (trimmed_name, w) + + # Now we call the language model load with the generator + self.language_model.load_weights(llm_weights_generator()) + + def print_architecture(self, + detailed: bool = True, + save_to_file: str = None): + """ + Print model architecture with parameter names, shapes, and sizes. + + Args: + detailed: If True, show detailed parameter breakdown + save_to_file: If provided, save output to this file path + """ + import sys + from io import StringIO + + # Capture output if saving to file + original_stdout = sys.stdout + if save_to_file: + sys.stdout = StringIO() + + try: + print("=" * 100) + print("NemotronH_Nano_VL Model Architecture") + print("=" * 100) + + total_params = 0 + param_groups = { + "language_model": [], + "vision_model": [], + "mlp1": [], + "other": [], + } + + for name, param in self.named_parameters(): + param_size = param.numel() + total_params += param_size + + # Group parameters by main component + if name.startswith("language_model"): + param_groups["language_model"].append( + (name, param.shape, param_size, param.dtype)) + elif name.startswith("vision_model"): + param_groups["vision_model"].append( + (name, param.shape, param_size, param.dtype)) + elif name.startswith("mlp1"): + param_groups["mlp1"].append( + (name, param.shape, param_size, param.dtype)) + else: + param_groups["other"].append( + (name, param.shape, param_size, param.dtype)) + + if detailed: + print(f"{name:<70} | Shape: {str(param.shape):<25} | " + f"Size: {param_size:>12,} | Dtype: {param.dtype}") + + print("=" * 100) + print("Summary by Component:") + print("-" * 60) + + for component, params in param_groups.items(): + if params: # Only show components that have parameters + component_total = sum(size for _, _, size, _ in params) + percentage = ((component_total / total_params) * + 100 if total_params > 0 else 0) + print(f"{component:<20} | Parameters: {len(params):>4} | " + f"Total Size: {component_total:>15,} | " + f"{percentage:>6.2f}%") + + print("-" * 60) + print(f"{'Total Parameters':<20} | {total_params:>15,}") + + # Estimate memory usage (assuming bfloat16 = 2 bytes per parameter) + memory_mb = total_params * 2 / (1024**2) + memory_gb = memory_mb / 1024 + print(f"{'Est. Memory (MB)':<20} | {memory_mb:>15.2f}") + print(f"{'Est. Memory (GB)':<20} | {memory_gb:>15.2f}") + print("=" * 100) + + # Save to file if requested + if save_to_file: + output = sys.stdout.getvalue() + sys.stdout = original_stdout + with open(save_to_file, "w") as f: + f.write(output) + print(f"Architecture saved to: {save_to_file}") + print(output) # Also print to console + + finally: + if save_to_file and sys.stdout != original_stdout: + sys.stdout = original_stdout + + def get_model_info(self): + """ + Get basic model information as a dictionary. + """ + total_params = sum(p.numel() for p in self.parameters()) + + component_info = {} + for name, param in self.named_parameters(): + component = name.split(".")[0] + if component not in component_info: + component_info[component] = {"params": 0, "size": 0} + component_info[component]["params"] += 1 + component_info[component]["size"] += param.numel() + + return { + "model_name": "NemotronH_Nano_VL", + "total_parameters": total_params, + "memory_estimate_mb": total_params * 2 / (1024**2), # bfloat16 + "components": component_info, + "config": { + "image_size": getattr(self.config, "force_image_size", None), + "patch_size": getattr(self.config, "patch_size", None), + "num_image_token": self.num_image_token, + "downsample_ratio": self.downsample_ratio, + }, + } + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + return self.language_model.mamba_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + return (self.language_model.mamba_cache. + get_seqlen_agnostic_capture_inputs(batch_size)) + + @classmethod + def get_mamba_state_shape_from_config(cls, vllm_config: "VllmConfig"): + text_config = vllm_config.model_config.hf_config.text_config + temp_vllm_config = copy.deepcopy(vllm_config) + temp_vllm_config.model_config.hf_config = text_config + return NemotronHForCausalLM.get_mamba_state_shape_from_config( + temp_vllm_config) + + @classmethod + def get_mamba_state_dtype_from_config(cls, vllm_config: "VllmConfig"): + text_config = vllm_config.model_config.hf_config.text_config + temp_vllm_config = copy.deepcopy(vllm_config) + temp_vllm_config.model_config.hf_config = text_config + return NemotronHForCausalLM.get_mamba_state_dtype_from_config( + temp_vllm_config) diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 8a563288cb4d6..da8628df1fe57 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -44,15 +44,16 @@ from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, SupportsQuant) from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) from vllm.model_executor.models.utils import ( - AutoWeightsLoader, make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) + AutoWeightsLoader, WeightsMapper, make_empty_intermediate_tensors_factory, + make_layers, maybe_prefix) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import NemotronHConfig @@ -426,38 +427,36 @@ class NemotronHModel(nn.Module): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - attb_params_mapping = { - "q_proj": "q", - "k_proj": "k", - "v_proj": "v", - } + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - if "embeddings" in name: - name = name.replace("embeddings", "embed_tokens") + if "scale" in name: + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue - if "A_log" in name: - name = name.replace("A_log", "A") - loaded_weight = loaded_weight.to(torch.float32) + # load stacked params + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue - if "D" in name: - loaded_weight = loaded_weight.to(torch.float32) - - if "dt_bias" in name: - loaded_weight = loaded_weight.to(torch.float32) - - # load attn params - if any(proj in name for proj in ["q_proj", "k_proj", "v_proj"]): - weight_name = next(proj - for proj in ["q_proj", "k_proj", "v_proj"] - if proj in name) - name = name.replace(weight_name, "qkv_proj") param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, loaded_weight, - attb_params_mapping[weight_name]) + weight_loader(param, loaded_weight, shard_id) + break + # load other params else: param = params_dict[name] @@ -471,6 +470,14 @@ class NemotronHModel(nn.Module): class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsQuant): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={"backbone": "model"}, + orig_to_new_substr={ + "A_log": "A", + "embeddings": "embed_tokens" + }, + ) + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -622,10 +629,5 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - # update name in weights before passing to loader - updated_weights = [] - for name, loaded_weight in weights: - name = name.replace("backbone", "model") - updated_weights.append((name, loaded_weight)) loader = AutoWeightsLoader(self) - return loader.load_weights(updated_weights) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/nemotron_vl.py b/vllm/model_executor/models/nemotron_vl.py index a9c7d8044e10c..acda2027401d9 100644 --- a/vllm/model_executor/models/nemotron_vl.py +++ b/vllm/model_executor/models/nemotron_vl.py @@ -560,7 +560,7 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, return [] # The result multimodal_embeddings is tuple of tensors, with each - # tensor correspoending to a multimodal data item (image). + # tensor corresponding to a multimodal data item (image). multimodal_embeddings: tuple[torch.Tensor, ...] = () # NOTE: It is important to iterate over the keys in this dictionary diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py index bccd1b87043a5..3e4c580a11211 100644 --- a/vllm/model_executor/models/olmo2.py +++ b/vllm/model_executor/models/olmo2.py @@ -52,10 +52,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP from vllm.model_executor.models.utils import ( - AutoWeightsLoader, is_pp_missing_parameter, + AutoWeightsLoader, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import Olmo3Config class Olmo2Attention(nn.Module): @@ -68,7 +69,7 @@ class Olmo2Attention(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() self.config = vllm_config.model_config.hf_config - assert isinstance(self.config, Olmo2Config) + assert isinstance(self.config, (Olmo2Config, Olmo3Config)) hidden_size = self.config.hidden_size self.tp_size = get_tensor_model_parallel_world_size() @@ -111,14 +112,14 @@ class Olmo2Attention(nn.Module): self.q_norm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) - # Rotary embeddings. - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=self.max_position_embeddings, - base=self.rope_theta, # type: ignore - ) self.scaling = self.head_dim**-0.5 + + layer_idx = extract_layer_index(prefix) + sliding_window = None + if ((layer_types := getattr(self.config, "layer_types", None)) + is not None and layer_types[layer_idx] == "sliding_attention"): + sliding_window = self.config.sliding_window + self.attn = Attention( self.num_heads, self.head_dim, @@ -126,7 +127,20 @@ class Olmo2Attention(nn.Module): num_kv_heads=self.num_kv_heads, cache_config=vllm_config.cache_config, quant_config=vllm_config.quant_config, - prefix=prefix, + per_layer_sliding_window=sliding_window, + prefix=f"{prefix}.attn", + ) + + # Rotary embeddings. Rope scaling is only applied on full attention + # layers. + self.rope_scaling = (self.config.rope_scaling + if sliding_window is None else None) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, # type: ignore + rope_scaling=self.rope_scaling, ) # Attention output projection. @@ -176,7 +190,7 @@ class Olmo2MLP(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - assert isinstance(config, Olmo2Config) + assert isinstance(config, (Olmo2Config, Olmo3Config)) hidden_size = config.hidden_size intermediate_size = config.intermediate_size @@ -221,7 +235,7 @@ class Olmo2DecoderLayer(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - assert isinstance(config, Olmo2Config) + assert isinstance(config, (Olmo2Config, Olmo3Config)) # Attention block. self.self_attn = Olmo2Attention(vllm_config=vllm_config, prefix=f"{prefix}.self_attn") @@ -261,7 +275,7 @@ class Olmo2Model(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() self.config = vllm_config.model_config.hf_config - assert isinstance(self.config, Olmo2Config) + assert isinstance(self.config, (Olmo2Config, Olmo3Config)) self.embed_tokens = VocabParallelEmbedding( self.config.vocab_size, @@ -376,7 +390,7 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - assert isinstance(config, Olmo2Config) + assert isinstance(config, (Olmo2Config, Olmo3Config)) self.config = config self.model = Olmo2Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index b74a09ee92c33..d6eec77ebcee5 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -12,7 +12,8 @@ from vllm.logger import init_logger from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputs, MultiModalKwargsItems) + MultiModalInputs, MultiModalKwargsItems, + MultiModalUUIDDict) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -203,13 +204,13 @@ class PaliGemmaMultiModalProcessor( mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, - mm_hash_overrides: Optional[dict[str, list[str]]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> MultiModalInputs: mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs, tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides) + mm_uuids=mm_uuids) prompt_token_ids = mm_inputs["prompt_token_ids"] tokenizer = self.info.get_tokenizer() diff --git a/vllm/model_executor/models/phi4_multimodal.py b/vllm/model_executor/models/phi4_multimodal.py index 6d973a964de04..25df9e9261d91 100644 --- a/vllm/model_executor/models/phi4_multimodal.py +++ b/vllm/model_executor/models/phi4_multimodal.py @@ -374,8 +374,8 @@ class Phi4MMAudioMeanVarianceNormLayer(nn.Module): Typically used as a very first layer in a model. Args: - input_size: int - layer input size. + config: [Phi4MultimodalAudioConfig](https://huggingface.co/docs/transformers/model_doc/phi4_multimodal#transformers.Phi4MultimodalAudioConfig) + object containing model parameters. """ def __init__(self, config: Phi4MultimodalAudioConfig): @@ -1350,7 +1350,7 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): return None # The result multimodal_embeddings is tuple of tensors, with each - # tensor correspoending to a multimodal data item (image or video). + # tensor corresponding to a multimodal data item (image or video). multimodal_embeddings: tuple[torch.Tensor, ...] = () # NOTE: It is important to iterate over the keys in this dictionary diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index 352ae4064cc61..46963828186cc 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -1154,7 +1154,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): return None # The result multimodal_embeddings is tuple of tensors, with each - # tensor correspoending to a multimodal data item (image or video). + # tensor corresponding to a multimodal data item (image or video). multimodal_embeddings: tuple[torch.Tensor, ...] = () # NOTE: It is important to iterate over the keys in this dictionary diff --git a/vllm/model_executor/models/phi4mm_audio.py b/vllm/model_executor/models/phi4mm_audio.py index b5e4d727bf210..a1c452053ddd2 100644 --- a/vllm/model_executor/models/phi4mm_audio.py +++ b/vllm/model_executor/models/phi4mm_audio.py @@ -7,7 +7,7 @@ #!/usr/bin/env python3 import abc import math -from typing import Literal, Optional +from typing import Any, Literal, Optional, Union import numpy as np import torch @@ -100,7 +100,7 @@ class ConformerEncoderLayer(nn.Module): activation function for glu used in the multihead attention, default "swish". activation_checkpointing: str, optional - a dictionarry of {"module","interval","offload"}, where + a dictionary of {"module","interval","offload"}, where "module": str accept ["transformer", "attention"] to select which module should do activation checkpointing. @@ -131,31 +131,31 @@ class ConformerEncoderLayer(nn.Module): def __init__( self, - d_model=512, - ext_pw_out_channel=0, - depthwise_seperable_out_channel=256, - depthwise_multiplier=1, - n_head=4, - d_ffn=2048, - ext_pw_kernel_size=1, - kernel_size=3, - dropout_rate=0.1, - causal=False, - batch_norm=False, - activation="relu", - chunk_se=0, - chunk_size=18, - conv_activation="relu", - conv_glu_type="sigmoid", - bias_in_glu=True, - linear_glu_in_convm=False, - attention_inner_dim=-1, - attention_glu_type="swish", - activation_checkpointing="", - export=False, - use_pt_scaled_dot_product_attention=False, + d_model: int = 512, + ext_pw_out_channel: int = 0, + depthwise_seperable_out_channel: int = 256, + depthwise_multiplier: int = 1, + n_head: int = 4, + d_ffn: int = 2048, + ext_pw_kernel_size: int = 1, + kernel_size: int = 3, + dropout_rate: float = 0.1, + causal: bool = False, + batch_norm: bool = False, + activation: str = "relu", + chunk_se: int = 0, + chunk_size: int = 18, + conv_activation: str = "relu", + conv_glu_type: str = "sigmoid", + bias_in_glu: bool = True, + linear_glu_in_convm: bool = False, + attention_inner_dim: int = -1, + attention_glu_type: str = "swish", + activation_checkpointing: str = "", + export: bool = False, + use_pt_scaled_dot_product_attention: bool = False, attn_group_sizes: int = 1, - ): + ) -> None: super().__init__() self.feed_forward_in = FeedForward( @@ -209,24 +209,21 @@ class ConformerEncoderLayer(nn.Module): def forward( self, - x, - pos_k, - pos_v, - mask, + x: torch.Tensor, + pos_k: torch.Tensor, + pos_v: torch.Tensor, + mask: torch.Tensor, relative_attention_bias: Optional[Tensor] = None, - ): + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ConformerEncoder forward. Args: - x: torch.Tensor - input feature of shape (batch, max_time_in, size) - pos_k: torch.Tensor - positional key embedding. - mask: torch.Tensor - mask for x (batch, max_time_in) - relative_attention_bias: Optional[torch.Tensor] - bias added to attention logits w.r.t. relative positions - (1, n_head, time1, time2) + x: input feature of shape (batch, max_time_in, size) + pos_k: positional key embedding. + pos_v: positional value embedding. + mask: mask for x (batch, max_time_in) + relative_attention_bias: bias added to attention logits w.r.t. + relative positions (1, n_head, time1, time2) """ x = x + 0.5 * self.feed_forward_in(x) norm_x = self.layer_norm_att(x) @@ -323,25 +320,25 @@ class TransformerEncoderBase(abc.ABC, nn.Module): def __init__( self, - input_size, - chunk_size, - left_chunk, - attention_dim=256, - attention_heads=4, - input_layer="nemo_conv", - cnn_out=-1, - cnn_layer_norm=False, - time_reduction=4, - dropout_rate=0.0, - padding_idx=-1, - relative_attention_bias_args=None, - positional_dropout_rate=0.0, - nemo_conv_settings=None, + input_size: int, + chunk_size: Union[int, list[int]], + left_chunk: Union[int, list[int]], + attention_dim: int = 256, + attention_heads: int = 4, + input_layer: str = "nemo_conv", + cnn_out: int = -1, + cnn_layer_norm: bool = False, + time_reduction: int = 4, + dropout_rate: float = 0.0, + padding_idx: int = -1, + relative_attention_bias_args: Optional[dict[str, Any]] = None, + positional_dropout_rate: float = 0.0, + nemo_conv_settings: Optional[dict[str, Any]] = None, conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none", - attention_group_size=1, - encoder_embedding_config=None, - ): + attention_group_size: int = 1, + encoder_embedding_config: Optional[dict[str, Any]] = None, + ) -> None: super().__init__() self.input_size = input_size self.input_layer = input_layer @@ -399,7 +396,10 @@ class TransformerEncoderBase(abc.ABC, nn.Module): self.encoder_embedding = MeanVarianceNormLayer( self.encoder_embedding_config["input_size"]) - def compute_lens_change(self, feature_lens): + def compute_lens_change( + self, + feature_lens: Union[int, + torch.Tensor]) -> Union[int, torch.Tensor]: """feature_lens: int return updated feature lens. @@ -433,10 +433,14 @@ class TransformerEncoderBase(abc.ABC, nn.Module): return ceil_func(feature_lens / self.time_reduction) @abc.abstractmethod - def forward(self): + def forward(self) -> Any: """Abstract forward method implementation.""" - def _chunk_size_selection(self, chunk_size=None, left_chunk=None): + def _chunk_size_selection( + self, + chunk_size: Optional[Union[int, list[int]]] = None, + left_chunk: Optional[Union[int, + list[int]]] = None) -> tuple[int, int]: """If chunk size is a list, we will randomly select a chunk size.""" if chunk_size is None: @@ -463,7 +467,7 @@ class TransformerEncoderBase(abc.ABC, nn.Module): return chunk_size_train_eff, left_chunk_train_eff - def _get_embed_class(self, embed): + def _get_embed_class(self, embed: nn.Module) -> nn.Module: # pylint: disable=protected-access is_embed_using_act_chkpt = isinstance(embed, CheckpointWrapper) is_embed_fsdp_wrapped = isinstance(embed, FullyShardedDataParallel) @@ -474,13 +478,17 @@ class TransformerEncoderBase(abc.ABC, nn.Module): embed_class = embed.module return embed_class - def _forward_embeddings_core(self, input_tensor, masks): + def _forward_embeddings_core( + self, input_tensor: torch.Tensor, + masks: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: embed_class = self._get_embed_class(self.embed) assert isinstance(embed_class, NemoConvSubsampling) input_tensor, masks = self.embed(input_tensor, masks) return input_tensor, masks - def _position_embedding(self, input_tensor): + def _position_embedding( + self, input_tensor: torch.Tensor + ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: pos_k = None pos_v = None if self.relative_attention_bias_layer is None: @@ -488,7 +496,9 @@ class TransformerEncoderBase(abc.ABC, nn.Module): input_tensor) # default to add abs sinusoid embedding return pos_k, pos_v - def _streaming_mask(self, seq_len, batch_size, chunk_size, left_chunk): + def _streaming_mask(self, seq_len: int, batch_size: int, + chunk_size: Union[int, list[int]], + left_chunk: Union[int, list[int]]) -> torch.Tensor: chunk_size_train_eff, left_chunk_train_eff = \ self._chunk_size_selection(chunk_size, left_chunk) @@ -502,11 +512,17 @@ class TransformerEncoderBase(abc.ABC, nn.Module): [batch_size, -1, -1])) return enc_streaming_mask - def forward_embeddings(self, - xs_pad, - masks, - chunk_size_nc=None, - left_chunk_nc=None): + def forward_embeddings( + self, + xs_pad: torch.Tensor, + masks: torch.Tensor, + chunk_size_nc: Optional[Union[int, list[int]]] = None, + left_chunk_nc: Optional[Union[int, list[int]]] = None + ) -> Union[tuple[torch.Tensor, Optional[torch.Tensor], + Optional[torch.Tensor], torch.Tensor, torch.Tensor], + tuple[torch.Tensor, Optional[torch.Tensor], + Optional[torch.Tensor], torch.Tensor, torch.Tensor, + torch.Tensor]]: """Forwarding the inputs through the top embedding layers Args: @@ -569,7 +585,7 @@ class TransformerEncoderBase(abc.ABC, nn.Module): return input_tensor, pos_k, pos_v, hs_mask, masks return input_tensor, pos_k, pos_v, hs_mask, masks, hs_mask_nc - def get_offset(self): + def get_offset(self) -> int: """Returns offset used when retaining inputs for decoding. This is essentially, how many additional frames have to be added to @@ -605,8 +621,6 @@ class ConformerEncoder(TransformerEncoderBase): Some examples for the 2 cases: left_chunk = 6 left_chunk = [12, 9, 6, 3] - left_chunk: int - number of chunks used for masking in streaming mode. num_lang: int This parameter is used to store the number of languages in the lang_dict, only used for multiseed/multilingual models. @@ -751,46 +765,46 @@ class ConformerEncoder(TransformerEncoderBase): def __init__( # pylint: disable-all self, - input_size, - chunk_size, - left_chunk, - num_lang=None, - attention_dim=256, - attention_heads=4, - linear_units=2048, - num_blocks=6, - dropout_rate=0.1, - input_layer="nemo_conv", - causal=True, - batch_norm=False, - cnn_out=-1, - cnn_layer_norm=False, - ext_pw_out_channel=0, - ext_pw_kernel_size=1, - depthwise_seperable_out_channel=256, - depthwise_multiplier=1, - chunk_se=0, - kernel_size=3, - activation="relu", - conv_activation="relu", - conv_glu_type="sigmoid", - bias_in_glu=True, - linear_glu_in_convm=False, - attention_glu_type="swish", - export=False, - extra_layer_output_idx=-1, - extra_multi_layer_output_idxs=[], # noqa - activation_checkpointing="", - relative_attention_bias_args=None, - time_reduction=4, - use_pt_scaled_dot_product_attention=False, - nemo_conv_settings=None, + input_size: int, + chunk_size: Union[int, list[int]], + left_chunk: Union[int, list[int]], + num_lang: Optional[int] = None, + attention_dim: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + input_layer: str = "nemo_conv", + causal: bool = True, + batch_norm: bool = False, + cnn_out: int = -1, + cnn_layer_norm: bool = False, + ext_pw_out_channel: int = 0, + ext_pw_kernel_size: int = 1, + depthwise_seperable_out_channel: int = 256, + depthwise_multiplier: int = 1, + chunk_se: int = 0, + kernel_size: int = 3, + activation: str = "relu", + conv_activation: str = "relu", + conv_glu_type: str = "sigmoid", + bias_in_glu: bool = True, + linear_glu_in_convm: bool = False, + attention_glu_type: str = "swish", + export: bool = False, + extra_layer_output_idx: int = -1, + extra_multi_layer_output_idxs: list[int] = [], # noqa + activation_checkpointing: str = "", + relative_attention_bias_args: Optional[dict[str, Any]] = None, + time_reduction: int = 4, + use_pt_scaled_dot_product_attention: bool = False, + nemo_conv_settings: Optional[dict[str, Any]] = None, conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none", - replication_pad_for_subsample_embedding=False, - attention_group_size=1, - encoder_embedding_config=None, - ): + replication_pad_for_subsample_embedding: bool = False, + attention_group_size: int = 1, + encoder_embedding_config: Optional[dict[str, Any]] = None, + ) -> None: super().__init__( input_size, chunk_size, @@ -852,11 +866,13 @@ class ConformerEncoder(TransformerEncoderBase): # the device and the needed dtype: self.register_buffer("dev_type", torch.zeros(()), persistent=False) - def init_relative_attention_bias(self, input_tensor): + def init_relative_attention_bias( + self, input_tensor: torch.Tensor) -> Optional[torch.Tensor]: if self.relative_attention_bias_layer: return self.relative_attention_bias_layer(input_tensor) - def calculate_hs_mask(self, xs_pad, device, mask): + def calculate_hs_mask(self, xs_pad: torch.Tensor, device: torch.device, + mask: Optional[torch.Tensor]) -> torch.Tensor: max_audio_length = xs_pad.shape[1] batch_size = xs_pad.shape[0] enc_streaming_mask = self._streaming_mask(max_audio_length, batch_size, @@ -877,7 +893,8 @@ class ConformerEncoder(TransformerEncoderBase): return pad_mask @torch.jit.ignore - def forward(self, xs_pad, masks): + def forward(self, xs_pad: torch.Tensor, + masks: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Conformer Forward function Args: @@ -997,7 +1014,12 @@ class WindowQformer(nn.Module): if normalize_before else None) self.window_size = window_size - def forward(self, audio_embed, mask, embed_len=None): + def forward( + self, + audio_embed: torch.Tensor, + mask: Optional[torch.Tensor], + embed_len: Optional[int] = None + ) -> tuple[torch.Tensor, Optional[int]]: """forward decoder""" # audio_embed: N x T x D => N x D x T @@ -1042,7 +1064,7 @@ class WindowQformer(nn.Module): class AudioEmbedding(nn.Module): """Image embedding.""" - def __init__(self, config: PretrainedConfig, **kwargs) -> None: + def __init__(self, config: PretrainedConfig, **kwargs: Any) -> None: super().__init__() self.config = config # n_embed or hidden_size for text LM @@ -1148,19 +1170,18 @@ class AudioEmbedding(nn.Module): self.input_embeds = None self.audio_embed_sizes = None - def set_audio_embeds(self, input_embeds: torch.FloatTensor) -> None: + def set_audio_embeds(self, input_embeds: torch.Tensor) -> None: self.input_embeds = input_embeds - def set_audio_embed_sizes(self, - audio_embed_sizes: torch.LongTensor) -> None: + def set_audio_embed_sizes(self, audio_embed_sizes: torch.Tensor) -> None: self.audio_embed_sizes = audio_embed_sizes def get_audio_features( self, - input_embeds: torch.FloatTensor, - audio_attention_mask: torch.Tensor = None, + input_embeds: torch.Tensor, + audio_attention_mask: Optional[torch.Tensor] = None, audio_projection_mode: str = "speech", - ) -> torch.FloatTensor: + ) -> torch.Tensor: """ arguments: input_embeds: audio features (B, T, D) B: num audios in a sequence @@ -1214,10 +1235,10 @@ class AudioEmbedding(nn.Module): def forward( self, - audio_features: torch.FloatTensor, - audio_attention_mask: torch.Tensor = None, + audio_features: torch.Tensor, + audio_attention_mask: Optional[torch.Tensor] = None, audio_projection_mode: str = "speech", - ) -> torch.FloatTensor: + ) -> torch.Tensor: """ arguments: audio_features: audio features (T, D) diff --git a/vllm/model_executor/models/phi4mm_utils.py b/vllm/model_executor/models/phi4mm_utils.py index 59535503822d4..6fbfca619a42f 100644 --- a/vllm/model_executor/models/phi4mm_utils.py +++ b/vllm/model_executor/models/phi4mm_utils.py @@ -16,13 +16,13 @@ from torch import Tensor, nn class BlockBase(nn.Module): """Block abstract module""" - def __init__(self, input_size, output_size): + def __init__(self, input_size: int, output_size: int) -> None: super().__init__() self.input_size = input_size self.output_size = output_size -def get_activation(name="relu"): +def get_activation(name: str = "relu") -> torch.nn.Module: """Select an activation function by name Args: @@ -43,15 +43,18 @@ def get_activation(name="relu"): return nn.Identity() -def adaptive_enc_mask(x_len, chunk_start_idx, left_window=0, right_window=0): +def adaptive_enc_mask(x_len: int, + chunk_start_idx: list[int], + left_window: int = 0, + right_window: int = 0) -> torch.Tensor: """ The function is very important for Transformer Transducer Streaming mode Args: - xs_len (int): sequence length - chunk_start_idx (list): first idx of each chunk, such as [0,18,36,48]. + x_len: sequence length + chunk_start_idx: first idx of each chunk, such as [0,18,36,48]. It also supports adaptive chunk size [0,10,15,45] - left_window (int): how many left chunks can be seen - right_window (int): how many right chunks can be seen. It is used for + left_window: how many left chunks can be seen + right_window: how many right chunks can be seen. It is used for chunk overlap model. Returns: mask (torch.Tensor): a mask tensor for streaming model @@ -172,13 +175,13 @@ class GLUPointWiseConv(nn.Module): def __init__( self, - input_dim, - output_dim, - kernel_size, - glu_type="sigmoid", - bias_in_glu=True, - causal=False, - ): + input_dim: int, + output_dim: int, + kernel_size: int, + glu_type: str = "sigmoid", + bias_in_glu: bool = True, + causal: bool = False, + ) -> None: super().__init__() self.glu_type = glu_type @@ -216,11 +219,10 @@ class GLUPointWiseConv(nn.Module): self.b1 = nn.Parameter(torch.zeros(1, output_dim, 1)) self.b2 = nn.Parameter(torch.zeros(1, output_dim, 1)) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: """ Args: - x: torch.Tensor - input tensor + x: input tensor """ # to be consistent with GLULinear, we assume the input always has the # #channel (#dim) in the last dimension of the tensor, so need to @@ -272,12 +274,12 @@ class DepthWiseSeperableConv1d(nn.Module): def __init__( self, - input_dim, - depthwise_seperable_out_channel, - kernel_size, - depthwise_multiplier, - padding=0, - ): + input_dim: int, + depthwise_seperable_out_channel: int, + kernel_size: int, + depthwise_multiplier: int, + padding: int = 0, + ) -> None: super().__init__() self.dw_conv = nn.Conv1d( @@ -301,12 +303,11 @@ class DepthWiseSeperableConv1d(nn.Module): self.pw_conv = nn.Identity() self.depthwise_seperable_out_channel = depthwise_seperable_out_channel - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: """ Args: - x: torch.Tensor - input tensor + x: input tensor """ x = self.dw_conv(x) if self.depthwise_seperable_out_channel != 0: @@ -375,23 +376,23 @@ class ConvModule(nn.Module): def __init__( self, - input_dim, - ext_pw_out_channel, - depthwise_seperable_out_channel, - ext_pw_kernel_size, - kernel_size, - depthwise_multiplier, - dropout_rate, - causal=False, - batch_norm=False, - chunk_se=0, - chunk_size=18, - activation="relu", - glu_type="sigmoid", - bias_in_glu=True, - linear_glu_in_convm=False, - export=False, - ): + input_dim: int, + ext_pw_out_channel: int, + depthwise_seperable_out_channel: int, + ext_pw_kernel_size: int, + kernel_size: int, + depthwise_multiplier: int, + dropout_rate: float, + causal: bool = False, + batch_norm: bool = False, + chunk_se: int = 0, + chunk_size: int = 18, + activation: str = "relu", + glu_type: str = "sigmoid", + bias_in_glu: bool = True, + linear_glu_in_convm: bool = False, + export: bool = False, + ) -> None: super().__init__() self.layer_norm = nn.LayerNorm(input_dim) self.input_dim = input_dim @@ -437,7 +438,7 @@ class ConvModule(nn.Module): self.ln2 = nn.Linear(input_dim * depthwise_multiplier, input_dim) - def _add_ext_pw_layer(self): + def _add_ext_pw_layer(self) -> None: """ This function is an extension of __init__ function and dedicated to the convolution module creation @@ -497,12 +498,11 @@ class ConvModule(nn.Module): self.pw_conv_simplify_w = torch.nn.Parameter(torch.ones(3)) self.pw_conv_simplify_b = torch.nn.Parameter(torch.zeros(3)) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: """ConvModule Forward. Args: - x: torch.Tensor - input tensor. + x: input tensor. """ x = self.layer_norm(x) @@ -567,21 +567,20 @@ class GLULinear(nn.Module): def __init__( self, - input_dim, - output_dim, - glu_type="sigmoid", - bias_in_glu=True, - ): + input_dim: int, + output_dim: int, + glu_type: str = "sigmoid", + bias_in_glu: bool = True, + ) -> None: super().__init__() self.linear = nn.Linear(input_dim, output_dim * 2, bias_in_glu) self.glu_act = GLU(-1, glu_type) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: """GLULinear forward Args: - x: torch.Tensor - inpute tensor. + x: input tensor. """ x = self.linear(x) return self.glu_act(x) @@ -609,12 +608,12 @@ class FeedForward(nn.Module): def __init__( self, - d_model, - d_inner, - dropout_rate, - activation="sigmoid", - bias_in_glu=True, - ): + d_model: int, + d_inner: int, + dropout_rate: float, + activation: str = "sigmoid", + bias_in_glu: bool = True, + ) -> None: super().__init__() self.d_model = d_model self.d_inner = d_inner @@ -628,12 +627,11 @@ class FeedForward(nn.Module): nn.Dropout(dropout_rate), ) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: """FeedForward forward function. Args: - x: torch.Tensor - input tensor. + x: input tensor. """ out = self.net(self.layer_norm(x)) @@ -642,14 +640,14 @@ class FeedForward(nn.Module): #### positional encoding starts here def _pre_hook( - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, -): + state_dict: dict, + prefix: str, + local_metadata: dict, + strict: bool, + missing_keys: list[str], + unexpected_keys: list[str], + error_msgs: list[str], +) -> None: """Perform pre-hook in load_state_dict for backward compatibility. Note: @@ -708,10 +706,10 @@ class T5RelativeAttentionLogitBias(nn.Module): """ def __init__(self, - num_heads, - num_buckets=-1, - max_distance=1000, - symmetric=False): + num_heads: int, + num_buckets: int = -1, + max_distance: int = 1000, + symmetric: bool = False) -> None: super().__init__() self.num_heads = num_heads self.num_buckets = num_buckets @@ -727,7 +725,7 @@ class T5RelativeAttentionLogitBias(nn.Module): self.num_buckets *= 2 self.bias_values = nn.Embedding(self.num_buckets, self.num_heads) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: # instantiate bias compatible with shape of x maxpos = x.size(1) context_position = torch.arange(maxpos, @@ -760,7 +758,7 @@ class T5RelativeAttentionLogitBias(nn.Module): return t5_rel_att_bias - def _bucket_relative_position(self, relative_position): + def _bucket_relative_position(self, relative_position: Tensor) -> Tensor: # this is a placeholder (isn't tested, likely buggy) using HuggingFace # implem as a reference this also needs to be extended to support # asymmetric +/- ve positions @@ -810,7 +808,10 @@ class AbsolutePositionalEncoding(nn.Module): """ - def __init__(self, d_model, dropout_rate, max_len=5000): + def __init__(self, + d_model: int, + dropout_rate: float, + max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super().__init__() self.d_model = d_model @@ -820,11 +821,11 @@ class AbsolutePositionalEncoding(nn.Module): self.extend_pe(torch.tensor(0.0).expand(1, max_len)) self._register_load_state_dict_pre_hook(_pre_hook) - def extend_pe(self, x): + def extend_pe(self, x: torch.Tensor) -> None: """Reset the positional encodings. Args: - x: torch.Tensor + x: input tensor """ if self.pe is not None and self.pe.size(1) >= x.size(1): if self.pe.dtype != x.dtype or self.pe.device != x.device: @@ -840,15 +841,14 @@ class AbsolutePositionalEncoding(nn.Module): pe = pe.unsqueeze(0) self.pe = pe.to(device=x.device, dtype=x.dtype) - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: """Add positional encoding. Args: - x: torch.Tensor - Input tensor. shape is (batch, time, ...) + x: Input tensor. shape is (batch, time, ...) Returns: - torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) + Encoded tensor. Its shape is (batch, time, ...) """ self.extend_pe(x) @@ -868,7 +868,7 @@ class MeanVarianceNormLayer(nn.Module): layer input size. """ - def __init__(self, input_size): + def __init__(self, input_size: int) -> None: super().__init__() self.input_size = input_size self.global_mean = nn.Parameter(torch.zeros(input_size)) @@ -878,8 +878,7 @@ class MeanVarianceNormLayer(nn.Module): """MeanVarianceNormLayer Forward Args: - input_: torch.Tensor - input tensor. + input_: input tensor. """ return (input_ - self.global_mean) * self.global_invstd @@ -949,7 +948,10 @@ class CausalConv1D(nn.Conv1d): dtype=dtype, ) - def update_cache(self, x, cache=None): + def update_cache( + self, + x: Tensor, + cache: Optional[Tensor] = None) -> tuple[Tensor, Optional[Tensor]]: if cache is None: new_x = F.pad(x, pad=(self._left_padding, self._right_padding)) next_cache = cache @@ -963,7 +965,11 @@ class CausalConv1D(nn.Conv1d): next_cache = next_cache[:, :, -cache.size(-1):] return new_x, next_cache - def forward(self, x, cache=None): + def forward( + self, + x: Tensor, + cache: Optional[Tensor] = None + ) -> Union[Tensor, tuple[Tensor, Optional[Tensor]]]: x, cache = self.update_cache(x, cache=cache) x = super().forward(x) if cache is None: @@ -1017,8 +1023,8 @@ class CausalConv2D(nn.Conv2d): def forward( self, - x, - ): + x: Tensor, + ) -> Tensor: x = F.pad( x, pad=(self._left_padding, self._right_padding, 0, 0), @@ -1062,16 +1068,16 @@ class NemoConvSubsampling(torch.nn.Module): """ def __init__( - self, - feat_in, - feat_out, - subsampling_factor=4, - subsampling="dw_striding", - conv_channels=256, - subsampling_conv_chunking_factor=1, - activation=nn.ReLU(), # noqa: B008 - is_causal=False, - ): + self, + feat_in: int, + feat_out: int, + subsampling_factor: int = 4, + subsampling: str = "dw_striding", + conv_channels: int = 256, + subsampling_conv_chunking_factor: int = 1, + activation: torch.nn.Module = nn.ReLU(), # noqa: B008 + is_causal: bool = False, + ) -> None: super().__init__() self._subsampling = subsampling self._conv_channels = conv_channels @@ -1328,28 +1334,25 @@ class NemoConvSubsampling(torch.nn.Module): self.conv = torch.nn.Sequential(*layers) - def get_sampling_frames(self): + def get_sampling_frames(self) -> list[int]: return [1, self.subsampling_factor] - def get_streaming_cache_size(self): + def get_streaming_cache_size(self) -> list[int]: return [0, self.subsampling_factor + 1] - def forward(self, x, mask): + def forward(self, x: Tensor, + mask: Optional[Tensor]) -> tuple[Tensor, Optional[Tensor]]: """ Forward method for NeMo subsampling. Args: - x[Batch, Time, Filters]: torch.Tensor - input tensor - x_mask: torch.Tensor - input mask + x: input tensor + mask: input mask Returns: - x: torch.Tensor - Resulting tensor from subsampling (B, T // + x: Resulting tensor from subsampling (B, T // time_reduction_factor, feat_out) - pad_mask: torch.Tensor - tensor of padded hidden state sequences (B, 1, T // + pad_mask: tensor of padded hidden state sequences (B, 1, T // time_reduction_factor) """ x = x.unsqueeze(1) if self.conv2d_subsampling else x.transpose(1, 2) @@ -1403,7 +1406,7 @@ class NemoConvSubsampling(torch.nn.Module): padding_length.size(0), -1) < padding_length.unsqueeze(1) return x, pad_mask.unsqueeze(1) - def reset_parameters(self): + def reset_parameters(self) -> None: # initialize weights if self._subsampling == "dw_striding": with torch.no_grad(): @@ -1433,7 +1436,7 @@ class NemoConvSubsampling(torch.nn.Module): torch.nn.init.uniform_(self.out.weight, -fc_scale, fc_scale) torch.nn.init.uniform_(self.out.bias, -fc_scale, fc_scale) - def conv_split_by_batch(self, x): + def conv_split_by_batch(self, x: Tensor) -> tuple[Tensor, bool]: """Tries to split input by batch, run conv and concat results""" b, _, _, _ = x.size() if b == 1: # can't split if batch size is 1 @@ -1460,7 +1463,7 @@ class NemoConvSubsampling(torch.nn.Module): True, ) - def conv_split_by_channel(self, x): + def conv_split_by_channel(self, x: Tensor) -> Tensor: """For dw convs, tries to split input by time, run conv and concat results""" x = self.conv[0](x) # full conv2D @@ -1500,7 +1503,8 @@ class NemoConvSubsampling(torch.nn.Module): x = self.conv[i * 3 + 4](x) # activation return x - def channel_chunked_conv(self, conv, chunk_size, x): + def channel_chunked_conv(self, conv: torch.nn.Module, chunk_size: int, + x: Tensor) -> Tensor: """Performs channel chunked convolution""" ind = 0 @@ -1541,7 +1545,7 @@ class NemoConvSubsampling(torch.nn.Module): return torch.cat(out_chunks, 1) def change_subsampling_conv_chunking_factor( - self, subsampling_conv_chunking_factor: int): + self, subsampling_conv_chunking_factor: int) -> None: if (subsampling_conv_chunking_factor != -1 and subsampling_conv_chunking_factor != 1 and subsampling_conv_chunking_factor % 2 != 0): @@ -1552,12 +1556,12 @@ class NemoConvSubsampling(torch.nn.Module): self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor -def calc_length(lengths, - all_paddings, - kernel_size, - stride, - ceil_mode, - repeat_num=1): +def calc_length(lengths: Tensor, + all_paddings: int, + kernel_size: int, + stride: int, + ceil_mode: bool, + repeat_num: int = 1) -> Tensor: """Calculates the output length of a Tensor passed through a convolution or max pooling layer""" add_pad: float = all_paddings - kernel_size @@ -1573,11 +1577,11 @@ def calc_length(lengths, class AttModule(nn.Module): """Attention abstraction module""" - def __init__(self): + def __init__(self) -> None: super().__init__() self.export_mode = False - def set_export(self, mode=True): + def set_export(self, mode: bool = True) -> None: """set the export mode""" self.export_mode = mode @@ -1591,14 +1595,10 @@ class AttModule(nn.Module): """AttModule forward Args: - x: torch.Tensor - input tensor. - memory: torch.Tensor, optional - memory tensor. - pos_emb: torch.Tensor, optional - positional encoder embedding. - att_mask: torch.Tensor, optional - attention mask tensor. + x: input tensor. + memory: memory tensor. + pos_emb: positional encoder embedding. + att_mask: attention mask tensor. """ return x, memory, pos_emb, att_mask @@ -1606,15 +1606,15 @@ class AttModule(nn.Module): class AttBlock(BlockBase, AttModule): """Attention Block module to support both Attention and Block module.""" - def memory_dims(self, max_len=False): + def memory_dims(self, max_len: bool = False) -> tuple[int, int]: """memory dimensions""" return (1, self.input_size) def masked_softmax( - scores, + scores: Tensor, mask: Optional[Tensor], -): +) -> Tensor: if mask is not None: mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2) scores = scores.masked_fill(mask, -torch.inf) @@ -1636,10 +1636,6 @@ class MultiHeadedAttention(nn.Module): input size features. dropout_rate: float dropout rate. - use_LN: bool - apply layer norm or not - dropout_at_output: bool - whether to apply dropout at output attention_inner_dim: int, optional the attention dimension used in the class, it can be different from the input dimension n_feat. @@ -1666,16 +1662,16 @@ class MultiHeadedAttention(nn.Module): def __init__( self, - n_head, - n_feat, - dropout_rate, - attention_inner_dim=-1, - glu_type="swish", - bias_in_glu=True, - use_pt_scaled_dot_product_attention=False, - n_value=-1, + n_head: int, + n_feat: int, + dropout_rate: float, + attention_inner_dim: int = -1, + glu_type: str = "swish", + bias_in_glu: bool = True, + use_pt_scaled_dot_product_attention: bool = False, + n_value: int = -1, group_size: int = 1, - ): + ) -> None: super().__init__() if n_value == -1: n_value = n_feat @@ -1718,28 +1714,22 @@ class MultiHeadedAttention(nn.Module): query: Tensor, key: Tensor, value: Tensor, - pos_k: Tensor, - pos_v: Tensor, + pos_k: Optional[Tensor], + pos_v: Optional[Tensor], mask: Optional[Tensor], relative_attention_bias: Optional[Tensor] = None, - ): + ) -> Tensor: """Compute 'Scaled Dot Product Attention'. Args: - query: torch.Tensor - query tensor (batch, time1, size) - key: torch.Tensor - key tensor (batch, time2, size) - value: torch.Tensor - value tensor (batch, time1, size) - pos_k: torch.Tensor - key tensor used for relative positional embedding. - pos_v: torch.Tensor - value tensor used for relative positional embedding. - mask: torch.Tensor - mask tensor (batch, time1, time2) - relative_attention_bias: torch.Tensor - bias added to attention logits w.r.t. relative positions + query: query tensor (batch, time1, size) + key: key tensor (batch, time2, size) + value: value tensor (batch, time1, size) + pos_k: key tensor used for relative positional embedding. + pos_v: value tensor used for relative positional embedding. + mask: mask tensor (batch, time1, time2) + relative_attention_bias: bias added to attention logits w.r.t. + relative positions (1, n_head, time1, time2) """ n_batch = query.size(0) @@ -1832,20 +1822,20 @@ class MultiSequential(torch.nn.Sequential): """Multi-input multi-output torch.nn.Sequential""" @torch.jit.ignore - def forward(self, *args): + def forward(self, *args) -> tuple: """Forward method implementation.""" for m in self: args = m(*args) return args -def get_offset(input_layer: str, time_reduction: int): +def get_offset(input_layer: str, time_reduction: int) -> int: """Get an offset. We will use the offset for determining #frames of a subsampled feature. Args: - input_layer (str): Type of an input layer - time_reduction (int): time reduction factor for downsampling a feature + input_layer: Type of an input layer + time_reduction: time reduction factor for downsampling a feature Returns: int: offset """ @@ -1858,13 +1848,14 @@ def get_offset(input_layer: str, time_reduction: int): return 0 -def unfold_tensor(xs_pad, max_seq_len): +def unfold_tensor(xs_pad: Tensor, max_seq_len: int) -> Tensor: """ For a given tensor with shape of (N, T, D), if sequence length T is longer than max_seq_len, this function unfold it to a (NT', max_seq_len, D) where T' is T // max_seq_len. Args: - xs_pad: N, T, D + xs_pad: input tensor with shape (N, T, D) + max_seq_len: maximum sequence length """ _, _, D = xs_pad.shape xs_pad = xs_pad.transpose(-1, -2) # convert to N, D, T diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index e7f5799a80067..142d3251bc67a 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -35,7 +35,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - NestedTensors) + MultiModalUUIDDict, NestedTensors) from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -316,14 +316,14 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo] mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], - mm_hash_overrides: Optional[dict[str, list[str]]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> tuple[list[int], MultiModalProcessingInfo, bool]: prompt_ids, mm_info, _ = super()._cached_apply_hf_processor( prompt=prompt, mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) # NOTE: The tokens are already inserted by the chat template diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index d05eb76cdf6fd..a7e71309b6074 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -41,6 +41,7 @@ from transformers.models.whisper import WhisperFeatureExtractor from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding +from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.qwen2_5_vl import ( Qwen2_5_VisionTransformer, Qwen2_5_VLImageEmbeddingInputs, Qwen2_5_VLImageInputs, Qwen2_5_VLImagePixelInputs, @@ -66,7 +67,8 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import decode_tokens, encode_tokens from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .interfaces import (MultiModalEmbeddings, SupportsLoRA, + SupportsMultiModal, SupportsPP) from .utils import (AutoWeightsLoader, WeightsMapper, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) @@ -549,6 +551,8 @@ class Qwen2_5OmniConditionalGenerationMixin: raise ValueError(f"Incorrect type of {name}. " f"Got type: {type(mm_input)}") if isinstance(mm_input, torch.Tensor): + if dim == 0: + return mm_input.reshape(-1, *mm_input.shape[2:]) return torch.concat(list(mm_input), dim=dim) else: return torch.concat(mm_input, dim=dim) @@ -726,7 +730,7 @@ class Qwen2_5OmniConditionalGenerationMixin: dummy_inputs=Qwen2_5OmniThinkerDummyInputsBuilder, ) class Qwen2_5OmniThinkerForConditionalGeneration( - nn.Module, SupportsMultiModal, SupportsPP, + nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, Qwen2_5OmniConditionalGenerationMixin): hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ @@ -734,6 +738,22 @@ class Qwen2_5OmniThinkerForConditionalGeneration( "thinker.model.": "language_model.model.", "thinker.": "", }) + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "attn.qkv": [ + "attn.q", + "attn.k", + "attn.v", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -828,7 +848,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration( return [] # The result multimodal_embeddings is tuple of tensors, with each - # tensor correspoending to a multimodal data item (image or video). + # tensor corresponding to a multimodal data item (image or video). multimodal_embeddings: tuple[torch.Tensor, ...] = () # NOTE: It is important to iterate over the keys in this dictionary @@ -855,7 +875,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration( if multimodal_embeddings is not None \ and len(multimodal_embeddings) != 0: - # TODO (ywang96): support overlapping modalitiy embeddings so that + # TODO (ywang96): support overlapping modality embeddings so that # `use_audio_in_video` will work on V1. inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, multimodal_embeddings, [ @@ -956,3 +976,12 @@ class Qwen2_5OmniThinkerForConditionalGeneration( mapper=self.hf_to_vllm_mapper) return loaded_weights + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="merger.", + tower_model=["visual.", "audio_tower."]) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index afef86fbaa027..fc028aa2287a2 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -38,6 +38,7 @@ from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig) +from vllm.attention.layer import check_upstream_fa_availability from vllm.config import VllmConfig from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils @@ -49,7 +50,6 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, - ReplicatedLinear, RowParallelLinear) # yapf: enable from vllm.model_executor.layers.quantization import QuantizationConfig @@ -64,6 +64,7 @@ from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope +from vllm.utils import is_pin_memory_available from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, @@ -299,7 +300,16 @@ class Qwen2_5_VisionAttention(nn.Module): disable_tp=use_data_parallel) # Detect attention implementation. - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend = get_vit_attn_backend( + head_size=self.hidden_size_per_attention_head, + dtype=torch.get_default_dtype()) + self.use_upstream_fa = False + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability( + torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN + self.use_upstream_fa = True + if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.ROCM_AITER_FA @@ -360,7 +370,10 @@ class Qwen2_5_VisionAttention(nn.Module): if self.attn_backend == _Backend.ROCM_AITER_FA: from aiter import flash_attn_varlen_func else: - from flash_attn import flash_attn_varlen_func + if self.use_upstream_fa: + from flash_attn import flash_attn_varlen_func + else: + from vllm.vllm_flash_attn import flash_attn_varlen_func q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) @@ -510,32 +523,32 @@ class Qwen2_5_VisionPatchMerger(nn.Module): norm_layer = partial(nn.LayerNorm, eps=1e-6) self.ln_q = norm_layer(context_dim) - cls_fc1 = (ReplicatedLinear - if use_data_parallel else ColumnParallelLinear) - cls_fc2 = (ReplicatedLinear - if use_data_parallel else RowParallelLinear) - self.mlp = nn.ModuleList([ - cls_fc1(self.hidden_size, - self.hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.mlp.0"), + self.mlp = nn.Sequential( + ColumnParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.0", + return_bias=False, + disable_tp=use_data_parallel, + ), nn.GELU(), - cls_fc2(self.hidden_size, - d_model, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.mlp.2"), - ]) + RowParallelLinear( + self.hidden_size, + d_model, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.2", + return_bias=False, + disable_tp=use_data_parallel, + ), + ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.ln_q(x) x = x.view(-1, self.hidden_size) - - mlp_fc1, mlp_act, mlp_fc2 = self.mlp - x_parallel, _ = mlp_fc1(x) - x_parallel = mlp_act(x_parallel) - out, _ = mlp_fc2(x_parallel) + out = self.mlp(x) return out @@ -629,7 +642,12 @@ class Qwen2_5_VisionTransformer(nn.Module): prefix=f"{prefix}.merger", use_data_parallel=use_data_parallel, ) - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend = get_vit_attn_backend( + head_size=head_dim, dtype=torch.get_default_dtype()) + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability( + torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN @property def dtype(self) -> torch.dtype: @@ -717,6 +735,15 @@ class Qwen2_5_VisionTransformer(nn.Module): seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() return max_seqlen, seqlens + @staticmethod + def invert_permutation(perm: torch.Tensor) -> torch.Tensor: + # building the inverse permutation in O(n) time + inv = torch.empty_like(perm, pin_memory=is_pin_memory_available()) + inv[perm] = torch.arange(perm.numel(), + device=perm.device, + dtype=perm.dtype) + return inv + def forward( self, x: torch.Tensor, @@ -760,6 +787,8 @@ class Qwen2_5_VisionTransformer(nn.Module): rotary_pos_emb = torch.cat(rotary_pos_emb) window_index = torch.cat(window_index) + # compute reverse indices + reverse_indices = self.invert_permutation(window_index) cu_window_seqlens = torch.cat(cu_window_seqlens) cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) cu_seqlens = torch.cat(cu_seqlens) @@ -780,6 +809,8 @@ class Qwen2_5_VisionTransformer(nn.Module): non_blocking=True) window_index = window_index.to(device=hidden_states.device, non_blocking=True) + reverse_indices = reverse_indices.to(device=hidden_states.device, + non_blocking=True) hidden_states = hidden_states.reshape( seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) @@ -813,7 +844,6 @@ class Qwen2_5_VisionTransformer(nn.Module): # adapter hidden_states = self.merger(hidden_states) - reverse_indices = torch.argsort(window_index) hidden_states = hidden_states[reverse_indices, :] return hidden_states @@ -959,7 +989,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, raise ValueError(f"{name} should be 2D or batched 3D tensor. " f"Got ndim: {mm_input.ndim} " f"(shape={mm_input.shape})") - return torch.concat(list(mm_input)) + return mm_input.reshape(-1, mm_input.shape[-1]) else: return torch.concat(mm_input) @@ -1183,21 +1213,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, input_ids: Flattened (concatenated) input_ids corresponding to a batch. positions: Flattened (concatenated) position ids corresponding to a - batch. - **NOTE**: If mrope is enabled (default setting for Qwen2.5-VL - opensource models), the shape will be `(3, seq_len)`, + batch. **NOTE**: If mrope is enabled (default setting for + Qwen2.5-VL opensource models), the shape will be `(3, seq_len)`, otherwise it will be `(seq_len,). - pixel_values: Pixel values to be fed to a model. - `None` if no images are passed. - image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM. - `None` if no images are passed. - pixel_values_videos: Pixel values of videos to be fed to a model. - `None` if no videos are passed. - video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM. - `None` if no videos are passed. - second_per_grid_ts: Tensor `(num_videos)` of video time interval ( - in seconds) for each grid along the temporal dimension in the - 3D position IDs. `None` if no videos are passed. """ if intermediate_tensors is not None: diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 54ec7b8627488..c797b71b5d2e1 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -342,7 +342,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, raise ValueError(f"Incorrect type of {name}. " f"Got type: {type(mm_input)}") if isinstance(mm_input, torch.Tensor): - return torch.concat(list(mm_input)) + return mm_input.reshape(-1, *mm_input.shape[2:]) else: return torch.concat(mm_input) diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index 421b43563bade..2bd9d2b52628a 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -53,15 +53,18 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP): self.quant_config = quant_config self.model = Qwen2Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) + self.head_dtype = vllm_config.model_config.head_dtype self.score = nn.Sequential( ColumnParallelLinear(config.hidden_size, config.hidden_size, quant_config=quant_config, + params_dtype=self.head_dtype, return_bias=False), nn.ReLU(), RowParallelLinear(config.hidden_size, config.num_labels, + params_dtype=self.head_dtype, quant_config=quant_config, return_bias=False), ) @@ -80,6 +83,7 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP): ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) + hidden_states = hidden_states.to(self.head_dtype) logits = self.score(hidden_states) return logits diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index f00b214b1ef18..d08181c5fd53b 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -41,6 +41,7 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize from transformers.models.qwen2_vl.video_processing_qwen2_vl import ( Qwen2VLVideoProcessor) +from vllm.attention.layer import check_upstream_fa_availability from vllm.config import VllmConfig from vllm.distributed import parallel_state, tensor_model_parallel_all_gather from vllm.distributed import utils as dist_utils @@ -314,7 +315,16 @@ class Qwen2VisionAttention(nn.Module): prefix=f"{prefix}.proj") # Detect attention implementation. - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend = get_vit_attn_backend( + head_size=self.hidden_size_per_attention_head, + dtype=torch.get_default_dtype()) + self.use_upstream_fa = False + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability( + torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN + self.use_upstream_fa = True + if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.ROCM_AITER_FA @@ -374,7 +384,10 @@ class Qwen2VisionAttention(nn.Module): if self.attn_backend == _Backend.ROCM_AITER_FA: from aiter import flash_attn_varlen_func else: - from flash_attn import flash_attn_varlen_func + if self.use_upstream_fa: + from flash_attn import flash_attn_varlen_func + else: + from vllm.vllm_flash_attn import flash_attn_varlen_func q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) @@ -628,7 +641,12 @@ class Qwen2VisionTransformer(nn.Module): quant_config=quant_config, prefix=f"{prefix}.merger", ) - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend = get_vit_attn_backend( + head_size=head_dim, dtype=torch.get_default_dtype()) + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability( + torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN @property def dtype(self) -> torch.dtype: @@ -1149,7 +1167,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, raise ValueError(f"{name} should be 2D or batched 3D tensor. " f"Got ndim: {mm_input.ndim} " f"(shape={mm_input.shape})") - return torch.concat(list(mm_input)) + return mm_input.reshape(-1, mm_input.shape[-1]) else: return torch.concat(mm_input) @@ -1218,6 +1236,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, grid_thw = image_input["image_grid_thw"] assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() if image_input["type"] == "image_embeds": image_embeds = image_input["image_embeds"] @@ -1227,15 +1246,17 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, # Split concatenated embeddings for each image item. merge_size = self.visual.spatial_merge_size - sizes = grid_thw.prod(-1) // merge_size // merge_size + sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) // + (merge_size * merge_size)).tolist() - return image_embeds.split(sizes.tolist()) + return image_embeds.split(sizes) def _process_video_input( self, video_input: Qwen2VLVideoInputs) -> tuple[torch.Tensor, ...]: grid_thw = video_input["video_grid_thw"] assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() if video_input["type"] == "video_embeds": video_embeds = video_input["video_embeds"] @@ -1245,9 +1266,10 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, # Split concatenated embeddings for each video item. merge_size = self.visual.spatial_merge_size - sizes = grid_thw.prod(-1) // merge_size // merge_size + sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) // + (merge_size * merge_size)).tolist() - return video_embeds.split(sizes.tolist()) + return video_embeds.split(sizes) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: modalities = {} @@ -1350,15 +1372,9 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, batch. **NOTE**: If mrope is enabled (default setting for Qwen2-VL opensource models), the shape will be `(3, seq_len)`, - otherwise it will be `(seq_len,). - pixel_values: Pixel values to be fed to a model. - `None` if no images are passed. - image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM. - `None` if no images are passed. - pixel_values_videos: Pixel values of videos to be fed to a model. - `None` if no videos are passed. - video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM. - `None` if no videos are passed. + otherwise it will be `(seq_len,)`. + intermediate_tensors: Intermediate tensors from prior forward pass. + inputs_embeds: Optional tensor of input embeddings. """ if intermediate_tensors is not None: diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index a7e0a00350e6b..85429b3a01f92 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -170,8 +170,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module): return quant_config def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - # NOTE: hidden_states can have either 1D or 2D shape. - orig_shape = hidden_states.shape + assert hidden_states.dim( + ) <= 2, "Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs" + is_input_1d = hidden_states.dim() == 1 hidden_dim = hidden_states.shape[-1] hidden_states = hidden_states.view(-1, hidden_dim) @@ -180,7 +181,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module): final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=router_logits) - return final_hidden_states.view(orig_shape) + # return to 1d if input is 1d + return final_hidden_states.squeeze(0) if is_input_1d else \ + final_hidden_states class Qwen3MoeAttention(nn.Module): diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py new file mode 100644 index 0000000000000..86e26da5b9b86 --- /dev/null +++ b/vllm/model_executor/models/qwen3_next.py @@ -0,0 +1,1298 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Inference-only Qwen3Next model.""" +from collections.abc import Iterable +from itertools import islice +from typing import Optional + +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import nn +from transformers.activations import ACT2FN + +from vllm import envs +from vllm.attention import Attention, AttentionBackend, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile +from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig, + VllmConfig, get_current_vllm_config) +from vllm.distributed import (divide, get_ep_group, get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.logger import init_logger +from vllm.model_executor.layers.fla.ops import ( + RMSNormGated, chunk_gated_delta_rule, fused_recurrent_gated_delta_rule) +from vllm.model_executor.layers.fused_moe import FusedMoE +# yapf conflicts with isort for this block +# yapf: disable +from vllm.model_executor.layers.layernorm import ( + GemmaRMSNorm as Qwen3NextRMSNorm) +# yapf: enable +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.abstract import MambaBase +from vllm.model_executor.layers.mamba.mamba_mixer2 import ( + mamba_v2_sharded_weight_loader) +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateDtypeCalculator, MambaStateShapeCalculator) +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.gptq import GPTQConfig +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, sharded_weight_loader) +from vllm.model_executor.models.mamba_cache import MambaCacheParams +from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import Qwen3NextConfig +from vllm.triton_utils import tl, triton +from vllm.utils import direct_register_custom_op +from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata + +from .interfaces import (HasInnerState, IsHybrid, MixtureOfExperts, + SupportsLoRA, SupportsPP) +from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +logger = init_logger(__name__) + +KVCache = tuple[torch.Tensor, torch.Tensor] + + +class Qwen3NextSparseMoeBlock(nn.Module): + + def __init__( + self, + config: Qwen3NextConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + enable_eplb: bool = False, + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + + self.ep_group = get_ep_group().device_group + self.ep_rank = self.ep_group.rank() + self.ep_size = self.ep_group.size() + self.n_routed_experts = config.num_experts + + if self.tp_size > config.num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.num_experts}.") + + # Load balancing settings. + vllm_config = get_current_vllm_config() + eplb_config = vllm_config.parallel_config.eplb_config + self.enable_eplb = enable_eplb + + self.n_logical_experts = self.n_routed_experts + self.n_redundant_experts = eplb_config.num_redundant_experts + self.n_physical_experts = (self.n_logical_experts + + self.n_redundant_experts) + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + + self.physical_expert_start = (self.ep_rank * + self.n_local_physical_experts) + self.physical_expert_end = (self.physical_expert_start + + self.n_local_physical_experts) + + self.experts = FusedMoE(num_experts=self.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts", + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts) + + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=False, + quant_config=self._maybe_ignore_quant_config(quant_config), + prefix=f"{prefix}.gate") + + if config.shared_expert_intermediate_size > 0: + self.shared_expert = Qwen3NextMLP( + hidden_size=config.hidden_size, + intermediate_size=config.shared_expert_intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=self.experts.must_reduce_shared_expert_outputs( + ), + ) + else: + self.shared_expert = None + self.shared_expert_gate = torch.nn.Linear(config.hidden_size, + 1, + bias=False) + + def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): + # GPTQ configs do not have a list of ignored modules, however AutoGPTQ + # seems to avoid gate quantization. + # See: https://huggingface.co/Qwen/Qwen3-30B-A3B-GPTQ-Int4 + if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)): + return None + return quant_config + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] + hidden_states = hidden_states.view(-1, hidden_dim) + + shared_output = None + if self.shared_expert is not None: + shared_output = self.shared_expert(hidden_states) + if self.shared_expert_gate is not None: + shared_output = F.sigmoid( + self.shared_expert_gate(hidden_states)) * shared_output + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts(hidden_states=hidden_states, + router_logits=router_logits) + + if shared_output is not None: + final_hidden_states = final_hidden_states + shared_output + if self.tp_size > 1: + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 + final_hidden_states) + + return final_hidden_states.view(orig_shape) + + +class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): + + @property + def mamba_type(self) -> str: + return "linear_attention" + + def get_attn_backend(self) -> type["AttentionBackend"]: + from vllm.v1.attention.backends.gdn_attn import GDNAttentionBackend + return GDNAttentionBackend + + def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]: + return MambaStateDtypeCalculator.gated_delta_net_state_dtype( + self.model_config.dtype, self.cache_config.mamba_cache_dtype) + + def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: + return MambaStateShapeCalculator.gated_delta_net_state_shape( + self.tp_size, + self.num_k_heads, + self.num_v_heads, + self.head_k_dim, + self.head_v_dim, + self.conv_kernel_size, + self.num_spec, + use_v1=True) + + def __init__( + self, + config: Qwen3NextConfig, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + speculative_config: Optional[SpeculativeConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.hidden_size = config.hidden_size + self.num_v_heads = config.linear_num_value_heads + self.num_k_heads = config.linear_num_key_heads + self.head_k_dim = config.linear_key_head_dim + self.head_v_dim = config.linear_value_head_dim + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + + self.conv_kernel_size = config.linear_conv_kernel_dim + self.layer_idx = extract_layer_index(prefix) + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] + self.layer_norm_epsilon = config.rms_norm_eps + self.prefix = prefix + + self.config = config + self.model_config = model_config + self.cache_config = cache_config + self.quant_config = quant_config + self.speculative_config = speculative_config + self.num_spec = (self.speculative_config.num_speculative_tokens + if self.speculative_config else 0) + + # QKV + self.conv_dim = self.key_dim * 2 + self.value_dim + self.conv1d = ColumnParallelLinear( + input_size=self.conv_kernel_size, + output_size=self.conv_dim, + bias=False, + prefix=f"{prefix}.conv1d", + ) + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + + # projection of the input hidden states + self.projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2 + self.projection_size_ba = self.num_v_heads * 2 + self.in_proj = MergedColumnParallelLinear( + input_size=self.hidden_size, + output_sizes=[self.projection_size_qkvz, self.projection_size_ba], + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.in_proj", + ) + + query_key_settings = (self.key_dim, 0, False) + value_settings = (self.value_dim, 0, False) + + delattr(self.conv1d.weight, "weight_loader") + set_weight_attrs( + self.conv1d.weight, { + "weight_loader": + mamba_v2_sharded_weight_loader([ + query_key_settings, + query_key_settings, + value_settings, + ], self.tp_size, self.tp_rank) + }) + + # selective projection used to make dt, B and C input dependant + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter( + torch.ones(self.num_v_heads // self.tp_size), ) + self.A_log = nn.Parameter( + torch.empty( + divide(self.num_v_heads, self.tp_size), + dtype=torch.float32, + )) + + set_weight_attrs(self.A_log, + {"weight_loader": sharded_weight_loader(0)}) + set_weight_attrs(self.dt_bias, + {"weight_loader": sharded_weight_loader(0)}) + + self.norm = RMSNormGated( + self.head_v_dim, + eps=self.layer_norm_epsilon, + group_size=None, + norm_before_gate=True, + device=torch.cuda.current_device(), + dtype=config.torch_dtype, + ) + + self.out_proj = RowParallelLinear(self.value_dim, + self.hidden_size, + bias=False, + input_is_parallel=True, + quant_config=quant_config, + prefix=f"{prefix}.out_proj") + + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + + def fix_query_key_value_ordering( + self, + mixed_qkvz, + mixed_ba, + ): + """ + Derives `query`, `key` and `value` tensors from `mixed_qkvzba`. + """ + new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + ( + self.num_k_heads // self.tp_size, + (self.head_k_dim + self.head_k_dim + + (self.head_v_dim + self.head_v_dim) * self.num_v_heads // + self.num_k_heads), + ) + new_tensor_shape_ba = mixed_qkvz.size()[:-1] + ( + self.num_k_heads // self.tp_size, + 2 * self.num_v_heads // self.num_k_heads, + ) + + mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz) + mixed_ba = mixed_ba.view(*new_tensor_shape_ba) + + split_arg_list_qkvz = [ + self.head_k_dim, + self.head_k_dim, + (self.num_v_heads // self.num_k_heads * self.head_v_dim), + (self.num_v_heads // self.num_k_heads * self.head_v_dim), + ] + split_arg_list_ba = [ + self.num_v_heads // self.num_k_heads, + self.num_v_heads // self.num_k_heads + ] + + # [b, sq, ng, (hn + hn + np/ng * hn + np/ng + np/ng)] + # --> [b, sq, ng, hn], [b, sq, ng, hn], [b, sq, ng, np/ng * hn], + # [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng], [b, sq, ng, np/ng] + (query, key, value, z) = torch.split(mixed_qkvz, + split_arg_list_qkvz, + dim=2) + (b, a) = torch.split(mixed_ba, split_arg_list_ba, dim=2) + + # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn] + value = value.reshape(value.size(0), -1, self.head_v_dim) + z = z.reshape(z.size(0), -1, self.head_v_dim) + b = b.reshape(b.size(0), self.num_v_heads // self.tp_size) + a = a.reshape(a.size(0), self.num_v_heads // self.tp_size) + + return query, key, value, z, b, a + + def rearrange_mixed_qkv(self, mixed_qkv): + if mixed_qkv is None: + return None, None, None + query, key, value = torch.split( + mixed_qkv, + [ + self.key_dim // self.tp_size, + self.key_dim // self.tp_size, + self.value_dim // self.tp_size, + ], + dim=-1, + ) + query, key = map( + lambda x: rearrange(x, 'l (h d) -> 1 l h d', d=self.head_k_dim), + (query, key)) + value = rearrange(value, 'l (h d) -> 1 l h d', d=self.head_v_dim) + return query, key, value + + def forward( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + cache_params: Optional[MambaCacheParams] = None, + ): + return torch.ops.vllm.gdn_attention( + hidden_states, + output, + self.prefix, + ) + + def _forward( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + ): + forward_context = get_forward_context() + attn_metadata: AttentionMetadata = forward_context.attn_metadata + + if attn_metadata is None: + # V1 profile run + return + + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + assert isinstance(attn_metadata, GDNAttentionMetadata) + has_initial_state = attn_metadata.has_initial_state + spec_query_start_loc = attn_metadata.spec_query_start_loc + non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc + spec_sequence_masks = attn_metadata.spec_sequence_masks + spec_token_masks = attn_metadata.spec_token_masks + spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501 + non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + conv_state = self_kv_cache[0].transpose(-1, -2) + ssm_state = self_kv_cache[1] + num_actual_tokens = (attn_metadata.num_prefill_tokens + + attn_metadata.num_decode_tokens + + attn_metadata.num_spec_decode_tokens) + num_accepted_tokens = attn_metadata.num_accepted_tokens + + # 1. Set up dimensions for reshapes later + projected_states, _ = self.in_proj(hidden_states[:num_actual_tokens]) + if spec_token_masks is not None: + spec_token_masks = spec_token_masks[:num_actual_tokens] + projected_states_qkvz, projected_states_ba = torch.split( + projected_states, + [ + self.projection_size_qkvz // self.tp_size, + self.projection_size_ba // self.tp_size + ], + dim=-1, + ) + query, key, value, z, b, a = self.fix_query_key_value_ordering( + projected_states_qkvz, projected_states_ba) + query, key, value = map(lambda x: rearrange(x, 'l p d -> l (p d)'), + (query, key, value)) + mixed_qkv = torch.cat((query, key, value), dim=-1) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), + self.conv1d.weight.size(2)) + + if spec_sequence_masks is not None: + if (attn_metadata.num_prefills == 0 + and attn_metadata.num_decodes == 0): + mixed_qkv_spec = mixed_qkv + mixed_qkv_non_spec = None + else: + mixed_qkv_spec = mixed_qkv[spec_token_masks] + mixed_qkv_non_spec = mixed_qkv[~spec_token_masks] + else: + mixed_qkv_spec = None + mixed_qkv_non_spec = mixed_qkv + + # 2.1: process the mutli-query part + if spec_sequence_masks is not None: + mixed_qkv_spec = mixed_qkv_spec.view( + attn_metadata.num_spec_decodes, -1, mixed_qkv_spec.size(-1)) + mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b l d -> b d l') + mixed_qkv_spec = causal_conv1d_update( + mixed_qkv_spec, + conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=spec_state_indices_tensor[:, 0] + [:attn_metadata.num_spec_decodes], + num_accepted_tokens=num_accepted_tokens, + validate_data=False, + ) + mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b d l -> (b l) d') + + # 2.2: process the remaining part + if attn_metadata.num_prefills > 0: + # - "cache_indices" updates the conv_state cache in positions + # pointed to by "mamba_cache_params.state_indices_tensor" + mixed_qkv_non_spec = causal_conv1d_fn( + mixed_qkv_non_spec.transpose(0, 1), + conv_weights, + self.conv1d.bias, + activation=self.activation, + conv_states=conv_state, + has_initial_state=has_initial_state, + cache_indices=non_spec_state_indices_tensor, + query_start_loc=non_spec_query_start_loc, + ).transpose(0, 1) + elif attn_metadata.num_decodes > 0: + mixed_qkv_non_spec = causal_conv1d_update( + mixed_qkv_non_spec, + conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=non_spec_state_indices_tensor[:attn_metadata + .num_decodes], + validate_data=True, + ) + else: + mixed_qkv_non_spec = None + + query_spec, key_spec, value_spec = self.rearrange_mixed_qkv( + mixed_qkv_spec) + query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv( + mixed_qkv_non_spec) + + beta = b.sigmoid() + # g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) + g = fused_gdn_gating(self.A_log, a, self.dt_bias) + g, beta = map(lambda x: rearrange(x, 'l d -> 1 l d'), (g, beta)) + + if spec_sequence_masks is not None: + if (attn_metadata.num_prefills == 0 + and attn_metadata.num_decodes == 0): + g_spec = g + beta_spec = beta + g_non_spec = None + beta_non_spec = None + else: + g_spec = g[:, spec_token_masks] + beta_spec = beta[:, spec_token_masks] + g_non_spec = g[:, ~spec_token_masks] + beta_non_spec = beta[:, ~spec_token_masks] + else: + g_spec = None + beta_spec = None + g_non_spec = g + beta_non_spec = beta + + # 3. Recurrent attention + + # 3.1: process the mutlti-query part + if spec_sequence_masks is not None: + core_attn_out_spec, last_recurrent_state = ( + fused_recurrent_gated_delta_rule( + q=query_spec, + k=key_spec, + v=value_spec, + g=g_spec, + beta=beta_spec, + initial_state=ssm_state, + inplace_final_state=True, + cu_seqlens=spec_query_start_loc[:attn_metadata. + num_spec_decodes + 1], + ssm_state_indices=spec_state_indices_tensor, + num_accepted_tokens=num_accepted_tokens, + use_qk_l2norm_in_kernel=True, + )) + else: + core_attn_out_spec, last_recurrent_state = None, None + + # 3.2: process the remaining part + if attn_metadata.num_prefills > 0: + initial_state = ssm_state[ + non_spec_state_indices_tensor].contiguous() + initial_state[~has_initial_state, ...] = 0 + ( + core_attn_out_non_spec, + last_recurrent_state, + ) = chunk_gated_delta_rule( + q=query_non_spec, + k=key_non_spec, + v=value_non_spec, + g=g_non_spec, + beta=beta_non_spec, + initial_state=initial_state, + output_final_state=True, + cu_seqlens=non_spec_query_start_loc, + head_first=False, + use_qk_l2norm_in_kernel=True, + ) + # Init cache + ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to( + ssm_state.dtype) + elif attn_metadata.num_decodes > 0: + core_attn_out_non_spec, last_recurrent_state = ( + fused_recurrent_gated_delta_rule( + q=query_non_spec, + k=key_non_spec, + v=value_non_spec, + g=g_non_spec, + beta=beta_non_spec, + initial_state=ssm_state, + inplace_final_state=True, + cu_seqlens=non_spec_query_start_loc[:attn_metadata. + num_decodes + 1], + ssm_state_indices=non_spec_state_indices_tensor, + use_qk_l2norm_in_kernel=True, + )) + else: + core_attn_out_non_spec, last_recurrent_state = None, None + + # Merge core attention output + if (spec_sequence_masks is not None + and core_attn_out_non_spec is not None): + core_attn_out = torch.empty( + (1, num_actual_tokens, *core_attn_out_spec.shape[2:]), + dtype=core_attn_out_non_spec.dtype, + device=core_attn_out_non_spec.device, + ) + core_attn_out[:, spec_token_masks] = core_attn_out_spec + core_attn_out[:, ~spec_token_masks] = core_attn_out_non_spec + elif spec_sequence_masks is not None: + core_attn_out = core_attn_out_spec + else: + core_attn_out = core_attn_out_non_spec + + z_shape_og = z.shape + # reshape input data into 2D tensor + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + z = z.reshape(-1, z.shape[-1]) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(z_shape_og) + core_attn_out = rearrange(core_attn_out, '... h d -> ... (h d)') + + output[:num_actual_tokens], _ = self.out_proj(core_attn_out) + + +class Qwen3NextAttention(nn.Module): + + def __init__( + self, + config: Qwen3NextConfig, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = config.head_dim or (self.hidden_size // self.num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.dual_chunk_attention_config = getattr( + config, "dual_chunk_attention_config", None) + self.attn_output_gate = getattr(config, "attn_output_gate", True) + + self.qkv_proj = QKVParallelLinear( + config.hidden_size, + self.head_dim, + self.total_num_heads * (1 + self.attn_output_gate), + self.total_num_kv_heads, + bias=getattr(config, "qkv_bias", False), + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + head_size=self.head_dim, + rotary_dim=self.head_dim, + max_position=config.max_position_embeddings, + base=config.rope_theta, + rope_scaling=config.rope_scaling, + partial_rotary_factor=config.partial_rotary_factor, + dual_chunk_attention_config=self.dual_chunk_attention_config, + ) + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + **{ + "layer_idx": extract_layer_index(prefix), + "dual_chunk_attention_config": + self.dual_chunk_attention_config, + } if self.dual_chunk_attention_config else {}, + ) + + self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + output: torch.Tensor, + hidden_states: torch.Tensor, + ): + qkv, _ = self.qkv_proj(hidden_states) + + if self.attn_output_gate: + q_gate, k, v = qkv.split( + [self.q_size * 2, self.kv_size, self.kv_size], dim=-1) + orig_shape = q_gate.shape[:-1] + q_gate = q_gate.view(*orig_shape, self.num_heads, -1) + q, gate = torch.chunk(q_gate, 2, dim=-1) + q = q.reshape(*orig_shape, -1) + gate = gate.reshape(*orig_shape, -1) + else: + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], + dim=-1) + + q = self.q_norm(q.view(-1, self.num_heads, self.head_dim)).view( + -1, self.num_heads * self.head_dim) + k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim)).view( + -1, self.num_kv_heads * self.head_dim) + + q, k = self.rotary_emb(positions, q, k) + + attn_output = self.attn(q, k, v) + + if self.attn_output_gate: + gate = torch.sigmoid(gate) + attn_output = attn_output * gate + + output[:], _ = self.o_proj(attn_output) + + +class Qwen3NextDecoderLayer(nn.Module): + + def __init__( + self, + config: Qwen3NextConfig, + layer_type: str, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + speculative_config: Optional[SpeculativeConfig] = None, + prefix: str = "", + enable_eplb: bool = False, + ) -> None: + super().__init__() + self.config = config + + self.layer_type = layer_type + self.layer_idx = extract_layer_index(prefix) + + if self.layer_type == "linear_attention": + self.linear_attn = Qwen3NextGatedDeltaNet( + config, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + speculative_config=speculative_config, + prefix=f'{prefix}.linear_attn') + elif self.layer_type == "full_attention": + self.self_attn = Qwen3NextAttention( + config, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f'{prefix}.self_attn', + ) + else: + raise ValueError(f"Invalid layer_type {self.layer_type}") + + mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else + config.mlp_only_layers) + if (self.layer_idx not in mlp_only_layers) and ( + config.num_experts > 0 and + (self.layer_idx + 1) % config.decoder_sparse_step == 0): + self.mlp = Qwen3NextSparseMoeBlock( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + enable_eplb=enable_eplb, + ) + else: + self.mlp = Qwen3NextMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + ) + + self.input_layernorm = Qwen3NextRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3NextRMSNorm( + config.hidden_size, eps=config.rms_norm_eps) + + self.layer_scale = getattr(config, "layer_scale", False) + if self.layer_scale: + self.attn_layer_scale = torch.nn.Parameter( + torch.zeros( + 1, + 1, + self.config.hidden_size, + dtype=config.torch_dtype, + ), ) + self.ffn_layer_scale = torch.nn.Parameter( + torch.zeros( + 1, + 1, + self.config.hidden_size, + dtype=config.torch_dtype, + ), ) + + def forward( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + positions: torch.Tensor = None, + **kwargs: object, + ): + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + self_attention_output = torch.empty_like(hidden_states) + if self.layer_type == "linear_attention": + self.linear_attn( + hidden_states=hidden_states, + output=self_attention_output, + ) + elif self.layer_type == "full_attention": + self.self_attn( + hidden_states=hidden_states, + output=self_attention_output, + positions=positions, + ) + else: + raise ValueError("Invalid layer_type") + hidden_states = self_attention_output + + if self.layer_scale: + if len(hidden_states.shape) == 2: + hidden_states = hidden_states * ( + self.attn_layer_scale.to(hidden_states.dtype)[0] + 1) + else: + hidden_states = hidden_states * ( + self.attn_layer_scale.to(hidden_states.dtype) + 1) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + + if self.layer_scale: + if len(hidden_states.shape) == 2: + hidden_states = hidden_states * ( + self.ffn_layer_scale.to(hidden_states.dtype)[0] + 1) + else: + assert len(hidden_states.shape) == len( + self.ffn_layer_scale.shape + ), f'shape must be the same {len(hidden_states.shape)}, {len(self.ffn_layer_scale.shape)}' # noqa: E501 + hidden_states = hidden_states * ( + self.ffn_layer_scale.to(hidden_states.dtype) + 1) + + return hidden_states, residual + + +@support_torch_compile +class Qwen3NextModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config: Qwen3NextConfig = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + parallel_config = vllm_config.parallel_config + lora_config = vllm_config.lora_config + speculative_config = vllm_config.speculative_config + enable_eplb = parallel_config.enable_eplb + eplb_config = parallel_config.eplb_config + self.num_redundant_experts = eplb_config.num_redundant_experts + + self.config = config + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + def get_layer(prefix: str): + return Qwen3NextDecoderLayer( + config, + layer_type=config.layer_types[extract_layer_index(prefix)], + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + speculative_config=speculative_config, + prefix=prefix, + enable_eplb=enable_eplb, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + if get_pp_group().is_last_rank: + self.norm = Qwen3NextRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in islice(self.layers, self.start_layer, self.end_layer): + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + residual=residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts, + num_redundant_experts=self.num_redundant_experts) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ("in_proj", "in_proj_qkvz", 0), + ("in_proj", "in_proj_ba", 1), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + if name.startswith("mtp."): + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + + if "mlp.experts" in name: + continue + + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # name = apply_attn_prefix(name, params_dict) + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, + MixtureOfExperts, IsHybrid): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": ["gate_proj", "up_proj"], + "in_proj": ["in_proj_qkvz", "in_proj_ba"], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config + assert not cache_config.enable_prefix_caching, \ + "Qwen3Next currently does not support prefix caching" + assert envs.VLLM_USE_V1, "Qwen3Next requires VLLM_USE_V1" + self.quant_config = vllm_config.quant_config + + super().__init__() + self.config = config + self.scheduler_config = scheduler_config + self.model = Qwen3NextModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + ) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + # Set MoE hyperparameters + self.expert_weights = [] + + self.moe_layers: list[FusedMoE] = [] + example_layer = None + for layer in self.model.layers: + if isinstance(layer, PPMissingLayer): + continue + + assert isinstance(layer, Qwen3NextDecoderLayer) + if isinstance(layer.mlp, Qwen3NextSparseMoeBlock): + example_layer = layer.mlp + self.moe_layers.append(layer.mlp.experts) + + if example_layer is None: + raise RuntimeError("No Qwen3Next layer found in the model.layers.") + + self.num_moe_layers = len(self.moe_layers) + self.num_expert_groups = 1 + self.num_shared_experts = 0 + self.num_logical_experts = example_layer.n_logical_experts + self.num_physical_experts = example_layer.n_physical_experts + self.num_local_physical_experts = example_layer.n_local_physical_experts + self.num_routed_experts = example_layer.n_routed_experts + self.num_redundant_experts = example_layer.n_redundant_experts + + def set_eplb_state( + self, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + for layer_idx, layer in enumerate(self.moe_layers): + # Register the expert weights. + self.expert_weights.append(layer.get_expert_weights()) + layer.set_eplb_state( + moe_layer_idx=layer_idx, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = (num_physical_experts - + self.num_logical_experts) + for layer in self.model.layers: + if isinstance(layer.mlp, Qwen3NextSparseMoeBlock): + moe = layer.mlp + moe.n_local_physical_experts = num_local_physical_experts + moe.n_physical_experts = num_physical_experts + moe.n_redundant_experts = self.num_redundant_experts + moe.experts.update_expert_map() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ): + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + + return hidden_states + + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype]: + return MambaStateDtypeCalculator.gated_delta_net_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype) + + @classmethod + def get_mamba_state_shape_from_config( + cls, vllm_config: "VllmConfig" + ) -> tuple[tuple[int, int], tuple[int, int]]: + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_config + tp_size = parallel_config.tensor_parallel_size + num_spec = (vllm_config.speculative_config.num_speculative_tokens + if vllm_config.speculative_config else 0) + return MambaStateShapeCalculator.gated_delta_net_state_shape( + tp_size, + hf_config.linear_num_key_heads, + hf_config.linear_num_value_heads, + hf_config.linear_key_head_dim, + hf_config.linear_value_head_dim, + hf_config.linear_conv_kernel_dim, + num_spec, + use_v1=True) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=["mtp."], + ) + return loader.load_weights(weights) + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() + + +def gdn_attention( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + self._forward(hidden_states=hidden_states, output=output) + + +def gdn_attention_fake( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="gdn_attention", + op_func=gdn_attention, + mutates_args=["output"], + fake_impl=gdn_attention_fake, + dispatch_key=current_platform.dispatch_key, +) + + +# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) +@triton.jit +def fused_gdn_gating_kernel( + g, + A_log, + a, + dt_bias, + seq_len, + NUM_HEADS: tl.constexpr, + beta: tl.constexpr, + threshold: tl.constexpr, + BLK_HEADS: tl.constexpr, +): + i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2) + head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS) + off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off + mask = head_off < NUM_HEADS + blk_A_log = tl.load(A_log + head_off, mask=mask) + blk_a = tl.load(a + off, mask=mask) + blk_bias = tl.load(dt_bias + head_off, mask=mask) + # If the model is loaded in fp16, without the .float() here, A might be -inf + x = blk_a.to(tl.float32) + blk_bias.to(tl.float32) + softplus_x = tl.where(beta * x <= threshold, + (1 / beta) * tl.log(1 + tl.exp(beta * x)), x) + blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x + tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask) + + +def fused_gdn_gating( + A_log: torch.Tensor, + a: torch.Tensor, + dt_bias: torch.Tensor, + beta: float = 1.0, + threshold: float = 20.0, +) -> torch.Tensor: + batch, num_heads = a.shape + seq_len = 1 + grid = (batch, seq_len, triton.cdiv(num_heads, 8)) + g = torch.empty_like(a, dtype=torch.float32) + fused_gdn_gating_kernel[grid](g, + A_log, + a, + dt_bias, + seq_len, + num_heads, + beta, + threshold, + 8, + num_warps=1) + return g diff --git a/vllm/model_executor/models/qwen3_next_mtp.py b/vllm/model_executor/models/qwen3_next_mtp.py new file mode 100644 index 0000000000000..e7aff377e9aec --- /dev/null +++ b/vllm/model_executor/models/qwen3_next_mtp.py @@ -0,0 +1,285 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Inference-only Qwen3Next MTP model.""" +from collections.abc import Iterable +from typing import Optional + +import torch +from torch import nn + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.distributed.parallel_state import get_pp_group +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.linear import ColumnParallelLinear +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.qwen3_next import (Qwen3NextDecoderLayer, + Qwen3NextRMSNorm) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import Qwen3NextConfig + +from .interfaces import SupportsPP +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, maybe_prefix) + +logger = init_logger(__name__) + +KVCache = tuple[torch.Tensor, torch.Tensor] + + +@support_torch_compile +class Qwen3NextMultiTokenPredictor(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + config: Qwen3NextConfig = model_config.hf_config + + self.config = config + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.mtp_start_layer_idx = config.num_hidden_layers + self.num_mtp_layers = getattr(config, "num_nextn_predict_layers", 1) + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + self.fc = ColumnParallelLinear(self.config.hidden_size * 2, + self.config.hidden_size, + gather_output=True, + bias=False, + return_bias=False) + + self.layers = torch.nn.ModuleList( + Qwen3NextDecoderLayer( + config, + layer_type="full_attention", + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f'{prefix}.layers.{self.mtp_start_layer_idx + idx}', + ) for idx in range(self.num_mtp_layers)) + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + self.norm = Qwen3NextRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_fc_norm_hidden = Qwen3NextRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_fc_norm_embedding = Qwen3NextRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings(input_ids) + assert hidden_states.shape[-1] == inputs_embeds.shape[-1] + inputs_embeds = self.pre_fc_norm_embedding(inputs_embeds) + hidden_states = self.pre_fc_norm_hidden(hidden_states) + hidden_states = torch.cat([inputs_embeds, hidden_states], dim=-1) + hidden_states = self.fc(hidden_states) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + current_step_idx = (spec_step_idx % self.num_mtp_layers) + hidden_states, residual = self.layers[current_step_idx]( + positions=positions, + hidden_states=hidden_states, + residual=residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + + if "mlp.experts" in name: + continue + + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +@support_torch_compile +class Qwen3NextMTP(nn.Module, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": ["up_proj", "down_proj"] + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + self.vllm_config = vllm_config + cache_config = vllm_config.cache_config + assert not cache_config.enable_prefix_caching, \ + "Qwen3NextMTP currently does not support prefix caching" + + self.quant_config = vllm_config.quant_config + + super().__init__() + self.config = config + self.model = Qwen3NextMultiTokenPredictor(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "model")) + self.unpadded_vocab_size = config.vocab_size + self.lm_head = ParallelLMHead(self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ): + hidden_states = self.model(input_ids, positions, hidden_states, + intermediate_tensors, inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + spec_step_idx: int = 0, + ) -> Optional[torch.Tensor]: + return self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + shared_weight_names = ["embed_tokens", "lm_head"] + + def remap_weight_names(weights): + for name, weight in weights: + if name.startswith("mtp."): + name = name.replace("mtp.", "model.") + elif not any(key in name for key in shared_weight_names): + continue + yield name, weight + + loader = AutoWeightsLoader(self) + return loader.load_weights(remap_weight_names(weights)) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index c522fcab7f333..85759df369850 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -74,6 +74,7 @@ _TEXT_GENERATION_MODELS = { "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"), "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"), "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"), + "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"), "GlmForCausalLM": ("glm", "GlmForCausalLM"), "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"), "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"), @@ -110,7 +111,7 @@ _TEXT_GENERATION_MODELS = { "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"), "MistralForCausalLM": ("llama", "LlamaForCausalLM"), "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), - "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"), + "MotifForCausalLM": ("motif", "MotifForCausalLM"), # transformers's mpt class has lower case "MptForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"), @@ -119,6 +120,7 @@ _TEXT_GENERATION_MODELS = { "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"), "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"), "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"), + "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"), "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"), "OPTForCausalLM": ("opt", "OPTForCausalLM"), "OrionForCausalLM": ("orion", "OrionForCausalLM"), @@ -224,6 +226,7 @@ _MULTIMODAL_MODELS = { "GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"), # noqa: E501 "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"), "InternVLChatModel": ("internvl", "InternVLChatModel"), + "NemotronH_Nano_VL": ("nano_nemotron_vl", "NemotronH_Nano_VL"), "InternS1ForConditionalGeneration": ("interns1", "InternS1ForConditionalGeneration"), # noqa: E501 "InternVLForConditionalGeneration": ("interns1", "InternS1ForConditionalGeneration"), # noqa: E501 "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"), @@ -278,13 +281,13 @@ _SPECULATIVE_DECODING_MODELS = { "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"), "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"), "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), - # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501 - # "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"), + "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"), "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"), "MedusaModel": ("medusa", "Medusa"), + "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"), # Temporarily disabled. # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1. # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 2bfa51162910b..ba405be416876 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -8,7 +8,7 @@ import torch from torch import nn from transformers import RobertaConfig -from vllm.config import VllmConfig +from vllm.config import ModelConfig, VllmConfig from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool, DispatchPooler, Pooler) from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -73,10 +73,16 @@ class RobertaEmbedding(nn.Module): class RobertaClassificationHead(nn.Module): """Head for sentence-level classification tasks.""" - def __init__(self, config: RobertaConfig): + def __init__(self, model_config: "ModelConfig"): super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + config = model_config.hf_config + head_dtype = model_config.head_dtype + self.dense = nn.Linear(config.hidden_size, + config.hidden_size, + dtype=head_dtype) + self.out_proj = nn.Linear(config.hidden_size, + config.num_labels, + dtype=head_dtype) def forward(self, x: torch.Tensor) -> torch.Tensor: # CLSPool has already been applied in `pooling` @@ -184,7 +190,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): self.roberta = BertModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "bert"), embedding_class=RobertaEmbedding) - self.classifier = RobertaClassificationHead(config) + self.classifier = RobertaClassificationHead(vllm_config.model_config) pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None diff --git a/vllm/model_executor/models/siglip2navit.py b/vllm/model_executor/models/siglip2navit.py index c6244fb3b3e6a..7d90d3a7ef128 100644 --- a/vllm/model_executor/models/siglip2navit.py +++ b/vllm/model_executor/models/siglip2navit.py @@ -13,6 +13,7 @@ from torch.nn import functional as F from transformers import Siglip2VisionConfig from transformers.configuration_utils import PretrainedConfig +from vllm.attention.layer import check_upstream_fa_availability from vllm.config import QuantizationConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn @@ -236,7 +237,15 @@ class Siglip2Attention(nn.Module): self.use_rope = config.use_rope # Detect attention implementation. - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend = get_vit_attn_backend( + head_size=self.head_dim, dtype=torch.get_default_dtype()) + self.use_upstream_fa = False + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability( + torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN + self.use_upstream_fa = True + if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.ROCM_AITER_FA @@ -280,7 +289,10 @@ class Siglip2Attention(nn.Module): if self.attn_backend == _Backend.ROCM_AITER_FA: from aiter import flash_attn_varlen_func else: - from flash_attn import flash_attn_varlen_func + if self.use_upstream_fa: + from flash_attn import flash_attn_varlen_func + else: + from vllm.vllm_flash_attn import flash_attn_varlen_func attn_output = flash_attn_varlen_func( queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(seq_length, -1) @@ -378,12 +390,9 @@ class Siglip2EncoderLayer(nn.Module): position_embeddings: torch.Tensor) -> tuple[torch.FloatTensor]: """ Args: - hidden_states (`torch.FloatTensor`): - Input to the layer of shape `(batch, seq_len, embed_dim)`. - output_attentions (`bool`, *optional*, defaults to `False`): - Whether or not to return the attentions tensors of all - attention layers. See `attentions` under - returned tensors for more detail. + hidden_states: Input tensor of shape (batch, seq_len, embed_dim). + cu_seqlens: Cumulative sequence lengths tensor. + position_embeddings: Position embeddings tensor. """ residual = hidden_states @@ -522,19 +531,11 @@ class Siglip2Encoder(nn.Module): ) -> torch.Tensor: r""" Args: - inputs_embeds (`torch.FloatTensor` of shape - `(batch_size, sequence_length, hidden_size)`): - Optionally, instead of passing `input_ids` you can choose to - directly pass an embedded representation. This is useful if - you want more control over how to convert `input_ids` indices - into associated vectors than the model's internal embedding - lookup matrix. - grid_thws (`torch.LongTensor`): - grid shape (num_patches, 3) - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See - `hidden_states` under returned tensors for more detail. - return_dict (`bool`, *optional*): + inputs_embeds: Input tensor of shape + (batch_size, sequence_length, hidden_size). + Embedded representation of the input tokens. + grid_thws: Grid tensor of shape (num_patches, 3) + containing grid dimensions. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py index 17299b64978e3..2ba5f94ea3b88 100644 --- a/vllm/model_executor/models/step3_vl.py +++ b/vllm/model_executor/models/step3_vl.py @@ -16,6 +16,7 @@ from torchvision import transforms from torchvision.transforms.functional import InterpolationMode from transformers import BatchFeature, PretrainedConfig, TensorType +from vllm.attention.layer import MultiHeadAttention from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn @@ -682,9 +683,9 @@ class Step3VisionAttention(nn.Module): prefix=f"{prefix}.out_proj", disable_tp=use_data_parallel) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, - self.head_dim).transpose(1, 2).contiguous() + # Use unified MultiHeadAttention with automatic backend selection + self.attn = MultiHeadAttention(self.num_heads, self.head_dim, + self.scale) def forward( self, @@ -696,19 +697,9 @@ class Step3VisionAttention(nn.Module): # get query proj qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) - q = q.view(bsz, tgt_len, self.num_heads, self.head_dim) - k = k.view(bsz, tgt_len, self.num_heads, self.head_dim) - v = v.view(bsz, tgt_len, self.num_heads, self.head_dim) - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - attn_output = F.scaled_dot_product_attention(q, - k, - v, - scale=self.scale, - is_causal=False) - attn_output = attn_output.transpose(1, 2).reshape( - bsz, tgt_len, self.num_heads * self.head_dim) + + # Use unified MultiHeadAttention with automatic backend selection + attn_output = self.attn(q, k, v) attn_output, _ = self.out_proj(attn_output) diff --git a/vllm/model_executor/models/terratorch.py b/vllm/model_executor/models/terratorch.py index 739396a4932cb..b9dfa8e9b6f51 100644 --- a/vllm/model_executor/models/terratorch.py +++ b/vllm/model_executor/models/terratorch.py @@ -36,7 +36,7 @@ from vllm.multimodal.cache import MultiModalProcessorOnlyCache from vllm.multimodal.inputs import (ImageItem, ModalityData, MultiModalDataDict, MultiModalFieldConfig, MultiModalInputs, MultiModalKwargsItems, - PlaceholderRange) + MultiModalUUIDDict, PlaceholderRange) from vllm.multimodal.parse import (DictEmbeddingItems, ModalityDataItems, MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -164,7 +164,7 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor): mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, - mm_hash_overrides: Optional[dict[str, list[str]]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> MultiModalInputs: if "image" in mm_data: image_data = mm_data["image"] @@ -174,9 +174,10 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor): mm_items = self._to_mm_items(mm_data) tokenization_kwargs = tokenization_kwargs or {} - mm_hashes = (mm_hash_overrides if mm_hash_overrides is not None else - self._hash_mm_items(mm_items, hf_processor_mm_kwargs, - tokenization_kwargs)) + mm_hashes = self._hash_mm_items(mm_items, + hf_processor_mm_kwargs, + tokenization_kwargs, + mm_uuids=mm_uuids) mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]} mm_processed_data = BatchFeature(image_data) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 5ad0482330ecd..a386f47e1929f 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -44,7 +44,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputs, PlaceholderRange) + MultiModalInputs, MultiModalUUIDDict, + PlaceholderRange) from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo) @@ -347,7 +348,7 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, - mm_hash_overrides: Optional[dict[str, list[str]]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> MultiModalInputs: """ Process multi-modal inputs to be used in vLLM. @@ -415,9 +416,8 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): num_image_patches), ) # Use overrides if provided; fallback to data-dependent hashing. - mm_hashes = (mm_hash_overrides if mm_hash_overrides is not None else - self._hash_mm_items(mm_items, hf_processor_mm_kwargs, - tokenization_kwargs)) + mm_hashes = (mm_uuids if mm_uuids is not None else self._hash_mm_items( + mm_items, hf_processor_mm_kwargs, tokenization_kwargs)) return MultiModalInputs( type="multimodal", diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index c883065805279..ad911ebedf895 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -276,7 +276,7 @@ class UltravoxProjector(nn.Module): else: self.act = get_act_fn(config.projector_act) - dim_out = config.text_hidden_size + dim_out = config.text_config.hidden_size self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False) # Ultravox v0.4.1 and below use layer_norm after the second linear layer @@ -418,7 +418,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() - config = vllm_config.model_config.hf_config + config: UltravoxConfig = vllm_config.model_config.hf_config multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multi_modal_config = multimodal_config @@ -438,7 +438,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): self.multi_modal_projector = UltravoxProjector(config) self.language_model = init_vllm_registered_model( vllm_config=vllm_config, - hf_config=config.text_config, + hf_config=config.wrapped_model_config, prefix=maybe_prefix(prefix, "language_model"), ) if config.text_model_id is not None: @@ -597,10 +597,11 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): with the `input_ids`. Args: - audio_features: A batch of audio input chunks [B, N, 80, M]. - audio_lens: Length of audio frames for each audio chunk [B]. - audio_token_len: Length of audio tokens for each audio chunk [B']. - Note: batch dim is different from batch dim in audio chunks. + input_ids: Flattened (concatenated) input_ids corresponding to a + batch. + positions: Position indices for the input tokens. + intermediate_tensors: Intermediate tensors from prior forward pass. + inputs_embeds: Optional tensor of input embeddings. """ diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 28cfefac30ddb..e716ec582baab 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -761,3 +761,10 @@ def fast_topk(values: torch.Tensor, topk: int, else: # Use topk for efficiency with larger k values return torch.topk(values, topk, dim=dim) + + +def get_model_hidden_size(hf_config: PretrainedConfig) -> int: + if hasattr(hf_config, "hidden_size"): + return hf_config.hidden_size + text_config = hf_config.get_text_config() + return text_config.hidden_size diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index de30509b1ccb4..81f86db7e1875 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -7,7 +7,6 @@ from typing import Final, Generic, Optional, Protocol, TypeVar, Union import torch from transformers import PretrainedConfig -from vllm.attention.selector import get_env_variable_attn_backend from vllm.logger import init_logger from vllm.platforms import _Backend, current_platform @@ -68,17 +67,18 @@ def get_vision_encoder_info( raise NotImplementedError(msg) -def get_vit_attn_backend(support_fa: bool = False) -> _Backend: +def get_vit_attn_backend(head_size: int, dtype: torch.dtype) -> _Backend: """ Get the available attention backend for Vision Transformer. """ - # TODO(Isotr0py): Remove `support_fa` after support FA for all ViTs attn. + # Lazy import to avoid circular dependency + from vllm.attention.selector import get_env_variable_attn_backend selected_backend: Optional[_Backend] = get_env_variable_attn_backend() if selected_backend is not None: return selected_backend - return current_platform.get_vit_attn_backend(support_fa) + return current_platform.get_vit_attn_backend(head_size, dtype) def resolve_visual_encoder_outputs( diff --git a/vllm/model_executor/models/voxtral.py b/vllm/model_executor/models/voxtral.py index f3731b389cfe0..16a97389cd21b 100644 --- a/vllm/model_executor/models/voxtral.py +++ b/vllm/model_executor/models/voxtral.py @@ -23,15 +23,18 @@ from transformers.tokenization_utils_base import TextInput from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.inputs.data import PromptType from vllm.logger import init_logger +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models import SupportsPP +from vllm.model_executor.models.module_mapping import MultiModelKeys # yapf: disable from vllm.model_executor.models.whisper import WhisperEncoder # yapf: enable from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, NestedTensors) + MultiModalKwargsItems, MultiModalUUIDDict, + NestedTensors) from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -43,8 +46,8 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import (MistralTokenizer, cached_tokenizer_from_config) -from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, - SupportsTranscription) +from .interfaces import (MultiModalEmbeddings, SupportsLoRA, + SupportsMultiModal, SupportsTranscription) from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) @@ -290,14 +293,14 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo] mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], - mm_hash_overrides: Optional[dict[str, list[str]]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> tuple[list[int], MultiModalProcessingInfo, bool]: prompt_ids, mm_info, _ = super()._cached_apply_hf_processor( prompt=prompt, mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) # NOTE: The tokens are already inserted by the chat template @@ -312,13 +315,25 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo] info=VoxtralProcessingInfo, dummy_inputs=VoxtralDummyInputsBuilder) class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsPP, SupportsTranscription): + SupportsPP, SupportsLoRA, + SupportsTranscription): supported_languages = ISO639_1_SUPPORTED_LANGS + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"] + } + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() self.tokenizer = cached_tokenizer_from_config(vllm_config.model_config) + # update quant config to so that ignored module and target module names + # match the vLLM model names + if hasattr(vllm_config, "quant_config"): + vllm_config.quant_config = self.maybe_update_quant_config( + vllm_config.quant_config) + config = vllm_config.model_config.hf_config self.config = config self.downsample_factor = self.config.audio_config.downsample_factor @@ -340,6 +355,14 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal, def get_language_model(self) -> torch.nn.Module: return self.language_model + def get_mm_mapping(self) -> MultiModelKeys: + """Get module prefix for multimodal models to filter LoRA modules.""" + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="audio_language_adapter", + tower_model=["whisper_encoder"], + ) + def forward( self, input_ids: torch.Tensor, @@ -542,6 +565,72 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal, return loaded_weights + def maybe_update_quant_config( + self, quant_config: QuantizationConfig) -> QuantizationConfig: + """ + Update quant config to so that ignored module and target module names + match the vLLM model names. + Right now this is specific for compressed-tensors format and + load_format mistral. + """ + remapping_rules = [ + (r"output", r"language_model.lm_head"), + (r"layers\.(\d+)\.attention\.wo", + r"language_model.model.layers.\1.self_attn.out_proj"), + (r"layers\.(\d+)\.attention\.w(.*)", + r"language_model.model.layers.\1.self_attn.\2_proj"), + (r"layers\.(\d+)\.feed_forward\.w1", + r"language_model.model.layers.\1.mlp.gate_proj"), + (r"layers\.(\d+)\.feed_forward\.w2", + r"language_model.model.layers.\1.mlp.down_proj"), + (r"layers\.(\d+)\.feed_forward\.w3", + r"language_model.model.layers.\1.mlp.up_proj"), + (r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.w(.*)", + r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.\2_proj" + ), + (r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.wo", + r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.out_proj" + ), + (r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward.w(\d+)", + r"whisper_encoder.whisper_encoder.layers.\1.layers.mlp.fc\2"), + (r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.0", + r"whisper_encoder.whisper_encoder.conv1"), + (r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.1", + r"whisper_encoder.whisper_encoder.conv2"), + (r"mm_whisper_embeddings\.audio_language_projection\.0", + r"audio_language_adapter.w_in"), + (r"mm_whisper_embeddings\.audio_language_projection\.2", + r"audio_language_adapter.w_out"), + ] + + # Update ignore list + if hasattr(quant_config, "ignore"): + mistral_ignore = [] + for name in quant_config.ignore: + mistral_name = name + for pattern, repl in remapping_rules: + if re.fullmatch(pattern, name): + mistral_name = re.sub(pattern, repl, name) + mistral_ignore.append(mistral_name) + quant_config.ignore = mistral_ignore + + # Update target list + if hasattr(quant_config, "config_groups"): + config_groups = quant_config.config_groups + for group_name in config_groups: + if "targets" in config_groups[group_name]: + targets = [] + for name in config_groups[group_name]["targets"]: + mistral_name = name + for pattern, repl in remapping_rules: + if re.fullmatch(pattern, name): + mistral_name = re.sub(pattern, repl, name) + targets.append(mistral_name) + config_groups[group_name]["targets"] = targets + quant_config.config_groups = config_groups + + return quant_config + class AudioLanguageAdapter(nn.Module): @@ -584,7 +673,6 @@ class VoxtralEncoderModel(nn.Module): self.whisper_encoder = WhisperEncoder(vllm_config=vllm_config, prefix=maybe_prefix( prefix, "whisper_encoder"), - is_standalone_encoder=True, init_in_fp32=True) mel_filters = mel_filter_bank( num_frequency_bins=1 + self.config.window_size // 2, diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 97e8cd6e76957..41ae7b129782d 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -15,6 +15,7 @@ from transformers.models.whisper.modeling_whisper import sinusoids from vllm.attention import Attention, AttentionType from vllm.attention.layer import MultiHeadAttention +from vllm.attention.layers.cross_attention import CrossAttention from vllm.config import (CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig) from vllm.distributed import get_tensor_model_parallel_world_size @@ -43,7 +44,7 @@ from vllm.transformers_utils.processor import cached_get_processor from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, - SupportsTranscription, SupportsV0Only) + SupportsTranscription) from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors, make_layers) @@ -124,6 +125,34 @@ class WhisperAudioInputs(TensorSchema): TensorShape("b", "nmb", "t")] +class WhisperEncoderAttention(MultiHeadAttention): + """Multi-headed attention for Whisper encoder with 2D tensor support.""" + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + ) -> torch.Tensor: + """ + Input shape: batch_size x seq_len x hidden_size + or seq_len x hidden_size + """ + is_2d = query.dim() == 2 + if is_2d: + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + + # Call the parent forward method + out = super().forward(query, key, value) + + if is_2d: + out = out.squeeze(0) + + return out + + class WhisperPositionalEmbedding(nn.Embedding): def __init__(self, num_positions: int, embedding_dim: int): @@ -144,7 +173,6 @@ class WhisperAttention(nn.Module): cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", - standalone_encoder: bool = False, ): super().__init__() self.embed_dim = embed_dim @@ -180,14 +208,25 @@ class WhisperAttention(nn.Module): quant_config=quant_config, prefix=f"{prefix}.out_proj", ) - if standalone_encoder: - self.attn = MultiHeadAttention( + if attn_type == AttentionType.ENCODER: + self.attn = WhisperEncoderAttention( self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, ) - else: + elif self.attn_type == AttentionType.ENCODER_DECODER: + self.attn = CrossAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + attn_type=self.attn_type, + ) + else: # AttentionType.DECODER (regular decoder self-attention) self.attn = Attention( self.num_heads, self.head_dim, @@ -332,11 +371,7 @@ class WhisperMLP(nn.Module): class WhisperEncoderLayer(nn.Module): - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = "", - is_standalone_encoder: bool = False): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config @@ -350,7 +385,6 @@ class WhisperEncoderLayer(nn.Module): cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", - standalone_encoder=is_standalone_encoder, ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.mlp = WhisperMLP( @@ -446,12 +480,10 @@ class WhisperEncoder(nn.Module): *, vllm_config: VllmConfig, prefix: str = "", - is_standalone_encoder: bool = False, init_in_fp32: bool = False): super().__init__() config = vllm_config.model_config.hf_config embed_dim = config.d_model - self.is_standalone_encoder = is_standalone_encoder self.num_mel_bins = config.num_mel_bins self.max_source_positions = config.max_source_positions self.embed_scale = (math.sqrt(embed_dim) @@ -469,9 +501,7 @@ class WhisperEncoder(nn.Module): self.start_layer, self.end_layer, self.layers = make_layers( config.encoder_layers, lambda prefix: WhisperEncoderLayer(vllm_config=vllm_config, - prefix=f"{prefix}.layers", - is_standalone_encoder= - is_standalone_encoder), + prefix=f"{prefix}.layers"), prefix=f"{prefix}.layers", ) self.layer_norm = nn.LayerNorm(config.d_model) @@ -752,7 +782,7 @@ class WhisperMultiModalProcessor( info=WhisperProcessingInfo, dummy_inputs=WhisperDummyInputsBuilder) class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, - SupportsMultiModal, SupportsV0Only): + SupportsMultiModal): packed_modules_mapping = { "self_attn.qkv_proj": [ "self_attn.q_proj", @@ -880,19 +910,17 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: - # TODO: This method does not obey the interface for SupportsMultiModal. - # Refactor this once encoder/decoder support is implemented in V1. + # Required as part of SupportsMultiModal interface. audio_input = self._parse_and_validate_audio_input(**kwargs) - return self.model.get_encoder_outputs(audio_input["input_features"]) + return [self.model.get_encoder_outputs(audio_input["input_features"])] def get_input_embeddings( self, input_ids: torch.Tensor, multimodal_embeddings: Optional[NestedTensors] = None, ) -> torch.Tensor: - # TODO: This method just returns the decoder sequence embeddings since - # Whisper does not have encoder text tokens. Refactor this once - # encoder/decoder support is implemented in V1. + # This method just returns the decoder sequence embeddings since + # Whisper does not have encoder text tokens. return self.model.decoder.get_input_embeddings(input_ids) def _parse_and_validate_audio_input( diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index ed65944c109bd..86335d48c1454 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -9,7 +9,7 @@ model alternates between state space model layers and attention-based layers. """ from collections.abc import Iterable from itertools import cycle -from typing import Optional, Union +from typing import Any, Optional, Union import torch from torch import nn @@ -528,8 +528,6 @@ class Zamba2MambaDecoderLayer(nn.Module): hidden_states: Input tensor [batch_size, seq_len, hidden_size] mamba_cache_params: Parameters for Mamba's state caches (one for conv, one for ssm) - sequence_idx: Index tensor for identifying sequences in batch - Required for proper chunked processing in prefill transformer_hidden_states: Optional output from transformer path Added to input if provided (used in hybrid architecture) positions: Optional position IDs (unused in Mamba) @@ -591,8 +589,6 @@ class Zamba2HybridLayer(nn.Module): Args: shared_transformer: Transformer decoder layer for attention pathway - linear: Linear projection for transformer output before Mamba - mamba: Mamba decoder layer for state space pathway """ super().__init__() self.block_idx = block_idx @@ -630,8 +626,6 @@ class Zamba2HybridLayer(nn.Module): positions: Position IDs for positional embeddings mamba_cache_params: Parameters for Mamba's state caches (one for conv, one for ssm) - sequence_idx: Indices for identifying sequences in batch, - required for proper chunked processing in prefill Returns: Output tensor combining transformer and Mamba representations @@ -915,8 +909,8 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): prefix: Optional prefix for parameter names Raises: - AssertionError: If prefix caching is enabled (not supported by - Mamba) + AssertionError: If prefix caching is enabled + (not supported by Mamba) """ config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config @@ -971,7 +965,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): input_ids: torch.Tensor, positions: torch.Tensor, inputs_embeds: Optional[torch.Tensor] = None, - **kwargs) -> torch.Tensor: + **kwargs: Any) -> torch.Tensor: """Forward pass through the model. Args: @@ -1012,9 +1006,9 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): return hidden_states - def copy_inputs_before_cuda_graphs(self, input_buffers: dict[str, - torch.Tensor], - **kwargs) -> dict[str, torch.Tensor]: + def copy_inputs_before_cuda_graphs( + self, input_buffers: dict[str, torch.Tensor], + **kwargs: Any) -> dict[str, torch.Tensor]: """Copy inputs before CUDA graph capture. Args: diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index 41ed0b09c5a2a..65436786f82ac 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -52,10 +52,11 @@ def set_weight_attrs( def _make_synced_weight_loader(original_weight_loader): def _synced_weight_loader(param, *args, **kwargs): - original_weight_loader(param, *args, **kwargs) + out = original_weight_loader(param, *args, **kwargs) # torch._sync doesn't support, is not needed for CPU tensors. if param.device != torch.device("cpu"): torch._sync(param) + return out return _synced_weight_loader diff --git a/vllm/model_executor/warmup/deep_gemm_warmup.py b/vllm/model_executor/warmup/deep_gemm_warmup.py index 74599fa44c88c..a25ef86a989db 100644 --- a/vllm/model_executor/warmup/deep_gemm_warmup.py +++ b/vllm/model_executor/warmup/deep_gemm_warmup.py @@ -10,6 +10,7 @@ import torch from tqdm import tqdm import vllm.envs as envs +from vllm.distributed.parallel_state import get_dp_group from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts from vllm.model_executor.layers.fused_moe.deep_gemm_utils import ( compute_aligned_M, deep_gemm_block_shape) @@ -131,11 +132,9 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE: set[torch.Size] = set() -def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(w1: torch.Tensor, - w2: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - num_topk: int): +def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup( + w1: torch.Tensor, w2: torch.Tensor, w1_scale: torch.Tensor, + w2_scale: torch.Tensor, num_topk: int, max_tokens: int): if (w1.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE and w2.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE): return @@ -147,9 +146,13 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(w1: torch.Tensor, num_experts = w1.size(0) device = w1.device + # Assumes all ranks have the same max_num_batched_tokens + max_tokens_across_dp = get_dp_group().world_size * max_tokens + max_tokens = min(max_tokens_across_dp, envs.VLLM_FUSED_MOE_CHUNK_SIZE) + # This is the maximum GroupedGemm M size that we expect to run # the grouped_gemm with. - MAX_M = compute_aligned_M(envs.VLLM_FUSED_MOE_CHUNK_SIZE, + MAX_M = compute_aligned_M(max_tokens, num_topk, num_experts, block_m, @@ -201,7 +204,8 @@ def deepgemm_fp8_gemm_nt_warmup(model: torch.nn.Module, max_tokens: int): _deepgemm_fp8_gemm_nt_warmup(w=w, ws=ws, max_tokens=max_tokens) -def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model: torch.nn.Module): +def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model: torch.nn.Module, + max_tokens: int): dg_modules = [ m for m in model.modules() if _fused_moe_grouped_gemm_may_use_deep_gemm(m) @@ -211,9 +215,9 @@ def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model: torch.nn.Module): w13, w13_scale, w2, w2_scale, num_topk = ( _extract_data_from_fused_moe_module(dgm)) _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup( - w13, w2, w13_scale, w2_scale, num_topk) + w13, w2, w13_scale, w2_scale, num_topk, max_tokens) def deep_gemm_warmup(model: torch.nn.Module, max_tokens: int): deepgemm_fp8_gemm_nt_warmup(model, max_tokens) - deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model) + deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model, max_tokens) diff --git a/vllm/model_executor/warmup/kernel_warmup.py b/vllm/model_executor/warmup/kernel_warmup.py index 761172e4d3616..89ce20308f447 100644 --- a/vllm/model_executor/warmup/kernel_warmup.py +++ b/vllm/model_executor/warmup/kernel_warmup.py @@ -10,6 +10,7 @@ from typing import TYPE_CHECKING import torch import vllm.envs as envs +from vllm.logger import init_logger from vllm.model_executor.warmup.deep_gemm_warmup import deep_gemm_warmup from vllm.platforms import current_platform from vllm.utils.deep_gemm import is_deep_gemm_supported @@ -19,6 +20,8 @@ if TYPE_CHECKING: from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.gpu_worker import Worker +logger = init_logger(__name__) + def kernel_warmup(worker: "Worker"): # Deep GEMM warmup @@ -30,10 +33,33 @@ def kernel_warmup(worker: "Worker"): max_tokens = worker.scheduler_config.max_num_batched_tokens deep_gemm_warmup(model, max_tokens) - # FlashInfer autotune for Blackwell (SM 10.0) GPUs - if has_flashinfer() and current_platform.is_device_capability(100): + # FlashInfer autotune for Hopper (SM 9.0) and Blackwell (SM 10.0) GPUs + if has_flashinfer() and current_platform.has_device_capability(90): flashinfer_autotune(worker.model_runner) + # FlashInfer attention warmup + # Only warmup if the model has FlashInfer attention groups + # and is not a pooling model + def _is_flashinfer_backend(backend): + try: + return backend.get_name() == "FLASHINFER_VLLM_V1" + except NotImplementedError: + return False + + if not worker.model_runner.is_pooling_model and all( + _is_flashinfer_backend(group.backend) + for groups in worker.model_runner.attn_groups for group in groups): + logger.info("Warming up FlashInfer attention.") + # Warmup with mixed batch containing both prefill and decode tokens + # This is to warm up both prefill and decode attention kernels + worker.model_runner._dummy_run( + num_tokens=16, + skip_eplb=True, + is_profile=True, + force_attention=True, + create_mixed_batch=True, + ) + def flashinfer_autotune(runner: "GPUModelRunner") -> None: """ diff --git a/vllm/multimodal/cache.py b/vllm/multimodal/cache.py index 35b743ed21d92..31ae450f4c2ff 100644 --- a/vllm/multimodal/cache.py +++ b/vllm/multimodal/cache.py @@ -3,19 +3,24 @@ import sys from abc import ABC, abstractmethod from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union +from multiprocessing.synchronize import Lock as LockType +from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union, cast import torch from typing_extensions import TypeAlias, override +from vllm.distributed.device_communicators.shm_object_storage import ( + MsgpackSerde, SingleWriterShmObjectStorage, SingleWriterShmRingBuffer) +from vllm.envs import VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME from vllm.logger import init_logger -from vllm.utils import GiB_bytes, LRUCache +from vllm.utils import GiB_bytes, LRUCache, MiB_bytes from vllm.utils.jsontree import (json_count_leaves, json_map_leaves, json_reduce_leaves) -from .inputs import (MultiModalFeatureSpec, MultiModalFieldElem, - MultiModalKwargs, MultiModalKwargsItem, - MultiModalKwargsItems, NestedTensors) +from .inputs import (MultiModalBatchedField, MultiModalFeatureSpec, + MultiModalFieldElem, MultiModalKwargs, + MultiModalKwargsItem, MultiModalKwargsItems, + NestedTensors) if TYPE_CHECKING: from vllm.config import ModelConfig, VllmConfig @@ -389,6 +394,106 @@ class MultiModalProcessorSenderCache(BaseMultiModalProcessorCache): self._cache.clear() +class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache): + """ + The cache which is used on P0 when IPC caching is enabled. + + How to update each item: + + - If the item is already in the cache, clear the input to avoid + unnecessary IPC. + + - If the item is not in the cache, store the data in shared memory. + """ + + def __init__(self, vllm_config: "VllmConfig") -> None: + super().__init__() + + self.world_size = vllm_config.parallel_config.world_size + mm_config = vllm_config.model_config.get_multimodal_config() + + ring_buffer = SingleWriterShmRingBuffer( + data_buffer_size=int(mm_config.mm_processor_cache_gb * GiB_bytes), + name=VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME, + create=True, # sender is the writer + ) + self._shm_cache = SingleWriterShmObjectStorage( + max_object_size=mm_config.mm_shm_cache_max_object_size_mb * + MiB_bytes, + n_readers=self.world_size, + ring_buffer=ring_buffer, + serde_class=MsgpackSerde, + ) + # cache (prompt_updates, modality) for P0 only + self._p0_cache: dict[str, tuple[Sequence[ResolvedPromptUpdate], + str]] = {} + + @override + def is_cached_item(self, mm_hash: str) -> bool: + return self._shm_cache.is_cached(mm_hash) + + @override + def get_and_update_item( + self, + mm_item: MultiModalProcessorCacheInItem, + mm_hash: str, + ) -> MultiModalProcessorCacheOutItem: + + if self._shm_cache.is_cached(mm_hash): + address, monotonic_id = self._shm_cache.get_cached(mm_hash) + prompt_updates, modality = self._p0_cache[mm_hash] + return self.address_as_item(address, monotonic_id, + modality), prompt_updates + + assert mm_item is not None, f"Expected a cached item for {mm_hash=}" + + try: + address, monotonic_id = self._shm_cache.put(mm_hash, mm_item[0]) + # Try to remove dangling items if p0 cache is too large. + if len(self._p0_cache) >= 2 * len(self._shm_cache.key_index): + self.remove_dangling_items() + self._p0_cache[mm_hash] = mm_item[1], mm_item[0].modality + address_item = self.address_as_item(address, monotonic_id, + mm_item[0].modality) + return address_item, mm_item[1] + except (ValueError, MemoryError) as e: + # put may fail if the object is too large or + # the cache is full. + # In this case we log the error and keep the original mm_input. + logger.debug("Failed to cache mm_input with hash %s: %s", mm_hash, + e) + return mm_item + + @override + def clear_cache(self) -> None: + self._shm_cache.clear() + self._p0_cache.clear() + + def remove_dangling_items(self) -> None: + """Remove items that are no longer in the shared memory cache.""" + cached_hashes = self._shm_cache.key_index.keys() + dangling_hashes = set(self._p0_cache.keys()) - cached_hashes + for mm_hash in dangling_hashes: + del self._p0_cache[mm_hash] + + def address_as_item(self, address: int, monotonic_id: int, + modality: str) -> MultiModalKwargsItem: + addr_elem = MultiModalFieldElem( + modality=modality, + key="address", + data=address, + field=MultiModalBatchedField(), + ) + id_elem = MultiModalFieldElem( + modality=modality, + key="monotonic_id", + data=monotonic_id, + field=MultiModalBatchedField(), + ) + mm_item = MultiModalKwargsItem.from_elems([addr_elem, id_elem]) + return mm_item + + def _enable_processor_cache( model_config: "ModelConfig", mm_registry: "MultiModalRegistry", @@ -408,6 +513,17 @@ def _enable_ipc_cache(vllm_config: "VllmConfig") -> bool: return supports_ipc_cache +def _enable_mm_input_shm_cache(vllm_config: "VllmConfig") -> bool: + """Whether the shared memory based cache should be enabled.""" + + if not _enable_ipc_cache(vllm_config): + return False + + mm_config = vllm_config.model_config.get_multimodal_config() + + return mm_config.mm_processor_cache_type == "shm" + + def processor_cache_from_config( vllm_config: "VllmConfig", mm_registry: "MultiModalRegistry", @@ -421,7 +537,9 @@ def processor_cache_from_config( if not _enable_ipc_cache(vllm_config): return MultiModalProcessorOnlyCache(model_config) - return MultiModalProcessorSenderCache(model_config) + if not _enable_mm_input_shm_cache(vllm_config): + return MultiModalProcessorSenderCache(model_config) + return ShmObjectStoreSenderCache(vllm_config) def processor_only_cache_from_config( @@ -491,11 +609,68 @@ class MultiModalReceiverCache(BaseMultiModalReceiverCache): self._cache.clear() -def receiver_cache_from_config( +class ShmObjectStoreReceiverCache(BaseMultiModalReceiverCache): + """ + The cache which is used on P1 Worker Process when IPC caching is enabled. + + How to update each item: + + - If the item has an address, replace the input with the cached item. + - If not, return the input. + """ + + def __init__( + self, + vllm_config: "VllmConfig", + shared_worker_lock: LockType, + ) -> None: + super().__init__() + + self.world_size = vllm_config.parallel_config.world_size + mm_config = vllm_config.model_config.get_multimodal_config() + + ring_buffer = SingleWriterShmRingBuffer( + data_buffer_size=int(mm_config.mm_processor_cache_gb * GiB_bytes), + name=VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME, + create=False, # Server is a reader + ) + self._shm_cache = SingleWriterShmObjectStorage( + max_object_size=mm_config.mm_shm_cache_max_object_size_mb * + MiB_bytes, + n_readers=self.world_size, + ring_buffer=ring_buffer, + serde_class=MsgpackSerde, + reader_lock=shared_worker_lock, + ) + + @override + def get_and_update_item( + self, + mm_item: Optional[MultiModalKwargsItem], + mm_hash: str, + ) -> MultiModalKwargsItem: + assert mm_item is not None, f"Expected an address item for {mm_hash=}" + if "address" in mm_item: + address = cast(int, mm_item["address"].data) + monotonic_id = cast(int, mm_item["monotonic_id"].data) + return self._shm_cache.get(address, monotonic_id) + + return mm_item + + @override + def clear_cache(self) -> None: + self._shm_cache.clear() + + +def engine_receiver_cache_from_config( vllm_config: "VllmConfig", mm_registry: "MultiModalRegistry", ) -> Optional[BaseMultiModalReceiverCache]: - """Return a `BaseMultiModalReceiverCache`, if enabled.""" + """ + This is used in the engine process. + Return a `BaseMultiModalReceiverCache` only when IPC caching is enabled and + mm_processor_cache_type=="lru". + """ model_config = vllm_config.model_config if not _enable_processor_cache(model_config, mm_registry): @@ -504,4 +679,31 @@ def receiver_cache_from_config( if not _enable_ipc_cache(vllm_config): return None - return MultiModalReceiverCache(model_config) + if not _enable_mm_input_shm_cache(vllm_config): + return MultiModalReceiverCache(model_config) + + return None + + +def worker_receiver_cache_from_config( + vllm_config: "VllmConfig", + mm_registry: "MultiModalRegistry", + shared_worker_lock: LockType, +) -> Optional[BaseMultiModalReceiverCache]: + """ + This is used in the worker process. + Return a `BaseMultiModalReceiverCache` only when IPC caching is enabled and + mm_processor_cache_type=="shm". + """ + model_config = vllm_config.model_config + + if not _enable_processor_cache(model_config, mm_registry): + return None + + if not _enable_ipc_cache(vllm_config): + return None + + if not _enable_mm_input_shm_cache(vllm_config): + return None + + return ShmObjectStoreReceiverCache(vllm_config, shared_worker_lock) diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index f8ea3835f049d..240e34e139cfe 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -85,9 +85,10 @@ which are treated as audio embeddings; these are directly passed to the model without HF processing. """ -ModalityData: TypeAlias = Union[_T, list[_T]] +ModalityData: TypeAlias = Union[_T, list[Optional[_T]], None] """ -Either a single data item, or a list of data items. +Either a single data item, or a list of data items. Can only be None if UUID +is provided. The number of data items allowed per modality is restricted by `--limit-mm-per-prompt`. diff --git a/vllm/multimodal/parse.py b/vllm/multimodal/parse.py index 88bb99529f200..493dd3560a516 100644 --- a/vllm/multimodal/parse.py +++ b/vllm/multimodal/parse.py @@ -36,7 +36,7 @@ class ModalityDataItems(ABC, Generic[_T, _I]): def __init__(self, data: _T, modality: str) -> None: super().__init__() - self.data = data + self.data: _T = data self.modality = modality def __repr__(self) -> str: @@ -177,7 +177,9 @@ class DictEmbeddingItems(ModalityDataItems[Mapping[str, torch.Tensor], class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]): - def __init__(self, data: Sequence[HfAudioItem]) -> None: + def __init__(self, data: Optional[Sequence[HfAudioItem]]) -> None: + if data is None: + data = [None] super().__init__(data, "audio") def get_audio_length(self, item_idx: int) -> int: @@ -198,7 +200,9 @@ class ImageSize(NamedTuple): class ImageProcessorItems(ProcessorBatchItems[HfImageItem]): - def __init__(self, data: Sequence[HfImageItem]) -> None: + def __init__(self, data: Optional[Sequence[HfImageItem]]) -> None: + if data is None: + data = [None] super().__init__(data, "image") def get_image_size(self, item_idx: int) -> ImageSize: @@ -223,10 +227,12 @@ class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]): def __init__( self, - data: Sequence[HfVideoItem], + data: Optional[Sequence[HfVideoItem]], metadata: Optional[Union[dict[str, Any], list[Optional[dict[str, Any]]]]] = None, ) -> None: + if data is None: + data = [None] super().__init__(data, "video") self.metadata = metadata @@ -385,6 +391,9 @@ class MultiModalDataParser: self, data: ModalityData[AudioItem], ) -> Optional[ModalityDataItems[Any, Any]]: + if data is None: + return AudioProcessorItems(None) + # also check single audio item with sampling rate if self._is_empty(data) or (isinstance(data, tuple) and self._is_empty(data[0])): @@ -420,6 +429,9 @@ class MultiModalDataParser: self, data: ModalityData[ImageItem], ) -> Optional[ModalityDataItems[Any, Any]]: + if data is None: + return ImageProcessorItems(None) + if self._is_empty(data): return None @@ -441,6 +453,9 @@ class MultiModalDataParser: self, data: ModalityData[VideoItem], ) -> Optional[ModalityDataItems[Any, Any]]: + if data is None: + return VideoProcessorItems(None) + if self._is_empty(data): return None diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 0531b7bd9f0a7..7471bfcb4d508 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1022,13 +1022,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], *, - mm_hash_overrides: Optional[Union[dict[str, list[str]], - MultiModalUUIDDict]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> MultiModalInputs: return self.apply(prompt, mm_data, hf_processor_mm_kwargs, - mm_hash_overrides=mm_hash_overrides) + mm_uuids=mm_uuids) def _get_data_parser(self) -> MultiModalDataParser: """ @@ -1076,7 +1075,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): [`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data]. """ mm_items = self.data_parser.parse_mm_data(mm_data) - for modality, items in mm_items.items(): self.validate_num_items(modality, len(items)) @@ -1364,8 +1362,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], *, - mm_hash_overrides: Optional[Union[dict[str, list[str]], - MultiModalUUIDDict]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> MultiModalHashes: """Create MM hashes to be returned (only used in V1). @@ -1376,30 +1373,30 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): model_id = self.info.model_id hashes: MultiModalHashes = {} - mm_hash_overrides = mm_hash_overrides or {} + mm_uuids = mm_uuids or {} for modality, items in mm_items.items(): - if modality in mm_hash_overrides: - mm_hashes = mm_hash_overrides[modality] - if isinstance(mm_hashes, str): - mm_hashes = [mm_hashes] + if modality in mm_uuids: + mm_uuids_per_modality = mm_uuids[modality] + if isinstance(mm_uuids_per_modality, str): + mm_uuids_per_modality = [mm_uuids_per_modality] # For None entries, compute a hash; otherwise, use provided ID. computed: list[str] = [] for i, item in enumerate(items): - mm_hash = mm_hashes[i] + item_uuid = mm_uuids_per_modality[i] - # NOTE: Even if a mm_hash is provided, we still compute a + # NOTE: Even if a item_uuid is provided, we still compute a # hash if `hf_processor_mm_kwargs` or `tokenization_kwargs` # are provided. This is because the processed multimodal # inputs can be different depending on the processor kwargs. - if mm_hash is None or \ + if item_uuid is None or \ hf_processor_mm_kwargs or \ tokenization_kwargs: # NOTE: use provided hash string to hash with kwargs # if available for better performance. - item = mm_hash if mm_hash is not None else item + item = item_uuid if item_uuid is not None else item computed.append( MultiModalHasher.hash_kwargs( model_id=model_id, @@ -1407,7 +1404,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): **hf_processor_mm_kwargs, **tokenization_kwargs)) else: - computed.append(mm_hash) + computed.append(item_uuid) hashes[modality] = computed else: hashes[modality] = [ @@ -1438,10 +1435,18 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ] for modality, items_is_cached in mm_is_cached.items() } - mm_missing_data = { - modality: [mm_data_items[modality][idx] for idx in idxs] - for modality, idxs in mm_missing_idxs.items() - } + mm_missing_data = {} + for modality, idxs in mm_missing_idxs.items(): + missing_modality_data = [] + for idx in idxs: + data = mm_data_items[modality][idx] + if data is None: + raise ValueError( + f"Cache miss for {modality} at index {idx} " + f"but data is not provided.") + else: + missing_modality_data.append(data) + mm_missing_data[modality] = missing_modality_data return self._to_mm_items(mm_missing_data) @@ -1514,8 +1519,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], *, - mm_hash_overrides: Optional[Union[dict[str, list[str]], - MultiModalUUIDDict]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> tuple[list[int], MultiModalProcessingInfo, bool]: ( prompt_ids, @@ -1539,7 +1543,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_hashes = self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs, tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides) + mm_uuids=mm_uuids) mm_prompt_updates = self._get_mm_prompt_updates( mm_data_items, @@ -1562,8 +1566,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], *, - mm_hash_overrides: Optional[Union[dict[str, list[str]], - MultiModalUUIDDict]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> tuple[list[int], MultiModalProcessingInfo, bool]: """ Apply the HF processor on the full prompt text, @@ -1578,13 +1581,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) mm_hashes = self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs, tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides) + mm_uuids=mm_uuids) mm_missing_data_items = self._get_cache_missing_items( cache=cache, @@ -1785,8 +1788,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, *, - mm_hash_overrides: Optional[Union[dict[str, list[str]], - MultiModalUUIDDict]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> MultiModalInputs: """ Process multi-modal inputs to be used in vLLM. @@ -1815,7 +1817,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_items, hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) # NOTE: tokenization_kwargs are not required to init processor @@ -1901,8 +1903,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, *, - mm_hash_overrides: Optional[Union[dict[str, list[str]], - MultiModalUUIDDict]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> MultiModalEncDecInputs: """ Process multi-modal inputs to be used in vLLM. @@ -1917,7 +1918,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): mm_data, hf_processor_mm_kwargs, tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) return self._get_enc_dec_inputs( diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index ef1380bdb614c..df6e19da82ca2 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -1,12 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - import base64 +import math from abc import abstractmethod from functools import partial from io import BytesIO from pathlib import Path -from typing import Any +from typing import Any, Union import numpy as np import numpy.typing as npt @@ -104,10 +104,12 @@ class OpenCVVideoBackend(VideoLoader): return api_pref @classmethod - def load_bytes(cls, - data: bytes, - num_frames: int = -1, - **kwargs) -> tuple[npt.NDArray, dict[str, Any]]: + def load_bytes( + cls, + data: bytes, + num_frames: int = -1, + **kwargs, + ) -> tuple[npt.NDArray, dict[str, Any]]: import cv2 backend = cls().get_cv2_video_api() @@ -119,6 +121,15 @@ class OpenCVVideoBackend(VideoLoader): original_fps = cap.get(cv2.CAP_PROP_FPS) duration = total_frames_num / original_fps if original_fps > 0 else 0 + # Use transformers transformers.video_utils.VideoMetadata format + metadata = { + "total_num_frames": total_frames_num, + "fps": original_fps, + "duration": duration, + "video_backend": "opencv" + } + + # resample video to target num_frames full_read = num_frames == -1 or total_frames_num < num_frames if full_read: num_frames = total_frames_num @@ -148,14 +159,88 @@ class OpenCVVideoBackend(VideoLoader): assert i == num_frames, (f"Expected reading {num_frames} frames, " f"but only loaded {i} frames from video.") + return frames, metadata + + +@VIDEO_LOADER_REGISTRY.register("opencv_dynamic") +class OpenCVDynamicVideoBackend(OpenCVVideoBackend): + + @classmethod + def load_bytes( + cls, + data: bytes, + num_frames: int = -1, + requested_fps: int = 2, + max_duration: int = 300, + **kwargs, + ) -> tuple[npt.NDArray, dict[str, Any]]: + import cv2 + + backend = cls().get_cv2_video_api() + cap = cv2.VideoCapture(BytesIO(data), backend, []) + if not cap.isOpened(): + raise ValueError("Could not open video stream") + + total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + original_fps = cap.get(cv2.CAP_PROP_FPS) + duration = total_frames_num / original_fps if original_fps > 0 else 0 + # Use transformers transformers.video_utils.VideoMetadata format metadata = { "total_num_frames": total_frames_num, "fps": original_fps, "duration": duration, - "video_backend": "opencv" + "video_backend": "opencv_dynamic" } + # resample video to target num_frames + max_frame_idx = total_frames_num - 1 + duration = duration or round(max_frame_idx / original_fps) + 1 + + # Refer to: + # https://github.com/huggingface/transformers/blob/v4.55.4/src/transformers/models/glm4v/video_processing_glm4v.py#L103-L140 + frame_indices: Union[range, list[int]] + if duration <= max_duration: + n = int(math.floor(duration * requested_fps)) + frame_indices = sorted({ + min(max_frame_idx, + int(math.ceil(i * original_fps / requested_fps))) + for i in range(n) + }) + else: + num_samples = int(max_duration * requested_fps) + if num_samples >= total_frames_num: + frame_indices = range(total_frames_num) + else: + target_seconds = np.linspace(0, + duration, + num_samples, + endpoint=True) + frame_indices = sorted({ + min(max_frame_idx, int(math.ceil(t * original_fps))) + for t in target_seconds + }) + + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + frames = np.empty((len(frame_indices), height, width, 3), + dtype=np.uint8) + + i = 0 + for idx in range(total_frames_num): + ok = cap.grab() + if not ok: + break + if idx in frame_indices: + ret, frame = cap.retrieve() + if ret: + frames[i] = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + i += 1 + + assert i == len(frame_indices), ( + f"Expected reading {len(frame_indices)} frames, " + f"but only loaded {i} frames from video.") + return frames, metadata diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 12d5e0bf08652..c5b6d91a62b6d 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -75,12 +75,12 @@ class CpuPlatform(Platform): def supported_dtypes(self) -> list[torch.dtype]: if self.get_cpu_architecture() == CpuArchEnum.POWERPC: return [torch.bfloat16, torch.float32] - elif sys.platform.startswith( - "darwin") and self.get_cpu_architecture() == CpuArchEnum.ARM: - # TODO: change this condition to check if the platform support bf16 - # instead of checking the OS. For instance M2 shall supports bf16 - # already. But we need to modify `cpu_extension.cmake` to activate - # the feature in the build. + elif (self.get_cpu_architecture() == CpuArchEnum.ARM + and sys.platform.startswith("darwin")): + if (subprocess.check_output( + ["sysctl -n hw.optional.arm.FEAT_BF16"], + shell=True).strip() == b"1"): + return [torch.bfloat16, torch.float16, torch.float32] return [torch.float16, torch.float32] # x86/aarch64 CPU has supported both bf16 and fp16 natively. return [torch.bfloat16, torch.float16, torch.float32] @@ -347,3 +347,7 @@ class CpuPlatform(Platform): @classmethod def opaque_attention_op(cls) -> bool: return True + + @classmethod + def support_hybrid_kv_cache(cls) -> bool: + return True diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 1b0a298352cbf..8e3436a9e73c5 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -64,8 +64,7 @@ class CudaPlatformBase(Platform): if self.has_device_capability(80): # Ampere and Hopper or later NVIDIA GPUs. return [torch.bfloat16, torch.float16, torch.float32] - elif (not self.has_device_capability(80) - ) and self.has_device_capability(60): + if self.has_device_capability(60): # Pascal, Volta and Turing NVIDIA GPUs, BF16 is not supported return [torch.float16, torch.float32] # Kepler and Maxwell NVIDIA GPUs, only FP32 is supported, @@ -146,6 +145,7 @@ class CudaPlatformBase(Platform): # required block_size. use_flashmla = False use_cutlass_mla = False + use_flashinfer_mla = False if envs.VLLM_ATTENTION_BACKEND is None: # Default case @@ -164,6 +164,8 @@ class CudaPlatformBase(Platform): use_flashmla = (envs.VLLM_ATTENTION_BACKEND == "FLASHMLA") use_cutlass_mla = ( envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA") + use_flashinfer_mla = ( + envs.VLLM_ATTENTION_BACKEND == "FLASHINFER_MLA") from vllm.attention.ops.flashmla import is_flashmla_supported if use_flashmla and is_flashmla_supported()[0] \ @@ -177,22 +179,26 @@ class CudaPlatformBase(Platform): logger.info("Forcing kv cache block size to 128 for " "CUTLASS_MLA backend.") + if use_flashinfer_mla and cache_config.block_size not in [32, 64]: + cache_config.block_size = 64 + logger.info( + "Forcing kv cache block size to 64 for FlashInferMLA " + "backend.") + # lazy import to avoid circular import from vllm.config import CUDAGraphMode compilation_config = vllm_config.compilation_config if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" and parallel_config.data_parallel_size > 1 - and compilation_config.cudagraph_mode != CUDAGraphMode.NONE): + and compilation_config.cudagraph_mode + not in [CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE]): logger.info( - "Data Parallel: disabling cudagraphs since DP " - "with DeepEP high-throughput kernels are not CUDA Graph " - "compatible. The DeepEP low-latency kernels are CUDA Graph " - "compatible. Set the all_to_all backend to deepep_low_latency " - "to use those kernels instead.") - compilation_config.cudagraph_mode = CUDAGraphMode.NONE - if model_config is not None: - model_config.enforce_eager = True + "Data Parallel with DeepEP high-throughput: using PIECEWISE " + "CUDA graphs and excluding MoE ops from capture. Set " + "VLLM_ALL2ALL_BACKEND=deepep_low_latency if you need MoE " + "graphs captured as well.") + compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE @classmethod def get_current_memory_usage(cls, @@ -203,18 +209,24 @@ class CudaPlatformBase(Platform): return torch.cuda.max_memory_allocated(device) @classmethod - def get_vit_attn_backend(cls, support_fa: bool = False) -> _Backend: - if cls.has_device_capability(80) and support_fa: - from transformers.utils import is_flash_attn_2_available - if is_flash_attn_2_available(): + def get_vit_attn_backend(cls, head_size: int, + dtype: torch.dtype) -> _Backend: + if dtype not in (torch.float16, torch.bfloat16): + return _Backend.XFORMERS + + if cls.has_device_capability(80): + FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 + from vllm.attention.selector import is_attn_backend_supported + is_default_fa_supported = is_attn_backend_supported( + FLASH_ATTN_V1, head_size, dtype, allow_import_error=False) + if is_default_fa_supported: return _Backend.FLASH_ATTN - logger.warning_once( - "Current `vllm-flash-attn` has a bug inside vision " - "module, so we use xformers backend instead. You can " - "run `pip install flash-attn` to use flash-attention " - "backend.") - # Fallback for Volta/Turing GPUs or FA not supported - return _Backend.XFORMERS + else: + # Fallback to XFORMERS + return _Backend.XFORMERS + else: + # Fallback for Volta/Turing GPUs or FA not supported + return _Backend.XFORMERS @classmethod def get_attn_backend_cls(cls, selected_backend, head_size, dtype, @@ -230,6 +242,9 @@ class CudaPlatformBase(Platform): use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or ( selected_backend is None and cls.is_device_capability(100) and block_size == 128) + use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or ( + selected_backend is None and cls.is_device_capability(100) + and block_size in [32, 64]) use_flashmla = selected_backend in [ _Backend.FLASHMLA, _Backend.FLASHMLA_VLLM_V1 ] or (selected_backend is None and is_flashmla_supported()[0]) @@ -254,6 +269,19 @@ class CudaPlatformBase(Platform): else: logger.warning( "Cutlass MLA backend is only supported on V1 engine") + if use_flashinfermla: + if use_v1: + from vllm.v1.attention.backends.utils import ( + set_kv_cache_layout) + set_kv_cache_layout("HND") + logger.info_once( + "Using FlashInfer MLA backend on V1 engine.") + return ("vllm.v1.attention.backends.mla." + "flashinfer_mla.FlashInferMLABackend") + else: + logger.warning( + "FlashInfer MLA backend is only supported on V1 engine" + ) if use_flashmla: if block_size != 64: logger.warning( @@ -513,7 +541,9 @@ class CudaPlatformBase(Platform): attention_backend = "FLASHMLA" # Only FlashMLA and CUTLASS_MLA support fp8 - if attention_backend in ["FLASHMLA", "CUTLASS_MLA"]: + if attention_backend in [ + "FLASHMLA", "CUTLASS_MLA", "FLASHINFER_MLA" + ]: supported = True else: supported = (not fp8_attention) @@ -532,6 +562,10 @@ class CudaPlatformBase(Platform): supported = flash_attn_supports_fp8() else: supported = True + elif attention_backend == "FLASHINFER": + supported = True + elif attention_backend == "TRITON_ATTN_VLLM_V1": + supported = cls.supports_fp8() return supported @classmethod @@ -554,6 +588,10 @@ class CudaPlatformBase(Platform): "You can use float16 instead by explicitly setting the " "`dtype` flag in CLI, for example: --dtype=half.") + @classmethod + def support_hybrid_kv_cache(cls) -> bool: + return True + # NVML utils # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index fdd3764d2c35d..054d08c3a85be 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -48,8 +48,10 @@ class _Backend(enum.Enum): ROCM_AITER_MLA_VLLM_V1 = enum.auto() ROCM_AITER_FA = enum.auto() # used for ViT attn backend TORCH_SDPA = enum.auto() + TORCH_SDPA_VLLM_V1 = enum.auto() FLASHINFER = enum.auto() FLASHINFER_VLLM_V1 = enum.auto() + FLASHINFER_MLA = enum.auto() TRITON_MLA = enum.auto() # Supported by V1 TRITON_MLA_VLLM_V1 = enum.auto() CUTLASS_MLA = enum.auto() @@ -190,7 +192,8 @@ class Platform: return device_id @classmethod - def get_vit_attn_backend(cls, support_fa: bool = False) -> _Backend: + def get_vit_attn_backend(cls, head_size: int, + dtype: torch.dtype) -> _Backend: return _Backend.TORCH_SDPA @classmethod @@ -584,6 +587,13 @@ class Platform: """ raise NotImplementedError + @classmethod + def support_hybrid_kv_cache(cls) -> bool: + """ + Returns if the hybrid kv cache is supported by the current platform. + """ + return False + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index c6d14aa87c7f2..bb8bff48c7b95 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -171,19 +171,19 @@ class RocmPlatform(Platform): supported_quantization: list[str] = [ "awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf", - "quark", "ptpc_fp8", "mxfp4", "petit_nvfp4" + "quark", "ptpc_fp8", "mxfp4", "petit_nvfp4", "torchao" ] @classmethod - def get_vit_attn_backend(cls, support_fa: bool = False) -> _Backend: - if support_fa: - if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA - and on_gfx9()): - # Note: AITER FA is only supported for Qwen-VL models. - # TODO: Add support for other VL models in their model class. - return _Backend.ROCM_AITER_FA - if on_gfx9(): - return _Backend.FLASH_ATTN + def get_vit_attn_backend(cls, head_size: int, + dtype: torch.dtype) -> _Backend: + if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA + and on_gfx9()): + # Note: AITER FA is only supported for Qwen-VL models. + # TODO: Add support for other VL models in their model class. + return _Backend.ROCM_AITER_FA + if on_gfx9(): + return _Backend.FLASH_ATTN return _Backend.TORCH_SDPA @classmethod @@ -322,23 +322,35 @@ class RocmPlatform(Platform): @classmethod def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: + from vllm.config.compilation import CUDAGraphMode + cache_config = vllm_config.cache_config + compilation_config = vllm_config.compilation_config + parallel_config = vllm_config.parallel_config + is_eager_execution = compilation_config == CUDAGraphMode.NONE + + use_v1 = envs.VLLM_USE_V1 + use_aiter_rms_norm = envs.VLLM_ROCM_USE_AITER and \ + envs.VLLM_ROCM_USE_AITER_RMSNORM + if cache_config and cache_config.block_size is None: cache_config.block_size = 16 - parallel_config = vllm_config.parallel_config if parallel_config.worker_cls == "auto": if vllm_config.speculative_config: - if not envs.VLLM_USE_V1: + if not use_v1: raise NotImplementedError( "Speculative decoding is not supported on vLLM V0.") parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker" else: - if envs.VLLM_USE_V1: + if use_v1: parallel_config.worker_cls = \ "vllm.v1.worker.gpu_worker.Worker" else: parallel_config.worker_cls = "vllm.worker.worker.Worker" + # Aiter rms norm perform best when CUDA Graph capture is enabled. + if use_v1 and use_aiter_rms_norm and not is_eager_execution: + compilation_config.custom_ops.append("+rms_norm") @classmethod def verify_model_arch(cls, model_arch: str) -> None: @@ -486,3 +498,7 @@ class RocmPlatform(Platform): f"Your {gpu_name} GPU {compute_str}. " "You can use float16 instead by explicitly setting the " "`dtype` flag in CLI, for example: --dtype=half.") + + @classmethod + def support_hybrid_kv_cache(cls) -> bool: + return True diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index c7b4ba34c602e..fe93e906064e4 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -165,7 +165,8 @@ class SamplingParams( the sampled token, so there may be up to `logprobs+1` elements in the response. When set to -1, return all `vocab_size` log probabilities.""" prompt_logprobs: Optional[int] = None - """Number of log probabilities to return per prompt token.""" + """Number of log probabilities to return per prompt token. + When set to -1, return all `vocab_size` log probabilities.""" # NOTE: This parameter is only exposed at the engine level for now. # It is not exposed in the OpenAI API server, as the OpenAI API does # not support returning only a list of token IDs. @@ -409,9 +410,11 @@ class SamplingParams( and self.logprobs < 0): raise ValueError( f"logprobs must be non-negative or -1, got {self.logprobs}.") - if self.prompt_logprobs is not None and self.prompt_logprobs < 0: - raise ValueError(f"prompt_logprobs must be non-negative, got " - f"{self.prompt_logprobs}.") + if (self.prompt_logprobs is not None and self.prompt_logprobs != -1 + and self.prompt_logprobs < 0): + raise ValueError( + f"prompt_logprobs must be non-negative or -1, got " + f"{self.prompt_logprobs}.") if (self.truncate_prompt_tokens is not None and (self.truncate_prompt_tokens == 0 or self.truncate_prompt_tokens < -1)): diff --git a/vllm/tracing.py b/vllm/tracing.py index 6a287d82be5ff..7537e9901a044 100644 --- a/vllm/tracing.py +++ b/vllm/tracing.py @@ -119,6 +119,11 @@ class SpanAttributes: # forward, block/sync across workers, cpu-gpu sync time and sampling time. GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE = ( "gen_ai.latency.time_in_model_execute") + GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL = \ + "gen_ai.latency.time_in_model_prefill" + GEN_AI_LATENCY_TIME_IN_MODEL_DECODE = "gen_ai.latency.time_in_model_decode" + GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE = \ + "gen_ai.latency.time_in_model_inference" def contains_trace_headers(headers: Mapping[str, str]) -> bool: diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 95e4ed1ccf07f..fd19d33ca0c89 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -1,13 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import enum import json import os import time from functools import cache, partial from pathlib import Path -from typing import Any, Callable, Optional, TypeVar, Union +from typing import Any, Callable, Literal, Optional, TypeVar, Union import huggingface_hub from huggingface_hub import get_safetensors_metadata, hf_hub_download @@ -27,6 +26,7 @@ from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME from vllm import envs from vllm.logger import init_logger +from vllm.transformers_utils.config_parser_base import ConfigParserBase from vllm.transformers_utils.utils import check_gguf_file if envs.VLLM_USE_MODELSCOPE: @@ -75,11 +75,12 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict( eagle="EAGLEConfig", speculators="SpeculatorsConfig", nemotron="NemotronConfig", + olmo3="Olmo3Config", ovis="OvisConfig", ultravox="UltravoxConfig", step3_vl="Step3VLConfig", step3_text="Step3TextConfig", -) + qwen3_next="Qwen3NextConfig") _CONFIG_ATTRS_MAPPING: dict[str, str] = { "llm_config": "text_config", @@ -100,10 +101,163 @@ _AUTO_CONFIG_KWARGS_OVERRIDES: dict[str, dict[str, Any]] = { } -class ConfigFormat(str, enum.Enum): - AUTO = "auto" - HF = "hf" - MISTRAL = "mistral" +class HFConfigParser(ConfigParserBase): + + def parse(self, + model: Union[str, Path], + trust_remote_code: bool, + revision: Optional[str] = None, + code_revision: Optional[str] = None, + **kwargs) -> tuple[dict, PretrainedConfig]: + kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE + config_dict, _ = PretrainedConfig.get_config_dict( + model, + revision=revision, + code_revision=code_revision, + token=_get_hf_token(), + **kwargs, + ) + # Use custom model class if it's in our registry + model_type = config_dict.get("model_type") + if model_type is None: + model_type = "speculators" if config_dict.get( + "speculators_config") is not None else model_type + + if model_type in _CONFIG_REGISTRY: + config_class = _CONFIG_REGISTRY[model_type] + config = config_class.from_pretrained( + model, + revision=revision, + code_revision=code_revision, + token=_get_hf_token(), + **kwargs, + ) + else: + try: + kwargs = _maybe_update_auto_config_kwargs( + kwargs, model_type=model_type) + config = AutoConfig.from_pretrained( + model, + trust_remote_code=trust_remote_code, + revision=revision, + code_revision=code_revision, + token=_get_hf_token(), + **kwargs, + ) + except ValueError as e: + if (not trust_remote_code + and "requires you to execute the configuration file" + in str(e)): + err_msg = ( + "Failed to load the model config. If the model " + "is a custom model not yet available in the " + "HuggingFace transformers library, consider setting " + "`trust_remote_code=True` in LLM or using the " + "`--trust-remote-code` flag in the CLI.") + raise RuntimeError(err_msg) from e + else: + raise e + config = _maybe_remap_hf_config_attrs(config) + return config_dict, config + + +class MistralConfigParser(ConfigParserBase): + + def parse(self, + model: Union[str, Path], + trust_remote_code: bool, + revision: Optional[str] = None, + code_revision: Optional[str] = None, + **kwargs) -> tuple[dict, PretrainedConfig]: + # This function loads a params.json config which + # should be used when loading models in mistral format + config_dict = _download_mistral_config_file(model, revision) + if (max_position_embeddings := + config_dict.get("max_position_embeddings")) is None: + max_position_embeddings = _maybe_retrieve_max_pos_from_hf( + model, revision, **kwargs) + config_dict["max_position_embeddings"] = max_position_embeddings + + from vllm.transformers_utils.configs.mistral import adapt_config_dict + + config = adapt_config_dict(config_dict) + + # Mistral configs may define sliding_window as list[int]. Convert it + # to int and add the layer_types list[str] to make it HF compatible + if ((sliding_window := getattr(config, "sliding_window", None)) + and isinstance(sliding_window, list)): + pattern_repeats = config.num_hidden_layers // len(sliding_window) + layer_types = sliding_window * pattern_repeats + config.layer_types = [ + "full_attention" if layer_type is None else "sliding_attention" + for layer_type in layer_types + ] + config.sliding_window = next(filter(None, sliding_window), None) + + return config_dict, config + + +_CONFIG_FORMAT_TO_CONFIG_PARSER: dict[str, type[ConfigParserBase]] = { + "hf": HFConfigParser, + "mistral": MistralConfigParser, +} + +ConfigFormat = Literal[ + "auto", + "hf", + "mistral", +] + + +def get_config_parser(config_format: str) -> ConfigParserBase: + """Get the config parser for a given config format.""" + if config_format not in _CONFIG_FORMAT_TO_CONFIG_PARSER: + raise ValueError(f"Unknown config format `{config_format}`.") + return _CONFIG_FORMAT_TO_CONFIG_PARSER[config_format]() + + +def register_config_parser(config_format: str): + + """Register a customized vllm config parser. + When a config format is not supported by vllm, you can register a customized + config parser to support it. + Args: + config_format (str): The config parser format name. + Examples: + + >>> from vllm.transformers_utils.config import (get_config_parser, + register_config_parser) + >>> from vllm.transformers_utils.config_parser_base import ConfigParserBase + >>> + >>> @register_config_parser("custom_config_parser") + ... class CustomConfigParser(ConfigParserBase): + ... def parse(self, + ... model: Union[str, Path], + ... trust_remote_code: bool, + ... revision: Optional[str] = None, + ... code_revision: Optional[str] = None, + ... **kwargs) -> tuple[dict, PretrainedConfig]: + ... raise NotImplementedError + >>> + >>> type(get_config_parser("custom_config_parser")) + <class 'CustomConfigParser'> + """ # noqa: E501 + + def _wrapper(config_parser_cls): + if config_format in _CONFIG_FORMAT_TO_CONFIG_PARSER: + logger.warning( + "Config format `%s` is already registered, and will be " + "overwritten by the new parser class `%s`.", config_format, + config_parser_cls) + if not issubclass(config_parser_cls, ConfigParserBase): + raise ValueError("The config parser must be a subclass of " + "`ConfigParserBase`.") + _CONFIG_FORMAT_TO_CONFIG_PARSER[config_format] = config_parser_cls + logger.info("Registered config parser `%s` with config format `%s`", + config_parser_cls, config_format) + return config_parser_cls + + return _wrapper _R = TypeVar("_R") @@ -350,7 +504,7 @@ def get_config( trust_remote_code: bool, revision: Optional[str] = None, code_revision: Optional[str] = None, - config_format: ConfigFormat = ConfigFormat.AUTO, + config_format: Union[str, ConfigFormat] = "auto", hf_overrides_kw: Optional[dict[str, Any]] = None, hf_overrides_fn: Optional[Callable[[PretrainedConfig], PretrainedConfig]] = None, @@ -363,20 +517,22 @@ def get_config( kwargs["gguf_file"] = Path(model).name model = Path(model).parent - if config_format == ConfigFormat.AUTO: + if config_format == "auto": try: if is_gguf or file_or_path_exists( model, HF_CONFIG_NAME, revision=revision): - config_format = ConfigFormat.HF + config_format = "hf" elif file_or_path_exists(model, MISTRAL_CONFIG_NAME, revision=revision): - config_format = ConfigFormat.MISTRAL + config_format = "mistral" else: raise ValueError( "Could not detect config format for no config file found. " - "Ensure your model has either config.json (HF format) " - "or params.json (Mistral format).") + "With config_format 'auto', ensure your model has either" + "config.json (HF format) or params.json (Mistral format)." + "Otherwise please specify your_custom_config_format" + "in engine args for customized config parser") except Exception as e: error_message = ( @@ -395,92 +551,14 @@ def get_config( raise ValueError(error_message) from e - if config_format == ConfigFormat.HF: - kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE - config_dict, _ = PretrainedConfig.get_config_dict( - model, - revision=revision, - code_revision=code_revision, - token=_get_hf_token(), - **kwargs, - ) - # Use custom model class if it's in our registry - model_type = config_dict.get("model_type") - if model_type is None: - model_type = "speculators" if config_dict.get( - "speculators_config") is not None else model_type - - if model_type in _CONFIG_REGISTRY: - config_class = _CONFIG_REGISTRY[model_type] - config = config_class.from_pretrained( - model, - revision=revision, - code_revision=code_revision, - token=_get_hf_token(), - **kwargs, - ) - else: - try: - kwargs = _maybe_update_auto_config_kwargs( - kwargs, model_type=model_type) - config = AutoConfig.from_pretrained( - model, - trust_remote_code=trust_remote_code, - revision=revision, - code_revision=code_revision, - token=_get_hf_token(), - **kwargs, - ) - except ValueError as e: - if (not trust_remote_code - and "requires you to execute the configuration file" - in str(e)): - err_msg = ( - "Failed to load the model config. If the model " - "is a custom model not yet available in the " - "HuggingFace transformers library, consider setting " - "`trust_remote_code=True` in LLM or using the " - "`--trust-remote-code` flag in the CLI.") - raise RuntimeError(err_msg) from e - else: - raise e - config = _maybe_remap_hf_config_attrs(config) - - elif config_format == ConfigFormat.MISTRAL: - # This function loads a params.json config which - # should be used when loading models in mistral format - config_dict = _download_mistral_config_file(model, revision) - if (max_position_embeddings := - config_dict.get("max_position_embeddings")) is None: - max_position_embeddings = _maybe_retrieve_max_pos_from_hf( - model, revision, **kwargs) - config_dict["max_position_embeddings"] = max_position_embeddings - - from vllm.transformers_utils.configs.mistral import adapt_config_dict - - config = adapt_config_dict(config_dict) - - # Mistral configs may define sliding_window as list[int]. Convert it - # to int and add the layer_types list[str] to make it HF compatible - if ((sliding_window := getattr(config, "sliding_window", None)) - and isinstance(sliding_window, list)): - pattern_repeats = config.num_hidden_layers // len(sliding_window) - layer_types = sliding_window * pattern_repeats - config.layer_types = [ - "full_attention" if layer_type is None else "sliding_attention" - for layer_type in layer_types - ] - config.sliding_window = next(filter(None, sliding_window), None) - else: - supported_formats = [ - fmt.value for fmt in ConfigFormat if fmt != ConfigFormat.AUTO - ] - raise ValueError( - f"Unsupported config format: {config_format}. " - f"Supported formats are: {', '.join(supported_formats)}. " - f"Ensure your model uses one of these configuration formats " - f"or specify the correct format explicitly.") - + config_parser = get_config_parser(config_format) + config_dict, config = config_parser.parse( + model, + trust_remote_code=trust_remote_code, + revision=revision, + code_revision=code_revision, + **kwargs, + ) # Special architecture mapping check for GGUF models if is_gguf: if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: @@ -601,20 +679,21 @@ def get_hf_file_to_dict(file_name: str, @cache -def get_pooling_config(model: str, revision: Optional[str] = 'main'): +def get_pooling_config(model: str, + revision: Optional[str] = 'main') -> Optional[dict]: """ This function gets the pooling and normalize config from the model - only applies to sentence-transformers models. Args: - model (str): The name of the Hugging Face model. - revision (str, optional): The specific version - of the model to use. Defaults to 'main'. + model: The name of the Hugging Face model. + revision: The specific version of the model to use. + Defaults to 'main'. Returns: - dict: A dictionary containing the pooling - type and whether normalization is used. + A dictionary containing the pooling type and whether + normalization is used, or None if no pooling configuration is found. """ modules_file_name = "modules.json" @@ -914,7 +993,7 @@ def _maybe_retrieve_max_pos_from_hf(model, revision, **kwargs) -> int: hf_config = get_config(model=model, trust_remote_code=trust_remote_code_val, revision=revision, - config_format=ConfigFormat.HF) + config_format="hf") if hf_value := hf_config.get_text_config().max_position_embeddings: max_position_embeddings = hf_value except Exception as e: diff --git a/vllm/transformers_utils/config_parser_base.py b/vllm/transformers_utils/config_parser_base.py new file mode 100644 index 0000000000000..c27177f74d4ba --- /dev/null +++ b/vllm/transformers_utils/config_parser_base.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Optional, Union + +from transformers import PretrainedConfig + + +class ConfigParserBase(ABC): + + @abstractmethod + def parse(self, + model: Union[str, Path], + trust_remote_code: bool, + revision: Optional[str] = None, + code_revision: Optional[str] = None, + **kwargs) -> tuple[dict, PretrainedConfig]: + raise NotImplementedError diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index f651ecb078b95..ca0d5def760a8 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -23,7 +23,9 @@ from vllm.transformers_utils.configs.moonvit import MoonViTConfig from vllm.transformers_utils.configs.nemotron import NemotronConfig from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig from vllm.transformers_utils.configs.nemotron_vl import Nemotron_Nano_VL_Config +from vllm.transformers_utils.configs.olmo3 import Olmo3Config from vllm.transformers_utils.configs.ovis import OvisConfig +from vllm.transformers_utils.configs.qwen3_next import Qwen3NextConfig from vllm.transformers_utils.configs.speculators.base import SpeculatorsConfig from vllm.transformers_utils.configs.step3_vl import (Step3TextConfig, Step3VisionEncoderConfig, @@ -44,10 +46,12 @@ __all__ = [ "NemotronConfig", "NemotronHConfig", "Nemotron_Nano_VL_Config", + "Olmo3Config", "OvisConfig", "SpeculatorsConfig", "UltravoxConfig", "Step3VLConfig", "Step3VisionEncoderConfig", "Step3TextConfig", + "Qwen3NextConfig", ] diff --git a/vllm/transformers_utils/configs/eagle.py b/vllm/transformers_utils/configs/eagle.py index 6aabf9e5262e6..444ed70de3d0c 100644 --- a/vllm/transformers_utils/configs/eagle.py +++ b/vllm/transformers_utils/configs/eagle.py @@ -46,6 +46,7 @@ class EAGLEConfig(PretrainedConfig): # Eagle model name should follow naming convention of # LlamaForCausalLM -> EagleLlamaForCausalLM # LlamaForCausalLM -> Eagle3LlamaForCausalLM + # LlamaForCausalLMEagle3 -> LlamaForCausalLMEagle3 if method == "eagle": assert self.model is not None, \ "model should not be None when method is eagle" @@ -53,6 +54,7 @@ class EAGLEConfig(PretrainedConfig): f"Eagle{arch}" if not arch.startswith("Eagle") \ else arch for arch in self.model.architectures ] + elif method == "eagle3": assert self.model is not None, \ "model should not be None when method is eagle3" diff --git a/vllm/transformers_utils/configs/jais.py b/vllm/transformers_utils/configs/jais.py index 767c4ddae870d..d5ca2c7b4751a 100644 --- a/vllm/transformers_utils/configs/jais.py +++ b/vllm/transformers_utils/configs/jais.py @@ -74,10 +74,10 @@ class JAISConfig(PretrainedConfig): use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). - scale_attn_by_inverse_layer_idx (`bool`, *optional*, - defaults to `False`): - Whether to additionally scale attention weights by - `1 / layer_idx + 1`. + scale_attn_by_inverse_layer_idx + (`bool`, *optional*, defaults to `False`): + Whether to additionally scale attention weights + by `1 / layer_idx + 1`. reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`): Whether to scale keys (K) prior to computing attention (dot-product) diff --git a/vllm/transformers_utils/configs/mistral.py b/vllm/transformers_utils/configs/mistral.py index 8a9c660b882fd..5d9206e188322 100644 --- a/vllm/transformers_utils/configs/mistral.py +++ b/vllm/transformers_utils/configs/mistral.py @@ -157,6 +157,7 @@ def _remap_mistral_audio_args(config: dict) -> dict: encoder_attention_heads=encoder_args["n_heads"], vocab_size=encoder_args["vocab_size"], max_source_positions=encoder_args["max_source_positions"], + is_encoder_decoder=False, # Override WhisperConfig default ) } if quant_config: diff --git a/vllm/transformers_utils/configs/olmo3.py b/vllm/transformers_utils/configs/olmo3.py new file mode 100644 index 0000000000000..874507db43a7f --- /dev/null +++ b/vllm/transformers_utils/configs/olmo3.py @@ -0,0 +1,80 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from transformers.configuration_utils import PretrainedConfig + + +class Olmo3Config(PretrainedConfig): + + model_type = "olmo3" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=50304, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + use_cache=True, + pad_token_id=1, + bos_token_id=None, + eos_token_id=50279, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + rms_norm_eps=1e-5, + sliding_window=4096, + layer_types=None, + **kwargs, + ): + # This model uses Olmo3ForCausalLM in transformers but Olmo2ForCausalLM + # in vLLM. + if "architectures" not in kwargs: + kwargs["architectures"] = ["Olmo2ForCausalLM"] + elif "Olmo3ForCausalLM" in kwargs["architectures"]: + kwargs["architectures"].remove("Olmo3ForCausalLM") + kwargs["architectures"].append("Olmo2ForCausalLM") + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + self.rms_norm_eps = rms_norm_eps + + self.sliding_window = sliding_window + self.layer_types = layer_types + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" if (i + 1) % 4 != 0 else "full_attention" + for i in range(self.num_hidden_layers) + ] diff --git a/vllm/transformers_utils/configs/qwen3_next.py b/vllm/transformers_utils/configs/qwen3_next.py new file mode 100644 index 0000000000000..c7af26acd1b9f --- /dev/null +++ b/vllm/transformers_utils/configs/qwen3_next.py @@ -0,0 +1,275 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Qwen3-Next model configuration""" + +from transformers.configuration_utils import (PretrainedConfig, + layer_type_validation) +from transformers.modeling_rope_utils import rope_config_validation +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class Qwen3NextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3NextModel`]. It is used to instantiate a + Qwen3-Next model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of + Qwen3-Next-80B-A3B-Instruct [Qwen/Qwen3-Next-80B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-Next-80B-A3B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the model. Defines the number of different tokens that can be represented by the + `inputs_ids`. + hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 5632): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 48): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 2): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str`, *optional*, defaults to `"silu"`): + The non-linear activation function in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + partial_rotary_factor (`float`, *optional*, defaults to 0.25): + Percentage of the query and keys which will have rotary embedding. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + head_dim (`int`, *optional*, defaults to 256): + Projection weights dimension in multi-head attention. + linear_conv_kernel_dim (`int`, *optional*, defaults to 4): + Kernel size of the convolution used in linear attention layers. + linear_key_head_dim (`int`, *optional*, defaults to 128): + Dimension of each key head in linear attention. + linear_value_head_dim (`int`, *optional*, defaults to 128): + Dimension of each value head in linear attention. + linear_num_key_heads (`int`, *optional*, defaults to 16): + Number of key heads used in linear attention layers. + linear_num_value_heads (`int`, *optional*, defaults to 32): + Number of value heads used in linear attention layers. + decoder_sparse_step (`int`, *optional*, defaults to 1): + The frequency of the MoE layer. + moe_intermediate_size (`int`, *optional*, defaults to 512): + Intermediate size of the routed expert. + shared_expert_intermediate_size (`int`, *optional*, defaults to 512): + Intermediate size of the shared expert. + num_experts_per_tok (`int`, *optional*, defaults to 10): + Number of selected experts. + num_experts (`int`, *optional*, defaults to 512): + Number of routed experts. + norm_topk_prob (`bool`, *optional*, defaults to `True`): + Whether to normalize the topk probabilities. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabling this will also + allow the model to output the auxiliary loss, including load balancing loss and router z-loss. + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + mlp_only_layers (`list[int]`, *optional*, defaults to `[]`): + Indicate which layers use Qwen3NextMLP rather than Qwen3NextSparseMoeBlock + The list contains layer index, from 0 to num_layers-1 if we have num_layers layers + If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity. + layer_types (`list[str]`, *optional*): + Types of each layer (attention or linear). + + ```python + >>> from transformers import Qwen3NextModel, Qwen3NextConfig + + >>> # Initializing a Qwen3Next style configuration + >>> configuration = Qwen3NextConfig() + + >>> # Initializing a model from the Qwen3-Next-80B-A3B style configuration + >>> model = Qwen3NextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ # noqa: E501 + + model_type = "qwen3_next" + keys_to_ignore_at_inference = ["past_key_values"] + + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.experts.*.gate_proj": "colwise", + "layers.*.mlp.experts.*.up_proj": "colwise", + "layers.*.mlp.experts.*.down_proj": "rowwise", + "layers.*.mlp.shared_experts.gate_proj": "colwise", + "layers.*.mlp.shared_experts.up_proj": "colwise", + "layers.*.mlp.shared_experts.down_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=151936, + hidden_size=2048, + intermediate_size=5632, + num_hidden_layers=48, + num_attention_heads=16, + num_key_value_heads=2, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + partial_rotary_factor=0.25, + attention_bias=False, + attention_dropout=0.0, + head_dim=256, + linear_conv_kernel_dim=4, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_num_key_heads=16, + linear_num_value_heads=32, + decoder_sparse_step=1, + moe_intermediate_size=512, + shared_expert_intermediate_size=512, + num_experts_per_tok=10, + num_experts=512, + norm_topk_prob=True, + output_router_logits=False, + router_aux_loss_coef=0.001, + mlp_only_layers=None, + layer_types=None, + **kwargs, + ): + if mlp_only_layers is None: + mlp_only_layers = [] + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.partial_rotary_factor = partial_rotary_factor + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.head_dim = head_dim + rope_config_validation(self) + + self.layer_types = layer_types + if self.layer_types is None: + self.layer_types = [ + "linear_attention" if bool((i + 1) % 4) else "full_attention" + for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types) + + # linear attention part + self.linear_conv_kernel_dim = linear_conv_kernel_dim + self.linear_key_head_dim = linear_key_head_dim + self.linear_value_head_dim = linear_value_head_dim + self.linear_num_key_heads = linear_num_key_heads + self.linear_num_value_heads = linear_num_value_heads + + # MoE arguments + self.decoder_sparse_step = decoder_sparse_step + self.moe_intermediate_size = moe_intermediate_size + self.shared_expert_intermediate_size = shared_expert_intermediate_size + self.num_experts_per_tok = num_experts_per_tok + self.num_experts = num_experts + self.norm_topk_prob = norm_topk_prob + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.mlp_only_layers = mlp_only_layers + + +__all__ = ["Qwen3NextConfig"] diff --git a/vllm/transformers_utils/configs/ultravox.py b/vllm/transformers_utils/configs/ultravox.py index 87064cc12deda..aaf31d84d0c1a 100644 --- a/vllm/transformers_utils/configs/ultravox.py +++ b/vllm/transformers_utils/configs/ultravox.py @@ -20,10 +20,13 @@ class UltravoxConfig(transformers.PretrainedConfig): Args: audio_config (`Union[AutoConfig, dict]`, *optional*): - Custom audio config or dict + Custom audio config or dict. text_config (`Union[AutoConfig, dict]`, *optional*): - The config object of the text backbone. Can be any of `LlamaConfig` - or `MistralConfig`. + The config object of the text backbone. + audio_model_id (`str`, *optional*): + The model ID of the audio backbone. + text_model_id (`str`, *optional*): + The model ID of the text backbone. ignore_index (`int`, *optional*, defaults to -100): The ignore index for the loss function. audio_token_index (`int`, *optional*, defaults to 32000): @@ -34,16 +37,12 @@ class UltravoxConfig(transformers.PretrainedConfig): The initialization value for the layer normalization. projector_act (`str`, *optional*, defaults to `"swiglu"`): The activation function used by the multimodal projector. - text_model_lora_config (`LoraConfigSimplified`, *optional*): - The LoRA configuration for finetuning the text model. - audio_model_lora_config (`LoraConfigSimplified`, *optional*): - The LoRA configuration for finetuning the audio model. projector_ln_mid (`bool`, *optional*, defaults to `False`): Whether to apply layer normalization at the middle of the projector or at the end. Versions v0.4.1 and below use `False`, but v0.5 and above use `True`. """ - + wrapped_model_config: transformers.PretrainedConfig model_type = "ultravox" audio_token = "<|audio|>" is_composition = False @@ -60,15 +59,10 @@ class UltravoxConfig(transformers.PretrainedConfig): stack_factor: int = 8, norm_init: float = 0.4, projector_act: str = "swiglu", - text_model_lora_config: Optional[dict[str, Any]] = None, - audio_model_lora_config: Optional[dict[str, Any]] = None, projector_ln_mid: bool = False, **kwargs, ): self.ignore_index = ignore_index - - self.audio_model_id = audio_model_id - self.text_model_id = text_model_id self.audio_token_index = audio_token_index self.hidden_size = hidden_size @@ -77,36 +71,46 @@ class UltravoxConfig(transformers.PretrainedConfig): self.projector_act = projector_act self.projector_ln_mid = projector_ln_mid - if text_model_id is not None: - # Avoid circular import - from vllm.transformers_utils.config import get_config - - text_config_obj = get_config(text_model_id, - trust_remote_code=False) - else: + # N.B. May set the wrapped_model_config below. + self.text_model_id = text_model_id + if text_model_id is None: text_config = text_config or {} - text_config_obj = transformers.CONFIG_MAPPING[text_config.get( - "model_type", "llama")](**text_config) + self.wrapped_model_config = transformers.CONFIG_MAPPING[ + text_config.get("model_type", "llama")](**text_config) - inner_text_config = text_config_obj.get_text_config() - - if audio_model_id is not None: - # Avoid circular import - from vllm.transformers_utils.config import get_config - - audio_config = get_config(audio_model_id, trust_remote_code=False) - else: + # N.B. May set the audio_config below. + self.audio_model_id = audio_model_id + if audio_model_id is None: + self.audio_model_id = None audio_config = audio_config or {} - audio_config = transformers.CONFIG_MAPPING[audio_config.get( + self.audio_config = transformers.CONFIG_MAPPING[audio_config.get( "model_type", "whisper")](**audio_config) - self.text_config = text_config_obj - self.audio_config = audio_config - self.text_model_lora_config = text_model_lora_config or {} - self.audio_model_lora_config = audio_model_lora_config or {} - - self.vocab_size = inner_text_config.vocab_size - self.initializer_range = inner_text_config.initializer_range - self.text_hidden_size = inner_text_config.hidden_size - super().__init__(**kwargs) + + def __setattr__(self, key, value): + # Since --hf-overrides are applied _after_ the UltravoxConfig is + # instantiated, load the configs implicitly when assigning text_model_id + # or audio_model_id. This allows: + # + # --hf-overrides.text_model_id=<quantized variant> + # + # to behave as intended. + if key == "text_model_id" and value is not None: + from vllm.transformers_utils.config import get_config + + self.wrapped_model_config = get_config(value, + trust_remote_code=False) + elif key == "audio_model_id" and value is not None: + from vllm.transformers_utils.config import get_config + + self.audio_config = get_config(value, trust_remote_code=False) + + return super().__setattr__(key, value) + + @property + def text_config(self) -> transformers.PretrainedConfig: + # When Ultravox wraps a multi-modal model (e.g. Gemma), we instantiate + # the full model, but the text config is the text config of the inner + # model. + return self.wrapped_model_config.get_text_config() diff --git a/vllm/transformers_utils/processors/deepseek_vl2.py b/vllm/transformers_utils/processors/deepseek_vl2.py index 5896bde312657..d1d117b4e2cf4 100644 --- a/vllm/transformers_utils/processors/deepseek_vl2.py +++ b/vllm/transformers_utils/processors/deepseek_vl2.py @@ -25,6 +25,7 @@ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. import math +from typing import Any import torch import torchvision.transforms as T @@ -178,17 +179,15 @@ class DeepseekVLV2Processor(ProcessorMixin): prompt: str, images: list[Image.Image], inference_mode: bool = True, - **kwargs, + **kwargs: Any, ): """ Args: prompt (str): the formatted prompt; - conversations (list[dict]): conversations with a list of messages; images (list[ImageType]): the list of images; inference_mode (bool): if True, then remove the last eos token; - system_prompt (str): the system prompt; - **kwargs: + **kwargs: Additional keyword arguments. Returns: outputs (BaseProcessorOutput): the output of the processor, @@ -259,7 +258,7 @@ class DeepseekVLV2Processor(ProcessorMixin): text: str, images: list[Image.Image], inference_mode: bool = True, - **kwargs, + **kwargs: Any, ): """ diff --git a/vllm/transformers_utils/runai_utils.py b/vllm/transformers_utils/runai_utils.py new file mode 100644 index 0000000000000..b7bee1974de5b --- /dev/null +++ b/vllm/transformers_utils/runai_utils.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +import shutil +import signal +import tempfile +from typing import Optional + +from vllm.logger import init_logger +from vllm.utils import PlaceholderModule + +logger = init_logger(__name__) + +SUPPORTED_SCHEMES = ['s3://', 'gs://'] + +try: + from runai_model_streamer import list_safetensors as runai_list_safetensors + from runai_model_streamer import pull_files as runai_pull_files +except (ImportError, OSError): + # see https://github.com/run-ai/runai-model-streamer/issues/26 + # OSError will be raised on arm64 platform + runai_model_streamer = PlaceholderModule( + "runai_model_streamer") # type: ignore[assignment] + runai_pull_files = runai_model_streamer.placeholder_attr("pull_files") + runai_list_safetensors = runai_model_streamer.placeholder_attr( + "list_safetensors") + + +def list_safetensors(path: str = "") -> list[str]: + """ + List full file names from object path and filter by allow pattern. + + Args: + path: The object storage path to list from. + + Returns: + list[str]: List of full object storage paths allowed by the pattern + """ + return runai_list_safetensors(path) + + +def is_runai_obj_uri(model_or_path: str) -> bool: + return model_or_path.lower().startswith(tuple(SUPPORTED_SCHEMES)) + + +class ObjectStorageModel: + """ + A class representing an ObjectStorage model mirrored into a + temporary directory. + + Attributes: + dir: The temporary created directory. + + Methods: + pull_files(): Pull model from object storage to the temporary directory. + """ + + def __init__(self) -> None: + for sig in (signal.SIGINT, signal.SIGTERM): + existing_handler = signal.getsignal(sig) + signal.signal(sig, self._close_by_signal(existing_handler)) + + self.dir = tempfile.mkdtemp() + + def __del__(self): + self._close() + + def _close(self) -> None: + if os.path.exists(self.dir): + shutil.rmtree(self.dir) + + def _close_by_signal(self, existing_handler=None): + + def new_handler(signum, frame): + self._close() + if existing_handler: + existing_handler(signum, frame) + + return new_handler + + def pull_files(self, + model_path: str = "", + allow_pattern: Optional[list[str]] = None, + ignore_pattern: Optional[list[str]] = None) -> None: + """ + Pull files from object storage into the temporary directory. + + Args: + model_path: The object storage path of the model. + allow_pattern: A list of patterns of which files to pull. + ignore_pattern: A list of patterns of which files not to pull. + + """ + if not model_path.endswith("/"): + model_path = model_path + "/" + runai_pull_files(model_path, self.dir, allow_pattern, ignore_pattern) diff --git a/vllm/transformers_utils/s3_utils.py b/vllm/transformers_utils/s3_utils.py index f95aae7815e0b..d17c1afe9b504 100644 --- a/vllm/transformers_utils/s3_utils.py +++ b/vllm/transformers_utils/s3_utils.py @@ -2,12 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import fnmatch -import os -import shutil -import signal -import tempfile -from pathlib import Path -from typing import Optional +from typing import Any, Optional from vllm.utils import PlaceholderModule @@ -31,7 +26,7 @@ def _filter_ignore(paths: list[str], patterns: list[str]) -> list[str]: ] -def glob(s3=None, +def glob(s3: Optional[Any] = None, path: str = "", allow_pattern: Optional[list[str]] = None) -> list[str]: """ @@ -56,7 +51,7 @@ def glob(s3=None, def list_files( - s3, + s3: Any, path: str, allow_pattern: Optional[list[str]] = None, ignore_pattern: Optional[list[str]] = None @@ -93,70 +88,3 @@ def list_files( paths = _filter_ignore(paths, ignore_pattern) return bucket_name, prefix, paths - - -class S3Model: - """ - A class representing a S3 model mirrored into a temporary directory. - - Attributes: - s3: S3 client. - dir: The temporary created directory. - - Methods: - pull_files(): Pull model from S3 to the temporary directory. - """ - - def __init__(self) -> None: - self.s3 = boto3.client('s3') - for sig in (signal.SIGINT, signal.SIGTERM): - existing_handler = signal.getsignal(sig) - signal.signal(sig, self._close_by_signal(existing_handler)) - - self.dir = tempfile.mkdtemp() - - def __del__(self): - self._close() - - def _close(self) -> None: - if os.path.exists(self.dir): - shutil.rmtree(self.dir) - - def _close_by_signal(self, existing_handler=None): - - def new_handler(signum, frame): - self._close() - if existing_handler: - existing_handler(signum, frame) - - return new_handler - - def pull_files(self, - s3_model_path: str = "", - allow_pattern: Optional[list[str]] = None, - ignore_pattern: Optional[list[str]] = None) -> None: - """ - Pull files from S3 storage into the temporary directory. - - Args: - s3_model_path: The S3 path of the model. - allow_pattern: A list of patterns of which files to pull. - ignore_pattern: A list of patterns of which files not to pull. - - """ - if not s3_model_path.endswith("/"): - s3_model_path = s3_model_path + "/" - - bucket_name, base_dir, files = list_files(self.s3, s3_model_path, - allow_pattern, - ignore_pattern) - if len(files) == 0: - return - - for file in files: - destination_file = os.path.join( - self.dir, - file.removeprefix(base_dir).lstrip("/")) - local_dir = Path(destination_file).parent - os.makedirs(local_dir, exist_ok=True) - self.s3.download_file(bucket_name, file, destination_file) diff --git a/vllm/transformers_utils/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group.py index ae8220f9b9dc5..6b519cccd3cc6 100644 --- a/vllm/transformers_utils/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group.py @@ -5,7 +5,8 @@ from typing import Optional from typing_extensions import assert_never -from vllm.config import LoRAConfig, ModelConfig, SchedulerConfig +from vllm.config import ModelConfig, SchedulerConfig +from vllm.config.lora import LoRAConfig from vllm.lora.request import LoRARequest from vllm.transformers_utils.tokenizer import (AnyTokenizer, encode_tokens, get_lora_tokenizer, diff --git a/vllm/triton_utils/__init__.py b/vllm/triton_utils/__init__.py index 0fcf5d15afd1d..828536e6408b1 100644 --- a/vllm/triton_utils/__init__.py +++ b/vllm/triton_utils/__init__.py @@ -7,8 +7,10 @@ from vllm.triton_utils.importing import (HAS_TRITON, TritonLanguagePlaceholder, if HAS_TRITON: import triton import triton.language as tl + import triton.language.extra.libdevice as tldevice else: triton = TritonPlaceholder() tl = TritonLanguagePlaceholder() + tldevice = TritonLanguagePlaceholder() -__all__ = ["HAS_TRITON", "triton", "tl"] +__all__ = ["HAS_TRITON", "triton", "tl", "tldevice"] diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 9c78e56d580e0..f13381ecd9ff3 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -78,6 +78,7 @@ if TYPE_CHECKING: from argparse import Namespace from vllm.config import ModelConfig, VllmConfig + from vllm.sequence import IntermediateTensors logger = init_logger(__name__) @@ -162,6 +163,12 @@ STR_FLASH_ATTN_VAL: str = "FLASH_ATTN" STR_DUAL_CHUNK_FLASH_ATTN_VAL: str = "DUAL_CHUNK_FLASH_ATTN" STR_INVALID_VAL: str = "INVALID" +MB_bytes = 1_000_000 +"""The number of bytes in one megabyte (MB).""" + +MiB_bytes = 1 << 20 +"""The number of bytes in one mebibyte (MiB).""" + GB_bytes = 1_000_000_000 """The number of bytes in one gigabyte (GB).""" @@ -1472,7 +1479,8 @@ def current_stream() -> torch.cuda.Stream: # is hurting performance. Therefore creating a dedicated stream # per process if current_platform.is_rocm(): - _current_stream_tls.value = torch.cuda.Stream() + # torch.cuda.set_stream here is the alias of _pathed_set_stream + torch.cuda.set_stream(torch.cuda.Stream()) elif current_platform.is_cpu(): _current_stream_tls.value = _StreamPlaceholder() else: @@ -2074,6 +2082,7 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, return await task(*args, **kwargs) +@lru_cache def supports_kw( callable: Callable[..., object], kw_name: str, @@ -2278,7 +2287,8 @@ def weak_ref_tensor(tensor: Any) -> Any: def weak_ref_tensors( - tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]] + tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor], + IntermediateTensors] ) -> Union[torch.Tensor, list[Any], tuple[Any], Any]: """ Convenience function to create weak references to tensors, @@ -2290,6 +2300,15 @@ def weak_ref_tensors( return [weak_ref_tensor(t) for t in tensors] if isinstance(tensors, tuple): return tuple(weak_ref_tensor(t) for t in tensors) + + # For IntermediateTensors used in pipeline parallelism + from vllm.sequence import IntermediateTensors + if isinstance(tensors, IntermediateTensors): + ret = IntermediateTensors({ + key: weak_ref_tensor(val) + for key, val in tensors.tensors.items() + }) + return ret raise ValueError("Invalid type for tensors") @@ -2779,7 +2798,10 @@ def memory_profiling( result.torch_peak_increase = diff_profile.torch_peak result.non_torch_increase = diff_from_create.non_torch_memory result.profile_time = diff_profile.timestamp - result.non_kv_cache_memory = result.non_torch_increase + result.torch_peak_increase + result.weights_memory # noqa + + non_torch_memory = result.non_torch_increase + peak_activation_memory = result.torch_peak_increase + result.non_kv_cache_memory = non_torch_memory + peak_activation_memory + result.weights_memory # noqa # Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 # noqa: E501 @@ -3249,7 +3271,7 @@ def check_use_alibi(model_config: ModelConfig) -> bool: and getattr(cfg.attn_config, "alibi", False))))) -def sha256(input) -> int: +def sha256(input) -> bytes: """Hash any picklable Python object using SHA-256. The input is serialized using pickle before hashing, which allows @@ -3260,16 +3282,15 @@ def sha256(input) -> int: input: Any picklable Python object. Returns: - An integer representing the SHA-256 hash of the serialized input. + Bytes representing the SHA-256 hash of the serialized input. """ input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) - return int.from_bytes(hashlib.sha256(input_bytes).digest(), - byteorder="big") + return hashlib.sha256(input_bytes).digest() -def sha256_cbor_64bit(input) -> int: +def sha256_cbor(input) -> bytes: """ - Hash objects using CBOR serialization and SHA-256, then truncate to 64bits. + Hash objects using CBOR serialization and SHA-256. This option is useful for non-Python-dependent serialization and hashing. @@ -3280,17 +3301,13 @@ def sha256_cbor_64bit(input) -> int: Custom classes must implement CBOR serialization methods. Returns: - An integer in the range [0, 2^64-1] representing the lower 64 bits - of the SHA-256 hash of the CBOR serialized input. + Bytes representing the SHA-256 hash of the CBOR serialized input. """ input_bytes = cbor2.dumps(input, canonical=True) - full_hash = int.from_bytes(hashlib.sha256(input_bytes).digest(), - byteorder="big") - - return full_hash & ((1 << 64) - 1) + return hashlib.sha256(input_bytes).digest() -def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], int]: +def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]: """Get a hash function by name, or raise an error if the function is not found. Args: @@ -3300,10 +3317,8 @@ def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], int]: """ if hash_fn_name == "sha256": return sha256 - if hash_fn_name == "sha256_cbor_64bit": - return sha256_cbor_64bit - if hash_fn_name == "builtin": - return hash + if hash_fn_name == "sha256_cbor": + return sha256_cbor raise ValueError(f"Unsupported hash function: {hash_fn_name}") @@ -3366,7 +3381,7 @@ def has_triton_kernels() -> bool: def set_process_title(name: str, suffix: str = "", - append: bool = False) -> None: + prefix: str = envs.VLLM_PROCESS_NAME_PREFIX) -> None: """ Set the current process title to a specific name with an optional suffix. @@ -3374,15 +3389,11 @@ def set_process_title(name: str, Args: name: The title to assign to the current process. suffix: An optional suffix to append to the base name. - append: Whether to append to the existing process title. + prefix: A prefix to prepend to the front separated by `::`. """ if suffix: name = f"{name}_{suffix}" - if append: - name = f"{setproctitle.getproctitle()}_{name}" - else: - name = f"{envs.VLLM_PROCESS_NAME_PREFIX}::{name}" - setproctitle.setproctitle(name) + setproctitle.setproctitle(f"{prefix}::{name}") def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None: diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index fab134733d4fd..83ec65c9b4594 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -200,11 +200,6 @@ def use_trtllm_attention( logger.info_once("Using TRTLLM attention (query is quantized).") return True - # TRTLLM prefill attention does not support FP8 kv cache with - # non-quantized query - if is_prefill and kv_cache_dtype.startswith("fp8"): - return False - # If sinks are being used, we must use TRTLLM attention as it's # the only backend that supports them if has_sinks: @@ -353,6 +348,12 @@ def flashinfer_scaled_fp8_mm( return output +@functools.cache +def flashinfer_disable_q_quantization() -> bool: + """Cache result which only depends on the environment""" + return envs.VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION + + __all__ = [ "has_flashinfer", "flashinfer_trtllm_fp8_block_scale_moe", diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index ced8234a7b433..6627164c98798 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -317,8 +317,8 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]): def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device) -> None: - self.kv_cache_spec = kv_cache_spec - self.vllm_config = vllm_config + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + self.scheduler_config = vllm_config.scheduler_config # For reorder @@ -641,10 +641,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): attn_metadata: TorchSDPAMetadata, attn_type: str = AttentionType.DECODER, ) -> None: - if self.num_kv_heads != self.num_heads: - key = key.repeat_interleave(self.num_queries_per_kv, dim=1) - value = value.repeat_interleave(self.num_queries_per_kv, dim=1) - attn_masks = attn_metadata.get_attn_bias(attn_type) if attn_masks is None: if self.alibi_slopes is not None: @@ -665,6 +661,10 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): key = key.movedim(0, key.dim() - 2) value = value.movedim(0, value.dim() - 2) + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=-3) + value = value.repeat_interleave(self.num_queries_per_kv, dim=-3) + causal_attn = (attn_type == AttentionType.DECODER) seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 3cc67acd04c6b..20f1904b3be6f 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -177,12 +177,11 @@ class FlashAttentionMetadataBuilder( def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): - self.vllm_config = vllm_config + super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.model_config = vllm_config.model_config self.parallel_config = vllm_config.parallel_config self.cache_config = vllm_config.cache_config self.compilation_config = vllm_config.compilation_config - self.device = device self.num_heads_q = self.model_config.get_num_attention_heads( self.parallel_config) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 06a853007a578..98a4cf38bc195 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -25,7 +25,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import cdiv, is_pin_memory_available -from vllm.utils.flashinfer import (supports_trtllm_attention, +from vllm.utils.flashinfer import (flashinfer_disable_q_quantization, + supports_trtllm_attention, use_trtllm_attention) from vllm.v1.attention.backends.flash_attn import use_cascade_attention # yapf conflicts with isort for this block @@ -48,8 +49,89 @@ FP4_DTYPE = torch.uint8 logger = init_logger(__name__) -class FlashInferBackend(AttentionBackend): +@triton.jit +def _trtllm_prefill_attn_kvfp8_dequant( + kv_cache_ptr, + block_tables_prefill_ptr, + block_table_stride, + mock_kv_cache_ptr, + k_scale_ptr, + v_scale_ptr, + K_CACHE_STRIDE: tl.constexpr, + KV_CACHE_STRIDE: tl.constexpr, +): + batch_idx = tl.program_id(0).to(tl.int64) + mock_block_table_idx = tl.program_id(1).to(tl.int64) + orig_page_num = tl.load(block_tables_prefill_ptr + + batch_idx * block_table_stride + + mock_block_table_idx).to(tl.int64) + if orig_page_num <= 0: + return + dequant_dtype = mock_kv_cache_ptr.dtype.element_ty + # Dequantize K + k_scale_val = tl.load(k_scale_ptr) + offset = orig_page_num * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE) + fp8_vals = tl.load(kv_cache_ptr + offset) + dequantized_vals = fp8_vals.to(tl.float32) * k_scale_val + mock_cache_offset = (batch_idx * block_table_stride + mock_block_table_idx + + 1) * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE) + dequantized_vals = dequantized_vals.to(dequant_dtype) + tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals) + + # Dequantize V + v_scale_val = tl.load(v_scale_ptr) + offset = (orig_page_num * KV_CACHE_STRIDE + K_CACHE_STRIDE + + tl.arange(0, K_CACHE_STRIDE)) + fp8_vals = tl.load(kv_cache_ptr + offset) + dequantized_vals = fp8_vals.to(tl.float32) * v_scale_val + mock_cache_offset = ( + (batch_idx * block_table_stride + mock_block_table_idx + 1) * + KV_CACHE_STRIDE + K_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)) + dequantized_vals = dequantized_vals.to(dequant_dtype) + tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals) + + +def trtllm_prefill_attn_kvfp8_dequant( + kv_cache: torch.Tensor, + block_tables_prefill: torch.Tensor, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + dequant_dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + batch_size, num_of_page_per_token = block_tables_prefill.shape + s = kv_cache.shape + assert s[1] == 2 + assert dequant_dtype in (torch.bfloat16, torch.float16) + k_cache_stride = s[2] * s[3] * s[4] + kv_cache_stride = k_cache_stride * s[1] + new_s = (batch_size * num_of_page_per_token + 1, s[1], s[2], s[3], s[4]) + # mock kv cache contains just the pages needed by this prefill + mock_kv_cache = torch.empty(new_s, + dtype=dequant_dtype, + device=kv_cache.device) + # we simply sequentially index the pages needed by this prefill + mock_block_table = torch.arange( + start=1, + end=batch_size * num_of_page_per_token + 1, + dtype=torch.int32, + device=block_tables_prefill.device, + ).reshape(batch_size, num_of_page_per_token) + grid = (batch_size, num_of_page_per_token) + _trtllm_prefill_attn_kvfp8_dequant[grid]( + kv_cache, + block_tables_prefill, + num_of_page_per_token, + mock_kv_cache, + k_scale, + v_scale, + k_cache_stride, + kv_cache_stride, + ) + return mock_kv_cache, mock_block_table + + +class FlashInferBackend(AttentionBackend): accept_output_buffer: bool = True @classmethod @@ -122,7 +204,6 @@ class FlashInferBackend(AttentionBackend): @dataclass class FlashInferMetadata: - num_actual_tokens: int # Number of tokens excluding padding. # The data type of the query @@ -163,11 +244,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): - self.device = device - self.vllm_config = vllm_config + super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.cache_config = vllm_config.cache_config self.model_config = vllm_config.model_config - self.kv_cache_spec = kv_cache_spec self._workspace_buffer = None self._prefill_wrapper = None # Wrapper for prefill/append self._decode_wrapper = None # Wrapper for decode (general shape) @@ -177,8 +256,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self.kv_cache_spec.block_size) max_num_reqs = vllm_config.scheduler_config.max_num_seqs max_num_pages = max_num_reqs * max_num_pages_per_req - self.enable_cuda_graph = self.compilation_config.cudagraph_mode.\ - decode_mode() == CUDAGraphMode.FULL + self.enable_cuda_graph = (self.compilation_config.cudagraph_mode.\ + decode_mode() == CUDAGraphMode.FULL) if self.enable_cuda_graph: # For full cudagraph capture, one `decode_wrapper` for each batch # size is needed for FlashInfer. @@ -194,20 +273,21 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): FlashInferBackend.validate_head_size(self.head_dim) self.page_size = self.kv_cache_spec.block_size - self.enable_fusion = ( - self.compilation_config.pass_config.enable_attn_fusion) - self.q_data_type = self.model_config.dtype self.cache_dtype = self.cache_config.cache_dtype if self.cache_dtype.startswith("fp8"): self.kv_cache_dtype = ( FlashInferBackend.get_fp8_dtype_for_flashinfer( self.cache_dtype)) - # Insert FP8 quant for query if FP8 kv cache and attn fusion enabled - if self.enable_fusion: - self.q_data_type = self.kv_cache_dtype else: + assert self.kv_cache_spec.dtype == self.model_config.dtype self.kv_cache_dtype = self.kv_cache_spec.dtype + if supports_trtllm_attention()[0] and \ + not flashinfer_disable_q_quantization(): + self.q_data_type = self.kv_cache_dtype + else: + self.q_data_type = self.model_config.dtype + self._cascade_wrapper = None # Wrapper for cascade attention # Global hyperparameters shared by all attention layers @@ -218,7 +298,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self.window_left = self.global_hyperparameters.window_left self.logits_soft_cap = self.global_hyperparameters.logits_soft_cap self.has_sinks = self.global_hyperparameters.has_sinks - + if self.has_sinks and not supports_trtllm_attention()[0]: + raise NotImplementedError( + "FlashInfer backend currently does not support attention " + "sinks, please use trtllm on blackwell or flash attention on " + "earlier GPUs.") # Preparing persistent buffers (device-side) self.paged_kv_indptr = torch.zeros(max_num_reqs + 1, dtype=torch.int32, @@ -410,7 +494,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self.q_data_type, is_prefill=False, has_sinks=self.has_sinks) - + if self.has_sinks and not (prefill_use_trtllm and decode_use_trtllm): + raise NotImplementedError( + "FlashInfer backend currently does not support attention " + "sinks, please use trtllm on blackwell or flash attention on " + "earlier GPUs.") attn_metadata = FlashInferMetadata( num_actual_tokens=num_actual_tokens, q_data_type=self.q_data_type, @@ -543,22 +631,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ) return attn_metadata - def build_for_cudagraph_capture( - self, common_attn_metadata: CommonAttentionMetadata): - """ - This method builds the metadata for full cudagraph capture. - Currently, only decode is supported for full cudagraphs with FlashInfer. - """ - m = common_attn_metadata - - assert m.num_reqs == m.num_actual_tokens, \ - "FlashInfer only supports decode-only full CUDAGraph capture. " \ - "Make sure all cudagraph capture sizes <= max_num_seq." - - m.max_query_len = 1 # decode-only - - return self.build(0, m) - def use_cascade_attention(self, *args, **kwargs) -> bool: if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype: # TODO: The cascade wrapper currently does not support setting @@ -668,8 +740,6 @@ class FlashInferImpl(AttentionImpl): # The attn+quant fusion happens when output_scale is provided. if output_scale is None: - assert attn_metadata.q_data_type != FP8_DTYPE, \ - "Query can only be FP8 if output fusion happened." assert output_block_scale is None, "output_block_scale "\ "is not supported when fusion has not happened" else: @@ -697,7 +767,8 @@ class FlashInferImpl(AttentionImpl): elif output.dtype == FP4_DTYPE: self.o_sf_scale = layer._o_scale_float - # Insert FP8 quant for query + # Insert FP8 quant for query + if attn_metadata.q_data_type == FP8_DTYPE: num_tokens, num_heads, head_size = query.shape query, _ = ops.scaled_fp8_quant( query.reshape( @@ -806,11 +877,29 @@ class FlashInferImpl(AttentionImpl): assert self.o_sf_scale is None out = output[num_decode_tokens:] + if attn_metadata.q_data_type != FP8_DTYPE \ + and self.kv_cache_dtype.startswith("fp8"): + # TRTLLM prefill attention does not support BF16 Q + # and fp8 kv cache. So to enable prefill attention + # with fp8 kv cache, we can construct a mock block + # and mock kv cache with BF16 KV involved in the prefill + mock_kv_cache, mock_block_table = ( + trtllm_prefill_attn_kvfp8_dequant( + kv_cache_permute, + block_tables_prefill, + layer._k_scale, + layer._v_scale, + attn_metadata.q_data_type, + )) + else: + mock_kv_cache = kv_cache_permute + mock_block_table = block_tables_prefill + trtllm_batch_context_with_kv_cache( query=prefill_query, - kv_cache=kv_cache_permute, + kv_cache=mock_kv_cache, workspace_buffer=workspace_buffer, - block_tables=block_tables_prefill, + block_tables=mock_block_table, seq_lens=seq_lens_prefill, max_q_len=attn_metadata.max_q_len, max_kv_len=attn_metadata.max_seq_len, @@ -848,7 +937,7 @@ class FlashInferImpl(AttentionImpl): decode_query = decode_query.contiguous() workspace_buffer = decode_wrapper._float_workspace_buffer block_tables_decode = attn_metadata.\ - block_table_tensor[:num_decode_tokens] + block_table_tensor[:num_decode_tokens] seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens] # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index d5b1c15e68d0e..cb983494216a7 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -516,10 +516,11 @@ class FlexAttentionMetadataBuilder( def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + self.model_config = vllm_config.model_config self.parallel_config = vllm_config.parallel_config self.cache_config = vllm_config.cache_config - self.device = device self.num_heads_q = self.model_config.get_num_attention_heads( self.parallel_config) diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py new file mode 100644 index 0000000000000..74eb9ae9d3254 --- /dev/null +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -0,0 +1,320 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Backend for GatedDeltaNet attention.""" +from dataclasses import dataclass +from typing import ClassVar, Optional + +import torch + +from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.backends.utils import PAD_SLOT_ID +from vllm.config import VllmConfig +from vllm.v1.attention.backends.utils import (AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, + split_decodes_and_prefills) +from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec + + +class GDNAttentionBackend(AttentionBackend): + + @staticmethod + def get_builder_cls() -> type["GDNAttentionMetadataBuilder"]: + return GDNAttentionMetadataBuilder + + +@dataclass +class GDNAttentionMetadata: + num_prefills: int + num_prefill_tokens: int + num_decodes: int + num_decode_tokens: int + num_spec_decodes: int + num_spec_decode_tokens: int + + has_initial_state: Optional[torch.Tensor] = None + + spec_query_start_loc: Optional[ + torch.Tensor] = None # shape: [num_spec_decodes + 1,] + non_spec_query_start_loc: Optional[ + torch.Tensor] = None # shape: [batch - num_spec_decodes + 1,] + + spec_state_indices_tensor: Optional[ + torch.Tensor] = None # shape: [batch, num_spec] + non_spec_state_indices_tensor: Optional[ + torch.Tensor] = None # shape: [batch - num_spec_decodes,] + spec_sequence_masks: Optional[torch.Tensor] = None # shape: [batch,] + spec_token_masks: Optional[ + torch. + Tensor] = None # shape: [num_prefill_tokens + num_decode_tokens,] + num_accepted_tokens: Optional[torch.Tensor] = None # shape: [batch,] + + +class GDNAttentionMetadataBuilder( + AttentionMetadataBuilder[GDNAttentionMetadata]): + + cudagraph_support = AttentionCGSupport.UNIFORM_BATCH + + reorder_batch_threshold: ClassVar[int] = 1 + + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], + vllm_config: VllmConfig, device: torch.device): + assert isinstance(kv_cache_spec, MambaSpec) + self.vllm_config = vllm_config + self.compilation_config = vllm_config.compilation_config + self.speculative_config = vllm_config.speculative_config + self.kv_cache_spec = kv_cache_spec + if self.speculative_config: + self.num_spec = self.speculative_config.num_speculative_tokens # noqa: E501 + else: + self.num_spec = 0 + self.use_spec_decode = self.num_spec > 0 + self.reorder_batch_threshold = self.num_spec + 1 # type: ignore[misc] + + self.use_full_cuda_graph = \ + self.compilation_config.cudagraph_mode.has_full_cudagraphs() + self.decode_cudagraph_max_bs = min( + self.vllm_config.scheduler_config.max_num_seqs, + self.compilation_config.max_capture_size) + + self.spec_state_indices_tensor = torch.empty( + (self.decode_cudagraph_max_bs, self.num_spec + 1), + dtype=torch.int32, + device=device, + ) + self.non_spec_state_indices_tensor = torch.empty( + (self.decode_cudagraph_max_bs, ), + dtype=torch.int32, + device=device, + ) + self.spec_sequence_masks = torch.empty( + (self.decode_cudagraph_max_bs, ), + dtype=torch.bool, + device=device, + ) + self.spec_token_masks = torch.empty( + (self.decode_cudagraph_max_bs * (self.num_spec + 1), ), + dtype=torch.bool, + device=device, + ) + self.spec_query_start_loc = torch.empty( + (self.decode_cudagraph_max_bs + 1, ), + dtype=torch.int32, + device=device, + ) + self.non_spec_query_start_loc = torch.empty( + (self.decode_cudagraph_max_bs + 1, ), + dtype=torch.int32, + device=device, + ) + self.num_accepted_tokens = torch.empty( + (self.decode_cudagraph_max_bs, ), + dtype=torch.int32, + device=device, + ) + + def build( # type: ignore[override] + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + num_accepted_tokens: Optional[torch.Tensor] = None, + num_draft_tokens: Optional[torch.Tensor] = None, + fast_build: bool = False, + ) -> GDNAttentionMetadata: + m = common_attn_metadata + + query_start_loc = m.query_start_loc + context_lens = m.num_computed_tokens_cpu + context_lens_tensor = context_lens.to(query_start_loc.device) + seq_lens_tensor = m.seq_lens + + if (not self.use_spec_decode or num_draft_tokens is None + or num_draft_tokens.sum().item() == 0): + spec_sequence_masks = None + else: + spec_sequence_masks = (num_draft_tokens > 0) & ( + context_lens_tensor + + (num_draft_tokens + 1) == seq_lens_tensor) + if spec_sequence_masks.sum().item() == 0: + spec_sequence_masks = None + + if spec_sequence_masks is None: + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills(m, decode_threshold=1)) + num_spec_decodes = 0 + num_spec_decode_tokens = 0 + spec_token_masks = None + spec_state_indices_tensor = None + non_spec_state_indices_tensor = m.block_table_tensor[:, 0] + spec_query_start_loc = None + non_spec_query_start_loc = query_start_loc + num_accepted_tokens = None + else: + num_spec_decodes = spec_sequence_masks.sum().item() + query_lens = query_start_loc[1:] - query_start_loc[:-1] + + non_spec_query_lens = query_lens[~spec_sequence_masks] + num_decodes = (non_spec_query_lens == 1).sum().item() + num_prefills = non_spec_query_lens.size(0) - num_decodes + num_decode_tokens = num_decodes + num_prefill_tokens = non_spec_query_lens.sum().item( + ) - num_decode_tokens + + if num_prefills == 0 and num_decodes == 0: + spec_token_masks = torch.ones( + (min(num_spec_decodes * + (self.num_spec + 1), query_start_loc[-1].item())), + dtype=torch.bool, + device=query_start_loc.device) + spec_state_indices_tensor = m.block_table_tensor[:, :self. + num_spec + 1] + non_spec_state_indices_tensor = None + spec_query_start_loc = query_start_loc + non_spec_query_start_loc = None + else: + spec_token_masks = torch.repeat_interleave( + spec_sequence_masks, query_lens) + spec_state_indices_tensor = m.block_table_tensor[ + spec_sequence_masks, :self.num_spec + 1] + non_spec_state_indices_tensor = \ + m.block_table_tensor[~spec_sequence_masks, 0] + + spec_query_start_loc = torch.zeros( + num_spec_decodes + 1, + dtype=torch.int32, + device=query_start_loc.device) + torch.cumsum(query_lens[spec_sequence_masks], + dim=0, + out=spec_query_start_loc[1:]) + non_spec_query_start_loc = torch.zeros( + query_lens.size(0) - num_spec_decodes + 1, + dtype=torch.int32, + device=query_start_loc.device) + torch.cumsum(query_lens[~spec_sequence_masks], + dim=0, + out=non_spec_query_start_loc[1:]) + + num_spec_decode_tokens = min( + num_spec_decodes * (self.num_spec + 1), + spec_token_masks.size(0)) + assert num_accepted_tokens is not None + num_accepted_tokens = num_accepted_tokens[spec_sequence_masks] + + if num_prefills > 0: + has_initial_state = context_lens_tensor > 0 + if spec_sequence_masks is not None: + has_initial_state = has_initial_state[~spec_sequence_masks] + else: + has_initial_state = None + + # prepare tensors for cudagraph + if (self.use_full_cuda_graph and num_prefills == 0 and num_decodes == 0 + and num_spec_decodes <= self.decode_cudagraph_max_bs + and m.num_actual_tokens <= self.decode_cudagraph_max_bs): + num_total_tokens = self.vllm_config.pad_for_cudagraph( + m.num_actual_tokens) + batch_size = num_total_tokens // (self.num_spec + 1) + + self.spec_state_indices_tensor[:num_spec_decodes].copy_( + spec_state_indices_tensor, non_blocking=True) + spec_state_indices_tensor = self.spec_state_indices_tensor[: + batch_size] + spec_state_indices_tensor[num_spec_decodes:].fill_(PAD_SLOT_ID) + + self.spec_sequence_masks[:num_spec_decodes].copy_( + spec_sequence_masks, non_blocking=True) + spec_sequence_masks = self.spec_sequence_masks[:batch_size] + spec_sequence_masks[num_spec_decodes:].fill_(False) + + assert spec_token_masks is not None + self.spec_token_masks[:spec_token_masks.size(0)].copy_( + spec_token_masks, non_blocking=True) + spec_token_masks = self.spec_token_masks[:m.num_actual_tokens] + spec_token_masks[spec_token_masks.size(0):].fill_(False) + + self.spec_query_start_loc[:num_spec_decodes + 1].copy_( + spec_query_start_loc, non_blocking=True) + spec_num_query_tokens = spec_query_start_loc[ + -1] # type: ignore[index] + spec_query_start_loc = self.spec_query_start_loc[:batch_size + 1] + spec_query_start_loc[num_spec_decodes + + 1:].fill_(spec_num_query_tokens) + + self.num_accepted_tokens[:num_spec_decodes].copy_( + num_accepted_tokens, non_blocking=True) + num_accepted_tokens = self.num_accepted_tokens[:batch_size] + num_accepted_tokens[num_spec_decodes:].fill_(1) + + if (self.use_full_cuda_graph and num_prefills == 0 + and num_spec_decodes == 0 + and num_decodes <= self.decode_cudagraph_max_bs): + num_total_tokens = self.vllm_config.pad_for_cudagraph( + m.num_actual_tokens) + batch_size = num_total_tokens + + self.non_spec_state_indices_tensor[:num_decodes].copy_( + non_spec_state_indices_tensor, non_blocking=True) + non_spec_state_indices_tensor = \ + self.non_spec_state_indices_tensor[:batch_size] + non_spec_state_indices_tensor[num_decodes:].fill_(PAD_SLOT_ID) + + self.non_spec_query_start_loc[:num_decodes + 1].copy_( + non_spec_query_start_loc, non_blocking=True) + non_spec_num_query_tokens = non_spec_query_start_loc[ + -1] # type: ignore[index] + non_spec_query_start_loc = \ + self.non_spec_query_start_loc[:batch_size + 1] + non_spec_query_start_loc[num_decodes + + 1:].fill_(non_spec_num_query_tokens) + + attn_metadata = GDNAttentionMetadata( + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_spec_decodes=num_spec_decodes, + num_spec_decode_tokens=num_spec_decode_tokens, + has_initial_state=has_initial_state, + spec_query_start_loc=spec_query_start_loc, + non_spec_query_start_loc=non_spec_query_start_loc, + spec_state_indices_tensor=spec_state_indices_tensor, + non_spec_state_indices_tensor=non_spec_state_indices_tensor, + spec_sequence_masks=spec_sequence_masks, + spec_token_masks=spec_token_masks, + num_accepted_tokens=num_accepted_tokens, + ) + return attn_metadata + + def build_for_cudagraph_capture( + self, common_attn_metadata: CommonAttentionMetadata): + """ + This method builds the metadata for full cudagraph capture. + Currently, only decode is supported for full cudagraphs with Mamba. + """ + m = common_attn_metadata + + assert (m.num_reqs * (self.num_spec + 1) <= m.num_actual_tokens + and ((m.num_reqs + 1) * (self.num_spec + 1) + >= m.num_actual_tokens)), \ + "GDN only supports decode-only full CUDAGraph capture. " \ + "Make sure all cudagraph capture sizes <= max_num_seq." + + num_accepted_tokens = torch.full((m.num_reqs, ), + m.max_query_len, + dtype=torch.int32, + device=m.query_start_loc.device) + num_drafted_tokens = torch.full((m.num_reqs, ), + self.num_spec, + dtype=torch.int32, + device=m.query_start_loc.device) + + # Fixes query-start loc for spec-sequence-indices. + m.query_start_loc = torch.arange(0, + m.num_actual_tokens + 1, + step=m.max_query_len, + device=m.query_start_loc.device, + dtype=torch.int32) + m.num_computed_tokens_cpu = (m.seq_lens_cpu - torch.full( + (m.num_reqs, ), m.max_query_len, dtype=torch.int32, device='cpu')) + + return self.build(0, m, num_accepted_tokens, num_drafted_tokens) diff --git a/vllm/v1/attention/backends/linear_attn.py b/vllm/v1/attention/backends/linear_attn.py index ac0034b5dcf06..3ff201d83a79b 100644 --- a/vllm/v1/attention/backends/linear_attn.py +++ b/vllm/v1/attention/backends/linear_attn.py @@ -39,8 +39,8 @@ class LinearAttentionMetadataBuilder( def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) assert isinstance(kv_cache_spec, MambaSpec) - self.kv_cache_spec = kv_cache_spec def build(self, common_prefix_len: int, diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index f3e6cd7430e0b..359bad1ea9dee 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -16,9 +16,58 @@ from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, from vllm.v1.kv_cache_interface import AttentionSpec -def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor, - chunk_size: int, - total_seqlens: int): +def _query_start_loc_to_chunk_indices_offsets( + query_start_loc: torch.Tensor, chunk_size: int, + total_seqlens: int) -> tuple[torch.Tensor, torch.Tensor]: + """ + Args: + query_start_loc (torch.Tensor): 1D tensor of cumulative sequence + lengths, shape (num_seqs + 1,). + The first element should be 0. Each entry represents the starting + index of a sequence in the flattened token array. + chunk_size (int): The size of each physical mamba chunk + (number of tokens per chunk). + total_seqlens (int): The total number of tokens in the batch. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - chunk_indices (torch.Tensor): 1D tensor of indices + indicating the physical chunk for each logical chunk. + - chunk_offsets (torch.Tensor): 1D tensor of offsets + indicating the starting index of each logical chunk within + its physical chunk. + + This function computes the chunk indices and offsets for the given + query_start_loc and chunk_size. Both are tensors of integers with length N, + where N is the number of logical (pseudo) chunks. + A logical chunk is a sequence of tokens that are all part of the same + sequence and are all in the same physical mamba chunk. + In other words, a logical chunk changes every time we cross a sequence + boundary or a physical mamba chunk boundary. + Logical chunks are needed to handle batched requests with initial states + (see _state_passing_fwd and _chunk_scan_fwd). + The chunk_indices tensor contains the index of the physical chunk for each + logical chunk. + The chunk_offsets tensor contains the offset (AKA starting index) of the + logical chunk in the physical chunk. + + Example: + query_start_loc = [0, 5, 10] + chunk_size = 8 + total_seqlens = 10 + -> chunk_indices = [0, 0, 1] + -> chunk_offsets = [0, 5, 0] + + In this example, we have 2 sequences, each with 5 tokens. The physical + chunk size is 8 tokens. + We have three logical chunks: + - the first logical chunk starts at token 0 in the first physical chunk + and contains all 5 tokens from the first sequence + - the second logical chunk starts at token 5 in the first physical chunk + and contains first 3 tokens from the second sequence + - the third logical chunk starts at token 0 in the second physical chunk + and contains the remaining 2 tokens from the second sequence + """ cu_seqlens = query_start_loc[1:] # remove prepended 0 diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 07ef7cb69a160..9970331a6042c 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -22,12 +22,9 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): - assert isinstance(kv_cache_spec, MambaSpec) - self.kv_cache_spec = kv_cache_spec - self.device = device - self.vllm_config = vllm_config - self.layer_names = layer_names + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + assert isinstance(kv_cache_spec, MambaSpec) self.compilation_config = vllm_config.compilation_config self.decode_cudagraph_max_bs = min( self.vllm_config.scheduler_config.max_num_seqs, @@ -52,4 +49,4 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): m.max_query_len = 1 # decode-only - return self.build(0, m) \ No newline at end of file + return self.build(0, m) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index ec1216a16bc46..a990cb2f1a972 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -381,6 +381,7 @@ class MLACommonMetadata(Generic[D]): num_reqs: int max_query_len: int + max_seq_len: int num_actual_tokens: int # Number of tokens excluding padding. query_start_loc: torch.Tensor @@ -443,11 +444,13 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): self.metadata_cls = metadata_cls \ if metadata_cls is not None else MLACommonMetadata self.kv_cache_spec = kv_cache_spec - self.device = device scheduler_config = vllm_config.scheduler_config self.model_config = vllm_config.model_config - cache_config = vllm_config.cache_config parallel_config = vllm_config.parallel_config + cache_config = vllm_config.cache_config + self.compilation_config = vllm_config.compilation_config + self.device = device + self.num_heads = self.model_config.get_num_attention_heads( parallel_config) self.mla_dims = get_mla_dims(self.model_config) @@ -460,7 +463,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): self.dcp_world_size = 1 self.dcp_rank = 0 - # Dont try to access the runner on AMD + # Don't try to access the runner on AMD if self.aot_schedule: self.page_size = self.kv_cache_spec.block_size @@ -581,7 +584,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): window_left=self._global_hyperparameters.window_left, logits_soft_cap=self._global_hyperparameters.logits_soft_cap, q_data_type=self.model_config.dtype, - kv_data_type=self.kv_cache_spec.dtype, ) # Prepare context prefills @@ -602,16 +604,17 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): logits_soft_cap=self._global_hyperparameters. logits_soft_cap, q_data_type=self.model_config.dtype, - kv_data_type=self.kv_cache_spec.dtype, ) prefill.prefill_main = self._fi_prefill_main prefill.prefill_chunks = self._fi_prefill_chunks - def _build_decode( - self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor, - seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor, - query_start_loc_device: torch.Tensor) -> MLACommonDecodeMetadata: + def _build_decode(self, block_table_tensor: torch.Tensor, + seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + num_decode_tokens: int) -> MLACommonDecodeMetadata: return MLACommonDecodeMetadata( block_table=block_table_tensor, seq_lens=seq_lens_device, @@ -624,11 +627,12 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): Currently, only decode is supported for full cudagraphs with MLA. """ m = common_attn_metadata - assert m.num_reqs == m.num_actual_tokens, \ + assert m.num_reqs <= (m.num_actual_tokens * + self.reorder_batch_threshold), \ "MLA only supports decode-only full CUDAGraph capture. " \ "Make sure all cudagraph capture sizes <= max_num_seq." - assert m.max_query_len == 1 # decode-only + assert m.max_query_len <= self.reorder_batch_threshold # decode only return self.build(0, m) @@ -639,6 +643,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): num_reqs = common_attn_metadata.num_reqs num_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len + max_seq_len = common_attn_metadata.max_seq_len # Note(simon): be careful about the CPU <> GPU memory movement in this # function. We should avoid GPU -> CPU sync as much as possible because @@ -819,11 +824,13 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): seq_lens_device=seq_lens[:num_decodes], query_start_loc_cpu=query_start_loc_cpu[:num_decodes + 1], query_start_loc_device=query_start_loc[:num_decodes + 1], + num_decode_tokens=num_decode_tokens, ) attn_metadata = self.metadata_cls( num_reqs=common_attn_metadata.num_reqs, max_query_len=common_attn_metadata.max_query_len, + max_seq_len=max_seq_len, num_actual_tokens=num_tokens, query_start_loc=query_start_loc, slot_mapping=slot_mapping, @@ -1313,7 +1320,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): k_scale: torch.Tensor, dcp_world_size: int, ): - assert k_scale is None, "DCP not support sacled kvcache now." + assert k_scale is None, "DCP not support scaled kvcache now." assert attn_metadata.prefill is not None prefill_metadata = attn_metadata.prefill assert prefill_metadata.chunked_context is not None diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index 95dce8d8e2eef..6017445402eca 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -76,6 +76,7 @@ g_sm100_workspace = SM100Workspace(128 * 1024 * 1024) # 128MB class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): + can_return_lse_for_decode: bool = True def __init__( self, @@ -138,7 +139,7 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): workspace: torch.Tensor, sm_scale: float, num_kv_splits: int, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor]: assert (q_nope.ndim == 3 ), f"q_nope must be a 3D tensor, but got {q_nope.ndim}" assert ( @@ -193,9 +194,13 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): dtype = (torch.bfloat16 if is_quantized_kv_cache(self.kv_cache_dtype) else q_nope.dtype) out = q_nope.new_empty((B_q, MAX_HEADS, D_latent), dtype=dtype) + lse = (torch.empty( + (B_q, MAX_HEADS), dtype=torch.float32, device=q_nope.device) + if self.need_to_return_lse_for_decode else torch.Tensor()) ops.sm100_cutlass_mla_decode( out, + lse, q_nope, q_pe, kv_c_and_k_pe_cache, @@ -205,7 +210,9 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): sm_scale, num_kv_splits, ) - return out[:, :H].contiguous() + returned_lse = lse[:, :H].contiguous( + ) if self.need_to_return_lse_for_decode else lse + return out[:, :H].contiguous(), returned_lse def _sm100_forward_decode( self, @@ -213,7 +220,7 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None @@ -226,13 +233,18 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): q_nope = q_nope.clone() q_pe = q_pe.clone() - o = self._sm100_cutlass_mla_decode(q_nope, q_pe, kv_c_and_k_pe_cache, - attn_metadata.decode.seq_lens, - attn_metadata.decode.block_table, - self._workspace.get_buf(), - self.scale, self._num_kv_splits) + o, lse = self._sm100_cutlass_mla_decode( + q_nope, + q_pe, + kv_c_and_k_pe_cache, + attn_metadata.decode.seq_lens, + attn_metadata.decode.block_table, + self._workspace.get_buf(), + self.scale, + self._num_kv_splits, + ) - return o + return o, (lse if self.need_to_return_lse_for_decode else None) # TODO: Currently we leave it here only for backup in case something is # wrong with the new SM100 CUTLASS MLA kernel @@ -286,4 +298,4 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): attn_metadata), None return self._sm100_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache, - attn_metadata), None + attn_metadata) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index e2a63c2f577e0..472095e13615b 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -11,17 +11,23 @@ from vllm.attention.backends.abstract import (AttentionLayer, AttentionType, from vllm.attention.utils.fa_utils import (flash_attn_supports_mla, get_flash_attn_version) from vllm.config import VllmConfig +from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger from vllm.v1.attention.backends.mla.common import (MLACommonBackend, MLACommonDecodeMetadata, MLACommonImpl, MLACommonMetadata, MLACommonMetadataBuilder) +from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.kv_cache_interface import AttentionSpec from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata logger = init_logger(__name__) +# NOTE(matt): This is an arbitrary number, copied from +# woosuk's implementation in standard FlashAttention backend +_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16 + class FlashAttnMLABackend(MLACommonBackend): @@ -48,6 +54,7 @@ class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata): max_query_len: int max_seq_len: int scheduler_metadata: Optional[torch.Tensor] = None + max_num_splits: int = 0 @dataclass @@ -57,14 +64,46 @@ class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]): class FlashAttnMLAMetadataBuilder( MLACommonMetadataBuilder[FlashAttnMLAMetadata]): + cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.UNIFORM_BATCH + reorder_batch_threshold: ClassVar[int] = 512 def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): super().__init__(kv_cache_spec, layer_names, vllm_config, device, FlashAttnMLAMetadata) + self.max_num_splits = 0 # No upper bound on the number of splits. self.fa_aot_schedule = (get_flash_attn_version() == 3) + self.use_full_cuda_graph = \ + self.compilation_config.cudagraph_mode.has_full_cudagraphs() + + if self.use_full_cuda_graph and self.fa_aot_schedule: + self.max_cudagraph_size = self.compilation_config.max_capture_size + + if self.max_cudagraph_size > 992: + # This condition derives from FA3's internal heuristic. + # TODO(woosuk): Support larger cudagraph sizes. + raise ValueError( + "Capture size larger than 992 is not supported for " + "full cuda graph.") + + self.scheduler_metadata = torch.zeros( + vllm_config.scheduler_config.max_num_seqs + 1, + dtype=torch.int32, + device=self.device, + ) + # When using cuda graph, we need to set the upper bound of the + # number of splits so that large enough intermediate buffers are + # pre-allocated during capture. + self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH + + # TODO(lucas): Until we add support for the DCP custom masking we need + # to restrict decodes to q_len == 1 when DCP is enabled. + self.__class__.reorder_batch_threshold = 1 \ + if get_dcp_group().world_size > 1 else self.reorder_batch_threshold + def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens, max_seq_len, causal): if self.fa_aot_schedule: @@ -81,14 +120,16 @@ class FlashAttnMLAMetadataBuilder( page_size=self.page_size, cu_seqlens_q=cu_query_lens, causal=causal, + num_splits=self.max_num_splits, ) return None - def _build_decode( - self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor, - seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor, - query_start_loc_device: torch.Tensor - ) -> FlashAttnMLADecodeMetadata: + def _build_decode(self, block_table_tensor: torch.Tensor, + seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + num_decode_tokens: int) -> FlashAttnMLADecodeMetadata: query_lens_cpu = (query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]) max_query_len = query_lens_cpu.max().item() max_seq_len = seq_lens_cpu.max().item() @@ -102,6 +143,29 @@ class FlashAttnMLAMetadataBuilder( causal=True, ) + # For FA3 + full cudagraph + max_num_splits = 0 + if self.use_full_cuda_graph and scheduler_metadata is not None: + n = scheduler_metadata.shape[0] + # Ensure the persistent buffer is large enough + assert n <= self.scheduler_metadata.shape[0], \ + f"Scheduler metadata size {n} exceeds buffer size " + \ + f"{self.scheduler_metadata.shape[0]}" + self.scheduler_metadata[:n] = scheduler_metadata + # NOTE(woosuk): We should zero out the rest of the scheduler + # metadata to guarantee the correctness. Otherwise, some thread + # blocks may use the invalid scheduler metadata and overwrite the + # output buffer. + self.scheduler_metadata[n:] = 0 + scheduler_metadata = self.scheduler_metadata[:n] + + if num_decode_tokens <= self.max_cudagraph_size: + # NOTE(woosuk): Setting num_splits > 1 may increase the memory + # usage, because the intermediate buffers of size [num_splits, + # num_heads, num_tokens, head_size] are allocated. Therefore, + # we only set num_splits when using cuda graphs. + max_num_splits = self.max_num_splits + return FlashAttnMLADecodeMetadata( block_table=block_table_tensor, seq_lens=seq_lens_device, @@ -109,10 +173,12 @@ class FlashAttnMLAMetadataBuilder( max_query_len=max_query_len, max_seq_len=max_seq_len, scheduler_metadata=scheduler_metadata, + max_num_splits=max_num_splits, ) class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]): + can_return_lse_for_decode: bool = True def __init__( self, @@ -175,20 +241,33 @@ class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]): kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank] k_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank:] - o = flash_attn_varlen_func( + # NOTE(matt): During CUDA graph capture, max_query_len can be 0, but the + # kernel uses this to calculate grid dimensions. Ensure it's at least 1 + # to prevent invalid grid configuration during graph capture. + max_seqlen_q = max(attn_metadata.decode.max_query_len, 1) + + attn_out = flash_attn_varlen_func( q=q_pe, k=k_pe_cache.unsqueeze(-2), # Add head dim of 1 v=kv_c_cache.unsqueeze(-2), # Add head dim of 1 q_v=q_nope, - max_seqlen_q=attn_metadata.decode.max_query_len, + max_seqlen_q=max_seqlen_q, cu_seqlens_q=attn_metadata.decode.query_start_loc, max_seqlen_k=attn_metadata.decode.max_seq_len, seqused_k=attn_metadata.decode.seq_lens, block_table=attn_metadata.decode.block_table, softmax_scale=self.scale, causal=True, + return_softmax_lse=self.need_to_return_lse_for_decode, fa_version=3, # only version 3 is supported scheduler_metadata=attn_metadata.decode.scheduler_metadata, + num_splits=attn_metadata.decode.max_num_splits, ) - return self._v_up_proj(o) + if self.need_to_return_lse_for_decode: + o, lse = attn_out + # FA returns LSE in shape [ H, B ] but DCP wants [ B, H ] + return o, lse.transpose(0, 1) # [ H, B ] -> [ B, H ] + else: + o = attn_out + return o, None diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py new file mode 100644 index 0000000000000..701248670f72e --- /dev/null +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -0,0 +1,114 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional, Union + +import torch +from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla + +from vllm.attention.backends.abstract import AttentionLayer, AttentionType +from vllm.logger import init_logger +from vllm.v1.attention.backends.mla.common import (MLACommonBackend, + MLACommonImpl, + MLACommonMetadata) + +logger = init_logger(__name__) + +FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024 + + +class FlashInferMLABackend(MLACommonBackend): + + @staticmethod + def get_name() -> str: + return "FLASHINFER_MLA" + + @staticmethod + def get_impl_cls() -> type["FlashInferMLAImpl"]: + return FlashInferMLAImpl + + +g_fi_workspace = torch.zeros( + FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device="cuda", +) + + +class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + logits_soft_cap, attn_type, + kv_sharing_target_layer_name, **mla_args) + + unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] + if any(unsupported_features): + raise NotImplementedError( + "FlashInferMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, logits_soft_cap") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashInferMLAImpl") + + self._workspace_buffer = g_fi_workspace + self.bmm1_scale: Optional[float] = None + self.bmm2_scale: Optional[float] = None + + def _forward_decode( + self, + q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata, + layer: AttentionLayer, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + assert kv_c_and_k_pe_cache.numel() > 0 + assert attn_metadata.decode is not None + + if isinstance(q, tuple): + q_nope, q_pe = q + q = torch.cat([q_nope, q_pe], dim=-1) + + # trtllm API requires extra dimension q_len_per_request for MTP + q = q.unsqueeze(1) + + if self.bmm1_scale is None: + self.bmm1_scale = (layer._q_scale_float * layer._k_scale_float * + self.scale) + if self.bmm2_scale is None: + self.bmm2_scale = layer._v_scale_float + + o = trtllm_batch_decode_with_kv_cache_mla( + query=q, + kv_cache=kv_c_and_k_pe_cache.unsqueeze(1), + workspace_buffer=self._workspace_buffer, + qk_nope_head_dim=self.qk_nope_head_dim, + kv_lora_rank=self.kv_lora_rank, + qk_rope_head_dim=self.qk_rope_head_dim, + block_tables=attn_metadata.decode.block_table, + seq_lens=attn_metadata.decode.seq_lens, + max_seq_len=attn_metadata.max_seq_len, + bmm1_scale=self.bmm1_scale, + bmm2_scale=self.bmm2_scale, + ) + + # TODO: Return LSE pending support from Flashinfer API: + # https://github.com/flashinfer-ai/flashinfer/pull/1566 + return o, None diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 1824bbadb6a1a..549af1a062252 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -12,6 +12,7 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, is_flashmla_supported) from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.platforms.cuda import CudaPlatform from vllm.v1.attention.backends.mla.common import (MLACommonBackend, MLACommonDecodeMetadata, MLACommonImpl, @@ -62,7 +63,6 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): super().__init__(kv_cache_spec, layer_names, vllm_config, device, FlashMLAMetadata) - self.compilation_config = vllm_config.compilation_config self.num_q_heads = vllm_config.model_config.get_num_attention_heads( vllm_config.parallel_config) @@ -85,10 +85,12 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): device=self.device, dtype=torch.int32) - def _build_decode( - self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor, - seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor, - query_start_loc_device: torch.Tensor) -> FlashMLADecodeMetadata: + def _build_decode(self, block_table_tensor: torch.Tensor, + seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + num_decode_tokens: int) -> FlashMLADecodeMetadata: tile_scheduler_metadata, num_splits = \ get_mla_metadata( seq_lens_device, @@ -157,6 +159,16 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): assert is_flashmla_supported(), \ "FlashMLA is not supported on this device" + # disallow FlashMLA on NVIDIA Blackwell (SM 10.0+) GPUs + # context: + # https://github.com/deepseek-ai/FlashMLA/issues/83 + # https://github.com/vllm-project/vllm/issues/24513 + if CudaPlatform.has_device_capability(100): + raise NotImplementedError( + "FlashMLA is temporarily disabled on Blackwell (SM 10.0). " + "Please use CUTLASS_MLA or TRITON_MLA instead. " + "Example: `export VLLM_ATTENTION_BACKEND=CUTLASS_MLA`") + unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): raise NotImplementedError( diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index fc6b1998e8eb0..db27a34d8959a 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -104,10 +104,12 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): dtype=torch.int32, device=device) - def _build_decode( - self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor, - seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor, - query_start_loc_device: torch.Tensor) -> AiterMLADecodeMetadata: + def _build_decode(self, block_table_tensor: torch.Tensor, + seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + num_decode_tokens: int) -> AiterMLADecodeMetadata: page_size = self.kv_cache_spec.block_size block_table_bounds = (seq_lens_device + page_size - 1) // page_size device = self.device diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 173a0a255e491..a4e2758bd311f 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -236,11 +236,11 @@ class AiterFlashAttentionMetadataBuilder( def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): - self.vllm_config = vllm_config + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + self.model_config = vllm_config.model_config self.parallel_config = vllm_config.parallel_config self.cache_config = vllm_config.cache_config - self.device = device self.num_heads_q = self.model_config.get_num_attention_heads( self.parallel_config) @@ -248,7 +248,6 @@ class AiterFlashAttentionMetadataBuilder( self.parallel_config) self.headdim = self.model_config.get_head_size() self.block_size = kv_cache_spec.block_size - self.kv_cache_spec = kv_cache_spec # Sliding window size to be used with the AOT scheduler will be # populated on first build() call. self.aot_sliding_window: Optional[tuple[int, int]] = None diff --git a/vllm/v1/attention/backends/short_conv_attn.py b/vllm/v1/attention/backends/short_conv_attn.py index fcbf0c7b53560..f5ad65b02b4d4 100644 --- a/vllm/v1/attention/backends/short_conv_attn.py +++ b/vllm/v1/attention/backends/short_conv_attn.py @@ -45,8 +45,8 @@ class ShortConvAttentionMetadataBuilder( def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) assert isinstance(kv_cache_spec, MambaSpec) - self.kv_cache_spec = kv_cache_spec def build(self, common_prefix_len: int, diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index d0b163fc9bed2..6c7feab57be83 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -160,7 +160,8 @@ class TreeAttentionMetadataBuilder( vllm_config: VllmConfig, device: torch.device, ): - self.kv_cache_spec = kv_cache_spec + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + self.block_size = kv_cache_spec.block_size spec_config = vllm_config.speculative_config diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 104cebb45d740..c294a5a73cbdd 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -15,6 +15,8 @@ from vllm.attention.ops.chunked_prefill_paged_decode import ( from vllm.attention.ops.paged_attn import PagedAttention from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, kFp8StaticTensorSym) from vllm.platforms import current_platform from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.utils import (AttentionCGSupport, @@ -66,9 +68,9 @@ class TritonAttentionMetadataBuilder( def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): - self.device = device + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + self.block_size = kv_cache_spec.block_size - self.kv_cache_spec = kv_cache_spec model_config = vllm_config.model_config self.num_heads_q = model_config.get_num_attention_heads( @@ -202,6 +204,9 @@ def use_aiter_unified_attention() -> bool: class TritonAttentionImpl(AttentionImpl): + def fused_output_quant_supported(self, quant_key: QuantKey): + return quant_key == kFp8StaticTensorSym + def __init__( self, num_heads: int, @@ -297,9 +302,9 @@ class TritonAttentionImpl(AttentionImpl): """ assert output is not None, "Output tensor must be provided." - if output_scale is not None or output_block_scale is not None: + if output_block_scale is not None: raise NotImplementedError( - "fused output quantization is not yet supported" + "fused block_scale output quantization is not yet supported" " for TritonAttentionImpl") if attn_metadata is None: @@ -394,6 +399,7 @@ class TritonAttentionImpl(AttentionImpl): alibi_slopes=self.alibi_slopes, sliding_window=self.sliding_window[0], sm_scale=self.scale, + output_scale=output_scale, sinks=self.sinks, ) @@ -419,6 +425,6 @@ class TritonAttentionImpl(AttentionImpl): k_descale=layer._k_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape), sinks=self.sinks, - ) + output_scale=output_scale) return output diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 79fcd928393f9..c1814b4ba27cc 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -72,6 +72,9 @@ class CommonAttentionMetadata: logits_indices_padded: Optional[torch.Tensor] = None num_logits_indices: Optional[int] = None + # Needed by CrossAttentionBuilder + encoder_seq_lens: Optional[np.ndarray] = None + @dataclass class UbatchSlice: @@ -193,6 +196,9 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): self.kv_cache_spec = kv_cache_spec + self.layer_names = layer_names + self.vllm_config = vllm_config + self.device = device @abstractmethod def build(self, @@ -542,7 +548,14 @@ def make_local_attention_virtual_batches( 1) batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32), local_blocks * pages_per_local_batch) - block_table_local = block_table[batch_indices, block_indices]\ + + # NOTE: https://github.com/pytorch/pytorch/pull/160256 causes performance + # regression when using numpy arrays (batch and block indices) to index into + # torch tensor (block_table). As a workaround, convert numpy arrays to torch + # tensor first, which recovers perf. + batch_indices_torch = torch.from_numpy(batch_indices) + block_indices_torch = torch.from_numpy(block_indices) + block_table_local = block_table[batch_indices_torch, block_indices_torch]\ .view(virtual_batches, -1) query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local) diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py index eb738e1120e84..371220a9f72b7 100644 --- a/vllm/v1/attention/backends/xformers.py +++ b/vllm/v1/attention/backends/xformers.py @@ -202,8 +202,9 @@ class XFormersAttentionMetadataBuilder( vllm_config: VllmConfig, device: torch.device, ): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + assert XFORMERS_AVAILABLE - self.kv_cache_spec = kv_cache_spec self.block_size = kv_cache_spec.block_size self._num_decodes = 0 self._num_decode_tokens = 0 diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index b537cac8e1d72..d1e1c1c8d0382 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -9,7 +9,11 @@ from vllm.distributed.kv_events import (MEDIUM_GPU, AllBlocksCleared, KVCacheEvent) from vllm.logger import init_logger from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, - FreeKVCacheBlockQueue, KVCacheBlock) + ExternalBlockHash, + FreeKVCacheBlockQueue, KVCacheBlock, + get_block_hash, + make_block_hash_with_group_id, + maybe_convert_block_hash) from vllm.v1.request import Request logger = init_logger(__name__) @@ -84,8 +88,10 @@ class BlockPool: """ cached_blocks = [] for group_id in kv_cache_group_ids: + block_hash_with_group_id = make_block_hash_with_group_id( + block_hash, group_id) cached_blocks_one_group = self.cached_block_hash_to_block.get( - BlockHashWithGroupId(block_hash, group_id)) + block_hash_with_group_id) if not cached_blocks_one_group: return None first_block = next(iter(cached_blocks_one_group.values())) @@ -124,28 +130,29 @@ class BlockPool: assert len(request.block_hashes) >= num_full_blocks new_block_hashes = request.block_hashes[num_cached_blocks:] - new_hashes: Optional[list[int]] = ([] if self.enable_kv_cache_events - else None) + new_hashes: Optional[list[ExternalBlockHash]] = ( + [] if self.enable_kv_cache_events else None) for i, blk in enumerate(new_full_blocks): assert blk.block_hash is None block_hash = new_block_hashes[i] # Update and added the full block to the cache. - block_hash_with_group_id = BlockHashWithGroupId( + block_hash_with_group_id = make_block_hash_with_group_id( block_hash, kv_cache_group_id) blk.block_hash = block_hash_with_group_id self.cached_block_hash_to_block[block_hash_with_group_id][ blk.block_id] = blk if new_hashes is not None: - new_hashes.append(block_hash.hash_value) + new_hashes.append(maybe_convert_block_hash(block_hash)) if self.enable_kv_cache_events: if num_cached_blocks == 0: - parent_block_hash = None + parent_block_hash: Optional[ExternalBlockHash] = None else: parent_block = blocks[num_cached_blocks - 1] assert parent_block.block_hash is not None - parent_block_hash = parent_block.block_hash.get_hash_value() + parent_block_hash = maybe_convert_block_hash( + get_block_hash(parent_block.block_hash)) self.kv_event_queue.append( BlockStored( @@ -220,7 +227,9 @@ class BlockPool: # we disable hybrid kv cache manager when kv cache event is # enabled, so there is only one group. self.kv_event_queue.append( - BlockRemoved(block_hashes=[block_hash.get_hash_value()], + BlockRemoved(block_hashes=[ + maybe_convert_block_hash(get_block_hash(block_hash)) + ], medium=MEDIUM_GPU)) return True diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py index bd2ec036834b2..eadea15a2e5e3 100644 --- a/vllm/v1/core/encoder_cache_manager.py +++ b/vllm/v1/core/encoder_cache_manager.py @@ -86,7 +86,7 @@ class EncoderCacheManager: Returns: True if the encoder output for this input is already cached """ - mm_hash = request.mm_hashes[input_id] + mm_hash = request.mm_features[input_id].identifier # Not cached at all if mm_hash not in self.cached: return False @@ -167,7 +167,7 @@ class EncoderCacheManager: This method assumes can_allocate() returned True for the same input. """ - mm_hash = request.mm_hashes[input_id] + mm_hash = request.mm_features[input_id].identifier request_id = request.request_id if mm_hash not in self.cached: self.cached[mm_hash] = set() @@ -193,8 +193,8 @@ class EncoderCacheManager: """ return { input_id - for input_id in range(len(request.mm_hashes)) - if request.mm_hashes[input_id] in self.cached + for input_id in range(len(request.mm_features)) + if request.mm_features[input_id].identifier in self.cached } def free_encoder_input(self, request: Request, input_id: int) -> None: @@ -208,7 +208,7 @@ class EncoderCacheManager: `can_allocate`). """ req_id = request.request_id - mm_hash = request.mm_hashes[input_id] + mm_hash = request.mm_features[input_id].identifier # The mm_hash not in cache or the req_id set is empty if not self.cached.get(mm_hash, None): return diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index aff1183e499a4..f939da8c5b5c3 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -6,11 +6,12 @@ import os from collections import defaultdict, deque from collections.abc import Iterable, Sequence from dataclasses import astuple, dataclass -from typing import Any, Callable, NamedTuple, Optional +from typing import Any, Callable, NewType, Optional, Union +from vllm import envs from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.utils import GiB_bytes, cdiv, sha256_cbor_64bit +from vllm.utils import GiB_bytes, cdiv, sha256_cbor from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheSpec, @@ -18,59 +19,78 @@ from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request -logger = init_logger(__name__) +# BlockHash represents the hash of a single KV-cache block used for +# prefix caching. Treating it as a distinct type from ``bytes`` helps +# catch accidental misuse when passing around raw byte strings. +BlockHash = NewType("BlockHash", bytes) + +# ``BlockHashWithGroupId`` combines a ``BlockHash`` with its KV cache group ID. +# It is represented as raw bytes for compactness and efficiency. The helper +# functions below pack/unpack the ``BlockHash`` and group id into/from the key. +BlockHashWithGroupId = NewType("BlockHashWithGroupId", bytes) + +# ExternalBlockHash is used for reproducible prefix-cache block hashing. +# It's a union of ``bytes`` and ``int`` to keep backward compatibility +# after we default block hashing to use sha256 bytes. +ExternalBlockHash = Union[bytes, int] -class BlockHash(NamedTuple): - """Hash value of a block (int), the token IDs in the block, and extra keys. - We keep a tuple of token IDs and extra keys to reduce the likelihood of - hash collisions when the hash value is the same. By using SHA256 however, - hash collisions are practically impossible. +def make_block_hash_with_group_id(block_hash: BlockHash, + group_id: int) -> BlockHashWithGroupId: + """Pack a ``BlockHash`` and group id into a ``BlockHashWithGroupId``. + + The group id is encoded using 4 bytes in big-endian order and appended to + the block hash bytes. This representation avoids creating tuples while + still allowing us to recover both components when needed. """ - # Hash value of the block in an integer. - hash_value: int - # Token IDs in the block. - token_ids: tuple[int, ...] - # Extra keys for the block. - extra_keys: Optional[Any] = None + return BlockHashWithGroupId(block_hash + + group_id.to_bytes(4, "big", signed=False)) -class BlockHashWithGroupId(NamedTuple): - # The hash value for the contents (e.g., token_ids) of a block without group - # ID. The value is the same for blocks representing the same tokens but for - # different groups. - block_hash: BlockHash - # The KV cache group ID. - group_id: int +def get_block_hash(key: BlockHashWithGroupId) -> BlockHash: + """Extract the ``BlockHash`` from a ``BlockHashWithGroupId``.""" + return BlockHash(key[:-4]) - def get_hash_value(self) -> int: - return self.block_hash.hash_value +def get_group_id(key: BlockHashWithGroupId) -> int: + """Extract the group id from a ``BlockHashWithGroupId``.""" + return int.from_bytes(key[-4:], "big", signed=False) + + +def maybe_convert_block_hash(hash_bytes: BlockHash) -> ExternalBlockHash: + if not envs.VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: + return hash_bytes + return int.from_bytes(hash_bytes, byteorder="big") & ((1 << 64) - 1) + + +logger = init_logger(__name__) # The hash seed for the first block of any prefix block sequence. # # We use a random value to avoid hash collisions or PYTHONHASHSEED environment -# variable if set such that processes can share the seed if needed. -# This aligns with the behavior of Python's hash() function, which also uses -# a random seed if PYTHONHASHSEED is not set. +# variable if set such that processes can share the seed if needed. This aligns +# with the behavior of Python's hash() function, which also uses a random seed +# if PYTHONHASHSEED is not set. # # The function `init_none_hash` initializes this variable globally. -NONE_HASH: int +NONE_HASH: BlockHash -def init_none_hash(hash_fn: Callable): +def init_none_hash(hash_fn: Callable[[Any], bytes]): global NONE_HASH hash_seed = os.getenv("PYTHONHASHSEED") - if hash_seed is None and hash_fn is sha256_cbor_64bit: + if hash_seed is None and hash_fn is sha256_cbor: logger.warning( "PYTHONHASHSEED is not set. This will lead to non-reproducible " - "block-hashes when using sha256_cbor_64bit as the hash function." + "block-hashes when using sha256_cbor as the hash function." "Consider setting PYTHONHASHSEED to a fixed value for " "reproducibility.") - NONE_HASH = (int.from_bytes(os.urandom(32), byteorder="big") - if hash_seed is None else hash_fn(hash_seed)) + if hash_seed is None: + NONE_HASH = BlockHash(os.urandom(32)) + else: + NONE_HASH = BlockHash(hash_fn(hash_seed)) class PrefixCachingMetrics: @@ -142,8 +162,8 @@ class KVCacheBlock: block_id: int # Reference count. ref_cnt: int = 0 - # The hash of the block composed of (block hash, tuple of token IDs). - # It is only available when the block is full. + # The hash key (block hash + group id) of the block, only available + # when the block is full and cached. _block_hash: Optional[BlockHashWithGroupId] = None # Used to construct a doubly linked list for free blocks. @@ -177,7 +197,7 @@ class KVCacheBlock: if self.next_free_block else None) return (f"KVCacheBlock(block_id={self.block_id}, " f"ref_cnt={self.ref_cnt}, " - f"_block_hash={self._block_hash}, " + f"_block_hash={self._block_hash!r}, " f"prev_free_block={prev_block_id}, " f"next_free_block={next_block_id})") @@ -398,9 +418,9 @@ def need_extra_keys(request: Request) -> bool: # Multimodal requests need to include the MM hash. # LoRA requests need to include the LoRA ID. # Request with provided cache salt need to include the salt. - return bool(request.mm_hashes) or (request.lora_request - is not None) or (request.cache_salt - is not None) + return bool(request.mm_features) or (request.lora_request + is not None) or (request.cache_salt + is not None) def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, @@ -422,32 +442,28 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, """ extra_keys: list[Any] = [] - mm_positions, mm_hashes = request.mm_positions, request.mm_hashes - if not mm_positions: + mm_features = request.mm_features + if not mm_features: return extra_keys, start_mm_idx - if mm_positions and len(mm_positions) != len(mm_hashes): - raise ValueError( - "The number of multi-modal positions and hashes must match. This " - "is likely because you did not enable MM hashing. " - "Please set `mm_processor_cache_gb > 0`.") - - # Note that we assume mm_positions is sorted by offset. + # Note that we assume mm_features are sorted by mm_position.offset. # We do not need to check all mm inputs if the start token index is out of # range. This usually happens in the late prefill phase and decoding phase. - if mm_positions[-1].offset + mm_positions[-1].length < start_token_idx: + last_pos = mm_features[-1].mm_position + if last_pos.offset + last_pos.length < start_token_idx: return extra_keys, start_mm_idx # Support start_mm_idx == -1 to indicate the last mm input. if start_mm_idx < 0: - assert -start_mm_idx <= len(mm_positions) - start_mm_idx = len(mm_positions) + start_mm_idx + assert -start_mm_idx <= len(mm_features) + start_mm_idx = len(mm_features) + start_mm_idx curr_mm_idx = start_mm_idx - while mm_positions and curr_mm_idx < len(mm_positions): - assert mm_hashes[curr_mm_idx] is not None - offset = mm_positions[curr_mm_idx].offset - length = mm_positions[curr_mm_idx].length + while mm_features and curr_mm_idx < len(mm_features): + mm_feature = mm_features[curr_mm_idx] + assert mm_feature.identifier is not None + offset = mm_feature.mm_position.offset + length = mm_feature.mm_position.length if end_token_idx > offset: if start_token_idx > offset + length: # This block has passed the current mm input. @@ -455,7 +471,7 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, continue # The block contains the current mm input. - extra_keys.append(mm_hashes[curr_mm_idx]) + extra_keys.append(mm_feature.identifier) if end_token_idx >= offset + length: # If this block contains the end of the current mm input, @@ -517,15 +533,14 @@ def generate_block_hash_extra_keys( def hash_block_tokens( - hash_function: Callable, - parent_block_hash: Optional[int], + hash_function: Callable[[Any], bytes], + parent_block_hash: Optional[BlockHash], curr_block_token_ids: Sequence[int], extra_keys: Optional[tuple[Any, ...]] = None) -> BlockHash: """Computes a hash value corresponding to the contents of a block and the contents of the preceding block(s). The hash value is used for prefix caching. We use LRU cache for this function to avoid recomputing hash values for the same block contents. - Args: hash_function: The hash function used to compute block hash. parent_block_hash: The hash of the parent block. None @@ -533,7 +548,6 @@ def hash_block_tokens( curr_block_token_ids: A list of token ids in the current block. The current block is assumed to be full. extra_keys: Extra keys for the block. - Returns: The hash value of the block and the token ids in the block. The entire tuple is used as the hash key of the block. @@ -544,26 +558,16 @@ def hash_block_tokens( curr_block_token_ids_tuple = tuple(curr_block_token_ids) return BlockHash( hash_function( - (parent_block_hash, curr_block_token_ids_tuple, extra_keys)), - curr_block_token_ids_tuple, extra_keys) + (parent_block_hash, curr_block_token_ids_tuple, extra_keys))) def get_request_block_hasher( block_size: int, - caching_hash_fn: Callable[[Any], - int]) -> Callable[[Request], list[BlockHash]]: + caching_hash_fn: Callable[[Any], bytes], +) -> Callable[[Request], list[BlockHash]]: """ Returns a function which computes the list of un-computed block hashes - of a request. - - Each request holds a list of its block hashes (request.block_hashes). - When a request is created, it calls the below function to compute - the hashes of all full blocks of the request's initial tokens. - The hashes are then stored in request.block_hashes. - Later, whenever new tokens are appended to the request, it calls - the below function again to compute any new full blocks of tokens. - The returned new hashes are appended to request.block_hashes. - """ + of a request.""" def request_block_hasher(request: Request) -> list[BlockHash]: start_token_idx = len(request.block_hashes) * block_size @@ -577,8 +581,8 @@ def get_request_block_hasher( # last mm input. curr_mm_idx = -1 - prev_block_hash_value = request.block_hashes[-1].hash_value \ - if request.block_hashes else None + prev_block_hash_value = (request.block_hashes[-1] + if request.block_hashes else None) new_block_hashes: list[BlockHash] = [] while True: end_token_idx = start_token_idx + block_size @@ -598,7 +602,7 @@ def get_request_block_hasher( new_block_hashes.append(block_hash) start_token_idx += block_size - prev_block_hash_value = block_hash.hash_value + prev_block_hash_value = block_hash return new_block_hashes diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index b5cd6c5c8af51..56ab396d6d937 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -6,6 +6,8 @@ from __future__ import annotations from dataclasses import dataclass from typing import TYPE_CHECKING, Optional +from vllm import bc_linter_include + if TYPE_CHECKING: import numpy as np import numpy.typing as npt @@ -13,20 +15,19 @@ if TYPE_CHECKING: from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorMetadata) from vllm.lora.request import LoRARequest - from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange + from vllm.multimodal.inputs import MultiModalFeatureSpec from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.v1.request import Request +@bc_linter_include @dataclass class NewRequestData: req_id: str prompt_token_ids: list[int] - mm_kwargs: list[MultiModalKwargsItem] - mm_hashes: list[str] - mm_positions: list[PlaceholderRange] + mm_features: list[MultiModalFeatureSpec] sampling_params: Optional[SamplingParams] pooling_params: Optional[PoolingParams] block_ids: tuple[list[int], ...] @@ -42,9 +43,7 @@ class NewRequestData: return cls( req_id=request.request_id, prompt_token_ids=request.prompt_token_ids, - mm_kwargs=request.mm_kwargs, - mm_hashes=request.mm_hashes, - mm_positions=request.mm_positions, + mm_features=request.mm_features, sampling_params=request.sampling_params, pooling_params=request.pooling_params, block_ids=block_ids, @@ -56,9 +55,7 @@ class NewRequestData: return (f"NewRequestData(" f"req_id={self.req_id}," f"prompt_token_ids={self.prompt_token_ids}," - f"mm_kwargs={self.mm_kwargs}," - f"mm_hashes={self.mm_hashes}," - f"mm_positions={self.mm_positions}," + f"mm_features={self.mm_features}," f"sampling_params={self.sampling_params}," f"block_ids={self.block_ids}," f"num_computed_tokens={self.num_computed_tokens}," @@ -70,9 +67,7 @@ class NewRequestData: return (f"NewRequestData(" f"req_id={self.req_id}," f"prompt_token_ids_len={len(self.prompt_token_ids)}," - f"mm_kwargs={self.mm_kwargs}," - f"mm_hashes={self.mm_hashes}," - f"mm_positions={self.mm_positions}," + f"mm_features={self.mm_features}," f"sampling_params={self.sampling_params}," f"block_ids={self.block_ids}," f"num_computed_tokens={self.num_computed_tokens}," @@ -80,6 +75,7 @@ class NewRequestData: ")") +@bc_linter_include @dataclass class CachedRequestData: @@ -109,6 +105,7 @@ class CachedRequestData: ) +@bc_linter_include @dataclass class SchedulerOutput: diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 31f7e9c70f8b3..c1e59423e9a18 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -144,8 +144,8 @@ class Scheduler(SchedulerInterface): ) # NOTE(woosuk): Here, "encoder" includes the vision encoder (and - # projector if needed). Currently, we assume that the encoder also - # has the Transformer architecture (e.g., ViT). + # projector if needed) for MM models as well as encoder-decoder + # transformers. self.max_num_encoder_input_tokens = encoder_compute_budget # NOTE: For the models without encoder (e.g., text-only models), # the encoder cache will not be initialized because cache size is 0 @@ -387,6 +387,14 @@ class Scheduler(SchedulerInterface): self.connector.get_num_new_matched_tokens( request, num_new_local_computed_tokens)) + if num_external_computed_tokens is None: + # The request cannot be scheduled because + # the KVConnector couldn't determine + # the number of matched tokens. + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue + # Total computed tokens (local + external). num_computed_tokens = (num_new_local_computed_tokens + num_external_computed_tokens) @@ -728,18 +736,18 @@ class Scheduler(SchedulerInterface): if num_new_tokens == 0 or not request.has_encoder_inputs: return [], num_new_tokens, encoder_compute_budget encoder_inputs_to_schedule: list[int] = [] - mm_positions = request.mm_positions - assert mm_positions is not None - assert len(mm_positions) > 0 + mm_features = request.mm_features + assert mm_features is not None + assert len(mm_features) > 0 # NOTE: since scheduler operates on the request level (possibly with # multiple encoder inputs per request), we need to create temporary # trackers for accounting at the encoder input level. mm_hashes_to_schedule = set() num_tokens_to_schedule = 0 - for i, pos_info in enumerate(mm_positions): - start_pos = pos_info.offset - num_encoder_tokens = pos_info.length + for i, mm_feature in enumerate(mm_features): + start_pos = mm_feature.mm_position.offset + num_encoder_tokens = mm_feature.mm_position.length # The encoder output is needed if the two ranges overlap: # [num_computed_tokens, num_computed_tokens + num_new_tokens) and @@ -767,15 +775,19 @@ class Scheduler(SchedulerInterface): # in the decoder's KV cache. continue - # The same encoder input has already been scheduled in the current - # step. - if request.mm_hashes[i] in mm_hashes_to_schedule: - continue + if not self.is_encoder_decoder: + # We are not using the encoder cache for encoder-decoder models, + # yet. + if request.mm_features[i].identifier in mm_hashes_to_schedule: + # The same encoder input has already been scheduled in the + # current step. + continue - if self.encoder_cache_manager.check_and_update_cache(request, i): - # The encoder input is already computed and cached from a - # previous step. - continue + if self.encoder_cache_manager.check_and_update_cache( + request, i): + # The encoder input is already computed and cached from a + # previous step. + continue # If no encoder input chunking is allowed, we do not want to # partially schedule a multimodal item. If the scheduled range would @@ -808,7 +820,7 @@ class Scheduler(SchedulerInterface): num_tokens_to_schedule += num_encoder_tokens encoder_compute_budget -= num_encoder_tokens - mm_hashes_to_schedule.add(request.mm_hashes[i]) + mm_hashes_to_schedule.add(request.mm_features[i].identifier) encoder_inputs_to_schedule.append(i) return ( @@ -957,9 +969,9 @@ class Scheduler(SchedulerInterface): stop_reason=request.stop_reason, events=request.take_events(), kv_transfer_params=kv_transfer_params, + trace_headers=request.trace_headers, num_cached_tokens=request.num_cached_tokens, )) - else: # Invariant: EngineCore returns no partial prefill outputs. assert not prompt_logprobs_tensors @@ -1036,10 +1048,16 @@ class Scheduler(SchedulerInterface): # Here, we use list(set) to avoid modifying the set while iterating # over it. for input_id in list(cached_encoder_input_ids): - mm_positions = request.mm_positions[input_id] - start_pos = mm_positions.offset - num_tokens = mm_positions.length - if start_pos + num_tokens <= request.num_computed_tokens: + mm_feature = request.mm_features[input_id] + start_pos = mm_feature.mm_position.offset + num_tokens = mm_feature.mm_position.length + if self.is_encoder_decoder and request.num_computed_tokens > 0: + # With Whisper, as soon as we've generated a single token, + # we know we're done with the encoder input. Cross Attention + # KVs have been calculated and cached already. + self.encoder_cache_manager.free_encoder_input( + request, input_id) + elif start_pos + num_tokens <= request.num_computed_tokens: # The encoder output is already processed and stored # in the decoder's KV cache. self.encoder_cache_manager.free_encoder_input( @@ -1188,6 +1206,8 @@ class Scheduler(SchedulerInterface): def shutdown(self) -> None: if self.kv_event_publisher: self.kv_event_publisher.shutdown() + if self.connector is not None: + self.connector.shutdown() ######################################################################## # KV Connector Related Methods diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 8159349e46758..d27239164b0db 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -559,12 +559,48 @@ class MambaManager(SingleTypeKVCacheManager): num_running_requests: int) -> int: return 0 + def get_num_blocks_to_allocate( + self, request_id: str, num_tokens: int, + new_computed_blocks: list[KVCacheBlock]) -> int: + """ + Get the number of blocks needed to be allocated for the request. + + Args: + request_id: The request ID. + num_tokens: The total number of tokens that need a slot (including + tokens that are already allocated). + new_computed_blocks: The new computed blocks just hitting the + prefix caching. + + Returns: + The number of blocks + """ + + assert isinstance(self.kv_cache_spec, MambaSpec) + if self.kv_cache_spec.num_speculative_blocks > 0: + num_tokens += (self.kv_cache_spec.block_size * + self.kv_cache_spec.num_speculative_blocks) + num_required_blocks = cdiv(num_tokens, self.block_size) + num_new_blocks = (num_required_blocks - len(new_computed_blocks) - + len(self.req_to_blocks[request_id])) + # If a computed block of a request is an eviction candidate (in the + # free queue and ref_cnt == 0), it will be changed from a free block + # to a computed block when the request is allocated, so we also count + # it as needed to be allocated. + num_evictable_computed_blocks = sum( + blk.ref_cnt == 0 and not blk.is_null + for blk in new_computed_blocks) + return num_new_blocks + num_evictable_computed_blocks + def allocate_new_blocks(self, request_id: str, num_tokens: int) -> list[KVCacheBlock]: - new_blocks = super().allocate_new_blocks(request_id, num_tokens) - assert len(self.req_to_blocks[request_id]) == 1, ( - "MambaManager should only allocate 1 block for each request.") - return new_blocks + # Allocate extra `num_speculative_blocks` blocks for + # speculative decoding (MTP/EAGLE) with linear attention. + assert isinstance(self.kv_cache_spec, MambaSpec) + if self.kv_cache_spec.num_speculative_blocks > 0: + num_tokens += (self.kv_cache_spec.block_size * + self.kv_cache_spec.num_speculative_blocks) + return super().allocate_new_blocks(request_id, num_tokens) class CrossAttentionManager(SingleTypeKVCacheManager): diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 5d8959a3cd3fe..dec4abec519bd 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -3,6 +3,7 @@ import enum import time +from collections.abc import Mapping from typing import Any, Optional, Union import msgspec @@ -66,6 +67,8 @@ class EngineCoreRequest( current_wave: int = 0 priority: int = 0 + trace_headers: Optional[Mapping[str, str]] = None + class EngineCoreEventType(enum.IntEnum): """The type of engine core request event.""" @@ -111,6 +114,7 @@ class EngineCoreOutput( events: Optional[list[EngineCoreEvent]] = None kv_transfer_params: Optional[dict[str, Any]] = None + trace_headers: Optional[Mapping[str, str]] = None # The number of tokens with prefix cache hits. num_cached_tokens: int = 0 @@ -144,7 +148,7 @@ class EngineCoreOutputs( omit_defaults=True, # type: ignore[call-arg] gc=False): # type: ignore[call-arg] - #NOTE(Nick): We could consider ways to make this more compact, + # NOTE(Nick): We could consider ways to make this more compact, # e.g. columnwise layout engine_index: int = 0 diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index d23602eaaffa9..a9ced402b974f 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -26,6 +26,7 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.tasks import SupportedTask +from vllm.tracing import init_tracer from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -97,6 +98,7 @@ class AsyncLLM(EngineClient): self.model_config = vllm_config.model_config self.vllm_config = vllm_config + self.observability_config = vllm_config.observability_config self.log_requests = log_requests self.log_stats = log_stats or (stat_loggers is not None) @@ -124,6 +126,11 @@ class AsyncLLM(EngineClient): # OutputProcessor (converts EngineCoreOutputs --> RequestOutput). self.output_processor = OutputProcessor(self.tokenizer, log_stats=self.log_stats) + if self.observability_config.otlp_traces_endpoint is not None: + tracer = init_tracer( + "vllm.llm_engine", + self.observability_config.otlp_traces_endpoint) + self.output_processor.tracer = tracer # EngineCore (starts the engine in background process). self.engine_core = EngineCoreClient.make_async_mp_client( @@ -143,6 +150,7 @@ class AsyncLLM(EngineClient): engine_idxs=self.engine_core.engine_ranks_managed, custom_stat_loggers=stat_loggers, enable_default_loggers=log_stats, + client_count=client_count, ) self.logger_manager.log_engine_initialized() @@ -169,9 +177,6 @@ class AsyncLLM(EngineClient): worker_name=worker_name, use_gzip=True)) else: - logger.info( - "Torch profiler disabled. AsyncLLM CPU traces will not be collected." # noqa: E501 - ) self.profiler = None @classmethod @@ -602,7 +607,7 @@ class AsyncLLM(EngineClient): return self.tokenizer.get_lora_tokenizer(lora_request) async def is_tracing_enabled(self) -> bool: - return False + return self.observability_config.otlp_traces_endpoint is not None async def do_log_stats( self, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index e239e6cbba167..995e70385be89 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -23,7 +23,7 @@ from vllm.logger import init_logger from vllm.logging_utils.dump_input import dump_engine_exception from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.cache import receiver_cache_from_config +from vllm.multimodal.cache import engine_receiver_cache_from_config from vllm.tasks import POOLING_TASKS, SupportedTask from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) @@ -131,7 +131,7 @@ class EngineCore: self.use_spec_decode = vllm_config.speculative_config is not None self.mm_registry = mm_registry = MULTIMODAL_REGISTRY - self.mm_receiver_cache = receiver_cache_from_config( + self.mm_receiver_cache = engine_receiver_cache_from_config( vllm_config, mm_registry) # Setup batch queue for pipeline parallelism. @@ -159,6 +159,9 @@ class EngineCore: self.request_block_hasher = get_request_block_hasher( block_size, caching_hash_fn) + self.step_fn = (self.step if self.batch_queue is None else + self.step_with_batch_queue) + def _initialize_kv_caches( self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]: start = time.time() @@ -224,7 +227,7 @@ class EngineCore: def add_request(self, request: Request, request_wave: int = 0): """Add request to the scheduler. - + `request_wave`: indicate which wave of requests this is expected to belong to in DP case """ @@ -331,7 +334,8 @@ class EngineCore: model_executed = False if self.scheduler.has_requests(): scheduler_output = self.scheduler.schedule() - future = self.model_executor.execute_model(scheduler_output) + future = self.model_executor.execute_model(scheduler_output, + non_block=True) batch_queue.appendleft( (future, scheduler_output)) # type: ignore[arg-type] @@ -433,7 +437,7 @@ class EngineCore: def preprocess_add_request( self, request: EngineCoreRequest) -> tuple[Request, int]: """Preprocess the request. - + This function could be directly used in input processing thread to allow request initialization running in parallel with Model forward """ @@ -534,9 +538,6 @@ class EngineCoreProc(EngineCore): assert addresses.coordinator_input is not None logger.info("Waiting for READY message from DP Coordinator...") - self.step_fn = (self.step if self.batch_queue is None else - self.step_with_batch_queue) - # Mark the startup heap as static so that it's ignored by GC. # Reduces pause times of oldest generation collections. gc.collect() @@ -697,7 +698,7 @@ class EngineCoreProc(EngineCore): parallel_config: ParallelConfig = kwargs[ "vllm_config"].parallel_config if parallel_config.data_parallel_size > 1 or dp_rank > 0: - set_process_title("DPEngineCore", str(dp_rank)) + set_process_title("EngineCore", f"DP{dp_rank}") decorate_logs() # Set data parallel rank for this engine process. parallel_config.data_parallel_rank = dp_rank diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 65f7abc97110c..bb0f37c6e0264 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -245,8 +245,8 @@ class InprocClient(EngineCoreClient): self.engine_core = EngineCore(*args, **kwargs) def get_output(self) -> EngineCoreOutputs: - outputs, _ = self.engine_core.step() - return outputs.get(0) or EngineCoreOutputs() + outputs, _ = self.engine_core.step_fn() + return outputs and outputs.get(0) or EngineCoreOutputs() def get_supported_tasks(self) -> tuple[SupportedTask, ...]: return self.engine_core.get_supported_tasks() @@ -347,8 +347,9 @@ class BackgroundResources: if isinstance(self.output_socket, zmq.asyncio.Socket): # Async case. - loop = self.output_socket._get_loop() - asyncio.get_running_loop() + loop = self.output_queue_task._loop \ + if self.output_queue_task else None + sockets = (self.output_socket, self.input_socket, self.first_req_send_socket, self.first_req_rcv_socket, self.stats_update_socket) @@ -359,11 +360,12 @@ class BackgroundResources: close_sockets(sockets) for task in tasks: if task is not None and not task.done(): - task.cancel() + with contextlib.suppress(Exception): + task.cancel() if in_loop(loop): close_sockets_and_tasks() - elif not loop.is_closed(): + elif loop and not loop.is_closed(): loop.call_soon_threadsafe(close_sockets_and_tasks) else: # Loop has been closed, try to clean up directly. diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py index 38f435f5166e0..cf4b06db843bd 100644 --- a/vllm/v1/engine/detokenizer.py +++ b/vllm/v1/engine/detokenizer.py @@ -121,12 +121,9 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC): self.output_token_ids) <= self.min_tokens: stop_check_offset = len(self.output_text) - if stop_terminated: - if skipped_stop_token_id is not None: - # Cleanup after skipping detokenization. - self.token_ids.append(skipped_stop_token_id) - # Stop token triggered; skip stop string check. - return None + if skipped_stop_token_id is not None: + # Cleanup after skipping detokenization. + self.token_ids.append(skipped_stop_token_id) # 2) Evaluate stop strings. stop_string = None diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 7130f666ef19f..fca5a783bc3bf 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -19,6 +19,7 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.tasks import SupportedTask +from vllm.tracing import init_tracer from vllm.transformers_utils.tokenizer_group import ( TokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import UsageContext @@ -65,6 +66,7 @@ class LLMEngine: "Set VLLM_USE_V1=0 and file and issue on Github.") self.vllm_config = vllm_config + self.observability_config = vllm_config.observability_config self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config @@ -99,6 +101,11 @@ class LLMEngine: # OutputProcessor (convert EngineCoreOutputs --> RequestOutput). self.output_processor = OutputProcessor(self.tokenizer, log_stats=self.log_stats) + if self.observability_config.otlp_traces_endpoint is not None: + tracer = init_tracer( + "vllm.llm_engine", + self.observability_config.otlp_traces_endpoint) + self.output_processor.tracer = tracer # EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs) self.engine_core = EngineCoreClient.make_client( diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 2ee55b585da6c..02c8c61cb9093 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -11,6 +11,8 @@ import torch from vllm.outputs import (CompletionOutput, PoolingOutput, PoolingRequestOutput, RequestOutput) from vllm.sampling_params import RequestOutputKind +from vllm.tracing import (SpanAttributes, SpanKind, Tracer, + extract_trace_context) from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason @@ -71,7 +73,6 @@ class RequestOutputCollector: @dataclass class OutputProcessorOutput: - request_outputs: list[Union[RequestOutput, PoolingRequestOutput]] reqs_to_abort: list[str] @@ -93,6 +94,9 @@ class RequestState: arrival_time: float, queue: Optional[RequestOutputCollector], log_stats: bool, + top_p: Optional[float] = None, + n: Optional[int] = None, + temperature: Optional[float] = None, ): self.request_id = request_id self.parent_req = parent_req @@ -105,6 +109,9 @@ class RequestState: self.logprobs_processor = logprobs_processor self.detokenizer = detokenizer self.max_tokens_param = max_tokens_param + self.top_p = top_p + self.n = n + self.temperature = temperature self.is_prefilling = True self.queue = queue self.num_cached_tokens = 0 @@ -137,10 +144,16 @@ class RequestState: request=request, ) max_tokens_param = sampling_params.max_tokens + top_p = sampling_params.top_p + n = sampling_params.n + temperature = sampling_params.temperature else: logprobs_processor = None detokenizer = None max_tokens_param = None + top_p = None + n = None + temperature = None assert request.pooling_params is not None output_kind = request.pooling_params.output_kind @@ -156,6 +169,9 @@ class RequestState: logprobs_processor=logprobs_processor, detokenizer=detokenizer, max_tokens_param=max_tokens_param, + top_p=top_p, + n=n, + temperature=temperature, arrival_time=request.arrival_time, queue=queue, log_stats=log_stats, @@ -274,16 +290,13 @@ class RequestState: class OutputProcessor: """Process EngineCoreOutputs into RequestOutputs.""" - def __init__( - self, - tokenizer: TokenizerGroup, - log_stats: bool, - ): + def __init__(self, tokenizer: TokenizerGroup, log_stats: bool): self.log_stats = log_stats self.tokenizer = tokenizer self.request_states: dict[str, RequestState] = {} self.parent_requests: dict[str, ParentRequest] = {} self.lora_states = LoRARequestStates() + self.tracer: Optional[Tracer] = None def get_num_unfinished_requests(self): return len(self.request_states) @@ -441,7 +454,9 @@ class OutputProcessor: # Track per-request stats self._update_stats_from_finished(req_state, finish_reason, iteration_stats) - + if self.tracer: + self.do_tracing(engine_core_output, req_state, + iteration_stats) self.lora_states.update_iteration_stats(iteration_stats) return OutputProcessorOutput( @@ -449,6 +464,63 @@ class OutputProcessor: reqs_to_abort=reqs_to_abort, ) + def do_tracing(self, engine_core_output: EngineCoreOutput, + req_state: RequestState, + iteration_stats: Optional[IterationStats]) -> None: + assert req_state.stats is not None + assert iteration_stats is not None + assert self.tracer is not None + + arrival_time_nano_seconds = int(req_state.stats.arrival_time * 1e9) + trace_context = extract_trace_context(engine_core_output.trace_headers) + with (self.tracer.start_as_current_span( + "llm_request", + kind=SpanKind.SERVER, + context=trace_context, + start_time=arrival_time_nano_seconds) as span): + metrics = req_state.stats + e2e_time = iteration_stats.iteration_timestamp - \ + metrics.arrival_time + queued_time = metrics.scheduled_ts - metrics.queued_ts + prefill_time = metrics.first_token_ts - metrics.scheduled_ts + decode_time = metrics.last_token_ts - metrics.first_token_ts + inference_time = metrics.last_token_ts - metrics.scheduled_ts + span.set_attribute( + SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN, + metrics.first_token_latency) + span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time) + span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE, + queued_time) + span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS, + len(req_state.prompt_token_ids)) + span.set_attribute(SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS, + metrics.num_generation_tokens) + span.set_attribute( + SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL, + prefill_time) + span.set_attribute( + SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_DECODE, + decode_time) + span.set_attribute( + SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE, + inference_time) + + # meta + span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID, + req_state.request_id) + if req_state.top_p: + span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P, + req_state.top_p) + if req_state.max_tokens_param: + span.set_attribute(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS, + req_state.max_tokens_param) + if req_state.temperature: + span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE, + req_state.temperature) + if req_state.n: + span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N, + req_state.n) + def _update_stats_from_output(self, req_state: RequestState, engine_core_output: EngineCoreOutput, engine_core_timestamp: Optional[float], diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 1aa117ded4ed8..f3fad15b750ad 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -12,7 +12,7 @@ from vllm.inputs.preprocess import InputPreprocessor from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal.cache import processor_cache_from_config -from vllm.multimodal.inputs import MultiModalFeatureSpec +from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict from vllm.multimodal.processing import EncDecMultiModalProcessor from vllm.multimodal.utils import argsort_mm_positions from vllm.pooling_params import PoolingParams @@ -65,19 +65,27 @@ class Processor: ) -> None: max_logprobs = self.model_config.max_logprobs if max_logprobs == -1: - return + max_logprobs = self.model_config.get_vocab_size() + # Validate sample logprobs. - if params.logprobs and (params.logprobs == -1 - or params.logprobs > max_logprobs): - raise ValueError( - f"Requested sample logprobs of {params.logprobs}, " - f"which is greater than max allowed: {max_logprobs}") + if params.logprobs: + num_logprobs = params.logprobs + if num_logprobs == -1: + num_logprobs = self.model_config.get_vocab_size() + if num_logprobs > max_logprobs: + raise ValueError( + f"Requested sample logprobs of {num_logprobs}, " + f"which is is greater than max allowed: {max_logprobs}") # Validate prompt logprobs. - if params.prompt_logprobs and params.prompt_logprobs > max_logprobs: - raise ValueError( - f"Requested prompt logprobs of {params.prompt_logprobs}, " - f"which is greater than max allowed: {max_logprobs}") + if params.prompt_logprobs: + num_prompt_logprobs = params.prompt_logprobs + if num_prompt_logprobs == -1: + num_prompt_logprobs = self.model_config.get_vocab_size() + if num_prompt_logprobs > max_logprobs: + raise ValueError( + f"Requested prompt logprobs of {num_prompt_logprobs}, " + f"which is is greater than max allowed: {max_logprobs}") def _validate_sampling_params( self, @@ -252,10 +260,10 @@ class Processor: else: # NOTE: engine_level_backend must be "auto" here, because we have # checked supported_backends above. - # "auto" is an opt-in to opinionated behavior where we try to - # choose a backend based on request contents. This is not the - # default as it is less predictable and subject to change - # between releases as feature support changes. + # In this mode, we set opinionated defaults based on what we think + # will satisfy the most use cases without having to worry about + # this setting. We include fallback behavior here, but not with any + # other setting where a specific backend was specified. try: validate_xgrammar_grammar(params) params.guided_decoding.backend = "xgrammar" @@ -268,11 +276,11 @@ class Processor: # Remember that this backend was set automatically params.guided_decoding.backend_was_auto = True - def _maybe_build_mm_hash_overrides( + def _maybe_build_mm_uuids( self, request_id: str, prompt: PromptType, - ) -> Optional[dict[str, list[str]]]: + ) -> Optional[MultiModalUUIDDict]: """Build per-item multimodal hash overrides when enabled. In this case, multimodal data items are identified by their request id, modality and index rather than their content. @@ -295,13 +303,13 @@ class Processor: if not mm_data: return None - overrides: dict[str, list[str]] = {} + mm_uuids: MultiModalUUIDDict = {} for modality, data in mm_data.items(): n = len(data) if isinstance(data, list) else 1 - overrides[modality] = [ + mm_uuids[modality] = [ f"{request_id}-{modality}-{i}" for i in range(n) ] - return overrides + return mm_uuids def process_inputs( self, @@ -317,11 +325,8 @@ class Processor: ) -> tuple[Optional[str], EngineCoreRequest]: # TODO(woosuk): Support pooling models. - # TODO(woosuk): Support encoder-decoder models. self._validate_lora(lora_request) self._validate_params(params, lora_request) - if trace_headers is not None: - raise ValueError("V1 does not support tracing yet.") data_parallel_size = self.vllm_config.parallel_config.data_parallel_size if data_parallel_rank is not None and not (0 <= data_parallel_rank < @@ -343,16 +348,15 @@ class Processor: if (self.model_config.multimodal_config and self.model_config.multimodal_config.mm_processor_cache_gb == 0 and not self.cache_config.enable_prefix_caching): - mm_hash_overrides = self._maybe_build_mm_hash_overrides( - request_id, prompt) + mm_uuids = self._maybe_build_mm_uuids(request_id, prompt) else: # Otherwise, use user-provided uuids as multimodal hash overrides # if provided. self._validate_multi_modal_uuids(prompt) if isinstance(prompt, dict): - mm_hash_overrides = prompt.get("multi_modal_uuids") + mm_uuids = prompt.get("multi_modal_uuids") else: - mm_hash_overrides = None + mm_uuids = None # Process inputs, which includes: # 1. Tokenize text prompt, with LoRA request if one exists. @@ -362,7 +366,7 @@ class Processor: prompt, tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) from vllm.platforms import current_platform current_platform.validate_request( @@ -377,10 +381,6 @@ class Processor: encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) - # TODO: Impl encoder-decoder - if encoder_inputs is not None: - raise NotImplementedError - sampling_params = None pooling_params = None if isinstance(params, SamplingParams): @@ -433,6 +433,7 @@ class Processor: cache_salt=decoder_inputs.get("cache_salt"), priority=priority, data_parallel_rank=data_parallel_rank, + trace_headers=trace_headers, ) def _validate_model_inputs(self, diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index ed0129fda9474..df2fd8d9df078 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -116,7 +116,7 @@ class CoreEngineProcManager: local_dp_ranks.append(local_index) self.processes.append( context.Process(target=target_fn, - name=f"EngineCore_{global_index}", + name=f"EngineCore_DP{global_index}", kwargs=common_kwargs | { "dp_rank": global_index, "local_dp_rank": local_index, diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 68408a0b8a3d5..625017d52fff0 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from concurrent.futures import Future -from typing import Callable, Optional, Union +from typing import Any, Callable, Optional, Union import torch import torch.distributed as dist @@ -14,6 +14,7 @@ from vllm.executor.uniproc_executor import ( # noqa from vllm.executor.uniproc_executor import ( # noqa UniProcExecutor as UniProcExecutorV0) from vllm.utils import resolve_obj_by_qualname +from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput @@ -86,12 +87,22 @@ class Executor(ExecutorBase): def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]: return self.collective_rpc("get_kv_cache_spec") + def collective_rpc(self, + method: Union[str, Callable], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict] = None, + non_block: bool = False) -> list[Any]: + raise NotImplementedError + def execute_model( self, - scheduler_output, + scheduler_output: SchedulerOutput, + non_block: bool = False, ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: output = self.collective_rpc("execute_model", - args=(scheduler_output, )) + args=(scheduler_output, ), + non_block=non_block) return output[0] def execute_dummy_batch(self) -> None: diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index ef6303495c245..f566c9aee0c54 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import multiprocessing -import os import pickle import queue import signal @@ -12,9 +11,10 @@ import weakref from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass from enum import Enum, auto -from functools import partial +from functools import cached_property, partial from multiprocessing.connection import Connection from multiprocessing.process import BaseProcess +from multiprocessing.synchronize import Lock as LockType from threading import Thread from typing import Any, Callable, Optional, Union, cast @@ -27,13 +27,19 @@ from vllm.distributed import (destroy_distributed_environment, from vllm.distributed.device_communicators.shm_broadcast import (Handle, MessageQueue) from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator +from vllm.distributed.parallel_state import (get_dp_group, get_ep_group, + get_pp_group, get_tp_group) from vllm.executor.multiproc_worker_utils import ( set_multiprocessing_worker_envs) from vllm.logger import init_logger +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.cache import worker_receiver_cache_from_config from vllm.utils import (decorate_logs, get_distributed_init_method, get_loopback_ip, get_mp_context, get_open_port, set_process_title) +from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.executor.abstract import Executor, FailureCallback +from vllm.v1.executor.utils import get_and_update_mm_cache from vllm.v1.outputs import (AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput) from vllm.worker.worker_base import WorkerWrapperBase @@ -80,6 +86,8 @@ class MultiprocExecutor(Executor): scheduler_output_handle = self.rpc_broadcast_mq.export_handle() # Create workers + context = get_mp_context() + shared_worker_lock = context.Lock() unready_workers: list[UnreadyWorkerProcHandle] = [] success = False try: @@ -91,6 +99,7 @@ class MultiprocExecutor(Executor): rank=rank, distributed_init_method=distributed_init_method, input_shm_handle=scheduler_output_handle, + shared_worker_lock=shared_worker_lock, )) # Workers must be created before wait_for_ready to avoid @@ -166,9 +175,9 @@ class MultiprocExecutor(Executor): def execute_model( self, - scheduler_output, + scheduler_output: SchedulerOutput, + non_block: bool = False, ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: - non_block = self.max_concurrent_batches > 1 if not self.has_connector: # get output only from a single worker (output_rank) @@ -320,7 +329,7 @@ class MultiprocExecutor(Executor): self.collective_rpc("check_health", timeout=10) return - @property + @cached_property def max_concurrent_batches(self) -> int: if self.scheduler_config.async_scheduling: return 2 @@ -379,6 +388,7 @@ class WorkerProc: rank: int, distributed_init_method: str, input_shm_handle: Handle, + shared_worker_lock: LockType, ): self.rank = rank wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank) @@ -398,17 +408,6 @@ class WorkerProc: wrapper.init_worker(all_kwargs) self.worker = wrapper - pp_size = vllm_config.parallel_config.pipeline_parallel_size - tp_size = vllm_config.parallel_config.tensor_parallel_size - pp_str = f"PP{rank // tp_size}" if pp_size > 1 else "" - tp_str = f"TP{rank % tp_size}" if tp_size > 1 else "" - suffix = f"{pp_str}{'_' if pp_str and tp_str else ''}{tp_str}" - process_name = "VllmWorker" - if suffix: - set_process_title(suffix, append=True) - process_name = f"{process_name} {suffix}" - decorate_logs(process_name) - # Initialize MessageQueue for receiving SchedulerOutput self.rpc_broadcast_mq = MessageQueue.create_from_handle( input_shm_handle, self.worker.rank) @@ -426,17 +425,28 @@ class WorkerProc: name="WorkerAsyncOutputCopy") self.async_output_copy_thread.start() - # Initialize device and loads weights + # Initialize multimodal receiver cache if needed + self.mm_receiver_cache = worker_receiver_cache_from_config( + vllm_config, MULTIMODAL_REGISTRY, shared_worker_lock) + + # Initialize device self.worker.init_device() + + # Set process title and log prefix + self.setup_proc_title_and_log_prefix( + enable_ep=vllm_config.parallel_config.enable_expert_parallel) + + # Load model self.worker.load_model() @staticmethod def make_worker_process( - vllm_config: VllmConfig, - local_rank: int, - rank: int, - distributed_init_method: str, - input_shm_handle, # Receive SchedulerOutput + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + input_shm_handle, # Receive SchedulerOutput + shared_worker_lock: LockType, ) -> UnreadyWorkerProcHandle: context = get_mp_context() # (reader, writer) @@ -453,6 +463,7 @@ class WorkerProc: "input_shm_handle": input_shm_handle, "ready_pipe": (reader, writer), "death_pipe": death_reader, + "shared_worker_lock": shared_worker_lock, } # Run EngineCore busy loop in background process. proc = context.Process(target=WorkerProc.worker_main, @@ -507,6 +518,7 @@ class WorkerProc: return cast(list[WorkerProcHandle], ready_proc_handles) def shutdown(self): + self.worker.shutdown() self.rpc_broadcast_mq = None self.worker_response_mq = None destroy_model_parallel() @@ -536,7 +548,7 @@ class WorkerProc: # tuple[Connection, Connection] reader, ready_writer = kwargs.pop("ready_pipe") death_pipe = kwargs.pop("death_pipe", None) - + shutdown_event = threading.Event() # Start death monitoring thread if death_pipe is provided if death_pipe is not None: @@ -548,7 +560,7 @@ class WorkerProc: # Parent process has exited, terminate this worker logger.info("Parent process exited, terminating worker") # Send signal to self to trigger clean shutdown - os.kill(os.getpid(), signal.SIGTERM) + shutdown_event.set() except Exception as e: logger.warning("Death monitoring error: %s", e) @@ -576,7 +588,7 @@ class WorkerProc: ready_writer.close() ready_writer = None - worker.worker_busy_loop() + worker.worker_busy_loop(cancel=shutdown_event) except Exception: # NOTE: if an Exception arises in busy_loop, we send @@ -586,6 +598,8 @@ class WorkerProc: if ready_writer is not None: logger.exception("WorkerProc failed to start.") + elif shutdown_event.is_set(): + logger.info("WorkerProc shutting down.") else: logger.exception("WorkerProc failed.") @@ -619,7 +633,8 @@ class WorkerProc: result = (WorkerProc.ResponseStatus.FAILURE, str(output)) else: result = (WorkerProc.ResponseStatus.SUCCESS, output) - self.worker_response_mq.enqueue(result) + if (response_mq := self.worker_response_mq) is not None: + response_mq.enqueue(result) def handle_output(self, output: Any): """Handles output from the worker. If async scheduling is enabled, @@ -637,16 +652,20 @@ class WorkerProc: output = self.async_output_queue.get() self.enqueue_output(output) - def worker_busy_loop(self): + def worker_busy_loop(self, cancel: Optional[threading.Event] = None): """Main busy loop for Multiprocessing Workers""" while True: - method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue() - + method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue( + cancel=cancel) try: if isinstance(method, str): func = getattr(self.worker, method) elif isinstance(method, bytes): func = partial(cloudpickle.loads(method), self.worker) + # retrieve from shm cache if available + if self.mm_receiver_cache is not None \ + and func.__name__ == "execute_model": + get_and_update_mm_cache(self.mm_receiver_cache, args) output = func(*args, **kwargs) except Exception as e: # Notes have been introduced in python 3.11 @@ -661,3 +680,24 @@ class WorkerProc: if output_rank is None or self.rank == output_rank: self.handle_output(output) + + @staticmethod + def setup_proc_title_and_log_prefix(enable_ep: bool) -> None: + dp_size = get_dp_group().world_size + dp_rank = get_dp_group().rank_in_group + pp_size = get_pp_group().world_size + pp_rank = get_pp_group().rank_in_group + tp_size = get_tp_group().world_size + tp_rank = get_tp_group().rank_in_group + process_name = "Worker" + if dp_size > 1: + process_name += f"_DP{dp_rank}" + if pp_size > 1: + process_name += f"_PP{pp_rank}" + if tp_size > 1: + process_name += f"_TP{tp_rank}" + if enable_ep: + ep_rank = get_ep_group().rank_in_group + process_name += f"_EP{ep_rank}" + set_process_title(name=process_name) + decorate_logs(process_name) diff --git a/vllm/v1/executor/ray_distributed_executor.py b/vllm/v1/executor/ray_distributed_executor.py index 8394ae788ab01..59c9b56625a95 100644 --- a/vllm/v1/executor/ray_distributed_executor.py +++ b/vllm/v1/executor/ray_distributed_executor.py @@ -66,11 +66,13 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor): def execute_model( self, scheduler_output: SchedulerOutput, + non_block: bool = False, ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: """Execute the model on the Ray workers. Args: scheduler_output: The scheduler output to execute. + non_block: If True, the method will return a Future. Returns: The model runner output. @@ -84,7 +86,7 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor): if not self.has_connector: # Get output only from a single worker (output_rank) # When PP is not used, we block here until the result is available. - if self.max_concurrent_batches == 1: + if not non_block: return refs[0].get() # When PP is used, we return a FutureWrapper immediately so that @@ -92,7 +94,7 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor): return FutureWrapper(refs) # Get output from all workers when connector is present - if self.max_concurrent_batches == 1: + if not non_block: # Block and get results from all workers outputs = [ref.get() for ref in refs] return self.kv_output_aggregator.aggregate(outputs) @@ -106,4 +108,3 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor): if reconfig_request.new_data_parallel_rank == \ ReconfigureRankType.SHUTDOWN_CURRENT_RANK: self.shutdown() - return \ No newline at end of file diff --git a/vllm/v1/executor/utils.py b/vllm/v1/executor/utils.py new file mode 100644 index 0000000000000..1855bc9963817 --- /dev/null +++ b/vllm/v1/executor/utils.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.multimodal.cache import ShmObjectStoreReceiverCache +from vllm.v1.core.sched.output import SchedulerOutput + + +def get_and_update_mm_cache( + receiver_cache: ShmObjectStoreReceiverCache, + args: tuple[SchedulerOutput], +) -> None: + """ + For each MultiModalKwargsItem in SchedulerOutput, fetch from shared memory + cache as needed. + + Args: + receiver_cache: The receiver cache to update. + args: According to the collective_rpc call of execute_model method in + executor, args is a tuple of only one SchedulerOutput element. + """ + scheduler_output = args[0] + for request_data in scheduler_output.scheduled_new_reqs: + request_data.mm_features = receiver_cache.get_and_update_features( + request_data.mm_features) diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 6467fcfe40aef..6e8f569fff0e3 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -194,6 +194,7 @@ class MambaSpec(KVCacheSpec): dtypes: tuple[torch.dtype] page_size_padded: Optional[int] = None mamba_type: str = "mamba2" + num_speculative_blocks: int = 0 @property def page_size_bytes(self) -> int: diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index f480344c854f7..347185d8341ee 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -652,6 +652,7 @@ class StatLoggerManager: engine_idxs: Optional[list[int]] = None, custom_stat_loggers: Optional[list[StatLoggerFactory]] = None, enable_default_loggers: bool = True, + client_count: int = 1, ): self.engine_idxs = engine_idxs if engine_idxs else [0] @@ -660,7 +661,12 @@ class StatLoggerManager: factories.extend(custom_stat_loggers) if enable_default_loggers and logger.isEnabledFor(logging.INFO): - factories.append(LoggingStatLogger) + if client_count > 1: + logger.warning( + "AsyncLLM created with api_server_count more than 1; " + "disabling stats logging to avoid incomplete stats.") + else: + factories.append(LoggingStatLogger) # engine_idx: StatLogger self.per_engine_logger_dict: dict[int, list[StatLoggerBase]] = {} diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 45c32aaaaf6c4..e6c344d193df2 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -68,6 +68,9 @@ class RequestStateStats: first_token_ts: float = 0.0 last_token_ts: float = 0.0 + # first token latency + first_token_latency: float = 0.0 + @dataclass class FinishedRequestStats: @@ -116,6 +119,7 @@ class IterationStats: first_token_latency = self._time_since(req_stats.arrival_time) self.time_to_first_tokens_iter.append(first_token_latency) + req_stats.first_token_latency = first_token_latency req_stats.num_generation_tokens += num_new_generation_tokens diff --git a/vllm/v1/request.py b/vllm/v1/request.py index ad7477241ebbd..4e3e581235cce 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -3,6 +3,7 @@ import enum import time +from collections.abc import Mapping from functools import partial from typing import TYPE_CHECKING, Any, Callable, Optional, Union @@ -35,6 +36,7 @@ class Request: structured_output_request: Optional["StructuredOutputRequest"] = None, cache_salt: Optional[str] = None, priority: int = 0, + trace_headers: Optional[Mapping[str, str]] = None, block_hasher: Optional[Callable[["Request"], list["BlockHash"]]] = None, ) -> None: @@ -89,18 +91,14 @@ class Request: self.mm_features = mm_features or [] self.num_encoder_inputs = len(self.mm_features) self.has_encoder_inputs = self.num_encoder_inputs > 0 - # TODO(sfeng33): Remove these legacy fields after clearing out all - # references in scheduler and model runner - self.mm_positions = [f.mm_position for f in self.mm_features] - self.mm_kwargs = [f.data for f in self.mm_features] - self.mm_hashes = [f.identifier for f in self.mm_features] # Read-only views # Prevent directly appending to these lists since # they should also be updated simultaneously. self.output_token_ids = ConstantList(self._output_token_ids) self.all_token_ids = ConstantList(self._all_token_ids) - + # trace_headers + self.trace_headers = trace_headers # State # The number of tokens with prefix cache hits. self.num_cached_tokens = -1 @@ -136,6 +134,7 @@ class Request: if request.sampling_params else None, cache_salt=request.cache_salt, priority=request.priority, + trace_headers=request.trace_headers, block_hasher=block_hasher, ) @@ -176,8 +175,8 @@ class Request: return RequestStatus.get_finished_reason(self.status) def get_num_encoder_tokens(self, input_id: int) -> int: - assert input_id < len(self.mm_positions) - num_tokens = self.mm_positions[input_id].length + assert input_id < len(self.mm_features) + num_tokens = self.mm_features[input_id].mm_position.length return num_tokens def record_event( diff --git a/vllm/v1/sample/logits_processor/__init__.py b/vllm/v1/sample/logits_processor/__init__.py index a5f1cadd85241..df944873bcaf3 100644 --- a/vllm/v1/sample/logits_processor/__init__.py +++ b/vllm/v1/sample/logits_processor/__init__.py @@ -195,7 +195,7 @@ class AdapterLogitsProcessor(LogitsProcessor): overridden in general. However, to implement custom constructor behavior - especially any logic which operates on or stores `vllm_config`, `device`, or `is_pin_memory` - `self.__init__(vllm_config, device, is_pin_memory)` - must be overriden and the override must call + must be overridden and the override must call `super().__init__(vllm_config, device, is_pin_memory)` """ diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index bf25c91d8390c..7132d507c722c 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -218,7 +218,7 @@ class EagleProposer: hidden_states=self.hidden_states[:num_input_tokens], inputs_embeds=inputs_embeds, ) - if self.method in ("deepseek_mtp", "ernie_mtp"): + if self.method in ("deepseek_mtp", "ernie_mtp", "qwen3_next_mtp"): last_hidden_states = ret_hidden_states hidden_states = last_hidden_states else: @@ -322,12 +322,18 @@ class EagleProposer: with set_forward_context(per_layer_attn_metadata, self.vllm_config, num_tokens=input_batch_size): - last_hidden_states, hidden_states = self.model( + ret_hidden_states = self.model( input_ids=input_ids, positions=self.positions[:input_batch_size], hidden_states=self.hidden_states[:input_batch_size], inputs_embeds=inputs_embeds, ) + if self.method in ("deepseek_mtp", "ernie_mtp", + "qwen3_next_mtp"): + last_hidden_states = ret_hidden_states + hidden_states = ret_hidden_states + else: + last_hidden_states, hidden_states = ret_hidden_states hidden_states = hidden_states[:batch_size] logits = self.model.compute_logits(last_hidden_states[:batch_size], None) diff --git a/vllm/v1/spec_decode/metrics.py b/vllm/v1/spec_decode/metrics.py index b4bc3058c570a..2aa8962f5739c 100644 --- a/vllm/v1/spec_decode/metrics.py +++ b/vllm/v1/spec_decode/metrics.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import time from dataclasses import dataclass, field from typing import Optional @@ -58,6 +59,7 @@ class SpecDecodingLogging: self.num_draft_tokens: list[int] = [] self.num_accepted_tokens: list[int] = [] self.accepted_tokens_per_pos_lists: list[list[int]] = [] + self.last_log_time = time.monotonic() def observe(self, spec_decoding_stats: SpecDecodingStats): self.num_drafts.append(spec_decoding_stats.num_drafts) @@ -73,6 +75,13 @@ class SpecDecodingLogging: num_drafts = np.sum(self.num_drafts) num_draft_tokens = np.sum(self.num_draft_tokens) num_accepted_tokens = np.sum(self.num_accepted_tokens) + draft_throughput = 0 + accepted_throughput = 0 + + elapsed_time = time.monotonic() - self.last_log_time + if elapsed_time > 0: + draft_throughput = num_draft_tokens / elapsed_time + accepted_throughput = num_accepted_tokens / elapsed_time draft_acceptance_rate = (num_accepted_tokens / num_draft_tokens * 100 if num_draft_tokens > 0 else float("nan")) @@ -86,16 +95,20 @@ class SpecDecodingLogging: log_fn( "SpecDecoding metrics: " - "Draft acceptance rate: %.1f%%, " "Mean acceptance length: %.2f, " + "Accepted throughput: %.2f tokens/s, " + "Drafted throughput: %.2f tokens/s, " "Accepted: %d tokens, " "Drafted: %d tokens, " - "Per-position acceptance rate: %s", - draft_acceptance_rate, + "Per-position acceptance rate: %s, " + "Avg Draft acceptance rate: %.1f%%", mean_acceptance_length, + accepted_throughput, + draft_throughput, num_accepted_tokens, num_draft_tokens, rates_str, + draft_acceptance_rate, ) self.reset() diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index e0c7d9094aa6d..fd84b4a111f58 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -355,7 +355,8 @@ def report_usage_stats( vllm_config.cache_config.block_size, "gpu_memory_utilization": vllm_config.cache_config.gpu_memory_utilization, - + "kv_cache_memory_bytes": + vllm_config.cache_config.kv_cache_memory_bytes, # Quantization "quantization": vllm_config.model_config.quantization, diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index c5902595a496b..194984bf50536 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -98,7 +98,7 @@ class BlockTable: # here because M (max_model_len) is not necessarily divisible by # block_size. if self.dcp_world_size > 1: - # Note(hc): The DCP implement store kvcache with a interleave + # Note(hc): The DCP implement store kvcache with an interleave # style, the kvcache for the token whose token_idx is i is # always stored on the GPU whose dcp_rank equals i % cp_world_size: @@ -112,9 +112,9 @@ class BlockTable: # tokens. virtual_block_offsets = positions % virtual_block_size mask = virtual_block_offsets % self.dcp_world_size == self.dcp_rank - # Calcuate local block_offsets + # Calculate local block_offsets block_offsets = virtual_block_offsets // self.dcp_world_size - # Calcuate slot_mapping + # Calculate slot_mapping slot_mapping = block_numbers * self.block_size + block_offsets # Write final slots, use -1 for not-local self.slot_mapping_np[:req_indices.shape[0]] = np.where( @@ -156,9 +156,14 @@ class BlockTable: class MultiGroupBlockTable: """The BlockTables for each KV cache group.""" - def __init__(self, max_num_reqs: int, max_model_len: int, - max_num_batched_tokens: int, pin_memory: bool, - device: torch.device, block_sizes: list[int]) -> None: + def __init__(self, + max_num_reqs: int, + max_model_len: int, + max_num_batched_tokens: int, + pin_memory: bool, + device: torch.device, + block_sizes: list[int], + num_speculative_tokens: int = 0) -> None: # Note(hc): each dcp rank only store # (max_model_len//dcp_world_size) tokens in kvcache, # so the block_size which used for calc max_num_blocks_per_req @@ -170,10 +175,11 @@ class MultiGroupBlockTable: dcp_world_size = 1 self.block_tables = [ - BlockTable(block_size, max_num_reqs, - cdiv(max_model_len, block_size * dcp_world_size), - max_num_batched_tokens, pin_memory, device) - for block_size in block_sizes + BlockTable( + block_size, max_num_reqs, + max(cdiv(max_model_len, block_size * dcp_world_size), + 1 + num_speculative_tokens), max_num_batched_tokens, + pin_memory, device) for block_size in block_sizes ] def append_row(self, block_ids: tuple[list[int], ...], diff --git a/vllm/v1/worker/cpu_model_runner.py b/vllm/v1/worker/cpu_model_runner.py index feb49978d7518..d5ec19b86b061 100644 --- a/vllm/v1/worker/cpu_model_runner.py +++ b/vllm/v1/worker/cpu_model_runner.py @@ -55,11 +55,23 @@ class CPUModelRunner(GPUModelRunner): raise ValueError("Multiple KVCacheGroups is not" "currently supported with CPU model runner.") - assert type(self.attn_groups[0] - [0].metadata_builder) is TorchSDPAMetadataBuilderV1 + # Guard against encoder-only / pooling models where `attn_groups` + # may be empty or lack the expected metadata_builder. + # Without this check, accessing `attn_groups[0][0]` would trigger + # an AssertionError on CPU backend. + if not hasattr(self, "attn_groups") or not self.attn_groups: + return + if not self.attn_groups[0]: + return - self.attn_groups[0][0].metadata_builder.reorder_batch( - self.input_batch, scheduler_output) + mb = getattr(self.attn_groups[0][0], "metadata_builder", None) + if not isinstance(mb, TorchSDPAMetadataBuilderV1): + # Encoder-only / rerank models do not benefit from reordering, + # so we safely skip here. + return + + # Safe path for decoder/attention-heavy models + mb.reorder_batch(self.input_batch, scheduler_output) def _postprocess_tensors(self) -> None: # Note: replace device tensors with cpu tensors diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d8a9c36870f7e..93d69147978d6 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -53,19 +53,23 @@ from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - GiB_bytes, LazyLoader, cdiv, check_use_alibi, - get_dtype_size, is_pin_memory_available, round_up, - supports_dynamo) + GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size, + is_pin_memory_available, round_up, supports_dynamo) +from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, create_fast_prefill_custom_backend) from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher +# yapf conflicts with isort for this block +# yapf: disable from vllm.v1.kv_cache_interface import (AttentionSpec, ChunkedLocalAttentionSpec, + CrossAttentionSpec, EncoderOnlyAttentionSpec, FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheSpec, MambaSpec, SlidingWindowSpec) +# yapf: enable from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, DraftTokenIds, LogprobsLists, LogprobsTensors, ModelRunnerOutput, SamplerOutput) @@ -209,6 +213,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( model_config) + if self.model_config.is_encoder_decoder: + # Maximum length of the encoder input, only for encoder-decoder + # models. + self.max_encoder_len = self.mm_registry.\ + get_encdec_max_encoder_len(model_config) + else: + self.max_encoder_len = 0 + # Sampler self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode) @@ -256,7 +268,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.max_num_cached_reqs = 2 * self.max_num_reqs self.req_states = RequestState( max_num_reqs=self.max_num_reqs, - max_model_len=self.max_model_len, + # We need to use the encoder length for encoder-decoer + # because of KV cache for cross-attention. + max_model_len=max(self.max_model_len, self.max_encoder_len), max_num_batched_tokens=self.max_num_tokens, max_num_cached_reqs=self.max_num_cached_reqs, device=self.device, @@ -295,6 +309,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.hidden_size, dtype=self.dtype, numpy=False) + self.num_draft_tokens = self._make_buffer(self.max_num_reqs, + dtype=torch.int32) + self.num_accepted_tokens = self._make_buffer(self.max_num_reqs, + dtype=torch.int64) # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: @@ -311,6 +329,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.mrope_positions = self._make_buffer( (3, self.max_num_tokens + 1), dtype=torch.int64) + # CUDA event to synchronize use of reused CPU tensors between steps + # when async scheduling is enabled. + self.prepare_inputs_event: Optional[torch.cuda.Event] = None + if self.use_async_scheduling: + self.prepare_inputs_event = torch.cuda.Event() + # Start in a completed state. + self.prepare_inputs_event.record(torch.cuda.default_stream()) + # None in the first PP rank. The rest are set after load_model. self.intermediate_tensors: Optional[IntermediateTensors] = None @@ -332,11 +358,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Cudagraph dispatcher for runtime cudagraph dispatching. self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config) - self.mm_budget = (MultiModalBudget( + self.mm_budget = MultiModalBudget( self.model_config, self.scheduler_config, self.mm_registry, - ) if self.supports_mm_inputs else None) + ) if self.supports_mm_inputs else None # Attention layers that are only in the KVCacheConfig of the runner # (e.g., KV sharing, encoder-only attention), but not in the @@ -511,7 +537,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): second_per_grid_ts = [] audio_feature_lengths = [] use_audio_in_video = False - for mm_item in req_data.mm_kwargs: + for mm_feature in req_state.mm_features: + mm_item = mm_feature.data + if mm_item is None: + continue mm_input = mm_item.get_data() if (t := mm_input.get("image_grid_thw")) is not None: image_grid_thw.append(t.tolist()) @@ -544,7 +573,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): mm_kwargs = list[MultiModalKwargsItem]() for req in scheduler_output.scheduled_new_reqs: - mm_kwargs.extend(req.mm_kwargs) + for feature in req.mm_features: + if feature.data is not None: + mm_kwargs.append(feature.data) # Input all modalities at once mm_kwargs_combined: BatchedTensorInputs = {} @@ -644,6 +675,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # from these partial requests, we do so for simplicity. # We will ignore the sampled tokens from the partial requests. logits_indices = query_start_loc[1:] - 1 + num_draft_tokens = None spec_decode_metadata = None else: # Get the number of draft tokens for each request. @@ -653,6 +685,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): query_start_loc, ) logits_indices = spec_decode_metadata.logits_indices + self.num_draft_tokens.np[:num_reqs] = num_draft_tokens + self.num_draft_tokens.np[num_reqs:].fill(0) + self.num_draft_tokens.copy_to_gpu() logits_indices_padded = None if self.cache_config.kv_sharing_fast_prefill: @@ -666,12 +701,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): idx_mapping_np] num_computed_tokens_cpu = torch.from_numpy(num_computed_tokens_np) spec_decode_common_attn_metadata = None + if use_spec_decode: + self.num_accepted_tokens.np[:num_reqs] = ( + self.input_batch.num_accepted_tokens_cpu[:num_reqs]) + self.num_accepted_tokens.np[num_reqs:].fill(1) + self.num_accepted_tokens.copy_to_gpu() attn_metadata: dict[str, Any] = {} # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): + encoder_seq_lens = self._get_encoder_seq_lens( + scheduler_output, kv_cache_group_spec.kv_cache_spec, num_reqs) if isinstance(kv_cache_group_spec.kv_cache_spec, EncoderOnlyAttentionSpec): @@ -710,6 +752,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): logits_indices_padded=logits_indices_padded, num_logits_indices=logits_indices.size(0), causal=True, + encoder_seq_lens=encoder_seq_lens, ) if self.speculative_config and \ @@ -728,10 +771,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): builder, ) - attn_metadata_i = (builder.build( + extra_attn_metadata_args = {} + if use_spec_decode and isinstance(builder, + GDNAttentionMetadataBuilder): + extra_attn_metadata_args = dict( + num_accepted_tokens=self.num_accepted_tokens. + gpu[:num_reqs], + num_draft_tokens=self.num_draft_tokens.gpu[:num_reqs], + ) + + attn_metadata_i = builder.build( common_prefix_len=common_prefix_len, common_attn_metadata=common_attn_metadata, - )) + **extra_attn_metadata_args) for layer_name in attn_group.layer_names: attn_metadata[layer_name] = attn_metadata_i @@ -949,10 +1001,24 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded]) return logits_indices_padded - def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): + def _batch_mm_kwargs_from_scheduler( + self, + scheduler_output: "SchedulerOutput", + ) -> tuple[list[MultiModalKwargsItem], list[tuple[str, PlaceholderRange]]]: + """Batch multimodal kwargs from scheduled encoder inputs. + + Args: + scheduler_output: The scheduler output containing scheduled encoder + inputs. + + Returns: + A tuple of (mm_kwargs, req_ids_pos) where: + - mm_kwargs: List of multimodal kwargs items to be batched + - mm_hashes_pos: List of (mm_hash, position_info) tuples + """ scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs if not scheduled_encoder_inputs: - return + return [], [] # Batch the multi-modal inputs. mm_kwargs = list[MultiModalKwargsItem]() # list of tuple (mm_hash, position_info) @@ -961,10 +1027,20 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_state = self.requests[req_id] for mm_input_id in encoder_input_ids: - mm_hash = req_state.mm_hashes[mm_input_id] - mm_kwargs.append(req_state.mm_kwargs[mm_input_id]) - mm_hashes_pos.append( - (mm_hash, req_state.mm_positions[mm_input_id])) + mm_feature = req_state.mm_features[mm_input_id] + mm_hash = mm_feature.identifier + mm_kwargs.append(mm_feature.data) + mm_hashes_pos.append((mm_hash, mm_feature.mm_position)) + + return mm_kwargs, mm_hashes_pos + + def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): + # Batch the multi-modal inputs using the helper method. + mm_kwargs, mm_hashes_pos = self._batch_mm_kwargs_from_scheduler( + scheduler_output) + + if not mm_kwargs: + return # Batch mm inputs as much as we can: if a request in the batch has # multiple modalities or a different modality than the previous one, @@ -1016,9 +1092,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_state = self.requests[req_id] num_computed_tokens = \ req_state.num_computed_tokens + shift_computed_tokens - mm_positions = req_state.mm_positions - mm_hashes = req_state.mm_hashes - for i, pos_info in enumerate(mm_positions): + for mm_feature in req_state.mm_features: + pos_info = mm_feature.mm_position start_pos = pos_info.offset num_encoder_tokens = pos_info.length @@ -1041,7 +1116,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) assert start_idx < end_idx - mm_hash = mm_hashes[i] + mm_hash = mm_feature.identifier encoder_output = self.encoder_cache.get(mm_hash, None) assert encoder_output is not None,\ f"Encoder cache miss for {mm_hash}." @@ -1056,6 +1131,35 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): mm_embeds.append(mm_embeds_item) return mm_embeds + def _extract_encoder_inputs( + self, + scheduler_output: "SchedulerOutput", + ) -> dict[str, torch.Tensor]: + """Extract encoder inputs for encoder-decoder models. + + This method extracts multimodal input features from scheduled encoder + inputs and formats them for the encoder-decoder model forward pass. + """ + # Batch the multi-modal inputs using the helper method. + mm_kwargs, _ = self._batch_mm_kwargs_from_scheduler(scheduler_output) + + if not mm_kwargs: + return {} + + # Group MM kwargs by modality and extract features + encoder_features = {} + for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( + mm_kwargs, + device=self.device, + pin_memory=self.pin_memory, + ): + # Add the grouped features to encoder_features dict + # This allows the model to receive them as kwargs (e.g., + # input_features=...) + encoder_features.update(mm_kwargs_group) + + return encoder_features + def get_model(self) -> nn.Module: # get raw model out of the cudagraph wrapper. if isinstance(self.model, CUDAGraphWrapper): @@ -1320,7 +1424,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens) num_input_tokens += num_pad - if self.supports_mm_inputs and get_pp_group().is_first_rank: + # _prepare_inputs may reorder the batch, so we must gather multi + # modal outputs after that to ensure the correct order + if (self.supports_mm_inputs and get_pp_group().is_first_rank + and not self.model_config.is_encoder_decoder): # Run the multimodal encoder if any. self._execute_mm_encoder(scheduler_output) mm_embeds = self._gather_mm_embeddings(scheduler_output) @@ -1362,6 +1469,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): intermediate_tensors = self.sync_and_slice_intermediate_tensors( num_input_tokens, intermediate_tensors, True) + if (self.model_config.is_encoder_decoder + and scheduler_output.scheduled_encoder_inputs): + encoder_inputs = self._extract_encoder_inputs(scheduler_output) + model_kwargs.update(encoder_inputs) + return ( num_scheduled_tokens, num_input_tokens, @@ -1411,6 +1523,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): sampling_metadata, ) sampler_output.sampled_token_ids = output_token_ids + self._update_states_after_model_execute(output_token_ids) return sampler_output @@ -1568,8 +1681,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): "prompt tokens, tokens, please disable it when the requests" " need prompt logprobs") - # Prepare the decoder inputs. - input_batch = self._prepare_inputs(scheduler_output) + if self.prepare_inputs_event is not None: + # Ensure prior step has finished with reused CPU tensors. + self.prepare_inputs_event.synchronize() + try: + # Prepare the decoder inputs. + (attn_metadata, logits_indices, spec_decode_metadata, + num_scheduled_tokens_np, spec_decode_common_attn_metadata, + max_query_len) = self._prepare_inputs(scheduler_output) + + finally: + if self.prepare_inputs_event is not None: + self.prepare_inputs_event.record() ( num_scheduled_tokens, @@ -1659,8 +1782,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): with record_function_or_nullcontext("Sample"): sampler_output = self._sample(logits, input_batch) - with record_function_or_nullcontext("Postprocess"): - assert isinstance(hidden_states, torch.Tensor) + with record_function_or_nullcontext("Bookkeep"): ( num_nans_in_logits, logprobs_lists, @@ -2168,6 +2290,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): uniform_decode: bool = False, skip_eplb: bool = False, is_profile: bool = False, + create_mixed_batch: bool = False, remove_lora: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """ @@ -2186,6 +2309,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): uniform_decode: If True, the batch is a uniform decode batch. skip_eplb: If True, skip EPLB state update. is_profile: If True, this is a profile run. + create_mixed_batch: If True, create a mixed batch with both decode + (1 token) and prefill (multiple tokens) requests. remove_lora: If False, dummy LoRAs are not destroyed after the run """ assert cudagraph_runtime_mode in { @@ -2197,7 +2322,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_tokens += num_pad # If cudagraph_mode.decode_mode() == FULL and - # cudagraph_mode.seperate_routine(). This means that we are using + # cudagraph_mode.separate_routine(). This means that we are using # different graphs and/or modes for mixed prefill-decode batches vs. # uniform decode batches. A uniform decode batch means that all # requests have identical query length, except a potential virtual @@ -2217,13 +2342,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # has num_tokens in total. assert num_tokens <= self.scheduler_config.max_num_batched_tokens max_num_reqs = self.scheduler_config.max_num_seqs - if uniform_decode: - num_reqs = cdiv(num_tokens, max_query_len) + if create_mixed_batch: + assert not uniform_decode + # Create mixed batch: + # first half decode tokens, second half one prefill + num_decode_tokens = num_tokens // 2 + num_prefill_tokens = num_tokens - num_decode_tokens + num_reqs = num_decode_tokens + 1 + + # Create decode requests (1 token each) followed by prefill request + num_scheduled_tokens_list = [1] * num_decode_tokens + [ + num_prefill_tokens + ] + # Note: Overriding max_query_len to be the prefill tokens + max_query_len = num_prefill_tokens + elif uniform_decode: + num_reqs = num_tokens // max_query_len assert num_reqs <= max_num_reqs, \ "Do not capture num_reqs > max_num_reqs for uniform batch" num_scheduled_tokens_list = [max_query_len] * num_reqs if num_tokens % max_query_len != 0: - num_scheduled_tokens_list[-1] = num_tokens % max_query_len + num_scheduled_tokens_list[-1] += num_tokens % max_query_len else: num_reqs = min(num_tokens, max_num_reqs) min_tokens_per_req = num_tokens // num_reqs @@ -2242,8 +2381,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL: attn_metadata = {} - # Make sure max_model_len is used at the graph capture time. - self.seq_lens.np[:num_reqs] = self.max_model_len + if create_mixed_batch: + # In the mixed batch mode (used for FI warmup), we use + # shorter sequence lengths to run faster. + # TODO(luka) better system for describing dummy batches + seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1] + else: + # Make sure max_model_len is used at the graph capture time. + seq_lens = self.max_model_len + self.seq_lens.np[:num_reqs] = seq_lens self.seq_lens.np[num_reqs:] = 0 self.seq_lens.copy_to_gpu() @@ -2275,17 +2421,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens, remove_lora): - if self.supports_mm_inputs: + model_kwargs = self._init_model_kwargs(num_tokens) + if (self.supports_mm_inputs + and not self.model_config.is_encoder_decoder): input_ids = None inputs_embeds = self.inputs_embeds.gpu[:num_tokens] model_kwargs = { - **self._init_model_kwargs(num_tokens), + **model_kwargs, **self._dummy_mm_kwargs(num_reqs), } else: input_ids = self.input_ids.gpu[:num_tokens] inputs_embeds = None - model_kwargs = self._init_model_kwargs(num_tokens) if self.uses_mrope: positions = self.mrope_positions.gpu[:, :num_tokens] @@ -2507,7 +2654,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): mm_budget = self.mm_budget assert mm_budget is not None - # TODO: handle encoder-decoder models once we support them. if (encoder_budget := mm_budget.get_encoder_budget()) > 0: # NOTE: Currently model is profiled with a single non-text # modality with the max possible input tokens even when @@ -2560,12 +2706,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.encoder_cache.clear() gc.collect() - def capture_model(self) -> None: + def capture_model(self) -> int: if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE: logger.warning( "Skipping CUDA graph capture. To turn on CUDA graph capture, " "ensure `cudagraph_mode` was not manually set to `NONE`") - return + return 0 else: self.initialize_cudagraph_capture() @@ -2588,6 +2734,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): finally: if should_freeze: gc.unfreeze() + gc.collect() # Trigger CUDA graph capture for specific shapes. # Capture the large shapes first so that the smaller shapes @@ -2635,6 +2782,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # This usually takes 5~20 seconds. logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", elapsed_time, cuda_graph_size / (1 << 30)) + return cuda_graph_size def _capture_cudagraphs(self, compilation_cases: list[int], cudagraph_runtime_mode: CUDAGraphMode, @@ -3082,7 +3230,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) for layer_name, attn_module in attn_layers.items(): if attn_module.attn_type == AttentionType.ENCODER_ONLY: - attn_spec = EncoderOnlyAttentionSpec( + attn_spec: AttentionSpec = EncoderOnlyAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, @@ -3124,7 +3272,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.shared_kv_cache_layers[layer_name] = kv_tgt_layer continue - # TODO: Support other attention modules, e.g., cross-attention # TODO(lucas): move the attention specs into the model layers like # the attention backends if attn_module.attn_type == AttentionType.DECODER: @@ -3152,19 +3299,26 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): head_size=attn_module.head_size, dtype=self.kv_cache_dtype, use_mla=use_mla) + elif attn_module.attn_type == AttentionType.ENCODER_DECODER: + kv_cache_spec[layer_name] = CrossAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + use_mla=use_mla) elif attn_module.attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY): # encoder-only attention does not need KV cache. continue - elif attn_module.attn_type == AttentionType.ENCODER_DECODER: - raise NotImplementedError else: raise ValueError( f"Unknown attention type: {attn_module.attn_type}") mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase) if len(mamba_layers) > 0: - if self.vllm_config.speculative_config is not None: + if (self.vllm_config.speculative_config is not None + and self.vllm_config.model_config.hf_config.model_type + not in ["qwen3_next"]): raise NotImplementedError( "Mamba with speculative decoding is not supported yet.") if self.vllm_config.cache_config.enable_prefix_caching: @@ -3183,7 +3337,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): dtypes=mamba_module.get_state_dtype(), block_size=max_model_len, page_size_padded=page_size_padded, - mamba_type=mamba_module.mamba_type) + mamba_type=mamba_module.mamba_type, + num_speculative_blocks=( + self.speculative_config.num_speculative_tokens + if self.speculative_config else 0), + ) return kv_cache_spec diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 6a3bc5d46df27..37dd431fd68f8 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -231,18 +231,40 @@ class Worker(WorkerBase): You may limit the usage of GPU memory by adjusting the `gpu_memory_utilization` parameter. """ + GiB = lambda b: b / GiB_bytes + if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes: + # still need a profile run which compiles the model for + # max_num_batched_tokens + self.model_runner.profile_run() + + msg = ( + f"Initial free memory {GiB(self.init_snapshot.free_memory)} " + f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f}GiB memory for " + "KV Cache as specified by kv_cache_memory_bytes config and " + "skipped memory profiling. This does does not respect the " + "gpu_memory_utilization config. Only use kv_cache_memory_bytes " + "config when you want manual control of KV cache memory " + "size. If OOM'ed, check the difference of initial free " + "memory between the current run and the previous run " + "where kv_cache_memory_bytes is suggested and update it " + "correspondingly.") + logger.info(msg) + return kv_cache_memory_bytes + torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() - GiB = lambda b: b / GiB_bytes # Execute a forward pass with dummy inputs to profile the memory usage # of the model. with memory_profiling( self.init_snapshot, - weights_memory=int( - self.model_runner.model_memory_usage)) as profile_result: + weights_memory=int(self.model_runner.model_memory_usage), + ) as profile_result: self.model_runner.profile_run() + self.non_torch_memory = profile_result.non_torch_increase + self.peak_activation_memory = profile_result.torch_peak_increase + free_gpu_memory = profile_result.after_profile.free_memory # NOTE(woosuk): Here we assume that the other processes using the same # GPU did not change their memory usage during the profiling. @@ -254,7 +276,7 @@ class Worker(WorkerBase): "release GPU memory while vLLM is profiling during initialization. " "To fix this, ensure consistent GPU memory allocation or " "isolate vLLM in its own container.") - available_kv_cache_memory = self.requested_memory \ + self.available_kv_cache_memory_bytes = self.requested_memory \ - profile_result.non_kv_cache_memory unrequested_memory = self.init_snapshot.free_memory \ @@ -274,10 +296,10 @@ class Worker(WorkerBase): ) logger.debug(profile_result) logger.info("Available KV cache memory: %.2f GiB", - GiB(available_kv_cache_memory)) + GiB(self.available_kv_cache_memory_bytes)) gc.collect() - return int(available_kv_cache_memory) + return int(self.available_kv_cache_memory_bytes) def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: return self.model_runner.get_kv_cache_spec() @@ -317,8 +339,56 @@ class Worker(WorkerBase): # cuda graph capture. kernel_warmup(self) + cuda_graph_memory_bytes = 0 if not self.model_config.enforce_eager: - self.model_runner.capture_model() + cuda_graph_memory_bytes = self.model_runner.capture_model() + + if (self.cache_config.kv_cache_memory_bytes is None + and hasattr(self, "peak_activation_memory")): + # Suggests optimal kv cache memory size if we rely on + # memory_profiling to guess the kv cache memory size which + # provides peak_activation_memory and a few other memory + # consumption. `memory_profiling` does not consider + # CUDAGraph memory size and may not utilize all gpu memory. + # Users may want fine-grained control to specify kv cache + # memory size. + GiB = lambda b: round(b / GiB_bytes, 2) + + # empirically observed that the memory profiling may + # slightly underestimate the memory consumption. + # So leave a small buffer (=150MiB) to avoid OOM. + redundancy_buffer_memory = 150 * (1 << 20) + non_kv_cache_memory = (self.model_runner.model_memory_usage + + self.peak_activation_memory + + self.non_torch_memory + + cuda_graph_memory_bytes) + kv_cache_memory_bytes_to_gpu_limit = ( + self.init_snapshot.free_memory - non_kv_cache_memory - + redundancy_buffer_memory) + kv_cache_memory_bytes_to_requested_limit = ( + int(self.requested_memory) - non_kv_cache_memory - + redundancy_buffer_memory) + + msg = ( + f"Free memory on device " + f"({GiB(self.init_snapshot.free_memory)}/" + f"{GiB(self.init_snapshot.total_memory)} GiB) on startup. " + f"Desired GPU memory utilization is " + f"({self.cache_config.gpu_memory_utilization}, " + f"{GiB(self.requested_memory)} GiB). " + f"Actual usage is {GiB(self.model_runner.model_memory_usage)} " + f"GiB for weight, {GiB(self.peak_activation_memory)} GiB " + f"for peak activation, {GiB(self.non_torch_memory)} GiB " + f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} " + f"GiB for CUDAGraph memory. Replace gpu_memory_utilization " + f"config with `--kv-cache-memory=" + f"{kv_cache_memory_bytes_to_requested_limit}` to fit into " + f"requested memory, or `--kv-cache-memory=" + f"{kv_cache_memory_bytes_to_gpu_limit}` to fully " + f"utilize gpu memory. Current kv cache memory in use is " + f"{int(self.available_kv_cache_memory_bytes)} bytes.") + + logger.info(msg) # Warm up sampler and preallocate memory buffer for logits and other # sampling related tensors of max possible shape to avoid memory @@ -601,6 +671,9 @@ class Worker(WorkerBase): self.model_runner.save_tensorized_model( tensorizer_config=tensorizer_config, ) + def shutdown(self) -> None: + self.model_runner.ensure_kv_transfer_shutdown() + def init_worker_distributed_environment( vllm_config: VllmConfig, diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py index e2ffa2f12fda5..3eb9f26e9f5b6 100644 --- a/vllm/v1/worker/kv_connector_model_runner_mixin.py +++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py @@ -9,7 +9,8 @@ from typing import Generator # noqa: UP035 from typing import TYPE_CHECKING, Optional from vllm.config import VllmConfig -from vllm.distributed.kv_transfer import (get_kv_transfer_group, +from vllm.distributed.kv_transfer import (ensure_kv_transfer_shutdown, + get_kv_transfer_group, has_kv_transfer_group) from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase from vllm.forward_context import get_forward_context, set_forward_context @@ -42,6 +43,12 @@ class KVConnectorModelRunnerMixin: # Do this here to save a collective_rpc. kv_connector.start_load_kv(get_forward_context()) + @staticmethod + def ensure_kv_transfer_shutdown() -> None: + # has_kv_transfer_group can be None during interpreter shutdown. + if has_kv_transfer_group and has_kv_transfer_group(): + ensure_kv_transfer_shutdown() + @staticmethod def maybe_wait_for_kv_save() -> None: if has_kv_transfer_group(): diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index 4b5f27d27541b..01d5f0525c4e2 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -11,7 +11,8 @@ import numpy as np import torch import torch.nn as nn -from vllm.config import LoRAConfig, ModelConfig, SchedulerConfig +from vllm.config import ModelConfig, SchedulerConfig +from vllm.config.lora import LoRAConfig from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest @@ -62,8 +63,7 @@ class LoRAModelRunnerMixin: def _set_active_loras(self, prompt_lora_mapping: tuple[int, ...], token_lora_mapping: tuple[int, ...], lora_requests: set[LoRARequest]) -> None: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") + self._ensure_lora_enabled() # Set is_prefill to True, so we always use the SGMV kernels on # non-cuda platforms. @@ -74,6 +74,11 @@ class LoRAModelRunnerMixin: is_prefill=True) self.lora_manager.set_active_adapters(lora_requests, lora_mapping) + def _ensure_lora_enabled(self) -> None: + if not hasattr(self, "lora_manager"): + raise RuntimeError( + "LoRA is not enabled. Use --enable-lora to enable LoRA.") + def set_active_loras(self, input_batch: InputBatch, num_scheduled_tokens: np.ndarray) -> None: @@ -171,21 +176,17 @@ class LoRAModelRunnerMixin: self.lora_manager.remove_all_adapters() def add_lora(self, lora_request: LoRARequest) -> bool: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") + self._ensure_lora_enabled() return self.lora_manager.add_adapter(lora_request) def remove_lora(self, lora_id: int) -> bool: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") + self._ensure_lora_enabled() return self.lora_manager.remove_adapter(lora_id) def pin_lora(self, lora_id: int) -> bool: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") + self._ensure_lora_enabled() return self.lora_manager.pin_adapter(lora_id) def list_loras(self) -> set[int]: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") + self._ensure_lora_enabled() return self.lora_manager.list_adapters() diff --git a/vllm/v1/worker/tpu_input_batch.py b/vllm/v1/worker/tpu_input_batch.py index 489edf772e1c9..01d1904778278 100644 --- a/vllm/v1/worker/tpu_input_batch.py +++ b/vllm/v1/worker/tpu_input_batch.py @@ -391,7 +391,7 @@ class InputBatch: # NOTE: the following is unsafe # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\ # self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...] - # instead, we need to temporiarily copy the data for one of the indices + # instead, we need to temporarily copy the data for one of the indices # TODO(lucas): optimize this by only copying valid indices tmp = self.token_ids_cpu[i1, ...].copy() self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...] diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 5947b54d33ce0..43f12912707f1 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -387,9 +387,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.requests[req_id] = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, - mm_kwargs=new_req_data.mm_kwargs, - mm_positions=new_req_data.mm_positions, - mm_hashes=new_req_data.mm_hashes, + mm_features=new_req_data.mm_features, sampling_params=sampling_params, pooling_params=None, generator=None, @@ -822,10 +820,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_state = self.requests[req_id] for mm_input_id in encoder_input_ids: - mm_hash = req_state.mm_hashes[mm_input_id] - mm_kwargs.append(req_state.mm_kwargs[mm_input_id]) - mm_hashes_pos.append( - (mm_hash, req_state.mm_positions[mm_input_id])) + mm_feature = req_state.mm_features[mm_input_id] + mm_hash = mm_feature.identifier + mm_kwargs.append(mm_feature.data) + mm_hashes_pos.append((mm_hash, mm_feature.mm_position)) # Batch mm inputs as much as we can: if a request in the batch has # multiple modalities or a different modality than the previous one, @@ -883,13 +881,12 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_id] req_state = self.requests[req_id] num_computed_tokens = req_state.num_computed_tokens - mm_positions = req_state.mm_positions - mm_hashes = req_state.mm_hashes # TODO unroll loop and assume/enforce --disable_chunked_mm_input # NOTE (NickLucche) here we diverge from logic in other runners, as # we assume to only have whole mm items to process. Hence we avoid # the intrinsic dynamism that `gather_mm_placeholders` introduces. - for i, pos_info in enumerate(mm_positions): + for mm_feature in req_state.mm_features: + pos_info = mm_feature.mm_position start_pos = pos_info.offset num_encoder_tokens = pos_info.length @@ -904,8 +901,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # The encoder output is already processed and stored # in the decoder's KV cache. continue - - mm_hash = mm_hashes[i] + mm_hash = mm_feature.identifier encoder_output = self.encoder_cache.get(mm_hash, None) assert encoder_output is not None,\ f"Encoder cache miss for {mm_hash}." @@ -1769,28 +1765,22 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.grammar_bitmask_cpu.zero_() self.require_structured_out_cpu.zero_() - # We receive the structured output bitmask from the scheduler, but the - # indices of the requests in the batch may not match the indices of - # the bitmask since the scheduler doesn't know how the tpu runner is - # ordering the requests in the batch. We need to match the order of - # bitmask with the order of requests - struct_out_indices: list[int] = [] - mask_indices: list[int] = [] - for req_id in self.input_batch.req_ids: - mask_index = scheduler_output.structured_output_request_ids.get( - req_id) - if mask_index is None: + sorted_struct_requests = sorted( + scheduler_output.structured_output_request_ids.items(), + key=lambda item: item[1]) + cumulative_mask_idx = 0 + for req_id, _ in sorted_struct_requests: + if req_id not in self.input_batch.req_id_to_index: continue batch_index = self.input_batch.req_id_to_index[req_id] - struct_out_indices.append(batch_index) - mask_indices.append(mask_index) - self.grammar_bitmask_cpu[struct_out_indices] = torch.from_numpy( - grammar_bitmask[mask_indices]) - # It's not guaranteed that all requests in this batch require - # structured output, so create a bool tensor to represent - # the requests that need structured output. - struct_out_indices = torch.tensor(struct_out_indices, dtype=torch.long) - self.require_structured_out_cpu[struct_out_indices] = True + self.grammar_bitmask_cpu[batch_index] = torch.from_numpy( + grammar_bitmask[cumulative_mask_idx]) + # It's not guaranteed that all requests in this batch require + # structured output, so create a bool tensor to represent + # the requests that need structured output. + self.require_structured_out_cpu[batch_index] = True + cumulative_mask_idx += 1 + return self.require_structured_out_cpu[:num_reqs].to(logits.device), \ self.grammar_bitmask_cpu[:num_reqs].to(logits.device), \ self.structured_decode_arange.to(logits.device) diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 3f4e3ecbd4e26..fc72b954df9cf 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -330,6 +330,9 @@ class TPUWorker: ensure_kv_transfer_initialized(vllm_config) + def shutdown(self) -> None: + self.model_runner.ensure_kv_transfer_shutdown() + if USE_TPU_COMMONS: from tpu_commons.worker import TPUWorker as TPUCommonsWorker diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 6767804c71b9f..be05d02ff29fe 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -12,6 +12,7 @@ from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.model_executor.models.utils import extract_layer_index from vllm.multimodal.cache import processor_only_cache_from_config from vllm.multimodal.registry import MultiModalRegistry +from vllm.platforms import current_platform from vllm.v1.attention.backends.utils import AttentionMetadataBuilder from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget from vllm.v1.kv_cache_interface import KVCacheGroupSpec @@ -269,7 +270,17 @@ def bind_kv_cache( # One typical case is encoder-decoder model, e.g., bart. # The cross attention and self attention in the same decoder layer # has different layer_name but the same layer_index. - raise NotImplementedError + + # TODO - analyze where runner_kv_caches is used and the right + # way to ensure it properly reflects multiple attention layers + # in the same decoder block. + if current_platform.is_cuda(): + # We know that the GPU runner is not impacted by this + # case. Some test code depends on runner_kv_caches, but + # not in a way that's impacted by ignoring this. + pass + else: + raise NotImplementedError layer_name = layer_names[0] runner_kv_caches.append(kv_caches[layer_name]) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index f05401fd01327..88f83c9dd7e6c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1337,8 +1337,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): return self.lora_manager.list_adapters() @torch.inference_mode() - def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: - """Cuda graph capture a model. + def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> int: + """Cuda graph capture a model and return cudagraph memory + consumption in bytes. Note that CUDA graph's performance gain is negligible if number of batched tokens are larger than 200. And since CUDA graph @@ -1505,6 +1506,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): # This usually takes < 10 seconds. logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", elapsed_time, cuda_graph_size / GiB_bytes) + return cuda_graph_size def _update_inputs_to_capture_for_enc_dec_model(self, capture_inputs: Dict[str, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index b4a67e2899d0d..670f256c0bf65 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -78,7 +78,8 @@ class Worker(LocalOrDistributedWorkerBase): "deepseek_mtp", "glm4_moe_mtp", "mimo_mtp", - "ernie_mtp")) \ + "ernie_mtp", + "qwen3_next_mtp")) \ else {"return_hidden_states": True} ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner @@ -228,6 +229,67 @@ class Worker(LocalOrDistributedWorkerBase): self.model_runner.save_tensorized_model( tensorizer_config=tensorizer_config, ) + @torch.inference_mode() + def determine_available_kv_cache_memory(self, + total_gpu_memory: int) -> float: + if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes: + # still need a profile run which compiles the model for + # max_num_batched_tokens + self.model_runner.profile_run() + + GiB = lambda b: b / GiB_bytes + msg = ( + f"Initial free memory " + f"{GiB(self.baseline_snapshot.free_memory):.2f} " + f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f}GiB memory for " + "KV Cache as specified by kv_cache_memory_bytes config and " + "skipped memory profiling. This does does not respect the " + "gpu_memory_utilization config. Only use kv_cache_memory_bytes " + "config when you want manual control of KV cache memory " + "size. If OOM'ed, check the difference of initial free " + "memory between the current run and the previous run " + "where kv_cache_memory_bytes is suggested and update it " + "correspondingly.") + logger.info(msg) + return self.cache_config.kv_cache_memory_bytes + + # Execute a forward pass with dummy inputs to profile the memory usage + # of the model. + with memory_profiling( + self.baseline_snapshot, + weights_memory=self.model_runner.model_memory_usage) as result: + self.model_runner.profile_run() + + self.non_torch_memory = result.non_torch_increase + self.peak_activation_memory = result.torch_peak_increase + + self._assert_memory_footprint_increased_during_profiling() + + self.requested_memory = total_gpu_memory * \ + self.cache_config.gpu_memory_utilization + + self.available_kv_cache_memory = (self.requested_memory - + result.non_kv_cache_memory) + + msg = (f"Memory profiling takes {result.profile_time:.2f} seconds\n" + "the current vLLM instance can use " + "total_gpu_memory " + f"({(total_gpu_memory / GiB_bytes):.2f}GiB)" + " x gpu_memory_utilization " + f"({self.cache_config.gpu_memory_utilization:.2f})" + f" = {(self.requested_memory / GiB_bytes):.2f}GiB\n" + "model weights take " + f"{(result.weights_memory / GiB_bytes):.2f}GiB;" + " non_torch_memory takes " + f"{(result.non_torch_increase / GiB_bytes):.2f}GiB;" + " PyTorch activation peak memory takes " + f"{(result.torch_peak_increase / GiB_bytes):.2f}GiB;" + " the rest of the memory reserved for KV Cache is " + f"{(self.available_kv_cache_memory / GiB_bytes):.2f}GiB.") + + logger.info(msg) + return self.available_kv_cache_memory + @torch.inference_mode() def determine_num_available_blocks(self) -> Tuple[int, int]: """Profiles the peak memory usage of the model to determine how many @@ -247,20 +309,8 @@ class Worker(LocalOrDistributedWorkerBase): torch.cuda.reset_peak_memory_stats() free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info() - - # Execute a forward pass with dummy inputs to profile the memory usage - # of the model. - with memory_profiling( - self.baseline_snapshot, - weights_memory=self.model_runner.model_memory_usage) as result: - self.model_runner.profile_run() - - self._assert_memory_footprint_increased_during_profiling() - - memory_for_current_instance = total_gpu_memory * \ - self.cache_config.gpu_memory_utilization - available_kv_cache_memory = (memory_for_current_instance - - result.non_kv_cache_memory) + available_kv_cache_memory = self.determine_available_kv_cache_memory( + total_gpu_memory) # Calculate the number of blocks that can be allocated with the # profiled peak memory. @@ -275,23 +325,6 @@ class Worker(LocalOrDistributedWorkerBase): num_gpu_blocks = max(num_gpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0) - msg = (f"Memory profiling takes {result.profile_time:.2f} seconds\n" - "the current vLLM instance can use " - "total_gpu_memory " - f"({(total_gpu_memory / GiB_bytes):.2f}GiB)" - " x gpu_memory_utilization " - f"({self.cache_config.gpu_memory_utilization:.2f})" - f" = {(memory_for_current_instance / GiB_bytes):.2f}GiB\n" - "model weights take " - f"{(result.weights_memory / GiB_bytes):.2f}GiB;" - " non_torch_memory takes " - f"{(result.non_torch_increase / GiB_bytes):.2f}GiB;" - " PyTorch activation peak memory takes " - f"{(result.torch_peak_increase / GiB_bytes):.2f}GiB;" - " the rest of the memory reserved for KV Cache is " - f"{(available_kv_cache_memory / GiB_bytes):.2f}GiB.") - - logger.info(msg) # Final cleanup gc.collect() @@ -381,8 +414,58 @@ class Worker(LocalOrDistributedWorkerBase): for size in sorted(warmup_sizes, reverse=True): logger.info("Compile and warming up model for size %d", size) self.model_runner._dummy_run(size) + + cuda_graph_memory_bytes = 0 if not self.model_config.enforce_eager: - self.model_runner.capture_model(self.gpu_cache) + cuda_graph_memory_bytes = self.model_runner.capture_model( + self.gpu_cache) + + if (self.cache_config.kv_cache_memory_bytes is None + and hasattr(self, "peak_activation_memory")): + # Suggests optimal kv cache memory size if we rely on + # memory_profiling to guess the kv cache memory size which + # provides peak_activation_memory and a few other memory + # consumption. `memory_profiling` does not consider + # CUDAGraph memory size and may not utilize all gpu memory. + # Users may want fine-grained control to specify kv cache + # memory size. + GiB = lambda b: round(b / GiB_bytes, 2) + non_kv_cache_memory = (self.model_runner.model_memory_usage + + self.peak_activation_memory + + self.non_torch_memory + + cuda_graph_memory_bytes) + + # empirically observed that the memory profiling may + # slightly underestimate the memory consumption. + # So leave a small buffer (=150MiB) to avoid OOM. + redundancy_buffer_memory = 150 * (1 << 20) + kv_cache_memory_bytes_to_gpu_limit = ( + self.baseline_snapshot.free_memory - non_kv_cache_memory - + redundancy_buffer_memory) + kv_cache_memory_bytes_to_requested_limit = ( + int(self.requested_memory) - non_kv_cache_memory - + redundancy_buffer_memory) + + msg = ( + f"Free memory on device " + f"({GiB(self.baseline_snapshot.free_memory)}/" + f"{GiB(self.baseline_snapshot.total_memory)} GiB) on startup. " + f"Desired GPU memory utilization is " + f"({self.cache_config.gpu_memory_utilization}, " + f"{GiB(self.requested_memory)} GiB). " + f"Actual usage is {GiB(self.model_runner.model_memory_usage)} " + f"GiB for weight, {GiB(self.peak_activation_memory)} GiB " + f"for peak activation, {GiB(self.non_torch_memory)} GiB " + f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} " + f"GiB for CUDAGraph memory. Replace gpu_memory_utilization " + f"config with `--kv-cache-memory=" + f"{kv_cache_memory_bytes_to_requested_limit}` to fit into " + f"requested memory, or `--kv-cache-memory=" + f"{kv_cache_memory_bytes_to_gpu_limit}` to fully " + f"utilize gpu memory. Current kv cache memory in use is " + f"{int(self.available_kv_cache_memory)} bytes.") + logger.info(msg) + # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling. set_random_seed(self.model_config.seed) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index a1fa7f2cf7a2e..aa76d21f0fcaa 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -129,6 +129,10 @@ class WorkerBase: """Get vocabulary size from model configuration.""" return self.model_config.get_vocab_size() + def shutdown(self) -> None: + """Clean up resources held by the worker.""" + return + class DelegateWorkerBase(WorkerBase): """ @@ -519,6 +523,10 @@ class WorkerWrapperBase: from vllm.utils import init_cached_hf_modules init_cached_hf_modules() + def shutdown(self) -> None: + if self.worker is not None: + self.worker.shutdown() + def adjust_rank(self, rank_mapping: Dict[int, int]) -> None: """ Adjust the rpc_rank based on the given mapping.