diff --git a/.buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_eplb.sh b/.buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_eplb.sh index 5302f524a0ae4..8106f50f18f66 100644 --- a/.buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_eplb.sh +++ b/.buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_eplb.sh @@ -17,7 +17,17 @@ wait_for_server() { } MODEL="deepseek-ai/DeepSeek-V2-lite" -BACKENDS=("deepep_high_throughput" "deepep_low_latency") + +# Set BACKENDS based on platform +if command -v rocm-smi &> /dev/null || [[ -d /opt/rocm ]] || [[ -n "${ROCM_PATH:-}" ]]; then + # ROCm platform + BACKENDS=("allgather_reducescatter") + # Disable MOE padding for ROCm since it is causing eplb to fail + export VLLM_ROCM_MOE_PADDING=0 +else + # Non-ROCm platform (CUDA/other) + BACKENDS=("deepep_high_throughput" "deepep_low_latency") +fi cleanup() { if [[ -n "${SERVER_PID:-}" ]] && kill -0 "${SERVER_PID}" 2>/dev/null; then diff --git a/.buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep.sh b/.buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh similarity index 64% rename from .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep.sh rename to .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh index a5135299297e2..6a1bef275d047 100644 --- a/.buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep.sh +++ b/.buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh @@ -1,10 +1,12 @@ #!/usr/bin/env bash set -euxo pipefail -# args: [THRESHOLD] [NUM_QUESTIONS] [START_PORT] +# args: [THRESHOLD] [NUM_QUESTIONS] [START_PORT] [DATA_PARALLEL_SIZE] [TENSOR_PARALLEL_SIZE] THRESHOLD=${1:-0.8} NUM_Q=${2:-1319} PORT=${3:-8020} +DATA_PARALLEL_SIZE=${4:-2} +TENSOR_PARALLEL_SIZE=${5:-2} OUT_DIR=${OUT_DIR:-/tmp/vllm-scheduled} mkdir -p "${OUT_DIR}" @@ -17,7 +19,16 @@ wait_for_server() { } MODEL="QWen/Qwen3-30B-A3B-FP8" -BACKENDS=("deepep_high_throughput" "deepep_low_latency") +# Set BACKENDS based on platform +if command -v rocm-smi &> /dev/null || [[ -d /opt/rocm ]] || [[ -n "${ROCM_PATH:-}" ]]; then + # ROCm platform + BACKENDS=("allgather_reducescatter") + # Disable MOE padding for ROCm since it is causing eplb to fail + export VLLM_ROCM_MOE_PADDING=0 +else + # Non-ROCm platform (CUDA/other) + BACKENDS=("deepep_high_throughput" "deepep_low_latency") +fi cleanup() { if [[ -n "${SERVER_PID:-}" ]] && kill -0 "${SERVER_PID}" 2>/dev/null; then @@ -36,8 +47,10 @@ for BACK in "${BACKENDS[@]}"; do VLLM_ALL2ALL_BACKEND=$BACK \ vllm serve "$MODEL" \ --enforce-eager \ - --tensor-parallel-size 2 \ - --data-parallel-size 2 \ + --enable-eplb \ + --eplb-config '{"window_size":10, "step_interval":100, "num_redundant_experts":0, "log_balancedness":true}' \ + --tensor-parallel-size ${TENSOR_PARALLEL_SIZE} \ + --data-parallel-size ${DATA_PARALLEL_SIZE} \ --enable-expert-parallel \ --trust-remote-code \ --max-model-len 2048 \ diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index 4e2ff5c5a6bd5..4ddf11c0b268f 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -754,6 +754,7 @@ steps: torch_nightly: true source_file_dependencies: - vllm/model_executor/models/ + - vllm/transformers_utils/ - tests/models/test_initialization.py commands: # Only when vLLM model source is modified - test initialization of a large @@ -1319,7 +1320,10 @@ steps: - pytest -v -s -x lora/test_llama_tp.py - pytest -v -s -x lora/test_llm_with_multi_loras.py - pytest -v -s -x lora/test_olmoe_tp.py - - pytest -v -s -x lora/test_gptoss_tp.py + + # Disabled for now because MXFP4 backend on non-cuda platform + # doesn't support LoRA yet + #- pytest -v -s -x lora/test_gptoss_tp.py - label: Weight Loading Multiple GPU Test # 33min @@ -1482,4 +1486,4 @@ steps: num_gpus: 4 working_dir: "/vllm-workspace" commands: - - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep.sh 0.8 200 8020 + - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index a5719d438eece..10a19c52c72dc 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -192,6 +192,7 @@ steps: # test with internal dp - python3 ../examples/offline_inference/data_parallel.py --enforce-eager - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py + - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_eagle_dp.py - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py - TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_internal_lb_dp.py - TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_hybrid_lb_dp.py @@ -691,6 +692,7 @@ steps: torch_nightly: true source_file_dependencies: - vllm/model_executor/models/ + - vllm/transformers_utils/ - tests/models/test_initialization.py commands: # Only when vLLM model source is modified - test initialization of a large @@ -901,11 +903,12 @@ steps: - label: Transformers Nightly Models Test working_dir: "/vllm-workspace/" optional: true + soft_fail: true commands: - pip install --upgrade git+https://github.com/huggingface/transformers - - pytest -v -s tests/models/test_initialization.py -k 'not (Ultravox or Phi4Multimodal or MiniCPMO or Lfm2Moe or RobertaForSequenceClassification or Ovis2_5 or DeepseekOCR or KimiVL)' + - pytest -v -s tests/models/test_initialization.py - pytest -v -s tests/models/test_transformers.py - # - pytest -v -s tests/models/multimodal/processing/ + - pytest -v -s tests/models/multimodal/processing/ - pytest -v -s tests/models/multimodal/test_mapping.py - python3 examples/offline_inference/basic/chat.py - python3 examples/offline_inference/vision_language.py --model-type qwen2_5_vl @@ -969,6 +972,7 @@ steps: - vllm/model_executor/layers/layernorm.py - vllm/model_executor/layers/activation.py - vllm/model_executor/layers/quantization/input_quant_fp8.py + - vllm/model_executor/layers/fused_moe/layer.py - tests/compile/test_fusion_attn.py - tests/compile/test_silu_mul_quant_fusion.py - tests/compile/distributed/test_fusion_all_reduce.py @@ -1115,6 +1119,7 @@ steps: # https://github.com/NVIDIA/nccl/issues/1838 - export NCCL_CUMEM_HOST_ENABLE=0 - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py + - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_eagle_dp.py - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py - DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py - pytest -v -s entrypoints/llm/test_collective_rpc.py @@ -1339,11 +1344,20 @@ steps: commands: - bash .buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_eplb.sh 0.25 200 8010 -- label: Qwen3-30B-A3B-FP8-block Accuracy +- label: Qwen3-30B-A3B-FP8-block Accuracy (H100) timeout_in_minutes: 60 gpu: h100 optional: true num_gpus: 4 working_dir: "/vllm-workspace" commands: - - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep.sh 0.8 200 8020 + - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 + +- label: Qwen3-30B-A3B-FP8-block Accuracy (B200) + timeout_in_minutes: 60 + gpu: b200 + optional: true + num_gpus: 2 + working_dir: "/vllm-workspace" + commands: + - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1 \ No newline at end of file diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 0e834c057c401..3247408e1163e 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -9,6 +9,7 @@ /vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth @yewentao256 @pavanimajety /vllm/model_executor/layers/mamba @tdoublep /vllm/model_executor/model_loader @22quinn +/vllm/model_executor/layers/batch_invariant.py @yewentao256 /vllm/multimodal @DarkLight1337 @ywang96 @NickLucche @tjtanaa /vllm/vllm_flash_attn @LucasWilkinson /vllm/lora @jeejeelee @@ -59,6 +60,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson /tests/v1/kv_connector/nixl_integration @NickLucche /tests/v1/kv_connector @ApostaC /tests/v1/offloading @ApostaC +/tests/v1/determinism @yewentao256 # Transformers modeling backend /vllm/model_executor/models/transformers @hmellor diff --git a/CMakeLists.txt b/CMakeLists.txt index a4cf51d17e982..d88ba3aa66303 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -136,7 +136,7 @@ elseif(HIP_FOUND) # ROCm 5.X and 6.X if (ROCM_VERSION_DEV_MAJOR GREATER_EQUAL 5 AND - NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM}) + Torch_VERSION VERSION_LESS ${TORCH_SUPPORTED_VERSION_ROCM}) message(WARNING "Pytorch version >= ${TORCH_SUPPORTED_VERSION_ROCM} " "expected for ROCm build, saw ${Torch_VERSION} instead.") endif() @@ -604,12 +604,15 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") set(SRCS "csrc/quantization/fp4/nvfp4_quant_kernels.cu" "csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu" - "csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu") + "csrc/quantization/fp4/nvfp4_experts_quant.cu" + "csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu" + "csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${FP4_ARCHS}") list(APPEND VLLM_EXT_SRC "${SRCS}") list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM120=1") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM120=1") message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}") else() message(STATUS "Not building NVFP4 as no compatible archs were found.") diff --git a/benchmarks/kernels/deepgemm/README.md b/benchmarks/kernels/deepgemm/README.md index 41e68e047be82..a28c6956be0e9 100644 --- a/benchmarks/kernels/deepgemm/README.md +++ b/benchmarks/kernels/deepgemm/README.md @@ -2,7 +2,7 @@ This directory includes benchmarks between DeepSeek's DeepGEMM block fp8 kernels against vLLM's existing triton and CUTLASS-based kernels. -Currently this just includes dense GEMMs and only works on Hopper GPUs. +Currently, this just includes dense GEMMs and only works on Hopper GPUs. ## Setup diff --git a/csrc/cache.h b/csrc/cache.h index b162a4a2bc31f..f2a5ec0acf5cd 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -41,11 +41,12 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, const double scale, const std::string& kv_cache_dtype); void gather_and_maybe_dequant_cache( - torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] - torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] - torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] - torch::Tensor const& cu_seq_lens, // [BATCH+1] - int64_t batch_size, const std::string& kv_cache_dtype, + torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] + torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] + torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] + torch::Tensor const& cu_seq_lens, // [BATCH+1] + torch::Tensor const& token_to_seq, // [MAX_TOKEN_ACROSS_CHUNKS] + int64_t num_tokens, const std::string& kv_cache_dtype, torch::Tensor const& scale, std::optional seq_starts = std::nullopt); diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 32960cc8073bb..8a5457206c706 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -905,91 +905,79 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, namespace vllm { // grid is launched with dimensions (batch, num_splits) -template +template __global__ void gather_and_maybe_dequant_cache( - const cache_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, - // ENTRIES...] - scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...] - const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES] - const int32_t* __restrict__ cu_seq_lens, // [BATCH+1] - const int32_t block_size, const int32_t entry_size, + const cache_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, + // ENTRIES...] + scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...] + const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES] + const int32_t* __restrict__ cu_seq_lens, // [BATCH+1] + const int32_t* __restrict__ token_to_seq, // [MAX_TOKEN_ACROSS_CHUNK] + const int32_t num_tokens, const int32_t block_size, const int64_t block_table_stride, const int64_t cache_block_stride, const int64_t cache_entry_stride, const int64_t dst_entry_stride, const float* __restrict__ scale, const int32_t* __restrict__ seq_starts) { // Optional: starting offsets per // batch + constexpr int vec_size = sizeof(float4) / sizeof(scalar_t); + using ltype = vllm::vec_n_t; + using stype = vllm::vec_n_t; + // We are adding this for code readability which will be optimized out when + // build in release. + assert(CTA_SIZE == blockDim.x); - const int64_t bid = blockIdx.x; // Batch ID - const int32_t num_splits = gridDim.y; - const int32_t split = blockIdx.y; - const int32_t seq_start = cu_seq_lens[bid]; - const int32_t seq_end = cu_seq_lens[bid + 1]; - const int32_t seq_len = seq_end - seq_start; - const int32_t tot_blocks = cuda_utils::ceil_div(seq_len, block_size); - const int32_t split_blocks = cuda_utils::ceil_div(tot_blocks, num_splits); +#pragma unroll + for (int token_id = blockIdx.x; token_id < num_tokens; + token_id += gridDim.x) { + int64_t batch_id = token_to_seq[token_id]; + int64_t batch_start = cu_seq_lens[batch_id]; + int64_t batch_end = cu_seq_lens[batch_id + 1]; + int32_t batch_offset = token_id - batch_start; - const int32_t split_start = split * split_blocks; - const int32_t split_end = min((split + 1) * split_blocks, tot_blocks); + if (token_id >= batch_end) return; + int32_t offset = 0; + if (seq_starts != nullptr) { + offset = seq_starts[batch_id]; + } + batch_offset += offset; + int32_t block_table_id = batch_offset / block_size; + int32_t slot_id = batch_offset % block_size; + int32_t block_table_offset = batch_id * block_table_stride + block_table_id; + int32_t block_id = block_table[block_table_offset]; + int64_t cache_offset = + block_id * cache_block_stride + slot_id * cache_entry_stride; + constexpr int32_t vec_iter_cnt = ENTRY_SIZE / vec_size; + scalar_t* dst_ = dst + token_id * dst_entry_stride; + cache_t* src_ = const_cast(src_cache) + cache_offset; - const bool is_active_split = (split_start < tot_blocks); - const bool is_last_split = (split_end == tot_blocks); - - if (!is_active_split) return; - - int32_t full_blocks_end = split_end; - int32_t partial_block_size = 0; - - // Adjust the pointer for the block_table for this batch. - // If seq_starts is provided, compute an offset based on (seq_starts[bid] / - // page_size) - const int32_t batch_offset = bid * block_table_stride; - int32_t offset = 0; - if (seq_starts != nullptr) { - offset = seq_starts[bid] / block_size; - } - const int32_t* batch_block_table = block_table + batch_offset + offset; - - // Adjust dst pointer based on the cumulative sequence lengths. - dst += seq_start * dst_entry_stride; - - if (is_last_split) { - partial_block_size = seq_len % block_size; - if (partial_block_size) full_blocks_end -= 1; - } - - auto copy_entry = [&](const cache_t* __restrict__ _src, - scalar_t* __restrict__ _dst) { - for (int i = threadIdx.x; i < entry_size; i += blockDim.x) { +#pragma unroll + for (int idx = threadIdx.x; idx < vec_iter_cnt; idx += CTA_SIZE) { if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { - _dst[i] = static_cast(_src[i]); + reinterpret_cast(dst_)[idx] = + static_cast(reinterpret_cast(src_)[idx]); } else { - _dst[i] = - fp8::scaled_convert(_src[i], *scale); + ltype loaded_val = reinterpret_cast(src_)[idx]; + stype store_val; +#pragma unroll + for (int j = 0; j < vec_size; ++j) { + store_val.val[j] = fp8::scaled_convert( + loaded_val.val[j], *scale); + } + reinterpret_cast(dst_)[idx] = store_val; } } - }; - - const auto loop_end = - std::min((int64_t)full_blocks_end, block_table_stride - offset); - for (int pid = split_start; pid < loop_end; ++pid) { - auto block_id = batch_block_table[pid]; - auto block_start_ptr = src_cache + block_id * cache_block_stride; - auto block_dst_ptr = dst + pid * block_size * dst_entry_stride; - for (int eid = 0; eid < block_size; ++eid) { - copy_entry(block_start_ptr + eid * cache_entry_stride, - block_dst_ptr + eid * dst_entry_stride); - } - } - - if (partial_block_size) { - if (offset + full_blocks_end < block_table_stride) { - auto block_id = batch_block_table[full_blocks_end]; - auto block_start_ptr = src_cache + block_id * cache_block_stride; - auto block_dst_ptr = - dst + full_blocks_end * block_size * dst_entry_stride; - for (int eid = 0; eid < partial_block_size; ++eid) { - copy_entry(block_start_ptr + eid * cache_entry_stride, - block_dst_ptr + eid * dst_entry_stride); + // process tail + constexpr int32_t tail_cnt = ENTRY_SIZE % vec_size; + dst_ = dst_ + ENTRY_SIZE - tail_cnt; + src_ = src_ + ENTRY_SIZE - tail_cnt; +#pragma unroll + for (int idx = threadIdx.x; idx < tail_cnt; idx += CTA_SIZE) { + if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { + dst_[idx] = static_cast(src_[idx]); + } else { + dst_[idx] = + fp8::scaled_convert(src_[idx], *scale); } } } @@ -1001,34 +989,38 @@ __global__ void gather_and_maybe_dequant_cache( // SCALAR_T is the data type of the destination tensor. // CACHE_T is the stored data type of kv-cache. // KV_DTYPE is the real data type of kv-cache. -#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE) \ - vllm::gather_and_maybe_dequant_cache \ - <<>>( \ - reinterpret_cast(src_cache.data_ptr()), \ - reinterpret_cast(dst.data_ptr()), \ - block_table.data_ptr(), cu_seq_lens.data_ptr(), \ - block_size, entry_size, block_table_stride, cache_block_stride, \ - cache_entry_stride, dst_entry_stride, \ - reinterpret_cast(scale.data_ptr()), seq_starts_ptr); +#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE) \ + vllm::gather_and_maybe_dequant_cache \ + <<>>( \ + reinterpret_cast(src_cache.data_ptr()), \ + reinterpret_cast(dst.data_ptr()), \ + block_table.data_ptr(), cu_seq_lens.data_ptr(), \ + token_to_seq.data_ptr(), num_tokens, block_size, \ + block_table_stride, cache_block_stride, cache_entry_stride, \ + dst_entry_stride, reinterpret_cast(scale.data_ptr()), \ + seq_starts_ptr); // Gather sequences from the cache into the destination tensor. // - cu_seq_lens contains the cumulative sequence lengths for each batch // - block_table contains the cache block indices for each sequence +// - token_to_seq contains the back mapping from token_id to batch_id // - Optionally, seq_starts (if provided) offsets the starting block index by // (seq_starts[bid] / page_size) void gather_and_maybe_dequant_cache( - torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] - torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] - torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] - torch::Tensor const& cu_seq_lens, // [BATCH+1] - int64_t batch_size, const std::string& kv_cache_dtype, + torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] + torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] + torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] + torch::Tensor const& cu_seq_lens, // [BATCH+1] + torch::Tensor const& token_to_seq, // [MAX_TOKEN_ACROSS_CHUNKS] + int64_t num_tokens, const std::string& kv_cache_dtype, torch::Tensor const& scale, std::optional seq_starts = std::nullopt) { at::cuda::OptionalCUDAGuard device_guard(src_cache.device()); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); int32_t block_size = src_cache.size(1); - int32_t entry_size = src_cache.flatten(2, -1).size(2); + int32_t head_dim = dst.size(-1); TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must be int32"); @@ -1038,6 +1030,9 @@ void gather_and_maybe_dequant_cache( TORCH_CHECK(seq_starts.value().dtype() == torch::kInt32, "seq_starts must be int32"); } + TORCH_CHECK(head_dim == 576, + "gather_and_maybe_dequant_cache only support the head_dim to 576 " + "for better performance") TORCH_CHECK(src_cache.device() == dst.device(), "src_cache and dst must be on the same device"); @@ -1055,10 +1050,9 @@ void gather_and_maybe_dequant_cache( int64_t cache_entry_stride = src_cache.stride(1); int64_t dst_entry_stride = dst.stride(0); - // Decide on the number of splits based on the batch size. - int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16; - dim3 grid(batch_size, num_splits); - dim3 block(1024); + constexpr int32_t thread_block_size = 64; + dim3 grid(num_tokens); + dim3 block(thread_block_size); const int32_t* seq_starts_ptr = seq_starts.has_value() ? seq_starts.value().data_ptr() : nullptr; diff --git a/csrc/cpu/cpu_attn.cpp b/csrc/cpu/cpu_attn.cpp index 50f17c758c148..92f8bee5a47a0 100644 --- a/csrc/cpu/cpu_attn.cpp +++ b/csrc/cpu/cpu_attn.cpp @@ -13,6 +13,18 @@ #define AMX_DISPATCH(...) case cpu_attention::ISA::AMX: #endif +#ifdef __aarch64__ + #include "cpu_attn_neon.hpp" + #define NEON_DISPATCH(...) \ + case cpu_attention::ISA::NEON: { \ + using attn_impl = cpu_attention::AttentionImpl; \ + return __VA_ARGS__(); \ + } +#else + #define NEON_DISPATCH(...) case cpu_attention::ISA::NEON: +#endif // #ifdef __aarch64__ + #define CPU_ATTN_DISPATCH_CASE(HEAD_DIM, ...) \ case HEAD_DIM: { \ constexpr size_t head_dim = HEAD_DIM; \ @@ -41,6 +53,7 @@ [&] { \ switch (ISA_TYPE) { \ AMX_DISPATCH(__VA_ARGS__) \ + NEON_DISPATCH(__VA_ARGS__) \ case cpu_attention::ISA::VEC: { \ using attn_impl = \ cpu_attention::AttentionImpl class AttentionImpl {}; @@ -143,6 +143,12 @@ struct AttentionMetadata { case ISA::VEC: ss << "VEC, "; break; + case ISA::VEC16: + ss << "VEC16, "; + break; + case ISA::NEON: + ss << "NEON, "; + break; } ss << "workitem_group_num: " << workitem_group_num << ", reduction_item_num: " << reduction_item_num @@ -841,7 +847,7 @@ struct VecTypeTrait { }; #endif -#if !defined(__powerpc__) +#if !defined(__powerpc__) && !defined(__s390x__) template <> struct VecTypeTrait { using vec_t = vec_op::FP16Vec16; diff --git a/csrc/cpu/cpu_attn_neon.hpp b/csrc/cpu/cpu_attn_neon.hpp new file mode 100644 index 0000000000000..827f0cfbc718e --- /dev/null +++ b/csrc/cpu/cpu_attn_neon.hpp @@ -0,0 +1,386 @@ +#ifndef CPU_ATTN_NEON_HPP +#define CPU_ATTN_NEON_HPP + +#include "cpu_attn_impl.hpp" +#include +#include +namespace cpu_attention { + +namespace { + +#define BLOCK_SIZE_ALIGNMENT 32 +#define HEAD_SIZE_ALIGNMENT 32 +#define MAX_Q_HEAD_NUM_PER_ITER 16 + +// These do not use vectorized class for loading / converting +// because csrc/cpu/cpu_types_arm.hpp does not have fallback options +// for vec_op::BF16Vec* / vec_op::BF16Vec* on Arm HW that +// doesn't support BF16. +// We don't use vec_op::FP32Vec* or vec_op::FP16Vec* for consistency. +template +FORCE_INLINE void load_row8_B_as_f32(const kv_cache_t* p, float32x4_t& b0, + float32x4_t& b1); + +template <> +FORCE_INLINE void load_row8_B_as_f32(const float* p, float32x4_t& b0, + float32x4_t& b1) { + b0 = vld1q_f32(p + 0); + b1 = vld1q_f32(p + 4); +} + +template <> +FORCE_INLINE void load_row8_B_as_f32(const c10::Half* p, + float32x4_t& b0, + float32x4_t& b1) { + const float16_t* h = reinterpret_cast(p); + float16x8_t v = vld1q_f16(h); + b0 = vcvt_f32_f16(vget_low_f16(v)); + b1 = vcvt_f32_f16(vget_high_f16(v)); +} + +template <> +FORCE_INLINE void load_row8_B_as_f32(const c10::BFloat16* p, + float32x4_t& b0, + float32x4_t& b1) { + const uint16_t* u = reinterpret_cast(p); +#ifdef ARM_BF16_SUPPORT + uint16x8_t u0 = vld1q_u16(u); + bfloat16x8_t bf0 = vreinterpretq_bf16_u16(u0); + b0 = vcvtq_low_f32_bf16(bf0); + b1 = vcvtq_high_f32_bf16(bf0); +#else + uint16x8_t x0 = vld1q_u16(u); + uint32x4_t lo = vshlq_n_u32(vmovl_u16(vget_low_u16(x0)), 16); + uint32x4_t hi = vshlq_n_u32(vmovl_u16(vget_high_u16(x0)), 16); + b0 = vreinterpretq_f32_u32(lo); + b1 = vreinterpretq_f32_u32(hi); +#endif +} + +// Mx8, with 1 <= M <= 8 , K streamed, unroll-by-4 with NEON FMLAs +// #Loads = (K // 4) * (M + 4 * sizeof(kv_cache_t) / 2) +// #FMLAs = (K // 4) * (4 * 2 * M) +// We have (4 * 2 * M) FMLAs for (M + 4 * sizeof(kv_cache_t) / 2) loads +template +FORCE_INLINE void gemm_micro_neon_fmla_Mx8_Ku4( + const float* __restrict A, // [M x K], + const kv_cache_t* __restrict B, // [K x 8], + float* __restrict C, // [M x 8], + int64_t lda, int64_t ldb, int64_t ldc, int32_t K, bool accumulate) { + // kernel supports max M of 8, as it'd spill for larger M + static_assert(1 <= M && M <= 8, "M must be in [1,8]"); + +// helpers for per-M codegen +#define ROWS_APPLY(OP) OP(0) OP(1) OP(2) OP(3) OP(4) OP(5) OP(6) OP(7) +#define IF_M(i) if constexpr (M > (i)) + + // A row base pointers +#define DECL_A(i) const float* a##i = A + (i) * lda; + ROWS_APPLY(DECL_A) +#undef DECL_A + + // declare 2 accumulators per row of M +#define DECL_ACC(i) float32x4_t acc##i##_0, acc##i##_1; + ROWS_APPLY(DECL_ACC) +#undef DECL_ACC + + // initialize accumulators +#define INIT_ACC(i) \ + IF_M(i) { \ + if (accumulate) { \ + acc##i##_0 = vld1q_f32(C + (i) * ldc + 0); \ + acc##i##_1 = vld1q_f32(C + (i) * ldc + 4); \ + } else { \ + acc##i##_0 = vdupq_n_f32(0.f); \ + acc##i##_1 = vdupq_n_f32(0.f); \ + } \ + } + ROWS_APPLY(INIT_ACC) +#undef INIT_ACC + + int32_t k = 0; + + // K unrolled by 4 + for (; k + 3 < K; k += 4) { + // load A[k..k+3] for each active row (M) +#define LOAD_A4(i) \ + float32x4_t a##i##v; \ + IF_M(i) a##i##v = vld1q_f32(a##i + k); + ROWS_APPLY(LOAD_A4) +#undef LOAD_A4 + + // helper: FMA lane L from aiv +#define FMAS_LANE(i, aiv, L) \ + IF_M(i) { \ + acc##i##_0 = vfmaq_laneq_f32(acc##i##_0, b0, aiv, L); \ + acc##i##_1 = vfmaq_laneq_f32(acc##i##_1, b1, aiv, L); \ + } + + // k + 0 + { + float32x4_t b0, b1; + load_row8_B_as_f32(B + (int64_t)(k + 0) * ldb, b0, b1); +#define STEP_K0(i) FMAS_LANE(i, a##i##v, 0) + ROWS_APPLY(STEP_K0) +#undef STEP_K0 + } + // k + 1 + { + float32x4_t b0, b1; + load_row8_B_as_f32(B + (int64_t)(k + 1) * ldb, b0, b1); +#define STEP_K1(i) FMAS_LANE(i, a##i##v, 1) + ROWS_APPLY(STEP_K1) +#undef STEP_K1 + } + // k + 2 + { + float32x4_t b0, b1; + load_row8_B_as_f32(B + (int64_t)(k + 2) * ldb, b0, b1); +#define STEP_K2(i) FMAS_LANE(i, a##i##v, 2) + ROWS_APPLY(STEP_K2) +#undef STEP_K2 + } + // k + 3 + { + float32x4_t b0, b1; + load_row8_B_as_f32(B + (int64_t)(k + 3) * ldb, b0, b1); +#define STEP_K3(i) FMAS_LANE(i, a##i##v, 3) + ROWS_APPLY(STEP_K3) +#undef STEP_K3 + } +#undef FMAS_LANE + } + + // K tail + for (; k < K; ++k) { + float32x4_t b0, b1; + load_row8_B_as_f32(B + (int64_t)k * ldb, b0, b1); +#define TAIL_ROW(i) \ + IF_M(i) { \ + float32x4_t ai = vdupq_n_f32(*(a##i + k)); \ + acc##i##_0 = vfmaq_f32(acc##i##_0, b0, ai); \ + acc##i##_1 = vfmaq_f32(acc##i##_1, b1, ai); \ + } + ROWS_APPLY(TAIL_ROW) +#undef TAIL_ROW + } + + // store accumulators to C +#define STORE_ROW(i) \ + IF_M(i) { \ + vst1q_f32(C + (i) * ldc + 0, acc##i##_0); \ + vst1q_f32(C + (i) * ldc + 4, acc##i##_1); \ + } + ROWS_APPLY(STORE_ROW) +#undef STORE_ROW + +#undef ROWS_APPLY +#undef IF_M +} + +template +FORCE_INLINE void gemm_macro_neon_fmla_Mx8_Ku4(const float* __restrict A, + const kv_cache_t* __restrict B, + float* __restrict C, int32_t M, + int32_t K, int64_t lda, + int64_t ldb, int64_t ldc, + bool accumulate) { + // micro kernel is Mx8 + static_assert(N % 8 == 0, "N must be a multiple of 8"); + for (int32_t m = 0; m < M;) { + int32_t mb = (M - m >= 8) ? 8 : (M - m >= 4) ? 4 : (M - m >= 2) ? 2 : 1; + const float* Ab = A + m * lda; + float* Cb = C + m * ldc; + + for (int32_t n = 0; n < N; n += 8) { + const kv_cache_t* Bn = B + n; + float* Cn = Cb + n; + switch (mb) { + case 8: + gemm_micro_neon_fmla_Mx8_Ku4<8, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc, + K, accumulate); + break; + case 4: + gemm_micro_neon_fmla_Mx8_Ku4<4, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc, + K, accumulate); + break; + case 2: + gemm_micro_neon_fmla_Mx8_Ku4<2, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc, + K, accumulate); + break; + default: + gemm_micro_neon_fmla_Mx8_Ku4<1, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc, + K, accumulate); + break; + } + } + // no tail loop for N as it's guaranteed to be a multiple of 8 + m += mb; + } +} + +template +class TileGemmNeonFMLA { + public: + template + FORCE_INLINE static void gemm(const int32_t m_size, + float* __restrict__ a_tile, + kv_cache_t* __restrict__ b_tile, + float* __restrict__ c_tile, const int64_t lda, + const int64_t ldb, const int64_t ldc, + const int32_t block_size, + const int32_t dynamic_k_size, + const bool accum_c) { + if constexpr (phase == AttentionGemmPhase::QK) { + gemm_macro_neon_fmla_Mx8_Ku4( + a_tile, b_tile, c_tile, m_size, k_size, lda, ldb, ldc, accum_c); + } else { + gemm_macro_neon_fmla_Mx8_Ku4( + a_tile, b_tile, c_tile, m_size, dynamic_k_size, lda, ldb, ldc, + accum_c); + } + } +}; + +} // namespace + +// this is similar to "ISA::VEC" at the moment +template +class AttentionImpl { + public: + using query_t = scalar_t; + using q_buffer_t = float; + using kv_cache_t = scalar_t; + using logits_buffer_t = float; + using partial_output_buffer_t = float; + using prob_buffer_t = float; + + constexpr static int64_t BlockSizeAlignment = + BLOCK_SIZE_ALIGNMENT; // KV token num unit of QK and PV phases + constexpr static int64_t HeadDimAlignment = + HEAD_SIZE_ALIGNMENT; // headdim num unit of PV phase + constexpr static int64_t MaxQHeadNumPerIteration = MAX_Q_HEAD_NUM_PER_ITER; + constexpr static int64_t HeadDim = head_dim; + constexpr static ISA ISAType = ISA::NEON; + constexpr static bool scale_on_logits = false; // apply scale on q_buffer + + static_assert(HeadDim % HeadDimAlignment == 0); + // the gemm micro kernel is Mx8 + static_assert(HeadDimAlignment % 8 == 0); + static_assert(BlockSizeAlignment % 8 == 0); + + public: + template