[CI/Build] Conditionally register cutlass_fp4_group_mm to fix building on Hopper (#26138)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-10-02 23:32:38 -04:00 committed by GitHub
parent 2aaa423842
commit 5d5146eee3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 1 deletions

View File

@ -14,6 +14,8 @@
* limitations under the License.
*/
#include "core/registration.h"
#include <torch/all.h>
#include <cutlass/arch/arch.h>
@ -418,3 +420,7 @@ void cutlass_fp4_group_mm(
"12.8 or above.");
#endif
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("cutlass_fp4_group_mm", &cutlass_fp4_group_mm);
}

View File

@ -397,7 +397,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor a_blockscale, Tensor b_blockscales, Tensor alphas,"
" Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()",
{stride_tag});
ops.impl("cutlass_fp4_group_mm", torch::kCUDA, &cutlass_fp4_group_mm);
// conditionally compiled so impl registration is in source file
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
// quantization, as well as bias