diff --git a/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu b/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu index 2c8df6144bf4..5b007e5ea328 100644 --- a/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu +++ b/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu @@ -14,6 +14,8 @@ * limitations under the License. */ +#include "core/registration.h" + #include #include @@ -418,3 +420,7 @@ void cutlass_fp4_group_mm( "12.8 or above."); #endif } + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("cutlass_fp4_group_mm", &cutlass_fp4_group_mm); +} diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index ebd28e735088..64a345eb66cc 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -397,7 +397,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor a_blockscale, Tensor b_blockscales, Tensor alphas," " Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()", {stride_tag}); - ops.impl("cutlass_fp4_group_mm", torch::kCUDA, &cutlass_fp4_group_mm); + // conditionally compiled so impl registration is in source file // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column // quantization, as well as bias