From 06ffc7e1d35b3f754e46439babfed564822bbb75 Mon Sep 17 00:00:00 2001 From: TY-AMD Date: Wed, 30 Apr 2025 01:26:42 +0800 Subject: [PATCH] [Misc][ROCm] Exclude `cutlass_mla_decode` for ROCm build (#17289) Signed-off-by: Tianyuan Wu --- csrc/torch_bindings.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index c9a120976b1c6..b595b0aa60b63 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -130,13 +130,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ") -> ()"); ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer); - // Compute MLA decode using cutlass. - ops.def( - "cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe," - " Tensor kv_c_and_k_pe_cache, Tensor seq_lens," - " Tensor page_table, float scale) -> ()"); - ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode); - // Layernorm // Apply Root Mean Square (RMS) Normalization to the input tensor. ops.def( @@ -450,6 +443,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("cutlass_sparse_compress(Tensor a) -> Tensor[]"); ops.impl("cutlass_sparse_compress", &cutlass_sparse_compress); + // CUTLASS MLA decode + ops.def( + "cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe," + " Tensor kv_c_and_k_pe_cache, Tensor seq_lens," + " Tensor page_table, float scale) -> ()"); + ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode); + // Mamba selective scan kernel ops.def( "selective_scan_fwd(Tensor! u, Tensor! delta,"