mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-23 00:37:04 +08:00
Merge branch 'main' into mlm-full-lora-support
This commit is contained in:
commit
fe104bd63c
@ -654,7 +654,7 @@ steps:
|
||||
- vllm/model_executor/layers/quantization
|
||||
autorun_on_main: true
|
||||
commands:
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt
|
||||
|
||||
- label: OpenAI API correctness # 22min
|
||||
timeout_in_minutes: 30
|
||||
@ -1064,7 +1064,7 @@ steps:
|
||||
- csrc/
|
||||
- vllm/model_executor/layers/quantization
|
||||
commands:
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-blackwell.txt --tp-size=1
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-blackwell.txt
|
||||
|
||||
##### 1 GPU test #####
|
||||
##### multi gpus test #####
|
||||
|
||||
14
.github/mergify.yml
vendored
14
.github/mergify.yml
vendored
@ -235,6 +235,20 @@ pull_request_rules:
|
||||
add:
|
||||
- rocm
|
||||
|
||||
- name: label-cpu
|
||||
description: Automatically apply cpu label
|
||||
conditions:
|
||||
- label != stale
|
||||
- files~=^(?!.*kv_offload)(?!.*cpu_offload).*\bcpu.*
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
- cpu
|
||||
assign:
|
||||
users:
|
||||
- "fadara01"
|
||||
- "aditew01"
|
||||
|
||||
- name: label-structured-output
|
||||
description: Automatically apply structured-output label
|
||||
conditions:
|
||||
|
||||
109
CMakeLists.txt
109
CMakeLists.txt
@ -357,6 +357,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
# marlin arches for fp16 output
|
||||
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0+PTX" "${CUDA_ARCHS}")
|
||||
# marlin has limited support for turing
|
||||
cuda_archs_loose_intersection(MARLIN_SM75_ARCHS "7.5" "${CUDA_ARCHS}")
|
||||
# marlin arches for bf16 output (we need 9.0 for bf16 atomicAdd PTX)
|
||||
cuda_archs_loose_intersection(MARLIN_BF16_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")
|
||||
# marlin arches for fp8 input
|
||||
@ -364,8 +366,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction
|
||||
# so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0)
|
||||
cuda_archs_loose_intersection(MARLIN_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}")
|
||||
# marlin arches for other files
|
||||
cuda_archs_loose_intersection(MARLIN_OTHER_ARCHS "7.5;8.0+PTX" "${CUDA_ARCHS}")
|
||||
|
||||
if (MARLIN_ARCHS)
|
||||
if (MARLIN_OTHER_ARCHS)
|
||||
|
||||
#
|
||||
# For the Marlin kernels we automatically generate sources for various
|
||||
@ -406,25 +410,39 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
message(STATUS "Marlin generation script has not changed, skipping generation.")
|
||||
endif()
|
||||
|
||||
file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_float16.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
|
||||
if (MARLIN_ARCHS)
|
||||
file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_float16.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
|
||||
|
||||
file(GLOB MARLIN_TEMPLATE_BF16_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_bfloat16.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_TEMPLATE_BF16_KERNEL_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_BF16_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_TEMPLATE_BF16_KERNEL_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
file(GLOB MARLIN_TEMPLATE_BF16_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_bfloat16.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_TEMPLATE_BF16_KERNEL_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_BF16_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_TEMPLATE_BF16_KERNEL_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_BF16_KERNEL_SRC})
|
||||
endif()
|
||||
|
||||
if (MARLIN_SM75_ARCHS)
|
||||
file(GLOB MARLIN_TEMPLATE_SM75_KERNEL_SRC "csrc/quantization/gptq_marlin/sm75_kernel_*.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_TEMPLATE_SM75_KERNEL_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_SM75_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_TEMPLATE_SM75_KERNEL_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_SM75_KERNEL_SRC})
|
||||
endif()
|
||||
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_BF16_KERNEL_SRC})
|
||||
|
||||
if (MARLIN_FP8_ARCHS)
|
||||
file(GLOB MARLIN_TEMPLATE_FP8_KERNEL_SRC "csrc/quantization/gptq_marlin/sm89_kernel_*.cu")
|
||||
@ -446,14 +464,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_SRCS}"
|
||||
CUDA_ARCHS "${MARLIN_ARCHS}")
|
||||
CUDA_ARCHS "${MARLIN_OTHER_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties("csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
||||
set_source_files_properties(${MARLIN_SRCS}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_EXT_SRC "${MARLIN_SRCS}")
|
||||
|
||||
message(STATUS "Building Marlin kernels for archs: ${MARLIN_ARCHS}")
|
||||
message(STATUS "Building Marlin kernels for archs: ${MARLIN_OTHER_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building Marlin kernels as no compatible archs found"
|
||||
" in CUDA target architectures")
|
||||
@ -980,12 +998,16 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# note that we always set `use_atomic_add=False` for moe marlin now,
|
||||
# so we don't need 9.0 for bf16 atomicAdd PTX
|
||||
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0+PTX" "${CUDA_ARCHS}")
|
||||
# moe marlin has limited support for turing
|
||||
cuda_archs_loose_intersection(MARLIN_MOE_SM75_ARCHS "7.5" "${CUDA_ARCHS}")
|
||||
# moe marlin arches for fp8 input
|
||||
# - sm80 doesn't support fp8 computation
|
||||
# - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction
|
||||
# so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0)
|
||||
cuda_archs_loose_intersection(MARLIN_MOE_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}")
|
||||
if (MARLIN_MOE_ARCHS)
|
||||
# moe marlin arches for other files
|
||||
cuda_archs_loose_intersection(MARLIN_MOE_OTHER_ARCHS "7.5;8.0+PTX" "${CUDA_ARCHS}")
|
||||
if (MARLIN_MOE_OTHER_ARCHS)
|
||||
|
||||
#
|
||||
# For the Marlin MOE kernels we automatically generate sources for various
|
||||
@ -1026,16 +1048,29 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
message(STATUS "Marlin MOE generation script has not changed, skipping generation.")
|
||||
endif()
|
||||
|
||||
file(GLOB MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/sm80_kernel_*.cu")
|
||||
list(APPEND MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/ops.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_MOE_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_MOE_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_MOE_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
if (MARLIN_MOE_ARCHS)
|
||||
file(GLOB MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/sm80_kernel_*.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_MOE_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_MOE_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_MOE_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_SRC})
|
||||
endif()
|
||||
|
||||
if (MARLIN_MOE_SM75_ARCHS)
|
||||
file(GLOB MARLIN_MOE_SM75_SRC "csrc/moe/marlin_moe_wna16/sm75_kernel_*.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_MOE_SM75_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_MOE_SM75_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_MOE_SM75_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_SM75_SRC})
|
||||
endif()
|
||||
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_SRC})
|
||||
|
||||
if (MARLIN_MOE_FP8_ARCHS)
|
||||
file(GLOB MARLIN_MOE_FP8_SRC "csrc/moe/marlin_moe_wna16/sm89_kernel_*.cu")
|
||||
@ -1049,7 +1084,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_FP8_SRC})
|
||||
endif()
|
||||
|
||||
message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}")
|
||||
set(MARLIN_MOE_OTHER_SRC "csrc/moe/marlin_moe_wna16/ops.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_MOE_OTHER_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_MOE_OTHER_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_MOE_OTHER_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_MOE_EXT_SRC "${MARLIN_MOE_OTHER_SRC}")
|
||||
|
||||
message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_OTHER_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building Marlin MOE kernels as no compatible archs found"
|
||||
" in CUDA target architectures")
|
||||
|
||||
@ -13,8 +13,8 @@ from vllm.triton_utils import triton
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
batch_size_range = [1, 16, 32, 64, 128]
|
||||
seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]
|
||||
batch_size_range = [1, 16, 128]
|
||||
seq_len_range = [1, 16, 64, 1024, 4096]
|
||||
intermediate_size = [3072, 9728, 12288]
|
||||
configs = list(itertools.product(batch_size_range, seq_len_range, intermediate_size))
|
||||
|
||||
|
||||
@ -15,19 +15,61 @@ __device__ __forceinline__ scalar_t compute(const scalar_t& x,
|
||||
const scalar_t& y) {
|
||||
return act_first ? ACT_FN(x) * y : x * ACT_FN(y);
|
||||
}
|
||||
// Activation and gating kernel template.
|
||||
|
||||
// Check if all pointers are 16-byte aligned for int4 vectorized access
|
||||
__device__ __forceinline__ bool is_16byte_aligned(const void* ptr) {
|
||||
return (reinterpret_cast<uintptr_t>(ptr) & 15) == 0;
|
||||
}
|
||||
|
||||
// Activation and gating kernel template.
|
||||
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
|
||||
bool act_first>
|
||||
__global__ void act_and_mul_kernel(
|
||||
scalar_t* __restrict__ out, // [..., d]
|
||||
const scalar_t* __restrict__ input, // [..., 2, d]
|
||||
const int d) {
|
||||
constexpr int VEC_SIZE = 16 / sizeof(scalar_t);
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
|
||||
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
|
||||
out[token_idx * d + idx] = compute<scalar_t, ACT_FN, act_first>(x, y);
|
||||
const scalar_t* x_ptr = input + token_idx * 2 * d;
|
||||
const scalar_t* y_ptr = x_ptr + d;
|
||||
scalar_t* out_ptr = out + token_idx * d;
|
||||
|
||||
// Check alignment for 128-bit vectorized access.
|
||||
// All three pointers must be 16-byte aligned for safe int4 operations.
|
||||
const bool aligned = is_16byte_aligned(x_ptr) && is_16byte_aligned(y_ptr) &&
|
||||
is_16byte_aligned(out_ptr);
|
||||
|
||||
if (aligned && d >= VEC_SIZE) {
|
||||
// Fast path: 128-bit vectorized loop
|
||||
const int4* x_vec = reinterpret_cast<const int4*>(x_ptr);
|
||||
const int4* y_vec = reinterpret_cast<const int4*>(y_ptr);
|
||||
int4* out_vec = reinterpret_cast<int4*>(out_ptr);
|
||||
const int num_vecs = d / VEC_SIZE;
|
||||
const int vec_end = num_vecs * VEC_SIZE;
|
||||
|
||||
for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
|
||||
int4 x = VLLM_LDG(&x_vec[i]), y = VLLM_LDG(&y_vec[i]), r;
|
||||
auto* xp = reinterpret_cast<scalar_t*>(&x);
|
||||
auto* yp = reinterpret_cast<scalar_t*>(&y);
|
||||
auto* rp = reinterpret_cast<scalar_t*>(&r);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; j++) {
|
||||
rp[j] = compute<scalar_t, ACT_FN, act_first>(xp[j], yp[j]);
|
||||
}
|
||||
out_vec[i] = r;
|
||||
}
|
||||
// Scalar cleanup for remaining elements
|
||||
for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) {
|
||||
out_ptr[i] = compute<scalar_t, ACT_FN, act_first>(VLLM_LDG(&x_ptr[i]),
|
||||
VLLM_LDG(&y_ptr[i]));
|
||||
}
|
||||
} else {
|
||||
// Scalar fallback for unaligned data or small d
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = VLLM_LDG(&x_ptr[idx]);
|
||||
const scalar_t y = VLLM_LDG(&y_ptr[idx]);
|
||||
out_ptr[idx] = compute<scalar_t, ACT_FN, act_first>(x, y);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -120,50 +162,115 @@ template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&, const float)>
|
||||
__global__ void act_and_mul_kernel_with_param(
|
||||
scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const int d,
|
||||
const float param) {
|
||||
constexpr int VEC_SIZE = 16 / sizeof(scalar_t);
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
|
||||
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
|
||||
out[token_idx * d + idx] = ACT_FN(x, param) * y;
|
||||
const scalar_t* x_ptr = input + token_idx * 2 * d;
|
||||
const scalar_t* y_ptr = x_ptr + d;
|
||||
scalar_t* out_ptr = out + token_idx * d;
|
||||
|
||||
// Check alignment for 128-bit vectorized access
|
||||
const bool aligned = is_16byte_aligned(x_ptr) && is_16byte_aligned(y_ptr) &&
|
||||
is_16byte_aligned(out_ptr);
|
||||
|
||||
if (aligned && d >= VEC_SIZE) {
|
||||
// Fast path: 128-bit vectorized loop
|
||||
const int4* x_vec = reinterpret_cast<const int4*>(x_ptr);
|
||||
const int4* y_vec = reinterpret_cast<const int4*>(y_ptr);
|
||||
int4* out_vec = reinterpret_cast<int4*>(out_ptr);
|
||||
const int num_vecs = d / VEC_SIZE;
|
||||
const int vec_end = num_vecs * VEC_SIZE;
|
||||
|
||||
for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
|
||||
int4 x = VLLM_LDG(&x_vec[i]), y = VLLM_LDG(&y_vec[i]), r;
|
||||
auto* xp = reinterpret_cast<scalar_t*>(&x);
|
||||
auto* yp = reinterpret_cast<scalar_t*>(&y);
|
||||
auto* rp = reinterpret_cast<scalar_t*>(&r);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; j++) {
|
||||
rp[j] = ACT_FN(xp[j], param) * yp[j];
|
||||
}
|
||||
out_vec[i] = r;
|
||||
}
|
||||
// Scalar cleanup for remaining elements
|
||||
for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) {
|
||||
out_ptr[i] = ACT_FN(VLLM_LDG(&x_ptr[i]), param) * VLLM_LDG(&y_ptr[i]);
|
||||
}
|
||||
} else {
|
||||
// Scalar fallback for unaligned data or small d
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = VLLM_LDG(&x_ptr[idx]);
|
||||
const scalar_t y = VLLM_LDG(&y_ptr[idx]);
|
||||
out_ptr[idx] = ACT_FN(x, param) * y;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T swigluoai_and_mul(const T& gate, const T& up,
|
||||
float alpha, float limit) {
|
||||
// clamp gate: min=None, max=limit
|
||||
const float gate_f = (float)gate;
|
||||
const float clamped_gate = gate_f > limit ? limit : gate_f;
|
||||
|
||||
// clamp up: min=-limit, max=limit
|
||||
const float up_f = (float)up;
|
||||
const float clamped_up =
|
||||
up_f > limit ? limit : (up_f < -limit ? -limit : up_f);
|
||||
|
||||
// glu = gate * sigmoid(gate * alpha)
|
||||
const float sigmoid_val = 1.0f / (1.0f + expf(-clamped_gate * alpha));
|
||||
const float glu = clamped_gate * sigmoid_val;
|
||||
|
||||
// (up + 1) * glu
|
||||
return (T)((clamped_up + 1.0f) * glu);
|
||||
// Clamp gate to (-inf, limit] and up to [-limit, limit]
|
||||
const float g = fminf((float)gate, limit);
|
||||
const float u = fmaxf(fminf((float)up, limit), -limit);
|
||||
// glu = gate * sigmoid(gate * alpha), then return (up + 1) * glu
|
||||
return (T)((u + 1.0f) * g / (1.0f + expf(-g * alpha)));
|
||||
}
|
||||
|
||||
// Interleaved gate/up: input has [gate0, up0, gate1, up1, ...].
|
||||
template <typename scalar_t,
|
||||
scalar_t (*ACT_FN)(const scalar_t&, const scalar_t&, const float,
|
||||
const float)>
|
||||
__global__ void swigluoai_and_mul_kernel(
|
||||
scalar_t* __restrict__ out, // [..., d]
|
||||
const scalar_t* __restrict__ input, // [..., 2, d]
|
||||
const scalar_t* __restrict__ input, // [..., 2 * d] (interleaved)
|
||||
const int d, const float alpha, const float limit) {
|
||||
// For interleaved data: input has 2*d elements per token (gate/up pairs)
|
||||
// output has d elements per token
|
||||
constexpr int VEC_SIZE = 16 / sizeof(scalar_t);
|
||||
constexpr int PAIRS = VEC_SIZE / 2; // Number of gate/up pairs per int4 load
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
// TODO: Vectorize loads and stores.
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
// gate = x[..., ::2] (even indices)
|
||||
const scalar_t gate = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx]);
|
||||
// up = x[..., 1::2] (odd indices)
|
||||
const scalar_t up = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx + 1]);
|
||||
const scalar_t* in_ptr = input + token_idx * 2 * d;
|
||||
scalar_t* out_ptr = out + token_idx * d;
|
||||
|
||||
out[token_idx * d + idx] = ACT_FN(gate, up, alpha, limit);
|
||||
// Check alignment for 128-bit vectorized access on input.
|
||||
// For output we use int2 (64-bit) which has 8-byte alignment requirement.
|
||||
const bool in_aligned = is_16byte_aligned(in_ptr);
|
||||
const bool out_aligned =
|
||||
(reinterpret_cast<uintptr_t>(out_ptr) & 7) == 0; // 8-byte for int2
|
||||
|
||||
if (in_aligned && out_aligned && d >= PAIRS) {
|
||||
// Fast path: vectorized loop
|
||||
// Each int4 load gives VEC_SIZE elements = PAIRS gate/up pairs
|
||||
// Each int2 store writes PAIRS output elements
|
||||
const int4* in_vec = reinterpret_cast<const int4*>(in_ptr);
|
||||
int2* out_vec = reinterpret_cast<int2*>(out_ptr);
|
||||
const int num_vecs = d / PAIRS;
|
||||
const int vec_end = num_vecs * PAIRS;
|
||||
|
||||
for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
|
||||
int4 v = VLLM_LDG(&in_vec[i]);
|
||||
int2 r;
|
||||
auto* vp = reinterpret_cast<scalar_t*>(&v);
|
||||
auto* rp = reinterpret_cast<scalar_t*>(&r);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < PAIRS; j++) {
|
||||
rp[j] = ACT_FN(vp[2 * j], vp[2 * j + 1], alpha, limit);
|
||||
}
|
||||
out_vec[i] = r;
|
||||
}
|
||||
// Scalar cleanup for remaining elements
|
||||
for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) {
|
||||
out_ptr[i] = ACT_FN(VLLM_LDG(&in_ptr[2 * i]),
|
||||
VLLM_LDG(&in_ptr[2 * i + 1]), alpha, limit);
|
||||
}
|
||||
} else {
|
||||
// Scalar fallback for unaligned data or small d
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
// gate = x[..., ::2] (even indices)
|
||||
const scalar_t gate = VLLM_LDG(&in_ptr[2 * idx]);
|
||||
// up = x[..., 1::2] (odd indices)
|
||||
const scalar_t up = VLLM_LDG(&in_ptr[2 * idx + 1]);
|
||||
out_ptr[idx] = ACT_FN(gate, up, alpha, limit);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -217,10 +324,41 @@ __global__ void activation_kernel(
|
||||
scalar_t* __restrict__ out, // [..., d]
|
||||
const scalar_t* __restrict__ input, // [..., d]
|
||||
const int d) {
|
||||
constexpr int VEC_SIZE = 16 / sizeof(scalar_t);
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
|
||||
out[token_idx * d + idx] = ACT_FN(x);
|
||||
const scalar_t* in_ptr = input + token_idx * d;
|
||||
scalar_t* out_ptr = out + token_idx * d;
|
||||
|
||||
// Check alignment for 128-bit vectorized access
|
||||
const bool aligned = is_16byte_aligned(in_ptr) && is_16byte_aligned(out_ptr);
|
||||
|
||||
if (aligned && d >= VEC_SIZE) {
|
||||
// Fast path: 128-bit vectorized loop
|
||||
const int4* in_vec = reinterpret_cast<const int4*>(in_ptr);
|
||||
int4* out_vec = reinterpret_cast<int4*>(out_ptr);
|
||||
const int num_vecs = d / VEC_SIZE;
|
||||
const int vec_end = num_vecs * VEC_SIZE;
|
||||
|
||||
for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
|
||||
int4 v = VLLM_LDG(&in_vec[i]), r;
|
||||
auto* vp = reinterpret_cast<scalar_t*>(&v);
|
||||
auto* rp = reinterpret_cast<scalar_t*>(&r);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; j++) {
|
||||
rp[j] = ACT_FN(vp[j]);
|
||||
}
|
||||
out_vec[i] = r;
|
||||
}
|
||||
// Scalar cleanup for remaining elements
|
||||
for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) {
|
||||
out_ptr[i] = ACT_FN(VLLM_LDG(&in_ptr[i]));
|
||||
}
|
||||
} else {
|
||||
// Scalar fallback for unaligned data or small d
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = VLLM_LDG(&in_ptr[idx]);
|
||||
out_ptr[idx] = ACT_FN(x);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -446,9 +446,13 @@ __device__ inline T apply_sigmoid(T val) {
|
||||
|
||||
template <ScoringFunc SF, typename T>
|
||||
__device__ inline T apply_scoring(T val) {
|
||||
if constexpr (SF == SCORING_SIGMOID) {
|
||||
if constexpr (SF == SCORING_NONE) {
|
||||
return val;
|
||||
} else if constexpr (SF == SCORING_SIGMOID) {
|
||||
return apply_sigmoid(val);
|
||||
} else {
|
||||
static_assert(SF == SCORING_NONE || SF == SCORING_SIGMOID,
|
||||
"Unsupported ScoringFunc in apply_scoring");
|
||||
return val;
|
||||
}
|
||||
}
|
||||
@ -670,10 +674,13 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
|
||||
if (case_id < num_tokens) {
|
||||
if (if_proceed_next_topk) {
|
||||
float scale = routed_scaling_factor;
|
||||
if (renormalize) {
|
||||
scale /= topk_sum;
|
||||
}
|
||||
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
||||
float base = cuda_cast<float, T>(s_topk_value[i]);
|
||||
float value = renormalize ? (base / topk_sum * routed_scaling_factor)
|
||||
: (base * routed_scaling_factor);
|
||||
float value = base * scale;
|
||||
topk_indices[i] = s_topk_idx[i];
|
||||
topk_values[i] = value;
|
||||
}
|
||||
|
||||
1
csrc/moe/marlin_moe_wna16/.gitignore
vendored
1
csrc/moe/marlin_moe_wna16/.gitignore
vendored
@ -1,2 +1,3 @@
|
||||
sm*_kernel_*.cu
|
||||
kernel_selector.h
|
||||
kernel_*.cu
|
||||
|
||||
@ -10,6 +10,8 @@ import jinja2
|
||||
|
||||
ARCHS = []
|
||||
SUPPORT_FP8 = False
|
||||
SUPPORT_SM75 = False
|
||||
SUPPORT_SM80 = False
|
||||
for arch in sys.argv[1].split(","):
|
||||
arch = arch[: arch.index(".") + 2].replace(".", "")
|
||||
arch = int(arch)
|
||||
@ -19,6 +21,10 @@ for arch in sys.argv[1].split(","):
|
||||
# with FP16 MMA, so it cannot achieve any acceleration.
|
||||
if arch in [89, 120]:
|
||||
SUPPORT_FP8 = True
|
||||
if arch >= 80:
|
||||
SUPPORT_SM80 = True
|
||||
if arch == 75:
|
||||
SUPPORT_SM75 = True
|
||||
|
||||
FILE_HEAD_COMMENT = """
|
||||
// auto generated by generate_kernels.py
|
||||
@ -157,6 +163,7 @@ def remove_old_kernels():
|
||||
|
||||
def generate_new_kernels():
|
||||
result_dict = {}
|
||||
sm_75_result_dict = {}
|
||||
|
||||
for quant_config in QUANT_CONFIGS:
|
||||
c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"])
|
||||
@ -174,6 +181,8 @@ def generate_new_kernels():
|
||||
s_type = quant_config.get("s_type", c_type)
|
||||
if (a_type, b_type, c_type) not in result_dict:
|
||||
result_dict[(a_type, b_type, c_type)] = []
|
||||
if a_type in ["kFloat16", "kS8"] and c_type == "kFloat16":
|
||||
sm_75_result_dict[(a_type, b_type, c_type)] = []
|
||||
|
||||
for group_blocks, m_blocks, thread_configs in itertools.product(
|
||||
all_group_blocks, all_m_blocks, all_thread_configs
|
||||
@ -197,78 +206,89 @@ def generate_new_kernels():
|
||||
"thread_k_blocks": thread_k // 16,
|
||||
"thread_n_blocks": thread_n // 16,
|
||||
"m_block_size_8": "true" if m_blocks == 0.5 else "false",
|
||||
"stages": "pipe_stages",
|
||||
"stages": 4,
|
||||
"group_blocks": group_blocks,
|
||||
"is_zp_float": "false",
|
||||
}
|
||||
|
||||
result_dict[(a_type, b_type, c_type)].append(config)
|
||||
if SUPPORT_SM80:
|
||||
result_dict[(a_type, b_type, c_type)].append(config)
|
||||
if (a_type, b_type, c_type) in sm_75_result_dict and SUPPORT_SM75:
|
||||
config_sm75 = config.copy()
|
||||
config_sm75["stages"] = 2
|
||||
sm_75_result_dict[(a_type, b_type, c_type)].append(config_sm75)
|
||||
|
||||
kernel_selector_str = FILE_HEAD_COMMENT
|
||||
|
||||
for (a_type, b_type, c_type), config_list in result_dict.items():
|
||||
all_template_str_list = []
|
||||
for config in config_list:
|
||||
s_type = config["s_type"]
|
||||
template_str = jinja2.Template(TEMPLATE).render(
|
||||
a_type_id=f"vllm::{a_type}.id()",
|
||||
b_type_id=f"vllm::{b_type}.id()",
|
||||
c_type_id=f"vllm::{c_type}.id()",
|
||||
s_type_id=f"vllm::{s_type}.id()",
|
||||
**config,
|
||||
)
|
||||
all_template_str_list.append(template_str)
|
||||
|
||||
conditions = [
|
||||
f"a_type == vllm::{a_type}",
|
||||
f"b_type == vllm::{b_type}",
|
||||
f"c_type == vllm::{c_type}",
|
||||
f"s_type == vllm::{s_type}",
|
||||
f"threads == {config['threads']}",
|
||||
f"thread_m_blocks == {config['thread_m_blocks']}",
|
||||
f"thread_n_blocks == {config['thread_n_blocks']}",
|
||||
f"thread_k_blocks == {config['thread_k_blocks']}",
|
||||
f"m_block_size_8 == {config['m_block_size_8']}",
|
||||
f"group_blocks == {config['group_blocks']}",
|
||||
f"is_zp_float == {config['is_zp_float']}",
|
||||
]
|
||||
conditions = " && ".join(conditions)
|
||||
|
||||
if kernel_selector_str == FILE_HEAD_COMMENT:
|
||||
kernel_selector_str += f"if ({conditions})\n kernel = "
|
||||
else:
|
||||
kernel_selector_str += f"else if ({conditions})\n kernel = "
|
||||
|
||||
kernel_template2 = (
|
||||
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
|
||||
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
|
||||
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
|
||||
"{{is_zp_float}}>;"
|
||||
)
|
||||
|
||||
kernel_selector_str += (
|
||||
jinja2.Template(kernel_template2).render(
|
||||
for result_dict_tmp in [result_dict, sm_75_result_dict]:
|
||||
for (a_type, b_type, c_type), config_list in result_dict_tmp.items():
|
||||
all_template_str_list = []
|
||||
if not config_list:
|
||||
continue
|
||||
for config in config_list:
|
||||
s_type = config["s_type"]
|
||||
template_str = jinja2.Template(TEMPLATE).render(
|
||||
a_type_id=f"vllm::{a_type}.id()",
|
||||
b_type_id=f"vllm::{b_type}.id()",
|
||||
c_type_id=f"vllm::{c_type}.id()",
|
||||
s_type_id=f"vllm::{s_type}.id()",
|
||||
**config,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
all_template_str_list.append(template_str)
|
||||
|
||||
file_content = FILE_HEAD + "\n\n"
|
||||
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
|
||||
if a_type == "kFE4M3fn":
|
||||
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
else:
|
||||
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
conditions = [
|
||||
f"a_type == vllm::{a_type}",
|
||||
f"b_type == vllm::{b_type}",
|
||||
f"c_type == vllm::{c_type}",
|
||||
f"s_type == vllm::{s_type}",
|
||||
f"threads == {config['threads']}",
|
||||
f"thread_m_blocks == {config['thread_m_blocks']}",
|
||||
f"thread_n_blocks == {config['thread_n_blocks']}",
|
||||
f"thread_k_blocks == {config['thread_k_blocks']}",
|
||||
f"m_block_size_8 == {config['m_block_size_8']}",
|
||||
f"stages == {config['stages']}",
|
||||
f"group_blocks == {config['group_blocks']}",
|
||||
f"is_zp_float == {config['is_zp_float']}",
|
||||
]
|
||||
conditions = " && ".join(conditions)
|
||||
|
||||
filename = filename.lower()
|
||||
if kernel_selector_str == FILE_HEAD_COMMENT:
|
||||
kernel_selector_str += f"if ({conditions})\n kernel = "
|
||||
else:
|
||||
kernel_selector_str += f"else if ({conditions})\n kernel = "
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
|
||||
f.write(file_content)
|
||||
kernel_template2 = (
|
||||
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
|
||||
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
|
||||
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
|
||||
"{{is_zp_float}}>;"
|
||||
)
|
||||
|
||||
kernel_selector_str += (
|
||||
jinja2.Template(kernel_template2).render(
|
||||
a_type_id=f"vllm::{a_type}.id()",
|
||||
b_type_id=f"vllm::{b_type}.id()",
|
||||
c_type_id=f"vllm::{c_type}.id()",
|
||||
s_type_id=f"vllm::{s_type}.id()",
|
||||
**config,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
file_content = FILE_HEAD + "\n\n"
|
||||
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
|
||||
if a_type == "kFE4M3fn":
|
||||
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
elif result_dict_tmp is sm_75_result_dict:
|
||||
filename = f"sm75_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
else:
|
||||
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
|
||||
filename = filename.lower()
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
|
||||
f.write(file_content)
|
||||
|
||||
if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT:
|
||||
kernel_selector_str += (
|
||||
|
||||
@ -26,6 +26,7 @@
|
||||
#include "quantization/gptq_marlin/marlin.cuh"
|
||||
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
|
||||
#include "quantization/gptq_marlin/dequant.h"
|
||||
#include "quantization/gptq_marlin/marlin_mma.h"
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||
@ -35,7 +36,7 @@
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
||||
|
||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const vllm::ScalarTypeId b_type_id, // weight MarlinScalarType id
|
||||
@ -84,146 +85,6 @@ __global__ void Marlin(
|
||||
|
||||
#else
|
||||
|
||||
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
|
||||
// output/accumulation.
|
||||
template <vllm::ScalarTypeId type_id, int k_size = 16>
|
||||
__device__ inline void mma(
|
||||
const typename MarlinScalarType<type_id>::FragA& a_frag,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b,
|
||||
typename MarlinScalarType<type_id>::FragC& frag_c, int idx = 0) {
|
||||
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
|
||||
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
||||
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
|
||||
if constexpr (k_size == 16) {
|
||||
if constexpr (std::is_same<scalar_t, half>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "f"(c[0]),
|
||||
"f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "r"(c[0]),
|
||||
"r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
}
|
||||
} else if (k_size == 32) {
|
||||
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <vllm::ScalarTypeId type_id, int k_size = 16>
|
||||
__device__ inline void mma_trans(
|
||||
const typename MarlinScalarType<type_id>::FragA& a_frag,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b2,
|
||||
typename MarlinScalarType<type_id>::FragC& frag_c) {
|
||||
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
|
||||
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
||||
const uint32_t* b2 = reinterpret_cast<const uint32_t*>(&frag_b2);
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
|
||||
if constexpr (k_size == 16) {
|
||||
if constexpr (std::is_same<scalar_t, half>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
|
||||
"f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]),
|
||||
"r"(c[3]));
|
||||
}
|
||||
} else {
|
||||
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 1200
|
||||
asm volatile(
|
||||
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
#else
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
#endif
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
||||
// memory, directly in tensor core layout.
|
||||
template <int count, vllm::ScalarTypeId type_id>
|
||||
@ -439,9 +300,20 @@ __global__ void Marlin(
|
||||
if constexpr (a_type_id == vllm::kFE4M3fn.id()) return;
|
||||
#endif
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
// Turing TensorCore only supports fp16 and int8
|
||||
if constexpr (a_type_id != vllm::kFloat16.id() && a_type_id != vllm::kS8.id())
|
||||
return;
|
||||
#endif
|
||||
|
||||
int num_tokens_past_padded = num_tokens_past_padded_ptr[0];
|
||||
constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks);
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
constexpr bool use_fp16_accum = a_type_id == vllm::kFloat16.id();
|
||||
#else
|
||||
constexpr bool use_fp16_accum = false;
|
||||
#endif
|
||||
using Adtype = MarlinScalarType<a_type_id>;
|
||||
using Cdtype = MarlinScalarType<c_type_id>;
|
||||
|
||||
@ -618,7 +490,22 @@ __global__ void Marlin(
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
|
||||
if constexpr (moe_block_size >= 16)
|
||||
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 16);
|
||||
if constexpr (moe_block_size >= 8)
|
||||
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 8);
|
||||
if constexpr (moe_block_size >= 4)
|
||||
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 4);
|
||||
if constexpr (moe_block_size >= 2)
|
||||
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 2);
|
||||
|
||||
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 1);
|
||||
block_num_valid_tokens = local_count;
|
||||
#else
|
||||
block_num_valid_tokens = __reduce_add_sync(0xffffffff, local_count);
|
||||
#endif
|
||||
|
||||
if (lane_id == 0)
|
||||
reinterpret_cast<int*>(sh_new)[0] = block_num_valid_tokens;
|
||||
@ -1018,10 +905,6 @@ __global__ void Marlin(
|
||||
constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride)
|
||||
: (stages * s_sh_stage);
|
||||
int4* sh_s = sh_zp + (stages * zp_sh_stage);
|
||||
// shared memory reused by reduction should be smaller than
|
||||
// shared memory used by weight.
|
||||
static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <=
|
||||
stages * b_sh_stage);
|
||||
int4* sh_a = sh_s + sh_s_size;
|
||||
|
||||
// Register storage for double buffer of shared memory reads.
|
||||
@ -1545,11 +1428,13 @@ __global__ void Marlin(
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks; i++) {
|
||||
if constexpr (m_block_size_8) {
|
||||
mma_trans<a_type_id>(frag_a[k2][i], frag_b0, frag_b1,
|
||||
frag_c[i][j][0]);
|
||||
mma_trans<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b0, frag_b1,
|
||||
frag_c[i][j][0]);
|
||||
} else {
|
||||
mma<a_type_id>(frag_a[k2][i], frag_b0, frag_c[i][j][0]);
|
||||
mma<a_type_id>(frag_a[k2][i], frag_b1, frag_c[i][j][1]);
|
||||
mma<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b0,
|
||||
frag_c[i][j][0]);
|
||||
mma<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b1,
|
||||
frag_c[i][j][1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1583,10 +1468,12 @@ __global__ void Marlin(
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks; i++) {
|
||||
mma<a_type_id, 32>(frag_a[k2][i], frag_b[0],
|
||||
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]);
|
||||
mma<a_type_id, 32>(frag_a[k2][i], frag_b[1],
|
||||
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]);
|
||||
mma<a_type_id, false, 32>(
|
||||
frag_a[k2][i], frag_b[0],
|
||||
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]);
|
||||
mma<a_type_id, false, 32>(
|
||||
frag_a[k2][i], frag_b[1],
|
||||
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]);
|
||||
}
|
||||
|
||||
if constexpr (group_blocks != -1) {
|
||||
@ -2132,6 +2019,21 @@ __global__ void Marlin(
|
||||
// While this pattern may not be the most readable, other ways of writing
|
||||
// the loop seemed to noticeably worse performance after compilation.
|
||||
if (slice_iters == 0) {
|
||||
// convert fp16 accum to fp32 for reduction
|
||||
if constexpr (use_fp16_accum) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < (thread_m_blocks * (is_a_8bit ? 2 : 4) * 2); i++) {
|
||||
float* frag_c_part_float = reinterpret_cast<float*>(frag_c) + i * 4;
|
||||
scalar_t* frag_c_part_half =
|
||||
reinterpret_cast<scalar_t*>(frag_c_part_float);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 3; i >= 0; i--) {
|
||||
frag_c_part_float[i] = Cdtype::num2float(frag_c_part_half[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (is_a_8bit) {
|
||||
float frag_a_s[2 * thread_m_blocks];
|
||||
|
||||
|
||||
@ -142,7 +142,7 @@ typedef struct {
|
||||
|
||||
int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
|
||||
int prob_n, int prob_k, int num_bits, int group_size,
|
||||
bool has_act_order, bool is_k_full) {
|
||||
bool has_act_order, bool is_k_full, int stages) {
|
||||
bool cache_scales_chunk = has_act_order && !is_k_full;
|
||||
|
||||
int tb_n = th_config.thread_n;
|
||||
@ -160,13 +160,13 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
|
||||
|
||||
if (cache_scales_chunk) {
|
||||
int load_groups =
|
||||
tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
|
||||
tb_groups * stages * 2; // Chunk size is 2x pipeline over dim K
|
||||
load_groups = max(load_groups, 32); // We load at least 32 scale groups
|
||||
return load_groups * tb_n * 2;
|
||||
} else {
|
||||
int tb_scales = tb_groups * tb_n * 2;
|
||||
|
||||
return tb_scales * pipe_stages;
|
||||
return tb_scales * stages;
|
||||
}
|
||||
}
|
||||
|
||||
@ -174,7 +174,7 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
|
||||
int thread_m_blocks, int prob_m, int prob_n,
|
||||
int prob_k, int num_bits, int group_size,
|
||||
bool has_act_order, bool is_k_full, int has_zp,
|
||||
int is_zp_float, bool is_a_8bit) {
|
||||
int is_zp_float, bool is_a_8bit, int stages) {
|
||||
int pack_factor = 32 / num_bits;
|
||||
|
||||
// Get B size
|
||||
@ -185,8 +185,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
|
||||
// shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights
|
||||
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
|
||||
int sh_block_meta_size = tb_m * 16;
|
||||
int sh_a_size = pipe_stages * (tb_m * tb_k) * (is_a_8bit ? 1 : 2);
|
||||
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
|
||||
int sh_a_size = stages * (tb_m * tb_k) * (is_a_8bit ? 1 : 2);
|
||||
int sh_b_size = stages * (tb_k * tb_n / pack_factor) * 4;
|
||||
int sh_red_size = tb_m * (tb_n + 8) * 2;
|
||||
int sh_bias_size = tb_n * 2;
|
||||
int tmp_size =
|
||||
@ -195,8 +195,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
|
||||
|
||||
int sh_s_size =
|
||||
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
|
||||
group_size, has_act_order, is_k_full);
|
||||
int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0;
|
||||
group_size, has_act_order, is_k_full, stages);
|
||||
int sh_g_idx_size = has_act_order && !is_k_full ? stages * tb_k / 4 : 0;
|
||||
int sh_zp_size = 0;
|
||||
if (has_zp) {
|
||||
if (is_zp_float)
|
||||
@ -217,7 +217,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
|
||||
int thread_m_blocks, int prob_m, int prob_n, int prob_k,
|
||||
int num_bits, int group_size, bool has_act_order,
|
||||
bool is_k_full, int has_zp, int is_zp_float,
|
||||
int max_shared_mem, bool is_a_8bit) {
|
||||
bool is_a_8bit, int stages, int max_shared_mem) {
|
||||
// Sanity
|
||||
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
|
||||
th_config.num_threads == -1) {
|
||||
@ -243,7 +243,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
|
||||
int cache_size =
|
||||
get_kernel_cache_size(th_config, m_block_size_8, thread_m_blocks, prob_m,
|
||||
prob_n, prob_k, num_bits, group_size, has_act_order,
|
||||
is_k_full, has_zp, is_zp_float, is_a_8bit);
|
||||
is_k_full, has_zp, is_zp_float, is_a_8bit, stages);
|
||||
return cache_size <= max_shared_mem;
|
||||
}
|
||||
|
||||
@ -252,7 +252,7 @@ MarlinFuncPtr get_marlin_kernel(
|
||||
const vllm::ScalarType c_type, const vllm::ScalarType s_type,
|
||||
int thread_m_blocks, int thread_n_blocks, int thread_k_blocks,
|
||||
bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks,
|
||||
int threads, bool is_zp_float) {
|
||||
int threads, bool is_zp_float, int stages) {
|
||||
int num_bits = b_type.size_bits();
|
||||
auto kernel = MarlinDefault;
|
||||
|
||||
@ -266,8 +266,8 @@ exec_config_t determine_exec_config(
|
||||
const vllm::ScalarType& c_type, const vllm::ScalarType& s_type, int prob_m,
|
||||
int prob_n, int prob_k, int num_experts, int top_k, int thread_m_blocks,
|
||||
bool m_block_size_8, int num_bits, int group_size, bool has_act_order,
|
||||
bool is_k_full, bool has_zp, bool is_zp_float, int max_shared_mem, int sms,
|
||||
bool is_a_8bit) {
|
||||
bool is_k_full, bool has_zp, bool is_zp_float, bool is_a_8bit, int stages,
|
||||
int max_shared_mem, int sms) {
|
||||
exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}};
|
||||
thread_config_t* thread_configs = thread_m_blocks > 1
|
||||
? large_batch_thread_configs
|
||||
@ -284,15 +284,15 @@ exec_config_t determine_exec_config(
|
||||
|
||||
if (!is_valid_config(th_config, m_block_size_8, thread_m_blocks, prob_m,
|
||||
prob_n, prob_k, num_bits, group_size, has_act_order,
|
||||
is_k_full, has_zp, is_zp_float, max_shared_mem - 512,
|
||||
is_a_8bit)) {
|
||||
is_k_full, has_zp, is_zp_float, is_a_8bit, stages,
|
||||
max_shared_mem - 512)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int cache_size = get_kernel_cache_size(
|
||||
th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k,
|
||||
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float,
|
||||
is_a_8bit);
|
||||
is_a_8bit, stages);
|
||||
|
||||
int group_blocks = 0;
|
||||
if (!has_act_order) {
|
||||
@ -303,7 +303,7 @@ exec_config_t determine_exec_config(
|
||||
get_marlin_kernel(a_type, b_type, c_type, s_type, thread_m_blocks,
|
||||
th_config.thread_n / 16, th_config.thread_k / 16,
|
||||
m_block_size_8, has_act_order, has_zp, group_blocks,
|
||||
th_config.num_threads, is_zp_float);
|
||||
th_config.num_threads, is_zp_float, stages);
|
||||
|
||||
if (kernel == MarlinDefault) continue;
|
||||
|
||||
@ -433,8 +433,14 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
dev);
|
||||
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
|
||||
dev);
|
||||
TORCH_CHECK(major_capability * 10 + minor_capability >= 80,
|
||||
"marlin kernel only support Ampere or newer GPUs.");
|
||||
TORCH_CHECK(major_capability * 10 + minor_capability >= 75,
|
||||
"marlin kernel only support Turing or newer GPUs.");
|
||||
int stages = 4;
|
||||
if (major_capability == 7 && minor_capability == 5) {
|
||||
stages = 2;
|
||||
TORCH_CHECK(a_type == vllm::kFloat16 || a_type == vllm::kS8,
|
||||
"Turing only support FP16 or INT8 activation.");
|
||||
}
|
||||
if (a_type == vllm::kFE4M3fn) {
|
||||
TORCH_CHECK(major_capability * 10 + minor_capability >= 89,
|
||||
"FP8 only support Ada Lovelace or newer GPUs.");
|
||||
@ -461,8 +467,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
exec_cfg = determine_exec_config(
|
||||
a_type, b_type, c_type, s_type, prob_m, prob_n, prob_k, num_experts,
|
||||
top_k, thread_m_blocks, m_block_size_8, num_bits, group_size,
|
||||
has_act_order, is_k_full, has_zp, is_zp_float, max_shared_mem, sms,
|
||||
is_a_8bit);
|
||||
has_act_order, is_k_full, has_zp, is_zp_float, is_a_8bit, stages,
|
||||
max_shared_mem, sms);
|
||||
thread_tfg = exec_cfg.tb_cfg;
|
||||
}
|
||||
|
||||
@ -479,7 +485,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
TORCH_CHECK(is_valid_config(thread_tfg, m_block_size_8, thread_m_blocks,
|
||||
prob_m, prob_n, prob_k, num_bits, group_size,
|
||||
has_act_order, is_k_full, has_zp, is_zp_float,
|
||||
max_shared_mem, is_a_8bit),
|
||||
is_a_8bit, stages, max_shared_mem),
|
||||
"Invalid thread config: thread_m_blocks = ", thread_m_blocks,
|
||||
", thread_k = ", thread_tfg.thread_k,
|
||||
", thread_n = ", thread_tfg.thread_n,
|
||||
@ -493,12 +499,12 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
int sh_cache_size =
|
||||
get_kernel_cache_size(thread_tfg, m_block_size_8, thread_m_blocks, prob_m,
|
||||
prob_n, prob_k, num_bits, group_size, has_act_order,
|
||||
is_k_full, has_zp, is_zp_float, is_a_8bit);
|
||||
is_k_full, has_zp, is_zp_float, is_a_8bit, stages);
|
||||
|
||||
auto kernel = get_marlin_kernel(
|
||||
a_type, b_type, c_type, s_type, thread_m_blocks, thread_n_blocks,
|
||||
thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks,
|
||||
num_threads, is_zp_float);
|
||||
num_threads, is_zp_float, stages);
|
||||
|
||||
if (kernel == MarlinDefault) {
|
||||
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
|
||||
|
||||
1
csrc/quantization/gptq_marlin/.gitignore
vendored
1
csrc/quantization/gptq_marlin/.gitignore
vendored
@ -1,2 +1,3 @@
|
||||
sm*_kernel_*.cu
|
||||
kernel_selector.h
|
||||
kernel_*.cu
|
||||
|
||||
@ -67,7 +67,7 @@ where `scale_factor * multiplier` can be computed at weight loading.
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 750
|
||||
// Lookup-table based 3-input logical operation; explicitly used for
|
||||
// dequantization as the compiler does not seem to automatically recognize it in
|
||||
// all cases.
|
||||
|
||||
@ -10,6 +10,8 @@ import jinja2
|
||||
|
||||
ARCHS = []
|
||||
SUPPORT_FP8 = False
|
||||
SUPPORT_SM75 = False
|
||||
SUPPORT_SM80 = False
|
||||
for arch in sys.argv[1].split(","):
|
||||
arch = arch[: arch.index(".") + 2].replace(".", "")
|
||||
arch = int(arch)
|
||||
@ -19,6 +21,10 @@ for arch in sys.argv[1].split(","):
|
||||
# with FP16 MMA, so it cannot achieve any acceleration.
|
||||
if arch in [89, 120]:
|
||||
SUPPORT_FP8 = True
|
||||
if arch >= 80:
|
||||
SUPPORT_SM80 = True
|
||||
if arch == 75:
|
||||
SUPPORT_SM75 = True
|
||||
|
||||
FILE_HEAD_COMMENT = """
|
||||
// auto generated by generate_kernels.py
|
||||
@ -166,6 +172,7 @@ def remove_old_kernels():
|
||||
|
||||
def generate_new_kernels():
|
||||
result_dict = {}
|
||||
sm_75_result_dict = {}
|
||||
|
||||
for quant_config in QUANT_CONFIGS:
|
||||
c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"])
|
||||
@ -184,6 +191,8 @@ def generate_new_kernels():
|
||||
s_type = quant_config.get("s_type", c_type)
|
||||
if (a_type, b_type, c_type) not in result_dict:
|
||||
result_dict[(a_type, b_type, c_type)] = []
|
||||
if a_type in ["kFloat16", "kS8"] and c_type == "kFloat16":
|
||||
sm_75_result_dict[(a_type, b_type, c_type)] = []
|
||||
|
||||
for group_blocks, m_blocks, thread_configs in itertools.product(
|
||||
all_group_blocks, all_m_blocks, all_thread_configs
|
||||
@ -207,78 +216,89 @@ def generate_new_kernels():
|
||||
"thread_k_blocks": thread_k // 16,
|
||||
"thread_n_blocks": thread_n // 16,
|
||||
"m_block_size_8": "true" if m_blocks == 0.5 else "false",
|
||||
"stages": "pipe_stages",
|
||||
"stages": 4,
|
||||
"group_blocks": group_blocks,
|
||||
"is_zp_float": "true" if is_zp_float else "false",
|
||||
}
|
||||
|
||||
result_dict[(a_type, b_type, c_type)].append(config)
|
||||
if SUPPORT_SM80:
|
||||
result_dict[(a_type, b_type, c_type)].append(config)
|
||||
if (a_type, b_type, c_type) in sm_75_result_dict and SUPPORT_SM75:
|
||||
config_sm75 = config.copy()
|
||||
config_sm75["stages"] = 2
|
||||
sm_75_result_dict[(a_type, b_type, c_type)].append(config_sm75)
|
||||
|
||||
kernel_selector_str = FILE_HEAD_COMMENT
|
||||
|
||||
for (a_type, b_type, c_type), config_list in result_dict.items():
|
||||
all_template_str_list = []
|
||||
for config in config_list:
|
||||
s_type = config["s_type"]
|
||||
template_str = jinja2.Template(TEMPLATE).render(
|
||||
a_type_id=f"vllm::{a_type}.id()",
|
||||
b_type_id=f"vllm::{b_type}.id()",
|
||||
c_type_id=f"vllm::{c_type}.id()",
|
||||
s_type_id=f"vllm::{s_type}.id()",
|
||||
**config,
|
||||
)
|
||||
all_template_str_list.append(template_str)
|
||||
|
||||
conditions = [
|
||||
f"a_type == vllm::{a_type}",
|
||||
f"b_type == vllm::{b_type}",
|
||||
f"c_type == vllm::{c_type}",
|
||||
f"s_type == vllm::{s_type}",
|
||||
f"threads == {config['threads']}",
|
||||
f"thread_m_blocks == {config['thread_m_blocks']}",
|
||||
f"thread_n_blocks == {config['thread_n_blocks']}",
|
||||
f"thread_k_blocks == {config['thread_k_blocks']}",
|
||||
f"m_block_size_8 == {config['m_block_size_8']}",
|
||||
f"group_blocks == {config['group_blocks']}",
|
||||
f"is_zp_float == {config['is_zp_float']}",
|
||||
]
|
||||
conditions = " && ".join(conditions)
|
||||
|
||||
if kernel_selector_str == FILE_HEAD_COMMENT:
|
||||
kernel_selector_str += f"if ({conditions})\n kernel = "
|
||||
else:
|
||||
kernel_selector_str += f"else if ({conditions})\n kernel = "
|
||||
|
||||
kernel_template2 = (
|
||||
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
|
||||
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
|
||||
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
|
||||
"{{is_zp_float}}>;"
|
||||
)
|
||||
|
||||
kernel_selector_str += (
|
||||
jinja2.Template(kernel_template2).render(
|
||||
for result_dict_tmp in [result_dict, sm_75_result_dict]:
|
||||
for (a_type, b_type, c_type), config_list in result_dict_tmp.items():
|
||||
all_template_str_list = []
|
||||
if not config_list:
|
||||
continue
|
||||
for config in config_list:
|
||||
s_type = config["s_type"]
|
||||
template_str = jinja2.Template(TEMPLATE).render(
|
||||
a_type_id=f"vllm::{a_type}.id()",
|
||||
b_type_id=f"vllm::{b_type}.id()",
|
||||
c_type_id=f"vllm::{c_type}.id()",
|
||||
s_type_id=f"vllm::{s_type}.id()",
|
||||
**config,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
all_template_str_list.append(template_str)
|
||||
|
||||
file_content = FILE_HEAD + "\n\n"
|
||||
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
|
||||
if a_type == "kFE4M3fn":
|
||||
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
else:
|
||||
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
conditions = [
|
||||
f"a_type == vllm::{a_type}",
|
||||
f"b_type == vllm::{b_type}",
|
||||
f"c_type == vllm::{c_type}",
|
||||
f"s_type == vllm::{s_type}",
|
||||
f"threads == {config['threads']}",
|
||||
f"thread_m_blocks == {config['thread_m_blocks']}",
|
||||
f"thread_n_blocks == {config['thread_n_blocks']}",
|
||||
f"thread_k_blocks == {config['thread_k_blocks']}",
|
||||
f"m_block_size_8 == {config['m_block_size_8']}",
|
||||
f"stages == {config['stages']}",
|
||||
f"group_blocks == {config['group_blocks']}",
|
||||
f"is_zp_float == {config['is_zp_float']}",
|
||||
]
|
||||
conditions = " && ".join(conditions)
|
||||
|
||||
filename = filename.lower()
|
||||
if kernel_selector_str == FILE_HEAD_COMMENT:
|
||||
kernel_selector_str += f"if ({conditions})\n kernel = "
|
||||
else:
|
||||
kernel_selector_str += f"else if ({conditions})\n kernel = "
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
|
||||
f.write(file_content)
|
||||
kernel_template2 = (
|
||||
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
|
||||
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
|
||||
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
|
||||
"{{is_zp_float}}>;"
|
||||
)
|
||||
|
||||
kernel_selector_str += (
|
||||
jinja2.Template(kernel_template2).render(
|
||||
a_type_id=f"vllm::{a_type}.id()",
|
||||
b_type_id=f"vllm::{b_type}.id()",
|
||||
c_type_id=f"vllm::{c_type}.id()",
|
||||
s_type_id=f"vllm::{s_type}.id()",
|
||||
**config,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
file_content = FILE_HEAD + "\n\n"
|
||||
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
|
||||
if a_type == "kFE4M3fn":
|
||||
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
elif result_dict_tmp is sm_75_result_dict:
|
||||
filename = f"sm75_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
else:
|
||||
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
|
||||
filename = filename.lower()
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
|
||||
f.write(file_content)
|
||||
|
||||
if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT:
|
||||
kernel_selector_str += (
|
||||
|
||||
@ -37,7 +37,7 @@ __global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){};
|
||||
|
||||
using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS);
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
||||
|
||||
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
|
||||
int const* __restrict__ perm_int_ptr,
|
||||
@ -148,7 +148,7 @@ typedef struct {
|
||||
|
||||
int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
|
||||
int prob_n, int prob_k, int num_bits, int group_size,
|
||||
bool has_act_order, bool is_k_full) {
|
||||
bool has_act_order, bool is_k_full, int stages) {
|
||||
bool cache_scales_chunk = has_act_order && !is_k_full;
|
||||
|
||||
int tb_n = th_config.thread_n;
|
||||
@ -166,28 +166,29 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
|
||||
|
||||
if (cache_scales_chunk) {
|
||||
int load_groups =
|
||||
tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
|
||||
tb_groups * stages * 2; // Chunk size is 2x pipeline over dim K
|
||||
load_groups = max(load_groups, 32); // We load at least 32 scale groups
|
||||
return load_groups * tb_n * 2;
|
||||
} else {
|
||||
int tb_scales = tb_groups * tb_n * 2;
|
||||
|
||||
return tb_scales * pipe_stages;
|
||||
return tb_scales * stages;
|
||||
}
|
||||
}
|
||||
|
||||
int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
|
||||
int prob_m, int prob_n, int prob_k, int num_bits,
|
||||
int group_size, bool has_act_order, bool is_k_full,
|
||||
int has_zp, int is_zp_float) {
|
||||
int has_zp, bool is_zp_float, bool is_a_8bit,
|
||||
int stages) {
|
||||
int pack_factor = 32 / num_bits;
|
||||
|
||||
// Get B size
|
||||
int tb_k = th_config.thread_k;
|
||||
int tb_n = th_config.thread_n;
|
||||
int tb_m = thread_m_blocks * 16;
|
||||
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2;
|
||||
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
|
||||
int sh_a_size = stages * (tb_m * tb_k) * (is_a_8bit ? 1 : 2);
|
||||
int sh_b_size = stages * (tb_k * tb_n / pack_factor) * 4;
|
||||
int sh_red_size = tb_m * (tb_n + 8) * 2;
|
||||
int sh_bias_size = tb_n * 2;
|
||||
int tmp_size =
|
||||
@ -196,8 +197,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
|
||||
|
||||
int sh_s_size =
|
||||
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
|
||||
group_size, has_act_order, is_k_full);
|
||||
int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0;
|
||||
group_size, has_act_order, is_k_full, stages);
|
||||
int sh_g_idx_size = has_act_order && !is_k_full ? stages * tb_k / 4 : 0;
|
||||
int sh_zp_size = 0;
|
||||
if (has_zp) {
|
||||
if (is_zp_float)
|
||||
@ -217,7 +218,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
|
||||
bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
|
||||
int prob_m, int prob_n, int prob_k, int num_bits,
|
||||
int group_size, bool has_act_order, bool is_k_full,
|
||||
int has_zp, int is_zp_float, int max_shared_mem) {
|
||||
int has_zp, bool is_zp_float, bool is_a_8bit, int stages,
|
||||
int max_shared_mem) {
|
||||
// Sanity
|
||||
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
|
||||
th_config.num_threads == -1) {
|
||||
@ -242,7 +244,7 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
|
||||
// Check that pipeline fits into cache
|
||||
int cache_size = get_kernel_cache_size(
|
||||
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size,
|
||||
has_act_order, is_k_full, has_zp, is_zp_float);
|
||||
has_act_order, is_k_full, has_zp, is_zp_float, is_a_8bit, stages);
|
||||
return cache_size <= max_shared_mem;
|
||||
}
|
||||
|
||||
@ -251,7 +253,7 @@ MarlinFuncPtr get_marlin_kernel(
|
||||
const vllm::ScalarType c_type, const vllm::ScalarType s_type,
|
||||
int thread_m_blocks, int thread_n_blocks, int thread_k_blocks,
|
||||
bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks,
|
||||
int threads, bool is_zp_float) {
|
||||
int threads, bool is_zp_float, int stages) {
|
||||
int num_bits = b_type.size_bits();
|
||||
auto kernel = MarlinDefault;
|
||||
|
||||
@ -265,7 +267,8 @@ exec_config_t determine_exec_config(
|
||||
const vllm::ScalarType& c_type, const vllm::ScalarType& s_type, int prob_m,
|
||||
int prob_n, int prob_k, int thread_m_blocks, bool m_block_size_8,
|
||||
int num_bits, int group_size, bool has_act_order, bool is_k_full,
|
||||
bool has_zp, bool is_zp_float, int max_shared_mem, int sms) {
|
||||
bool has_zp, bool is_zp_float, int is_a_8bit, int stages,
|
||||
int max_shared_mem, int sms) {
|
||||
exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}};
|
||||
thread_config_t* thread_configs = thread_m_blocks > 1
|
||||
? large_batch_thread_configs
|
||||
@ -280,13 +283,15 @@ exec_config_t determine_exec_config(
|
||||
|
||||
if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k,
|
||||
num_bits, group_size, has_act_order, is_k_full, has_zp,
|
||||
is_zp_float, max_shared_mem - 512)) {
|
||||
is_zp_float, is_a_8bit, stages,
|
||||
max_shared_mem - 512)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int cache_size = get_kernel_cache_size(
|
||||
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits,
|
||||
group_size, has_act_order, is_k_full, has_zp, is_zp_float);
|
||||
int cache_size = get_kernel_cache_size(th_config, thread_m_blocks, prob_m,
|
||||
prob_n, prob_k, num_bits, group_size,
|
||||
has_act_order, is_k_full, has_zp,
|
||||
is_zp_float, is_a_8bit, stages);
|
||||
|
||||
int group_blocks = 0;
|
||||
if (!has_act_order) {
|
||||
@ -297,14 +302,10 @@ exec_config_t determine_exec_config(
|
||||
get_marlin_kernel(a_type, b_type, c_type, s_type, thread_m_blocks,
|
||||
th_config.thread_n / 16, th_config.thread_k / 16,
|
||||
m_block_size_8, has_act_order, has_zp, group_blocks,
|
||||
th_config.num_threads, is_zp_float);
|
||||
th_config.num_threads, is_zp_float, stages);
|
||||
|
||||
if (kernel == MarlinDefault) continue;
|
||||
|
||||
// int m_tiles = div_ceil(prob_m, thread_m_blocks * 16);
|
||||
// int n_tiles = prob_n / th_config.thread_n;
|
||||
// int k_tiles = prob_k / th_config.thread_k;
|
||||
|
||||
return {1, th_config};
|
||||
}
|
||||
|
||||
@ -321,6 +322,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
int group_size, int dev, cudaStream_t stream, int thread_k_init,
|
||||
int thread_n_init, int sms, bool use_atomic_add,
|
||||
bool use_fp32_reduce, bool is_zp_float) {
|
||||
bool is_a_8bit = a_type.size_bits() == 8;
|
||||
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
||||
", ", prob_n, ", ", prob_k, "]");
|
||||
|
||||
@ -389,8 +391,14 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
dev);
|
||||
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
|
||||
dev);
|
||||
TORCH_CHECK(major_capability * 10 + minor_capability >= 80,
|
||||
"marlin kernel only support Ampere or newer GPUs.");
|
||||
TORCH_CHECK(major_capability * 10 + minor_capability >= 75,
|
||||
"marlin kernel only support Turing or newer GPUs.");
|
||||
int stages = 4;
|
||||
if (major_capability == 7 && minor_capability == 5) {
|
||||
stages = 2;
|
||||
TORCH_CHECK(a_type == vllm::kFloat16 || a_type == vllm::kS8,
|
||||
"Turing only support FP16 or INT8 activation.");
|
||||
}
|
||||
if (a_type == vllm::kFE4M3fn) {
|
||||
TORCH_CHECK(
|
||||
major_capability * 10 + minor_capability == 89 ||
|
||||
@ -431,7 +439,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
exec_cfg = determine_exec_config(
|
||||
a_type, b_type, c_type, s_type, prob_m_split, prob_n, prob_k,
|
||||
thread_m_blocks, m_block_size_8, num_bits, group_size, has_act_order,
|
||||
is_k_full, has_zp, is_zp_float, max_shared_mem, sms);
|
||||
is_k_full, has_zp, is_zp_float, is_a_8bit, stages, max_shared_mem,
|
||||
sms);
|
||||
thread_tfg = exec_cfg.tb_cfg;
|
||||
if (thread_tfg.thread_n != -1) {
|
||||
if (prob_n / thread_tfg.thread_n *
|
||||
@ -440,7 +449,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
if (is_valid_config({128, 64, 128}, thread_m_blocks, prob_m_split,
|
||||
prob_n, prob_k, num_bits, group_size,
|
||||
has_act_order, is_k_full, has_zp, is_zp_float,
|
||||
max_shared_mem_new)) {
|
||||
is_a_8bit, stages, max_shared_mem_new)) {
|
||||
thread_tfg = {128, 64, 128};
|
||||
exec_cfg = {1, thread_tfg};
|
||||
}
|
||||
@ -466,7 +475,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
TORCH_CHECK(
|
||||
is_valid_config(thread_tfg, thread_m_blocks, prob_m_split, prob_n,
|
||||
prob_k, num_bits, group_size, has_act_order, is_k_full,
|
||||
has_zp, is_zp_float, max_shared_mem_new),
|
||||
has_zp, is_zp_float, is_a_8bit, stages,
|
||||
max_shared_mem_new),
|
||||
"Invalid thread config: thread_m_blocks = ", thread_m_blocks,
|
||||
", thread_k = ", thread_tfg.thread_k,
|
||||
", thread_n = ", thread_tfg.thread_n,
|
||||
@ -475,12 +485,12 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
", prob_m_split = ", prob_m_split, ", group_size = ", group_size,
|
||||
", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
|
||||
", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float,
|
||||
", max_shared_mem_new = ", max_shared_mem_new);
|
||||
", stages = ", stages, ", max_shared_mem_new = ", max_shared_mem_new);
|
||||
|
||||
auto kernel = get_marlin_kernel(
|
||||
a_type, b_type, c_type, s_type, thread_m_blocks, thread_n_blocks,
|
||||
thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks,
|
||||
num_threads, is_zp_float);
|
||||
num_threads, is_zp_float, stages);
|
||||
|
||||
if (kernel == MarlinDefault) {
|
||||
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
|
||||
|
||||
@ -1,17 +1,19 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/all.h>
|
||||
#ifndef _marlin_cuh
|
||||
#define _marlin_cuh
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <iostream>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <iostream>
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin
|
||||
#endif
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin
|
||||
#endif
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
@ -51,9 +53,51 @@ using I4 = Vec<int, 4>;
|
||||
|
||||
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
// No support for async
|
||||
#else
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
|
||||
__device__ inline void cp_async1_ca_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
if (pred) {
|
||||
reinterpret_cast<int32_t*>(smem_ptr)[0] =
|
||||
reinterpret_cast<const int32_t*>(glob_ptr)[0];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void cp_async2_ca_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
if (pred) {
|
||||
reinterpret_cast<int64_t*>(smem_ptr)[0] =
|
||||
reinterpret_cast<const int64_t*>(glob_ptr)[0];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void cp_async4_ca_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
if (pred) {
|
||||
reinterpret_cast<int4*>(smem_ptr)[0] =
|
||||
reinterpret_cast<const int4*>(glob_ptr)[0];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
if (pred) {
|
||||
reinterpret_cast<int4*>(smem_ptr)[0] =
|
||||
reinterpret_cast<const int4*>(glob_ptr)[0];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
||||
reinterpret_cast<int4*>(smem_ptr)[0] =
|
||||
reinterpret_cast<const int4*>(glob_ptr)[0];
|
||||
}
|
||||
|
||||
__device__ inline void cp_async_fence() {}
|
||||
|
||||
template <int n>
|
||||
__device__ inline void cp_async_wait() {}
|
||||
|
||||
#else
|
||||
|
||||
__device__ inline void cp_async1_ca_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
@ -126,6 +170,8 @@ __device__ inline void cp_async_wait() {
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
||||
}
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
|
||||
#endif
|
||||
269
csrc/quantization/gptq_marlin/marlin_mma.h
Normal file
269
csrc/quantization/gptq_marlin/marlin_mma.h
Normal file
@ -0,0 +1,269 @@
|
||||
|
||||
#include "marlin_dtypes.cuh"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
|
||||
// output/accumulation.
|
||||
template <vllm::ScalarTypeId type_id, bool use_fp16_accum, int k_size = 16>
|
||||
__device__ inline void mma(
|
||||
const typename MarlinScalarType<type_id>::FragA& a_frag,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b,
|
||||
typename MarlinScalarType<type_id>::FragC& frag_c, int idx = 0) {
|
||||
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
|
||||
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
||||
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
|
||||
if constexpr (!std::is_same<scalar_t, half>::value || k_size != 16) {
|
||||
static_assert(!use_fp16_accum);
|
||||
}
|
||||
|
||||
if constexpr (k_size == 16) {
|
||||
if constexpr (std::is_same<scalar_t, half>::value && !use_fp16_accum) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(b[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
|
||||
"f"(c[3]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[2]), "r"(a[3]), "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
|
||||
"f"(c[3]));
|
||||
#else
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
#endif
|
||||
} else if constexpr (std::is_same<scalar_t, half>::value &&
|
||||
use_fp16_accum) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
uint32_t* c = reinterpret_cast<uint32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
|
||||
"{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(b[0]), "r"(c[0]), "r"(c[1]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
|
||||
"{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(a[2]), "r"(a[3]), "r"(b[1]), "r"(c[0]), "r"(c[1]));
|
||||
#else
|
||||
uint32_t* c = reinterpret_cast<uint32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
|
||||
"{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"r"(c[0]), "r"(c[1]));
|
||||
#endif
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "f"(c[0]),
|
||||
"f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "r"(c[0]),
|
||||
"r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
}
|
||||
} else if (k_size == 32) {
|
||||
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(a[0]), "r"(b[0]), "r"(c[0]), "r"(c[1]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
||||
: "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(a[1]), "r"(b[0]), "r"(c[2]), "r"(c[3]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(a[2]), "r"(b[1]), "r"(c[0]), "r"(c[1]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
||||
: "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(a[3]), "r"(b[1]), "r"(c[2]), "r"(c[3]));
|
||||
#else
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <vllm::ScalarTypeId type_id, bool use_fp16_accum, int k_size = 16>
|
||||
__device__ inline void mma_trans(
|
||||
const typename MarlinScalarType<type_id>::FragA& a_frag,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b2,
|
||||
typename MarlinScalarType<type_id>::FragC& frag_c) {
|
||||
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
|
||||
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
||||
const uint32_t* b2 = reinterpret_cast<const uint32_t*>(&frag_b2);
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
|
||||
if constexpr (!std::is_same<scalar_t, half>::value || k_size != 16) {
|
||||
static_assert(!use_fp16_accum);
|
||||
}
|
||||
|
||||
if constexpr (k_size == 16) {
|
||||
if constexpr (std::is_same<scalar_t, half>::value && !use_fp16_accum) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
|
||||
"f"(c[3]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[1]), "r"(b2[1]), "r"(a[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
|
||||
"f"(c[3]));
|
||||
#else
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
#endif
|
||||
} else if constexpr (std::is_same<scalar_t, half>::value &&
|
||||
use_fp16_accum) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
uint32_t* c = reinterpret_cast<uint32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
|
||||
"{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
|
||||
"{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(b[1]), "r"(b2[1]), "r"(a[1]), "r"(c[0]), "r"(c[1]));
|
||||
#else
|
||||
uint32_t* c = reinterpret_cast<uint32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
|
||||
"{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"r"(c[0]), "r"(c[1]));
|
||||
#endif
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
|
||||
"f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]),
|
||||
"r"(c[3]));
|
||||
}
|
||||
} else {
|
||||
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(b[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
||||
: "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(b2[1]), "r"(a[0]), "r"(c[2]), "r"(c[3]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(b[0]), "r"(a[1]), "r"(c[0]), "r"(c[1]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
||||
: "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(b2[1]), "r"(a[1]), "r"(c[2]), "r"(c[3]));
|
||||
#else
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
@ -26,6 +26,7 @@
|
||||
#include "marlin.cuh"
|
||||
#include "marlin_dtypes.cuh"
|
||||
#include "dequant.h"
|
||||
#include "marlin_mma.h"
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||
@ -35,7 +36,7 @@
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
||||
|
||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const vllm::ScalarTypeId b_type_id, // weight MarlinScalarType id
|
||||
@ -75,137 +76,6 @@ __global__ void Marlin(
|
||||
|
||||
#else
|
||||
|
||||
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
|
||||
// output/accumulation.
|
||||
template <vllm::ScalarTypeId type_id, int k_size = 16>
|
||||
__device__ inline void mma(
|
||||
const typename MarlinScalarType<type_id>::FragA& a_frag,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b,
|
||||
typename MarlinScalarType<type_id>::FragC& frag_c, int idx = 0) {
|
||||
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
|
||||
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
||||
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
|
||||
if constexpr (k_size == 16) {
|
||||
if constexpr (std::is_same<scalar_t, half>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "f"(c[0]),
|
||||
"f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "r"(c[0]),
|
||||
"r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
}
|
||||
} else if (k_size == 32) {
|
||||
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <vllm::ScalarTypeId type_id, int k_size = 16>
|
||||
__device__ inline void mma_trans(
|
||||
const typename MarlinScalarType<type_id>::FragA& a_frag,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b2,
|
||||
typename MarlinScalarType<type_id>::FragC& frag_c) {
|
||||
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
|
||||
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
||||
const uint32_t* b2 = reinterpret_cast<const uint32_t*>(&frag_b2);
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
|
||||
if constexpr (k_size == 16) {
|
||||
if constexpr (std::is_same<scalar_t, half>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
|
||||
"f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]),
|
||||
"r"(c[3]));
|
||||
}
|
||||
} else {
|
||||
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
||||
// memory, directly in tensor core layout.
|
||||
template <int count, vllm::ScalarTypeId type_id>
|
||||
@ -415,6 +285,17 @@ __global__ void Marlin(
|
||||
if constexpr (a_type_id == vllm::kFE4M3fn.id()) return;
|
||||
#endif
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
// Turing TensorCore only supports fp16 and int8
|
||||
if constexpr (a_type_id != vllm::kFloat16.id() && a_type_id != vllm::kS8.id())
|
||||
return;
|
||||
#endif
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
constexpr bool use_fp16_accum = a_type_id == vllm::kFloat16.id();
|
||||
#else
|
||||
constexpr bool use_fp16_accum = false;
|
||||
#endif
|
||||
using Adtype = MarlinScalarType<a_type_id>;
|
||||
using Cdtype = MarlinScalarType<c_type_id>;
|
||||
const int4* A = A0;
|
||||
@ -873,10 +754,6 @@ __global__ void Marlin(
|
||||
constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride)
|
||||
: (stages * s_sh_stage);
|
||||
int4* sh_s = sh_zp + (stages * zp_sh_stage);
|
||||
// shared memory reused by reduction should be smaller than
|
||||
// shared memory used by weight.
|
||||
static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <=
|
||||
stages * b_sh_stage);
|
||||
int4* sh_a = sh_s + sh_s_size;
|
||||
|
||||
// Register storage for double buffer of shared memory reads.
|
||||
@ -1395,11 +1272,13 @@ __global__ void Marlin(
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks; i++) {
|
||||
if constexpr (m_block_size_8) {
|
||||
mma_trans<a_type_id>(frag_a[k2][i], frag_b0, frag_b1,
|
||||
frag_c[i][j][0]);
|
||||
mma_trans<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b0, frag_b1,
|
||||
frag_c[i][j][0]);
|
||||
} else {
|
||||
mma<a_type_id>(frag_a[k2][i], frag_b0, frag_c[i][j][0]);
|
||||
mma<a_type_id>(frag_a[k2][i], frag_b1, frag_c[i][j][1]);
|
||||
mma<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b0,
|
||||
frag_c[i][j][0]);
|
||||
mma<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b1,
|
||||
frag_c[i][j][1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1433,10 +1312,12 @@ __global__ void Marlin(
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks; i++) {
|
||||
mma<a_type_id, 32>(frag_a[k2][i], frag_b[0],
|
||||
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]);
|
||||
mma<a_type_id, 32>(frag_a[k2][i], frag_b[1],
|
||||
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]);
|
||||
mma<a_type_id, false, 32>(
|
||||
frag_a[k2][i], frag_b[0],
|
||||
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]);
|
||||
mma<a_type_id, false, 32>(
|
||||
frag_a[k2][i], frag_b[1],
|
||||
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]);
|
||||
}
|
||||
|
||||
if constexpr (group_blocks != -1) {
|
||||
@ -1956,6 +1837,21 @@ __global__ void Marlin(
|
||||
// While this pattern may not be the most readable, other ways of writing
|
||||
// the loop seemed to noticeably worse performance after compilation.
|
||||
if (slice_iters == 0) {
|
||||
// convert fp16 accum to fp32 for reduction
|
||||
if constexpr (use_fp16_accum) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < (thread_m_blocks * (is_a_8bit ? 2 : 4) * 2); i++) {
|
||||
float* frag_c_part_float = reinterpret_cast<float*>(frag_c) + i * 4;
|
||||
scalar_t* frag_c_part_half =
|
||||
reinterpret_cast<scalar_t*>(frag_c_part_float);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 3; i >= 0; i--) {
|
||||
frag_c_part_float[i] = Cdtype::num2float(frag_c_part_half[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (is_a_8bit) {
|
||||
float frag_a_s[2 * thread_m_blocks];
|
||||
|
||||
|
||||
@ -550,8 +550,8 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowPrefill(
|
||||
int rowEnd = rowEnds[rowIdx];
|
||||
|
||||
// Local pointers to this block
|
||||
outIndices += rowIdx * topK;
|
||||
logits += rowIdx * stride0;
|
||||
outIndices += static_cast<int64_t>(rowIdx) * topK;
|
||||
logits += static_cast<int64_t>(rowIdx) * stride0;
|
||||
|
||||
topKPerRowJob<kNumThreadsPerBlock, kNumBins, useRadixSort>(
|
||||
nullptr, logits, rowStart, rowEnd, outIndices, nullptr, stride1, topK);
|
||||
@ -576,19 +576,21 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode(
|
||||
|
||||
// Local pointers to this block
|
||||
if constexpr (!multipleBlocksPerRow && !mergeBlocks) {
|
||||
outIndices += rowIdx * topK;
|
||||
outIndices += static_cast<int64_t>(rowIdx) * topK;
|
||||
} else if constexpr (multipleBlocksPerRow) {
|
||||
const auto blockSize = rowEnd / gridDim.y; // 16384 / 2 = 8192
|
||||
rowStart = blockSize * blockIdx.y; // 8192 * 1 = 8192
|
||||
rowEnd = gridDim.y == blockIdx.y + 1 ? rowEnd : rowStart + blockSize;
|
||||
outIndices += rowIdx * gridDim.y * topK + blockIdx.y * topK;
|
||||
outLogits += rowIdx * gridDim.y * topK + blockIdx.y * topK;
|
||||
outIndices +=
|
||||
static_cast<int64_t>(rowIdx) * gridDim.y * topK + blockIdx.y * topK;
|
||||
outLogits +=
|
||||
static_cast<int64_t>(rowIdx) * gridDim.y * topK + blockIdx.y * topK;
|
||||
} else if constexpr (mergeBlocks) {
|
||||
rowEnd = numBlocksToMerge * topK;
|
||||
indices += rowIdx * numBlocksToMerge * topK;
|
||||
outIndices += rowIdx * topK;
|
||||
indices += static_cast<int64_t>(rowIdx) * numBlocksToMerge * topK;
|
||||
outIndices += static_cast<int64_t>(rowIdx) * topK;
|
||||
}
|
||||
logits += rowIdx * stride0;
|
||||
logits += static_cast<int64_t>(rowIdx) * stride0;
|
||||
|
||||
topKPerRowJob<kNumThreadsPerBlock, kNumBins, useRadixSort,
|
||||
multipleBlocksPerRow, mergeBlocks>(
|
||||
|
||||
@ -621,7 +621,7 @@ ENV UV_HTTP_TIMEOUT=500
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=requirements/kv_connectors.txt,target=/tmp/kv_connectors.txt,ro \
|
||||
if [ "$INSTALL_KV_CONNECTORS" = "true" ]; then \
|
||||
uv pip install --system -r /tmp/kv_connectors.txt; \
|
||||
uv pip install --system -r /tmp/kv_connectors.txt || true; \
|
||||
fi
|
||||
|
||||
ENV VLLM_USAGE_SOURCE production-docker-image
|
||||
|
||||
@ -50,5 +50,5 @@ ijson # Required for mistral streaming tool parser
|
||||
setproctitle # Used to set process names for better debugging and monitoring
|
||||
openai-harmony >= 0.0.3 # Required for gpt-oss
|
||||
anthropic == 0.71.0
|
||||
model-hosting-container-standards >= 0.1.9, < 1.0.0
|
||||
mcp
|
||||
model-hosting-container-standards >= 0.1.10, < 1.0.0
|
||||
mcp
|
||||
|
||||
@ -523,6 +523,8 @@ CUSTOM_OPS_QUANT_RMS_NORM = ["+quant_fp8,+rms_norm"]
|
||||
list[tuple[Any, ...]](flat_product(MODELS_GROUP_FP8, CUSTOM_OPS_QUANT_RMS_NORM)),
|
||||
)
|
||||
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
|
||||
# TODO: remove skip after we fix the fusion thoroughly
|
||||
@pytest.mark.skipif(is_blackwell(), reason="Temporarily disabled on Blackwell")
|
||||
def test_rms_group_quant(
|
||||
model_name: str,
|
||||
model_kwargs: dict[str, Any],
|
||||
@ -562,7 +564,9 @@ def test_rms_group_quant(
|
||||
splitting_ops=splitting_ops,
|
||||
# Common
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
pass_config=PassConfig(eliminate_noops=True, enable_fusion=True),
|
||||
pass_config=PassConfig(
|
||||
fuse_norm_quant=True, fuse_act_quant=True, eliminate_noops=True
|
||||
),
|
||||
# Inductor caches custom passes by default as well via uuid
|
||||
inductor_compile_config={"force_disable_caches": True},
|
||||
)
|
||||
|
||||
@ -244,3 +244,35 @@ async def test_audio_with_timestamp(mary_had_lamb, whisper_client):
|
||||
)
|
||||
assert transcription.segments is not None
|
||||
assert len(transcription.segments) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_with_max_tokens(whisper_client, mary_had_lamb):
|
||||
transcription = await whisper_client.audio.transcriptions.create(
|
||||
model=MODEL_NAME,
|
||||
file=mary_had_lamb,
|
||||
language="en",
|
||||
response_format="text",
|
||||
temperature=0.0,
|
||||
extra_body={"max_completion_tokens": 1},
|
||||
)
|
||||
out = json.loads(transcription)
|
||||
out_text = out["text"]
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tok = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
out_tokens = tok(out_text, add_special_tokens=False)["input_ids"]
|
||||
assert len(out_tokens) == 1
|
||||
# max_completion_tokens > max_model_len
|
||||
transcription = await whisper_client.audio.transcriptions.create(
|
||||
model=MODEL_NAME,
|
||||
file=mary_had_lamb,
|
||||
language="en",
|
||||
response_format="text",
|
||||
temperature=0.0,
|
||||
extra_body={"max_completion_tokens": int(1e6)},
|
||||
)
|
||||
out = json.loads(transcription)
|
||||
out_text = out["text"]
|
||||
out_tokens = tok(out_text, add_special_tokens=False)["input_ids"]
|
||||
assert len(out_tokens) < 450 # ~Whisper max output len
|
||||
|
||||
@ -227,3 +227,36 @@ async def test_long_audio_request(foscolo, client_and_model):
|
||||
)
|
||||
out = json.loads(translation)["text"].strip().lower()
|
||||
assert out.count("greek sea") == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_with_max_tokens(mary_had_lamb, client_and_model):
|
||||
client, model_name = client_and_model
|
||||
transcription = await client.audio.translations.create(
|
||||
model=model_name,
|
||||
file=mary_had_lamb,
|
||||
response_format="text",
|
||||
temperature=0.0,
|
||||
extra_body={"max_completion_tokens": 1},
|
||||
)
|
||||
out = json.loads(transcription)
|
||||
out_text = out["text"]
|
||||
print(out_text)
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tok = AutoTokenizer.from_pretrained(model_name)
|
||||
out_tokens = tok(out_text, add_special_tokens=False)["input_ids"]
|
||||
assert len(out_tokens) == 1
|
||||
# max_completion_tokens > max_model_len
|
||||
transcription = await client.audio.transcriptions.create(
|
||||
model=model_name,
|
||||
file=mary_had_lamb,
|
||||
response_format="text",
|
||||
temperature=0.0,
|
||||
extra_body={"max_completion_tokens": int(1e6)},
|
||||
)
|
||||
out = json.loads(transcription)
|
||||
out_text = out["text"]
|
||||
print(out_text)
|
||||
out_tokens = tok(out_text, add_special_tokens=False)["input_ids"]
|
||||
assert len(out_tokens) < 450 # ~Whisper max output len
|
||||
|
||||
@ -7,9 +7,8 @@ This directory contains a replacement for the lm-eval-harness GSM8K evaluation,
|
||||
### Run tests with pytest (like buildkite)
|
||||
|
||||
```bash
|
||||
pytest -s -v tests/gsm8k/test_gsm8k_correctness.py \
|
||||
--config-list-file=configs/models-small.txt \
|
||||
--tp-size=1
|
||||
pytest -s -v tests/evals/gsm8k/test_gsm8k_correctness.py \
|
||||
--config-list-file=configs/models-small.txt
|
||||
```
|
||||
|
||||
### Run standalone evaluation script
|
||||
@ -31,5 +30,11 @@ model_name: "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
accuracy_threshold: 0.54 # Minimum expected accuracy
|
||||
num_questions: 1319 # Number of questions (default: full test set)
|
||||
num_fewshot: 5 # Few-shot examples from train set
|
||||
max_model_len: 4096 # Model context length
|
||||
server_args: "--max-model-len 4096 --tensor-parallel-size 2" # Server arguments
|
||||
env: # Environment variables (optional)
|
||||
VLLM_USE_FLASHINFER_MOE_FP4: "1"
|
||||
```
|
||||
|
||||
The `server_args` field accepts any arguments that can be passed to `vllm serve`.
|
||||
|
||||
The `env` field accepts a dictionary of environment variables to set for the server process.
|
||||
|
||||
@ -2,5 +2,4 @@ model_name: "RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8"
|
||||
accuracy_threshold: 0.72
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
max_model_len: 4096
|
||||
|
||||
server_args: "--enforce-eager --max-model-len 4096"
|
||||
|
||||
@ -2,4 +2,4 @@ model_name: "nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test"
|
||||
accuracy_threshold: 0.74
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
max_model_len: 4096
|
||||
server_args: "--enforce-eager --max-model-len 4096"
|
||||
|
||||
@ -2,4 +2,4 @@ model_name: "RedHatAI/Llama-3.2-1B-Instruct-quantized.w8a8"
|
||||
accuracy_threshold: 0.31
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
max_model_len: 4096
|
||||
server_args: "--enforce-eager --max-model-len 4096"
|
||||
|
||||
@ -2,4 +2,4 @@ model_name: "nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16"
|
||||
accuracy_threshold: 0.45
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
max_model_len: 4096
|
||||
server_args: "--enforce-eager --max-model-len 4096"
|
||||
|
||||
@ -2,4 +2,4 @@ model_name: "RedHatAI/Qwen2.5-VL-3B-Instruct-FP8-Dynamic"
|
||||
accuracy_threshold: 0.60
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
max_model_len: 4096
|
||||
server_args: "--enforce-eager --max-model-len 4096"
|
||||
|
||||
@ -2,4 +2,4 @@ model_name: "Qwen/Qwen3-0.6B-FP8"
|
||||
accuracy_threshold: 0.375
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
max_model_len: 4096
|
||||
server_args: "--enforce-eager --max-model-len 4096"
|
||||
|
||||
@ -2,5 +2,4 @@ model_name: "nvidia/Qwen3-30B-A3B-FP4"
|
||||
accuracy_threshold: 0.89
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
max_model_len: 4096
|
||||
|
||||
server_args: "--enforce-eager --max-model-len 4096"
|
||||
|
||||
12
tests/evals/gsm8k/configs/Qwen3-Next-80B-A3B-NVFP4-EP2.yaml
Normal file
12
tests/evals/gsm8k/configs/Qwen3-Next-80B-A3B-NVFP4-EP2.yaml
Normal file
@ -0,0 +1,12 @@
|
||||
model_name: "nm-testing/Qwen3-Next-80B-A3B-Instruct-NVFP4"
|
||||
accuracy_threshold: 0.75
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
server_args: >-
|
||||
--enforce-eager
|
||||
--max-model-len 4096
|
||||
--tensor-parallel-size 2
|
||||
--enable-expert-parallel
|
||||
--speculative-config '{"method":"qwen3_next_mtp","num_speculative_tokens":1}'
|
||||
env:
|
||||
VLLM_USE_FLASHINFER_MOE_FP4: "1"
|
||||
@ -3,3 +3,4 @@ Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml
|
||||
Qwen1.5-MoE-W4A16-CT.yaml
|
||||
DeepSeek-V2-Lite-Instruct-FP8.yaml
|
||||
Qwen3-30B-A3B-NVFP4.yaml
|
||||
Qwen3-Next-80B-A3B-NVFP4-EP2.yaml
|
||||
|
||||
@ -11,14 +11,12 @@ def pytest_addoption(parser):
|
||||
default="configs/models-small.txt",
|
||||
help="File containing list of config files to test",
|
||||
)
|
||||
parser.addoption("--tp-size", default=1, type=int, help="Tensor parallel size")
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc):
|
||||
"""Generate test parameters from config files."""
|
||||
if "config_filename" in metafunc.fixturenames:
|
||||
config_list_file = metafunc.config.getoption("--config-list-file")
|
||||
tp_size = metafunc.config.getoption("--tp-size")
|
||||
|
||||
# Handle both relative and absolute paths
|
||||
config_list_path = Path(config_list_file)
|
||||
@ -55,9 +53,9 @@ def pytest_generate_tests(metafunc):
|
||||
# Generate test parameters
|
||||
if config_files:
|
||||
metafunc.parametrize(
|
||||
["config_filename", "tp_size"],
|
||||
[(config_file, int(tp_size)) for config_file in config_files],
|
||||
ids=[f"{config_file.stem}-tp{tp_size}" for config_file in config_files],
|
||||
"config_filename",
|
||||
config_files,
|
||||
ids=[config_file.stem for config_file in config_files],
|
||||
)
|
||||
else:
|
||||
print("No config files found, test will be skipped")
|
||||
|
||||
@ -5,30 +5,31 @@ GSM8K evaluation using vLLM server and isolated GSM8K script.
|
||||
Replacement for lm-eval-harness with better performance and control.
|
||||
|
||||
Usage:
|
||||
pytest -s -v test_gsm8k_correctness.py \
|
||||
--config-list-file=configs/models-small.txt \
|
||||
--tp-size=1
|
||||
pytest -s -v tests/evals/gsm8k/test_gsm8k_correctness.py \
|
||||
--config-list-file=configs/models-small.txt
|
||||
"""
|
||||
|
||||
import shlex
|
||||
|
||||
import yaml
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
from .gsm8k_eval import evaluate_gsm8k
|
||||
|
||||
RTOL = 0.08 # Relative tolerance for accuracy comparison
|
||||
TOL = 0.08 # Absolute tolerance for accuracy comparison
|
||||
|
||||
|
||||
def launch_gsm8k_eval(eval_config, server_url, tp_size):
|
||||
"""Launch GSM8K evaluation using our isolated script."""
|
||||
def run_gsm8k_eval(eval_config: dict, server_url: str) -> dict:
|
||||
"""Run GSM8K evaluation using our isolated script."""
|
||||
# Extract host and port from server URL
|
||||
if "://" in server_url:
|
||||
server_url = server_url.split("://")[1]
|
||||
|
||||
host_port = server_url.split("/")[0] # Remove path if present
|
||||
if ":" in host_port:
|
||||
host, port = host_port.split(":")
|
||||
port = int(port)
|
||||
host, p = host_port.split(":")
|
||||
port = int(p)
|
||||
else:
|
||||
host = host_port
|
||||
port = 8000
|
||||
@ -48,46 +49,57 @@ def launch_gsm8k_eval(eval_config, server_url, tp_size):
|
||||
return results
|
||||
|
||||
|
||||
def test_gsm8k_correctness_param(config_filename, tp_size):
|
||||
def test_gsm8k_correctness(config_filename):
|
||||
"""Test GSM8K correctness for a given model configuration."""
|
||||
eval_config = yaml.safe_load(config_filename.read_text(encoding="utf-8"))
|
||||
|
||||
# Server arguments
|
||||
server_args = [
|
||||
"--max-model-len",
|
||||
str(eval_config.get("max_model_len", 4096)),
|
||||
"--enforce-eager",
|
||||
"--trust-remote-code",
|
||||
"--tensor-parallel-size",
|
||||
str(tp_size),
|
||||
]
|
||||
# Parse server arguments from config (use shlex to handle quoted strings)
|
||||
server_args_str = eval_config.get("server_args", "")
|
||||
server_args = shlex.split(server_args_str) if server_args_str else []
|
||||
|
||||
# Add standard server arguments
|
||||
server_args.extend(
|
||||
[
|
||||
"--trust-remote-code",
|
||||
]
|
||||
)
|
||||
|
||||
env_dict = eval_config.get("env", None)
|
||||
|
||||
print(f"Starting GSM8K evaluation for model: {eval_config['model_name']}")
|
||||
print(f"Expected metric threshold: {eval_config['accuracy_threshold']}")
|
||||
print(f"Number of questions: {eval_config['num_questions']}")
|
||||
print(f"Number of few-shot examples: {eval_config['num_fewshot']}")
|
||||
print(f"Server args: {' '.join(server_args)}")
|
||||
|
||||
# Launch server and run evaluation
|
||||
with RemoteOpenAIServer(
|
||||
eval_config["model_name"], server_args, env_dict=env_dict, max_wait_seconds=480
|
||||
eval_config["model_name"],
|
||||
server_args,
|
||||
env_dict=env_dict,
|
||||
max_wait_seconds=600,
|
||||
) as remote_server:
|
||||
server_url = remote_server.url_for("v1")
|
||||
print(f"Server started at: {server_url}")
|
||||
|
||||
results = launch_gsm8k_eval(eval_config, server_url, tp_size)
|
||||
results = run_gsm8k_eval(eval_config, server_url)
|
||||
|
||||
# Check accuracy against threshold
|
||||
measured_accuracy = results["accuracy"]
|
||||
expected_accuracy = eval_config["accuracy_threshold"]
|
||||
measured_metric = results["accuracy"]
|
||||
expected_metric = eval_config["accuracy_threshold"]
|
||||
|
||||
print(f"GSM8K Results for {eval_config['model_name']}:")
|
||||
print(f" Accuracy: {measured_accuracy:.3f}")
|
||||
print(f" Expected: {expected_accuracy:.3f}")
|
||||
print(f" Measured metric: {measured_metric:.4f}")
|
||||
print(f" Expected metric: {expected_metric:.4f}")
|
||||
print(f" Tolerance: {TOL:.4f}")
|
||||
print(f" Questions: {results['num_questions']}")
|
||||
print(f" Invalid rate: {results['invalid_rate']:.3f}")
|
||||
print(f" Latency: {results['latency']:.1f}s")
|
||||
print(f" QPS: {results['questions_per_second']:.1f}")
|
||||
|
||||
# Verify accuracy is within tolerance
|
||||
assert measured_accuracy >= expected_accuracy - RTOL, (
|
||||
f"Accuracy too low: {measured_accuracy:.3f} < "
|
||||
f"{expected_accuracy:.3f} - {RTOL:.3f}"
|
||||
# Verify metric is within tolerance
|
||||
assert measured_metric >= expected_metric - TOL, (
|
||||
f"GSM8K metric too low: {measured_metric:.4f} < "
|
||||
f"{expected_metric:.4f} - {TOL:.4f} = {expected_metric - TOL:.4f}"
|
||||
)
|
||||
|
||||
print(f"✅ GSM8K test passed for {eval_config['model_name']}")
|
||||
|
||||
@ -60,12 +60,12 @@ def test_profiling(model_id: str, max_model_len: int):
|
||||
total_num_patches.item() + num_tiles.item() + 3
|
||||
) # image start, image, image end
|
||||
|
||||
profiled_tokens = profiler.get_mm_max_contiguous_tokens(
|
||||
profiled_tokens = profiler.get_mm_max_tokens(
|
||||
max_model_len,
|
||||
mm_counts=mm_counts,
|
||||
)
|
||||
|
||||
assert total_tokens == profiled_tokens["image"]
|
||||
assert total_num_patches == profiled_tokens["image"]
|
||||
assert total_tokens == sum(
|
||||
placeholder.length
|
||||
for placeholder in decoder_dummy_data.multi_modal_placeholders["image"]
|
||||
|
||||
@ -9,6 +9,7 @@ from tempfile import NamedTemporaryFile, TemporaryDirectory
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image, ImageChops
|
||||
|
||||
from vllm.multimodal.image import convert_image_mode
|
||||
@ -410,6 +411,97 @@ def test_argsort_mm_positions(case):
|
||||
assert modality_idxs == expected_modality_idxs
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"is_embed,expected",
|
||||
[
|
||||
(None, 5),
|
||||
(torch.tensor([True, True, True, True, True]), 5),
|
||||
(torch.tensor([False, False, False, False, False]), 0),
|
||||
(torch.tensor([True, False, True, False, True]), 3),
|
||||
(torch.tensor([True]), 1),
|
||||
],
|
||||
)
|
||||
def test_placeholder_range_get_num_embeds(is_embed, expected):
|
||||
length = len(is_embed) if is_embed is not None else 5
|
||||
pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed)
|
||||
assert pr.get_num_embeds == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"is_embed,expected",
|
||||
[
|
||||
(None, None),
|
||||
(
|
||||
torch.tensor([False, True, False, True, True]),
|
||||
torch.tensor([0, 1, 1, 2, 3]),
|
||||
),
|
||||
(torch.tensor([True, True, True]), torch.tensor([1, 2, 3])),
|
||||
],
|
||||
)
|
||||
def test_placeholder_range_embeds_cumsum(is_embed, expected):
|
||||
length = len(is_embed) if is_embed is not None else 5
|
||||
pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed)
|
||||
|
||||
if expected is None:
|
||||
assert pr.embeds_cumsum is None
|
||||
return
|
||||
|
||||
assert torch.equal(pr.embeds_cumsum, expected)
|
||||
# cached_property should return the same object on repeated access
|
||||
assert pr.embeds_cumsum is pr.embeds_cumsum
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"is_embed,start_idx,end_idx,expected",
|
||||
[
|
||||
(None, 2, 4, (2, 4)),
|
||||
(
|
||||
torch.tensor([False, True, False, True, True]),
|
||||
3,
|
||||
5,
|
||||
(1, 3),
|
||||
),
|
||||
(
|
||||
torch.tensor([False, True, False, True, True]),
|
||||
0,
|
||||
2,
|
||||
(0, 1),
|
||||
),
|
||||
(
|
||||
torch.tensor([True, False, True, False]),
|
||||
2,
|
||||
2,
|
||||
(1, 1),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_placeholder_range_get_embeds_indices_in_range(
|
||||
is_embed, start_idx, end_idx, expected
|
||||
):
|
||||
length = len(is_embed) if is_embed is not None else 5
|
||||
pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed)
|
||||
assert pr.get_embeds_indices_in_range(start_idx, end_idx) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"offset,is_embed,expected",
|
||||
[
|
||||
(0, None, [(0, 4)]),
|
||||
(
|
||||
2,
|
||||
torch.tensor([False, True, False, True, True]),
|
||||
[(3, 3), (5, 6)],
|
||||
),
|
||||
(0, torch.tensor([True, True, True, True]), [(0, 3)]),
|
||||
(0, torch.tensor([False, False, False, False]), []),
|
||||
],
|
||||
)
|
||||
def test_placeholder_range_extract_embeds_range(offset, is_embed, expected):
|
||||
length = len(is_embed) if is_embed is not None else 5
|
||||
pr = PlaceholderRange(offset=offset, length=length, is_embed=is_embed)
|
||||
assert pr.extract_embeds_range() == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
|
||||
@pytest.mark.parametrize("num_frames", [-1, 32, 1800])
|
||||
|
||||
@ -172,7 +172,7 @@ def test_local_attention_virtual_batches(test_data: LocalAttentionTestData):
|
||||
)
|
||||
|
||||
# Call the function
|
||||
result = make_local_attention_virtual_batches(
|
||||
result, _ = make_local_attention_virtual_batches(
|
||||
attn_chunk_size, common_attn_metadata, block_size
|
||||
)
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec, PlaceholderRange
|
||||
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
|
||||
@ -23,7 +24,7 @@ class MockRequest:
|
||||
)
|
||||
self.mm_features.append(feature)
|
||||
|
||||
def get_num_encoder_tokens(self, input_id: int) -> int:
|
||||
def get_num_encoder_embeds(self, input_id: int) -> int:
|
||||
return self._token_counts[input_id]
|
||||
|
||||
|
||||
@ -162,8 +163,8 @@ def test_schedule_request_multi_images_respect_space_limit():
|
||||
|
||||
num_tokens_to_schedule = 0
|
||||
assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule)
|
||||
num_tokens_to_schedule += req.get_num_encoder_tokens(0)
|
||||
compute_budget -= req.get_num_encoder_tokens(0)
|
||||
num_tokens_to_schedule += req.get_num_encoder_embeds(0)
|
||||
compute_budget -= req.get_num_encoder_embeds(0)
|
||||
|
||||
assert not manager.can_allocate(req, 1, compute_budget, num_tokens_to_schedule)
|
||||
|
||||
@ -174,7 +175,75 @@ def test_schedule_request_multi_images_respect_compute_limit():
|
||||
compute_budget = 10
|
||||
num_tokens_to_schedule = 0
|
||||
assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule)
|
||||
num_tokens_to_schedule += req.get_num_encoder_tokens(0)
|
||||
compute_budget -= req.get_num_encoder_tokens(0)
|
||||
num_tokens_to_schedule += req.get_num_encoder_embeds(0)
|
||||
compute_budget -= req.get_num_encoder_embeds(0)
|
||||
|
||||
assert not manager.can_allocate(req, 1, compute_budget, num_tokens_to_schedule)
|
||||
|
||||
|
||||
def test_encoder_cache_with_is_embed_mask():
|
||||
class MockRequestWithMask(MockRequest):
|
||||
def get_num_encoder_embeds(self, input_id: int) -> int:
|
||||
return self.mm_features[input_id].mm_position.get_num_embeds
|
||||
|
||||
is_embed = torch.zeros(100, dtype=torch.bool)
|
||||
is_embed[torch.tensor([5, 15, 25, 35, 45, 55, 65, 75])] = True
|
||||
|
||||
request = MockRequestWithMask("r1", ["img1"], [100])
|
||||
request.mm_features[0] = MultiModalFeatureSpec(
|
||||
data=None,
|
||||
modality="image",
|
||||
identifier="img1",
|
||||
mm_position=PlaceholderRange(offset=0, length=100, is_embed=is_embed),
|
||||
)
|
||||
|
||||
manager = EncoderCacheManager(cache_size=100)
|
||||
manager.allocate(request, 0)
|
||||
|
||||
assert manager.num_free_slots == 92
|
||||
assert "img1" in manager.cached
|
||||
|
||||
old_size = 100
|
||||
new_size = request.mm_features[0].mm_position.get_num_embeds
|
||||
assert new_size == 8
|
||||
savings_ratio = old_size / new_size
|
||||
assert savings_ratio == 12.5
|
||||
|
||||
|
||||
def test_encoder_cache_mask_based_retrieval():
|
||||
class MockRequestWithMask(MockRequest):
|
||||
def get_num_encoder_embeds(self, input_id: int) -> int:
|
||||
return self.mm_features[input_id].mm_position.get_num_embeds
|
||||
|
||||
is_embed = torch.tensor(
|
||||
[False, False, True, True, False, True, True, True, False, False]
|
||||
)
|
||||
|
||||
request = MockRequestWithMask("r1", ["img1"], [10])
|
||||
request.mm_features[0] = MultiModalFeatureSpec(
|
||||
data=None,
|
||||
modality="image",
|
||||
identifier="img1",
|
||||
mm_position=PlaceholderRange(offset=0, length=10, is_embed=is_embed),
|
||||
)
|
||||
|
||||
manager = EncoderCacheManager(cache_size=50)
|
||||
manager.allocate(request, 0)
|
||||
|
||||
assert request.mm_features[0].mm_position.get_num_embeds == 5
|
||||
|
||||
start_idx = 2
|
||||
end_idx = 8
|
||||
num_embeds_before = is_embed[:start_idx].sum().item()
|
||||
num_embeds_in_range = is_embed[start_idx:end_idx].sum().item()
|
||||
|
||||
assert num_embeds_before == 0
|
||||
assert num_embeds_in_range == 5
|
||||
|
||||
start_idx = 0
|
||||
end_idx = 5
|
||||
num_embeds_before = is_embed[:start_idx].sum().item() if start_idx > 0 else 0
|
||||
num_embeds_in_range = is_embed[start_idx:end_idx].sum().item()
|
||||
|
||||
assert num_embeds_before == 0
|
||||
assert num_embeds_in_range == 2
|
||||
|
||||
@ -188,7 +188,6 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
tensor_parallel_size=tp_size,
|
||||
# enable_prefix_caching=False,
|
||||
max_num_seqs=32,
|
||||
max_model_len=8192,
|
||||
dtype="bfloat16", # not everything is supported
|
||||
|
||||
@ -38,7 +38,7 @@ class MockRequest:
|
||||
)
|
||||
self.mm_features.append(feature)
|
||||
|
||||
def get_num_encoder_tokens(self, input_id: int) -> int:
|
||||
def get_num_encoder_embeds(self, input_id: int) -> int:
|
||||
assert input_id < len(self._token_counts)
|
||||
return self._token_counts[input_id]
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ import functools
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import CacheConfig
|
||||
@ -51,11 +51,19 @@ def create_chunked_local_attention_backend(
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> AttentionMetadata:
|
||||
common_attn_metadata = make_local_attention_virtual_batches(
|
||||
):
|
||||
cm, make_virtual_batches_block_table = make_local_attention_virtual_batches(
|
||||
attention_chunk_size, common_attn_metadata, block_size
|
||||
)
|
||||
return super().build(common_prefix_len, common_attn_metadata, fast_build)
|
||||
metadata = super().build(common_prefix_len, cm, fast_build)
|
||||
metadata.make_virtual_batches_block_table = make_virtual_batches_block_table
|
||||
return metadata
|
||||
|
||||
def update_block_table(
|
||||
self, metadata, blk_table: torch.Tensor, slot_mapping: torch.Tensor
|
||||
):
|
||||
blk_table = metadata.make_virtual_batches_block_table(blk_table)
|
||||
return super().update_block_table(metadata, blk_table, slot_mapping)
|
||||
|
||||
attn_backend = subclass_attention_backend(
|
||||
name_prefix=prefix,
|
||||
|
||||
@ -10,6 +10,7 @@ from vllm.attention.ops.vit_attn_wrappers import (
|
||||
vit_flash_attn_wrapper,
|
||||
vit_torch_sdpa_wrapper,
|
||||
)
|
||||
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
||||
from vllm.config import MultiModalConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
@ -101,6 +102,10 @@ class MMEncoderAttention(CustomOp):
|
||||
self.attn_backend,
|
||||
)
|
||||
|
||||
if self.is_flash_attn_backend:
|
||||
assert self.flash_attn_varlen_func is not None
|
||||
self._fa_version = get_flash_attn_version()
|
||||
|
||||
logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.")
|
||||
|
||||
@classmethod
|
||||
@ -204,6 +209,7 @@ class MMEncoderAttention(CustomOp):
|
||||
max_seqlen=max_seqlen,
|
||||
batch_size=bsz,
|
||||
is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA),
|
||||
fa_version=self._fa_version,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
@ -16,6 +16,7 @@ import einops
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
|
||||
@ -27,11 +28,15 @@ def flash_attn_maxseqlen_wrapper(
|
||||
max_seqlen: torch.Tensor,
|
||||
batch_size: int,
|
||||
is_rocm_aiter: bool,
|
||||
fa_version: int,
|
||||
) -> torch.Tensor:
|
||||
kwargs = {}
|
||||
if is_rocm_aiter:
|
||||
from aiter import flash_attn_varlen_func
|
||||
else:
|
||||
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
|
||||
|
||||
kwargs["fa_version"] = fa_version
|
||||
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||
output = flash_attn_varlen_func(
|
||||
q,
|
||||
@ -43,6 +48,7 @@ def flash_attn_maxseqlen_wrapper(
|
||||
max_seqlen_k=max_seqlen.item(),
|
||||
dropout_p=0.0,
|
||||
causal=False,
|
||||
**kwargs,
|
||||
)
|
||||
context_layer = einops.rearrange(output, "(b s) h d -> b s h d", b=batch_size)
|
||||
return context_layer
|
||||
@ -56,6 +62,7 @@ def flash_attn_maxseqlen_wrapper_fake(
|
||||
max_seqlen: torch.Tensor,
|
||||
batch_size: int,
|
||||
is_rocm_aiter: bool,
|
||||
fa_version: int,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(q)
|
||||
|
||||
@ -75,9 +82,10 @@ def vit_flash_attn_wrapper(
|
||||
max_seqlen: torch.Tensor,
|
||||
batch_size: int,
|
||||
is_rocm_aiter: bool,
|
||||
fa_version: int,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.vllm.flash_attn_maxseqlen_wrapper(
|
||||
q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter
|
||||
q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, fa_version
|
||||
)
|
||||
|
||||
|
||||
@ -89,6 +97,13 @@ def torch_sdpa_wrapper(
|
||||
v: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Never remove the contiguous logic for ROCm
|
||||
# Without it, hallucinations occur with the backend
|
||||
if current_platform.is_rocm():
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
|
||||
outputs = []
|
||||
|
||||
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
|
||||
@ -64,7 +64,12 @@ class NaiveAll2AllManager(All2AllManagerBase):
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if extra_tensors is not None:
|
||||
raise NotImplementedError(
|
||||
"extra_tensors is not supported for NaiveAll2AllManager"
|
||||
)
|
||||
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
|
||||
dp_metadata = get_forward_context().dp_metadata
|
||||
assert dp_metadata is not None
|
||||
@ -76,6 +81,7 @@ class NaiveAll2AllManager(All2AllManagerBase):
|
||||
router_logits = self.naive_multicast(
|
||||
router_logits, cu_tokens_across_sp_cpu, is_sequence_parallel
|
||||
)
|
||||
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(
|
||||
@ -113,7 +119,11 @@ class AgRsAll2AllManager(All2AllManagerBase):
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
"""
|
||||
Gather hidden_states and router_logits from all dp ranks.
|
||||
"""
|
||||
@ -121,15 +131,22 @@ class AgRsAll2AllManager(All2AllManagerBase):
|
||||
assert dp_metadata is not None
|
||||
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
assert sizes is not None
|
||||
|
||||
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
|
||||
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
|
||||
hidden_states, router_logits = dist_group.all_gatherv(
|
||||
[hidden_states, router_logits],
|
||||
|
||||
tensors_to_gather = [hidden_states, router_logits]
|
||||
if extra_tensors is not None:
|
||||
tensors_to_gather.extend(extra_tensors)
|
||||
|
||||
gathered_tensors = dist_group.all_gatherv(
|
||||
tensors_to_gather,
|
||||
dim=0,
|
||||
sizes=sizes,
|
||||
)
|
||||
return hidden_states, router_logits
|
||||
|
||||
if extra_tensors is not None:
|
||||
return (gathered_tensors[0], gathered_tensors[1], gathered_tensors[2:])
|
||||
return gathered_tensors[0], gathered_tensors[1]
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
@ -204,6 +221,7 @@ class PPLXAll2AllManager(All2AllManagerBase):
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -251,6 +269,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import threading
|
||||
from typing import Any
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
import torch
|
||||
@ -68,7 +69,11 @@ class All2AllManagerBase:
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
):
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> Any:
|
||||
# Subclasses should either:
|
||||
# - implement handling for extra_tensors, or
|
||||
# - raise a clear error if extra_tensors is not supported.
|
||||
raise NotImplementedError
|
||||
|
||||
def set_num_sms(self, num_sms: int):
|
||||
|
||||
@ -318,17 +318,23 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
return output_list
|
||||
|
||||
def dispatch(
|
||||
def dispatch( # type: ignore[override]
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
assert self.all2all_manager is not None
|
||||
hidden_states, router_logits = self.all2all_manager.dispatch(
|
||||
hidden_states, router_logits, is_sequence_parallel
|
||||
return self.all2all_manager.dispatch(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
is_sequence_parallel,
|
||||
extra_tensors, # type: ignore[call-arg]
|
||||
)
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
|
||||
@ -144,7 +144,7 @@ class ECExampleConnector(ECConnectorBase):
|
||||
Update ECConnector state after encoder cache allocation.
|
||||
"""
|
||||
mm_hash = request.mm_features[index].identifier
|
||||
num_encoder_token = request.get_num_encoder_tokens(index)
|
||||
num_encoder_token = request.get_num_encoder_embeds(index)
|
||||
# Insert mm_hash only if this block has not been recorded yet.
|
||||
self._mm_datas_need_loads[mm_hash] = num_encoder_token
|
||||
|
||||
|
||||
@ -1007,10 +1007,17 @@ class GroupCoordinator:
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
if self.device_communicator is not None:
|
||||
return self.device_communicator.dispatch(
|
||||
hidden_states, router_logits, is_sequence_parallel
|
||||
return self.device_communicator.dispatch( # type: ignore[call-arg]
|
||||
hidden_states,
|
||||
router_logits,
|
||||
is_sequence_parallel,
|
||||
extra_tensors,
|
||||
)
|
||||
else:
|
||||
return hidden_states, router_logits
|
||||
|
||||
@ -2054,6 +2054,9 @@ class TranscriptionRequest(OpenAIBaseModel):
|
||||
|
||||
presence_penalty: float | None = 0.0
|
||||
"""The presence penalty to use for sampling."""
|
||||
|
||||
max_completion_tokens: int | None = None
|
||||
"""The maximum number of tokens to generate."""
|
||||
# --8<-- [end:transcription-sampling-params]
|
||||
|
||||
# Default sampling parameters for transcription requests.
|
||||
@ -2300,6 +2303,9 @@ class TranslationRequest(OpenAIBaseModel):
|
||||
# Flattened stream option to simplify form data.
|
||||
stream_include_usage: bool | None = False
|
||||
stream_continuous_usage_stats: bool | None = False
|
||||
|
||||
max_completion_tokens: int | None = None
|
||||
"""The maximum number of tokens to generate."""
|
||||
# --8<-- [end:translation-extra-params]
|
||||
|
||||
# Default sampling parameters for translation requests.
|
||||
|
||||
@ -293,8 +293,14 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
try:
|
||||
# Unlike most decoder-only models, whisper generation length is not
|
||||
# constrained by the size of the input audio, which is mapped to a
|
||||
# fixed-size log-mel-spectogram.
|
||||
default_max_tokens = self.model_config.max_model_len
|
||||
# fixed-size log-mel-spectogram. Still, allow for fewer tokens to be
|
||||
# generated by respecting the extra completion tokens arg.
|
||||
if request.max_completion_tokens is None:
|
||||
default_max_tokens = self.model_config.max_model_len
|
||||
else:
|
||||
default_max_tokens = min(
|
||||
self.model_config.max_model_len, request.max_completion_tokens
|
||||
)
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens, self.default_sampling_params
|
||||
)
|
||||
|
||||
@ -207,7 +207,7 @@ if TYPE_CHECKING:
|
||||
VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False
|
||||
VLLM_ENABLE_CUDAGRAPH_GC: bool = False
|
||||
VLLM_LOOPBACK_IP: str = ""
|
||||
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False
|
||||
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = True
|
||||
VLLM_ENABLE_RESPONSES_API_STORE: bool = False
|
||||
VLLM_USE_TRTLLM_ATTENTION: str | None = None
|
||||
VLLM_NVFP4_GEMM_BACKEND: str | None = None
|
||||
@ -1430,7 +1430,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# kv-cache memory usage and enable longer contexts)
|
||||
# TODO(lucas): Remove this flag once latency regression is resolved.
|
||||
"VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE": lambda: bool(
|
||||
int(os.getenv("VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE", "0"))
|
||||
int(os.getenv("VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE", "1"))
|
||||
),
|
||||
# Enables support for the "store" option in the OpenAI Responses API.
|
||||
# When set to 1, vLLM's OpenAI server will retain the input and output
|
||||
|
||||
@ -14,11 +14,6 @@ class LoRARequest(
|
||||
"""
|
||||
Request for a LoRA adapter.
|
||||
|
||||
Note that this class should be used internally. For online
|
||||
serving, it is recommended to not allow users to use this class but
|
||||
instead provide another layer of abstraction to prevent users from
|
||||
accessing unauthorized LoRA adapters.
|
||||
|
||||
lora_int_id must be globally unique for a given adapter.
|
||||
This is currently not enforced in vLLM.
|
||||
"""
|
||||
|
||||
@ -71,6 +71,18 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
"implementation based on the prepare_finalize"
|
||||
)
|
||||
|
||||
def prepare_dp_allgather_tensor(
|
||||
self,
|
||||
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, list[torch.Tensor]]:
|
||||
"""Hook to prepare tensors and extra tensors for DP allgather + EP dispatch."""
|
||||
raise NotImplementedError(
|
||||
"Method 'prepare_dp_allgather_tensor' is not implemented in "
|
||||
f"{self.__class__.__name__}."
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
|
||||
@ -44,6 +44,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
is_flashinfer_supporting_global_sf,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer_trtllm_fused_moe
|
||||
from vllm.utils.math_utils import cdiv, round_up
|
||||
from vllm.utils.torch_utils import (
|
||||
aux_stream,
|
||||
@ -1933,10 +1934,46 @@ class FusedMoE(CustomOp):
|
||||
)
|
||||
|
||||
with sp_ctx:
|
||||
extra_tensors = None
|
||||
if do_naive_dispatch_combine:
|
||||
hidden_states_combined, router_logits = get_ep_group().dispatch(
|
||||
hidden_states, router_logits, self.is_sequence_parallel
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.layers.quantization.modelopt import (
|
||||
ModelOptNvFp4FusedMoE,
|
||||
)
|
||||
|
||||
post_quant_allgather = (
|
||||
has_flashinfer_trtllm_fused_moe()
|
||||
and self.quant_method is not None
|
||||
and self.dp_size > 1
|
||||
and self.use_ep
|
||||
and isinstance(self.quant_method, ModelOptNvFp4FusedMoE)
|
||||
)
|
||||
if post_quant_allgather:
|
||||
hidden_states_to_dispatch, extra_tensors = (
|
||||
self.quant_method.prepare_dp_allgather_tensor(
|
||||
self, hidden_states, router_logits
|
||||
)
|
||||
)
|
||||
else:
|
||||
hidden_states_to_dispatch = hidden_states
|
||||
|
||||
dispatch_res = get_ep_group().dispatch(
|
||||
hidden_states_to_dispatch,
|
||||
router_logits,
|
||||
self.is_sequence_parallel,
|
||||
extra_tensors=extra_tensors,
|
||||
)
|
||||
if extra_tensors is not None:
|
||||
hidden_states_combined, router_logits, extra_tensors_combined = (
|
||||
dispatch_res
|
||||
)
|
||||
hidden_states_combined = (
|
||||
hidden_states_combined,
|
||||
extra_tensors_combined[0],
|
||||
)
|
||||
else:
|
||||
hidden_states_combined, router_logits = dispatch_res
|
||||
|
||||
# Run shared experts before matrix multiply.
|
||||
# because matrix multiply maybe modify the hidden_states.
|
||||
if has_separate_shared_experts and not use_shared_experts_stream:
|
||||
|
||||
@ -121,7 +121,7 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
return 75
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
|
||||
@ -626,17 +626,11 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
else:
|
||||
# If no modular kernel is provided, use cutlass_moe_fp4 for TP case
|
||||
# only (no EP).
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
|
||||
|
||||
assert layer.expert_map is None, (
|
||||
"Expert Parallelism / expert_map "
|
||||
"is currently not supported for "
|
||||
"CompressedTensorsW4A4Nvfp4MoEMethod."
|
||||
)
|
||||
assert self.moe_quant_config is not None
|
||||
|
||||
# Cutlass moe takes in activations in BF16/Half precision
|
||||
# and fp4 quantized weights loaded from the checkpoint
|
||||
return cutlass_moe_fp4(
|
||||
a=x,
|
||||
w1_fp4=layer.w13_weight,
|
||||
@ -644,6 +638,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
quant_config=self.moe_quant_config,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
# TODO(bnell): derive these from arguments
|
||||
m=x.shape[0],
|
||||
|
||||
@ -253,7 +253,7 @@ class Fp8Config(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
return 75
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
|
||||
@ -181,7 +181,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
return 75
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
|
||||
@ -871,7 +871,7 @@ class ModelOptNvFp4Config(ModelOptQuantConfigBase):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
return 75
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
@ -1522,6 +1522,24 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
w2_blockscale_swizzled, requires_grad=False
|
||||
)
|
||||
|
||||
def prepare_dp_allgather_tensor(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, list[torch.Tensor]]:
|
||||
"""Optionally prepare extra tensors to carry through DP allgather/EP."""
|
||||
import flashinfer
|
||||
|
||||
a1_gscale = layer.w13_input_scale_quant
|
||||
hidden_states_fp4, hidden_states_sf = flashinfer.fp4_quantize(
|
||||
hidden_states,
|
||||
a1_gscale,
|
||||
is_sf_swizzled_layout=False,
|
||||
)
|
||||
extra_tensors: list[torch.Tensor] = [hidden_states_sf]
|
||||
return hidden_states_fp4, extra_tensors
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
@ -1576,8 +1594,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
e_score_correction_bias=layer.e_score_correction_bias,
|
||||
)
|
||||
|
||||
# Hidden_states in select_experts is only used to extract metadata
|
||||
if isinstance(x, tuple):
|
||||
x_routing, _ = x
|
||||
else:
|
||||
x_routing = x
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
hidden_states=x,
|
||||
hidden_states=x_routing,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
|
||||
@ -238,7 +238,7 @@ def prepare_static_weights_for_trtllm_fp4_moe(
|
||||
|
||||
def flashinfer_trtllm_fp4_moe(
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
global_num_experts: int,
|
||||
@ -269,12 +269,16 @@ def flashinfer_trtllm_fp4_moe(
|
||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||
|
||||
# Quantize input to FP4
|
||||
a1_gscale = layer.w13_input_scale_quant
|
||||
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
|
||||
x,
|
||||
a1_gscale,
|
||||
is_sf_swizzled_layout=False,
|
||||
)
|
||||
if isinstance(x, tuple):
|
||||
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
|
||||
else:
|
||||
# hidden_states is the already quantized
|
||||
a1_gscale = layer.w13_input_scale_quant
|
||||
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
|
||||
x,
|
||||
a1_gscale,
|
||||
is_sf_swizzled_layout=False,
|
||||
)
|
||||
|
||||
# Determine routing method type
|
||||
use_llama4_routing = custom_routing_function is Llama4MoE.custom_routing_function
|
||||
@ -360,13 +364,17 @@ def flashinfer_trtllm_fp4_routed_moe(
|
||||
torch.bfloat16
|
||||
).view(torch.int16)
|
||||
|
||||
# Quantize input to FP4
|
||||
a1_gscale = layer.w13_input_scale_quant
|
||||
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
|
||||
x,
|
||||
a1_gscale,
|
||||
is_sf_swizzled_layout=False,
|
||||
)
|
||||
if isinstance(x, tuple):
|
||||
# Hidden_states is the already quantized
|
||||
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
|
||||
else:
|
||||
# Quantize input to FP4
|
||||
a1_gscale = layer.w13_input_scale_quant
|
||||
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
|
||||
x,
|
||||
a1_gscale,
|
||||
is_sf_swizzled_layout=False,
|
||||
)
|
||||
|
||||
# Call TRT-LLM FP4 block-scale MoE kernel
|
||||
out = flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe(
|
||||
|
||||
@ -169,10 +169,13 @@ class DeciLMDecoderLayer(nn.Module):
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
if not self._is_no_op_ffn:
|
||||
ffn_mult = block_config.ffn.ffn_mult
|
||||
intermediate_size = _ffn_mult_to_intermediate_size(
|
||||
ffn_mult, config.hidden_size
|
||||
)
|
||||
if hasattr(block_config.ffn, "ffn_mult"):
|
||||
ffn_mult = block_config.ffn.ffn_mult
|
||||
intermediate_size = _ffn_mult_to_intermediate_size(
|
||||
ffn_mult, config.hidden_size
|
||||
)
|
||||
else:
|
||||
intermediate_size = block_config.ffn.intermediate_size
|
||||
|
||||
self.mlp = LlamaMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
|
||||
@ -713,17 +713,13 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> int:
|
||||
target_width, target_height = self.get_image_size_with_most_features()
|
||||
video_soft_tokens = self.get_num_video_tokens(
|
||||
num_video_soft_tokens = self.get_num_video_tokens(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
|
||||
image_processor=None,
|
||||
)
|
||||
|
||||
# NOTE: By default in Qwen3-VL, one video token is converted to
|
||||
# "<{timestamp} seconds>" (on average 9.5 tokens) + vision_start_token + video_token + vision_end_token # noqa: E501
|
||||
formatted_video_soft_tokens = video_soft_tokens * 12.5
|
||||
return int(formatted_video_soft_tokens)
|
||||
return num_video_soft_tokens
|
||||
|
||||
def _calculate_timestamps(
|
||||
self, indices: list[int] | torch.Tensor, video_fps: float, merge_size: int
|
||||
|
||||
@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
|
||||
from collections import UserDict, defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from functools import cached_property, partial
|
||||
from itertools import accumulate
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@ -169,11 +169,42 @@ class PlaceholderRange:
|
||||
between `offset` and `offset + length` to assign embeddings to.
|
||||
"""
|
||||
|
||||
def get_num_embeds(self) -> int:
|
||||
@cached_property
|
||||
def embeds_cumsum(self) -> torch.Tensor | None:
|
||||
if self.is_embed is None:
|
||||
return None
|
||||
|
||||
return self.is_embed.cumsum(dim=0)
|
||||
|
||||
@cached_property
|
||||
def get_num_embeds(self) -> int:
|
||||
if self.embeds_cumsum is None:
|
||||
return self.length
|
||||
|
||||
return int(self.is_embed.sum().item())
|
||||
return int(self.embeds_cumsum[-1])
|
||||
|
||||
def get_embeds_indices_in_range(
|
||||
self, start_idx: int, end_idx: int
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Returns the starting and ending indices of the embeddings of encoder outputs
|
||||
in the range of [start_idx, end_idx) in the placeholders.
|
||||
|
||||
For example, given:
|
||||
PlaceholderRange(offset=2, length=5, is_embed=[False, True, False, True, True])
|
||||
|
||||
If start_idx=3 and end_idx=5, the output is (1, 3) because we want to get
|
||||
the second and the third embeddings from the encoder output.
|
||||
"""
|
||||
if self.embeds_cumsum is None:
|
||||
return start_idx, end_idx
|
||||
|
||||
embeds_start_idx = (
|
||||
int(self.embeds_cumsum[start_idx - 1]) if start_idx > 0 else 0
|
||||
)
|
||||
embeds_end_idx = int(self.embeds_cumsum[end_idx - 1])
|
||||
|
||||
return embeds_start_idx, embeds_end_idx
|
||||
|
||||
def extract_embeds_range(self) -> list[tuple[int, int]]:
|
||||
"""Extract the start and end indices of the embedded region in prompt.
|
||||
@ -188,7 +219,7 @@ class PlaceholderRange:
|
||||
Returns full placeholder range if `is_embed` is `None`.
|
||||
"""
|
||||
if self.is_embed is None:
|
||||
return [(self.offset, self.offset + self.length)]
|
||||
return [(self.offset, self.offset + self.length - 1)]
|
||||
|
||||
mask_i = self.is_embed.int()
|
||||
starts = torch.nonzero(
|
||||
|
||||
@ -274,15 +274,11 @@ class MultiModalProfiler(Generic[_I]):
|
||||
def _get_mm_num_tokens(
|
||||
self,
|
||||
mm_inputs: MultiModalInputs,
|
||||
mm_embeddings_only: bool = True,
|
||||
) -> Mapping[str, int]:
|
||||
placeholders_by_modality = mm_inputs["mm_placeholders"]
|
||||
|
||||
return {
|
||||
modality: sum(
|
||||
item.get_num_embeds() if mm_embeddings_only else item.length
|
||||
for item in placeholders
|
||||
)
|
||||
modality: sum(item.get_num_embeds for item in placeholders)
|
||||
for modality, placeholders in placeholders_by_modality.items()
|
||||
}
|
||||
|
||||
@ -328,12 +324,15 @@ class MultiModalProfiler(Generic[_I]):
|
||||
multi_modal_placeholders=mm_inputs["mm_placeholders"],
|
||||
)
|
||||
|
||||
def _get_mm_max_tokens(
|
||||
def get_mm_max_tokens(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int] | None = None,
|
||||
mm_embeddings_only: bool = True,
|
||||
) -> Mapping[str, int]:
|
||||
"""
|
||||
Returns the maximum number of embeddings per item of each modality, excluding
|
||||
any break/text tokens in-between multimodal embeddings/encoder outputs.
|
||||
"""
|
||||
if mm_counts is None:
|
||||
mm_counts = self.get_mm_limits()
|
||||
|
||||
@ -349,21 +348,4 @@ class MultiModalProfiler(Generic[_I]):
|
||||
}
|
||||
|
||||
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
|
||||
return self._get_mm_num_tokens(mm_inputs, mm_embeddings_only=mm_embeddings_only)
|
||||
|
||||
def get_mm_max_contiguous_tokens(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int] | None = None,
|
||||
) -> Mapping[str, int]:
|
||||
"""
|
||||
Returns the maximum length of the multimodal (image placeholders+text)
|
||||
tokens, including any break/text tokens in-between image embeddings.
|
||||
|
||||
`<im_start> [IMG] [IMG] [IMG] <row_break> [IMG] [IMG] [IMG] <im_end>`
|
||||
Returns 9, even when the number of image embeddings is 6.
|
||||
|
||||
This is important to take into account when profiling and
|
||||
initializing the encoder cache size.
|
||||
"""
|
||||
return self._get_mm_max_tokens(seq_len, mm_counts, mm_embeddings_only=False)
|
||||
return self._get_mm_num_tokens(mm_inputs)
|
||||
|
||||
@ -164,7 +164,7 @@ class MultiModalRegistry:
|
||||
profiler.get_mm_limits() if profiler_limits is None else profiler_limits
|
||||
)
|
||||
|
||||
return profiler.get_mm_max_contiguous_tokens(
|
||||
return profiler.get_mm_max_tokens(
|
||||
seq_len,
|
||||
{modality: 1 for modality, limit in profiler_limits.items() if limit > 0},
|
||||
)
|
||||
|
||||
@ -184,6 +184,23 @@ def has_flashinfer_cutedsl() -> bool:
|
||||
)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_flashinfer_trtllm_fused_moe() -> bool:
|
||||
"""Return `True` if FlashInfer TRTLLM fused MoE is available."""
|
||||
if not has_flashinfer_moe():
|
||||
return False
|
||||
required_functions = [
|
||||
("flashinfer.fused_moe", "trtllm_fp8_block_scale_moe"),
|
||||
("flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe"),
|
||||
("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"),
|
||||
]
|
||||
for module_name, attr_name in required_functions:
|
||||
mod = _get_submodule(module_name)
|
||||
if not mod or not hasattr(mod, attr_name):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_flashinfer_cutlass_fused_moe() -> bool:
|
||||
"""Return `True` if FlashInfer CUTLASS fused MoE is available."""
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with FlashAttention."""
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
@ -250,6 +251,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
if get_flash_attn_version() == 3
|
||||
else AttentionCGSupport.UNIFORM_BATCH
|
||||
)
|
||||
supports_update_block_table: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -493,6 +495,17 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
def update_block_table(
|
||||
self,
|
||||
metadata: FlashAttentionMetadata,
|
||||
blk_table: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> FlashAttentionMetadata:
|
||||
new_metadata = copy.copy(metadata)
|
||||
new_metadata.block_table = blk_table
|
||||
new_metadata.slot_mapping = slot_mapping
|
||||
return new_metadata
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
return use_cascade_attention(*args, **kwargs)
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
import itertools
|
||||
from dataclasses import dataclass
|
||||
|
||||
@ -134,6 +135,8 @@ class Mamba2AttentionMetadata:
|
||||
class Mamba2AttentionMetadataBuilder(
|
||||
BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata]
|
||||
):
|
||||
supports_update_block_table: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
@ -346,3 +349,27 @@ class Mamba2AttentionMetadataBuilder(
|
||||
num_computed_tokens_p=num_computed_tokens_p,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
def update_block_table(
|
||||
self,
|
||||
metadata: Mamba2AttentionMetadata,
|
||||
blk_table: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> Mamba2AttentionMetadata:
|
||||
new_metadata = copy.copy(metadata)
|
||||
prefix_caching = self.vllm_config.cache_config.enable_prefix_caching
|
||||
state_indices_t = blk_table if prefix_caching else blk_table[:, 0]
|
||||
num_reqs = blk_table.shape[0]
|
||||
|
||||
# For CUDA graphs, copy to persistent buffer
|
||||
if (
|
||||
metadata.num_prefills == 0
|
||||
and num_reqs <= self.decode_cudagraph_max_bs
|
||||
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
):
|
||||
persistent_state_indices_t = self.state_indices_tensor[:num_reqs]
|
||||
persistent_state_indices_t.copy_(state_indices_t, non_blocking=True)
|
||||
state_indices_t = persistent_state_indices_t
|
||||
|
||||
new_metadata.state_indices_tensor = state_indices_t
|
||||
return new_metadata
|
||||
|
||||
@ -4,6 +4,7 @@ import abc
|
||||
import enum
|
||||
import functools
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field, fields, make_dataclass
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@ -317,6 +318,9 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
||||
# If not, set this to None. Otherwise set it to the query
|
||||
# length that will be pulled into the front of the batch.
|
||||
reorder_batch_threshold: int | None = None
|
||||
# Does this backend/builder support updating the block table in existing
|
||||
# metadata
|
||||
supports_update_block_table: bool = False
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
@ -387,6 +391,21 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def update_block_table(
|
||||
self,
|
||||
metadata: M,
|
||||
blk_table: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> M:
|
||||
"""
|
||||
Update the block table for the attention metadata.
|
||||
Faster when theres multiple kv-cache groups that create virtually the
|
||||
same metadata but just with different block tables.
|
||||
|
||||
Only needs to be implemented if supports_update_block_table is True.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata
|
||||
) -> M:
|
||||
@ -603,7 +622,7 @@ def make_local_attention_virtual_batches(
|
||||
attn_chunk_size: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
block_size: int = 0,
|
||||
) -> CommonAttentionMetadata:
|
||||
) -> tuple[CommonAttentionMetadata, Callable[[torch.Tensor], torch.Tensor]]:
|
||||
query_start_loc_np = common_attn_metadata.query_start_loc_cpu.numpy()
|
||||
seq_lens_np = common_attn_metadata.seq_lens_cpu.numpy()
|
||||
block_table = common_attn_metadata.block_table_tensor
|
||||
@ -715,9 +734,12 @@ def make_local_attention_virtual_batches(
|
||||
# tensor first, which recovers perf.
|
||||
batch_indices_torch = torch.from_numpy(batch_indices)
|
||||
block_indices_torch = torch.from_numpy(block_indices)
|
||||
block_table_local = block_table[batch_indices_torch, block_indices_torch].view(
|
||||
virtual_batches, -1
|
||||
)
|
||||
|
||||
# Save as a lambda so we can return this for update_block_table
|
||||
make_block_table = lambda block_table: block_table[
|
||||
batch_indices_torch, block_indices_torch
|
||||
].view(virtual_batches, -1)
|
||||
block_table_local = make_block_table(block_table)
|
||||
|
||||
query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local)
|
||||
seq_lens_cpu = torch.from_numpy(seqlens_k_local)
|
||||
@ -736,7 +758,7 @@ def make_local_attention_virtual_batches(
|
||||
causal=True,
|
||||
_seq_lens_cpu=seq_lens_cpu,
|
||||
_num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local),
|
||||
)
|
||||
), make_block_table
|
||||
|
||||
|
||||
def make_kv_sharing_fast_prefill_common_attn_metadata(
|
||||
|
||||
@ -39,20 +39,26 @@ class EncoderCacheManager:
|
||||
space for new embeddings.
|
||||
Oldest cached embeddings with no request referenced will be first evicted.
|
||||
|
||||
NOTE: The EncoderCacheManager operates on the level of multimodal embeddings
|
||||
instead of encoder tokens (i.e. all tokens that represent the multimodal data
|
||||
in the input sequence). This means all break/text tokens in-between multimodal
|
||||
embeddings are not considered with respect to the cache size and the number
|
||||
of free slots.
|
||||
|
||||
Args:
|
||||
cache_size: Limit the size of the cache, measured by the number of
|
||||
tokens from the input sequence.
|
||||
encoder embeddings from the input sequence.
|
||||
|
||||
Attributes:
|
||||
cache_size: Total cache capacity in encoder tokens.
|
||||
num_free_slots: Current available cache capacity in encoder tokens.
|
||||
cache_size: Total cache capacity in encoder embeddings.
|
||||
num_free_slots: Current available cache capacity in encoder embeddings.
|
||||
num_freeable_slots: Capacity that can be immediately reclaimed by
|
||||
evicting entries with zero references (in encoder tokens).
|
||||
evicting entries with zero references (in encoder embeddings).
|
||||
cached: Mapping from mm_hash to a set of request IDs that currently
|
||||
reference the cached entry. If the set is empty, the entry exists
|
||||
but is not referenced by any request and is eligible for
|
||||
reclamation.
|
||||
freeable: List of tuples (mm_hash, num_tokens) representing entries
|
||||
freeable: List of tuples (mm_hash, num_encoder_embeds) representing entries
|
||||
whose no current running request is needed and that can be freed to
|
||||
make space when needed.
|
||||
freed: List of mm_hash strings that were actually evicted since the
|
||||
@ -67,7 +73,7 @@ class EncoderCacheManager:
|
||||
# mm_hash of mm_data => ids of requests that reference the mm_data
|
||||
self.cached: dict[str, set[str]] = {}
|
||||
|
||||
# mm_hash of mm_data => num_encoder_tokens of the mm_data
|
||||
# mm_hash of mm_data => num_encoder_embeds of the mm_data
|
||||
self.freeable: OrderedDict[str, int] = OrderedDict()
|
||||
self.freed: list[str] = []
|
||||
|
||||
@ -93,8 +99,8 @@ class EncoderCacheManager:
|
||||
|
||||
# Cached but currently not referenced by any request
|
||||
if not self.cached[mm_hash]:
|
||||
num_tokens = self.freeable.pop(mm_hash)
|
||||
self.num_freeable_slots -= num_tokens
|
||||
num_encoder_embeds = self.freeable.pop(mm_hash)
|
||||
self.num_freeable_slots -= num_encoder_embeds
|
||||
|
||||
self.cached[mm_hash].add(request.request_id)
|
||||
return True
|
||||
@ -104,7 +110,7 @@ class EncoderCacheManager:
|
||||
request: Request,
|
||||
input_id: int,
|
||||
encoder_compute_budget: int,
|
||||
num_tokens_to_schedule: int,
|
||||
num_embeds_to_schedule: int,
|
||||
) -> bool:
|
||||
"""Check if there's sufficient cache space for a multimodal input.
|
||||
If there is, return True and update EncoderCacheManager state.
|
||||
@ -121,9 +127,9 @@ class EncoderCacheManager:
|
||||
Args:
|
||||
request: The request containing the multimodal input.
|
||||
input_id: Index of the multimodal input within the request.
|
||||
encoder_compute_budget: Number of encoder tokens allowed to be
|
||||
encoder_compute_budget: Number of encoder embeddings allowed to be
|
||||
computed when this method is invoked.
|
||||
num_tokens_to_schedule: Number of tokens already scheduled to be
|
||||
num_embeds_to_schedule: Number of encoder embeddings already scheduled to be
|
||||
allocated with cache space when this method is invoked.
|
||||
|
||||
Returns:
|
||||
@ -134,30 +140,30 @@ class EncoderCacheManager:
|
||||
Note: This method does not allocate physical memory for the encoder
|
||||
output but only the state of EncoderCacheManager.
|
||||
"""
|
||||
num_tokens = request.get_num_encoder_tokens(input_id)
|
||||
num_embeds = request.get_num_encoder_embeds(input_id)
|
||||
|
||||
# Not enough compute budget
|
||||
if num_tokens > encoder_compute_budget:
|
||||
if num_embeds > encoder_compute_budget:
|
||||
return False
|
||||
|
||||
num_tokens += num_tokens_to_schedule
|
||||
num_embeds += num_embeds_to_schedule
|
||||
|
||||
# Enough free slots
|
||||
if num_tokens <= self.num_free_slots:
|
||||
if num_embeds <= self.num_free_slots:
|
||||
return True
|
||||
|
||||
# Not enough reclaimable slots
|
||||
if num_tokens > self.num_freeable_slots:
|
||||
if num_embeds > self.num_freeable_slots:
|
||||
return False
|
||||
|
||||
# Not enough free slots but enough reclaimable slots
|
||||
# NOTE: Eviction takes place here, but physical memory is not freed
|
||||
# until model runner is notified by the scheduler output.
|
||||
while num_tokens > self.num_free_slots:
|
||||
mm_hash, num_free_token = self.freeable.popitem(last=False)
|
||||
while num_embeds > self.num_free_slots:
|
||||
mm_hash, num_free_embeds = self.freeable.popitem(last=False)
|
||||
del self.cached[mm_hash]
|
||||
self.freed.append(mm_hash)
|
||||
self.num_free_slots += num_free_token
|
||||
self.num_free_slots += num_free_embeds
|
||||
return True
|
||||
|
||||
def allocate(self, request: Request, input_id: int) -> None:
|
||||
@ -176,16 +182,16 @@ class EncoderCacheManager:
|
||||
if mm_hash not in self.cached:
|
||||
self.cached[mm_hash] = set()
|
||||
|
||||
num_encoder_tokens = request.get_num_encoder_tokens(input_id)
|
||||
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
|
||||
|
||||
# NOTE: Encoder cache should always have enough space for encoder inputs
|
||||
# that are scheduled since eviction takes place at can_allocate().
|
||||
assert self.num_free_slots >= num_encoder_tokens
|
||||
assert self.num_freeable_slots >= num_encoder_tokens
|
||||
assert self.num_free_slots >= num_encoder_embeds
|
||||
assert self.num_freeable_slots >= num_encoder_embeds
|
||||
|
||||
self.cached[mm_hash].add(request_id)
|
||||
self.num_free_slots -= num_encoder_tokens
|
||||
self.num_freeable_slots -= num_encoder_tokens
|
||||
self.num_free_slots -= num_encoder_embeds
|
||||
self.num_freeable_slots -= num_encoder_embeds
|
||||
|
||||
def get_cached_input_ids(self, request: Request) -> set[int]:
|
||||
"""Get all cached multimodal input IDs for a request.
|
||||
@ -206,7 +212,7 @@ class EncoderCacheManager:
|
||||
|
||||
When the reference set for the corresponding `mm_hash` becomes empty,
|
||||
the entry is appended to `freeable` and `num_freeable_slots` is
|
||||
increased by the number of encoder tokens for that input.
|
||||
increased by the number of encoder embeddings for that input.
|
||||
|
||||
The entry is NOT physically freed until capacity is needed (e.g., by
|
||||
`can_allocate`).
|
||||
@ -218,9 +224,9 @@ class EncoderCacheManager:
|
||||
return
|
||||
self.cached[mm_hash].discard(req_id)
|
||||
if not self.cached[mm_hash]:
|
||||
num_tokens = request.get_num_encoder_tokens(input_id)
|
||||
self.freeable[mm_hash] = num_tokens
|
||||
self.num_freeable_slots += num_tokens
|
||||
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
|
||||
self.freeable[mm_hash] = num_encoder_embeds
|
||||
self.num_freeable_slots += num_encoder_embeds
|
||||
|
||||
def free(self, request: Request) -> None:
|
||||
"""Free all encoder input cache reference held by *request*.
|
||||
@ -361,20 +367,20 @@ class EncoderDecoderCacheManager(EncoderCacheManager):
|
||||
request: Request,
|
||||
input_id: int,
|
||||
encoder_compute_budget: int,
|
||||
num_tokens_to_schedule: int,
|
||||
num_embeds_to_schedule: int,
|
||||
) -> bool:
|
||||
num_tokens = request.get_num_encoder_tokens(input_id)
|
||||
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
|
||||
# Not enough compute budget
|
||||
if num_tokens > encoder_compute_budget:
|
||||
if num_encoder_embeds > encoder_compute_budget:
|
||||
return False
|
||||
|
||||
num_tokens += num_tokens_to_schedule
|
||||
num_encoder_embeds += num_embeds_to_schedule
|
||||
# Enough free slots
|
||||
return num_tokens <= self.num_free_slots
|
||||
return num_encoder_embeds <= self.num_free_slots
|
||||
|
||||
def allocate(self, request: Request, input_id: int) -> None:
|
||||
num_encoder_tokens = request.get_num_encoder_tokens(input_id)
|
||||
self.num_free_slots -= num_encoder_tokens
|
||||
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
|
||||
self.num_free_slots -= num_encoder_embeds
|
||||
|
||||
mm_hash = request.mm_features[input_id].identifier
|
||||
self.freed.append(mm_hash)
|
||||
@ -392,5 +398,5 @@ class EncoderDecoderCacheManager(EncoderCacheManager):
|
||||
return freed
|
||||
|
||||
def free_encoder_input(self, request: Request, input_id: int) -> None:
|
||||
num_tokens = request.get_num_encoder_tokens(input_id)
|
||||
self.num_free_slots += num_tokens
|
||||
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
|
||||
self.num_free_slots += num_encoder_embeds
|
||||
|
||||
@ -355,11 +355,11 @@ class Scheduler(SchedulerInterface):
|
||||
if preempted_encoder_inputs:
|
||||
# Restore encoder compute budget if the preempted
|
||||
# request had encoder inputs scheduled in this step.
|
||||
num_tokens_to_restore = sum(
|
||||
preempted_req.get_num_encoder_tokens(i)
|
||||
num_embeds_to_restore = sum(
|
||||
preempted_req.get_num_encoder_embeds(i)
|
||||
for i in preempted_encoder_inputs
|
||||
)
|
||||
encoder_compute_budget += num_tokens_to_restore
|
||||
encoder_compute_budget += num_embeds_to_restore
|
||||
req_index -= 1
|
||||
else:
|
||||
preempted_req = self.running.pop()
|
||||
@ -911,10 +911,11 @@ class Scheduler(SchedulerInterface):
|
||||
# multiple encoder inputs per request), we need to create temporary
|
||||
# trackers for accounting at the encoder input level.
|
||||
mm_hashes_to_schedule = set()
|
||||
num_tokens_to_schedule = 0
|
||||
num_embeds_to_schedule = 0
|
||||
for i, mm_feature in enumerate(mm_features):
|
||||
start_pos = mm_feature.mm_position.offset
|
||||
num_encoder_tokens = mm_feature.mm_position.length
|
||||
num_encoder_embeds = mm_feature.mm_position.get_num_embeds
|
||||
|
||||
# The encoder output is needed if the two ranges overlap:
|
||||
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and
|
||||
@ -970,9 +971,8 @@ class Scheduler(SchedulerInterface):
|
||||
):
|
||||
num_new_tokens = start_pos - num_computed_tokens
|
||||
break
|
||||
|
||||
if not self.encoder_cache_manager.can_allocate(
|
||||
request, i, encoder_compute_budget, num_tokens_to_schedule
|
||||
request, i, encoder_compute_budget, num_embeds_to_schedule
|
||||
):
|
||||
# The encoder cache is full or the encoder budget is exhausted.
|
||||
# NOTE(woosuk): We assume that the encoder input tokens should
|
||||
@ -992,14 +992,31 @@ class Scheduler(SchedulerInterface):
|
||||
num_new_tokens = 0
|
||||
break
|
||||
|
||||
# Calculate the number of embeddings to schedule in the current range
|
||||
# of scheduled encoder placholder tokens.
|
||||
start_idx_rel = max(0, num_computed_tokens - start_pos)
|
||||
end_idx_rel = min(
|
||||
num_encoder_tokens, num_computed_tokens + num_new_tokens - start_pos
|
||||
)
|
||||
curr_embeds_start, curr_embeds_end = (
|
||||
mm_feature.mm_position.get_embeds_indices_in_range(
|
||||
start_idx_rel,
|
||||
end_idx_rel,
|
||||
)
|
||||
)
|
||||
# There's no embeddings in the current range of encoder placeholder tokens
|
||||
# so we can skip the encoder input.
|
||||
if curr_embeds_end - curr_embeds_start == 0:
|
||||
continue
|
||||
|
||||
if self.ec_connector is not None and remote_cache_has_item[i]:
|
||||
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
|
||||
external_load_encoder_input.append(i)
|
||||
num_tokens_to_schedule += num_encoder_tokens
|
||||
num_embeds_to_schedule += num_encoder_embeds
|
||||
continue
|
||||
|
||||
num_tokens_to_schedule += num_encoder_tokens
|
||||
encoder_compute_budget -= num_encoder_tokens
|
||||
num_embeds_to_schedule += num_encoder_embeds
|
||||
encoder_compute_budget -= num_encoder_embeds
|
||||
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
|
||||
encoder_inputs_to_schedule.append(i)
|
||||
|
||||
|
||||
@ -209,10 +209,10 @@ class Request:
|
||||
def get_finished_reason(self) -> FinishReason | None:
|
||||
return RequestStatus.get_finished_reason(self.status)
|
||||
|
||||
def get_num_encoder_tokens(self, input_id: int) -> int:
|
||||
def get_num_encoder_embeds(self, input_id: int) -> int:
|
||||
assert input_id < len(self.mm_features)
|
||||
num_tokens = self.mm_features[input_id].mm_position.length
|
||||
return num_tokens
|
||||
num_embeds = self.mm_features[input_id].mm_position.get_num_embeds
|
||||
return num_embeds
|
||||
|
||||
def record_event(
|
||||
self,
|
||||
|
||||
@ -170,9 +170,7 @@ from .utils import (
|
||||
MultiModalBudget,
|
||||
add_kv_sharing_layers_to_kv_cache_groups,
|
||||
bind_kv_cache,
|
||||
gather_mm_placeholders,
|
||||
sanity_check_mm_encoder_outputs,
|
||||
scatter_mm_placeholders,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -1640,6 +1638,15 @@ class GPUModelRunner(
|
||||
logits_indices
|
||||
)
|
||||
|
||||
# Cache attention metadata builds across hybrid KV-cache groups
|
||||
# The only thing that changes between different hybrid KV-cache groups when the
|
||||
# same metadata builder and KVCacheSpec is the same is the block table, so we
|
||||
# can cache the attention metadata builds and just update the block table using
|
||||
# `builder.update_block_table` if the builder supports it.
|
||||
cached_attn_metadata: dict[
|
||||
tuple[KVCacheSpec, type[AttentionMetadataBuilder]], AttentionMetadata
|
||||
] = {}
|
||||
|
||||
def _build_attn_group_metadata(
|
||||
kv_cache_gid: int,
|
||||
attn_gid: int,
|
||||
@ -1647,13 +1654,15 @@ class GPUModelRunner(
|
||||
ubid: int | None = None,
|
||||
) -> None:
|
||||
attn_group = self.attn_groups[kv_cache_gid][attn_gid]
|
||||
builder = attn_group.get_metadata_builder(ubid or 0)
|
||||
cache_key = (kv_cache_groups[kv_cache_gid].kv_cache_spec, type(builder))
|
||||
|
||||
cascade_attn_prefix_len = (
|
||||
cascade_attn_prefix_lens[kv_cache_gid][attn_gid]
|
||||
if cascade_attn_prefix_lens
|
||||
else 0
|
||||
)
|
||||
|
||||
builder = attn_group.get_metadata_builder(ubid or 0)
|
||||
extra_attn_metadata_args = {}
|
||||
if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder):
|
||||
assert ubid is None, "UBatching not supported with GDN yet"
|
||||
@ -1668,12 +1677,23 @@ class GPUModelRunner(
|
||||
attn_metadata_i = builder.build_for_cudagraph_capture(
|
||||
common_attn_metadata
|
||||
)
|
||||
elif (
|
||||
cache_key in cached_attn_metadata
|
||||
and builder.supports_update_block_table
|
||||
):
|
||||
attn_metadata_i = builder.update_block_table(
|
||||
cached_attn_metadata[cache_key],
|
||||
common_attn_metadata.block_table_tensor,
|
||||
common_attn_metadata.slot_mapping,
|
||||
)
|
||||
else:
|
||||
attn_metadata_i = builder.build(
|
||||
common_prefix_len=cascade_attn_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
**extra_attn_metadata_args,
|
||||
)
|
||||
if builder.supports_update_block_table:
|
||||
cached_attn_metadata[cache_key] = attn_metadata_i
|
||||
|
||||
if ubid is None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
@ -2271,10 +2291,7 @@ class GPUModelRunner(
|
||||
|
||||
# Cache the encoder outputs by mm_hash
|
||||
for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs):
|
||||
self.encoder_cache[mm_hash] = scatter_mm_placeholders(
|
||||
output,
|
||||
is_embed=pos_info.is_embed,
|
||||
)
|
||||
self.encoder_cache[mm_hash] = output
|
||||
logger.debug("Finish execute for mm hash %s", mm_hash)
|
||||
self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash)
|
||||
|
||||
@ -2325,6 +2342,13 @@ class GPUModelRunner(
|
||||
num_encoder_tokens,
|
||||
)
|
||||
assert start_idx < end_idx
|
||||
curr_embeds_start, curr_embeds_end = (
|
||||
pos_info.get_embeds_indices_in_range(start_idx, end_idx)
|
||||
)
|
||||
# If there are no embeddings in the current range, we skip
|
||||
# gathering the embeddings.
|
||||
if curr_embeds_start == curr_embeds_end:
|
||||
continue
|
||||
|
||||
mm_hash = mm_feature.identifier
|
||||
encoder_output = self.encoder_cache.get(mm_hash, None)
|
||||
@ -2332,16 +2356,14 @@ class GPUModelRunner(
|
||||
|
||||
if (is_embed := pos_info.is_embed) is not None:
|
||||
is_embed = is_embed[start_idx:end_idx]
|
||||
mm_embeds_item = encoder_output[curr_embeds_start:curr_embeds_end]
|
||||
else:
|
||||
mm_embeds_item = encoder_output[start_idx:end_idx]
|
||||
|
||||
req_start_pos = req_start_idx + start_pos - num_computed_tokens
|
||||
is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = (
|
||||
True if is_embed is None else is_embed
|
||||
)
|
||||
|
||||
mm_embeds_item = gather_mm_placeholders(
|
||||
encoder_output[start_idx:end_idx],
|
||||
is_embed=is_embed,
|
||||
)
|
||||
mm_embeds_req.append(mm_embeds_item)
|
||||
|
||||
if self.is_multimodal_pruning_enabled and self.uses_mrope:
|
||||
@ -4570,31 +4592,8 @@ class GPUModelRunner(
|
||||
dummy_encoder_outputs,
|
||||
expected_num_items=max_mm_items_per_batch,
|
||||
)
|
||||
|
||||
# NOTE: This happens when encoder cache needs to store
|
||||
# the embeddings that encoder outputs are scattered onto.
|
||||
# In this case we create dummy embeddings of size
|
||||
# (max_tokens_for_modality, hidden_size) and scatter
|
||||
# encoder output into it.
|
||||
encoder_output_shape = dummy_encoder_outputs[0].shape
|
||||
max_mm_tokens_per_item = mm_budget.max_tokens_by_modality[
|
||||
dummy_modality
|
||||
]
|
||||
if encoder_output_shape[0] < max_mm_tokens_per_item:
|
||||
encoder_hidden_size = encoder_output_shape[-1]
|
||||
expanded_outputs = []
|
||||
for output in dummy_encoder_outputs:
|
||||
expanded = output.new_zeros(
|
||||
(max_mm_tokens_per_item, encoder_hidden_size)
|
||||
)
|
||||
num_tokens = output.shape[0]
|
||||
expanded[:num_tokens].copy_(output)
|
||||
expanded_outputs.append(expanded)
|
||||
|
||||
dummy_encoder_outputs = expanded_outputs
|
||||
|
||||
# Cache the dummy encoder outputs.
|
||||
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
|
||||
for i, output in enumerate(dummy_encoder_outputs):
|
||||
self.encoder_cache[f"tmp_{i}"] = output
|
||||
|
||||
# Add `is_profile` here to pre-allocate communication buffers
|
||||
hidden_states, last_hidden_states = self._dummy_run(
|
||||
|
||||
@ -4,10 +4,12 @@ from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
from vllm.multimodal.cache import processor_only_cache_from_config
|
||||
@ -17,6 +19,8 @@ from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
|
||||
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
|
||||
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MultiModalBudget:
|
||||
"""Helper class to calculate budget information for multi-modal models."""
|
||||
@ -198,6 +202,7 @@ def sanity_check_mm_encoder_outputs(
|
||||
)
|
||||
|
||||
|
||||
@deprecated("`scatter_mm_placeholders` is deprecated and will be removed in v0.15.0.")
|
||||
def scatter_mm_placeholders(
|
||||
embeds: torch.Tensor,
|
||||
is_embed: torch.Tensor | None,
|
||||
@ -226,6 +231,7 @@ def scatter_mm_placeholders(
|
||||
return placeholders
|
||||
|
||||
|
||||
@deprecated("`gather_mm_placeholders` is deprecated and will be removed in v0.15.0.")
|
||||
def gather_mm_placeholders(
|
||||
placeholders: torch.Tensor,
|
||||
is_embed: torch.Tensor | None,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user