[CI/Build] Conditionally register cutlass_fp4_group_mm to fix building on Hopper (#26138)

Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Michael Goin 2025-10-02 23:32:38 -04:00 committed by yewentao256
parent 2ea7d48656
commit 173c8a9520
2 changed files with 7 additions and 1 deletions

View File

@ -14,6 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#include "core/registration.h"
#include <torch/all.h> #include <torch/all.h>
#include <cutlass/arch/arch.h> #include <cutlass/arch/arch.h>
@ -418,3 +420,7 @@ void cutlass_fp4_group_mm(
"12.8 or above."); "12.8 or above.");
#endif #endif
} }
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("cutlass_fp4_group_mm", &cutlass_fp4_group_mm);
}

View File

@ -397,7 +397,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor a_blockscale, Tensor b_blockscales, Tensor alphas," " Tensor a_blockscale, Tensor b_blockscales, Tensor alphas,"
" Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()", " Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()",
{stride_tag}); {stride_tag});
ops.impl("cutlass_fp4_group_mm", torch::kCUDA, &cutlass_fp4_group_mm); // conditionally compiled so impl registration is in source file
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
// quantization, as well as bias // quantization, as well as bias