From d31a64712489d7c079fe48515c7ddd8a60bc0e71 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 16 Jul 2025 01:27:29 -0400 Subject: [PATCH] [BugFix] Fix import error on non-blackwell machines (#21020) Signed-off-by: Lucas Wilkinson --- csrc/attention/mla/sm100_cutlass_mla_kernel.cu | 10 ++++++++++ csrc/ops.h | 13 ------------- csrc/torch_bindings.cpp | 5 ++--- 3 files changed, 12 insertions(+), 16 deletions(-) diff --git a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu index 0d57ff4cc7cb2..e0e95d06290df 100644 --- a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu +++ b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu @@ -18,6 +18,7 @@ limitations under the License. * Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929 * by Alcanderian JieXin Liang */ +#include "core/registration.h" #include #include @@ -270,4 +271,13 @@ int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_ba } #endif + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("sm100_cutlass_mla_decode", &sm100_cutlass_mla_decode); +} + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CatchAll, m) { + m.impl("sm100_cutlass_mla_get_workspace_size", &sm100_cutlass_mla_get_workspace_size); +} + // clang-format on diff --git a/csrc/ops.h b/csrc/ops.h index 20ad163dc0d65..7f3e6b6923a3f 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -167,19 +167,6 @@ void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope, torch::Tensor const& seq_lens, torch::Tensor const& page_table, double scale); -void sm100_cutlass_mla_decode( - torch::Tensor const& out, torch::Tensor const& q_nope, - torch::Tensor const& q_pe, torch::Tensor const& kv_c_and_k_pe_cache, - torch::Tensor const& seq_lens, torch::Tensor const& page_table, - torch::Tensor const& workspace, double sm_scale, - int64_t num_kv_splits = - 1 /* Set to 1 to avoid cuda_graph issue by default. */); - -int64_t sm100_cutlass_mla_get_workspace_size( - int64_t max_seq_len, int64_t num_batches, int64_t sm_count = 0, - int64_t num_kv_splits = - 1 /* Set to 1 to avoid cuda_graph issue by default. */); - torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor); #ifndef USE_ROCM diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 370edc2014934..23e9212a2f1d1 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -521,15 +521,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor page_table, Tensor workspace, float " "scale," " int num_kv_splits) -> ()"); - ops.impl("sm100_cutlass_mla_decode", torch::kCUDA, &sm100_cutlass_mla_decode); + // conditionally compiled so impl in source file // SM100 CUTLASS MLA workspace ops.def( "sm100_cutlass_mla_get_workspace_size(int max_seq_len, int num_batches," " int sm_count, int num_kv_splits) " "-> int"); - ops.impl("sm100_cutlass_mla_get_workspace_size", - &sm100_cutlass_mla_get_workspace_size); + // conditionally compiled so impl in source file // Compute NVFP4 block quantized tensor. ops.def(