diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 9d0b3fdd3a02c..8e6d32f71f220 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -654,7 +654,7 @@ steps: - vllm/model_executor/layers/quantization autorun_on_main: true commands: - - pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1 + - pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt - label: OpenAI API correctness # 22min timeout_in_minutes: 30 @@ -1064,7 +1064,7 @@ steps: - csrc/ - vllm/model_executor/layers/quantization commands: - - pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-blackwell.txt --tp-size=1 + - pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-blackwell.txt ##### 1 GPU test ##### ##### multi gpus test ##### diff --git a/.github/mergify.yml b/.github/mergify.yml index 3ad79f93bc7ad..3e4e21efe39df 100644 --- a/.github/mergify.yml +++ b/.github/mergify.yml @@ -235,6 +235,20 @@ pull_request_rules: add: - rocm +- name: label-cpu + description: Automatically apply cpu label + conditions: + - label != stale + - files~=^(?!.*kv_offload)(?!.*cpu_offload).*\bcpu.* + actions: + label: + add: + - cpu + assign: + users: + - "fadara01" + - "aditew01" + - name: label-structured-output description: Automatically apply structured-output label conditions: diff --git a/CMakeLists.txt b/CMakeLists.txt index cd52df86e0346..5ca71f6ba4df0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -357,6 +357,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # marlin arches for fp16 output cuda_archs_loose_intersection(MARLIN_ARCHS "8.0+PTX" "${CUDA_ARCHS}") + # marlin has limited support for turing + cuda_archs_loose_intersection(MARLIN_SM75_ARCHS "7.5" "${CUDA_ARCHS}") # marlin arches for bf16 output (we need 9.0 for bf16 atomicAdd PTX) cuda_archs_loose_intersection(MARLIN_BF16_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}") # marlin arches for fp8 input @@ -364,8 +366,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction # so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0) cuda_archs_loose_intersection(MARLIN_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}") + # marlin arches for other files + cuda_archs_loose_intersection(MARLIN_OTHER_ARCHS "7.5;8.0+PTX" "${CUDA_ARCHS}") - if (MARLIN_ARCHS) + if (MARLIN_OTHER_ARCHS) # # For the Marlin kernels we automatically generate sources for various @@ -406,25 +410,39 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") message(STATUS "Marlin generation script has not changed, skipping generation.") endif() - file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_float16.cu") - set_gencode_flags_for_srcs( - SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}" - CUDA_ARCHS "${MARLIN_ARCHS}") - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) - set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC} - PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") - endif() - list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC}) + if (MARLIN_ARCHS) + file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_float16.cu") + set_gencode_flags_for_srcs( + SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}" + CUDA_ARCHS "${MARLIN_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) + set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC} + PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") + endif() + list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC}) - file(GLOB MARLIN_TEMPLATE_BF16_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_bfloat16.cu") - set_gencode_flags_for_srcs( - SRCS "${MARLIN_TEMPLATE_BF16_KERNEL_SRC}" - CUDA_ARCHS "${MARLIN_BF16_ARCHS}") - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) - set_source_files_properties(${MARLIN_TEMPLATE_BF16_KERNEL_SRC} - PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") + file(GLOB MARLIN_TEMPLATE_BF16_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_bfloat16.cu") + set_gencode_flags_for_srcs( + SRCS "${MARLIN_TEMPLATE_BF16_KERNEL_SRC}" + CUDA_ARCHS "${MARLIN_BF16_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) + set_source_files_properties(${MARLIN_TEMPLATE_BF16_KERNEL_SRC} + PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") + endif() + list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_BF16_KERNEL_SRC}) + endif() + + if (MARLIN_SM75_ARCHS) + file(GLOB MARLIN_TEMPLATE_SM75_KERNEL_SRC "csrc/quantization/gptq_marlin/sm75_kernel_*.cu") + set_gencode_flags_for_srcs( + SRCS "${MARLIN_TEMPLATE_SM75_KERNEL_SRC}" + CUDA_ARCHS "${MARLIN_SM75_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) + set_source_files_properties(${MARLIN_TEMPLATE_SM75_KERNEL_SRC} + PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") + endif() + list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_SM75_KERNEL_SRC}) endif() - list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_BF16_KERNEL_SRC}) if (MARLIN_FP8_ARCHS) file(GLOB MARLIN_TEMPLATE_FP8_KERNEL_SRC "csrc/quantization/gptq_marlin/sm89_kernel_*.cu") @@ -446,14 +464,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/gptq_marlin/awq_marlin_repack.cu") set_gencode_flags_for_srcs( SRCS "${MARLIN_SRCS}" - CUDA_ARCHS "${MARLIN_ARCHS}") + CUDA_ARCHS "${MARLIN_OTHER_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) - set_source_files_properties("csrc/quantization/gptq_marlin/gptq_marlin.cu" + set_source_files_properties(${MARLIN_SRCS} PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") endif() list(APPEND VLLM_EXT_SRC "${MARLIN_SRCS}") - message(STATUS "Building Marlin kernels for archs: ${MARLIN_ARCHS}") + message(STATUS "Building Marlin kernels for archs: ${MARLIN_OTHER_ARCHS}") else() message(STATUS "Not building Marlin kernels as no compatible archs found" " in CUDA target architectures") @@ -980,12 +998,16 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # note that we always set `use_atomic_add=False` for moe marlin now, # so we don't need 9.0 for bf16 atomicAdd PTX cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0+PTX" "${CUDA_ARCHS}") + # moe marlin has limited support for turing + cuda_archs_loose_intersection(MARLIN_MOE_SM75_ARCHS "7.5" "${CUDA_ARCHS}") # moe marlin arches for fp8 input # - sm80 doesn't support fp8 computation # - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction # so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0) cuda_archs_loose_intersection(MARLIN_MOE_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}") - if (MARLIN_MOE_ARCHS) + # moe marlin arches for other files + cuda_archs_loose_intersection(MARLIN_MOE_OTHER_ARCHS "7.5;8.0+PTX" "${CUDA_ARCHS}") + if (MARLIN_MOE_OTHER_ARCHS) # # For the Marlin MOE kernels we automatically generate sources for various @@ -1026,16 +1048,29 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") message(STATUS "Marlin MOE generation script has not changed, skipping generation.") endif() - file(GLOB MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/sm80_kernel_*.cu") - list(APPEND MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/ops.cu") - set_gencode_flags_for_srcs( - SRCS "${MARLIN_MOE_SRC}" - CUDA_ARCHS "${MARLIN_MOE_ARCHS}") - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) - set_source_files_properties(${MARLIN_MOE_SRC} - PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") + if (MARLIN_MOE_ARCHS) + file(GLOB MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/sm80_kernel_*.cu") + set_gencode_flags_for_srcs( + SRCS "${MARLIN_MOE_SRC}" + CUDA_ARCHS "${MARLIN_MOE_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) + set_source_files_properties(${MARLIN_MOE_SRC} + PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") + endif() + list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_SRC}) + endif() + + if (MARLIN_MOE_SM75_ARCHS) + file(GLOB MARLIN_MOE_SM75_SRC "csrc/moe/marlin_moe_wna16/sm75_kernel_*.cu") + set_gencode_flags_for_srcs( + SRCS "${MARLIN_MOE_SM75_SRC}" + CUDA_ARCHS "${MARLIN_MOE_SM75_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) + set_source_files_properties(${MARLIN_MOE_SM75_SRC} + PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") + endif() + list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_SM75_SRC}) endif() - list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_SRC}) if (MARLIN_MOE_FP8_ARCHS) file(GLOB MARLIN_MOE_FP8_SRC "csrc/moe/marlin_moe_wna16/sm89_kernel_*.cu") @@ -1049,7 +1084,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_FP8_SRC}) endif() - message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}") + set(MARLIN_MOE_OTHER_SRC "csrc/moe/marlin_moe_wna16/ops.cu") + set_gencode_flags_for_srcs( + SRCS "${MARLIN_MOE_OTHER_SRC}" + CUDA_ARCHS "${MARLIN_MOE_OTHER_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) + set_source_files_properties(${MARLIN_MOE_OTHER_SRC} + PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") + endif() + list(APPEND VLLM_MOE_EXT_SRC "${MARLIN_MOE_OTHER_SRC}") + + message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_OTHER_ARCHS}") else() message(STATUS "Not building Marlin MOE kernels as no compatible archs found" " in CUDA target architectures") diff --git a/benchmarks/kernels/benchmark_activation.py b/benchmarks/kernels/benchmark_activation.py index 66268b71b3de6..d31e67057d8f6 100644 --- a/benchmarks/kernels/benchmark_activation.py +++ b/benchmarks/kernels/benchmark_activation.py @@ -13,8 +13,8 @@ from vllm.triton_utils import triton from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE -batch_size_range = [1, 16, 32, 64, 128] -seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096] +batch_size_range = [1, 16, 128] +seq_len_range = [1, 16, 64, 1024, 4096] intermediate_size = [3072, 9728, 12288] configs = list(itertools.product(batch_size_range, seq_len_range, intermediate_size)) diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index a4a880f13cf7e..8268065ef02c8 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -15,19 +15,61 @@ __device__ __forceinline__ scalar_t compute(const scalar_t& x, const scalar_t& y) { return act_first ? ACT_FN(x) * y : x * ACT_FN(y); } -// Activation and gating kernel template. +// Check if all pointers are 16-byte aligned for int4 vectorized access +__device__ __forceinline__ bool is_16byte_aligned(const void* ptr) { + return (reinterpret_cast(ptr) & 15) == 0; +} + +// Activation and gating kernel template. template __global__ void act_and_mul_kernel( scalar_t* __restrict__ out, // [..., d] const scalar_t* __restrict__ input, // [..., 2, d] const int d) { + constexpr int VEC_SIZE = 16 / sizeof(scalar_t); const int64_t token_idx = blockIdx.x; - for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { - const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); - const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]); - out[token_idx * d + idx] = compute(x, y); + const scalar_t* x_ptr = input + token_idx * 2 * d; + const scalar_t* y_ptr = x_ptr + d; + scalar_t* out_ptr = out + token_idx * d; + + // Check alignment for 128-bit vectorized access. + // All three pointers must be 16-byte aligned for safe int4 operations. + const bool aligned = is_16byte_aligned(x_ptr) && is_16byte_aligned(y_ptr) && + is_16byte_aligned(out_ptr); + + if (aligned && d >= VEC_SIZE) { + // Fast path: 128-bit vectorized loop + const int4* x_vec = reinterpret_cast(x_ptr); + const int4* y_vec = reinterpret_cast(y_ptr); + int4* out_vec = reinterpret_cast(out_ptr); + const int num_vecs = d / VEC_SIZE; + const int vec_end = num_vecs * VEC_SIZE; + + for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) { + int4 x = VLLM_LDG(&x_vec[i]), y = VLLM_LDG(&y_vec[i]), r; + auto* xp = reinterpret_cast(&x); + auto* yp = reinterpret_cast(&y); + auto* rp = reinterpret_cast(&r); +#pragma unroll + for (int j = 0; j < VEC_SIZE; j++) { + rp[j] = compute(xp[j], yp[j]); + } + out_vec[i] = r; + } + // Scalar cleanup for remaining elements + for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) { + out_ptr[i] = compute(VLLM_LDG(&x_ptr[i]), + VLLM_LDG(&y_ptr[i])); + } + } else { + // Scalar fallback for unaligned data or small d + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + const scalar_t x = VLLM_LDG(&x_ptr[idx]); + const scalar_t y = VLLM_LDG(&y_ptr[idx]); + out_ptr[idx] = compute(x, y); + } } } @@ -120,50 +162,115 @@ template __global__ void act_and_mul_kernel_with_param( scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const int d, const float param) { + constexpr int VEC_SIZE = 16 / sizeof(scalar_t); const int64_t token_idx = blockIdx.x; - for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { - const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); - const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]); - out[token_idx * d + idx] = ACT_FN(x, param) * y; + const scalar_t* x_ptr = input + token_idx * 2 * d; + const scalar_t* y_ptr = x_ptr + d; + scalar_t* out_ptr = out + token_idx * d; + + // Check alignment for 128-bit vectorized access + const bool aligned = is_16byte_aligned(x_ptr) && is_16byte_aligned(y_ptr) && + is_16byte_aligned(out_ptr); + + if (aligned && d >= VEC_SIZE) { + // Fast path: 128-bit vectorized loop + const int4* x_vec = reinterpret_cast(x_ptr); + const int4* y_vec = reinterpret_cast(y_ptr); + int4* out_vec = reinterpret_cast(out_ptr); + const int num_vecs = d / VEC_SIZE; + const int vec_end = num_vecs * VEC_SIZE; + + for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) { + int4 x = VLLM_LDG(&x_vec[i]), y = VLLM_LDG(&y_vec[i]), r; + auto* xp = reinterpret_cast(&x); + auto* yp = reinterpret_cast(&y); + auto* rp = reinterpret_cast(&r); +#pragma unroll + for (int j = 0; j < VEC_SIZE; j++) { + rp[j] = ACT_FN(xp[j], param) * yp[j]; + } + out_vec[i] = r; + } + // Scalar cleanup for remaining elements + for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) { + out_ptr[i] = ACT_FN(VLLM_LDG(&x_ptr[i]), param) * VLLM_LDG(&y_ptr[i]); + } + } else { + // Scalar fallback for unaligned data or small d + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + const scalar_t x = VLLM_LDG(&x_ptr[idx]); + const scalar_t y = VLLM_LDG(&y_ptr[idx]); + out_ptr[idx] = ACT_FN(x, param) * y; + } } } template __device__ __forceinline__ T swigluoai_and_mul(const T& gate, const T& up, float alpha, float limit) { - // clamp gate: min=None, max=limit - const float gate_f = (float)gate; - const float clamped_gate = gate_f > limit ? limit : gate_f; - - // clamp up: min=-limit, max=limit - const float up_f = (float)up; - const float clamped_up = - up_f > limit ? limit : (up_f < -limit ? -limit : up_f); - - // glu = gate * sigmoid(gate * alpha) - const float sigmoid_val = 1.0f / (1.0f + expf(-clamped_gate * alpha)); - const float glu = clamped_gate * sigmoid_val; - - // (up + 1) * glu - return (T)((clamped_up + 1.0f) * glu); + // Clamp gate to (-inf, limit] and up to [-limit, limit] + const float g = fminf((float)gate, limit); + const float u = fmaxf(fminf((float)up, limit), -limit); + // glu = gate * sigmoid(gate * alpha), then return (up + 1) * glu + return (T)((u + 1.0f) * g / (1.0f + expf(-g * alpha))); } +// Interleaved gate/up: input has [gate0, up0, gate1, up1, ...]. template __global__ void swigluoai_and_mul_kernel( scalar_t* __restrict__ out, // [..., d] - const scalar_t* __restrict__ input, // [..., 2, d] + const scalar_t* __restrict__ input, // [..., 2 * d] (interleaved) const int d, const float alpha, const float limit) { + // For interleaved data: input has 2*d elements per token (gate/up pairs) + // output has d elements per token + constexpr int VEC_SIZE = 16 / sizeof(scalar_t); + constexpr int PAIRS = VEC_SIZE / 2; // Number of gate/up pairs per int4 load const int64_t token_idx = blockIdx.x; - // TODO: Vectorize loads and stores. - for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { - // gate = x[..., ::2] (even indices) - const scalar_t gate = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx]); - // up = x[..., 1::2] (odd indices) - const scalar_t up = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx + 1]); + const scalar_t* in_ptr = input + token_idx * 2 * d; + scalar_t* out_ptr = out + token_idx * d; - out[token_idx * d + idx] = ACT_FN(gate, up, alpha, limit); + // Check alignment for 128-bit vectorized access on input. + // For output we use int2 (64-bit) which has 8-byte alignment requirement. + const bool in_aligned = is_16byte_aligned(in_ptr); + const bool out_aligned = + (reinterpret_cast(out_ptr) & 7) == 0; // 8-byte for int2 + + if (in_aligned && out_aligned && d >= PAIRS) { + // Fast path: vectorized loop + // Each int4 load gives VEC_SIZE elements = PAIRS gate/up pairs + // Each int2 store writes PAIRS output elements + const int4* in_vec = reinterpret_cast(in_ptr); + int2* out_vec = reinterpret_cast(out_ptr); + const int num_vecs = d / PAIRS; + const int vec_end = num_vecs * PAIRS; + + for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) { + int4 v = VLLM_LDG(&in_vec[i]); + int2 r; + auto* vp = reinterpret_cast(&v); + auto* rp = reinterpret_cast(&r); +#pragma unroll + for (int j = 0; j < PAIRS; j++) { + rp[j] = ACT_FN(vp[2 * j], vp[2 * j + 1], alpha, limit); + } + out_vec[i] = r; + } + // Scalar cleanup for remaining elements + for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) { + out_ptr[i] = ACT_FN(VLLM_LDG(&in_ptr[2 * i]), + VLLM_LDG(&in_ptr[2 * i + 1]), alpha, limit); + } + } else { + // Scalar fallback for unaligned data or small d + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + // gate = x[..., ::2] (even indices) + const scalar_t gate = VLLM_LDG(&in_ptr[2 * idx]); + // up = x[..., 1::2] (odd indices) + const scalar_t up = VLLM_LDG(&in_ptr[2 * idx + 1]); + out_ptr[idx] = ACT_FN(gate, up, alpha, limit); + } } } @@ -217,10 +324,41 @@ __global__ void activation_kernel( scalar_t* __restrict__ out, // [..., d] const scalar_t* __restrict__ input, // [..., d] const int d) { + constexpr int VEC_SIZE = 16 / sizeof(scalar_t); const int64_t token_idx = blockIdx.x; - for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { - const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]); - out[token_idx * d + idx] = ACT_FN(x); + const scalar_t* in_ptr = input + token_idx * d; + scalar_t* out_ptr = out + token_idx * d; + + // Check alignment for 128-bit vectorized access + const bool aligned = is_16byte_aligned(in_ptr) && is_16byte_aligned(out_ptr); + + if (aligned && d >= VEC_SIZE) { + // Fast path: 128-bit vectorized loop + const int4* in_vec = reinterpret_cast(in_ptr); + int4* out_vec = reinterpret_cast(out_ptr); + const int num_vecs = d / VEC_SIZE; + const int vec_end = num_vecs * VEC_SIZE; + + for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) { + int4 v = VLLM_LDG(&in_vec[i]), r; + auto* vp = reinterpret_cast(&v); + auto* rp = reinterpret_cast(&r); +#pragma unroll + for (int j = 0; j < VEC_SIZE; j++) { + rp[j] = ACT_FN(vp[j]); + } + out_vec[i] = r; + } + // Scalar cleanup for remaining elements + for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) { + out_ptr[i] = ACT_FN(VLLM_LDG(&in_ptr[i])); + } + } else { + // Scalar fallback for unaligned data or small d + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + const scalar_t x = VLLM_LDG(&in_ptr[idx]); + out_ptr[idx] = ACT_FN(x); + } } } diff --git a/csrc/moe/grouped_topk_kernels.cu b/csrc/moe/grouped_topk_kernels.cu index 5fa367abd96f5..7229e420d3fe4 100644 --- a/csrc/moe/grouped_topk_kernels.cu +++ b/csrc/moe/grouped_topk_kernels.cu @@ -446,9 +446,13 @@ __device__ inline T apply_sigmoid(T val) { template __device__ inline T apply_scoring(T val) { - if constexpr (SF == SCORING_SIGMOID) { + if constexpr (SF == SCORING_NONE) { + return val; + } else if constexpr (SF == SCORING_SIGMOID) { return apply_sigmoid(val); } else { + static_assert(SF == SCORING_NONE || SF == SCORING_SIGMOID, + "Unsupported ScoringFunc in apply_scoring"); return val; } } @@ -670,10 +674,13 @@ __global__ void group_idx_and_topk_idx_kernel( if (case_id < num_tokens) { if (if_proceed_next_topk) { + float scale = routed_scaling_factor; + if (renormalize) { + scale /= topk_sum; + } for (int i = lane_id; i < topk; i += WARP_SIZE) { float base = cuda_cast(s_topk_value[i]); - float value = renormalize ? (base / topk_sum * routed_scaling_factor) - : (base * routed_scaling_factor); + float value = base * scale; topk_indices[i] = s_topk_idx[i]; topk_values[i] = value; } diff --git a/csrc/moe/marlin_moe_wna16/.gitignore b/csrc/moe/marlin_moe_wna16/.gitignore index ba805f9250ece..7dc482a894660 100644 --- a/csrc/moe/marlin_moe_wna16/.gitignore +++ b/csrc/moe/marlin_moe_wna16/.gitignore @@ -1,2 +1,3 @@ sm*_kernel_*.cu kernel_selector.h +kernel_*.cu diff --git a/csrc/moe/marlin_moe_wna16/generate_kernels.py b/csrc/moe/marlin_moe_wna16/generate_kernels.py index 88f1055337fd5..9db03ea149d0c 100644 --- a/csrc/moe/marlin_moe_wna16/generate_kernels.py +++ b/csrc/moe/marlin_moe_wna16/generate_kernels.py @@ -10,6 +10,8 @@ import jinja2 ARCHS = [] SUPPORT_FP8 = False +SUPPORT_SM75 = False +SUPPORT_SM80 = False for arch in sys.argv[1].split(","): arch = arch[: arch.index(".") + 2].replace(".", "") arch = int(arch) @@ -19,6 +21,10 @@ for arch in sys.argv[1].split(","): # with FP16 MMA, so it cannot achieve any acceleration. if arch in [89, 120]: SUPPORT_FP8 = True + if arch >= 80: + SUPPORT_SM80 = True + if arch == 75: + SUPPORT_SM75 = True FILE_HEAD_COMMENT = """ // auto generated by generate_kernels.py @@ -157,6 +163,7 @@ def remove_old_kernels(): def generate_new_kernels(): result_dict = {} + sm_75_result_dict = {} for quant_config in QUANT_CONFIGS: c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"]) @@ -174,6 +181,8 @@ def generate_new_kernels(): s_type = quant_config.get("s_type", c_type) if (a_type, b_type, c_type) not in result_dict: result_dict[(a_type, b_type, c_type)] = [] + if a_type in ["kFloat16", "kS8"] and c_type == "kFloat16": + sm_75_result_dict[(a_type, b_type, c_type)] = [] for group_blocks, m_blocks, thread_configs in itertools.product( all_group_blocks, all_m_blocks, all_thread_configs @@ -197,78 +206,89 @@ def generate_new_kernels(): "thread_k_blocks": thread_k // 16, "thread_n_blocks": thread_n // 16, "m_block_size_8": "true" if m_blocks == 0.5 else "false", - "stages": "pipe_stages", + "stages": 4, "group_blocks": group_blocks, "is_zp_float": "false", } - result_dict[(a_type, b_type, c_type)].append(config) + if SUPPORT_SM80: + result_dict[(a_type, b_type, c_type)].append(config) + if (a_type, b_type, c_type) in sm_75_result_dict and SUPPORT_SM75: + config_sm75 = config.copy() + config_sm75["stages"] = 2 + sm_75_result_dict[(a_type, b_type, c_type)].append(config_sm75) kernel_selector_str = FILE_HEAD_COMMENT - for (a_type, b_type, c_type), config_list in result_dict.items(): - all_template_str_list = [] - for config in config_list: - s_type = config["s_type"] - template_str = jinja2.Template(TEMPLATE).render( - a_type_id=f"vllm::{a_type}.id()", - b_type_id=f"vllm::{b_type}.id()", - c_type_id=f"vllm::{c_type}.id()", - s_type_id=f"vllm::{s_type}.id()", - **config, - ) - all_template_str_list.append(template_str) - - conditions = [ - f"a_type == vllm::{a_type}", - f"b_type == vllm::{b_type}", - f"c_type == vllm::{c_type}", - f"s_type == vllm::{s_type}", - f"threads == {config['threads']}", - f"thread_m_blocks == {config['thread_m_blocks']}", - f"thread_n_blocks == {config['thread_n_blocks']}", - f"thread_k_blocks == {config['thread_k_blocks']}", - f"m_block_size_8 == {config['m_block_size_8']}", - f"group_blocks == {config['group_blocks']}", - f"is_zp_float == {config['is_zp_float']}", - ] - conditions = " && ".join(conditions) - - if kernel_selector_str == FILE_HEAD_COMMENT: - kernel_selector_str += f"if ({conditions})\n kernel = " - else: - kernel_selector_str += f"else if ({conditions})\n kernel = " - - kernel_template2 = ( - "Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, " - "{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, " - "{{thread_n_blocks}}, {{thread_k_blocks}}, " - "{{m_block_size_8}}, {{stages}}, {{group_blocks}}, " - "{{is_zp_float}}>;" - ) - - kernel_selector_str += ( - jinja2.Template(kernel_template2).render( + for result_dict_tmp in [result_dict, sm_75_result_dict]: + for (a_type, b_type, c_type), config_list in result_dict_tmp.items(): + all_template_str_list = [] + if not config_list: + continue + for config in config_list: + s_type = config["s_type"] + template_str = jinja2.Template(TEMPLATE).render( a_type_id=f"vllm::{a_type}.id()", b_type_id=f"vllm::{b_type}.id()", c_type_id=f"vllm::{c_type}.id()", s_type_id=f"vllm::{s_type}.id()", **config, ) - + "\n" - ) + all_template_str_list.append(template_str) - file_content = FILE_HEAD + "\n\n" - file_content += "\n\n".join(all_template_str_list) + "\n\n}\n" - if a_type == "kFE4M3fn": - filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu" - else: - filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu" + conditions = [ + f"a_type == vllm::{a_type}", + f"b_type == vllm::{b_type}", + f"c_type == vllm::{c_type}", + f"s_type == vllm::{s_type}", + f"threads == {config['threads']}", + f"thread_m_blocks == {config['thread_m_blocks']}", + f"thread_n_blocks == {config['thread_n_blocks']}", + f"thread_k_blocks == {config['thread_k_blocks']}", + f"m_block_size_8 == {config['m_block_size_8']}", + f"stages == {config['stages']}", + f"group_blocks == {config['group_blocks']}", + f"is_zp_float == {config['is_zp_float']}", + ] + conditions = " && ".join(conditions) - filename = filename.lower() + if kernel_selector_str == FILE_HEAD_COMMENT: + kernel_selector_str += f"if ({conditions})\n kernel = " + else: + kernel_selector_str += f"else if ({conditions})\n kernel = " - with open(os.path.join(os.path.dirname(__file__), filename), "w") as f: - f.write(file_content) + kernel_template2 = ( + "Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, " + "{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, " + "{{thread_n_blocks}}, {{thread_k_blocks}}, " + "{{m_block_size_8}}, {{stages}}, {{group_blocks}}, " + "{{is_zp_float}}>;" + ) + + kernel_selector_str += ( + jinja2.Template(kernel_template2).render( + a_type_id=f"vllm::{a_type}.id()", + b_type_id=f"vllm::{b_type}.id()", + c_type_id=f"vllm::{c_type}.id()", + s_type_id=f"vllm::{s_type}.id()", + **config, + ) + + "\n" + ) + + file_content = FILE_HEAD + "\n\n" + file_content += "\n\n".join(all_template_str_list) + "\n\n}\n" + if a_type == "kFE4M3fn": + filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu" + elif result_dict_tmp is sm_75_result_dict: + filename = f"sm75_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu" + else: + filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu" + + filename = filename.lower() + + with open(os.path.join(os.path.dirname(__file__), filename), "w") as f: + f.write(file_content) if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT: kernel_selector_str += ( diff --git a/csrc/moe/marlin_moe_wna16/marlin_template.h b/csrc/moe/marlin_moe_wna16/marlin_template.h index 5b6b2456b4111..138197b76f026 100644 --- a/csrc/moe/marlin_moe_wna16/marlin_template.h +++ b/csrc/moe/marlin_moe_wna16/marlin_template.h @@ -26,6 +26,7 @@ #include "quantization/gptq_marlin/marlin.cuh" #include "quantization/gptq_marlin/marlin_dtypes.cuh" #include "quantization/gptq_marlin/dequant.h" +#include "quantization/gptq_marlin/marlin_mma.h" #include "core/scalar_type.hpp" #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ @@ -35,7 +36,7 @@ namespace MARLIN_NAMESPACE_NAME { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 template -__device__ inline void mma( - const typename MarlinScalarType::FragA& a_frag, - const typename MarlinScalarType::FragB& frag_b, - typename MarlinScalarType::FragC& frag_c, int idx = 0) { - const uint32_t* a = reinterpret_cast(&a_frag); - const uint32_t* b = reinterpret_cast(&frag_b); - using scalar_t = typename MarlinScalarType::scalar_t; - if constexpr (k_size == 16) { - if constexpr (std::is_same::value) { - float* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else if constexpr (std::is_same::value) { - float* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else if constexpr (std::is_same::value) { - float* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 " - "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "f"(c[0]), - "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else if constexpr (std::is_same::value) { - int32_t* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite " - "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" - : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) - : "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "r"(c[0]), - "r"(c[1]), "r"(c[2]), "r"(c[3])); - } - } else if (k_size == 32) { - if constexpr (std::is_same::value) { - float* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else if constexpr (std::is_same::value) { - int32_t* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); - } - } -} - -template -__device__ inline void mma_trans( - const typename MarlinScalarType::FragA& a_frag, - const typename MarlinScalarType::FragB& frag_b, - const typename MarlinScalarType::FragB& frag_b2, - typename MarlinScalarType::FragC& frag_c) { - const uint32_t* a = reinterpret_cast(&a_frag); - const uint32_t* b = reinterpret_cast(&frag_b); - const uint32_t* b2 = reinterpret_cast(&frag_b2); - float* c = reinterpret_cast(&frag_c); - using scalar_t = typename MarlinScalarType::scalar_t; - if constexpr (k_size == 16) { - if constexpr (std::is_same::value) { - float* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else if constexpr (std::is_same::value) { - float* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else if constexpr (std::is_same::value) { - float* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 " - "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]), - "f"(c[3])); - } else if constexpr (std::is_same::value) { - int32_t* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite " - "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" - : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) - : "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]), - "r"(c[3])); - } - } else { - if constexpr (std::is_same::value) { - float* c = reinterpret_cast(&frag_c); - #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 1200 - asm volatile( - "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e4m3.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - #else - asm volatile( - "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - #endif - } else if constexpr (std::is_same::value) { - int32_t* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) - : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), - "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); - } - } -} - // Instruction for loading a full 16x16 matrix fragment of operand A from shared // memory, directly in tensor core layout. template @@ -439,9 +300,20 @@ __global__ void Marlin( if constexpr (a_type_id == vllm::kFE4M3fn.id()) return; #endif + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 + // Turing TensorCore only supports fp16 and int8 + if constexpr (a_type_id != vllm::kFloat16.id() && a_type_id != vllm::kS8.id()) + return; + #endif + int num_tokens_past_padded = num_tokens_past_padded_ptr[0]; constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 + constexpr bool use_fp16_accum = a_type_id == vllm::kFloat16.id(); + #else + constexpr bool use_fp16_accum = false; + #endif using Adtype = MarlinScalarType; using Cdtype = MarlinScalarType; @@ -618,7 +490,22 @@ __global__ void Marlin( } } + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 + + if constexpr (moe_block_size >= 16) + local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 16); + if constexpr (moe_block_size >= 8) + local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 8); + if constexpr (moe_block_size >= 4) + local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 4); + if constexpr (moe_block_size >= 2) + local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 2); + + local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 1); + block_num_valid_tokens = local_count; + #else block_num_valid_tokens = __reduce_add_sync(0xffffffff, local_count); + #endif if (lane_id == 0) reinterpret_cast(sh_new)[0] = block_num_valid_tokens; @@ -1018,10 +905,6 @@ __global__ void Marlin( constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride) : (stages * s_sh_stage); int4* sh_s = sh_zp + (stages * zp_sh_stage); - // shared memory reused by reduction should be smaller than - // shared memory used by weight. - static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <= - stages * b_sh_stage); int4* sh_a = sh_s + sh_s_size; // Register storage for double buffer of shared memory reads. @@ -1545,11 +1428,13 @@ __global__ void Marlin( #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { if constexpr (m_block_size_8) { - mma_trans(frag_a[k2][i], frag_b0, frag_b1, - frag_c[i][j][0]); + mma_trans(frag_a[k2][i], frag_b0, frag_b1, + frag_c[i][j][0]); } else { - mma(frag_a[k2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k2][i], frag_b1, frag_c[i][j][1]); + mma(frag_a[k2][i], frag_b0, + frag_c[i][j][0]); + mma(frag_a[k2][i], frag_b1, + frag_c[i][j][1]); } } } @@ -1583,10 +1468,12 @@ __global__ void Marlin( #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { - mma(frag_a[k2][i], frag_b[0], - (group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]); - mma(frag_a[k2][i], frag_b[1], - (group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]); + mma( + frag_a[k2][i], frag_b[0], + (group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]); + mma( + frag_a[k2][i], frag_b[1], + (group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]); } if constexpr (group_blocks != -1) { @@ -2132,6 +2019,21 @@ __global__ void Marlin( // While this pattern may not be the most readable, other ways of writing // the loop seemed to noticeably worse performance after compilation. if (slice_iters == 0) { + // convert fp16 accum to fp32 for reduction + if constexpr (use_fp16_accum) { + #pragma unroll + for (int i = 0; i < (thread_m_blocks * (is_a_8bit ? 2 : 4) * 2); i++) { + float* frag_c_part_float = reinterpret_cast(frag_c) + i * 4; + scalar_t* frag_c_part_half = + reinterpret_cast(frag_c_part_float); + + #pragma unroll + for (int i = 3; i >= 0; i--) { + frag_c_part_float[i] = Cdtype::num2float(frag_c_part_half[i]); + } + } + } + if constexpr (is_a_8bit) { float frag_a_s[2 * thread_m_blocks]; diff --git a/csrc/moe/marlin_moe_wna16/ops.cu b/csrc/moe/marlin_moe_wna16/ops.cu index 4fd8fc5c54202..8ac1691220a6b 100644 --- a/csrc/moe/marlin_moe_wna16/ops.cu +++ b/csrc/moe/marlin_moe_wna16/ops.cu @@ -142,7 +142,7 @@ typedef struct { int get_scales_cache_size(thread_config_t const& th_config, int prob_m, int prob_n, int prob_k, int num_bits, int group_size, - bool has_act_order, bool is_k_full) { + bool has_act_order, bool is_k_full, int stages) { bool cache_scales_chunk = has_act_order && !is_k_full; int tb_n = th_config.thread_n; @@ -160,13 +160,13 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m, if (cache_scales_chunk) { int load_groups = - tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K + tb_groups * stages * 2; // Chunk size is 2x pipeline over dim K load_groups = max(load_groups, 32); // We load at least 32 scale groups return load_groups * tb_n * 2; } else { int tb_scales = tb_groups * tb_n * 2; - return tb_scales * pipe_stages; + return tb_scales * stages; } } @@ -174,7 +174,7 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8, int thread_m_blocks, int prob_m, int prob_n, int prob_k, int num_bits, int group_size, bool has_act_order, bool is_k_full, int has_zp, - int is_zp_float, bool is_a_8bit) { + int is_zp_float, bool is_a_8bit, int stages) { int pack_factor = 32 / num_bits; // Get B size @@ -185,8 +185,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8, // shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights // both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32) int sh_block_meta_size = tb_m * 16; - int sh_a_size = pipe_stages * (tb_m * tb_k) * (is_a_8bit ? 1 : 2); - int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4; + int sh_a_size = stages * (tb_m * tb_k) * (is_a_8bit ? 1 : 2); + int sh_b_size = stages * (tb_k * tb_n / pack_factor) * 4; int sh_red_size = tb_m * (tb_n + 8) * 2; int sh_bias_size = tb_n * 2; int tmp_size = @@ -195,8 +195,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8, int sh_s_size = get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, - group_size, has_act_order, is_k_full); - int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0; + group_size, has_act_order, is_k_full, stages); + int sh_g_idx_size = has_act_order && !is_k_full ? stages * tb_k / 4 : 0; int sh_zp_size = 0; if (has_zp) { if (is_zp_float) @@ -217,7 +217,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8, int thread_m_blocks, int prob_m, int prob_n, int prob_k, int num_bits, int group_size, bool has_act_order, bool is_k_full, int has_zp, int is_zp_float, - int max_shared_mem, bool is_a_8bit) { + bool is_a_8bit, int stages, int max_shared_mem) { // Sanity if (th_config.thread_k == -1 || th_config.thread_n == -1 || th_config.num_threads == -1) { @@ -243,7 +243,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8, int cache_size = get_kernel_cache_size(th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, - is_k_full, has_zp, is_zp_float, is_a_8bit); + is_k_full, has_zp, is_zp_float, is_a_8bit, stages); return cache_size <= max_shared_mem; } @@ -252,7 +252,7 @@ MarlinFuncPtr get_marlin_kernel( const vllm::ScalarType c_type, const vllm::ScalarType s_type, int thread_m_blocks, int thread_n_blocks, int thread_k_blocks, bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks, - int threads, bool is_zp_float) { + int threads, bool is_zp_float, int stages) { int num_bits = b_type.size_bits(); auto kernel = MarlinDefault; @@ -266,8 +266,8 @@ exec_config_t determine_exec_config( const vllm::ScalarType& c_type, const vllm::ScalarType& s_type, int prob_m, int prob_n, int prob_k, int num_experts, int top_k, int thread_m_blocks, bool m_block_size_8, int num_bits, int group_size, bool has_act_order, - bool is_k_full, bool has_zp, bool is_zp_float, int max_shared_mem, int sms, - bool is_a_8bit) { + bool is_k_full, bool has_zp, bool is_zp_float, bool is_a_8bit, int stages, + int max_shared_mem, int sms) { exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}}; thread_config_t* thread_configs = thread_m_blocks > 1 ? large_batch_thread_configs @@ -284,15 +284,15 @@ exec_config_t determine_exec_config( if (!is_valid_config(th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, - is_k_full, has_zp, is_zp_float, max_shared_mem - 512, - is_a_8bit)) { + is_k_full, has_zp, is_zp_float, is_a_8bit, stages, + max_shared_mem - 512)) { continue; } int cache_size = get_kernel_cache_size( th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float, - is_a_8bit); + is_a_8bit, stages); int group_blocks = 0; if (!has_act_order) { @@ -303,7 +303,7 @@ exec_config_t determine_exec_config( get_marlin_kernel(a_type, b_type, c_type, s_type, thread_m_blocks, th_config.thread_n / 16, th_config.thread_k / 16, m_block_size_8, has_act_order, has_zp, group_blocks, - th_config.num_threads, is_zp_float); + th_config.num_threads, is_zp_float, stages); if (kernel == MarlinDefault) continue; @@ -433,8 +433,14 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, dev); cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, dev); - TORCH_CHECK(major_capability * 10 + minor_capability >= 80, - "marlin kernel only support Ampere or newer GPUs."); + TORCH_CHECK(major_capability * 10 + minor_capability >= 75, + "marlin kernel only support Turing or newer GPUs."); + int stages = 4; + if (major_capability == 7 && minor_capability == 5) { + stages = 2; + TORCH_CHECK(a_type == vllm::kFloat16 || a_type == vllm::kS8, + "Turing only support FP16 or INT8 activation."); + } if (a_type == vllm::kFE4M3fn) { TORCH_CHECK(major_capability * 10 + minor_capability >= 89, "FP8 only support Ada Lovelace or newer GPUs."); @@ -461,8 +467,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, exec_cfg = determine_exec_config( a_type, b_type, c_type, s_type, prob_m, prob_n, prob_k, num_experts, top_k, thread_m_blocks, m_block_size_8, num_bits, group_size, - has_act_order, is_k_full, has_zp, is_zp_float, max_shared_mem, sms, - is_a_8bit); + has_act_order, is_k_full, has_zp, is_zp_float, is_a_8bit, stages, + max_shared_mem, sms); thread_tfg = exec_cfg.tb_cfg; } @@ -479,7 +485,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, TORCH_CHECK(is_valid_config(thread_tfg, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float, - max_shared_mem, is_a_8bit), + is_a_8bit, stages, max_shared_mem), "Invalid thread config: thread_m_blocks = ", thread_m_blocks, ", thread_k = ", thread_tfg.thread_k, ", thread_n = ", thread_tfg.thread_n, @@ -493,12 +499,12 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, int sh_cache_size = get_kernel_cache_size(thread_tfg, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, - is_k_full, has_zp, is_zp_float, is_a_8bit); + is_k_full, has_zp, is_zp_float, is_a_8bit, stages); auto kernel = get_marlin_kernel( a_type, b_type, c_type, s_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks, - num_threads, is_zp_float); + num_threads, is_zp_float, stages); if (kernel == MarlinDefault) { TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n, diff --git a/csrc/quantization/gptq_marlin/.gitignore b/csrc/quantization/gptq_marlin/.gitignore index ba805f9250ece..7dc482a894660 100644 --- a/csrc/quantization/gptq_marlin/.gitignore +++ b/csrc/quantization/gptq_marlin/.gitignore @@ -1,2 +1,3 @@ sm*_kernel_*.cu kernel_selector.h +kernel_*.cu diff --git a/csrc/quantization/gptq_marlin/dequant.h b/csrc/quantization/gptq_marlin/dequant.h index 26b8d40368aa9..edd97dbfcd8e5 100644 --- a/csrc/quantization/gptq_marlin/dequant.h +++ b/csrc/quantization/gptq_marlin/dequant.h @@ -67,7 +67,7 @@ where `scale_factor * multiplier` can be computed at weight loading. namespace MARLIN_NAMESPACE_NAME { -#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 750 // Lookup-table based 3-input logical operation; explicitly used for // dequantization as the compiler does not seem to automatically recognize it in // all cases. diff --git a/csrc/quantization/gptq_marlin/generate_kernels.py b/csrc/quantization/gptq_marlin/generate_kernels.py index 27ef7271ba41c..24866fc5cd546 100644 --- a/csrc/quantization/gptq_marlin/generate_kernels.py +++ b/csrc/quantization/gptq_marlin/generate_kernels.py @@ -10,6 +10,8 @@ import jinja2 ARCHS = [] SUPPORT_FP8 = False +SUPPORT_SM75 = False +SUPPORT_SM80 = False for arch in sys.argv[1].split(","): arch = arch[: arch.index(".") + 2].replace(".", "") arch = int(arch) @@ -19,6 +21,10 @@ for arch in sys.argv[1].split(","): # with FP16 MMA, so it cannot achieve any acceleration. if arch in [89, 120]: SUPPORT_FP8 = True + if arch >= 80: + SUPPORT_SM80 = True + if arch == 75: + SUPPORT_SM75 = True FILE_HEAD_COMMENT = """ // auto generated by generate_kernels.py @@ -166,6 +172,7 @@ def remove_old_kernels(): def generate_new_kernels(): result_dict = {} + sm_75_result_dict = {} for quant_config in QUANT_CONFIGS: c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"]) @@ -184,6 +191,8 @@ def generate_new_kernels(): s_type = quant_config.get("s_type", c_type) if (a_type, b_type, c_type) not in result_dict: result_dict[(a_type, b_type, c_type)] = [] + if a_type in ["kFloat16", "kS8"] and c_type == "kFloat16": + sm_75_result_dict[(a_type, b_type, c_type)] = [] for group_blocks, m_blocks, thread_configs in itertools.product( all_group_blocks, all_m_blocks, all_thread_configs @@ -207,78 +216,89 @@ def generate_new_kernels(): "thread_k_blocks": thread_k // 16, "thread_n_blocks": thread_n // 16, "m_block_size_8": "true" if m_blocks == 0.5 else "false", - "stages": "pipe_stages", + "stages": 4, "group_blocks": group_blocks, "is_zp_float": "true" if is_zp_float else "false", } - result_dict[(a_type, b_type, c_type)].append(config) + if SUPPORT_SM80: + result_dict[(a_type, b_type, c_type)].append(config) + if (a_type, b_type, c_type) in sm_75_result_dict and SUPPORT_SM75: + config_sm75 = config.copy() + config_sm75["stages"] = 2 + sm_75_result_dict[(a_type, b_type, c_type)].append(config_sm75) kernel_selector_str = FILE_HEAD_COMMENT - for (a_type, b_type, c_type), config_list in result_dict.items(): - all_template_str_list = [] - for config in config_list: - s_type = config["s_type"] - template_str = jinja2.Template(TEMPLATE).render( - a_type_id=f"vllm::{a_type}.id()", - b_type_id=f"vllm::{b_type}.id()", - c_type_id=f"vllm::{c_type}.id()", - s_type_id=f"vllm::{s_type}.id()", - **config, - ) - all_template_str_list.append(template_str) - - conditions = [ - f"a_type == vllm::{a_type}", - f"b_type == vllm::{b_type}", - f"c_type == vllm::{c_type}", - f"s_type == vllm::{s_type}", - f"threads == {config['threads']}", - f"thread_m_blocks == {config['thread_m_blocks']}", - f"thread_n_blocks == {config['thread_n_blocks']}", - f"thread_k_blocks == {config['thread_k_blocks']}", - f"m_block_size_8 == {config['m_block_size_8']}", - f"group_blocks == {config['group_blocks']}", - f"is_zp_float == {config['is_zp_float']}", - ] - conditions = " && ".join(conditions) - - if kernel_selector_str == FILE_HEAD_COMMENT: - kernel_selector_str += f"if ({conditions})\n kernel = " - else: - kernel_selector_str += f"else if ({conditions})\n kernel = " - - kernel_template2 = ( - "Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, " - "{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, " - "{{thread_n_blocks}}, {{thread_k_blocks}}, " - "{{m_block_size_8}}, {{stages}}, {{group_blocks}}, " - "{{is_zp_float}}>;" - ) - - kernel_selector_str += ( - jinja2.Template(kernel_template2).render( + for result_dict_tmp in [result_dict, sm_75_result_dict]: + for (a_type, b_type, c_type), config_list in result_dict_tmp.items(): + all_template_str_list = [] + if not config_list: + continue + for config in config_list: + s_type = config["s_type"] + template_str = jinja2.Template(TEMPLATE).render( a_type_id=f"vllm::{a_type}.id()", b_type_id=f"vllm::{b_type}.id()", c_type_id=f"vllm::{c_type}.id()", s_type_id=f"vllm::{s_type}.id()", **config, ) - + "\n" - ) + all_template_str_list.append(template_str) - file_content = FILE_HEAD + "\n\n" - file_content += "\n\n".join(all_template_str_list) + "\n\n}\n" - if a_type == "kFE4M3fn": - filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu" - else: - filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu" + conditions = [ + f"a_type == vllm::{a_type}", + f"b_type == vllm::{b_type}", + f"c_type == vllm::{c_type}", + f"s_type == vllm::{s_type}", + f"threads == {config['threads']}", + f"thread_m_blocks == {config['thread_m_blocks']}", + f"thread_n_blocks == {config['thread_n_blocks']}", + f"thread_k_blocks == {config['thread_k_blocks']}", + f"m_block_size_8 == {config['m_block_size_8']}", + f"stages == {config['stages']}", + f"group_blocks == {config['group_blocks']}", + f"is_zp_float == {config['is_zp_float']}", + ] + conditions = " && ".join(conditions) - filename = filename.lower() + if kernel_selector_str == FILE_HEAD_COMMENT: + kernel_selector_str += f"if ({conditions})\n kernel = " + else: + kernel_selector_str += f"else if ({conditions})\n kernel = " - with open(os.path.join(os.path.dirname(__file__), filename), "w") as f: - f.write(file_content) + kernel_template2 = ( + "Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, " + "{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, " + "{{thread_n_blocks}}, {{thread_k_blocks}}, " + "{{m_block_size_8}}, {{stages}}, {{group_blocks}}, " + "{{is_zp_float}}>;" + ) + + kernel_selector_str += ( + jinja2.Template(kernel_template2).render( + a_type_id=f"vllm::{a_type}.id()", + b_type_id=f"vllm::{b_type}.id()", + c_type_id=f"vllm::{c_type}.id()", + s_type_id=f"vllm::{s_type}.id()", + **config, + ) + + "\n" + ) + + file_content = FILE_HEAD + "\n\n" + file_content += "\n\n".join(all_template_str_list) + "\n\n}\n" + if a_type == "kFE4M3fn": + filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu" + elif result_dict_tmp is sm_75_result_dict: + filename = f"sm75_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu" + else: + filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu" + + filename = filename.lower() + + with open(os.path.join(os.path.dirname(__file__), filename), "w") as f: + f.write(file_content) if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT: kernel_selector_str += ( diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 28ff06559a98a..77f319d53bc52 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -37,7 +37,7 @@ __global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){}; using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS); -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr, @@ -148,7 +148,7 @@ typedef struct { int get_scales_cache_size(thread_config_t const& th_config, int prob_m, int prob_n, int prob_k, int num_bits, int group_size, - bool has_act_order, bool is_k_full) { + bool has_act_order, bool is_k_full, int stages) { bool cache_scales_chunk = has_act_order && !is_k_full; int tb_n = th_config.thread_n; @@ -166,28 +166,29 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m, if (cache_scales_chunk) { int load_groups = - tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K + tb_groups * stages * 2; // Chunk size is 2x pipeline over dim K load_groups = max(load_groups, 32); // We load at least 32 scale groups return load_groups * tb_n * 2; } else { int tb_scales = tb_groups * tb_n * 2; - return tb_scales * pipe_stages; + return tb_scales * stages; } } int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks, int prob_m, int prob_n, int prob_k, int num_bits, int group_size, bool has_act_order, bool is_k_full, - int has_zp, int is_zp_float) { + int has_zp, bool is_zp_float, bool is_a_8bit, + int stages) { int pack_factor = 32 / num_bits; // Get B size int tb_k = th_config.thread_k; int tb_n = th_config.thread_n; int tb_m = thread_m_blocks * 16; - int sh_a_size = pipe_stages * (tb_m * tb_k) * 2; - int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4; + int sh_a_size = stages * (tb_m * tb_k) * (is_a_8bit ? 1 : 2); + int sh_b_size = stages * (tb_k * tb_n / pack_factor) * 4; int sh_red_size = tb_m * (tb_n + 8) * 2; int sh_bias_size = tb_n * 2; int tmp_size = @@ -196,8 +197,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks, int sh_s_size = get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, - group_size, has_act_order, is_k_full); - int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0; + group_size, has_act_order, is_k_full, stages); + int sh_g_idx_size = has_act_order && !is_k_full ? stages * tb_k / 4 : 0; int sh_zp_size = 0; if (has_zp) { if (is_zp_float) @@ -217,7 +218,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks, bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks, int prob_m, int prob_n, int prob_k, int num_bits, int group_size, bool has_act_order, bool is_k_full, - int has_zp, int is_zp_float, int max_shared_mem) { + int has_zp, bool is_zp_float, bool is_a_8bit, int stages, + int max_shared_mem) { // Sanity if (th_config.thread_k == -1 || th_config.thread_n == -1 || th_config.num_threads == -1) { @@ -242,7 +244,7 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks, // Check that pipeline fits into cache int cache_size = get_kernel_cache_size( th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, - has_act_order, is_k_full, has_zp, is_zp_float); + has_act_order, is_k_full, has_zp, is_zp_float, is_a_8bit, stages); return cache_size <= max_shared_mem; } @@ -251,7 +253,7 @@ MarlinFuncPtr get_marlin_kernel( const vllm::ScalarType c_type, const vllm::ScalarType s_type, int thread_m_blocks, int thread_n_blocks, int thread_k_blocks, bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks, - int threads, bool is_zp_float) { + int threads, bool is_zp_float, int stages) { int num_bits = b_type.size_bits(); auto kernel = MarlinDefault; @@ -265,7 +267,8 @@ exec_config_t determine_exec_config( const vllm::ScalarType& c_type, const vllm::ScalarType& s_type, int prob_m, int prob_n, int prob_k, int thread_m_blocks, bool m_block_size_8, int num_bits, int group_size, bool has_act_order, bool is_k_full, - bool has_zp, bool is_zp_float, int max_shared_mem, int sms) { + bool has_zp, bool is_zp_float, int is_a_8bit, int stages, + int max_shared_mem, int sms) { exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}}; thread_config_t* thread_configs = thread_m_blocks > 1 ? large_batch_thread_configs @@ -280,13 +283,15 @@ exec_config_t determine_exec_config( if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full, has_zp, - is_zp_float, max_shared_mem - 512)) { + is_zp_float, is_a_8bit, stages, + max_shared_mem - 512)) { continue; } - int cache_size = get_kernel_cache_size( - th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, - group_size, has_act_order, is_k_full, has_zp, is_zp_float); + int cache_size = get_kernel_cache_size(th_config, thread_m_blocks, prob_m, + prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, has_zp, + is_zp_float, is_a_8bit, stages); int group_blocks = 0; if (!has_act_order) { @@ -297,14 +302,10 @@ exec_config_t determine_exec_config( get_marlin_kernel(a_type, b_type, c_type, s_type, thread_m_blocks, th_config.thread_n / 16, th_config.thread_k / 16, m_block_size_8, has_act_order, has_zp, group_blocks, - th_config.num_threads, is_zp_float); + th_config.num_threads, is_zp_float, stages); if (kernel == MarlinDefault) continue; - // int m_tiles = div_ceil(prob_m, thread_m_blocks * 16); - // int n_tiles = prob_n / th_config.thread_n; - // int k_tiles = prob_k / th_config.thread_k; - return {1, th_config}; } @@ -321,6 +322,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, int group_size, int dev, cudaStream_t stream, int thread_k_init, int thread_n_init, int sms, bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) { + bool is_a_8bit = a_type.size_bits() == 8; TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); @@ -389,8 +391,14 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, dev); cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, dev); - TORCH_CHECK(major_capability * 10 + minor_capability >= 80, - "marlin kernel only support Ampere or newer GPUs."); + TORCH_CHECK(major_capability * 10 + minor_capability >= 75, + "marlin kernel only support Turing or newer GPUs."); + int stages = 4; + if (major_capability == 7 && minor_capability == 5) { + stages = 2; + TORCH_CHECK(a_type == vllm::kFloat16 || a_type == vllm::kS8, + "Turing only support FP16 or INT8 activation."); + } if (a_type == vllm::kFE4M3fn) { TORCH_CHECK( major_capability * 10 + minor_capability == 89 || @@ -431,7 +439,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, exec_cfg = determine_exec_config( a_type, b_type, c_type, s_type, prob_m_split, prob_n, prob_k, thread_m_blocks, m_block_size_8, num_bits, group_size, has_act_order, - is_k_full, has_zp, is_zp_float, max_shared_mem, sms); + is_k_full, has_zp, is_zp_float, is_a_8bit, stages, max_shared_mem, + sms); thread_tfg = exec_cfg.tb_cfg; if (thread_tfg.thread_n != -1) { if (prob_n / thread_tfg.thread_n * @@ -440,7 +449,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, if (is_valid_config({128, 64, 128}, thread_m_blocks, prob_m_split, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float, - max_shared_mem_new)) { + is_a_8bit, stages, max_shared_mem_new)) { thread_tfg = {128, 64, 128}; exec_cfg = {1, thread_tfg}; } @@ -466,7 +475,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, TORCH_CHECK( is_valid_config(thread_tfg, thread_m_blocks, prob_m_split, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full, - has_zp, is_zp_float, max_shared_mem_new), + has_zp, is_zp_float, is_a_8bit, stages, + max_shared_mem_new), "Invalid thread config: thread_m_blocks = ", thread_m_blocks, ", thread_k = ", thread_tfg.thread_k, ", thread_n = ", thread_tfg.thread_n, @@ -475,12 +485,12 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, ", prob_m_split = ", prob_m_split, ", group_size = ", group_size, ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, ", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float, - ", max_shared_mem_new = ", max_shared_mem_new); + ", stages = ", stages, ", max_shared_mem_new = ", max_shared_mem_new); auto kernel = get_marlin_kernel( a_type, b_type, c_type, s_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks, - num_threads, is_zp_float); + num_threads, is_zp_float, stages); if (kernel == MarlinDefault) { TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n, diff --git a/csrc/quantization/gptq_marlin/marlin.cuh b/csrc/quantization/gptq_marlin/marlin.cuh index 2505e221322dd..33fe52f605b42 100644 --- a/csrc/quantization/gptq_marlin/marlin.cuh +++ b/csrc/quantization/gptq_marlin/marlin.cuh @@ -1,17 +1,19 @@ #pragma once -#include +#ifndef _marlin_cuh + #define _marlin_cuh + #include -#include -#include -#include -#include -#include -#include + #include + #include + #include + #include + #include + #include -#ifndef MARLIN_NAMESPACE_NAME - #define MARLIN_NAMESPACE_NAME marlin -#endif + #ifndef MARLIN_NAMESPACE_NAME + #define MARLIN_NAMESPACE_NAME marlin + #endif namespace MARLIN_NAMESPACE_NAME { @@ -51,9 +53,51 @@ using I4 = Vec; constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 -// No support for async -#else + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +__device__ inline void cp_async1_ca_pred(void* smem_ptr, const void* glob_ptr, + bool pred = true) { + if (pred) { + reinterpret_cast(smem_ptr)[0] = + reinterpret_cast(glob_ptr)[0]; + } +} + +__device__ inline void cp_async2_ca_pred(void* smem_ptr, const void* glob_ptr, + bool pred = true) { + if (pred) { + reinterpret_cast(smem_ptr)[0] = + reinterpret_cast(glob_ptr)[0]; + } +} + +__device__ inline void cp_async4_ca_pred(void* smem_ptr, const void* glob_ptr, + bool pred = true) { + if (pred) { + reinterpret_cast(smem_ptr)[0] = + reinterpret_cast(glob_ptr)[0]; + } +} + +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, + bool pred = true) { + if (pred) { + reinterpret_cast(smem_ptr)[0] = + reinterpret_cast(glob_ptr)[0]; + } +} + +__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { + reinterpret_cast(smem_ptr)[0] = + reinterpret_cast(glob_ptr)[0]; +} + +__device__ inline void cp_async_fence() {} + +template +__device__ inline void cp_async_wait() {} + + #else __device__ inline void cp_async1_ca_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) { @@ -126,6 +170,8 @@ __device__ inline void cp_async_wait() { asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); } -#endif + #endif } // namespace MARLIN_NAMESPACE_NAME + +#endif \ No newline at end of file diff --git a/csrc/quantization/gptq_marlin/marlin_mma.h b/csrc/quantization/gptq_marlin/marlin_mma.h new file mode 100644 index 0000000000000..6ec2aaafc4392 --- /dev/null +++ b/csrc/quantization/gptq_marlin/marlin_mma.h @@ -0,0 +1,269 @@ + +#include "marlin_dtypes.cuh" + +namespace MARLIN_NAMESPACE_NAME { + +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 +// output/accumulation. +template +__device__ inline void mma( + const typename MarlinScalarType::FragA& a_frag, + const typename MarlinScalarType::FragB& frag_b, + typename MarlinScalarType::FragC& frag_c, int idx = 0) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + using scalar_t = typename MarlinScalarType::scalar_t; + if constexpr (!std::is_same::value || k_size != 16) { + static_assert(!use_fp16_accum); + } + + if constexpr (k_size == 16) { + if constexpr (std::is_same::value && !use_fp16_accum) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(b[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]), + "f"(c[3])); + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[2]), "r"(a[3]), "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), + "f"(c[3])); +#else + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); +#endif + } else if constexpr (std::is_same::value && + use_fp16_accum) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 + uint32_t* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " + "{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n" + : "=r"(c[0]), "=r"(c[1]) + : "r"(a[0]), "r"(a[1]), "r"(b[0]), "r"(c[0]), "r"(c[1])); + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " + "{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n" + : "=r"(c[0]), "=r"(c[1]) + : "r"(a[2]), "r"(a[3]), "r"(b[1]), "r"(c[0]), "r"(c[1])); +#else + uint32_t* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + "{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n" + : "=r"(c[0]), "=r"(c[1]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "r"(c[0]), "r"(c[1])); +#endif + } else if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "f"(c[0]), + "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + int32_t* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "r"(c[0]), + "r"(c[1]), "r"(c[2]), "r"(c[3])); + } + } else if (k_size == 32) { + if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + int32_t* c = reinterpret_cast(&frag_c); +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=r"(c[0]), "=r"(c[1]) + : "r"(a[0]), "r"(b[0]), "r"(c[0]), "r"(c[1])); + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=r"(c[2]), "=r"(c[3]) + : "r"(a[1]), "r"(b[0]), "r"(c[2]), "r"(c[3])); + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=r"(c[0]), "=r"(c[1]) + : "r"(a[2]), "r"(b[1]), "r"(c[0]), "r"(c[1])); + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=r"(c[2]), "=r"(c[3]) + : "r"(a[3]), "r"(b[1]), "r"(c[2]), "r"(c[3])); +#else + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); +#endif + } + } +} + +template +__device__ inline void mma_trans( + const typename MarlinScalarType::FragA& a_frag, + const typename MarlinScalarType::FragB& frag_b, + const typename MarlinScalarType::FragB& frag_b2, + typename MarlinScalarType::FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + const uint32_t* b2 = reinterpret_cast(&frag_b2); + float* c = reinterpret_cast(&frag_c); + using scalar_t = typename MarlinScalarType::scalar_t; + if constexpr (!std::is_same::value || k_size != 16) { + static_assert(!use_fp16_accum); + } + + if constexpr (k_size == 16) { + if constexpr (std::is_same::value && !use_fp16_accum) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]), + "f"(c[3])); + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[1]), "r"(b2[1]), "r"(a[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), + "f"(c[3])); +#else + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); +#endif + } else if constexpr (std::is_same::value && + use_fp16_accum) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 + uint32_t* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " + "{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n" + : "=r"(c[0]), "=r"(c[1]) + : "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1])); + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " + "{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n" + : "=r"(c[0]), "=r"(c[1]) + : "r"(b[1]), "r"(b2[1]), "r"(a[1]), "r"(c[0]), "r"(c[1])); +#else + uint32_t* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + "{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n" + : "=r"(c[0]), "=r"(c[1]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "r"(c[0]), "r"(c[1])); +#endif + } else if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]), + "f"(c[3])); + } else if constexpr (std::is_same::value) { + int32_t* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]), + "r"(c[3])); + } + } else { + if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + int32_t* c = reinterpret_cast(&frag_c); +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=r"(c[0]), "=r"(c[1]) + : "r"(b[0]), "r"(a[0]), "r"(c[0]), "r"(c[1])); + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=r"(c[2]), "=r"(c[3]) + : "r"(b2[1]), "r"(a[0]), "r"(c[2]), "r"(c[3])); + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=r"(c[0]), "=r"(c[1]) + : "r"(b[0]), "r"(a[1]), "r"(c[0]), "r"(c[1])); + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=r"(c[2]), "=r"(c[3]) + : "r"(b2[1]), "r"(a[1]), "r"(c[2]), "r"(c[3])); +#else + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); +#endif + } + } +} + +} // namespace MARLIN_NAMESPACE_NAME \ No newline at end of file diff --git a/csrc/quantization/gptq_marlin/marlin_template.h b/csrc/quantization/gptq_marlin/marlin_template.h index 22bb71e482ce8..c7b53696c1223 100644 --- a/csrc/quantization/gptq_marlin/marlin_template.h +++ b/csrc/quantization/gptq_marlin/marlin_template.h @@ -26,6 +26,7 @@ #include "marlin.cuh" #include "marlin_dtypes.cuh" #include "dequant.h" +#include "marlin_mma.h" #include "core/scalar_type.hpp" #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ @@ -35,7 +36,7 @@ namespace MARLIN_NAMESPACE_NAME { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 template -__device__ inline void mma( - const typename MarlinScalarType::FragA& a_frag, - const typename MarlinScalarType::FragB& frag_b, - typename MarlinScalarType::FragC& frag_c, int idx = 0) { - const uint32_t* a = reinterpret_cast(&a_frag); - const uint32_t* b = reinterpret_cast(&frag_b); - using scalar_t = typename MarlinScalarType::scalar_t; - if constexpr (k_size == 16) { - if constexpr (std::is_same::value) { - float* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else if constexpr (std::is_same::value) { - float* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else if constexpr (std::is_same::value) { - float* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 " - "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "f"(c[0]), - "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else if constexpr (std::is_same::value) { - int32_t* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite " - "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" - : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) - : "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "r"(c[0]), - "r"(c[1]), "r"(c[2]), "r"(c[3])); - } - } else if (k_size == 32) { - if constexpr (std::is_same::value) { - float* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else if constexpr (std::is_same::value) { - int32_t* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); - } - } -} - -template -__device__ inline void mma_trans( - const typename MarlinScalarType::FragA& a_frag, - const typename MarlinScalarType::FragB& frag_b, - const typename MarlinScalarType::FragB& frag_b2, - typename MarlinScalarType::FragC& frag_c) { - const uint32_t* a = reinterpret_cast(&a_frag); - const uint32_t* b = reinterpret_cast(&frag_b); - const uint32_t* b2 = reinterpret_cast(&frag_b2); - float* c = reinterpret_cast(&frag_c); - using scalar_t = typename MarlinScalarType::scalar_t; - if constexpr (k_size == 16) { - if constexpr (std::is_same::value) { - float* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else if constexpr (std::is_same::value) { - float* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else if constexpr (std::is_same::value) { - float* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 " - "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]), - "f"(c[3])); - } else if constexpr (std::is_same::value) { - int32_t* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite " - "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" - : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) - : "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]), - "r"(c[3])); - } - } else { - if constexpr (std::is_same::value) { - float* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else if constexpr (std::is_same::value) { - int32_t* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) - : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), - "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); - } - } -} - // Instruction for loading a full 16x16 matrix fragment of operand A from shared // memory, directly in tensor core layout. template @@ -415,6 +285,17 @@ __global__ void Marlin( if constexpr (a_type_id == vllm::kFE4M3fn.id()) return; #endif + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 + // Turing TensorCore only supports fp16 and int8 + if constexpr (a_type_id != vllm::kFloat16.id() && a_type_id != vllm::kS8.id()) + return; + #endif + + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 + constexpr bool use_fp16_accum = a_type_id == vllm::kFloat16.id(); + #else + constexpr bool use_fp16_accum = false; + #endif using Adtype = MarlinScalarType; using Cdtype = MarlinScalarType; const int4* A = A0; @@ -873,10 +754,6 @@ __global__ void Marlin( constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride) : (stages * s_sh_stage); int4* sh_s = sh_zp + (stages * zp_sh_stage); - // shared memory reused by reduction should be smaller than - // shared memory used by weight. - static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <= - stages * b_sh_stage); int4* sh_a = sh_s + sh_s_size; // Register storage for double buffer of shared memory reads. @@ -1395,11 +1272,13 @@ __global__ void Marlin( #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { if constexpr (m_block_size_8) { - mma_trans(frag_a[k2][i], frag_b0, frag_b1, - frag_c[i][j][0]); + mma_trans(frag_a[k2][i], frag_b0, frag_b1, + frag_c[i][j][0]); } else { - mma(frag_a[k2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k2][i], frag_b1, frag_c[i][j][1]); + mma(frag_a[k2][i], frag_b0, + frag_c[i][j][0]); + mma(frag_a[k2][i], frag_b1, + frag_c[i][j][1]); } } } @@ -1433,10 +1312,12 @@ __global__ void Marlin( #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { - mma(frag_a[k2][i], frag_b[0], - (group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]); - mma(frag_a[k2][i], frag_b[1], - (group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]); + mma( + frag_a[k2][i], frag_b[0], + (group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]); + mma( + frag_a[k2][i], frag_b[1], + (group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]); } if constexpr (group_blocks != -1) { @@ -1956,6 +1837,21 @@ __global__ void Marlin( // While this pattern may not be the most readable, other ways of writing // the loop seemed to noticeably worse performance after compilation. if (slice_iters == 0) { + // convert fp16 accum to fp32 for reduction + if constexpr (use_fp16_accum) { + #pragma unroll + for (int i = 0; i < (thread_m_blocks * (is_a_8bit ? 2 : 4) * 2); i++) { + float* frag_c_part_float = reinterpret_cast(frag_c) + i * 4; + scalar_t* frag_c_part_half = + reinterpret_cast(frag_c_part_float); + + #pragma unroll + for (int i = 3; i >= 0; i--) { + frag_c_part_float[i] = Cdtype::num2float(frag_c_part_half[i]); + } + } + } + if constexpr (is_a_8bit) { float frag_a_s[2 * thread_m_blocks]; diff --git a/csrc/sampler.cu b/csrc/sampler.cu index fc2154beff9e0..d458f8e4c1d02 100644 --- a/csrc/sampler.cu +++ b/csrc/sampler.cu @@ -550,8 +550,8 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowPrefill( int rowEnd = rowEnds[rowIdx]; // Local pointers to this block - outIndices += rowIdx * topK; - logits += rowIdx * stride0; + outIndices += static_cast(rowIdx) * topK; + logits += static_cast(rowIdx) * stride0; topKPerRowJob( nullptr, logits, rowStart, rowEnd, outIndices, nullptr, stride1, topK); @@ -576,19 +576,21 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode( // Local pointers to this block if constexpr (!multipleBlocksPerRow && !mergeBlocks) { - outIndices += rowIdx * topK; + outIndices += static_cast(rowIdx) * topK; } else if constexpr (multipleBlocksPerRow) { const auto blockSize = rowEnd / gridDim.y; // 16384 / 2 = 8192 rowStart = blockSize * blockIdx.y; // 8192 * 1 = 8192 rowEnd = gridDim.y == blockIdx.y + 1 ? rowEnd : rowStart + blockSize; - outIndices += rowIdx * gridDim.y * topK + blockIdx.y * topK; - outLogits += rowIdx * gridDim.y * topK + blockIdx.y * topK; + outIndices += + static_cast(rowIdx) * gridDim.y * topK + blockIdx.y * topK; + outLogits += + static_cast(rowIdx) * gridDim.y * topK + blockIdx.y * topK; } else if constexpr (mergeBlocks) { rowEnd = numBlocksToMerge * topK; - indices += rowIdx * numBlocksToMerge * topK; - outIndices += rowIdx * topK; + indices += static_cast(rowIdx) * numBlocksToMerge * topK; + outIndices += static_cast(rowIdx) * topK; } - logits += rowIdx * stride0; + logits += static_cast(rowIdx) * stride0; topKPerRowJob( diff --git a/docker/Dockerfile b/docker/Dockerfile index ae2624ace67b9..e61021b6eeb85 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -621,7 +621,7 @@ ENV UV_HTTP_TIMEOUT=500 RUN --mount=type=cache,target=/root/.cache/uv \ --mount=type=bind,source=requirements/kv_connectors.txt,target=/tmp/kv_connectors.txt,ro \ if [ "$INSTALL_KV_CONNECTORS" = "true" ]; then \ - uv pip install --system -r /tmp/kv_connectors.txt; \ + uv pip install --system -r /tmp/kv_connectors.txt || true; \ fi ENV VLLM_USAGE_SOURCE production-docker-image diff --git a/requirements/common.txt b/requirements/common.txt index 31c8fb404f63a..426d281c26704 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -50,5 +50,5 @@ ijson # Required for mistral streaming tool parser setproctitle # Used to set process names for better debugging and monitoring openai-harmony >= 0.0.3 # Required for gpt-oss anthropic == 0.71.0 -model-hosting-container-standards >= 0.1.9, < 1.0.0 -mcp \ No newline at end of file +model-hosting-container-standards >= 0.1.10, < 1.0.0 +mcp diff --git a/tests/compile/distributed/test_fusions_e2e.py b/tests/compile/distributed/test_fusions_e2e.py index bd326f1157d8f..960b5b4bd7ad4 100644 --- a/tests/compile/distributed/test_fusions_e2e.py +++ b/tests/compile/distributed/test_fusions_e2e.py @@ -523,6 +523,8 @@ CUSTOM_OPS_QUANT_RMS_NORM = ["+quant_fp8,+rms_norm"] list[tuple[Any, ...]](flat_product(MODELS_GROUP_FP8, CUSTOM_OPS_QUANT_RMS_NORM)), ) @pytest.mark.parametrize("inductor_graph_partition", [True, False]) +# TODO: remove skip after we fix the fusion thoroughly +@pytest.mark.skipif(is_blackwell(), reason="Temporarily disabled on Blackwell") def test_rms_group_quant( model_name: str, model_kwargs: dict[str, Any], @@ -562,7 +564,9 @@ def test_rms_group_quant( splitting_ops=splitting_ops, # Common mode=CompilationMode.VLLM_COMPILE, - pass_config=PassConfig(eliminate_noops=True, enable_fusion=True), + pass_config=PassConfig( + fuse_norm_quant=True, fuse_act_quant=True, eliminate_noops=True + ), # Inductor caches custom passes by default as well via uuid inductor_compile_config={"force_disable_caches": True}, ) diff --git a/tests/entrypoints/openai/test_transcription_validation_whisper.py b/tests/entrypoints/openai/test_transcription_validation_whisper.py index 3c507ee0a3fa7..8bf729c517f7a 100644 --- a/tests/entrypoints/openai/test_transcription_validation_whisper.py +++ b/tests/entrypoints/openai/test_transcription_validation_whisper.py @@ -244,3 +244,35 @@ async def test_audio_with_timestamp(mary_had_lamb, whisper_client): ) assert transcription.segments is not None assert len(transcription.segments) > 0 + + +@pytest.mark.asyncio +async def test_audio_with_max_tokens(whisper_client, mary_had_lamb): + transcription = await whisper_client.audio.transcriptions.create( + model=MODEL_NAME, + file=mary_had_lamb, + language="en", + response_format="text", + temperature=0.0, + extra_body={"max_completion_tokens": 1}, + ) + out = json.loads(transcription) + out_text = out["text"] + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained(MODEL_NAME) + out_tokens = tok(out_text, add_special_tokens=False)["input_ids"] + assert len(out_tokens) == 1 + # max_completion_tokens > max_model_len + transcription = await whisper_client.audio.transcriptions.create( + model=MODEL_NAME, + file=mary_had_lamb, + language="en", + response_format="text", + temperature=0.0, + extra_body={"max_completion_tokens": int(1e6)}, + ) + out = json.loads(transcription) + out_text = out["text"] + out_tokens = tok(out_text, add_special_tokens=False)["input_ids"] + assert len(out_tokens) < 450 # ~Whisper max output len diff --git a/tests/entrypoints/openai/test_translation_validation.py b/tests/entrypoints/openai/test_translation_validation.py index d7d407484f16d..2c577237691ab 100644 --- a/tests/entrypoints/openai/test_translation_validation.py +++ b/tests/entrypoints/openai/test_translation_validation.py @@ -227,3 +227,36 @@ async def test_long_audio_request(foscolo, client_and_model): ) out = json.loads(translation)["text"].strip().lower() assert out.count("greek sea") == 2 + + +@pytest.mark.asyncio +async def test_audio_with_max_tokens(mary_had_lamb, client_and_model): + client, model_name = client_and_model + transcription = await client.audio.translations.create( + model=model_name, + file=mary_had_lamb, + response_format="text", + temperature=0.0, + extra_body={"max_completion_tokens": 1}, + ) + out = json.loads(transcription) + out_text = out["text"] + print(out_text) + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained(model_name) + out_tokens = tok(out_text, add_special_tokens=False)["input_ids"] + assert len(out_tokens) == 1 + # max_completion_tokens > max_model_len + transcription = await client.audio.transcriptions.create( + model=model_name, + file=mary_had_lamb, + response_format="text", + temperature=0.0, + extra_body={"max_completion_tokens": int(1e6)}, + ) + out = json.loads(transcription) + out_text = out["text"] + print(out_text) + out_tokens = tok(out_text, add_special_tokens=False)["input_ids"] + assert len(out_tokens) < 450 # ~Whisper max output len diff --git a/tests/evals/gsm8k/README.md b/tests/evals/gsm8k/README.md index 29c5199e1e87a..dcbfd85bfeee8 100644 --- a/tests/evals/gsm8k/README.md +++ b/tests/evals/gsm8k/README.md @@ -7,9 +7,8 @@ This directory contains a replacement for the lm-eval-harness GSM8K evaluation, ### Run tests with pytest (like buildkite) ```bash -pytest -s -v tests/gsm8k/test_gsm8k_correctness.py \ - --config-list-file=configs/models-small.txt \ - --tp-size=1 +pytest -s -v tests/evals/gsm8k/test_gsm8k_correctness.py \ + --config-list-file=configs/models-small.txt ``` ### Run standalone evaluation script @@ -31,5 +30,11 @@ model_name: "Qwen/Qwen2.5-1.5B-Instruct" accuracy_threshold: 0.54 # Minimum expected accuracy num_questions: 1319 # Number of questions (default: full test set) num_fewshot: 5 # Few-shot examples from train set -max_model_len: 4096 # Model context length +server_args: "--max-model-len 4096 --tensor-parallel-size 2" # Server arguments +env: # Environment variables (optional) + VLLM_USE_FLASHINFER_MOE_FP4: "1" ``` + +The `server_args` field accepts any arguments that can be passed to `vllm serve`. + +The `env` field accepts a dictionary of environment variables to set for the server process. diff --git a/tests/evals/gsm8k/configs/DeepSeek-V2-Lite-Instruct-FP8.yaml b/tests/evals/gsm8k/configs/DeepSeek-V2-Lite-Instruct-FP8.yaml index 7ec6a1e0be27f..72fa7e8a38c73 100644 --- a/tests/evals/gsm8k/configs/DeepSeek-V2-Lite-Instruct-FP8.yaml +++ b/tests/evals/gsm8k/configs/DeepSeek-V2-Lite-Instruct-FP8.yaml @@ -2,5 +2,4 @@ model_name: "RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8" accuracy_threshold: 0.72 num_questions: 1319 num_fewshot: 5 -max_model_len: 4096 - +server_args: "--enforce-eager --max-model-len 4096" diff --git a/tests/evals/gsm8k/configs/Llama-3-8B-Instruct-nonuniform-CT.yaml b/tests/evals/gsm8k/configs/Llama-3-8B-Instruct-nonuniform-CT.yaml index caa0448f23d48..b7b59e9dcd5ce 100644 --- a/tests/evals/gsm8k/configs/Llama-3-8B-Instruct-nonuniform-CT.yaml +++ b/tests/evals/gsm8k/configs/Llama-3-8B-Instruct-nonuniform-CT.yaml @@ -2,4 +2,4 @@ model_name: "nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test" accuracy_threshold: 0.74 num_questions: 1319 num_fewshot: 5 -max_model_len: 4096 \ No newline at end of file +server_args: "--enforce-eager --max-model-len 4096" diff --git a/tests/evals/gsm8k/configs/Llama-3.2-1B-Instruct-INT8-CT.yaml b/tests/evals/gsm8k/configs/Llama-3.2-1B-Instruct-INT8-CT.yaml index 615aa69a2d2b6..8b3c9ff645e87 100644 --- a/tests/evals/gsm8k/configs/Llama-3.2-1B-Instruct-INT8-CT.yaml +++ b/tests/evals/gsm8k/configs/Llama-3.2-1B-Instruct-INT8-CT.yaml @@ -2,4 +2,4 @@ model_name: "RedHatAI/Llama-3.2-1B-Instruct-quantized.w8a8" accuracy_threshold: 0.31 num_questions: 1319 num_fewshot: 5 -max_model_len: 4096 \ No newline at end of file +server_args: "--enforce-eager --max-model-len 4096" diff --git a/tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml b/tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml index 9297bf6ddf2d3..4a1b1948acac8 100644 --- a/tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml +++ b/tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml @@ -2,4 +2,4 @@ model_name: "nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16" accuracy_threshold: 0.45 num_questions: 1319 num_fewshot: 5 -max_model_len: 4096 +server_args: "--enforce-eager --max-model-len 4096" diff --git a/tests/evals/gsm8k/configs/Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml b/tests/evals/gsm8k/configs/Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml index 5319ada30f645..5ce3af8be346a 100644 --- a/tests/evals/gsm8k/configs/Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml +++ b/tests/evals/gsm8k/configs/Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml @@ -2,4 +2,4 @@ model_name: "RedHatAI/Qwen2.5-VL-3B-Instruct-FP8-Dynamic" accuracy_threshold: 0.60 num_questions: 1319 num_fewshot: 5 -max_model_len: 4096 \ No newline at end of file +server_args: "--enforce-eager --max-model-len 4096" diff --git a/tests/evals/gsm8k/configs/Qwen3-0.6B-FP8.yaml b/tests/evals/gsm8k/configs/Qwen3-0.6B-FP8.yaml index c39fb979d98ac..5452ebe753f04 100644 --- a/tests/evals/gsm8k/configs/Qwen3-0.6B-FP8.yaml +++ b/tests/evals/gsm8k/configs/Qwen3-0.6B-FP8.yaml @@ -2,4 +2,4 @@ model_name: "Qwen/Qwen3-0.6B-FP8" accuracy_threshold: 0.375 num_questions: 1319 num_fewshot: 5 -max_model_len: 4096 \ No newline at end of file +server_args: "--enforce-eager --max-model-len 4096" diff --git a/tests/evals/gsm8k/configs/Qwen3-30B-A3B-NVFP4.yaml b/tests/evals/gsm8k/configs/Qwen3-30B-A3B-NVFP4.yaml index 6b7bdd1e65bb3..f162aa8bfe5b0 100644 --- a/tests/evals/gsm8k/configs/Qwen3-30B-A3B-NVFP4.yaml +++ b/tests/evals/gsm8k/configs/Qwen3-30B-A3B-NVFP4.yaml @@ -2,5 +2,4 @@ model_name: "nvidia/Qwen3-30B-A3B-FP4" accuracy_threshold: 0.89 num_questions: 1319 num_fewshot: 5 -max_model_len: 4096 - +server_args: "--enforce-eager --max-model-len 4096" diff --git a/tests/evals/gsm8k/configs/Qwen3-Next-80B-A3B-NVFP4-EP2.yaml b/tests/evals/gsm8k/configs/Qwen3-Next-80B-A3B-NVFP4-EP2.yaml new file mode 100644 index 0000000000000..673b473f817eb --- /dev/null +++ b/tests/evals/gsm8k/configs/Qwen3-Next-80B-A3B-NVFP4-EP2.yaml @@ -0,0 +1,12 @@ +model_name: "nm-testing/Qwen3-Next-80B-A3B-Instruct-NVFP4" +accuracy_threshold: 0.75 +num_questions: 1319 +num_fewshot: 5 +server_args: >- + --enforce-eager + --max-model-len 4096 + --tensor-parallel-size 2 + --enable-expert-parallel + --speculative-config '{"method":"qwen3_next_mtp","num_speculative_tokens":1}' +env: + VLLM_USE_FLASHINFER_MOE_FP4: "1" diff --git a/tests/evals/gsm8k/configs/models-blackwell.txt b/tests/evals/gsm8k/configs/models-blackwell.txt index 3c9b1084de7bc..39978aa6ffbe9 100644 --- a/tests/evals/gsm8k/configs/models-blackwell.txt +++ b/tests/evals/gsm8k/configs/models-blackwell.txt @@ -3,3 +3,4 @@ Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml Qwen1.5-MoE-W4A16-CT.yaml DeepSeek-V2-Lite-Instruct-FP8.yaml Qwen3-30B-A3B-NVFP4.yaml +Qwen3-Next-80B-A3B-NVFP4-EP2.yaml diff --git a/tests/evals/gsm8k/conftest.py b/tests/evals/gsm8k/conftest.py index 1932a13cdfc63..6f25fe6414af4 100644 --- a/tests/evals/gsm8k/conftest.py +++ b/tests/evals/gsm8k/conftest.py @@ -11,14 +11,12 @@ def pytest_addoption(parser): default="configs/models-small.txt", help="File containing list of config files to test", ) - parser.addoption("--tp-size", default=1, type=int, help="Tensor parallel size") def pytest_generate_tests(metafunc): """Generate test parameters from config files.""" if "config_filename" in metafunc.fixturenames: config_list_file = metafunc.config.getoption("--config-list-file") - tp_size = metafunc.config.getoption("--tp-size") # Handle both relative and absolute paths config_list_path = Path(config_list_file) @@ -55,9 +53,9 @@ def pytest_generate_tests(metafunc): # Generate test parameters if config_files: metafunc.parametrize( - ["config_filename", "tp_size"], - [(config_file, int(tp_size)) for config_file in config_files], - ids=[f"{config_file.stem}-tp{tp_size}" for config_file in config_files], + "config_filename", + config_files, + ids=[config_file.stem for config_file in config_files], ) else: print("No config files found, test will be skipped") diff --git a/tests/evals/gsm8k/test_gsm8k_correctness.py b/tests/evals/gsm8k/test_gsm8k_correctness.py index b5d67df7bf3db..ea6715f5cb532 100644 --- a/tests/evals/gsm8k/test_gsm8k_correctness.py +++ b/tests/evals/gsm8k/test_gsm8k_correctness.py @@ -5,30 +5,31 @@ GSM8K evaluation using vLLM server and isolated GSM8K script. Replacement for lm-eval-harness with better performance and control. Usage: -pytest -s -v test_gsm8k_correctness.py \ - --config-list-file=configs/models-small.txt \ - --tp-size=1 +pytest -s -v tests/evals/gsm8k/test_gsm8k_correctness.py \ + --config-list-file=configs/models-small.txt """ +import shlex + import yaml from tests.utils import RemoteOpenAIServer from .gsm8k_eval import evaluate_gsm8k -RTOL = 0.08 # Relative tolerance for accuracy comparison +TOL = 0.08 # Absolute tolerance for accuracy comparison -def launch_gsm8k_eval(eval_config, server_url, tp_size): - """Launch GSM8K evaluation using our isolated script.""" +def run_gsm8k_eval(eval_config: dict, server_url: str) -> dict: + """Run GSM8K evaluation using our isolated script.""" # Extract host and port from server URL if "://" in server_url: server_url = server_url.split("://")[1] host_port = server_url.split("/")[0] # Remove path if present if ":" in host_port: - host, port = host_port.split(":") - port = int(port) + host, p = host_port.split(":") + port = int(p) else: host = host_port port = 8000 @@ -48,46 +49,57 @@ def launch_gsm8k_eval(eval_config, server_url, tp_size): return results -def test_gsm8k_correctness_param(config_filename, tp_size): +def test_gsm8k_correctness(config_filename): """Test GSM8K correctness for a given model configuration.""" eval_config = yaml.safe_load(config_filename.read_text(encoding="utf-8")) - # Server arguments - server_args = [ - "--max-model-len", - str(eval_config.get("max_model_len", 4096)), - "--enforce-eager", - "--trust-remote-code", - "--tensor-parallel-size", - str(tp_size), - ] + # Parse server arguments from config (use shlex to handle quoted strings) + server_args_str = eval_config.get("server_args", "") + server_args = shlex.split(server_args_str) if server_args_str else [] + + # Add standard server arguments + server_args.extend( + [ + "--trust-remote-code", + ] + ) env_dict = eval_config.get("env", None) + print(f"Starting GSM8K evaluation for model: {eval_config['model_name']}") + print(f"Expected metric threshold: {eval_config['accuracy_threshold']}") + print(f"Number of questions: {eval_config['num_questions']}") + print(f"Number of few-shot examples: {eval_config['num_fewshot']}") + print(f"Server args: {' '.join(server_args)}") + # Launch server and run evaluation with RemoteOpenAIServer( - eval_config["model_name"], server_args, env_dict=env_dict, max_wait_seconds=480 + eval_config["model_name"], + server_args, + env_dict=env_dict, + max_wait_seconds=600, ) as remote_server: server_url = remote_server.url_for("v1") + print(f"Server started at: {server_url}") - results = launch_gsm8k_eval(eval_config, server_url, tp_size) + results = run_gsm8k_eval(eval_config, server_url) - # Check accuracy against threshold - measured_accuracy = results["accuracy"] - expected_accuracy = eval_config["accuracy_threshold"] + measured_metric = results["accuracy"] + expected_metric = eval_config["accuracy_threshold"] print(f"GSM8K Results for {eval_config['model_name']}:") - print(f" Accuracy: {measured_accuracy:.3f}") - print(f" Expected: {expected_accuracy:.3f}") + print(f" Measured metric: {measured_metric:.4f}") + print(f" Expected metric: {expected_metric:.4f}") + print(f" Tolerance: {TOL:.4f}") print(f" Questions: {results['num_questions']}") print(f" Invalid rate: {results['invalid_rate']:.3f}") print(f" Latency: {results['latency']:.1f}s") print(f" QPS: {results['questions_per_second']:.1f}") - # Verify accuracy is within tolerance - assert measured_accuracy >= expected_accuracy - RTOL, ( - f"Accuracy too low: {measured_accuracy:.3f} < " - f"{expected_accuracy:.3f} - {RTOL:.3f}" + # Verify metric is within tolerance + assert measured_metric >= expected_metric - TOL, ( + f"GSM8K metric too low: {measured_metric:.4f} < " + f"{expected_metric:.4f} - {TOL:.4f} = {expected_metric - TOL:.4f}" ) print(f"✅ GSM8K test passed for {eval_config['model_name']}") diff --git a/tests/models/multimodal/processing/test_mllama4.py b/tests/models/multimodal/processing/test_mllama4.py index e5ff2d1391b62..325159965c803 100644 --- a/tests/models/multimodal/processing/test_mllama4.py +++ b/tests/models/multimodal/processing/test_mllama4.py @@ -60,12 +60,12 @@ def test_profiling(model_id: str, max_model_len: int): total_num_patches.item() + num_tiles.item() + 3 ) # image start, image, image end - profiled_tokens = profiler.get_mm_max_contiguous_tokens( + profiled_tokens = profiler.get_mm_max_tokens( max_model_len, mm_counts=mm_counts, ) - assert total_tokens == profiled_tokens["image"] + assert total_num_patches == profiled_tokens["image"] assert total_tokens == sum( placeholder.length for placeholder in decoder_dummy_data.multi_modal_placeholders["image"] diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index 636cd0ffd445e..02bb1f769baad 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -9,6 +9,7 @@ from tempfile import NamedTemporaryFile, TemporaryDirectory import numpy as np import pytest +import torch from PIL import Image, ImageChops from vllm.multimodal.image import convert_image_mode @@ -410,6 +411,97 @@ def test_argsort_mm_positions(case): assert modality_idxs == expected_modality_idxs +@pytest.mark.parametrize( + "is_embed,expected", + [ + (None, 5), + (torch.tensor([True, True, True, True, True]), 5), + (torch.tensor([False, False, False, False, False]), 0), + (torch.tensor([True, False, True, False, True]), 3), + (torch.tensor([True]), 1), + ], +) +def test_placeholder_range_get_num_embeds(is_embed, expected): + length = len(is_embed) if is_embed is not None else 5 + pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed) + assert pr.get_num_embeds == expected + + +@pytest.mark.parametrize( + "is_embed,expected", + [ + (None, None), + ( + torch.tensor([False, True, False, True, True]), + torch.tensor([0, 1, 1, 2, 3]), + ), + (torch.tensor([True, True, True]), torch.tensor([1, 2, 3])), + ], +) +def test_placeholder_range_embeds_cumsum(is_embed, expected): + length = len(is_embed) if is_embed is not None else 5 + pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed) + + if expected is None: + assert pr.embeds_cumsum is None + return + + assert torch.equal(pr.embeds_cumsum, expected) + # cached_property should return the same object on repeated access + assert pr.embeds_cumsum is pr.embeds_cumsum + + +@pytest.mark.parametrize( + "is_embed,start_idx,end_idx,expected", + [ + (None, 2, 4, (2, 4)), + ( + torch.tensor([False, True, False, True, True]), + 3, + 5, + (1, 3), + ), + ( + torch.tensor([False, True, False, True, True]), + 0, + 2, + (0, 1), + ), + ( + torch.tensor([True, False, True, False]), + 2, + 2, + (1, 1), + ), + ], +) +def test_placeholder_range_get_embeds_indices_in_range( + is_embed, start_idx, end_idx, expected +): + length = len(is_embed) if is_embed is not None else 5 + pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed) + assert pr.get_embeds_indices_in_range(start_idx, end_idx) == expected + + +@pytest.mark.parametrize( + "offset,is_embed,expected", + [ + (0, None, [(0, 4)]), + ( + 2, + torch.tensor([False, True, False, True, True]), + [(3, 3), (5, 6)], + ), + (0, torch.tensor([True, True, True, True]), [(0, 3)]), + (0, torch.tensor([False, False, False, False]), []), + ], +) +def test_placeholder_range_extract_embeds_range(offset, is_embed, expected): + length = len(is_embed) if is_embed is not None else 5 + pr = PlaceholderRange(offset=offset, length=length, is_embed=is_embed) + assert pr.extract_embeds_range() == expected + + @pytest.mark.asyncio @pytest.mark.parametrize("video_url", TEST_VIDEO_URLS) @pytest.mark.parametrize("num_frames", [-1, 32, 1800]) diff --git a/tests/v1/attention/test_chunked_local_attention.py b/tests/v1/attention/test_chunked_local_attention.py index faace3473a281..4529c2cfc29b6 100644 --- a/tests/v1/attention/test_chunked_local_attention.py +++ b/tests/v1/attention/test_chunked_local_attention.py @@ -172,7 +172,7 @@ def test_local_attention_virtual_batches(test_data: LocalAttentionTestData): ) # Call the function - result = make_local_attention_virtual_batches( + result, _ = make_local_attention_virtual_batches( attn_chunk_size, common_attn_metadata, block_size ) diff --git a/tests/v1/core/test_encoder_cache_manager.py b/tests/v1/core/test_encoder_cache_manager.py index 8a52b5bd78977..511ff48c401ca 100644 --- a/tests/v1/core/test_encoder_cache_manager.py +++ b/tests/v1/core/test_encoder_cache_manager.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest +import torch from vllm.multimodal.inputs import MultiModalFeatureSpec, PlaceholderRange from vllm.v1.core.encoder_cache_manager import EncoderCacheManager @@ -23,7 +24,7 @@ class MockRequest: ) self.mm_features.append(feature) - def get_num_encoder_tokens(self, input_id: int) -> int: + def get_num_encoder_embeds(self, input_id: int) -> int: return self._token_counts[input_id] @@ -162,8 +163,8 @@ def test_schedule_request_multi_images_respect_space_limit(): num_tokens_to_schedule = 0 assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule) - num_tokens_to_schedule += req.get_num_encoder_tokens(0) - compute_budget -= req.get_num_encoder_tokens(0) + num_tokens_to_schedule += req.get_num_encoder_embeds(0) + compute_budget -= req.get_num_encoder_embeds(0) assert not manager.can_allocate(req, 1, compute_budget, num_tokens_to_schedule) @@ -174,7 +175,75 @@ def test_schedule_request_multi_images_respect_compute_limit(): compute_budget = 10 num_tokens_to_schedule = 0 assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule) - num_tokens_to_schedule += req.get_num_encoder_tokens(0) - compute_budget -= req.get_num_encoder_tokens(0) + num_tokens_to_schedule += req.get_num_encoder_embeds(0) + compute_budget -= req.get_num_encoder_embeds(0) assert not manager.can_allocate(req, 1, compute_budget, num_tokens_to_schedule) + + +def test_encoder_cache_with_is_embed_mask(): + class MockRequestWithMask(MockRequest): + def get_num_encoder_embeds(self, input_id: int) -> int: + return self.mm_features[input_id].mm_position.get_num_embeds + + is_embed = torch.zeros(100, dtype=torch.bool) + is_embed[torch.tensor([5, 15, 25, 35, 45, 55, 65, 75])] = True + + request = MockRequestWithMask("r1", ["img1"], [100]) + request.mm_features[0] = MultiModalFeatureSpec( + data=None, + modality="image", + identifier="img1", + mm_position=PlaceholderRange(offset=0, length=100, is_embed=is_embed), + ) + + manager = EncoderCacheManager(cache_size=100) + manager.allocate(request, 0) + + assert manager.num_free_slots == 92 + assert "img1" in manager.cached + + old_size = 100 + new_size = request.mm_features[0].mm_position.get_num_embeds + assert new_size == 8 + savings_ratio = old_size / new_size + assert savings_ratio == 12.5 + + +def test_encoder_cache_mask_based_retrieval(): + class MockRequestWithMask(MockRequest): + def get_num_encoder_embeds(self, input_id: int) -> int: + return self.mm_features[input_id].mm_position.get_num_embeds + + is_embed = torch.tensor( + [False, False, True, True, False, True, True, True, False, False] + ) + + request = MockRequestWithMask("r1", ["img1"], [10]) + request.mm_features[0] = MultiModalFeatureSpec( + data=None, + modality="image", + identifier="img1", + mm_position=PlaceholderRange(offset=0, length=10, is_embed=is_embed), + ) + + manager = EncoderCacheManager(cache_size=50) + manager.allocate(request, 0) + + assert request.mm_features[0].mm_position.get_num_embeds == 5 + + start_idx = 2 + end_idx = 8 + num_embeds_before = is_embed[:start_idx].sum().item() + num_embeds_in_range = is_embed[start_idx:end_idx].sum().item() + + assert num_embeds_before == 0 + assert num_embeds_in_range == 5 + + start_idx = 0 + end_idx = 5 + num_embeds_before = is_embed[:start_idx].sum().item() if start_idx > 0 else 0 + num_embeds_in_range = is_embed[start_idx:end_idx].sum().item() + + assert num_embeds_before == 0 + assert num_embeds_in_range == 2 diff --git a/tests/v1/determinism/test_batch_invariance.py b/tests/v1/determinism/test_batch_invariance.py index 1c45e7fe366ff..7a58e1c9bad03 100644 --- a/tests/v1/determinism/test_batch_invariance.py +++ b/tests/v1/determinism/test_batch_invariance.py @@ -188,7 +188,6 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( llm = LLM( model=model_name, tensor_parallel_size=tp_size, - # enable_prefix_caching=False, max_num_seqs=32, max_model_len=8192, dtype="bfloat16", # not everything is supported diff --git a/tests/v1/ec_connector/unit/test_ec_example_connector.py b/tests/v1/ec_connector/unit/test_ec_example_connector.py index 7e9eb21310031..9ed82e1cef823 100644 --- a/tests/v1/ec_connector/unit/test_ec_example_connector.py +++ b/tests/v1/ec_connector/unit/test_ec_example_connector.py @@ -38,7 +38,7 @@ class MockRequest: ) self.mm_features.append(feature) - def get_num_encoder_tokens(self, input_id: int) -> int: + def get_num_encoder_embeds(self, input_id: int) -> int: assert input_id < len(self._token_counts) return self._token_counts[input_id] diff --git a/vllm/attention/layers/chunked_local_attention.py b/vllm/attention/layers/chunked_local_attention.py index 0ced0028ded9e..7e3794d408332 100644 --- a/vllm/attention/layers/chunked_local_attention.py +++ b/vllm/attention/layers/chunked_local_attention.py @@ -4,7 +4,7 @@ import functools import torch -from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata +from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.layer import Attention from vllm.attention.selector import get_attn_backend from vllm.config import CacheConfig @@ -51,11 +51,19 @@ def create_chunked_local_attention_backend( common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False, - ) -> AttentionMetadata: - common_attn_metadata = make_local_attention_virtual_batches( + ): + cm, make_virtual_batches_block_table = make_local_attention_virtual_batches( attention_chunk_size, common_attn_metadata, block_size ) - return super().build(common_prefix_len, common_attn_metadata, fast_build) + metadata = super().build(common_prefix_len, cm, fast_build) + metadata.make_virtual_batches_block_table = make_virtual_batches_block_table + return metadata + + def update_block_table( + self, metadata, blk_table: torch.Tensor, slot_mapping: torch.Tensor + ): + blk_table = metadata.make_virtual_batches_block_table(blk_table) + return super().update_block_table(metadata, blk_table, slot_mapping) attn_backend = subclass_attention_backend( name_prefix=prefix, diff --git a/vllm/attention/layers/mm_encoder_attention.py b/vllm/attention/layers/mm_encoder_attention.py index c9107ebcab856..8b3dee1340b9f 100644 --- a/vllm/attention/layers/mm_encoder_attention.py +++ b/vllm/attention/layers/mm_encoder_attention.py @@ -10,6 +10,7 @@ from vllm.attention.ops.vit_attn_wrappers import ( vit_flash_attn_wrapper, vit_torch_sdpa_wrapper, ) +from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.config import MultiModalConfig from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp @@ -101,6 +102,10 @@ class MMEncoderAttention(CustomOp): self.attn_backend, ) + if self.is_flash_attn_backend: + assert self.flash_attn_varlen_func is not None + self._fa_version = get_flash_attn_version() + logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.") @classmethod @@ -204,6 +209,7 @@ class MMEncoderAttention(CustomOp): max_seqlen=max_seqlen, batch_size=bsz, is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA), + fa_version=self._fa_version, ) return output diff --git a/vllm/attention/ops/vit_attn_wrappers.py b/vllm/attention/ops/vit_attn_wrappers.py index 46c7d83dfa5c2..5a74e1310133d 100644 --- a/vllm/attention/ops/vit_attn_wrappers.py +++ b/vllm/attention/ops/vit_attn_wrappers.py @@ -16,6 +16,7 @@ import einops import torch import torch.nn.functional as F +from vllm.platforms import current_platform from vllm.utils.torch_utils import direct_register_custom_op @@ -27,11 +28,15 @@ def flash_attn_maxseqlen_wrapper( max_seqlen: torch.Tensor, batch_size: int, is_rocm_aiter: bool, + fa_version: int, ) -> torch.Tensor: + kwargs = {} if is_rocm_aiter: from aiter import flash_attn_varlen_func else: from vllm.attention.utils.fa_utils import flash_attn_varlen_func + + kwargs["fa_version"] = fa_version q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) output = flash_attn_varlen_func( q, @@ -43,6 +48,7 @@ def flash_attn_maxseqlen_wrapper( max_seqlen_k=max_seqlen.item(), dropout_p=0.0, causal=False, + **kwargs, ) context_layer = einops.rearrange(output, "(b s) h d -> b s h d", b=batch_size) return context_layer @@ -56,6 +62,7 @@ def flash_attn_maxseqlen_wrapper_fake( max_seqlen: torch.Tensor, batch_size: int, is_rocm_aiter: bool, + fa_version: int, ) -> torch.Tensor: return torch.empty_like(q) @@ -75,9 +82,10 @@ def vit_flash_attn_wrapper( max_seqlen: torch.Tensor, batch_size: int, is_rocm_aiter: bool, + fa_version: int, ) -> torch.Tensor: return torch.ops.vllm.flash_attn_maxseqlen_wrapper( - q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter + q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, fa_version ) @@ -89,6 +97,13 @@ def torch_sdpa_wrapper( v: torch.Tensor, cu_seqlens: torch.Tensor, ) -> torch.Tensor: + # Never remove the contiguous logic for ROCm + # Without it, hallucinations occur with the backend + if current_platform.is_rocm(): + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + outputs = [] lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index c40dde26b741f..7a4e81cf967de 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -64,7 +64,12 @@ class NaiveAll2AllManager(All2AllManagerBase): hidden_states: torch.Tensor, router_logits: torch.Tensor, is_sequence_parallel: bool = False, + extra_tensors: list[torch.Tensor] | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: + if extra_tensors is not None: + raise NotImplementedError( + "extra_tensors is not supported for NaiveAll2AllManager" + ) sp_size = self.tp_group.world_size if is_sequence_parallel else 1 dp_metadata = get_forward_context().dp_metadata assert dp_metadata is not None @@ -76,6 +81,7 @@ class NaiveAll2AllManager(All2AllManagerBase): router_logits = self.naive_multicast( router_logits, cu_tokens_across_sp_cpu, is_sequence_parallel ) + return hidden_states, router_logits def combine( @@ -113,7 +119,11 @@ class AgRsAll2AllManager(All2AllManagerBase): hidden_states: torch.Tensor, router_logits: torch.Tensor, is_sequence_parallel: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor]: + extra_tensors: list[torch.Tensor] | None = None, + ) -> ( + tuple[torch.Tensor, torch.Tensor] + | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]] + ): """ Gather hidden_states and router_logits from all dp ranks. """ @@ -121,15 +131,22 @@ class AgRsAll2AllManager(All2AllManagerBase): assert dp_metadata is not None sizes = dp_metadata.get_chunk_sizes_across_dp_rank() assert sizes is not None - dist_group = get_ep_group() if is_sequence_parallel else get_dp_group() assert sizes[dist_group.rank_in_group] == hidden_states.shape[0] - hidden_states, router_logits = dist_group.all_gatherv( - [hidden_states, router_logits], + + tensors_to_gather = [hidden_states, router_logits] + if extra_tensors is not None: + tensors_to_gather.extend(extra_tensors) + + gathered_tensors = dist_group.all_gatherv( + tensors_to_gather, dim=0, sizes=sizes, ) - return hidden_states, router_logits + + if extra_tensors is not None: + return (gathered_tensors[0], gathered_tensors[1], gathered_tensors[2:]) + return gathered_tensors[0], gathered_tensors[1] def combine( self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False @@ -204,6 +221,7 @@ class PPLXAll2AllManager(All2AllManagerBase): hidden_states: torch.Tensor, router_logits: torch.Tensor, is_sequence_parallel: bool = False, + extra_tensors: list[torch.Tensor] | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError @@ -251,6 +269,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase): hidden_states: torch.Tensor, router_logits: torch.Tensor, is_sequence_parallel: bool = False, + extra_tensors: list[torch.Tensor] | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index 3a849da70e4cb..caeff54406b59 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import threading +from typing import Any from weakref import WeakValueDictionary import torch @@ -68,7 +69,11 @@ class All2AllManagerBase: hidden_states: torch.Tensor, router_logits: torch.Tensor, is_sequence_parallel: bool = False, - ): + extra_tensors: list[torch.Tensor] | None = None, + ) -> Any: + # Subclasses should either: + # - implement handling for extra_tensors, or + # - raise a clear error if extra_tensors is not supported. raise NotImplementedError def set_num_sms(self, num_sms: int): diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index cd9c267beb5b5..9542498c453ec 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -318,17 +318,23 @@ class CudaCommunicator(DeviceCommunicatorBase): return output_list - def dispatch( + def dispatch( # type: ignore[override] self, hidden_states: torch.Tensor, router_logits: torch.Tensor, is_sequence_parallel: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor]: + extra_tensors: list[torch.Tensor] | None = None, + ) -> ( + tuple[torch.Tensor, torch.Tensor] + | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]] + ): assert self.all2all_manager is not None - hidden_states, router_logits = self.all2all_manager.dispatch( - hidden_states, router_logits, is_sequence_parallel + return self.all2all_manager.dispatch( + hidden_states, + router_logits, + is_sequence_parallel, + extra_tensors, # type: ignore[call-arg] ) - return hidden_states, router_logits def combine( self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False diff --git a/vllm/distributed/ec_transfer/ec_connector/example_connector.py b/vllm/distributed/ec_transfer/ec_connector/example_connector.py index 5f2eff5a8e6a8..c9aad9e9fc8f3 100644 --- a/vllm/distributed/ec_transfer/ec_connector/example_connector.py +++ b/vllm/distributed/ec_transfer/ec_connector/example_connector.py @@ -144,7 +144,7 @@ class ECExampleConnector(ECConnectorBase): Update ECConnector state after encoder cache allocation. """ mm_hash = request.mm_features[index].identifier - num_encoder_token = request.get_num_encoder_tokens(index) + num_encoder_token = request.get_num_encoder_embeds(index) # Insert mm_hash only if this block has not been recorded yet. self._mm_datas_need_loads[mm_hash] = num_encoder_token diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 338cb1f1814b5..f5ada5a009ec3 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1007,10 +1007,17 @@ class GroupCoordinator: hidden_states: torch.Tensor, router_logits: torch.Tensor, is_sequence_parallel: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor]: + extra_tensors: list[torch.Tensor] | None = None, + ) -> ( + tuple[torch.Tensor, torch.Tensor] + | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]] + ): if self.device_communicator is not None: - return self.device_communicator.dispatch( - hidden_states, router_logits, is_sequence_parallel + return self.device_communicator.dispatch( # type: ignore[call-arg] + hidden_states, + router_logits, + is_sequence_parallel, + extra_tensors, ) else: return hidden_states, router_logits diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index a7c4980cd3674..94dde4564ea0c 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -2054,6 +2054,9 @@ class TranscriptionRequest(OpenAIBaseModel): presence_penalty: float | None = 0.0 """The presence penalty to use for sampling.""" + + max_completion_tokens: int | None = None + """The maximum number of tokens to generate.""" # --8<-- [end:transcription-sampling-params] # Default sampling parameters for transcription requests. @@ -2300,6 +2303,9 @@ class TranslationRequest(OpenAIBaseModel): # Flattened stream option to simplify form data. stream_include_usage: bool | None = False stream_continuous_usage_stats: bool | None = False + + max_completion_tokens: int | None = None + """The maximum number of tokens to generate.""" # --8<-- [end:translation-extra-params] # Default sampling parameters for translation requests. diff --git a/vllm/entrypoints/openai/speech_to_text.py b/vllm/entrypoints/openai/speech_to_text.py index cea9924ebbaca..df9c06adb105a 100644 --- a/vllm/entrypoints/openai/speech_to_text.py +++ b/vllm/entrypoints/openai/speech_to_text.py @@ -293,8 +293,14 @@ class OpenAISpeechToText(OpenAIServing): try: # Unlike most decoder-only models, whisper generation length is not # constrained by the size of the input audio, which is mapped to a - # fixed-size log-mel-spectogram. - default_max_tokens = self.model_config.max_model_len + # fixed-size log-mel-spectogram. Still, allow for fewer tokens to be + # generated by respecting the extra completion tokens arg. + if request.max_completion_tokens is None: + default_max_tokens = self.model_config.max_model_len + else: + default_max_tokens = min( + self.model_config.max_model_len, request.max_completion_tokens + ) sampling_params = request.to_sampling_params( default_max_tokens, self.default_sampling_params ) diff --git a/vllm/envs.py b/vllm/envs.py index d0f2798096263..7e072a588591c 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -207,7 +207,7 @@ if TYPE_CHECKING: VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False VLLM_ENABLE_CUDAGRAPH_GC: bool = False VLLM_LOOPBACK_IP: str = "" - VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False + VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = True VLLM_ENABLE_RESPONSES_API_STORE: bool = False VLLM_USE_TRTLLM_ATTENTION: str | None = None VLLM_NVFP4_GEMM_BACKEND: str | None = None @@ -1430,7 +1430,7 @@ environment_variables: dict[str, Callable[[], Any]] = { # kv-cache memory usage and enable longer contexts) # TODO(lucas): Remove this flag once latency regression is resolved. "VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE": lambda: bool( - int(os.getenv("VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE", "0")) + int(os.getenv("VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE", "1")) ), # Enables support for the "store" option in the OpenAI Responses API. # When set to 1, vLLM's OpenAI server will retain the input and output diff --git a/vllm/lora/request.py b/vllm/lora/request.py index c97e435e32165..55756bdb103bd 100644 --- a/vllm/lora/request.py +++ b/vllm/lora/request.py @@ -14,11 +14,6 @@ class LoRARequest( """ Request for a LoRA adapter. - Note that this class should be used internally. For online - serving, it is recommended to not allow users to use this class but - instead provide another layer of abstraction to prevent users from - accessing unauthorized LoRA adapters. - lora_int_id must be globally unique for a given adapter. This is currently not enforced in vLLM. """ diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py index 8c9d8a2777d58..a46e3972ed8e3 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py @@ -71,6 +71,18 @@ class FusedMoEMethodBase(QuantizeMethodBase): "implementation based on the prepare_finalize" ) + def prepare_dp_allgather_tensor( + self, + layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> tuple[torch.Tensor, list[torch.Tensor]]: + """Hook to prepare tensors and extra tensors for DP allgather + EP dispatch.""" + raise NotImplementedError( + "Method 'prepare_dp_allgather_tensor' is not implemented in " + f"{self.__class__.__name__}." + ) + @abstractmethod def get_fused_moe_quant_config( self, layer: torch.nn.Module diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index cc3afade709d9..b39ce415a0f83 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -44,6 +44,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( is_flashinfer_supporting_global_sf, ) from vllm.platforms import current_platform +from vllm.utils.flashinfer import has_flashinfer_trtllm_fused_moe from vllm.utils.math_utils import cdiv, round_up from vllm.utils.torch_utils import ( aux_stream, @@ -1933,10 +1934,46 @@ class FusedMoE(CustomOp): ) with sp_ctx: + extra_tensors = None if do_naive_dispatch_combine: - hidden_states_combined, router_logits = get_ep_group().dispatch( - hidden_states, router_logits, self.is_sequence_parallel + # Avoid circular import + from vllm.model_executor.layers.quantization.modelopt import ( + ModelOptNvFp4FusedMoE, ) + + post_quant_allgather = ( + has_flashinfer_trtllm_fused_moe() + and self.quant_method is not None + and self.dp_size > 1 + and self.use_ep + and isinstance(self.quant_method, ModelOptNvFp4FusedMoE) + ) + if post_quant_allgather: + hidden_states_to_dispatch, extra_tensors = ( + self.quant_method.prepare_dp_allgather_tensor( + self, hidden_states, router_logits + ) + ) + else: + hidden_states_to_dispatch = hidden_states + + dispatch_res = get_ep_group().dispatch( + hidden_states_to_dispatch, + router_logits, + self.is_sequence_parallel, + extra_tensors=extra_tensors, + ) + if extra_tensors is not None: + hidden_states_combined, router_logits, extra_tensors_combined = ( + dispatch_res + ) + hidden_states_combined = ( + hidden_states_combined, + extra_tensors_combined[0], + ) + else: + hidden_states_combined, router_logits = dispatch_res + # Run shared experts before matrix multiply. # because matrix multiply maybe modify the hidden_states. if has_separate_shared_experts and not use_shared_experts_stream: diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 3ed15ed7dd422..314848721a80a 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -121,7 +121,7 @@ class AWQMarlinConfig(QuantizationConfig): @classmethod def get_min_capability(cls) -> int: - return 80 + return 75 @classmethod def get_config_filenames(cls) -> list[str]: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index f650a6eabbb9c..c302e465aedb7 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -626,17 +626,11 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): apply_router_weight_on_input=layer.apply_router_weight_on_input, ) else: + # If no modular kernel is provided, use cutlass_moe_fp4 for TP case + # only (no EP). from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 - assert layer.expert_map is None, ( - "Expert Parallelism / expert_map " - "is currently not supported for " - "CompressedTensorsW4A4Nvfp4MoEMethod." - ) assert self.moe_quant_config is not None - - # Cutlass moe takes in activations in BF16/Half precision - # and fp4 quantized weights loaded from the checkpoint return cutlass_moe_fp4( a=x, w1_fp4=layer.w13_weight, @@ -644,6 +638,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): topk_weights=topk_weights, topk_ids=topk_ids, quant_config=self.moe_quant_config, + expert_map=layer.expert_map, apply_router_weight_on_input=layer.apply_router_weight_on_input, # TODO(bnell): derive these from arguments m=x.shape[0], diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index f2b66a2beb6d7..800340ed6043c 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -253,7 +253,7 @@ class Fp8Config(QuantizationConfig): @classmethod def get_min_capability(cls) -> int: - return 80 + return 75 @classmethod def get_config_filenames(cls) -> list[str]: diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 6e5dcfe59b2f9..347c7b2008d12 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -181,7 +181,7 @@ class GPTQMarlinConfig(QuantizationConfig): @classmethod def get_min_capability(cls) -> int: - return 80 + return 75 @classmethod def get_config_filenames(cls) -> list[str]: diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index f71854e6b63c5..aa3937d4c03ff 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -871,7 +871,7 @@ class ModelOptNvFp4Config(ModelOptQuantConfigBase): @classmethod def get_min_capability(cls) -> int: - return 80 + return 75 @classmethod def override_quantization_method( @@ -1522,6 +1522,24 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): w2_blockscale_swizzled, requires_grad=False ) + def prepare_dp_allgather_tensor( + self, + layer: FusedMoE, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> tuple[torch.Tensor, list[torch.Tensor]]: + """Optionally prepare extra tensors to carry through DP allgather/EP.""" + import flashinfer + + a1_gscale = layer.w13_input_scale_quant + hidden_states_fp4, hidden_states_sf = flashinfer.fp4_quantize( + hidden_states, + a1_gscale, + is_sf_swizzled_layout=False, + ) + extra_tensors: list[torch.Tensor] = [hidden_states_sf] + return hidden_states_fp4, extra_tensors + def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: @@ -1576,8 +1594,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): e_score_correction_bias=layer.e_score_correction_bias, ) + # Hidden_states in select_experts is only used to extract metadata + if isinstance(x, tuple): + x_routing, _ = x + else: + x_routing = x topk_weights, topk_ids, _ = layer.select_experts( - hidden_states=x, + hidden_states=x_routing, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index 76bce8a8d98d6..1d410316d6299 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -238,7 +238,7 @@ def prepare_static_weights_for_trtllm_fp4_moe( def flashinfer_trtllm_fp4_moe( layer: torch.nn.Module, - x: torch.Tensor, + x: torch.Tensor | tuple[torch.Tensor, torch.Tensor], router_logits: torch.Tensor, top_k: int, global_num_experts: int, @@ -269,12 +269,16 @@ def flashinfer_trtllm_fp4_moe( from vllm.model_executor.models.llama4 import Llama4MoE # Quantize input to FP4 - a1_gscale = layer.w13_input_scale_quant - (hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize( - x, - a1_gscale, - is_sf_swizzled_layout=False, - ) + if isinstance(x, tuple): + hidden_states_fp4, hidden_states_scale_linear_fp4 = x + else: + # hidden_states is the already quantized + a1_gscale = layer.w13_input_scale_quant + (hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize( + x, + a1_gscale, + is_sf_swizzled_layout=False, + ) # Determine routing method type use_llama4_routing = custom_routing_function is Llama4MoE.custom_routing_function @@ -360,13 +364,17 @@ def flashinfer_trtllm_fp4_routed_moe( torch.bfloat16 ).view(torch.int16) - # Quantize input to FP4 - a1_gscale = layer.w13_input_scale_quant - (hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize( - x, - a1_gscale, - is_sf_swizzled_layout=False, - ) + if isinstance(x, tuple): + # Hidden_states is the already quantized + hidden_states_fp4, hidden_states_scale_linear_fp4 = x + else: + # Quantize input to FP4 + a1_gscale = layer.w13_input_scale_quant + (hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize( + x, + a1_gscale, + is_sf_swizzled_layout=False, + ) # Call TRT-LLM FP4 block-scale MoE kernel out = flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe( diff --git a/vllm/model_executor/models/nemotron_nas.py b/vllm/model_executor/models/nemotron_nas.py index 19a942a5277cc..83ef5e7e1282d 100644 --- a/vllm/model_executor/models/nemotron_nas.py +++ b/vllm/model_executor/models/nemotron_nas.py @@ -169,10 +169,13 @@ class DeciLMDecoderLayer(nn.Module): self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) if not self._is_no_op_ffn: - ffn_mult = block_config.ffn.ffn_mult - intermediate_size = _ffn_mult_to_intermediate_size( - ffn_mult, config.hidden_size - ) + if hasattr(block_config.ffn, "ffn_mult"): + ffn_mult = block_config.ffn.ffn_mult + intermediate_size = _ffn_mult_to_intermediate_size( + ffn_mult, config.hidden_size + ) + else: + intermediate_size = block_config.ffn.intermediate_size self.mlp = LlamaMLP( hidden_size=self.hidden_size, diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 27a7ef1a44064..80e951257e536 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -713,17 +713,13 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo): mm_counts: Mapping[str, int], ) -> int: target_width, target_height = self.get_image_size_with_most_features() - video_soft_tokens = self.get_num_video_tokens( + num_video_soft_tokens = self.get_num_video_tokens( image_width=target_width, image_height=target_height, num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts), image_processor=None, ) - - # NOTE: By default in Qwen3-VL, one video token is converted to - # "<{timestamp} seconds>" (on average 9.5 tokens) + vision_start_token + video_token + vision_end_token # noqa: E501 - formatted_video_soft_tokens = video_soft_tokens * 12.5 - return int(formatted_video_soft_tokens) + return num_video_soft_tokens def _calculate_timestamps( self, indices: list[int] | torch.Tensor, video_fps: float, merge_size: int diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 6b1cbbe24e2e7..fa69818a7b1f8 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from collections import UserDict, defaultdict from collections.abc import Mapping, Sequence from dataclasses import dataclass -from functools import partial +from functools import cached_property, partial from itertools import accumulate from typing import ( TYPE_CHECKING, @@ -169,11 +169,42 @@ class PlaceholderRange: between `offset` and `offset + length` to assign embeddings to. """ - def get_num_embeds(self) -> int: + @cached_property + def embeds_cumsum(self) -> torch.Tensor | None: if self.is_embed is None: + return None + + return self.is_embed.cumsum(dim=0) + + @cached_property + def get_num_embeds(self) -> int: + if self.embeds_cumsum is None: return self.length - return int(self.is_embed.sum().item()) + return int(self.embeds_cumsum[-1]) + + def get_embeds_indices_in_range( + self, start_idx: int, end_idx: int + ) -> tuple[int, int]: + """ + Returns the starting and ending indices of the embeddings of encoder outputs + in the range of [start_idx, end_idx) in the placeholders. + + For example, given: + PlaceholderRange(offset=2, length=5, is_embed=[False, True, False, True, True]) + + If start_idx=3 and end_idx=5, the output is (1, 3) because we want to get + the second and the third embeddings from the encoder output. + """ + if self.embeds_cumsum is None: + return start_idx, end_idx + + embeds_start_idx = ( + int(self.embeds_cumsum[start_idx - 1]) if start_idx > 0 else 0 + ) + embeds_end_idx = int(self.embeds_cumsum[end_idx - 1]) + + return embeds_start_idx, embeds_end_idx def extract_embeds_range(self) -> list[tuple[int, int]]: """Extract the start and end indices of the embedded region in prompt. @@ -188,7 +219,7 @@ class PlaceholderRange: Returns full placeholder range if `is_embed` is `None`. """ if self.is_embed is None: - return [(self.offset, self.offset + self.length)] + return [(self.offset, self.offset + self.length - 1)] mask_i = self.is_embed.int() starts = torch.nonzero( diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index cb70041e9744f..a690948f759e9 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -274,15 +274,11 @@ class MultiModalProfiler(Generic[_I]): def _get_mm_num_tokens( self, mm_inputs: MultiModalInputs, - mm_embeddings_only: bool = True, ) -> Mapping[str, int]: placeholders_by_modality = mm_inputs["mm_placeholders"] return { - modality: sum( - item.get_num_embeds() if mm_embeddings_only else item.length - for item in placeholders - ) + modality: sum(item.get_num_embeds for item in placeholders) for modality, placeholders in placeholders_by_modality.items() } @@ -328,12 +324,15 @@ class MultiModalProfiler(Generic[_I]): multi_modal_placeholders=mm_inputs["mm_placeholders"], ) - def _get_mm_max_tokens( + def get_mm_max_tokens( self, seq_len: int, mm_counts: Mapping[str, int] | None = None, - mm_embeddings_only: bool = True, ) -> Mapping[str, int]: + """ + Returns the maximum number of embeddings per item of each modality, excluding + any break/text tokens in-between multimodal embeddings/encoder outputs. + """ if mm_counts is None: mm_counts = self.get_mm_limits() @@ -349,21 +348,4 @@ class MultiModalProfiler(Generic[_I]): } mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) - return self._get_mm_num_tokens(mm_inputs, mm_embeddings_only=mm_embeddings_only) - - def get_mm_max_contiguous_tokens( - self, - seq_len: int, - mm_counts: Mapping[str, int] | None = None, - ) -> Mapping[str, int]: - """ - Returns the maximum length of the multimodal (image placeholders+text) - tokens, including any break/text tokens in-between image embeddings. - - ` [IMG] [IMG] [IMG] [IMG] [IMG] [IMG] ` - Returns 9, even when the number of image embeddings is 6. - - This is important to take into account when profiling and - initializing the encoder cache size. - """ - return self._get_mm_max_tokens(seq_len, mm_counts, mm_embeddings_only=False) + return self._get_mm_num_tokens(mm_inputs) diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 00a84f9dec4f7..1e7fe8648ab71 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -164,7 +164,7 @@ class MultiModalRegistry: profiler.get_mm_limits() if profiler_limits is None else profiler_limits ) - return profiler.get_mm_max_contiguous_tokens( + return profiler.get_mm_max_tokens( seq_len, {modality: 1 for modality, limit in profiler_limits.items() if limit > 0}, ) diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 5019b771f4a14..1c2710be3173b 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -184,6 +184,23 @@ def has_flashinfer_cutedsl() -> bool: ) +@functools.cache +def has_flashinfer_trtllm_fused_moe() -> bool: + """Return `True` if FlashInfer TRTLLM fused MoE is available.""" + if not has_flashinfer_moe(): + return False + required_functions = [ + ("flashinfer.fused_moe", "trtllm_fp8_block_scale_moe"), + ("flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe"), + ("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"), + ] + for module_name, attr_name in required_functions: + mod = _get_submodule(module_name) + if not mod or not hasattr(mod, attr_name): + return False + return True + + @functools.cache def has_flashinfer_cutlass_fused_moe() -> bool: """Return `True` if FlashInfer CUTLASS fused MoE is available.""" diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index f5ad98cf2125c..3445e998d6371 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with FlashAttention.""" +import copy from dataclasses import dataclass from typing import ClassVar @@ -250,6 +251,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad if get_flash_attn_version() == 3 else AttentionCGSupport.UNIFORM_BATCH ) + supports_update_block_table: bool = True def __init__( self, @@ -493,6 +495,17 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad ) return attn_metadata + def update_block_table( + self, + metadata: FlashAttentionMetadata, + blk_table: torch.Tensor, + slot_mapping: torch.Tensor, + ) -> FlashAttentionMetadata: + new_metadata = copy.copy(metadata) + new_metadata.block_table = blk_table + new_metadata.slot_mapping = slot_mapping + return new_metadata + def use_cascade_attention(self, *args, **kwargs) -> bool: return use_cascade_attention(*args, **kwargs) diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index bf1d8f09ab0ac..f923371283aa0 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy import itertools from dataclasses import dataclass @@ -134,6 +135,8 @@ class Mamba2AttentionMetadata: class Mamba2AttentionMetadataBuilder( BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata] ): + supports_update_block_table: bool = True + def __init__( self, kv_cache_spec: AttentionSpec, @@ -346,3 +349,27 @@ class Mamba2AttentionMetadataBuilder( num_computed_tokens_p=num_computed_tokens_p, ) return attn_metadata + + def update_block_table( + self, + metadata: Mamba2AttentionMetadata, + blk_table: torch.Tensor, + slot_mapping: torch.Tensor, + ) -> Mamba2AttentionMetadata: + new_metadata = copy.copy(metadata) + prefix_caching = self.vllm_config.cache_config.enable_prefix_caching + state_indices_t = blk_table if prefix_caching else blk_table[:, 0] + num_reqs = blk_table.shape[0] + + # For CUDA graphs, copy to persistent buffer + if ( + metadata.num_prefills == 0 + and num_reqs <= self.decode_cudagraph_max_bs + and self.compilation_config.cudagraph_mode.has_full_cudagraphs() + ): + persistent_state_indices_t = self.state_indices_tensor[:num_reqs] + persistent_state_indices_t.copy_(state_indices_t, non_blocking=True) + state_indices_t = persistent_state_indices_t + + new_metadata.state_indices_tensor = state_indices_t + return new_metadata diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 1cbe929fc57a8..56763f4b52539 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -4,6 +4,7 @@ import abc import enum import functools from abc import abstractmethod +from collections.abc import Callable from dataclasses import dataclass, field, fields, make_dataclass from typing import ( TYPE_CHECKING, @@ -317,6 +318,9 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): # If not, set this to None. Otherwise set it to the query # length that will be pulled into the front of the batch. reorder_batch_threshold: int | None = None + # Does this backend/builder support updating the block table in existing + # metadata + supports_update_block_table: bool = False @abstractmethod def __init__( @@ -387,6 +391,21 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): """ raise NotImplementedError + def update_block_table( + self, + metadata: M, + blk_table: torch.Tensor, + slot_mapping: torch.Tensor, + ) -> M: + """ + Update the block table for the attention metadata. + Faster when theres multiple kv-cache groups that create virtually the + same metadata but just with different block tables. + + Only needs to be implemented if supports_update_block_table is True. + """ + raise NotImplementedError + def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata ) -> M: @@ -603,7 +622,7 @@ def make_local_attention_virtual_batches( attn_chunk_size: int, common_attn_metadata: CommonAttentionMetadata, block_size: int = 0, -) -> CommonAttentionMetadata: +) -> tuple[CommonAttentionMetadata, Callable[[torch.Tensor], torch.Tensor]]: query_start_loc_np = common_attn_metadata.query_start_loc_cpu.numpy() seq_lens_np = common_attn_metadata.seq_lens_cpu.numpy() block_table = common_attn_metadata.block_table_tensor @@ -715,9 +734,12 @@ def make_local_attention_virtual_batches( # tensor first, which recovers perf. batch_indices_torch = torch.from_numpy(batch_indices) block_indices_torch = torch.from_numpy(block_indices) - block_table_local = block_table[batch_indices_torch, block_indices_torch].view( - virtual_batches, -1 - ) + + # Save as a lambda so we can return this for update_block_table + make_block_table = lambda block_table: block_table[ + batch_indices_torch, block_indices_torch + ].view(virtual_batches, -1) + block_table_local = make_block_table(block_table) query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local) seq_lens_cpu = torch.from_numpy(seqlens_k_local) @@ -736,7 +758,7 @@ def make_local_attention_virtual_batches( causal=True, _seq_lens_cpu=seq_lens_cpu, _num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local), - ) + ), make_block_table def make_kv_sharing_fast_prefill_common_attn_metadata( diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py index 50f738713590b..d73c05d2cf80b 100644 --- a/vllm/v1/core/encoder_cache_manager.py +++ b/vllm/v1/core/encoder_cache_manager.py @@ -39,20 +39,26 @@ class EncoderCacheManager: space for new embeddings. Oldest cached embeddings with no request referenced will be first evicted. + NOTE: The EncoderCacheManager operates on the level of multimodal embeddings + instead of encoder tokens (i.e. all tokens that represent the multimodal data + in the input sequence). This means all break/text tokens in-between multimodal + embeddings are not considered with respect to the cache size and the number + of free slots. + Args: cache_size: Limit the size of the cache, measured by the number of - tokens from the input sequence. + encoder embeddings from the input sequence. Attributes: - cache_size: Total cache capacity in encoder tokens. - num_free_slots: Current available cache capacity in encoder tokens. + cache_size: Total cache capacity in encoder embeddings. + num_free_slots: Current available cache capacity in encoder embeddings. num_freeable_slots: Capacity that can be immediately reclaimed by - evicting entries with zero references (in encoder tokens). + evicting entries with zero references (in encoder embeddings). cached: Mapping from mm_hash to a set of request IDs that currently reference the cached entry. If the set is empty, the entry exists but is not referenced by any request and is eligible for reclamation. - freeable: List of tuples (mm_hash, num_tokens) representing entries + freeable: List of tuples (mm_hash, num_encoder_embeds) representing entries whose no current running request is needed and that can be freed to make space when needed. freed: List of mm_hash strings that were actually evicted since the @@ -67,7 +73,7 @@ class EncoderCacheManager: # mm_hash of mm_data => ids of requests that reference the mm_data self.cached: dict[str, set[str]] = {} - # mm_hash of mm_data => num_encoder_tokens of the mm_data + # mm_hash of mm_data => num_encoder_embeds of the mm_data self.freeable: OrderedDict[str, int] = OrderedDict() self.freed: list[str] = [] @@ -93,8 +99,8 @@ class EncoderCacheManager: # Cached but currently not referenced by any request if not self.cached[mm_hash]: - num_tokens = self.freeable.pop(mm_hash) - self.num_freeable_slots -= num_tokens + num_encoder_embeds = self.freeable.pop(mm_hash) + self.num_freeable_slots -= num_encoder_embeds self.cached[mm_hash].add(request.request_id) return True @@ -104,7 +110,7 @@ class EncoderCacheManager: request: Request, input_id: int, encoder_compute_budget: int, - num_tokens_to_schedule: int, + num_embeds_to_schedule: int, ) -> bool: """Check if there's sufficient cache space for a multimodal input. If there is, return True and update EncoderCacheManager state. @@ -121,9 +127,9 @@ class EncoderCacheManager: Args: request: The request containing the multimodal input. input_id: Index of the multimodal input within the request. - encoder_compute_budget: Number of encoder tokens allowed to be + encoder_compute_budget: Number of encoder embeddings allowed to be computed when this method is invoked. - num_tokens_to_schedule: Number of tokens already scheduled to be + num_embeds_to_schedule: Number of encoder embeddings already scheduled to be allocated with cache space when this method is invoked. Returns: @@ -134,30 +140,30 @@ class EncoderCacheManager: Note: This method does not allocate physical memory for the encoder output but only the state of EncoderCacheManager. """ - num_tokens = request.get_num_encoder_tokens(input_id) + num_embeds = request.get_num_encoder_embeds(input_id) # Not enough compute budget - if num_tokens > encoder_compute_budget: + if num_embeds > encoder_compute_budget: return False - num_tokens += num_tokens_to_schedule + num_embeds += num_embeds_to_schedule # Enough free slots - if num_tokens <= self.num_free_slots: + if num_embeds <= self.num_free_slots: return True # Not enough reclaimable slots - if num_tokens > self.num_freeable_slots: + if num_embeds > self.num_freeable_slots: return False # Not enough free slots but enough reclaimable slots # NOTE: Eviction takes place here, but physical memory is not freed # until model runner is notified by the scheduler output. - while num_tokens > self.num_free_slots: - mm_hash, num_free_token = self.freeable.popitem(last=False) + while num_embeds > self.num_free_slots: + mm_hash, num_free_embeds = self.freeable.popitem(last=False) del self.cached[mm_hash] self.freed.append(mm_hash) - self.num_free_slots += num_free_token + self.num_free_slots += num_free_embeds return True def allocate(self, request: Request, input_id: int) -> None: @@ -176,16 +182,16 @@ class EncoderCacheManager: if mm_hash not in self.cached: self.cached[mm_hash] = set() - num_encoder_tokens = request.get_num_encoder_tokens(input_id) + num_encoder_embeds = request.get_num_encoder_embeds(input_id) # NOTE: Encoder cache should always have enough space for encoder inputs # that are scheduled since eviction takes place at can_allocate(). - assert self.num_free_slots >= num_encoder_tokens - assert self.num_freeable_slots >= num_encoder_tokens + assert self.num_free_slots >= num_encoder_embeds + assert self.num_freeable_slots >= num_encoder_embeds self.cached[mm_hash].add(request_id) - self.num_free_slots -= num_encoder_tokens - self.num_freeable_slots -= num_encoder_tokens + self.num_free_slots -= num_encoder_embeds + self.num_freeable_slots -= num_encoder_embeds def get_cached_input_ids(self, request: Request) -> set[int]: """Get all cached multimodal input IDs for a request. @@ -206,7 +212,7 @@ class EncoderCacheManager: When the reference set for the corresponding `mm_hash` becomes empty, the entry is appended to `freeable` and `num_freeable_slots` is - increased by the number of encoder tokens for that input. + increased by the number of encoder embeddings for that input. The entry is NOT physically freed until capacity is needed (e.g., by `can_allocate`). @@ -218,9 +224,9 @@ class EncoderCacheManager: return self.cached[mm_hash].discard(req_id) if not self.cached[mm_hash]: - num_tokens = request.get_num_encoder_tokens(input_id) - self.freeable[mm_hash] = num_tokens - self.num_freeable_slots += num_tokens + num_encoder_embeds = request.get_num_encoder_embeds(input_id) + self.freeable[mm_hash] = num_encoder_embeds + self.num_freeable_slots += num_encoder_embeds def free(self, request: Request) -> None: """Free all encoder input cache reference held by *request*. @@ -361,20 +367,20 @@ class EncoderDecoderCacheManager(EncoderCacheManager): request: Request, input_id: int, encoder_compute_budget: int, - num_tokens_to_schedule: int, + num_embeds_to_schedule: int, ) -> bool: - num_tokens = request.get_num_encoder_tokens(input_id) + num_encoder_embeds = request.get_num_encoder_embeds(input_id) # Not enough compute budget - if num_tokens > encoder_compute_budget: + if num_encoder_embeds > encoder_compute_budget: return False - num_tokens += num_tokens_to_schedule + num_encoder_embeds += num_embeds_to_schedule # Enough free slots - return num_tokens <= self.num_free_slots + return num_encoder_embeds <= self.num_free_slots def allocate(self, request: Request, input_id: int) -> None: - num_encoder_tokens = request.get_num_encoder_tokens(input_id) - self.num_free_slots -= num_encoder_tokens + num_encoder_embeds = request.get_num_encoder_embeds(input_id) + self.num_free_slots -= num_encoder_embeds mm_hash = request.mm_features[input_id].identifier self.freed.append(mm_hash) @@ -392,5 +398,5 @@ class EncoderDecoderCacheManager(EncoderCacheManager): return freed def free_encoder_input(self, request: Request, input_id: int) -> None: - num_tokens = request.get_num_encoder_tokens(input_id) - self.num_free_slots += num_tokens + num_encoder_embeds = request.get_num_encoder_embeds(input_id) + self.num_free_slots += num_encoder_embeds diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 754e0b9d08316..8e835ad096405 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -355,11 +355,11 @@ class Scheduler(SchedulerInterface): if preempted_encoder_inputs: # Restore encoder compute budget if the preempted # request had encoder inputs scheduled in this step. - num_tokens_to_restore = sum( - preempted_req.get_num_encoder_tokens(i) + num_embeds_to_restore = sum( + preempted_req.get_num_encoder_embeds(i) for i in preempted_encoder_inputs ) - encoder_compute_budget += num_tokens_to_restore + encoder_compute_budget += num_embeds_to_restore req_index -= 1 else: preempted_req = self.running.pop() @@ -911,10 +911,11 @@ class Scheduler(SchedulerInterface): # multiple encoder inputs per request), we need to create temporary # trackers for accounting at the encoder input level. mm_hashes_to_schedule = set() - num_tokens_to_schedule = 0 + num_embeds_to_schedule = 0 for i, mm_feature in enumerate(mm_features): start_pos = mm_feature.mm_position.offset num_encoder_tokens = mm_feature.mm_position.length + num_encoder_embeds = mm_feature.mm_position.get_num_embeds # The encoder output is needed if the two ranges overlap: # [num_computed_tokens, num_computed_tokens + num_new_tokens) and @@ -970,9 +971,8 @@ class Scheduler(SchedulerInterface): ): num_new_tokens = start_pos - num_computed_tokens break - if not self.encoder_cache_manager.can_allocate( - request, i, encoder_compute_budget, num_tokens_to_schedule + request, i, encoder_compute_budget, num_embeds_to_schedule ): # The encoder cache is full or the encoder budget is exhausted. # NOTE(woosuk): We assume that the encoder input tokens should @@ -992,14 +992,31 @@ class Scheduler(SchedulerInterface): num_new_tokens = 0 break + # Calculate the number of embeddings to schedule in the current range + # of scheduled encoder placholder tokens. + start_idx_rel = max(0, num_computed_tokens - start_pos) + end_idx_rel = min( + num_encoder_tokens, num_computed_tokens + num_new_tokens - start_pos + ) + curr_embeds_start, curr_embeds_end = ( + mm_feature.mm_position.get_embeds_indices_in_range( + start_idx_rel, + end_idx_rel, + ) + ) + # There's no embeddings in the current range of encoder placeholder tokens + # so we can skip the encoder input. + if curr_embeds_end - curr_embeds_start == 0: + continue + if self.ec_connector is not None and remote_cache_has_item[i]: mm_hashes_to_schedule.add(request.mm_features[i].identifier) external_load_encoder_input.append(i) - num_tokens_to_schedule += num_encoder_tokens + num_embeds_to_schedule += num_encoder_embeds continue - num_tokens_to_schedule += num_encoder_tokens - encoder_compute_budget -= num_encoder_tokens + num_embeds_to_schedule += num_encoder_embeds + encoder_compute_budget -= num_encoder_embeds mm_hashes_to_schedule.add(request.mm_features[i].identifier) encoder_inputs_to_schedule.append(i) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index a775e840e841c..f33059b80b894 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -209,10 +209,10 @@ class Request: def get_finished_reason(self) -> FinishReason | None: return RequestStatus.get_finished_reason(self.status) - def get_num_encoder_tokens(self, input_id: int) -> int: + def get_num_encoder_embeds(self, input_id: int) -> int: assert input_id < len(self.mm_features) - num_tokens = self.mm_features[input_id].mm_position.length - return num_tokens + num_embeds = self.mm_features[input_id].mm_position.get_num_embeds + return num_embeds def record_event( self, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9af9aa7ad2a2c..8cb814a6d053f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -170,9 +170,7 @@ from .utils import ( MultiModalBudget, add_kv_sharing_layers_to_kv_cache_groups, bind_kv_cache, - gather_mm_placeholders, sanity_check_mm_encoder_outputs, - scatter_mm_placeholders, ) if TYPE_CHECKING: @@ -1640,6 +1638,15 @@ class GPUModelRunner( logits_indices ) + # Cache attention metadata builds across hybrid KV-cache groups + # The only thing that changes between different hybrid KV-cache groups when the + # same metadata builder and KVCacheSpec is the same is the block table, so we + # can cache the attention metadata builds and just update the block table using + # `builder.update_block_table` if the builder supports it. + cached_attn_metadata: dict[ + tuple[KVCacheSpec, type[AttentionMetadataBuilder]], AttentionMetadata + ] = {} + def _build_attn_group_metadata( kv_cache_gid: int, attn_gid: int, @@ -1647,13 +1654,15 @@ class GPUModelRunner( ubid: int | None = None, ) -> None: attn_group = self.attn_groups[kv_cache_gid][attn_gid] + builder = attn_group.get_metadata_builder(ubid or 0) + cache_key = (kv_cache_groups[kv_cache_gid].kv_cache_spec, type(builder)) + cascade_attn_prefix_len = ( cascade_attn_prefix_lens[kv_cache_gid][attn_gid] if cascade_attn_prefix_lens else 0 ) - builder = attn_group.get_metadata_builder(ubid or 0) extra_attn_metadata_args = {} if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder): assert ubid is None, "UBatching not supported with GDN yet" @@ -1668,12 +1677,23 @@ class GPUModelRunner( attn_metadata_i = builder.build_for_cudagraph_capture( common_attn_metadata ) + elif ( + cache_key in cached_attn_metadata + and builder.supports_update_block_table + ): + attn_metadata_i = builder.update_block_table( + cached_attn_metadata[cache_key], + common_attn_metadata.block_table_tensor, + common_attn_metadata.slot_mapping, + ) else: attn_metadata_i = builder.build( common_prefix_len=cascade_attn_prefix_len, common_attn_metadata=common_attn_metadata, **extra_attn_metadata_args, ) + if builder.supports_update_block_table: + cached_attn_metadata[cache_key] = attn_metadata_i if ubid is None: assert isinstance(attn_metadata, dict) @@ -2271,10 +2291,7 @@ class GPUModelRunner( # Cache the encoder outputs by mm_hash for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs): - self.encoder_cache[mm_hash] = scatter_mm_placeholders( - output, - is_embed=pos_info.is_embed, - ) + self.encoder_cache[mm_hash] = output logger.debug("Finish execute for mm hash %s", mm_hash) self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash) @@ -2325,6 +2342,13 @@ class GPUModelRunner( num_encoder_tokens, ) assert start_idx < end_idx + curr_embeds_start, curr_embeds_end = ( + pos_info.get_embeds_indices_in_range(start_idx, end_idx) + ) + # If there are no embeddings in the current range, we skip + # gathering the embeddings. + if curr_embeds_start == curr_embeds_end: + continue mm_hash = mm_feature.identifier encoder_output = self.encoder_cache.get(mm_hash, None) @@ -2332,16 +2356,14 @@ class GPUModelRunner( if (is_embed := pos_info.is_embed) is not None: is_embed = is_embed[start_idx:end_idx] + mm_embeds_item = encoder_output[curr_embeds_start:curr_embeds_end] + else: + mm_embeds_item = encoder_output[start_idx:end_idx] req_start_pos = req_start_idx + start_pos - num_computed_tokens is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = ( True if is_embed is None else is_embed ) - - mm_embeds_item = gather_mm_placeholders( - encoder_output[start_idx:end_idx], - is_embed=is_embed, - ) mm_embeds_req.append(mm_embeds_item) if self.is_multimodal_pruning_enabled and self.uses_mrope: @@ -4570,31 +4592,8 @@ class GPUModelRunner( dummy_encoder_outputs, expected_num_items=max_mm_items_per_batch, ) - - # NOTE: This happens when encoder cache needs to store - # the embeddings that encoder outputs are scattered onto. - # In this case we create dummy embeddings of size - # (max_tokens_for_modality, hidden_size) and scatter - # encoder output into it. - encoder_output_shape = dummy_encoder_outputs[0].shape - max_mm_tokens_per_item = mm_budget.max_tokens_by_modality[ - dummy_modality - ] - if encoder_output_shape[0] < max_mm_tokens_per_item: - encoder_hidden_size = encoder_output_shape[-1] - expanded_outputs = [] - for output in dummy_encoder_outputs: - expanded = output.new_zeros( - (max_mm_tokens_per_item, encoder_hidden_size) - ) - num_tokens = output.shape[0] - expanded[:num_tokens].copy_(output) - expanded_outputs.append(expanded) - - dummy_encoder_outputs = expanded_outputs - - # Cache the dummy encoder outputs. - self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) + for i, output in enumerate(dummy_encoder_outputs): + self.encoder_cache[f"tmp_{i}"] = output # Add `is_profile` here to pre-allocate communication buffers hidden_states, last_hidden_states = self._dummy_run( diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index e9c48223d58b9..2e8afec024ce9 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -4,10 +4,12 @@ from collections import defaultdict from dataclasses import dataclass, field import torch +from typing_extensions import deprecated from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.layer import Attention from vllm.config import ModelConfig, SchedulerConfig, VllmConfig +from vllm.logger import init_logger from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.model_executor.models.utils import extract_layer_index from vllm.multimodal.cache import processor_only_cache_from_config @@ -17,6 +19,8 @@ from vllm.v1.attention.backends.utils import AttentionMetadataBuilder from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec +logger = init_logger(__name__) + class MultiModalBudget: """Helper class to calculate budget information for multi-modal models.""" @@ -198,6 +202,7 @@ def sanity_check_mm_encoder_outputs( ) +@deprecated("`scatter_mm_placeholders` is deprecated and will be removed in v0.15.0.") def scatter_mm_placeholders( embeds: torch.Tensor, is_embed: torch.Tensor | None, @@ -226,6 +231,7 @@ def scatter_mm_placeholders( return placeholders +@deprecated("`gather_mm_placeholders` is deprecated and will be removed in v0.15.0.") def gather_mm_placeholders( placeholders: torch.Tensor, is_embed: torch.Tensor | None,