diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index b07d20bab7dd9..e0e3ef71b485f 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -172,7 +172,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Quantization #if defined(__AVX512F__) || (defined(__aarch64__) && !defined(__APPLE__)) || \ defined(__powerpc64__) - at::Tag stride_tag = at::Tag::needs_fixed_stride_order; // Helper function to release oneDNN handlers ops.def("release_dnnl_matmul_handler(int handler) -> ()", &release_dnnl_matmul_handler); @@ -208,15 +207,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Compute int8 quantized tensor for given scaling factor. ops.def( "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale," - "Tensor? azp) -> ()", - {stride_tag}); + "Tensor? azp) -> ()"); ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant); // Compute int8 quantized tensor and scaling factor ops.def( "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, " - "Tensor!? azp) -> ()", - {stride_tag}); + "Tensor!? azp) -> ()"); ops.impl("dynamic_scaled_int8_quant", torch::kCPU, &dynamic_scaled_int8_quant); #endif diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index c3ae06a30e3e8..5af74c2c2a6b0 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -20,18 +20,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // vLLM custom ops // - // The default behavior in PyTorch 2.6 was changed to "requires_contiguous", - // so we need - // to override this for many GEMMs with the following tag. Otherwise, - // torch.compile will force all input tensors to be contiguous(), which - // will break many custom ops that require column-major weight matrices. - // This was a bug and PyTorch 2.7 has since fixed this. -#if TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 6 - #define stride_tag at::Tag::needs_fixed_stride_order -#else - #define stride_tag -#endif - ops.def( "persistent_masked_m_silu_mul_quant(Tensor input, Tensor counts, Tensor! " "y_q, Tensor! y_s," @@ -241,15 +229,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Quantized GEMM for AWQ. ops.def( "awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, " - "Tensor _zeros, SymInt split_k_iters) -> Tensor", - {stride_tag}); + "Tensor _zeros, SymInt split_k_iters) -> Tensor"); ops.impl("awq_gemm", torch::kCUDA, &awq_gemm); // Dequantization for AWQ. ops.def( "awq_dequantize(Tensor _kernel, Tensor _scaling_factors, " - "Tensor _zeros, SymInt split_k_iters, int thx, int thy) -> Tensor", - {stride_tag}); + "Tensor _zeros, SymInt split_k_iters, int thx, int thy) -> Tensor"); ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize); // Note about marlin kernel 'workspace' arguments: @@ -271,8 +257,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, " "Tensor b_scales, Tensor workspace, " "int b_q_type, " - "SymInt size_m, SymInt size_n, SymInt size_k) -> Tensor", - {stride_tag}); + "SymInt size_m, SymInt size_n, SymInt size_k) -> Tensor"); // conditionally compiled so impl in source file // Machete (Dense) Optimized Mixed Precision GEMM for Hopper. @@ -298,8 +283,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor? channel_scales," " Tensor? token_scales," " str? schedule" - ") -> Tensor", - {stride_tag}); + ") -> Tensor"); ops.def( "machete_prepack_B(" " Tensor B," @@ -319,8 +303,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor b_scales, Tensor? global_scale, Tensor? b_zeros_or_none, Tensor? " "g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_q_type, " "SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, " - "bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor", - {stride_tag}); + "bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor"); // conditionally compiled so impl registration is in source file // gptq_marlin repack from GPTQ. @@ -346,8 +329,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor token_scales," " ScalarType? out_type," " str? maybe_schedule" - ") -> Tensor", - {stride_tag}); + ") -> Tensor"); // pack scales ops.def("cutlass_pack_scale_fp8(Tensor scales) -> Tensor"); // encode and reorder weight matrix @@ -394,24 +376,21 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def( "cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b," " Tensor block_scale_a, Tensor block_scale_b," - " Tensor alpha) -> ()", - {stride_tag}); + " Tensor alpha) -> ()"); ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm); // cutlass blockwise scaledgroup GEMM ops.def( "cutlass_blockwise_scaled_grouped_mm(Tensor! output, Tensor a, Tensor b, " "Tensor scales_a, Tensor scales_b, " - "Tensor problem_sizes, Tensor expert_offsets) -> ()", - {stride_tag}); + "Tensor problem_sizes, Tensor expert_offsets) -> ()"); // conditionally compiled so impl registration is in source file // cutlass nvfp4 block scaled group GEMM ops.def( "cutlass_fp4_group_mm(Tensor! out, Tensor a, Tensor b," " Tensor a_blockscale, Tensor b_blockscales, Tensor alphas," - " Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()", - {stride_tag}); + " Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()"); // conditionally compiled so impl registration is in source file // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column @@ -419,8 +398,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def( "cutlass_scaled_mm(Tensor! out, Tensor a," " Tensor b, Tensor a_scales," - " Tensor b_scales, Tensor? bias) -> ()", - {stride_tag}); + " Tensor b_scales, Tensor? bias) -> ()"); ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm); // CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column @@ -429,8 +407,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "cutlass_scaled_mm_azp(Tensor! out, Tensor a," " Tensor b, Tensor a_scales," " Tensor b_scales, Tensor azp_adj," - " Tensor? azp, Tensor? bias) -> ()", - {stride_tag}); + " Tensor? azp, Tensor? bias) -> ()"); ops.impl("cutlass_scaled_mm_azp", torch::kCUDA, &cutlass_scaled_mm_azp); // Check if cutlass scaled_mm is supported for CUDA devices of the given @@ -449,8 +426,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor a_scales, Tensor b_scales, Tensor expert_offsets, " " Tensor problem_sizes, Tensor a_strides, " " Tensor b_strides, Tensor c_strides, bool per_act_token, " - " bool per_out_ch) -> ()", - {stride_tag}); + " bool per_out_ch) -> ()"); ops.impl("cutlass_moe_mm", torch::kCUDA, &cutlass_moe_mm); // A function that computes data required to run fused MoE with w8a8 grouped @@ -464,8 +440,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor! problem_sizes1, Tensor! problem_sizes2, " " Tensor! input_permutation, " " Tensor! output_permutation, int num_experts, " - " int n, int k, Tensor? blockscale_offsets) -> ()", - {stride_tag}); + " int n, int k, Tensor? blockscale_offsets) -> " + "()"); ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data); // A function that computes problem sizes for each expert's multiplication @@ -476,8 +452,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor! problem_sizes1, " " Tensor! problem_sizes2, " " int num_experts, int n, int k, " - " Tensor? blockscale_offsets) -> ()", - {stride_tag}); + " Tensor? blockscale_offsets) -> ()"); ops.impl("get_cutlass_moe_mm_problem_sizes", torch::kCUDA, &get_cutlass_moe_mm_problem_sizes); @@ -492,8 +467,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor! problem_sizes2, " " Tensor expert_num_tokens, " " int num_local_experts, int padded_m, " - " int n, int k) -> ()", - {stride_tag}); + " int n, int k) -> ()"); ops.impl("get_cutlass_pplx_moe_mm_data", torch::kCUDA, &get_cutlass_pplx_moe_mm_data); @@ -517,8 +491,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "cutlass_scaled_sparse_mm(Tensor! out, Tensor a," " Tensor bt_nzs," " Tensor bt_meta, Tensor a_scales," - " Tensor b_scales, Tensor? bias) -> ()", - {stride_tag}); + " Tensor b_scales, Tensor? bias) -> ()"); ops.impl("cutlass_scaled_sparse_mm", torch::kCUDA, &cutlass_scaled_sparse_mm); // CUTLASS sparse matrix compressor @@ -567,8 +540,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, " "Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, bool " "use_v2_format, int bit) " - "-> Tensor", - {stride_tag}); + "-> Tensor"); ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm); // Post processing for GPTQ.