[CI Failure] Fix test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe (#24750)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-09-13 03:29:19 -04:00 committed by simon-mo
parent da3fa78dc9
commit 26b999c71a
2 changed files with 3 additions and 2 deletions

View File

@ -43,6 +43,7 @@ void sm100_cutlass_mla_decode(
torch::Tensor const& seq_lens, torch::Tensor const& seq_lens,
torch::Tensor const& page_table, torch::Tensor const& page_table,
torch::Tensor const& workspace, torch::Tensor const& workspace,
double sm_scale,
int64_t num_kv_splits) { int64_t num_kv_splits) {
TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_decode"); TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_decode");
} }

View File

@ -771,11 +771,11 @@ def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe(
w13_ref = dequant_mxfp4_batches( w13_ref = dequant_mxfp4_batches(
w13_q.view(torch.uint8), w13_q.view(torch.uint8),
w13_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape( w13_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape(
num_experts, 2 * intermediate_size, hidden_size) num_experts, 2 * intermediate_size, hidden_size).to(device)
w2_ref = dequant_mxfp4_batches( w2_ref = dequant_mxfp4_batches(
w2_q.view(torch.uint8), w2_q.view(torch.uint8),
w2_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape( w2_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape(
num_experts, hidden_size, intermediate_size) num_experts, hidden_size, intermediate_size).to(device)
# Quantize activations for SM100 path and dequantize for reference # Quantize activations for SM100 path and dequantize for reference
hidden_states_q, hidden_states_sf = mxfp8_quantize(hidden_states, True, 32) hidden_states_q, hidden_states_sf = mxfp8_quantize(hidden_states, True, 32)