diff --git a/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-W8A16-compressed-tensors.yaml b/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-W8A16-compressed-tensors.yaml deleted file mode 100644 index 1bce7e7fdf146..0000000000000 --- a/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-W8A16-compressed-tensors.yaml +++ /dev/null @@ -1,12 +0,0 @@ -# For vllm script, with -t option (tensor parallel size). -# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Qwen2-1.5B-Instruct-W8A16-Channelwise -b "auto" -l 1000 -f 5 -t 1 -model_name: "nm-testing/Qwen2-1.5B-Instruct-W8A16-Channelwise" -tasks: -- name: "gsm8k" - metrics: - - name: "exact_match,strict-match" - value: 0.595 - - name: "exact_match,flexible-extract" - value: 0.582 -limit: 1000 -num_fewshot: 5 diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml index 33b7114666fa2..12f730738b8a5 100644 --- a/.buildkite/release-pipeline.yaml +++ b/.buildkite/release-pipeline.yaml @@ -116,24 +116,6 @@ steps: commands: - "bash .buildkite/scripts/annotate-release.sh" - - label: "Build and publish TPU release image" - depends_on: ~ - if: build.env("NIGHTLY") == "1" - agents: - queue: tpu_queue_postmerge - commands: - - "yes | docker system prune -a" - - "git fetch --all" - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --tag vllm/vllm-tpu:nightly --tag vllm/vllm-tpu:$BUILDKITE_COMMIT --progress plain -f docker/Dockerfile.tpu ." - - "docker push vllm/vllm-tpu:nightly" - - "docker push vllm/vllm-tpu:$BUILDKITE_COMMIT" - plugins: - - docker-login#v3.0.0: - username: vllmbot - password-env: DOCKERHUB_TOKEN - env: - DOCKER_BUILDKIT: "1" - - input: "Provide Release version here" id: input-release-version fields: diff --git a/.buildkite/scripts/annotate-release.sh b/.buildkite/scripts/annotate-release.sh index fde48603ad3cd..56bb5cedaa0a9 100755 --- a/.buildkite/scripts/annotate-release.sh +++ b/.buildkite/scripts/annotate-release.sh @@ -2,16 +2,23 @@ set -ex -# Get release version and strip leading 'v' if present -RELEASE_VERSION=$(buildkite-agent meta-data get release-version | sed 's/^v//') - -if [ -z "$RELEASE_VERSION" ]; then - echo "Error: RELEASE_VERSION is empty. 'release-version' metadata might not be set or is invalid." - exit 1 +# Get release version, default to 1.0.0.dev for nightly/per-commit builds +RELEASE_VERSION=$(buildkite-agent meta-data get release-version 2>/dev/null | sed 's/^v//') +if [ -z "${RELEASE_VERSION}" ]; then + RELEASE_VERSION="1.0.0.dev" fi buildkite-agent annotate --style 'info' --context 'release-workflow' << EOF -To download the wheel: +To download the wheel (by commit): +\`\`\` +aws s3 cp s3://vllm-wheels/${BUILDKITE_COMMIT}/vllm-${RELEASE_VERSION}-cp38-abi3-manylinux1_x86_64.whl . +aws s3 cp s3://vllm-wheels/${BUILDKITE_COMMIT}/vllm-${RELEASE_VERSION}-cp38-abi3-manylinux2014_aarch64.whl . + +aws s3 cp s3://vllm-wheels/${BUILDKITE_COMMIT}/vllm-${RELEASE_VERSION}+cu129-cp38-abi3-manylinux1_x86_64.whl . +aws s3 cp s3://vllm-wheels/${BUILDKITE_COMMIT}/vllm-${RELEASE_VERSION}+cu129-cp38-abi3-manylinux1_x86_64.whl . +\`\`\` + +To download the wheel (by version): \`\`\` aws s3 cp s3://vllm-wheels/${RELEASE_VERSION}/vllm-${RELEASE_VERSION}-cp38-abi3-manylinux1_x86_64.whl . aws s3 cp s3://vllm-wheels/${RELEASE_VERSION}/vllm-${RELEASE_VERSION}-cp38-abi3-manylinux2014_aarch64.whl . diff --git a/.buildkite/scripts/hardware_ci/run-amd-test.sh b/.buildkite/scripts/hardware_ci/run-amd-test.sh index aa4cc7b35a543..58fd435691f4a 100755 --- a/.buildkite/scripts/hardware_ci/run-amd-test.sh +++ b/.buildkite/scripts/hardware_ci/run-amd-test.sh @@ -173,6 +173,14 @@ fi PARALLEL_JOB_COUNT=8 MYPYTHONPATH=".." +# Test that we're launching on the machine that has +# proper access to GPUs +render_gid=$(getent group render | cut -d: -f3) +if [[ -z "$render_gid" ]]; then + echo "Error: 'render' group not found. This is required for GPU access." >&2 + exit 1 +fi + # check if the command contains shard flag, we will run all shards in parallel because the host have 8 GPUs. if [[ $commands == *"--shard-id="* ]]; then # assign job count as the number of shards used @@ -186,6 +194,7 @@ if [[ $commands == *"--shard-id="* ]]; then --device /dev/kfd $BUILDKITE_AGENT_META_DATA_RENDER_DEVICES \ --network=host \ --shm-size=16gb \ + --group-add "$render_gid" \ --rm \ -e HIP_VISIBLE_DEVICES="${GPU}" \ -e HF_TOKEN \ @@ -217,8 +226,8 @@ else --device /dev/kfd $BUILDKITE_AGENT_META_DATA_RENDER_DEVICES \ --network=host \ --shm-size=16gb \ + --group-add "$render_gid" \ --rm \ - -e HIP_VISIBLE_DEVICES=0 \ -e HF_TOKEN \ -e AWS_ACCESS_KEY_ID \ -e AWS_SECRET_ACCESS_KEY \ diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index c023457fb03e4..bb5ef5d624630 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -48,8 +48,8 @@ steps: commands: - bash standalone_tests/pytorch_nightly_dependency.sh -- label: Async Engine, Inputs, Utils, Worker Test # 36min - timeout_in_minutes: 50 +- label: Async Engine, Inputs, Utils, Worker Test # 10min + timeout_in_minutes: 15 mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 # grade: Blocking @@ -344,7 +344,7 @@ steps: - pytest -v -s v1/logits_processors - pytest -v -s v1/worker - pytest -v -s v1/spec_decode - - pytest -v -s -m 'not cpu_test' v1/kv_connector/unit + - pytest -v -s -m 'not cpu_test' v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_lmcache_integration.py - pytest -v -s -m 'not cpu_test' v1/metrics - pytest -v -s v1/test_oracle.py - pytest -v -s v1/test_request.py @@ -616,9 +616,9 @@ steps: - uv pip install --system torchao==0.13.0 - VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/ --ignore quantization/test_blackwell_moe.py -- label: LM Eval Small Models # 53min - timeout_in_minutes: 75 - mirror_hardwares: [amdexperimental] +- label: LM Eval Small Models # 15min + timeout_in_minutes: 20 + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 # grade: Blocking source_file_dependencies: @@ -627,17 +627,18 @@ steps: commands: - pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1 -- label: OpenAI API correctness # 22min - timeout_in_minutes: 30 - mirror_hardwares: [amdexperimental] +- label: OpenAI API correctness # 10min + timeout_in_minutes: 15 + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 # grade: Blocking source_file_dependencies: - csrc/ - vllm/entrypoints/openai/ - vllm/model_executor/models/whisper.py - commands: # LMEval+Transcription WER check - - pytest -s entrypoints/openai/correctness/ + commands: # LMEval + # Transcription WER check is skipped because encoder-decoder models are not supported on ROCm, see https://github.com/vllm-project/vllm/issues/27442 + - pytest -s entrypoints/openai/correctness/ --ignore entrypoints/openai/correctness/test_transcription_api_correctness.py - label: OpenAI-Compatible Tool Use # 23 min timeout_in_minutes: 35 @@ -858,10 +859,10 @@ steps: - pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing - cd .. && VLLM_WORKER_MULTIPROC_METHOD=spawn pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work -- label: Multi-Modal Accuracy Eval (Small Models) # 50min - mirror_hardwares: [amdexperimental] +- label: Multi-Modal Accuracy Eval (Small Models) # 10min + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 - timeout_in_minutes: 70 + timeout_in_minutes: 15 working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" source_file_dependencies: - vllm/multimodal/ diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index a020b0d276be0..a0d2076199b14 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -232,8 +232,8 @@ steps: commands: - pytest -v -s distributed/test_eplb_algo.py -- label: EPLB Execution Test # 5min - timeout_in_minutes: 15 +- label: EPLB Execution Test # 10min + timeout_in_minutes: 20 working_dir: "/vllm-workspace/tests" num_gpus: 4 source_file_dependencies: @@ -241,6 +241,7 @@ steps: - tests/distributed/test_eplb_execute.py commands: - pytest -v -s distributed/test_eplb_execute.py + - pytest -v -s distributed/test_eplb_spec_decode.py - label: Metrics, Tracing Test # 12min timeout_in_minutes: 20 @@ -315,6 +316,7 @@ steps: - vllm/ - tests/v1 commands: + - uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt # split the test to avoid interference - pytest -v -s -m 'not cpu_test' v1/core - pytest -v -s v1/executor @@ -347,8 +349,7 @@ steps: - vllm/v1/attention - tests/v1/attention commands: - - export VLLM_DISABLE_FLASHINFER_PREFILL=1 # TODO: FI prefill is bugged and causes incorrectness, fix this - - pytest -v -s v1/attention + - VLLM_DISABLE_FLASHINFER_PREFILL=1 pytest -v -s v1/attention # TODO: FI prefill is bugged and causes incorrectness, fix this - label: V1 Test others (CPU) # 5 mins source_file_dependencies: @@ -459,18 +460,21 @@ steps: - tests/compile commands: - pytest -v -s compile/test_basic_correctness.py + - pytest -v -s compile/test_multimodal_compile.py - pytest -v -s compile/piecewise/ -- label: PyTorch Fullgraph Test # 22min - timeout_in_minutes: 35 +- label: PyTorch Fullgraph Test # 27min + timeout_in_minutes: 40 mirror_hardwares: [amdexperimental] torch_nightly: true source_file_dependencies: - vllm/ - tests/compile commands: - - pytest -v -s compile/test_full_graph.py - - pytest -v -s compile/test_fusions_e2e.py + - pytest -v -s compile/test_full_graph.py -k 'not test_fp8_kv_scale_compile' + # Limit to no custom ops to reduce running time + # Wrap with quotes to escape yaml and avoid starting -k string with a - + - "pytest -v -s compile/test_fusions_e2e.py -k 'TRITON and -quant_fp8'" - label: Cudagraph test timeout_in_minutes: 20 @@ -544,8 +548,11 @@ steps: - label: Model Executor Test # 23min timeout_in_minutes: 35 + torch_nightly: true mirror_hardwares: [amdexperimental] source_file_dependencies: + - vllm/engine/arg_utils.py + - vllm/config/model.py - vllm/model_executor - tests/model_executor - tests/entrypoints/openai/test_tensorizer_entrypoint.py @@ -924,7 +931,33 @@ steps: - pytest -v -s tests/compile/test_silu_mul_quant_fusion.py # this runner has 2 GPUs available even though num_gpus=2 is not set - pytest -v -s tests/compile/test_fusion_all_reduce.py + # Limit to Inductor partition, no custom ops, and allreduce & attn fusion to reduce running time + # Wrap with quotes to escape yaml + - "pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and Llama-3.1 and -quant_fp8 and -rms_norm'" + +- label: Blackwell Fusion E2E Tests # 30 min + timeout_in_minutes: 40 + working_dir: "/vllm-workspace/" + gpu: b200 + optional: true + num_gpus: 2 + source_file_dependencies: + - csrc/quantization/fp4/ + - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py + - vllm/v1/attention/backends/flashinfer.py + - vllm/compilation/ + # can affect pattern matching + - vllm/model_executor/layers/layernorm.py + - vllm/model_executor/layers/activation.py + - vllm/model_executor/layers/quantization/input_quant_fp8.py + - tests/compile/test_fusions_e2e.py + - tests/compile/test_full_graph.py + commands: + - nvidia-smi + # Run all e2e fusion tests - pytest -v -s tests/compile/test_fusions_e2e.py + # test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40) + - pytest -v -s tests/compile/test_full_graph.py::test_fp8_kv_scale_compile - label: Blackwell GPT-OSS Eval timeout_in_minutes: 60 @@ -1223,6 +1256,7 @@ steps: - pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm - pytest -v -s tests/distributed/test_context_parallel.py - CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 + - pytest -v -s tests/v1/distributed/test_dbo.py ##### B200 test ##### - label: Distributed Tests (B200) # optional @@ -1233,6 +1267,7 @@ steps: commands: - pytest -v -s tests/distributed/test_context_parallel.py - pytest -v -s tests/distributed/test_nccl_symm_mem_allreduce.py + - pytest -v -s tests/v1/distributed/test_dbo.py ##### RL Integration Tests ##### - label: Prime-RL Integration Test # 15min diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index ba08a43352154..23def076cf880 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -9,7 +9,7 @@ /vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth @yewentao256 @pavanimajety /vllm/model_executor/layers/mamba @tdoublep /vllm/model_executor/model_loader @22quinn -/vllm/multimodal @DarkLight1337 @ywang96 @NickLucche +/vllm/multimodal @DarkLight1337 @ywang96 @NickLucche @tjtanaa /vllm/vllm_flash_attn @LucasWilkinson /vllm/lora @jeejeelee /vllm/reasoning @aarnphm @chaunceyjiang @@ -105,11 +105,21 @@ mkdocs.yaml @hmellor /vllm/attention/ops/triton_unified_attention.py @tdoublep # ROCm related: specify owner with write access to notify AMD folks for careful code review -/docker/Dockerfile.rocm* @gshtras -/vllm/v1/attention/backends/rocm*.py @gshtras -/vllm/v1/attention/backends/mla/rocm*.py @gshtras -/vllm/attention/ops/rocm*.py @gshtras -/vllm/model_executor/layers/fused_moe/rocm*.py @gshtras +/vllm/**/*rocm* @tjtanaa +/docker/Dockerfile.rocm* @gshtras @tjtanaa +/vllm/v1/attention/backends/rocm*.py @gshtras @tjtanaa +/vllm/v1/attention/backends/mla/rocm*.py @gshtras @tjtanaa +/vllm/attention/ops/rocm*.py @gshtras @tjtanaa +/vllm/model_executor/layers/fused_moe/rocm*.py @gshtras @tjtanaa +/csrc/rocm @gshtras @tjtanaa +/requirements/*rocm* @tjtanaa +/tests/**/*rocm* @tjtanaa +/docs/**/*rocm* @tjtanaa +/vllm/**/*quark* @tjtanaa +/tests/**/*quark* @tjtanaa +/docs/**/*quark* @tjtanaa +/vllm/**/*aiter* @tjtanaa +/tests/**/*aiter* @tjtanaa # TPU /vllm/v1/worker/tpu* @NickLucche @@ -127,3 +137,8 @@ mkdocs.yaml @hmellor /vllm/config/pooler.py @noooop /vllm/pooling_params.py @noooop /vllm/model_executor/layers/pooler.py @noooop + +# Security guide and policies +/docs/usage/security.md @russellb +/SECURITY.md @russellb +/docs/contributing/vulnerability_management.md @russellb diff --git a/.gitignore b/.gitignore index ffa36dee1ab9d..50070d7898fe6 100644 --- a/.gitignore +++ b/.gitignore @@ -221,3 +221,6 @@ csrc/moe/marlin_moe_wna16/kernel_* # Ignore ep_kernels_workspace folder ep_kernels_workspace/ + +# Allow tracked library source folders under submodules (e.g., benchmarks/lib) +!vllm/benchmarks/lib/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bcd40e7f8ab39..e034f75a9d322 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -38,7 +38,7 @@ repos: rev: 0.9.1 hooks: - id: pip-compile - args: [requirements/test.in, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cu129, --python-platform, x86_64-manylinux_2_28] + args: [requirements/test.in, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cu129, --python-platform, x86_64-manylinux_2_28, --python-version, "3.12"] files: ^requirements/test\.(in|txt)$ - repo: local hooks: diff --git a/CMakeLists.txt b/CMakeLists.txt index 7cb94f919f123..0e9fa63b178ea 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -241,7 +241,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") message(STATUS "Enabling cumem allocator extension.") # link against cuda driver library list(APPEND CUMEM_LIBS CUDA::cuda_driver) - define_gpu_extension_target( + define_extension_target( cumem_allocator DESTINATION vllm LANGUAGE CXX @@ -858,7 +858,7 @@ if (VLLM_GPU_LANG STREQUAL "HIP") endif() message(STATUS "Enabling C extension.") -define_gpu_extension_target( +define_extension_target( _C DESTINATION vllm LANGUAGE ${VLLM_GPU_LANG} @@ -973,7 +973,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() message(STATUS "Enabling moe extension.") -define_gpu_extension_target( +define_extension_target( _moe_C DESTINATION vllm LANGUAGE ${VLLM_GPU_LANG} @@ -994,7 +994,7 @@ if(VLLM_GPU_LANG STREQUAL "HIP") "csrc/rocm/skinny_gemms.cu" "csrc/rocm/attention.cu") - define_gpu_extension_target( + define_extension_target( _rocm_C DESTINATION vllm LANGUAGE ${VLLM_GPU_LANG} diff --git a/README.md b/README.md index 2e750ef8fc894..b5e230e4b9b07 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ Join us at the [PyTorch Conference, October 22-23](https://events.linuxfoundatio *Latest News* 🔥 +- [2025/11] We hosted [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/xSrYXjNgr1HbCP4ExYNG1w) focusing on distributed inference and diverse accelerator support with vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1nQJ8ZkLSjKxvu36sSHaceVXtttbLvvu-?usp=drive_link). - [2025/10] We hosted [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/__xb4OyOsImz-9eAVrdlcg) focused on hands-on vLLM inference optimization! Please find the meetup slides [here](https://drive.google.com/drive/folders/1KqwjsFJLfEsC8wlDugnrR61zsWHt94Q6). - [2025/09] We hosted [vLLM Toronto Meetup](https://luma.com/e80e0ymm) focused on tackling inference at scale and speculative decoding with speakers from NVIDIA and Red Hat! Please find the meetup slides [here](https://docs.google.com/presentation/d/1IYJYmJcu9fLpID5N5RbW_vO0XLo0CGOR14IXOjB61V8/edit?usp=sharing). - [2025/08] We hosted [vLLM Shenzhen Meetup](https://mp.weixin.qq.com/s/k8ZBO1u2_2odgiKWH_GVTQ) focusing on the ecosystem around vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1Ua2SVKVSu-wp5vou_6ElraDt2bnKhiEA). @@ -83,7 +84,7 @@ vLLM is flexible and easy to use with: - Tensor, pipeline, data and expert parallelism support for distributed inference - Streaming outputs - OpenAI-compatible API server -- Support for NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, and TPU. Additionally, support for diverse hardware plugins such as Intel Gaudi, IBM Spyre and Huawei Ascend. +- Support for NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, Arm CPUs, and TPU. Additionally, support for diverse hardware plugins such as Intel Gaudi, IBM Spyre and Huawei Ascend. - Prefix caching support - Multi-LoRA support diff --git a/benchmarks/kernels/benchmark_fused_collective.py b/benchmarks/kernels/benchmark_fused_collective.py new file mode 100644 index 0000000000000..38e7fdcf55426 --- /dev/null +++ b/benchmarks/kernels/benchmark_fused_collective.py @@ -0,0 +1,1129 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Benchmark for FlashInfer fused collective operations vs standard operations. + +This benchmark compares: +1. FlashInfer's trtllm_allreduce_fusion (fused allreduce + rmsnorm + optional quant) +2. Standard tensor_model_parallel_all_reduce + separate rmsnorm/quant operations + +Usage with torchrun: + torchrun --nproc_per_node=2 benchmark_fused_collective.py + +""" + +import argparse +import itertools +import os +import time + +import pandas as pd +import torch # type: ignore +import torch.distributed as dist # type: ignore + +from vllm.config.vllm import CompilationConfig, VllmConfig, set_current_vllm_config +from vllm.distributed import ( + get_tp_group, + tensor_model_parallel_all_reduce, +) +from vllm.distributed.parallel_state import ( + graph_capture, + init_distributed_environment, + initialize_model_parallel, +) +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm # noqa +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 # noqa +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape # noqa +from vllm.platforms import current_platform # noqa + +RMS_NORM_OP = torch.ops._C.rms_norm +FUSED_ADD_RMS_NORM_OP = torch.ops._C.fused_add_rms_norm +RMS_NORM_STATIC_FP8_QUANT_OP = torch.ops._C.rms_norm_static_fp8_quant +FUSED_ADD_RMS_NORM_STATIC_FP8_QUANT_OP = ( + torch.ops._C.fused_add_rms_norm_static_fp8_quant +) +SCALED_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant + +logger = init_logger(__name__) + +# Try to import FlashInfer +try: + import flashinfer.comm as flashinfer_comm # type: ignore + + if not hasattr(flashinfer_comm, "trtllm_allreduce_fusion"): + flashinfer_comm = None + logger.warning( + "FlashInfer comm module found but missing trtllm_allreduce_fusion" + ) +except ImportError: + flashinfer_comm = None + logger.warning("FlashInfer not found, only benchmarking standard operations") + +# Constants +FP8_DTYPE = current_platform.fp8_dtype() +MiB = 1024 * 1024 + +# FlashInfer max sizes per world size +# Enable 64MB for 2, 4, 8 world sizes to verify large input sizes +# use --disable-oneshot to disable oneshot mode for very large input sizes +_FI_MAX_SIZES = { + 2: 64 * MiB, # 64MB + 4: 64 * MiB, # 64MB + 8: 64 * MiB, # 64MB +} + +# Global workspace tensor for FlashInfer +_FI_WORKSPACE_TENSOR = None + + +def setup_flashinfer_workspace( + world_size: int, + rank: int, + hidden_dim: int, + max_token_num: int, + use_fp32_lamport: bool = False, +): + """Setup FlashInfer workspace for fused allreduce operations.""" + global _FI_WORKSPACE_TENSOR + + if flashinfer_comm is None: + return None, None + + if world_size not in _FI_MAX_SIZES: + logger.warning("FlashInfer not supported for world size %s", world_size) + return None, None + + try: + # Create IPC workspace + ipc_handles, workspace_tensor = ( + flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( + tp_rank=rank, + tp_size=world_size, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + group=get_tp_group().device_group, + use_fp32_lamport=use_fp32_lamport, + ) + ) + + _FI_WORKSPACE_TENSOR = workspace_tensor + return ipc_handles, workspace_tensor + except Exception as e: + logger.error("Failed to setup FlashInfer workspace: %s", e) + return None, None + + +def cleanup_flashinfer_workspace(ipc_handles): + """Cleanup FlashInfer workspace.""" + if flashinfer_comm is None or ipc_handles is None: + return + + try: + group = get_tp_group().device_group + flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(ipc_handles, group) + except Exception as e: + logger.error("Failed to cleanup FlashInfer workspace: %s", e) + + +class FlashInferFusedAllReduceParams: + """Parameters for FlashInfer fused allreduce operations.""" + + def __init__( + self, + rank: int, + world_size: int, + use_fp32_lamport: bool = False, + max_token_num: int = 1024, + ): + self.rank = rank + self.world_size = world_size + self.use_fp32_lamport = use_fp32_lamport + self.trigger_completion_at_end = True + self.launch_with_pdl = True + self.fp32_acc = True + self.max_token_num = max_token_num + + def get_trtllm_fused_allreduce_kwargs(self): + return { + "world_rank": self.rank, + "world_size": self.world_size, + "launch_with_pdl": self.launch_with_pdl, + "trigger_completion_at_end": self.trigger_completion_at_end, + "fp32_acc": self.fp32_acc, + } + + +def flashinfer_fused_allreduce_rmsnorm( + input_tensor: torch.Tensor, + residual: torch.Tensor | None, + rms_gamma: torch.Tensor, + rms_eps: float, + allreduce_params: "FlashInferFusedAllReduceParams", + use_oneshot: bool, + norm_out: torch.Tensor | None = None, +): + """FlashInfer fused allreduce + rmsnorm operation.""" + if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None: + raise RuntimeError("FlashInfer not available or workspace not initialized") + + if norm_out is None: + norm_out = input_tensor + residual_out = residual + else: + residual_out = input_tensor + + flashinfer_comm.trtllm_allreduce_fusion( + allreduce_in=input_tensor, + token_num=input_tensor.shape[0], + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + hidden_dim=input_tensor.shape[-1], + workspace_ptrs=_FI_WORKSPACE_TENSOR, + pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm, + allreduce_out=None, + quant_out=None, + scale_out=None, + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, + scale_factor=None, + use_oneshot=use_oneshot, + **allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + + +def flashinfer_fused_allreduce_rmsnorm_fp8_quant( + input_tensor: torch.Tensor, + residual: torch.Tensor | None, + rms_gamma: torch.Tensor, + rms_eps: float, + scale_factor: torch.Tensor, + allreduce_params: FlashInferFusedAllReduceParams, + use_oneshot: bool = True, + norm_out: torch.Tensor | None = None, + quant_out: torch.Tensor | None = None, +): + """FlashInfer fused allreduce + rmsnorm + FP8 quantization.""" + if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None: + raise RuntimeError("FlashInfer not available or workspace not initialized") + + if norm_out is None: + norm_out = input_tensor + residual_out = residual + else: + residual_out = input_tensor + + flashinfer_comm.trtllm_allreduce_fusion( + allreduce_in=input_tensor, + token_num=input_tensor.shape[0], + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + hidden_dim=input_tensor.shape[-1], + workspace_ptrs=_FI_WORKSPACE_TENSOR, + pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant, + allreduce_out=None, + quant_out=quant_out, + scale_out=None, + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, + scale_factor=scale_factor, + use_oneshot=use_oneshot, + **allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + + +def flashinfer_fused_allreduce_rmsnorm_fp4_quant( + input_tensor: torch.Tensor, + residual: torch.Tensor | None, + rms_gamma: torch.Tensor, + rms_eps: float, + input_global_scale: torch.Tensor, + allreduce_params: FlashInferFusedAllReduceParams, + quant_out: torch.Tensor, + use_oneshot: bool, + output_scale: torch.Tensor, + norm_out: torch.Tensor | None = None, +): + """FlashInfer fused allreduce + rmsnorm + FP4 quantization.""" + if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None: + raise RuntimeError("FlashInfer not available or workspace not initialized") + + if norm_out is None: + norm_out = input_tensor + residual_out = residual + else: + residual_out = input_tensor + + flashinfer_comm.trtllm_allreduce_fusion( + allreduce_in=input_tensor, + token_num=input_tensor.shape[0], + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + hidden_dim=input_tensor.shape[-1], + workspace_ptrs=_FI_WORKSPACE_TENSOR, + pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant, + allreduce_out=None, + quant_out=quant_out, + scale_out=output_scale, + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, + scale_factor=input_global_scale, + use_oneshot=use_oneshot, + **allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + + +class VllmFusedAllreduce: + def __init__(self, hidden_dim, dtype): + self.rms_eps = 1e-6 + self.rms_norm = RMSNorm(hidden_dim, eps=self.rms_eps, dtype=dtype) + self.fp8_quant = QuantFP8( + static=True, + group_shape=GroupShape.PER_TENSOR, + ) + + def allreduce_rmsnorm( + self, input_tensor: torch.Tensor, residual: torch.Tensor | None + ): + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + return self.rms_norm(allreduce_out, residual) + + def allreduce_rmsnorm_fp8_quant( + self, + input_tensor: torch.Tensor, + residual: torch.Tensor | None, + scale_factor: torch.Tensor, + ): + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + rms_out = self.rms_norm(allreduce_out, residual) + if residual is None: + quant_out = self.fp8_quant(rms_out, scale_factor) + return quant_out + else: + rms_out, residual_out = rms_out + quant_out = self.fp8_quant(rms_out, scale_factor) + return quant_out, residual_out + + def allreduce_rmsnorm_fp4_quant( + self, + input_tensor: torch.Tensor, + residual: torch.Tensor | None, + input_global_scale: torch.Tensor, + quant_out: torch.Tensor, + output_scale: torch.Tensor, + ): + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + rms_out = self.rms_norm(allreduce_out, residual) + if residual is None: + SCALED_FP4_QUANT_OP(quant_out, rms_out, output_scale, input_global_scale) + return quant_out, output_scale + else: + rms_out, residual_out = rms_out + SCALED_FP4_QUANT_OP(quant_out, rms_out, output_scale, input_global_scale) + return quant_out, residual_out, output_scale + + +def create_test_tensors( + num_tokens: int, hidden_dim: int, dtype: torch.dtype, use_residual: bool = True +): + """Create test tensors for benchmarking.""" + input_tensor = torch.randn(num_tokens, hidden_dim, dtype=dtype) + residual = ( + torch.randn_like(input_tensor) + if use_residual + else torch.zeros_like(input_tensor) + ) + rms_gamma = torch.ones(hidden_dim, dtype=dtype) + norm_out = None if use_residual else torch.empty_like(input_tensor) + + # Quantization scales + scale_fp8 = torch.tensor(1.0, dtype=torch.float32) + scale_fp4 = torch.tensor(1.0, dtype=torch.float32) + quant_out_fp8 = torch.empty_like(input_tensor, dtype=FP8_DTYPE) + # Pre-allocate FP4 output tensors (to avoid allocation overhead in benchmarks) + fp4_quant_out = torch.empty((num_tokens, hidden_dim // 2), dtype=torch.uint8) + fp4_output_scale = torch.empty((128, 4), dtype=torch.int32) + + return ( + input_tensor, + norm_out, + residual, + rms_gamma, + scale_fp8, + quant_out_fp8, + scale_fp4, + fp4_quant_out, + fp4_output_scale, + ) + + +def benchmark_operation( + operation_func, *args, warmup: int = 5, trials: int = 20, **kwargs +): + """Benchmark a single operation using CUDA graphs.""" + # Warmup before graph capture + for _ in range(warmup): + operation_func(*args, **kwargs) + torch.cuda.synchronize() + + # Create CUDA graph + graph = torch.cuda.CUDAGraph() + num_op_per_cudagraph = 10 + + # Use vLLM's graph_capture to make tensor_model_parallel_all_reduce graph-safe + device = torch.device(f"cuda:{torch.cuda.current_device()}") + with graph_capture(device=device), torch.cuda.graph(graph): + for _ in range(num_op_per_cudagraph): + operation_func(*args, **kwargs) + + # Graph warmup + torch.cuda.synchronize() + for _ in range(warmup): + graph.replay() + + # Benchmark with CUDA graph + torch.cuda.synchronize() + start_time = time.perf_counter() + + for _ in range(trials // num_op_per_cudagraph): + # operation_func(*args, **kwargs) + graph.replay() + + torch.cuda.synchronize() + end_time = time.perf_counter() + + avg_time_ms = ((end_time - start_time) / trials) * 1000 + return avg_time_ms + + +def run_benchmarks( + num_tokens: int, + hidden_dim: int, + dtype: torch.dtype, + use_residual: bool, + allreduce_params: FlashInferFusedAllReduceParams | None, + quant_modes: set[str], + no_oneshot: bool, +): + """Run all benchmarks for given configuration. + + Args: + quant_mode: "none", "fp8_only", "fp4_only", or "all" + """ + ( + input_tensor, + norm_out, + residual, + rms_gamma, + scale_fp8, + quant_out_fp8, + scale_fp4, + fp4_quant_out, + fp4_output_scale, + ) = create_test_tensors(num_tokens, hidden_dim, dtype, use_residual) + + rms_eps = 1e-6 + results = {} + vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype) + use_oneshot_options = [False] if no_oneshot else [True, False] + + # Create RMSNorm and QuantFP8 layers once for native benchmarks + + if "none" in quant_modes: + # Standard AllReduce + RMSNorm + for custom_op in ["-rms_norm", "+rms_norm"]: + with set_current_vllm_config( + VllmConfig(compilation_config=CompilationConfig(custom_ops=[custom_op])) + ): + try: + suffix = ( + "_custom_rms_norm" if "+" in custom_op else "_native_rms_norm" + ) + time_ms = benchmark_operation( + vllm_fused_allreduce.allreduce_rmsnorm, + input_tensor, + residual=residual, + ) + results[f"standard_allreduce_{suffix}"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm failed: %s", e) + results[f"standard_allreduce_{suffix}"] = float("inf") + + # Standard AllReduce + RMSNorm Native Compiled + with set_current_vllm_config( + VllmConfig(compilation_config=CompilationConfig(custom_ops=["-rms_norm"])) + ): + try: + standard_allreduce_rmsnorm_native_compiled = torch.compile( + vllm_fused_allreduce.allreduce_rmsnorm, + fullgraph=True, + dynamic=False, + ) + time_ms = benchmark_operation( + standard_allreduce_rmsnorm_native_compiled, + input_tensor, + residual=residual, + ) + results["standard_allreduce_rmsnorm_native_compiled"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm Native Compiled failed: %s", e) + results["standard_allreduce_rmsnorm_native_compiled"] = float("inf") + + # FlashInfer Fused AllReduce + RMSNorm Oneshot/Twoshot + if flashinfer_comm is not None and allreduce_params is not None: + for use_oneshot in use_oneshot_options: + suffix = "_oneshot" if use_oneshot else "_twoshot" + try: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm, + input_tensor, + residual=residual, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + allreduce_params=allreduce_params, + use_oneshot=use_oneshot, + ) + results[f"flashinfer_fused_allreduce_rmsnorm{suffix}"] = time_ms + except Exception as e: + logger.error("FlashInfer Fused AllReduce+RMSNorm failed: %s", e) + results[f"flashinfer_fused_allreduce_rmsnorm{suffix}"] = float( + "inf" + ) + + if "fp8" in quant_modes: + # Standard AllReduce + RMSNorm + FP8 Quant + for rms_norm_custom_op in ["-rms_norm", "+rms_norm"]: + suffix = ( + "_custom_rms_norm" if "+" in rms_norm_custom_op else "_native_rms_norm" + ) + for quant_fp8_custom_op in ["-quant_fp8", "+quant_fp8"]: + suffix += ( + "_custom_quant_fp8" + if "+" in quant_fp8_custom_op + else "_native_quant_fp8" + ) + with set_current_vllm_config( + VllmConfig( + compilation_config=CompilationConfig( + custom_ops=[rms_norm_custom_op, quant_fp8_custom_op] + ) + ) + ): + try: + time_ms = benchmark_operation( + vllm_fused_allreduce.allreduce_rmsnorm_fp8_quant, + input_tensor, + residual=residual, + scale_factor=scale_fp8, + ) + results[f"standard_allreduce{suffix}"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm+FP8 failed: %s", e) + results[f"standard_allreduce{suffix}"] = float("inf") + + # Standard AllReduce + RMSNorm + FP8 Quant Native Compiled + with set_current_vllm_config( + VllmConfig( + compilation_config=CompilationConfig( + custom_ops=["-rms_norm", "-quant_fp8"] + ) + ) + ): + try: + standard_allreduce_rmsnorm_fp8_quant_native_compiled = torch.compile( + vllm_fused_allreduce.allreduce_rmsnorm_fp8_quant, + fullgraph=True, + dynamic=False, + ) + time_ms = benchmark_operation( + standard_allreduce_rmsnorm_fp8_quant_native_compiled, + input_tensor, + residual=residual, + scale_factor=scale_fp8, + ) + results["standard_allreduce_rmsnorm_fp8_quant_native_compiled"] = ( + time_ms + ) + except Exception as e: + logger.error( + "Standard AllReduce+RMSNorm+FP8 Native Compiled failed: %s", e + ) + results["standard_allreduce_rmsnorm_fp8_quant_native_compiled"] = float( + "inf" + ) + + # FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Oneshot + if flashinfer_comm is not None and allreduce_params is not None: + for use_oneshot in use_oneshot_options: + suffix = "_oneshot" if use_oneshot else "_twoshot" + try: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm_fp8_quant, + input_tensor, + norm_out=norm_out, + residual=residual, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + scale_factor=scale_fp8, + quant_out=quant_out_fp8, + allreduce_params=allreduce_params, + use_oneshot=use_oneshot, + ) + results[f"flashinfer_fused_allreduce_rmsnorm_fp8_quant{suffix}"] = ( + time_ms + ) + except Exception as e: + logger.error( + "FlashInfer Fused AllReduce+RMSNorm+FP8 Oneshot failed: %s", + e, + ) + results[f"flashinfer_fused_allreduce_rmsnorm_fp8_quant{suffix}"] = ( + float("inf") + ) + + if "fp4" in quant_modes and current_platform.has_device_capability(100): + # Standard AllReduce + RMSNorm + FP4 Quant + for rms_norm_custom_op in ["-rms_norm", "+rms_norm"]: + suffix = ( + "_custom_rms_norm" if "+" in rms_norm_custom_op else "_native_rms_norm" + ) + with set_current_vllm_config( + VllmConfig( + compilation_config=CompilationConfig( + custom_ops=[rms_norm_custom_op] + ) + ) + ): + try: + time_ms = benchmark_operation( + vllm_fused_allreduce.allreduce_rmsnorm_fp4_quant, + input_tensor, + residual=residual, + input_global_scale=scale_fp4, + quant_out=fp4_quant_out, + output_scale=fp4_output_scale, + ) + results[f"standard_allreduce_{suffix}_fp4_quant"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm+FP4 failed: %s", e) + results[f"standard_allreduce_{suffix}_fp4_quant"] = float("inf") + + # Standard AllReduce + RMSNorm + FP4 Quant Native Compiled + with set_current_vllm_config( + VllmConfig(compilation_config=CompilationConfig(custom_ops=["-rms_norm"])) + ): + try: + standard_allreduce_rmsnorm_fp4_quant_native_compiled = torch.compile( + vllm_fused_allreduce.allreduce_rmsnorm_fp4_quant, + fullgraph=True, + dynamic=False, + ) + time_ms = benchmark_operation( + standard_allreduce_rmsnorm_fp4_quant_native_compiled, + input_tensor, + residual=residual, + quant_out=fp4_quant_out, + input_global_scale=scale_fp4, + output_scale=fp4_output_scale, + ) + results["standard_allreduce_rmsnorm_fp4_quant_native_compiled"] = ( + time_ms + ) + except Exception as e: + logger.error( + "Standard AllReduce+RMSNorm+FP4 Native Compiled failed: %s", e + ) + results["standard_allreduce_rmsnorm_fp4_quant_native_compiled"] = float( + "inf" + ) + + # FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Oneshot + if flashinfer_comm is not None and allreduce_params is not None: + for use_oneshot in use_oneshot_options: + suffix = "_oneshot" if use_oneshot else "_twoshot" + try: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm_fp4_quant, + input_tensor, + residual=residual, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + input_global_scale=scale_fp4, + allreduce_params=allreduce_params, + quant_out=fp4_quant_out, + output_scale=fp4_output_scale, + use_oneshot=use_oneshot, + ) + results[f"flashinfer_fused_allreduce_rmsnorm_fp4_quant{suffix}"] = ( + time_ms + ) + except Exception as e: + logger.error( + "FlashInfer Fused AllReduce+RMSNorm+FP4 Oneshot failed: %s", + e, + ) + results[f"flashinfer_fused_allreduce_rmsnorm_fp4_quant{suffix}"] = ( + float("inf") + ) + + # FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Two-shot + if flashinfer_comm is not None and allreduce_params is not None: + try: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm_fp4_quant, + input_tensor, + residual=residual, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + input_global_scale=scale_fp4, + allreduce_params=allreduce_params, + quant_out=fp4_quant_out, + output_scale=fp4_output_scale, + use_oneshot=False, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"] = ( + time_ms + ) + except Exception as e: + logger.error( + "FlashInfer Fused AllReduce+RMSNorm+FP4 Two-shot failed: %s", + e, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"] = float( + "inf" + ) + + return results + + +def prepare_results_with_speedups(results_dict): + """Prepare results with speedup calculations based on dynamic baseline selection.""" + prepared_results = [] + + # Determine the fastest baseline for each operation type + def get_fastest_baseline(op_name, results_dict): + """Get the fastest baseline between standard and native_compiled versions.""" + if "fp8_quant" in op_name: + candidates = [ + "standard_allreduce_rmsnorm_fp8_quant", + "standard_allreduce_rmsnorm_fp8_quant_native_compiled", + ] + elif "fp4_quant" in op_name: + candidates = [ + "standard_allreduce_rmsnorm_fp4_quant", + "standard_allreduce_rmsnorm_fp4_quant_native_compiled", + ] + else: + candidates = [ + "standard_allreduce_rmsnorm", + "standard_allreduce_rmsnorm_native_compiled", + ] + + # Find the fastest among available candidates + fastest_time = float("inf") + fastest_baseline = None + + for candidate in candidates: + if ( + candidate in results_dict + and results_dict[candidate] != float("inf") + and results_dict[candidate] < fastest_time + ): + fastest_time = results_dict[candidate] + fastest_baseline = candidate + + return fastest_baseline + + # Create dynamic baseline mapping + dynamic_baseline_mapping = {} + for op_name in results_dict: + if ( + op_name.startswith("flashinfer_") + or op_name.startswith("standard_") + and not op_name.endswith("_native_compiled") + ): + dynamic_baseline_mapping[op_name] = get_fastest_baseline( + op_name, results_dict + ) + + for op_name, time_ms in results_dict.items(): + if time_ms == float("inf"): + speedup_str = "FAILED" + time_str = "FAILED" + else: + time_str = f"{time_ms:.3f}" + # Find the appropriate baseline for this operation + baseline_op = dynamic_baseline_mapping.get(op_name) + if baseline_op and baseline_op in results_dict: + baseline_time = results_dict[baseline_op] + if baseline_time != float("inf") and baseline_time > 0: + speedup = baseline_time / time_ms + speedup_str = f"{speedup:.2f}x" + else: + speedup_str = "N/A" + else: + # For baseline operations, determine if this is the fastest baseline + if op_name.endswith("_native_compiled") or ( + op_name.startswith("standard_") + and not op_name.endswith("_native_compiled") + ): + fastest_baseline = get_fastest_baseline(op_name, results_dict) + if fastest_baseline == op_name: + speedup_str = "baseline" + else: + if fastest_baseline and fastest_baseline in results_dict: + baseline_time = results_dict[fastest_baseline] + if baseline_time != float("inf") and baseline_time > 0: + speedup = baseline_time / time_ms + speedup_str = f"{speedup:.2f}x" + else: + speedup_str = "N/A" + else: + speedup_str = "N/A" + else: + speedup_str = "N/A" + + prepared_results.append( + { + "operation": op_name, + "time_ms": time_ms, + "time_str": time_str, + "speedup_str": speedup_str, + } + ) + + return prepared_results + + +def print_results( + results_dict, + num_tokens, + hidden_dim, + dtype, + use_residual, + quant_modes, + input_size_mb, +): + """Print benchmark results in a formatted table.""" + print(f"\n{'=' * 80}") + print( + f"Results: num_tokens={num_tokens}, hidden_dim={hidden_dim} " + f"(input size: {input_size_mb:.2f} MB)" + ) + print( + f"dtype={dtype}, residual={'yes' if use_residual else 'no'}, " + f"quant_modes={','.join(sorted(list(quant_modes)))}" + ) + print(f"{'=' * 80}") + print(f"{'Operation':<50} {'Time (ms)':<12} {'Speedup':<10}") + print(f"{'-' * 80}") + + # Prepare results with speedup calculations + prepared_results = prepare_results_with_speedups(results_dict) + + for result in prepared_results: + if result["time_ms"] == float("inf"): + time_display = result["time_str"] + else: + time_display = f"{result['time_ms']:.3f}" + + print( + f"{result['operation']:<50} {time_display:<12} {result['speedup_str']:<10}" + ) + + +def format_results_markdown( + all_results: list[dict], world_size: int, args: argparse.Namespace +) -> str: + """Format all benchmark results as markdown.""" + lines: list[str] = [] + lines.append("# FlashInfer Fused Collective Operations Benchmark Results") + lines.append("") + lines.append(f"**World Size:** {world_size} ") + lines.append(f"**Hidden Dimension:** {args.hidden_dim} ") + lines.append(f"**Warmup Iterations:** {args.warmup} ") + lines.append(f"**Benchmark Trials:** {args.trials} ") + modes = ",".join(all_results[0]["quant_modes"]) if all_results else "N/A" + lines.append(f"**Quantization Modes:** {modes} ") + lines.append("") + lines.append("---") + lines.append("") + + for entry in all_results: + num_tokens = entry["num_tokens"] + dtype = entry["dtype"] + use_residual = entry["use_residual"] + results_dict = entry["results"] + input_size_mb = entry["input_size_mb"] + residual_str = "with residual" if use_residual else "no residual" + + lines.append( + f"## Configuration: num_tokens={num_tokens}, dtype={dtype}, {residual_str}" + ) + lines.append(f"**Input Size:** {input_size_mb:.2f} MB") + lines.append("") + + prepared = prepare_results_with_speedups(results_dict) + # Build DataFrame for markdown export + rows = [ + { + "Operation": r["operation"].replace("_", " ").title(), + "Time (ms)": r["time_str"], + "Speedup": r["speedup_str"], + } + for r in prepared + ] + df = pd.DataFrame(rows) + if df.empty: + lines.append("No results.") + else: + lines.append(df.to_markdown(index=False)) + lines.append("") + + return "\n".join(lines) + + +def save_results_to_file( + all_results: list[dict], world_size: int, args: argparse.Namespace, rank: int +): + """Save benchmark results to markdown file (only on rank 0).""" + if rank != 0: + return + + if not all_results: + logger.warning("No results to save") + return + + output_path = args.output_file + + try: + markdown_content = format_results_markdown(all_results, world_size, args) + + with open(output_path, "a") as f: + f.write(markdown_content) + + except Exception as e: + logger.error("Failed to save results to file: %s", e) + + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark fused collective operations" + ) + parser.add_argument( + "--num-tokens", + type=int, + nargs="+", + default=[128, 512, 1024, 2048], + help="Numbers of tokens to test", + ) + parser.add_argument( + "--hidden-dim", type=int, default=8192, help="Hidden dimension size" + ) + parser.add_argument( + "--dtypes", + type=str, + nargs="+", + default=["bfloat16"], + choices=["float16", "bfloat16", "float32"], + help="Data types to test", + ) + parser.add_argument( + "--no-residual", + action="store_true", + help="Skip residual connection tests", + ) + + parser.add_argument( + "--quant-modes", + type=str, + default="none,fp8,fp4", + help=( + "Comma-separated quantization modes to run: none, fp8, fp4. " + "Default: none,fp8,fp4" + ), + ) + + parser.add_argument( + "--warmup", type=int, default=5, help="Number of warmup iterations" + ) + parser.add_argument( + "--trials", type=int, default=20, help="Number of benchmark trials" + ) + parser.add_argument( + "--output-file", + type=str, + help="""Output file path for markdown results + (default: benchmark_results_.md) + """, + ) + + parser.add_argument( + "--no-oneshot", + action="store_true", + help="Skip oneshot benchmarks", + ) + + args = parser.parse_args() + + # Check if running with torchrun (required for collective operations) + if "RANK" not in os.environ or "WORLD_SIZE" not in os.environ: + raise RuntimeError( + "Must run with torchrun for distributed benchmarking. " + "Example: torchrun --nproc_per_node=2 benchmark_fused_collective.py" + ) + + # Initialize distributed environment + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + torch.set_default_device(device) + + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # Validate world size (must be > 1 for collective operations) + if world_size <= 1: + raise ValueError( + "World size must be > 1 for collective operations benchmarking. " + f"Current world size: {world_size}. Use torchrun with --nproc_per_node > 1." + ) + + # Parse quantization modes + valid_quant_modes = {"none", "fp8", "fp4"} + raw_modes = [ + m.strip().lower() for m in (args.quant_modes or "").split(",") if m.strip() + ] + quant_modes = set(raw_modes) if raw_modes else {"none", "fp8", "fp4"} + invalid = sorted(list(quant_modes - valid_quant_modes)) + if invalid: + raise ValueError( + f"Invalid --quant-modes entries: {','.join(invalid)}. " + f"Valid options are: {','.join(sorted(valid_quant_modes))}." + ) + + if rank == 0: + logger.info("Running benchmark with world_size=%s, rank=%s", world_size, rank) + logger.info("Quantization modes: %s", ",".join(sorted(list(quant_modes)))) + if flashinfer_comm is not None: + logger.info( + "FlashInfer available - will benchmark fused operations", + ) + else: + logger.info( + "FlashInfer not available - only benchmarking standard operations" + ) + + # Convert dtype strings to torch dtypes + dtype_map = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float32": torch.float32, + } + dtypes = [dtype_map[dt] for dt in args.dtypes] + + # Test configurations + residual_options = [True] if not args.no_residual else [False] + + configs = list(itertools.product(args.num_tokens, dtypes, residual_options)) + + # Setup FlashInfer workspace if available + ipc_handles = None + allreduce_params = None + + if flashinfer_comm is not None: + # Use the largest hidden dimension for workspace setup + max_num_token = _FI_MAX_SIZES.get(world_size) // ( + args.hidden_dim * world_size * 2 + ) + + ipc_handles, workspace_tensor = setup_flashinfer_workspace( + world_size, rank, args.hidden_dim, max_num_token + ) + + if workspace_tensor is not None: + allreduce_params = FlashInferFusedAllReduceParams( + rank=rank, + world_size=world_size, + max_token_num=max_num_token, + ) + + # Collect all results for markdown export + all_results = [] + + try: + # Run benchmarks + for num_tokens, dtype, use_residual in configs: + if rank == 0: + logger.info( + "\nTesting: num_tokens=%s, hidden_dim=%s, dtype=%s, residual=%s", + num_tokens, + args.hidden_dim, + dtype, + use_residual, + ) + + results = run_benchmarks( + num_tokens, + args.hidden_dim, + dtype, + use_residual, + allreduce_params, + quant_modes=quant_modes, + no_oneshot=args.no_oneshot, + ) + + # Store results for markdown export + if rank == 0: + # Calculate input size in MB + input_size_mb = ( + num_tokens * args.hidden_dim * torch.finfo(dtype).bits + ) / (8 * 1024 * 1024) + all_results.append( + { + "num_tokens": num_tokens, + "hidden_dim": args.hidden_dim, + "dtype": str(dtype).replace("torch.", ""), + "use_residual": use_residual, + "quant_modes": sorted(list(quant_modes)), + "input_size_mb": input_size_mb, + "results": results, + } + ) + + print_results( + results, + num_tokens, + args.hidden_dim, + dtype, + use_residual, + quant_modes, + input_size_mb, + ) + + # Save results to markdown file + if args.output_file and rank == 0: + save_results_to_file(all_results, world_size, args, rank) + + finally: + # Cleanup + if ipc_handles is not None: + cleanup_flashinfer_workspace(ipc_handles) + + dist.barrier() + + +if __name__ == "__main__": + main() diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py index d525bd5faacf6..9b426d8d5f778 100644 --- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py +++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py @@ -16,8 +16,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.utils.argparse_utils import FlexibleArgumentParser DEFAULT_MODELS = [ - "nm-testing/Mixtral-8x7B-Instruct-v0.1", - "nm-testing/deepseekv2-lite", + "mistralai/Mixtral-8x7B-Instruct-v0.1", + "deepseek-ai/DeepSeek-V2-Lite", "ibm-granite/granite-3.0-1b-a400m", "ibm-granite/granite-3.0-3b-a800m", ] diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index bf1512268fe0b..6715c9b548aa1 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -19,13 +19,24 @@ from torch.utils.benchmark import Measurement as TMeasurement from utils import ArgPool, Bench, CudaGraphBenchParams from weight_shapes import WEIGHT_SHAPES -from vllm.triton_utils import HAS_TRITON +from vllm.lora.ops.triton_ops.utils import get_lora_op_configs +from vllm.triton_utils import HAS_TRITON, triton if HAS_TRITON: - from vllm.lora.ops.triton_ops import LoRAKernelMeta, lora_expand, lora_shrink + from vllm.lora.ops.triton_ops import ( ## added fused_moe_lora + LoRAKernelMeta, + fused_moe_lora_expand, + fused_moe_lora_shrink, + lora_expand, + lora_shrink, + ) + from vllm.lora.ops.triton_ops.fused_moe_lora_op import ( + _LORA_PTR_DICT, ## added _LORA_PTR_DICT for fused_moe_lora + ) from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT - +from vllm import _custom_ops as ops from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm.utils.math_utils import round_up DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) DEFAULT_TP_SIZES = [1] @@ -59,6 +70,8 @@ DEFAULT_NUM_LORAS = [1, 2, 3, 4] DEFAULT_SORT_BY_LORA_IDS = [False, True] DEFAULT_SEQ_LENGTHS = [1] DEFAULT_EXPAND_FN_ADD_INPUTS = [True, False] +DEFAULT_TOP_K_NUMS = [1] # Added for MoE LoRA top_k +DEFAULT_NUM_EXPERTS = [8] # Added for MoE LoRA num_experts # Utilities @@ -191,6 +204,11 @@ class OpType(Enum): LORA_SHRINK = auto() LORA_EXPAND = auto() + ## Adding support for fused moe lora + FUSED_MOE_LORA_GATE_UP_SHRINK = auto() ## Gate/Up projection variant with shrink + FUSED_MOE_LORA_GATE_UP_EXPAND = auto() ## Gate/Up projection variant with expand + FUSED_MOE_LORA_DOWN_SHRINK = auto() ## Down projection variant with shrink + FUSED_MOE_LORA_DOWN_EXPAND = auto() ## Down projection variant with expand @staticmethod def from_str(s: str) -> "OpType": @@ -198,6 +216,15 @@ class OpType(Enum): return OpType.LORA_SHRINK if s.lower() == "lora_expand": return OpType.LORA_EXPAND + # Adding support for fused moe lora, both in gate_up and down + if s.lower() == "fused_moe_lora_gate_up_shrink": ## Gate/Up variant with shrink + return OpType.FUSED_MOE_LORA_GATE_UP_SHRINK + if s.lower() == "fused_moe_lora_gate_up_expand": ## Gate/Up variant with expand + return OpType.FUSED_MOE_LORA_GATE_UP_EXPAND + if s.lower() == "fused_moe_lora_down_shrink": ## Down variant with shrink + return OpType.FUSED_MOE_LORA_DOWN_SHRINK + if s.lower() == "fused_moe_lora_down_expand": ## Down variant with expand + return OpType.FUSED_MOE_LORA_DOWN_EXPAND raise ValueError(f"Unrecognized str {s} to convert to OpType") def is_shrink_fn(self) -> bool: @@ -206,19 +233,56 @@ class OpType(Enum): def is_expand_fn(self) -> bool: return self in [OpType.LORA_EXPAND] + def is_fused_moe_lora_fn(self) -> bool: ## adding for fused MoE LoRA + return self in [ + OpType.FUSED_MOE_LORA_GATE_UP_SHRINK, + OpType.FUSED_MOE_LORA_DOWN_SHRINK, + OpType.FUSED_MOE_LORA_GATE_UP_EXPAND, + OpType.FUSED_MOE_LORA_DOWN_EXPAND, + ] + + def is_fused_moe_lora_gate_up_fn( + self, + ) -> bool: ## adding for fused MoE LoRA Gate/Up + return self in [ + OpType.FUSED_MOE_LORA_GATE_UP_SHRINK, + OpType.FUSED_MOE_LORA_GATE_UP_EXPAND, + ] + + def is_fused_moe_lora_down_fn(self) -> bool: ## adding for fused MoE LoRA Down + return self in [ + OpType.FUSED_MOE_LORA_DOWN_SHRINK, + OpType.FUSED_MOE_LORA_DOWN_EXPAND, + ] + + def is_fused_moe_lora_shrink_fn(self) -> bool: + return self in [ + OpType.FUSED_MOE_LORA_GATE_UP_SHRINK, + OpType.FUSED_MOE_LORA_DOWN_SHRINK, + ] + + def is_fused_moe_lora_expand_fn(self) -> bool: + return self in [ + OpType.FUSED_MOE_LORA_GATE_UP_EXPAND, + OpType.FUSED_MOE_LORA_DOWN_EXPAND, + ] + def num_slices(self) -> list[int]: + if self.is_fused_moe_lora_gate_up_fn(): + return [2] + elif self.is_fused_moe_lora_down_fn(): + return [1] return [1, 2, 3] def mkn( self, batch_size: int, seq_length: int, hidden_size: int, lora_rank: int ) -> tuple[int, int, int]: num_tokens = batch_size * seq_length - if self.is_shrink_fn(): + if self.is_shrink_fn() or self.is_fused_moe_lora_fn(): m = num_tokens k = hidden_size n = lora_rank - else: - assert self.is_expand_fn() + elif self.is_expand_fn(): m = num_tokens k = lora_rank n = hidden_size @@ -232,9 +296,36 @@ class OpType(Enum): """ if self.is_shrink_fn(): return op_dtype, op_dtype, torch.float32 - else: - assert self.is_expand_fn() + elif self.is_expand_fn(): return torch.float32, op_dtype, op_dtype + else: + assert self.is_fused_moe_lora_fn() + return op_dtype, op_dtype, op_dtype + + def matmul_shapes_fused_moe_lora( + self, + m: int, + n: int, + k: int, + num_loras: int, + num_slices: int, + top_k_num: int, + num_experts: int, + ) -> tuple[tuple[int], tuple[int], tuple[int], tuple[int]]: + if self.is_fused_moe_lora_shrink_fn(): + input_shape = ( + (m * top_k_num, n) + if self in [OpType.FUSED_MOE_LORA_DOWN_SHRINK] + else (m, n) + ) + output_shape = (num_slices, m, top_k_num, k) + weight_shape = (num_loras, num_experts, k, n) + else: + assert self.is_fused_moe_lora_expand_fn() + input_shape = (num_slices, m, top_k_num, k) + output_shape = (m, top_k_num, n * num_slices) + weight_shape = (num_loras, num_experts, n, k) + return (input_shape, weight_shape, output_shape) def matmul_shapes( self, @@ -244,6 +335,8 @@ class OpType(Enum): lora_rank: int, num_loras: int, num_slices: int, + top_k_num: int | None = None, + num_experts: int | None = None, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: """ Given num_slices, return the shapes of the A, B, and C matrices @@ -258,6 +351,16 @@ class OpType(Enum): if self in [OpType.LORA_EXPAND]: # LoRA expand kernels support num_slices inherently in the kernel return ((num_slices, m, k), b_shape, (m, n * num_slices)) + if self.is_fused_moe_lora_fn(): + return self.matmul_shapes_fused_moe_lora( + m, + k, + n, + num_loras, + num_slices, + top_k_num, + num_experts, + ) raise ValueError(f"Unrecognized op_type {self}") def bench_fn(self) -> Callable: @@ -265,6 +368,16 @@ class OpType(Enum): return lora_shrink if self == OpType.LORA_EXPAND: return lora_expand + if self in [ + OpType.FUSED_MOE_LORA_GATE_UP_SHRINK, + OpType.FUSED_MOE_LORA_DOWN_SHRINK, + ]: + return fused_moe_lora_shrink + if self in [ + OpType.FUSED_MOE_LORA_GATE_UP_EXPAND, + OpType.FUSED_MOE_LORA_DOWN_EXPAND, + ]: + return fused_moe_lora_expand raise ValueError(f"Unrecognized optype {self}") @@ -318,6 +431,8 @@ class BenchmarkContext: sort_by_lora_id: bool dtype: torch.dtype seq_length: int | None = None + num_experts: int | None = None # num_experts for MoE based ops + top_k_num: int | None = None # top_k for MoE based ops num_slices: int | None = None # num_slices for slice based ops def with_seq_length(self, seq_length: int) -> "BenchmarkContext": @@ -373,6 +488,11 @@ class BenchmarkTensors: f"{dtype_to_str(self.output.dtype)}" ) + def get_num_tokens(self, size: int, top_k_num: int, op_type: OpType): + return ( + size * top_k_num if op_type in [OpType.FUSED_MOE_LORA_DOWN_SHRINK] else size + ) + @staticmethod def make( ctx: BenchmarkContext, op_type: OpType, device: str = "cuda" @@ -385,6 +505,8 @@ class BenchmarkTensors: ctx.lora_rank, ctx.num_loras, ctx.num_slices, + ctx.top_k_num, + ctx.num_experts, ) a_type, b_type, c_type = op_type.matmul_dtypes(ctx.dtype) input_tensor, lora_weights, output_tensor = make_rand_tensors( @@ -432,17 +554,27 @@ class BenchmarkTensors: prompt_lora_indices_tensor, ) - def sanity_check(self) -> None: + def sanity_check(self, ctx: BenchmarkContext, op_type: OpType) -> None: """ Fails asserts when non-conformality is detected. """ - num_tokens = self.input.shape[-2] + num_tokens = ( + self.input.shape[1] + if op_type.is_fused_moe_lora_expand_fn() + else self.input.shape[-2] + ) # check metadata tensors - assert torch.sum(self.seq_lens) == num_tokens + ## In down shrink case, each token is repeated top_k_num times + assert num_tokens == self.get_num_tokens( + torch.sum(self.seq_lens), ctx.top_k_num, op_type + ), f"Expected {num_tokens} tokens, but got {torch.sum(self.seq_lens)}" num_seqs = self.seq_lens.shape[0] # assert self.seq_start_loc.shape[0] == num_seqs + ## In down shrink case, each prompt corresponds to top_k_num sequences assert self.prompt_lora_mapping.shape[0] == num_seqs - assert self.lora_kernel_meta.token_lora_mapping.shape[0] == num_tokens + assert self.get_num_tokens( + self.lora_kernel_meta.token_lora_mapping.shape[0], ctx.top_k_num, op_type + ) def to_device(self, device: str): """ @@ -471,21 +603,111 @@ class BenchmarkTensors: to_device(field) if field_name != "no_lora_flag_cpu" else field, ) - def metadata(self) -> tuple[int, int, int]: + def metadata(self, ctx: BenchmarkContext, op_type: OpType) -> tuple[int, int, int]: """ Return num_seqs, num_tokens and max_seq_len """ num_seqs = self.seq_lens.shape[0] - num_tokens = self.lora_kernel_meta.token_lora_mapping.shape[0] + num_tokens = self.get_num_tokens( + self.lora_kernel_meta.token_lora_mapping.shape[0], ctx.top_k_num, op_type + ) max_seq_len = torch.max(self.seq_lens).item() num_slices = len(self.lora_weights_lst) return num_seqs, num_tokens, max_seq_len, num_slices - def as_lora_shrink_kwargs(self) -> dict[str, Any]: - self.sanity_check() + def fused_moe_lora_data_prepare( + self, + block_size: int, + token_lora_mapping: torch.Tensor, + ctx: BenchmarkContext, + ): + def moe_lora_align_block_size( + topk_ids: torch.Tensor, + token_lora_mapping: torch.Tensor, + block_size: int, + num_experts: int, + max_loras: int, + expert_map: torch.Tensor | None = None, + pad_sorted_ids: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Aligns tokens and experts into block-sized chunks for LoRA-based + mixture-of-experts (MoE) execution. + """ + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + if pad_sorted_ids: + max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) + sorted_ids = torch.empty( + (max_loras * max_num_tokens_padded,), + dtype=torch.int32, + device=topk_ids.device, + ) + max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) + # Expert ids must be set default to -1 to prevent a blank block + expert_ids = torch.empty( + (max_loras * max_num_m_blocks,), + dtype=torch.int32, + device=topk_ids.device, + ) + num_tokens_post_pad = torch.empty( + (max_loras), dtype=torch.int32, device=topk_ids.device + ) + + ops.moe_lora_align_block_size( + topk_ids, + token_lora_mapping, + num_experts, + block_size, + max_loras, + max_num_tokens_padded, + max_num_m_blocks, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + if expert_map is not None: + expert_ids = expert_map[expert_ids] + + return sorted_ids, expert_ids, num_tokens_post_pad + + num_tokens = ctx.batch_size + curr_topk_ids = torch.randint( + 0, + ctx.num_experts, + (num_tokens, ctx.top_k_num), + device="cuda", + dtype=torch.int32, + ) + topk_weights = torch.randint( + 0, + ctx.num_experts, + (num_tokens, ctx.top_k_num), + device="cuda", + dtype=torch.int32, + ) + + (sorted_token_ids_lora, expert_ids_lora, num_tokens_post_padded_lora) = ( + moe_lora_align_block_size( + topk_ids=curr_topk_ids, + token_lora_mapping=token_lora_mapping, + block_size=block_size, + num_experts=ctx.num_experts, + max_loras=ctx.num_loras, + ) + ) + + sorted_token_ids = sorted_token_ids_lora.view(ctx.num_loras, -1) + expert_ids = expert_ids_lora.view(ctx.num_loras, -1) + num_tokens_post_padded = num_tokens_post_padded_lora + return (topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded) + + def as_lora_shrink_kwargs( + self, ctx: BenchmarkContext, op_type: OpType + ) -> dict[str, Any]: + self.sanity_check(ctx, op_type) self.to_device(self.input.device) - _, num_tokens, _, num_slices = self.metadata() + _, num_tokens, _, num_slices = self.metadata(ctx, op_type) # Sanity check matrix shapes. i_shape, lw_shape, o_shape = ( @@ -520,11 +742,13 @@ class BenchmarkTensors: "no_lora_flag_cpu": self.lora_kernel_meta.no_lora_flag_cpu, } - def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]: - self.sanity_check() + def as_lora_expand_kwargs( + self, ctx: BenchmarkContext, op_type: OpType, add_inputs: bool + ) -> dict[str, Any]: + self.sanity_check(ctx, op_type) self.to_device(self.input.device) - _, num_tokens, _, num_slices = self.metadata() + _, num_tokens, _, num_slices = self.metadata(ctx, op_type) # Sanity check matrix shapes. i_shape, lw_shape, o_shape = ( @@ -561,18 +785,173 @@ class BenchmarkTensors: "no_lora_flag_cpu": self.lora_kernel_meta.no_lora_flag_cpu, } - def bench_fn_kwargs( - self, op_type: OpType, add_inputs: bool | None = None + def as_fused_moe_lora_shrink_kwargs( + self, ctx: BenchmarkContext, op_type: OpType ) -> dict[str, Any]: - if op_type.is_shrink_fn(): + self.sanity_check(ctx, op_type) + self.to_device(self.input.device) + + _, num_tokens, _, num_slices = self.metadata(ctx, op_type) + + # Sanity check matrix shapes. + i_shape, lw_shape, o_shape = ( + self.input.shape, + self.lora_weights_lst[0].shape, + self.output.shape, + ) + # Expected input shape : [num_tokens, hidden_size] for gate_up + # Expected input shape : [top_k_num * num_tokens, hidden_size] for down + assert len(i_shape) == 2 + assert i_shape[0] == num_tokens + hidden_size = i_shape[1] + # Expected lora weight shape [max_lora, num_experts, lora_rank, hidden_size] + assert len(lw_shape) == 4 + assert lw_shape[-1] == hidden_size + lora_rank = lw_shape[-2] + # Expected output shape : [num_slices, num_tokens, top_k_num, lora_rank] + assert len(o_shape) == 4 + assert ( + o_shape + == (num_slices, num_tokens // ctx.top_k_num, ctx.top_k_num, lora_rank) + if op_type in [OpType.FUSED_MOE_LORA_DOWN_SHRINK] + else o_shape == (num_slices, num_tokens, ctx.top_k_num, lora_rank) + ) + kernel_config = get_lora_op_configs( + op_type.name.lower(), + max_loras=lw_shape[0], + batch=num_tokens, + hidden_size=hidden_size, + rank=lora_rank, + num_slices=num_slices, + add_inputs=False, + ) + + (topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded) = ( + self.fused_moe_lora_data_prepare( + block_size=kernel_config["BLOCK_SIZE_M"], + token_lora_mapping=self.lora_kernel_meta.token_lora_mapping, + ctx=ctx, + ) + ) + + return { + "qcurr_hidden_states": self.input, + "lora_a_stacked": self.lora_weights_lst, + "a_intermediate_cache1": self.output, + "topk_weights": topk_weights, + "sorted_token_ids": sorted_token_ids, + "expert_ids": expert_ids, + "num_tokens_post_padded": num_tokens_post_padded, + "top_k_num": ctx.top_k_num, + "device": self.input.device, + "N": lora_rank, + "M": topk_weights.shape[0], + "EM": sorted_token_ids.shape[1], + "K": self.input.shape[1], + "num_tokens": num_tokens, + "num_experts": ctx.num_experts, + "num_slices": num_slices, + "shrink_block_size_m": kernel_config["BLOCK_SIZE_M"], + "shrink_block_size_n": kernel_config["BLOCK_SIZE_N"], + "shrink_block_size_k": kernel_config["BLOCK_SIZE_K"], + "shrink_group_size_m": kernel_config["GROUP_SIZE_M"], + "shrink_num_warps": kernel_config["NUM_WARPS"], + "shrink_num_stages": kernel_config["NUM_STAGES"], + "shrink_split_k": kernel_config.get("SPLIT_K", 1), + "mul_routed_weight": op_type.is_fused_moe_lora_down_fn(), + } + + def as_fused_moe_lora_expand_kwargs( + self, ctx: BenchmarkContext, op_type: OpType + ) -> dict[str, Any]: + self.sanity_check(ctx, op_type) + self.to_device(self.input.device) + + _, num_tokens, _, num_slices = self.metadata(ctx, op_type) + + # Sanity check matrix shapes. + i_shape, lw_shape, o_shape = ( + self.input.shape, + self.lora_weights_lst[0].shape, + self.output.shape, + ) + + # Expected input shape : [num_slices, num_tokens, top_k_num, lora_rank] + assert len(i_shape) == 4 + assert i_shape[0] == num_slices + assert i_shape[1] == num_tokens + lora_rank = i_shape[-1] + # Expected lora weight shape : [num_loras, num_experts, hidden_size, lora_rank] + assert len(lw_shape) == 4 + assert lw_shape[-1] == lora_rank + hidden_size = lw_shape[-2] + # Expected output shape : [num_tokens, top_k_num, hidden_size * num_slices] + assert len(o_shape) == 3 + assert o_shape == (num_tokens, ctx.top_k_num, hidden_size * num_slices) + + kernel_config = get_lora_op_configs( + op_type.name.lower(), + max_loras=lw_shape[0], + batch=num_tokens, + hidden_size=hidden_size, + rank=lora_rank, + num_slices=num_slices, + add_inputs=False, + ) + + (topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded) = ( + self.fused_moe_lora_data_prepare( + block_size=kernel_config["BLOCK_SIZE_M"], + token_lora_mapping=self.lora_kernel_meta.token_lora_mapping, + ctx=ctx, + ) + ) + + return { + "a_intermediate_cache1": self.input, + "lora_b_stacked": self.lora_weights_lst, + "output": self.output, + "topk_weights": topk_weights, + "sorted_token_ids": sorted_token_ids, + "expert_ids": expert_ids, + "num_tokens_post_padded": num_tokens_post_padded, + "top_k_num": ctx.top_k_num, + "device": self.input.device, + "N": lora_rank, + "M": topk_weights.shape[0], + "EM": sorted_token_ids.shape[1], + "K": self.input.shape[1], + "num_tokens": num_tokens, + "num_experts": ctx.num_experts, + "num_slices": num_slices, + "max_lora_rank": lora_rank, + "w1_output_dim_size": lw_shape[2], + "expand_block_size_m": kernel_config["BLOCK_SIZE_M"], + "expand_block_size_n": kernel_config["BLOCK_SIZE_N"], + "expand_block_size_k": kernel_config["BLOCK_SIZE_K"], + "expand_group_size_m": kernel_config["GROUP_SIZE_M"], + "expand_num_warps": kernel_config["NUM_WARPS"], + "expand_num_stages": kernel_config["NUM_STAGES"], + "expand_split_k": kernel_config.get("SPLIT_K", 1), + "mul_routed_weight": op_type.is_fused_moe_lora_down_fn(), + } + + def bench_fn_kwargs( + self, ctx: BenchmarkContext, op_type: OpType, add_inputs: bool | None = None + ) -> dict[str, Any]: + if op_type.is_shrink_fn() or op_type.is_fused_moe_lora_fn(): assert add_inputs is None else: assert add_inputs is not None if op_type == OpType.LORA_SHRINK: - return self.as_lora_shrink_kwargs() + return self.as_lora_shrink_kwargs(ctx, op_type) if op_type == OpType.LORA_EXPAND: - return self.as_lora_expand_kwargs(add_inputs) + return self.as_lora_expand_kwargs(ctx, op_type, add_inputs) + if op_type.is_fused_moe_lora_shrink_fn(): + return self.as_fused_moe_lora_shrink_kwargs(ctx, op_type) + if op_type.is_fused_moe_lora_expand_fn(): + return self.as_fused_moe_lora_expand_kwargs(ctx, op_type) raise ValueError(f"Unrecognized optype {self}") def test_correctness( @@ -617,7 +996,7 @@ def bench_optype( test_correctness: bool = False, ) -> TMeasurement: assert arg_pool_size >= 1 - if op_type.is_shrink_fn(): + if op_type.is_shrink_fn() or op_type.is_fused_moe_lora_fn(): assert expand_fn_add_inputs is None else: assert expand_fn_add_inputs is not None @@ -627,23 +1006,30 @@ def bench_optype( BenchmarkTensors.make(ctx, op_type) for _ in range(arg_pool_size) ] for bt in bench_tensors: - bt.sanity_check() + bt.sanity_check(ctx, op_type) # Test correctness of our implementation. if test_correctness: + assert op_type in [OpType.LORA_SHRINK, OpType.LORA_EXPAND], ( + f"Correctness testing is not supported for {op_type.name}." + ) assert all( - [bt.test_correctness(op_type, expand_fn_add_inputs) for bt in bench_tensors] + [ + bt.test_correctness(ctx, op_type, expand_fn_add_inputs) + for bt in bench_tensors + ] ) # BenchmarkTensors -> dict (kwargs) kwargs_list = [ - bt.bench_fn_kwargs(op_type, add_inputs=expand_fn_add_inputs) + bt.bench_fn_kwargs(ctx, op_type, add_inputs=expand_fn_add_inputs) for bt in bench_tensors ] # Clear LoRA optimization hash-maps. _LORA_A_PTR_DICT.clear() _LORA_B_PTR_DICT.clear() + _LORA_PTR_DICT.clear() # Run bench function so that _LORA_A_PTR_DICT and _LORA_B_PTR_DICT are set up for kwargs in kwargs_list: op_type.bench_fn()(**kwargs) @@ -793,7 +1179,9 @@ def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]): # Benchmark bench_op expand_fn_add_inputs = ( - [None] if bench_op.is_shrink_fn() else args.expand_fn_add_inputs + [None] + if bench_op.is_shrink_fn() or bench_op.is_fused_moe_lora_fn() + else args.expand_fn_add_inputs ) for add_input_arg in expand_fn_add_inputs: seq_len_timers.append( @@ -831,12 +1219,22 @@ def as_benchmark_contexts( hidden_sizes: list[int], lora_ranks: list[int], args: argparse.Namespace ) -> list[BenchmarkContext]: ctxs: list[BenchmarkContext] = [] - for batch_size, hidden_size, lora_rank, num_loras, sort_by_lora_id in product( # noqa + for ( + batch_size, + hidden_size, + lora_rank, + num_loras, + sort_by_lora_id, + top_k_num, + num_experts, + ) in product( # noqa args.batch_sizes, list(hidden_sizes), lora_ranks, args.num_loras, args.sort_by_lora_id, + args.top_k_nums, + args.num_experts, ): ctxs.append( BenchmarkContext( @@ -851,6 +1249,8 @@ def as_benchmark_contexts( seq_length=None, sort_by_lora_id=sort_by_lora_id, dtype=args.dtype, + top_k_num=top_k_num, + num_experts=num_experts, # To be filled based on the OpType to benchmark num_slices=None, ) @@ -1012,6 +1412,22 @@ if __name__ == "__main__": ), ) + p.add_argument( + "--top-k-nums", + nargs="+", + type=int, + default=DEFAULT_TOP_K_NUMS, + help="Top-K values for MoE LoRA operations", + ) + + p.add_argument( + "--num-experts", + nargs="+", + type=int, + default=DEFAULT_NUM_EXPERTS, + help="Number of experts for MoE LoRA operations", + ) + parser = FlexibleArgumentParser( description=f""" Benchmark LoRA kernels: diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index bc6cf83bc21fd..c99951aa27826 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -211,7 +211,7 @@ def get_rocm_tuning_space(use_fp16): num_warps_range = [1, 2, 4, 8] group_m_range = [1, 4, 8, 16, 32] num_stage_range = [2] - waves_per_eu_range = [0] + waves_per_eu_range = [0, 1, 2, 4] matrix_instr_nonkdim_range = [16, 32] if use_fp16 else [] kpack_range = [1, 2] if use_fp16 else [] @@ -590,6 +590,7 @@ def main(args: argparse.Namespace): "DeepseekV3ForCausalLM", "DeepseekV32ForCausalLM", "Glm4MoeForCausalLM", + "NemotronHForCausalLM", ): E = config.n_routed_experts topk = config.num_experts_per_tok @@ -615,6 +616,11 @@ def main(args: argparse.Namespace): topk = config.moe_topk[0] intermediate_size = config.moe_intermediate_size[0] hidden_size = config.hidden_size + elif config.architectures[0] in ["Qwen3OmniMoeForConditionalGeneration"]: + E = config.thinker_config.text_config.num_experts + topk = config.thinker_config.text_config.num_experts_per_tok + intermediate_size = config.thinker_config.text_config.moe_intermediate_size + hidden_size = config.thinker_config.text_config.hidden_size else: # Support for llama4 config = config.get_text_config() diff --git a/benchmarks/kernels/benchmark_rope.py b/benchmarks/kernels/benchmark_rope.py index 29ef6409bb166..074b7a440b612 100644 --- a/benchmarks/kernels/benchmark_rope.py +++ b/benchmarks/kernels/benchmark_rope.py @@ -1,97 +1,76 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from itertools import accumulate +import itertools -import nvtx import torch -from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding, get_rope -from vllm.platforms import current_platform +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.triton_utils import triton from vllm.utils.argparse_utils import FlexibleArgumentParser +batch_size_range = [2**i for i in range(0, 8, 2)] +seq_len_range = [2**i for i in range(6, 10, 1)] +num_heads_range = [32, 48] +configs = list(itertools.product(batch_size_range, seq_len_range, num_heads_range)) -def benchmark_rope_kernels_multi_lora( - is_neox_style: bool, - batch_size: int, - seq_len: int, - num_heads: int, - head_size: int, - rotary_dim: int | None, - dtype: torch.dtype, - seed: int, - device: str, - max_position: int = 8192, - base: float = 10000, -) -> None: - current_platform.seed_everything(seed) - torch.set_default_device(device) - if rotary_dim is None: - rotary_dim = head_size - # silulating serving 4 LoRAs - scaling_factors = [1, 2, 4, 8] - # batched RoPE can take multiple scaling factors - batched_rope = get_rope( - head_size, - rotary_dim, - max_position, - base, - is_neox_style, - {"rope_type": "linear", "factor": tuple(scaling_factors)}, - ) - # non-batched RoPE takes only one scaling factor, we create multiple - # instances to simulate the same behavior - non_batched_ropes: list[RotaryEmbedding] = [] - for scaling_factor in scaling_factors: - non_batched_ropes.append( - get_rope( - head_size, - rotary_dim, - max_position, - base, - is_neox_style, - {"rope_type": "linear", "factor": (scaling_factor,)}, - ) - ) - positions = torch.randint(0, max_position, (batch_size, seq_len)) - query = torch.randn(batch_size, seq_len, num_heads * head_size, dtype=dtype) - key = torch.randn_like(query) - - # create query offsets for batched RoPE, we concat multiple kv cache - # together and each query needs to find the right kv cache of its type - offset_map = torch.tensor( - list( - accumulate( - [0] - + [ - max_position * scaling_factor * 2 - for scaling_factor in scaling_factors[:-1] - ] - ) +def get_benchmark(head_size, rotary_dim, is_neox_style, device): + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size", "seq_len", "num_heads"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["torch", "flashinfer", "vllm"], + line_names=["PyTorch", "FlashInfer", "vLLM"], + styles=[("blue", "-"), ("green", "-"), ("red", "-")], + ylabel="us", + plot_name=f"rope-perf{'-neox-style' if is_neox_style else ''}", + args={}, ) ) - query_types = torch.randint( - 0, len(scaling_factors), (batch_size, seq_len), device=device - ) - # map query types to offsets - query_offsets = offset_map[query_types] - # the kernel takes flattened offsets - flatten_offsets = query_offsets.flatten() + def benchmark(batch_size, seq_len, num_heads, provider): + dtype = torch.bfloat16 + max_position = 8192 + base = 10000 + rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style) + rope = rope.to(dtype=dtype, device=device) + cos_sin_cache = rope.cos_sin_cache.to(dtype=torch.float, device=device) - # batched queries of the same type together for non-batched RoPE - queries = [query[query_types == i] for i in range(len(scaling_factors))] - keys = [key[query_types == i] for i in range(len(scaling_factors))] - packed_qkr = zip(queries, keys, non_batched_ropes) - # synchronize before start timing - torch.cuda.synchronize() - with nvtx.annotate("non-batched", color="yellow"): - for q, k, r in packed_qkr: - r.forward(positions, q, k) - torch.cuda.synchronize() - with nvtx.annotate("batched", color="green"): - batched_rope.forward(positions, query, key, flatten_offsets) - torch.cuda.synchronize() + positions = torch.randint(0, max_position, (batch_size, seq_len), device=device) + query = torch.randn( + (batch_size, seq_len, num_heads * head_size), dtype=dtype, device=device + ) + key = torch.randn_like(query) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "torch": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: rope.forward_native(positions, query.clone(), key.clone()), + quantiles=quantiles, + ) + elif provider == "flashinfer": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: torch.ops.vllm.flashinfer_rotary_embedding( + positions, + query.clone(), + key.clone(), + head_size, + cos_sin_cache, + is_neox_style, + ), + quantiles=quantiles, + ) + else: + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: rope.forward_cuda(positions, query.clone(), key.clone()), + quantiles=quantiles, + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + return benchmark if __name__ == "__main__": @@ -116,17 +95,12 @@ if __name__ == "__main__": parser.add_argument( "--device", type=str, choices=["cuda:0", "cuda:1"], default="cuda:0" ) + parser.add_argument("--save-path", type=str, default="./configs/rope/") args = parser.parse_args() - print(args) - benchmark_rope_kernels_multi_lora( - is_neox_style=args.is_neox_style, - batch_size=args.batch_size, - seq_len=args.seq_len, - num_heads=args.num_heads, - head_size=args.head_size, - rotary_dim=args.rotary_dim, - dtype=getattr(torch, args.dtype), - seed=args.seed, - device=args.device, + # Get the benchmark function + benchmark = get_benchmark( + args.head_size, args.rotary_dim, args.is_neox_style, args.device ) + # Run performance benchmark + benchmark.run(print_data=True, save_path=args.save_path) diff --git a/benchmarks/kernels/benchmark_shapes.py b/benchmarks/kernels/benchmark_shapes.py index 18c459c31d3f8..3e23c4cac059c 100644 --- a/benchmarks/kernels/benchmark_shapes.py +++ b/benchmarks/kernels/benchmark_shapes.py @@ -78,11 +78,11 @@ WEIGHT_SHAPES = { } WEIGHT_SHAPES_MOE = { - "nm-testing/Mixtral-8x7B-Instruct-v0.1": [ + "mistralai/Mixtral-8x7B-Instruct-v0.1": [ [8, 2, 4096, 28672], [8, 2, 14336, 4096], ], - "nm-testing/deepseekv2-lite": [ + "deepseek-ai/DeepSeek-V2-Lite": [ [64, 6, 2048, 1408], ], "ibm-granite/granite-3.0-1b-a400m": [ diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 192d349b30099..dbda19fbcbf20 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -343,7 +343,7 @@ message(STATUS "CPU extension source files: ${VLLM_EXT_SRC}") # Define extension targets # -define_gpu_extension_target( +define_extension_target( _C DESTINATION vllm LANGUAGE CXX @@ -354,4 +354,4 @@ define_gpu_extension_target( WITH_SOABI ) -message(STATUS "Enabling C extension.") \ No newline at end of file +message(STATUS "Enabling C extension.") diff --git a/cmake/external_projects/flashmla.cmake b/cmake/external_projects/flashmla.cmake index f661084ec48ae..2cf3c1a755d3c 100644 --- a/cmake/external_projects/flashmla.cmake +++ b/cmake/external_projects/flashmla.cmake @@ -92,7 +92,7 @@ if(FLASH_MLA_ARCHS) SRCS "${FlashMLA_Extension_SOURCES}" CUDA_ARCHS "${FLASH_MLA_ARCHS}") - define_gpu_extension_target( + define_extension_target( _flashmla_C DESTINATION vllm LANGUAGE ${VLLM_GPU_LANG} @@ -109,7 +109,7 @@ if(FLASH_MLA_ARCHS) $<$:-UPy_LIMITED_API> $<$:-UPy_LIMITED_API>) - define_gpu_extension_target( + define_extension_target( _flashmla_extension_C DESTINATION vllm LANGUAGE ${VLLM_GPU_LANG} diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index 931090db50e92..29db9fa273a41 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -38,7 +38,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG a893712401d70362fbb299cd9c4b3476e8e9ed54 + GIT_TAG 8e1b01d56210dc72030a2d0d41c2d8d266ba6309 GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn diff --git a/cmake/utils.cmake b/cmake/utils.cmake index c2181d4549236..ca0062ba4fabe 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -453,21 +453,20 @@ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES) endmacro() # -# Define a target named `GPU_MOD_NAME` for a single extension. The +# Define a target named `MOD_NAME` for a single extension. The # arguments are: # # DESTINATION - Module destination directory. -# LANGUAGE - The GPU language for this module, e.g CUDA, HIP, -# etc. +# LANGUAGE - The language for this module, e.g. CUDA, HIP, +# CXX, etc. # SOURCES - List of source files relative to CMakeLists.txt # directory. # # Optional arguments: # -# ARCHITECTURES - A list of target GPU architectures in cmake -# format. -# Refer `CMAKE_CUDA_ARCHITECTURES` documentation -# and `CMAKE_HIP_ARCHITECTURES` for more info. +# ARCHITECTURES - A list of target architectures in cmake format. +# For GPU, refer to CMAKE_CUDA_ARCHITECTURES and +# CMAKE_HIP_ARCHITECTURES for more info. # ARCHITECTURES will use cmake's defaults if # not provided. # COMPILE_FLAGS - Extra compiler flags passed to NVCC/hip. @@ -478,63 +477,61 @@ endmacro() # # Note: optimization level/debug info is set via cmake build type. # -function (define_gpu_extension_target GPU_MOD_NAME) +function (define_extension_target MOD_NAME) cmake_parse_arguments(PARSE_ARGV 1 - GPU + ARG "WITH_SOABI" "DESTINATION;LANGUAGE;USE_SABI" "SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES") # Add hipify preprocessing step when building with HIP/ROCm. - if (GPU_LANGUAGE STREQUAL "HIP") - hipify_sources_target(GPU_SOURCES ${GPU_MOD_NAME} "${GPU_SOURCES}") + if (ARG_LANGUAGE STREQUAL "HIP") + hipify_sources_target(ARG_SOURCES ${MOD_NAME} "${ARG_SOURCES}") endif() - if (GPU_WITH_SOABI) - set(GPU_WITH_SOABI WITH_SOABI) + if (ARG_WITH_SOABI) + set(SOABI_KEYWORD WITH_SOABI) else() - set(GPU_WITH_SOABI) + set(SOABI_KEYWORD "") endif() - if (GPU_USE_SABI) - Python_add_library(${GPU_MOD_NAME} MODULE USE_SABI ${GPU_USE_SABI} ${GPU_WITH_SOABI} "${GPU_SOURCES}") + if (ARG_USE_SABI) + Python_add_library(${MOD_NAME} MODULE USE_SABI ${ARG_USE_SABI} ${SOABI_KEYWORD} "${ARG_SOURCES}") else() - Python_add_library(${GPU_MOD_NAME} MODULE ${GPU_WITH_SOABI} "${GPU_SOURCES}") + Python_add_library(${MOD_NAME} MODULE ${SOABI_KEYWORD} "${ARG_SOURCES}") endif() - if (GPU_LANGUAGE STREQUAL "HIP") + if (ARG_LANGUAGE STREQUAL "HIP") # Make this target dependent on the hipify preprocessor step. - add_dependencies(${GPU_MOD_NAME} hipify${GPU_MOD_NAME}) + add_dependencies(${MOD_NAME} hipify${MOD_NAME}) # Make sure we include the hipified versions of the headers, and avoid conflicts with the ones in the original source folder - target_include_directories(${GPU_MOD_NAME} PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/csrc - ${GPU_INCLUDE_DIRECTORIES}) + target_include_directories(${MOD_NAME} PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/csrc + ${ARG_INCLUDE_DIRECTORIES}) else() - target_include_directories(${GPU_MOD_NAME} PRIVATE csrc - ${GPU_INCLUDE_DIRECTORIES}) + target_include_directories(${MOD_NAME} PRIVATE csrc + ${ARG_INCLUDE_DIRECTORIES}) endif() - if (GPU_ARCHITECTURES) - set_target_properties(${GPU_MOD_NAME} PROPERTIES - ${GPU_LANGUAGE}_ARCHITECTURES "${GPU_ARCHITECTURES}") + if (ARG_ARCHITECTURES) + set_target_properties(${MOD_NAME} PROPERTIES + ${ARG_LANGUAGE}_ARCHITECTURES "${ARG_ARCHITECTURES}") endif() + target_compile_options(${MOD_NAME} PRIVATE + $<$:${ARG_COMPILE_FLAGS}>) - target_compile_options(${GPU_MOD_NAME} PRIVATE - $<$:${GPU_COMPILE_FLAGS}>) + target_compile_definitions(${MOD_NAME} PRIVATE + "-DTORCH_EXTENSION_NAME=${MOD_NAME}") - target_compile_definitions(${GPU_MOD_NAME} PRIVATE - "-DTORCH_EXTENSION_NAME=${GPU_MOD_NAME}") - - - target_link_libraries(${GPU_MOD_NAME} PRIVATE torch ${GPU_LIBRARIES}) + target_link_libraries(${MOD_NAME} PRIVATE torch ${ARG_LIBRARIES}) # Don't use `TORCH_LIBRARIES` for CUDA since it pulls in a bunch of # dependencies that are not necessary and may not be installed. - if (GPU_LANGUAGE STREQUAL "CUDA") - target_link_libraries(${GPU_MOD_NAME} PRIVATE CUDA::cudart CUDA::cuda_driver) + if (ARG_LANGUAGE STREQUAL "CUDA") + target_link_libraries(${MOD_NAME} PRIVATE torch CUDA::cudart CUDA::cuda_driver ${ARG_LIBRARIES}) else() - target_link_libraries(${GPU_MOD_NAME} PRIVATE ${TORCH_LIBRARIES}) + target_link_libraries(${MOD_NAME} PRIVATE torch ${TORCH_LIBRARIES} ${ARG_LIBRARIES}) endif() - install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION} COMPONENT ${GPU_MOD_NAME}) + install(TARGETS ${MOD_NAME} LIBRARY DESTINATION ${ARG_DESTINATION} COMPONENT ${MOD_NAME}) endfunction() diff --git a/csrc/attention/merge_attn_states.cu b/csrc/attention/merge_attn_states.cu index 6bee9e4ce1166..229d9862fb670 100644 --- a/csrc/attention/merge_attn_states.cu +++ b/csrc/attention/merge_attn_states.cu @@ -46,6 +46,32 @@ __global__ void merge_attn_states_kernel( s_lse = std::isinf(s_lse) ? -std::numeric_limits::infinity() : s_lse; const float max_lse = fmaxf(p_lse, s_lse); + + /* In certain edge cases, MLA can produce p_lse = s_lse = -inf; + continuing the pipeline then yields NaN. Root cause: with chunked prefill + a batch may be split into two chunks; if a request in that batch has no + prefix hit, every LSE entry for that request’s position is -inf, and at + this moment we merge cross-attention at first. For now we simply emit + prefix_output (expected to be all zeros) and prefix_lse (-inf) to fix + this problem. + */ + if (std::isinf(max_lse)) { + if (pack_offset < head_size) { + // Pack 128b load + pack_128b_t p_out_pack = reinterpret_cast( + prefix_head_ptr)[pack_offset / pack_size]; + + // Pack 128b storage + reinterpret_cast(output_head_ptr)[pack_offset / pack_size] = + p_out_pack; + } + // We only need to write to output_lse once per head. + if (output_lse != nullptr && pack_idx == 0) { + output_lse[head_idx * num_tokens + token_idx] = max_lse; + } + return; + } + p_lse = p_lse - max_lse; s_lse = s_lse - max_lse; const float p_se = expf(p_lse); diff --git a/csrc/mamba/mamba_ssm/selective_scan.h b/csrc/mamba/mamba_ssm/selective_scan.h index 13c6178941cf8..7d22dd8b84a39 100644 --- a/csrc/mamba/mamba_ssm/selective_scan.h +++ b/csrc/mamba/mamba_ssm/selective_scan.h @@ -24,6 +24,8 @@ struct SSMParamsBase { int64_t pad_slot_id; bool delta_softplus; + bool cache_enabled; + int block_size; index_t A_d_stride; index_t A_dstate_stride; @@ -46,8 +48,9 @@ struct SSMParamsBase { index_t out_z_batch_stride; index_t out_z_d_stride; index_t ssm_states_batch_stride; - index_t ssm_states_dim_stride; + index_t ssm_states_dim_stride; index_t ssm_states_dstate_stride; + index_t cache_indices_stride; // Common data pointers. void *__restrict__ A_ptr; @@ -66,6 +69,9 @@ struct SSMParamsBase { void *__restrict__ cache_indices_ptr; void *__restrict__ has_initial_state_ptr; + void *__restrict__ block_idx_first_scheduled_token_ptr; // (batch,) - first block to write + void *__restrict__ block_idx_last_scheduled_token_ptr; // (batch,) - last block to write + void *__restrict__ initial_state_idx_ptr; // (batch,) - index of the initial state to use }; diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index d534e138d26d6..fb2a2e5789999 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -119,7 +119,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr : reinterpret_cast(params.cache_indices_ptr); - const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id]; + const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id]; // cache_index == params.pad_slot_id is defined as padding, so we exit early if (cache_index == params.pad_slot_id){ return; @@ -133,9 +133,18 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { input_t *Bvar = reinterpret_cast(params.B_ptr) + sequence_start_index * params.B_batch_stride + group_id * params.B_group_stride; weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * kNRows * params.C_d_stride; input_t *Cvar = reinterpret_cast(params.C_ptr) + sequence_start_index * params.C_batch_stride + group_id * params.C_group_stride; - typename Ktraits::state_t *ssm_states = reinterpret_cast(params.ssm_states_ptr) + - cache_index * params.ssm_states_batch_stride + - dim_id * kNRows * params.ssm_states_dim_stride; + + typename Ktraits::state_t *ssm_states; + if (params.cache_enabled) { + // APC mode: ssm_states points to the base, we'll use absolute cache slots later + ssm_states = reinterpret_cast(params.ssm_states_ptr) + + dim_id * kNRows * params.ssm_states_dim_stride; + } else { + // Non-APC mode: offset by cache_index as before + ssm_states = reinterpret_cast(params.ssm_states_ptr) + + cache_index * params.ssm_states_batch_stride + + dim_id * kNRows * params.ssm_states_dim_stride; + } float D_val[kNRows] = {0}; if (params.D_ptr != nullptr) { @@ -159,7 +168,22 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { // } constexpr int kChunkSize = kNThreads * kNItems; - const int n_chunks = (seqlen + 2048 - 1) / 2048; + + // Use block_size for chunking when APC is enabled, otherwise use 2048 for backwards compatibility + const int iteration_chunk_size = params.cache_enabled ? params.block_size : 2048; + const int n_chunks = (seqlen + iteration_chunk_size - 1) / iteration_chunk_size; + + const int* batch_cache_indices = cache_indices != nullptr ? + cache_indices + batch_id * params.cache_indices_stride : nullptr; + const int* block_idx_first_scheduled = params.block_idx_first_scheduled_token_ptr != nullptr ? + reinterpret_cast(params.block_idx_first_scheduled_token_ptr) : nullptr; + const int* block_idx_last_scheduled = params.block_idx_last_scheduled_token_ptr != nullptr ? + reinterpret_cast(params.block_idx_last_scheduled_token_ptr) : nullptr; + const int* initial_state_idx = params.initial_state_idx_ptr != nullptr ? + reinterpret_cast(params.initial_state_idx_ptr) : nullptr; + + const size_t load_cache_slot = params.cache_enabled && batch_cache_indices != nullptr ? batch_cache_indices[initial_state_idx[batch_id]] : cache_index; + for (int chunk = 0; chunk < n_chunks; ++chunk) { input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems]; @@ -219,7 +243,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { if constexpr (kIsVariableC) { auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1; load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, - smem_load_weight_C, (seqlen - chunk * kChunkSize) * (1 )); + smem_load_weight_C, (seqlen - chunk * kChunkSize) * (1)); if constexpr (!kIsVariableB) { #pragma unroll for (int r = 0; r < kNRows; ++r) { @@ -242,7 +266,6 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { for (int i = 0; i < kNItems; ++i) { thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]), !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]); - if (seqlen % (kNItems * kNThreads) != 0) { // So that the last state is correct if (threadIdx.x * kNItems + i >= seqlen - chunk * kChunkSize) { thread_data[i] = make_float2(1.f, 0.f); @@ -250,8 +273,24 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { } } // Initialize running total - - scan_t running_prefix = chunk > 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.0, has_initial_state ? float(ssm_states[state_idx * params.ssm_states_dstate_stride]): 0.0); + scan_t running_prefix; + if (chunk > 0) { + running_prefix = smem_running_prefix[state_idx + r * MAX_DSTATE]; + } else { + // Load initial state + if (params.cache_enabled && has_initial_state && batch_cache_indices != nullptr) { + size_t state_offset = load_cache_slot * params.ssm_states_batch_stride + + r * params.ssm_states_dim_stride + + state_idx * params.ssm_states_dstate_stride; + running_prefix = make_float2(1.0, float(ssm_states[state_offset])); + } else if (has_initial_state) { + // Non-APC mode: load from current batch position + running_prefix = make_float2(1.0, float(ssm_states[state_idx * params.ssm_states_dstate_stride])); + } else { + // No initial state + running_prefix = make_float2(1.0, 0.0); + } + } SSMScanPrefixCallbackOp prefix_op(running_prefix); typename Ktraits::BlockScanT(smem_scan).InclusiveScan( @@ -260,8 +299,25 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { // There's a syncthreads in the scan op, so we don't need to sync here. // Unless there's only 1 warp, but then it's the same thread (0) reading and writing. if (threadIdx.x == 0) { - smem_running_prefix[state_idx] = prefix_op.running_prefix; - if (chunk == n_chunks - 1) { + smem_running_prefix[state_idx + r * MAX_DSTATE] = prefix_op.running_prefix; + + // Store state at the end of each chunk when cache is enabled + if (params.cache_enabled && batch_cache_indices != nullptr) { + + size_t cache_slot; + if (chunk == n_chunks - 1) { + cache_slot = batch_cache_indices[block_idx_last_scheduled[batch_id]]; + } else { + cache_slot = batch_cache_indices[block_idx_first_scheduled[batch_id] + chunk]; + } + + size_t state_offset = cache_slot * params.ssm_states_batch_stride + + r * params.ssm_states_dim_stride + + state_idx * params.ssm_states_dstate_stride; + + ssm_states[state_offset] = typename Ktraits::state_t(prefix_op.running_prefix.y); + } else if (!params.cache_enabled && chunk == n_chunks - 1) { + // Non-APC mode: store only final state at current batch position ssm_states[state_idx * params.ssm_states_dstate_stride] = typename Ktraits::state_t(prefix_op.running_prefix.y); } } @@ -274,7 +330,6 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { } } } - input_t *out = reinterpret_cast(params.out_ptr) + sequence_start_index * params.out_batch_stride + dim_id * kNRows * params.out_d_stride + chunk * kChunkSize; __syncthreads(); @@ -346,7 +401,9 @@ template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) { #ifndef USE_ROCM - if (params.seqlen <= 128) { + if (params.cache_enabled && params.block_size == 1024) { + selective_scan_fwd_launch<64, 16, input_t, weight_t, state_t>(params, stream); + } else if (params.seqlen <= 128) { selective_scan_fwd_launch<32, 4, input_t, weight_t, state_t>(params, stream); } else if (params.seqlen <= 256) { selective_scan_fwd_launch<32, 8, input_t, weight_t, state_t>(params, stream); @@ -358,7 +415,9 @@ void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) { selective_scan_fwd_launch<128, 16, input_t, weight_t, state_t>(params, stream); } #else - if (params.seqlen <= 256) { + if (params.cache_enabled && params.block_size == 1024) { + selective_scan_fwd_launch<64, 16, input_t, weight_t, state_t>(params, stream); + } else if (params.seqlen <= 256) { selective_scan_fwd_launch<64, 4, input_t, weight_t, state_t>(params, stream); } else if (params.seqlen <= 512) { selective_scan_fwd_launch<64, 8, input_t, weight_t, state_t>(params, stream); @@ -437,13 +496,17 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, const std::optional& D, const std::optional& delta_bias, const torch::Tensor ssm_states, - bool has_z, + bool has_z, bool delta_softplus, const std::optional& query_start_loc, const std::optional& cache_indices, const std::optional& has_initial_state, bool varlen, - int64_t pad_slot_id) { + int64_t pad_slot_id, + int64_t block_size, + const std::optional &block_idx_first_scheduled_token, + const std::optional &block_idx_last_scheduled_token, + const std::optional &initial_state_idx) { // Reset the parameters memset(¶ms, 0, sizeof(params)); @@ -477,6 +540,14 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, params.cache_indices_ptr = cache_indices.has_value() ? cache_indices.value().data_ptr() : nullptr; params.has_initial_state_ptr = has_initial_state.has_value() ? has_initial_state.value().data_ptr() : nullptr; + // Set cache parameters - cache is enabled if we have direct cache writing params + params.cache_enabled = block_idx_first_scheduled_token.has_value(); + params.block_size = static_cast(block_size); + + // Set direct cache writing pointers + params.block_idx_first_scheduled_token_ptr = block_idx_first_scheduled_token.has_value() ? block_idx_first_scheduled_token.value().data_ptr() : nullptr; + params.block_idx_last_scheduled_token_ptr = block_idx_last_scheduled_token.has_value() ? block_idx_last_scheduled_token.value().data_ptr() : nullptr; + params.initial_state_idx_ptr = initial_state_idx.has_value() ? initial_state_idx.value().data_ptr() : nullptr; // All stride are in elements, not bytes. params.A_d_stride = A.stride(0); @@ -504,9 +575,11 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, params.out_d_stride = out.stride(0); params.ssm_states_batch_stride = ssm_states.stride(0); - params.ssm_states_dim_stride = ssm_states.stride(1); + params.ssm_states_dim_stride = ssm_states.stride(1); params.ssm_states_dstate_stride = ssm_states.stride(2); + params.cache_indices_stride = cache_indices.has_value() ? cache_indices.value().stride(0) : 0; + } else{ if (!is_variable_B) { @@ -537,8 +610,10 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, params.out_d_stride = out.stride(1); params.ssm_states_batch_stride = ssm_states.stride(0); - params.ssm_states_dim_stride = ssm_states.stride(1); + params.ssm_states_dim_stride = ssm_states.stride(1); params.ssm_states_dstate_stride = ssm_states.stride(2); + + params.cache_indices_stride = cache_indices.has_value() ? cache_indices.value().stride(0) : 0; } } @@ -554,7 +629,11 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, const torch::Tensor &ssm_states, // used to identify padding entries if cache_indices provided // in case of padding, the kernel will return early - int64_t pad_slot_id) { + int64_t pad_slot_id, + int64_t block_size, + const std::optional &block_idx_first_scheduled_token, + const std::optional &block_idx_last_scheduled_token, + const std::optional &initial_state_idx) { auto input_type = u.scalar_type(); auto weight_type = A.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); @@ -646,7 +725,16 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, auto cache_indices_ = cache_indices.value(); TORCH_CHECK(cache_indices_.scalar_type() == at::ScalarType::Int); TORCH_CHECK(cache_indices_.is_cuda()); - CHECK_SHAPE(cache_indices_, batch_size); + + // cache_indices can be either 1D (batch_size,) for non-APC mode + // or 2D (batch_size, max_positions) for APC mode + const bool is_apc_mode = block_idx_first_scheduled_token.has_value(); + if (is_apc_mode) { + TORCH_CHECK(cache_indices_.dim() == 2, "cache_indices must be 2D for APC mode"); + TORCH_CHECK(cache_indices_.size(0) == batch_size, "cache_indices first dimension must match batch_size"); + } else { + CHECK_SHAPE(cache_indices_, batch_size); + } } @@ -686,7 +774,11 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, cache_indices, has_initial_state, varlen, - pad_slot_id + pad_slot_id, + block_size, + block_idx_first_scheduled_token, + block_idx_last_scheduled_token, + initial_state_idx ); diff --git a/csrc/moe/dynamic_4bit_int_moe_cpu.cpp b/csrc/moe/dynamic_4bit_int_moe_cpu.cpp index 1d06fc6b5b0a0..df47bb8dd1d7d 100644 --- a/csrc/moe/dynamic_4bit_int_moe_cpu.cpp +++ b/csrc/moe/dynamic_4bit_int_moe_cpu.cpp @@ -87,30 +87,23 @@ torch::Tensor dynamic_4bit_int_moe_cpu( const int64_t g_eff_13 = (group_size != -1) ? group_size : H; const int64_t g_eff_2 = (group_size != -1) ? group_size : I; - // Per-expert outputs filled in parallel - std::vector y_list(E); - y_list.resize(E); + auto X_all = x_c.index_select(/*dim=*/0, expert_tokens); + if (apply_router_weight_on_input) { + X_all = X_all.mul(expert_gates.unsqueeze(1)); + } + auto Y_all = at::empty({offsets[E], H}, x_c.options()); at::parallel_for(0, E, 1, [&](int64_t e_begin, int64_t e_end) { + c10::InferenceMode guard; for (int64_t e = e_begin; e < e_end; ++e) { const int64_t te = counts[e]; if (te == 0) { - y_list[e] = at::empty({0, H}, x_c.options()); continue; } const int64_t start = offsets[e]; - auto sel_tokens = - expert_tokens.narrow(/*dim=*/0, /*start=*/start, /*length=*/te); - auto gates_e = - expert_gates.narrow(/*dim=*/0, /*start=*/start, /*length=*/te); - - auto x_e = x_c.index_select(/*dim=*/0, sel_tokens); - - if (apply_router_weight_on_input) { - x_e = x_e.mul(gates_e.unsqueeze(1)); - } + auto x_e = X_all.narrow(/*dim=*/0, /*start=*/start, /*length=*/te); auto w13_e = w13_packed.select(/*dim=*/0, e); auto w2_e = w2_packed.select(/*dim=*/0, e); @@ -137,17 +130,15 @@ torch::Tensor dynamic_4bit_int_moe_cpu( // W2 auto y = mm(act, w2_e, g_eff_2, /*in_features=*/I, /*out_features=*/H); - if (!apply_router_weight_on_input) { - y = y.mul(gates_e.unsqueeze(1)); - } - // Store per-expert result - y_list[e] = y; + Y_all.narrow(/*dim=*/0, /*start=*/start, /*length=*/te).copy_(y); } }); - // Concatenate all expert outputs to match expert_tokens order - auto Y_all = at::cat(y_list, /*dim=*/0); + if (!apply_router_weight_on_input) { + Y_all = Y_all.mul(expert_gates.unsqueeze(1)); + } + auto out = at::zeros({T, H}, x.options()); out = at::index_add(out, /*dim=*/0, /*index=*/expert_tokens, /*source=*/Y_all); diff --git a/csrc/moe/grouped_topk_kernels.cu b/csrc/moe/grouped_topk_kernels.cu index c93f9d54d780c..69b4c1fb11d1a 100644 --- a/csrc/moe/grouped_topk_kernels.cu +++ b/csrc/moe/grouped_topk_kernels.cu @@ -427,11 +427,29 @@ __device__ inline bool is_finite(const T val) { #endif } +// Scoring function enums +enum ScoringFunc { + SCORING_NONE = 0, // no activation function + SCORING_SIGMOID = 1 // apply sigmoid +}; + +// Efficient sigmoid approximation from TensorRT-LLM +__device__ inline float sigmoid_accurate(float x) { + return 0.5f * tanhf(0.5f * x) + 0.5f; +} + template -__device__ void topk_with_k2(T* output, T const* input, +__device__ inline T apply_sigmoid(T val) { + float f = cuda_cast(val); + return cuda_cast(sigmoid_accurate(f)); +} + +template +__device__ void topk_with_k2(T* output, T const* input, T const* bias, cg::thread_block_tile<32> const& tile, int32_t const lane_id, - int const num_experts_per_group) { + int const num_experts_per_group, + int const scoring_func) { // Get the top2 per thread T largest = neg_inf(); T second_largest = neg_inf(); @@ -439,6 +457,12 @@ __device__ void topk_with_k2(T* output, T const* input, if (num_experts_per_group > WARP_SIZE) { for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { T value = input[i]; + // Apply scoring function if needed + if (scoring_func == SCORING_SIGMOID) { + value = apply_sigmoid(value); + } + value = value + bias[i]; + if (value > largest) { second_largest = largest; largest = value; @@ -448,7 +472,13 @@ __device__ void topk_with_k2(T* output, T const* input, } } else { for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { - largest = input[i]; + T value = input[i]; + // Apply scoring function if needed + if (scoring_func == SCORING_SIGMOID) { + value = apply_sigmoid(value); + } + value = value + bias[i]; + largest = value; } } @@ -472,17 +502,21 @@ __device__ void topk_with_k2(T* output, T const* input, } template -__global__ void topk_with_k2_kernel(T* output, T* input, +__global__ void topk_with_k2_kernel(T* output, T* input, T const* bias, int64_t const num_tokens, int64_t const num_cases, int64_t const n_group, - int64_t const num_experts_per_group) { + int64_t const num_experts_per_group, + int const scoring_func) { int32_t warp_id = threadIdx.x / WARP_SIZE; int32_t lane_id = threadIdx.x % WARP_SIZE; int32_t case_id = blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; if (case_id < num_cases) { input += case_id * num_experts_per_group; + // bias is per expert group, offset to current group + int32_t group_id = case_id % n_group; + T const* group_bias = bias + group_id * num_experts_per_group; output += case_id; cg::thread_block block = cg::this_thread_block(); @@ -491,7 +525,8 @@ __global__ void topk_with_k2_kernel(T* output, T* input, #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.wait;"); #endif - topk_with_k2(output, input, tile, lane_id, num_experts_per_group); + topk_with_k2(output, input, group_bias, tile, lane_id, + num_experts_per_group, scoring_func); } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.launch_dependents;"); @@ -500,16 +535,15 @@ __global__ void topk_with_k2_kernel(T* output, T* input, template __global__ void group_idx_and_topk_idx_kernel( - T* scores, T const* group_scores, T* topk_values, IdxT* topk_indices, - T* scores_with_bias, int64_t const num_tokens, int64_t const n_group, + T* scores, T const* group_scores, float* topk_values, IdxT* topk_indices, + T const* bias, int64_t const num_tokens, int64_t const n_group, int64_t const topk_group, int64_t const topk, int64_t const num_experts, int64_t const num_experts_per_group, bool renormalize, - double routed_scaling_factor) { + double routed_scaling_factor, int scoring_func) { int32_t warp_id = threadIdx.x / WARP_SIZE; int32_t lane_id = threadIdx.x % WARP_SIZE; int32_t case_id = blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; // one per token - scores_with_bias += case_id * num_experts; scores += case_id * num_experts; group_scores += case_id * n_group; topk_values += case_id * topk; @@ -577,10 +611,16 @@ __global__ void group_idx_and_topk_idx_kernel( int32_t offset = i_group * num_experts_per_group; for (int32_t i = lane_id; i < align_num_experts_per_group; i += WARP_SIZE) { - T candidates = (i < num_experts_per_group) && - is_finite(scores_with_bias[offset + i]) - ? scores_with_bias[offset + i] - : neg_inf(); + T candidates = neg_inf(); + if (i < num_experts_per_group) { + // Apply scoring function (if any) and add bias + T input = scores[offset + i]; + if (is_finite(input)) { + T score = (scoring_func == SCORING_SIGMOID) ? apply_sigmoid(input) + : input; + candidates = score + bias[offset + i]; + } + } queue.add(candidates, offset + i); } if (group_scores[i_group] == topk_group_value) { @@ -602,11 +642,12 @@ __global__ void group_idx_and_topk_idx_kernel( for (int i = lane_id; i < warp_topk::round_up_to_multiple_of(topk); i += WARP_SIZE) { - T value = - i < topk - ? scores[s_topk_idx[i]] - : cuda_cast(0.0f); // Load the valid value of expert + T value = cuda_cast(0.0f); if (i < topk) { + // Load the score value (without bias) for normalization + T input = scores[s_topk_idx[i]]; + value = + (scoring_func == SCORING_SIGMOID) ? apply_sigmoid(input) : input; s_topk_value[i] = value; } topk_sum += @@ -627,12 +668,12 @@ __global__ void group_idx_and_topk_idx_kernel( value = cuda_cast(s_topk_value[i]) * routed_scaling_factor; } topk_indices[i] = s_topk_idx[i]; - topk_values[i] = cuda_cast(value); + topk_values[i] = value; } } else { for (int i = lane_id; i < topk; i += WARP_SIZE) { topk_indices[i] = i; - topk_values[i] = cuda_cast(1.0f / topk); + topk_values[i] = 1.0f / topk; } } // Note: when if_proceed_next_topk==false, choose the first 8 experts as the @@ -644,12 +685,12 @@ __global__ void group_idx_and_topk_idx_kernel( } template -void invokeNoAuxTc(T* scores, T* group_scores, T* topk_values, - IdxT* topk_indices, T* scores_with_bias, - int64_t const num_tokens, int64_t const num_experts, - int64_t const n_group, int64_t const topk_group, - int64_t const topk, bool const renormalize, - double const routed_scaling_factor, bool enable_pdl = false, +void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values, + IdxT* topk_indices, T const* bias, int64_t const num_tokens, + int64_t const num_experts, int64_t const n_group, + int64_t const topk_group, int64_t const topk, + bool const renormalize, double const routed_scaling_factor, + int const scoring_func, bool enable_pdl = false, cudaStream_t const stream = 0) { int64_t num_cases = num_tokens * n_group; int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1; @@ -664,8 +705,9 @@ void invokeNoAuxTc(T* scores, T* group_scores, T* topk_values, attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; config.numAttrs = 1; config.attrs = attrs; - cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores_with_bias, - num_tokens, num_cases, n_group, num_experts / n_group); + cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores, bias, + num_tokens, num_cases, n_group, num_experts / n_group, + scoring_func); int64_t topk_with_k_group_num_blocks = (num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1; @@ -682,19 +724,18 @@ void invokeNoAuxTc(T* scores, T* group_scores, T* topk_values, config.numAttrs = 1; config.attrs = attrs; cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores, - topk_values, topk_indices, scores_with_bias, num_tokens, - n_group, topk_group, topk, num_experts, - num_experts / n_group, renormalize, routed_scaling_factor); + topk_values, topk_indices, bias, num_tokens, n_group, + topk_group, topk, num_experts, num_experts / n_group, + renormalize, routed_scaling_factor, scoring_func); } #define INSTANTIATE_NOAUX_TC(T, IdxT) \ template void invokeNoAuxTc( \ - T * scores, T * group_scores, T * topk_values, IdxT * topk_indices, \ - T * scores_with_bias, int64_t const num_tokens, \ - int64_t const num_experts, int64_t const n_group, \ - int64_t const topk_group, int64_t const topk, bool const renormalize, \ - double const routed_scaling_factor, bool enable_pdl, \ - cudaStream_t const stream); + T * scores, T * group_scores, float* topk_values, IdxT* topk_indices, \ + T const* bias, int64_t const num_tokens, int64_t const num_experts, \ + int64_t const n_group, int64_t const topk_group, int64_t const topk, \ + bool const renormalize, double const routed_scaling_factor, \ + int const scoring_func, bool enable_pdl, cudaStream_t const stream); INSTANTIATE_NOAUX_TC(float, int32_t); INSTANTIATE_NOAUX_TC(half, int32_t); @@ -703,28 +744,32 @@ INSTANTIATE_NOAUX_TC(__nv_bfloat16, int32_t); } // namespace vllm std::tuple grouped_topk( - torch::Tensor const& scores, torch::Tensor const& scores_with_bias, - int64_t n_group, int64_t topk_group, int64_t topk, bool renormalize, - double routed_scaling_factor) { - auto data_type = scores_with_bias.scalar_type(); - auto input_size = scores_with_bias.sizes(); + torch::Tensor const& scores, int64_t n_group, int64_t topk_group, + int64_t topk, bool renormalize, double routed_scaling_factor, + torch::Tensor const& bias, int64_t scoring_func = 0) { + auto data_type = scores.scalar_type(); + auto input_size = scores.sizes(); int64_t num_tokens = input_size[0]; int64_t num_experts = input_size[1]; - TORCH_CHECK(input_size.size() == 2, "scores_with_bias must be a 2D Tensor"); + TORCH_CHECK(input_size.size() == 2, "scores must be a 2D Tensor"); TORCH_CHECK(num_experts % n_group == 0, "num_experts should be divisible by n_group"); TORCH_CHECK(n_group <= 32, "n_group should be smaller than or equal to 32 for now"); TORCH_CHECK(topk <= 32, "topk should be smaller than or equal to 32 for now"); + TORCH_CHECK(scoring_func == vllm::moe::SCORING_NONE || + scoring_func == vllm::moe::SCORING_SIGMOID, + "scoring_func must be SCORING_NONE (0) or SCORING_SIGMOID (1)"); torch::Tensor group_scores = torch::empty( {num_tokens, n_group}, torch::dtype(data_type).device(torch::kCUDA)); + // Always output float32 for topk_values (eliminates Python-side conversion) torch::Tensor topk_values = torch::empty( - {num_tokens, topk}, torch::dtype(data_type).device(torch::kCUDA)); + {num_tokens, topk}, torch::dtype(torch::kFloat32).device(torch::kCUDA)); torch::Tensor topk_indices = torch::empty( {num_tokens, topk}, torch::dtype(torch::kInt32).device(torch::kCUDA)); - auto stream = c10::cuda::getCurrentCUDAStream(scores_with_bias.get_device()); + auto stream = c10::cuda::getCurrentCUDAStream(scores.get_device()); switch (data_type) { case torch::kFloat16: @@ -732,11 +777,11 @@ std::tuple grouped_topk( vllm::moe::invokeNoAuxTc( reinterpret_cast(scores.mutable_data_ptr()), reinterpret_cast(group_scores.mutable_data_ptr()), - reinterpret_cast(topk_values.mutable_data_ptr()), + reinterpret_cast(topk_values.mutable_data_ptr()), reinterpret_cast(topk_indices.mutable_data_ptr()), - reinterpret_cast(scores_with_bias.data_ptr()), num_tokens, + reinterpret_cast(bias.data_ptr()), num_tokens, num_experts, n_group, topk_group, topk, renormalize, - routed_scaling_factor, false, stream); + routed_scaling_factor, static_cast(scoring_func), false, stream); break; case torch::kFloat32: // Handle Float32 @@ -745,20 +790,20 @@ std::tuple grouped_topk( reinterpret_cast(group_scores.mutable_data_ptr()), reinterpret_cast(topk_values.mutable_data_ptr()), reinterpret_cast(topk_indices.mutable_data_ptr()), - reinterpret_cast(scores_with_bias.data_ptr()), num_tokens, + reinterpret_cast(bias.data_ptr()), num_tokens, num_experts, n_group, topk_group, topk, renormalize, - routed_scaling_factor, false, stream); + routed_scaling_factor, static_cast(scoring_func), false, stream); break; case torch::kBFloat16: // Handle BFloat16 vllm::moe::invokeNoAuxTc<__nv_bfloat16, int32_t>( reinterpret_cast<__nv_bfloat16*>(scores.mutable_data_ptr()), reinterpret_cast<__nv_bfloat16*>(group_scores.mutable_data_ptr()), - reinterpret_cast<__nv_bfloat16*>(topk_values.mutable_data_ptr()), + reinterpret_cast(topk_values.mutable_data_ptr()), reinterpret_cast(topk_indices.mutable_data_ptr()), - reinterpret_cast<__nv_bfloat16*>(scores_with_bias.data_ptr()), - num_tokens, num_experts, n_group, topk_group, topk, renormalize, - routed_scaling_factor, false, stream); + reinterpret_cast<__nv_bfloat16 const*>(bias.data_ptr()), num_tokens, + num_experts, n_group, topk_group, topk, renormalize, + routed_scaling_factor, static_cast(scoring_func), false, stream); break; default: // Handle other data types diff --git a/csrc/moe/moe_lora_align_sum_kernels.cu b/csrc/moe/moe_lora_align_sum_kernels.cu index e76d1c3667853..360f1312cf579 100644 --- a/csrc/moe/moe_lora_align_sum_kernels.cu +++ b/csrc/moe/moe_lora_align_sum_kernels.cu @@ -28,11 +28,16 @@ __global__ void moe_lora_align_sum_kernel( int64_t block_size, int num_experts, int max_loras, size_t numel, int max_num_tokens_padded, int max_num_m_blocks, int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, - int topk_num, int32_t* total_tokens_post_pad) { + int topk_num, int32_t* total_tokens_post_pad, int32_t* adapter_enabled, + int32_t* lora_ids) { const size_t tokens_per_thread = div_ceil(numel, blockDim.x); const size_t start_idx = threadIdx.x * tokens_per_thread; - int lora_id = blockIdx.x; + int lora_idx = blockIdx.x; + int lora_id = lora_ids[lora_idx]; + if (lora_id == -1 || adapter_enabled[lora_id] == 0) { + return; + } extern __shared__ int32_t shared_mem[]; int32_t* cumsum = shared_mem; token_cnts_t* tokens_cnts = (token_cnts_t*)(shared_mem + num_experts + 1); @@ -121,14 +126,13 @@ __global__ void moe_lora_align_sum_kernel( } } -void moe_lora_align_block_size(torch::Tensor topk_ids, - torch::Tensor token_lora_mapping, - int64_t num_experts, int64_t block_size, - int64_t max_loras, int64_t max_num_tokens_padded, - int64_t max_num_m_blocks, - torch::Tensor sorted_token_ids, - torch::Tensor expert_ids, - torch::Tensor num_tokens_post_pad) { +void moe_lora_align_block_size( + torch::Tensor topk_ids, torch::Tensor token_lora_mapping, + int64_t num_experts, int64_t block_size, int64_t max_loras, + int64_t max_num_tokens_padded, int64_t max_num_m_blocks, + torch::Tensor sorted_token_ids, torch::Tensor expert_ids, + torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled, + torch::Tensor lora_ids) { const int topk_num = topk_ids.size(1); TORCH_CHECK(block_size > 0, "block_size should be greater than 0. "); @@ -164,6 +168,7 @@ void moe_lora_align_block_size(torch::Tensor topk_ids, max_loras, topk_ids.numel(), max_num_tokens_padded, max_num_m_blocks, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), topk_num, - num_tokens_post_pad.data_ptr()); + num_tokens_post_pad.data_ptr(), + adapter_enabled.data_ptr(), lora_ids.data_ptr()); }); } \ No newline at end of file diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index e4bf0aa99421b..11c6875f7f1d0 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -20,14 +20,13 @@ void batched_moe_align_block_size(int64_t max_tokens_per_batch, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad); -void moe_lora_align_block_size(torch::Tensor topk_ids, - torch::Tensor token_lora_mapping, - int64_t num_experts, int64_t block_size, - int64_t max_loras, int64_t max_num_tokens_padded, - int64_t max_num_m_blocks, - torch::Tensor sorted_token_ids, - torch::Tensor expert_ids, - torch::Tensor num_tokens_post_pad); +void moe_lora_align_block_size( + torch::Tensor topk_ids, torch::Tensor token_lora_mapping, + int64_t num_experts, int64_t block_size, int64_t max_loras, + int64_t max_num_tokens_padded, int64_t max_num_m_blocks, + torch::Tensor sorted_token_ids, torch::Tensor expert_ids, + torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled, + torch::Tensor lora_ids); #ifndef USE_ROCM torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, torch::Tensor b_qweight, torch::Tensor b_scales, @@ -40,9 +39,9 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, int64_t BLOCK_SIZE_K, int64_t bit); std::tuple grouped_topk( - torch::Tensor const& scores, torch::Tensor const& scores_with_bias, - int64_t n_group, int64_t topk_group, int64_t topk, bool renormalize, - double routed_scaling_factor); + torch::Tensor const& scores, int64_t n_group, int64_t topk_group, + int64_t topk, bool renormalize, double routed_scaling_factor, + torch::Tensor const& bias, int64_t scoring_func); #endif bool moe_permute_unpermute_supported(); diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index c08a543908ef0..bd95ade40a083 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -44,7 +44,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { " int max_num_m_blocks, " " Tensor !sorted_token_ids," " Tensor !experts_ids," - " Tensor !num_tokens_post_pad) -> () "); + " Tensor !num_tokens_post_pad," + " Tensor !adapter_enabled," + " Tensor !lora_ids) -> () "); m.impl("moe_lora_align_block_size", torch::kCUDA, &moe_lora_align_block_size); #ifndef USE_ROCM @@ -105,9 +107,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { // Apply grouped topk routing to select experts. m.def( - "grouped_topk(Tensor scores, Tensor scores_with_bias, int n_group, int " + "grouped_topk(Tensor scores, int n_group, int " "topk_group, int topk, bool renormalize, float " - "routed_scaling_factor) -> (Tensor, Tensor)"); + "routed_scaling_factor, Tensor bias, int scoring_func) -> (Tensor, " + "Tensor)"); m.impl("grouped_topk", torch::kCUDA, &grouped_topk); #endif } diff --git a/csrc/ops.h b/csrc/ops.h index 0bed7492f6616..3f5cb799b774c 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -321,17 +321,19 @@ void dynamic_per_token_scaled_fp8_quant( torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale, std::optional const& scale_ub); -void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta, - const torch::Tensor& A, const torch::Tensor& B, - const torch::Tensor& C, - const std::optional& D_, - const std::optional& z_, - const std::optional& delta_bias_, - bool delta_softplus, - const std::optional& query_start_loc, - const std::optional& cache_indices, - const std::optional& has_initial_state, - const torch::Tensor& ssm_states, int64_t pad_slot_id); +void selective_scan_fwd( + const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A, + const torch::Tensor& B, const torch::Tensor& C, + const std::optional& D_, + const std::optional& z_, + const std::optional& delta_bias_, bool delta_softplus, + const std::optional& query_start_loc, + const std::optional& cache_indices, + const std::optional& has_initial_state, + const torch::Tensor& ssm_states, int64_t pad_slot_id, int64_t block_size, + const std::optional& block_idx_first_scheduled_token, + const std::optional& block_idx_last_scheduled_token, + const std::optional& initial_state_idx); torch::Tensor dynamic_4bit_int_moe_cpu( torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights, diff --git a/csrc/quantization/activation_kernels.cu b/csrc/quantization/activation_kernels.cu index 6fcd246f63c50..2521b2797e2c2 100644 --- a/csrc/quantization/activation_kernels.cu +++ b/csrc/quantization/activation_kernels.cu @@ -578,11 +578,13 @@ void persistent_masked_m_silu_mul_quant( // This kernel currently only supports H % 128 == 0 and assumes a // fixed GROUP_SIZE of 128. + static constexpr int GROUP_SIZE = 128; + TORCH_CHECK(input.dtype() == torch::kBFloat16); TORCH_CHECK(y_q.dtype() == torch::kFloat8_e4m3fn || y_q.dtype() == torch::kFloat8_e4m3fnuz); TORCH_CHECK(y_s.dtype() == torch::kFloat32); - TORCH_CHECK(input.size(-1) % 256 == 0); + TORCH_CHECK(input.size(-1) % (GROUP_SIZE * 2) == 0); using Idx_t = int64_t; @@ -601,8 +603,6 @@ void persistent_masked_m_silu_mul_quant( Idx_t stride_counts_e = tokens_per_expert.stride(0); - static constexpr int GROUP_SIZE = 128; - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); #define KERNEL(BLOCK_COUNT, USE_UE8M0, THREAD_COUNT, STAGES) \ @@ -628,21 +628,26 @@ void persistent_masked_m_silu_mul_quant( static constexpr int SILU_V2_BLOCK_COUNT = 132 * 32; + int const NUM_GROUPS = H / GROUP_SIZE; if (!use_ue8m0) { - if (H >= 4096) { + if (H >= 4096 && (NUM_GROUPS % 8 == 0)) { + /* 8 warps config */ static constexpr int NUM_STAGES = 4; static constexpr int THREAD_COUNT = 256; KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, NUM_STAGES); } else { + /* 1 warp config */ static constexpr int THREAD_COUNT = 32; KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, 2); } } else { - if (H >= 4096) { + if (H >= 4096 && (NUM_GROUPS % 8 == 0)) { + /* 8 warps config */ static constexpr int NUM_STAGES = 4; static constexpr int THREAD_COUNT = 256; KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, NUM_STAGES); } else { + /* 1 warp config */ static constexpr int THREAD_COUNT = 32; KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, 2); } diff --git a/csrc/quantization/fp4/nvfp4_quant_kernels.cu b/csrc/quantization/fp4/nvfp4_quant_kernels.cu index 5575ee8e4197e..6d69852bb4e4f 100644 --- a/csrc/quantization/fp4/nvfp4_quant_kernels.cu +++ b/csrc/quantization/fp4/nvfp4_quant_kernels.cu @@ -31,6 +31,13 @@ namespace vllm { +template +__host__ __device__ inline Int round_up(Int x, Int y) { + static_assert(std::is_integral_v, + "round_up argument must be integral type"); + return (x + y - 1) / y * y; +} + // Use UE4M3 by default. template __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) @@ -42,10 +49,21 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched."); + int sf_m = round_up(numRows, 128); + int sf_n_unpadded = numCols / CVT_FP4_SF_VEC_SIZE; + int sf_n_int = round_up(sf_n_unpadded, 4) / 4; + for (int row = numRows + blockIdx.x; row < sf_m; row += gridDim.x) { + // Each thread writes 4 uint32_t elements. + for (int col = sf_n_unpadded + threadIdx.x * 4; col < sf_n_int; + col += blockDim.x * 4) { + SFout[row * sf_n_int + col] = 0x00; + } + } + // Get the global scaling factor, which will be applied to the SF. // Note SFScale is the same as next GEMM's alpha, which is // (448.f / (Alpha_A / 6.f)). - float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0]; + float const global_scale = SFScale == nullptr ? 1.0f : SFScale[0]; // Input tensor row/col loops. for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { @@ -64,7 +82,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) rowIdx, colIdx, numCols, SFout); out_pos = - cvt_warp_fp16_to_fp4(in_vec, SFScaleVal, sf_out); + cvt_warp_fp16_to_fp4(in_vec, global_scale, sf_out); } } } diff --git a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu index cf2cccc913f62..62aeb927ccdcb 100644 --- a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu +++ b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu @@ -1,6 +1,5 @@ #include "scaled_mm_kernels.hpp" #include "scaled_mm_sm100_fp8_dispatch.cuh" -#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" namespace vllm { @@ -13,11 +12,11 @@ void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a, if (bias) { TORCH_CHECK(bias->dtype() == out.dtype(), "currently bias dtype must match output dtype ", out.dtype()); - return cutlass_scaled_mm_sm100_fp8_epilogue( - out, a, b, a_scales, b_scales, *bias); + return cutlass_scaled_mm_sm100_fp8_epilogue(out, a, b, a_scales, + b_scales, *bias); } else { - return cutlass_scaled_mm_sm100_fp8_epilogue( - out, a, b, a_scales, b_scales); + return cutlass_scaled_mm_sm100_fp8_epilogue(out, a, b, a_scales, + b_scales); } } diff --git a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8_dispatch.cuh index f876b7d9acd87..c950008b4139a 100644 --- a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8_dispatch.cuh +++ b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8_dispatch.cuh @@ -2,6 +2,7 @@ #include "scaled_mm.cuh" #include "cutlass_gemm_caller.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" /** * This file defines Gemm kernel configurations for SM100 (fp8) based on the @@ -12,8 +13,88 @@ namespace vllm { using c3x::cutlass_gemm_caller; -template typename Epilogue> +template typename Epilogue_, + typename TileShape, typename ClusterShape, typename KernelSchedule, + typename EpilogueSchedule, bool swap_ab_ = false> +struct cutlass_3x_gemm_sm100_fp8 { + using ElementAB = ElementAB_; + using ElementC = ElementD_; + using ElementD = ElementD_; + using ElementAcc = + typename std::conditional, int32_t, + float>::type; + + using Epilogue = Epilogue_; + + using EVTCompute = typename Epilogue::EVTCompute; + + static constexpr int AlignmentAB = + 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentCD = + 128 / cutlass::sizeof_bits::value; + + // Compile-time swap_ab flag + static constexpr bool swap_ab = swap_ab_; + + // ----------------------------------------------------------- + // Layout definitions + // ----------------------------------------------------------- + using LayoutA = cutlass::layout::RowMajor; + using LayoutA_T = typename cutlass::layout::LayoutTranspose::type; + + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutB_T = typename cutlass::layout::LayoutTranspose::type; + + using LayoutD = cutlass::layout::RowMajor; + using LayoutD_Transpose = + typename cutlass::layout::LayoutTranspose::type; + + using LayoutC = LayoutD; + using LayoutC_Transpose = LayoutD_Transpose; + + // ----------------------------------------------------------- + // Collective epilogue (conditionally swap operands and layouts) + // ----------------------------------------------------------- + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, TileShape, + ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, float, ElementC, + conditional_t, AlignmentCD, + ElementD, conditional_t, + AlignmentCD, EpilogueSchedule, EVTCompute>::CollectiveOp; + + static constexpr size_t CEStorageSize = + sizeof(typename CollectiveEpilogue::SharedStorage); + + using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(CEStorageSize)>; + + // ----------------------------------------------------------- + // Collective mainloop (conditionally swap operands and layouts) + // ----------------------------------------------------------- + using CollectiveMainloop = conditional_t< + swap_ab, + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, ElementAB, + LayoutB_T, AlignmentAB, // Swapped B (as A) + ElementAB, LayoutA_T, AlignmentAB, // Swapped A (as B) + ElementAcc, TileShape, ClusterShape, Stages, + KernelSchedule>::CollectiveOp, + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, ElementAB, + LayoutA, AlignmentAB, ElementAB, LayoutB, AlignmentAB, ElementAcc, + TileShape, ClusterShape, Stages, KernelSchedule>::CollectiveOp>; + + // ----------------------------------------------------------- + // Kernel definition + // ----------------------------------------------------------- + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, CollectiveMainloop, CollectiveEpilogue, void>; +}; + +template struct sm100_fp8_config_default { // M in (256, inf) static_assert(std::is_same()); @@ -22,12 +103,16 @@ struct sm100_fp8_config_default { using TileShape = Shape<_256, _128, _128>; using ClusterShape = Shape<_2, _2, _1>; using Cutlass3xGemm = - cutlass_3x_gemm_sm100; + conditional_t, + cutlass_3x_gemm_sm100_fp8< + InType, OutType, c3x::ScaledEpilogue, TileShape, + ClusterShape, KernelSchedule, EpilogueSchedule>>; }; -template typename Epilogue> +template struct sm100_fp8_config_M256 { // M in (64, 256] static_assert(std::is_same()); @@ -36,44 +121,127 @@ struct sm100_fp8_config_M256 { using TileShape = Shape<_128, _128, _128>; using ClusterShape = Shape<_2, _1, _1>; using Cutlass3xGemm = - cutlass_3x_gemm_sm100; + conditional_t, + cutlass_3x_gemm_sm100_fp8< + InType, OutType, c3x::ScaledEpilogue, TileShape, + ClusterShape, KernelSchedule, EpilogueSchedule>>; }; -template typename Epilogue> +template +struct sm100_fp8_config_M64_swap_ab { + // This config is for M in (16, 64] and K >= 4096 + static_assert(std::is_same()); + using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; + using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; + using TileShape = Shape<_128, _64, _256>; + using ClusterShape = Shape<_4, _1, _1>; + + // Use ScaledEpilogueColumnBias instead of ScaledEpilogueBias when doing swap + // AB + using Cutlass3xGemm = conditional_t< + EnableBias, + cutlass_3x_gemm_sm100_fp8, + cutlass_3x_gemm_sm100_fp8>; +}; + +template struct sm100_fp8_config_M64 { - // M in (16, 64] + // This config is for M = 64 and K < 4096 (do not enable swap AB in such case) static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; using TileShape = Shape<_64, _64, _128>; using ClusterShape = Shape<_1, _1, _1>; + using Cutlass3xGemm = - cutlass_3x_gemm_sm100; + conditional_t, + cutlass_3x_gemm_sm100_fp8< + InType, OutType, c3x::ScaledEpilogue, TileShape, + ClusterShape, KernelSchedule, EpilogueSchedule>>; }; -template typename Epilogue> -struct sm100_fp8_config_M16 { +template +struct sm100_fp8_config_M16_swap_ab { // M in [1, 16] static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; - using TileShape = Shape<_64, _64, _128>; - using ClusterShape = Shape<_1, _4, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm_sm100; + using TileShape = Shape<_128, _32, _128>; + using ClusterShape = Shape<_4, _1, _1>; + + // Use ScaledEpilogueColumnBias instead of ScaledEpilogueBias when doing swap + // AB + using Cutlass3xGemm = conditional_t< + EnableBias, + cutlass_3x_gemm_sm100_fp8, + cutlass_3x_gemm_sm100_fp8>; }; -template typename Epilogue, +template +void cutlass_gemm_caller_sm100_fp8(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + EpilogueArgs&&... epilogue_params) { + static constexpr bool swap_ab = Gemm::swap_ab; + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + using GemmKernel = typename Gemm::GemmKernel; + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + + int32_t m = a.size(0), n = b.size(1), k = a.size(1); + auto prob_shape = + swap_ab ? cute::make_shape(n, m, k, 1) : cute::make_shape(m, n, k, 1); + + StrideA a_stride = + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); + StrideB b_stride = + cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1)); + StrideC c_stride = cutlass::make_cute_packed_stride( + StrideC{}, + swap_ab ? cute::make_shape(n, m, 1) : cute::make_shape(m, n, 1)); + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + auto c_ptr = static_cast(out.data_ptr()); + + typename GemmKernel::MainloopArguments mainloop_args = + swap_ab ? typename GemmKernel::MainloopArguments{b_ptr, b_stride, a_ptr, + a_stride} + : typename GemmKernel::MainloopArguments{a_ptr, a_stride, b_ptr, + b_stride}; + + typename GemmKernel::EpilogueArguments epilogue_args{ + Gemm::Epilogue::prepare_args( + std::forward(epilogue_params)...), + c_ptr, c_stride, c_ptr, c_stride}; + + c3x::cutlass_gemm_caller(a.device(), prob_shape, mainloop_args, + epilogue_args); +} + +template inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, EpilogueArgs&&... args) { static_assert(std::is_same()); TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); @@ -81,55 +249,69 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out, using Cutlass3xGemmDefault = typename sm100_fp8_config_default::Cutlass3xGemm; - using Cutlass3xGemmM16 = - typename sm100_fp8_config_M16::Cutlass3xGemm; + EnableBias>::Cutlass3xGemm; + using Cutlass3xGemmM16SwapAB = + typename sm100_fp8_config_M16_swap_ab::Cutlass3xGemm; + using Cutlass3xGemmM64SwapAB = + typename sm100_fp8_config_M64_swap_ab::Cutlass3xGemm; using Cutlass3xGemmM64 = - typename sm100_fp8_config_M64::Cutlass3xGemm; + typename sm100_fp8_config_M64::Cutlass3xGemm; + using Cutlass3xGemmM256 = - typename sm100_fp8_config_M256::Cutlass3xGemm; + typename sm100_fp8_config_M256::Cutlass3xGemm; uint32_t const m = a.size(0); - uint32_t const mp2 = - std::max(static_cast(16), next_pow_2(m)); // next power of 2 + uint32_t const k = a.size(1); - if (mp2 <= 16) { + if (m <= 16) { // m in [1, 16] - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else if (mp2 <= 64) { + return cutlass_gemm_caller_sm100_fp8( + out, a, b, b_scales, a_scales, std::forward(args)...); + } else if (m <= 64) { // m in (16, 64] - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else if (mp2 <= 256) { + if (m == 64 && k < 4096) { + // do not enable swap AB + return cutlass_gemm_caller_sm100_fp8( + out, a, b, a_scales, b_scales, std::forward(args)...); + } + return cutlass_gemm_caller_sm100_fp8( + out, a, b, b_scales, a_scales, std::forward(args)...); + + } else if (m <= 256) { // m in (64, 256] - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); + return cutlass_gemm_caller_sm100_fp8( + out, a, b, a_scales, b_scales, std::forward(args)...); } else { // m in (256, inf) - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); + return cutlass_gemm_caller_sm100_fp8( + out, a, b, a_scales, b_scales, std::forward(args)...); } } -template