mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-25 04:34:40 +08:00
cleanup at::Tag::needs_fixed_stride_order (#28974)
Signed-off-by: Boyuan Feng <boyuan@meta.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
322cb02872
commit
a903d59ffa
@ -172,7 +172,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// Quantization
|
||||
#if defined(__AVX512F__) || (defined(__aarch64__) && !defined(__APPLE__)) || \
|
||||
defined(__powerpc64__)
|
||||
at::Tag stride_tag = at::Tag::needs_fixed_stride_order;
|
||||
// Helper function to release oneDNN handlers
|
||||
ops.def("release_dnnl_matmul_handler(int handler) -> ()",
|
||||
&release_dnnl_matmul_handler);
|
||||
@ -208,15 +207,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// Compute int8 quantized tensor for given scaling factor.
|
||||
ops.def(
|
||||
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
|
||||
"Tensor? azp) -> ()",
|
||||
{stride_tag});
|
||||
"Tensor? azp) -> ()");
|
||||
ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant);
|
||||
|
||||
// Compute int8 quantized tensor and scaling factor
|
||||
ops.def(
|
||||
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, "
|
||||
"Tensor!? azp) -> ()",
|
||||
{stride_tag});
|
||||
"Tensor!? azp) -> ()");
|
||||
ops.impl("dynamic_scaled_int8_quant", torch::kCPU,
|
||||
&dynamic_scaled_int8_quant);
|
||||
#endif
|
||||
|
||||
@ -20,18 +20,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// vLLM custom ops
|
||||
//
|
||||
|
||||
// The default behavior in PyTorch 2.6 was changed to "requires_contiguous",
|
||||
// so we need
|
||||
// to override this for many GEMMs with the following tag. Otherwise,
|
||||
// torch.compile will force all input tensors to be contiguous(), which
|
||||
// will break many custom ops that require column-major weight matrices.
|
||||
// This was a bug and PyTorch 2.7 has since fixed this.
|
||||
#if TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 6
|
||||
#define stride_tag at::Tag::needs_fixed_stride_order
|
||||
#else
|
||||
#define stride_tag
|
||||
#endif
|
||||
|
||||
ops.def(
|
||||
"persistent_masked_m_silu_mul_quant(Tensor input, Tensor counts, Tensor! "
|
||||
"y_q, Tensor! y_s,"
|
||||
@ -241,15 +229,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// Quantized GEMM for AWQ.
|
||||
ops.def(
|
||||
"awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, "
|
||||
"Tensor _zeros, SymInt split_k_iters) -> Tensor",
|
||||
{stride_tag});
|
||||
"Tensor _zeros, SymInt split_k_iters) -> Tensor");
|
||||
ops.impl("awq_gemm", torch::kCUDA, &awq_gemm);
|
||||
|
||||
// Dequantization for AWQ.
|
||||
ops.def(
|
||||
"awq_dequantize(Tensor _kernel, Tensor _scaling_factors, "
|
||||
"Tensor _zeros, SymInt split_k_iters, int thx, int thy) -> Tensor",
|
||||
{stride_tag});
|
||||
"Tensor _zeros, SymInt split_k_iters, int thx, int thy) -> Tensor");
|
||||
ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);
|
||||
|
||||
// Note about marlin kernel 'workspace' arguments:
|
||||
@ -271,8 +257,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, "
|
||||
"Tensor b_scales, Tensor workspace, "
|
||||
"int b_q_type, "
|
||||
"SymInt size_m, SymInt size_n, SymInt size_k) -> Tensor",
|
||||
{stride_tag});
|
||||
"SymInt size_m, SymInt size_n, SymInt size_k) -> Tensor");
|
||||
// conditionally compiled so impl in source file
|
||||
|
||||
// Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
|
||||
@ -298,8 +283,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor? channel_scales,"
|
||||
" Tensor? token_scales,"
|
||||
" str? schedule"
|
||||
") -> Tensor",
|
||||
{stride_tag});
|
||||
") -> Tensor");
|
||||
ops.def(
|
||||
"machete_prepack_B("
|
||||
" Tensor B,"
|
||||
@ -319,8 +303,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"Tensor b_scales, Tensor? global_scale, Tensor? b_zeros_or_none, Tensor? "
|
||||
"g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_q_type, "
|
||||
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
|
||||
"bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor",
|
||||
{stride_tag});
|
||||
"bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor");
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
// gptq_marlin repack from GPTQ.
|
||||
@ -346,8 +329,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor token_scales,"
|
||||
" ScalarType? out_type,"
|
||||
" str? maybe_schedule"
|
||||
") -> Tensor",
|
||||
{stride_tag});
|
||||
") -> Tensor");
|
||||
// pack scales
|
||||
ops.def("cutlass_pack_scale_fp8(Tensor scales) -> Tensor");
|
||||
// encode and reorder weight matrix
|
||||
@ -394,24 +376,21 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def(
|
||||
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
|
||||
" Tensor block_scale_a, Tensor block_scale_b,"
|
||||
" Tensor alpha) -> ()",
|
||||
{stride_tag});
|
||||
" Tensor alpha) -> ()");
|
||||
ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm);
|
||||
|
||||
// cutlass blockwise scaledgroup GEMM
|
||||
ops.def(
|
||||
"cutlass_blockwise_scaled_grouped_mm(Tensor! output, Tensor a, Tensor b, "
|
||||
"Tensor scales_a, Tensor scales_b, "
|
||||
"Tensor problem_sizes, Tensor expert_offsets) -> ()",
|
||||
{stride_tag});
|
||||
"Tensor problem_sizes, Tensor expert_offsets) -> ()");
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
// cutlass nvfp4 block scaled group GEMM
|
||||
ops.def(
|
||||
"cutlass_fp4_group_mm(Tensor! out, Tensor a, Tensor b,"
|
||||
" Tensor a_blockscale, Tensor b_blockscales, Tensor alphas,"
|
||||
" Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()",
|
||||
{stride_tag});
|
||||
" Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()");
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
||||
@ -419,8 +398,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def(
|
||||
"cutlass_scaled_mm(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor? bias) -> ()",
|
||||
{stride_tag});
|
||||
" Tensor b_scales, Tensor? bias) -> ()");
|
||||
ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm);
|
||||
|
||||
// CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
|
||||
@ -429,8 +407,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor azp_adj,"
|
||||
" Tensor? azp, Tensor? bias) -> ()",
|
||||
{stride_tag});
|
||||
" Tensor? azp, Tensor? bias) -> ()");
|
||||
ops.impl("cutlass_scaled_mm_azp", torch::kCUDA, &cutlass_scaled_mm_azp);
|
||||
|
||||
// Check if cutlass scaled_mm is supported for CUDA devices of the given
|
||||
@ -449,8 +426,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
|
||||
" Tensor problem_sizes, Tensor a_strides, "
|
||||
" Tensor b_strides, Tensor c_strides, bool per_act_token, "
|
||||
" bool per_out_ch) -> ()",
|
||||
{stride_tag});
|
||||
" bool per_out_ch) -> ()");
|
||||
ops.impl("cutlass_moe_mm", torch::kCUDA, &cutlass_moe_mm);
|
||||
|
||||
// A function that computes data required to run fused MoE with w8a8 grouped
|
||||
@ -464,8 +440,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
|
||||
" Tensor! input_permutation, "
|
||||
" Tensor! output_permutation, int num_experts, "
|
||||
" int n, int k, Tensor? blockscale_offsets) -> ()",
|
||||
{stride_tag});
|
||||
" int n, int k, Tensor? blockscale_offsets) -> "
|
||||
"()");
|
||||
ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data);
|
||||
|
||||
// A function that computes problem sizes for each expert's multiplication
|
||||
@ -476,8 +452,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor! problem_sizes1, "
|
||||
" Tensor! problem_sizes2, "
|
||||
" int num_experts, int n, int k, "
|
||||
" Tensor? blockscale_offsets) -> ()",
|
||||
{stride_tag});
|
||||
" Tensor? blockscale_offsets) -> ()");
|
||||
ops.impl("get_cutlass_moe_mm_problem_sizes", torch::kCUDA,
|
||||
&get_cutlass_moe_mm_problem_sizes);
|
||||
|
||||
@ -492,8 +467,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor! problem_sizes2, "
|
||||
" Tensor expert_num_tokens, "
|
||||
" int num_local_experts, int padded_m, "
|
||||
" int n, int k) -> ()",
|
||||
{stride_tag});
|
||||
" int n, int k) -> ()");
|
||||
ops.impl("get_cutlass_pplx_moe_mm_data", torch::kCUDA,
|
||||
&get_cutlass_pplx_moe_mm_data);
|
||||
|
||||
@ -517,8 +491,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"cutlass_scaled_sparse_mm(Tensor! out, Tensor a,"
|
||||
" Tensor bt_nzs,"
|
||||
" Tensor bt_meta, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor? bias) -> ()",
|
||||
{stride_tag});
|
||||
" Tensor b_scales, Tensor? bias) -> ()");
|
||||
ops.impl("cutlass_scaled_sparse_mm", torch::kCUDA, &cutlass_scaled_sparse_mm);
|
||||
|
||||
// CUTLASS sparse matrix compressor
|
||||
@ -567,8 +540,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, "
|
||||
"Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, bool "
|
||||
"use_v2_format, int bit) "
|
||||
"-> Tensor",
|
||||
{stride_tag});
|
||||
"-> Tensor");
|
||||
ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);
|
||||
|
||||
// Post processing for GPTQ.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user