diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index c9a120976b1c6..b595b0aa60b63 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -130,13 +130,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ") -> ()"); ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer); - // Compute MLA decode using cutlass. - ops.def( - "cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe," - " Tensor kv_c_and_k_pe_cache, Tensor seq_lens," - " Tensor page_table, float scale) -> ()"); - ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode); - // Layernorm // Apply Root Mean Square (RMS) Normalization to the input tensor. ops.def( @@ -450,6 +443,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("cutlass_sparse_compress(Tensor a) -> Tensor[]"); ops.impl("cutlass_sparse_compress", &cutlass_sparse_compress); + // CUTLASS MLA decode + ops.def( + "cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe," + " Tensor kv_c_and_k_pe_cache, Tensor seq_lens," + " Tensor page_table, float scale) -> ()"); + ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode); + // Mamba selective scan kernel ops.def( "selective_scan_fwd(Tensor! u, Tensor! delta,"