Mamba V2 Test not Asserting Failures. (#21379)

Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
This commit is contained in:
Yu Chin Fabian Lim 2025-07-23 04:40:27 -04:00 committed by GitHub
parent accac82928
commit 32ec9e2f2a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 25 additions and 10 deletions

View File

@ -119,7 +119,8 @@ def mixer2_gated_norm_tensor_parallel(
gate_states[..., local_rank * N:(local_rank + 1) * N],
)
ref_output = mixer_single_gpu(hidden_states, gate_states)
torch.allclose(output,
ref_output[..., local_rank * N:(local_rank + 1) * N],
atol=1e-3,
rtol=1e-3)
torch.testing.assert_close(output,
ref_output[...,
local_rank * N:(local_rank + 1) * N],
atol=5e-3,
rtol=1e-3)

View File

@ -193,6 +193,13 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
# this tests the kernels on a single example (no batching)
# TODO: the bfloat16 case requires higher thresholds. To be investigated
if itype == torch.bfloat16:
atol, rtol = 5e-2, 5e-2
else:
atol, rtol = 8e-3, 5e-3
# set seed
batch_size = 1 # batch_size
# ssd_minimal_discrete requires chunk_size divide seqlen
@ -216,14 +223,14 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
return_final_states=True)
# just test the last in sequence
torch.allclose(Y[:, -1], Y_min[:, -1], atol=1e-3, rtol=1e-3)
torch.testing.assert_close(Y[:, -1], Y_min[:, -1], atol=atol, rtol=rtol)
# just test the last head
# NOTE, in the kernel we always cast states to fp32
torch.allclose(final_state[:, -1],
final_state_min[:, -1].to(torch.float32),
atol=1e-3,
rtol=1e-3)
torch.testing.assert_close(final_state[:, -1],
final_state_min[:, -1].to(torch.float32),
atol=atol,
rtol=rtol)
@pytest.mark.parametrize("itype", [torch.float32, torch.float16])
@ -263,6 +270,13 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases
# TODO: the irregular chunk size cases have some issues and require higher
# tolerance. This is to be invesigated
if chunk_size not in {8, 256}:
atol, rtol = 5e-1, 5e-1
else:
atol, rtol = 5e-3, 5e-3
# hold state during the cutting process so we know if an
# example has been exhausted and needs to cycle
last_taken: dict = {} # map: eg -> pointer to last taken sample
@ -300,7 +314,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
# just test one dim and dstate
Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0]
Y_min_eg = Y_min[i][:, 0, 0]
torch.allclose(Y_eg, Y_min_eg, atol=1e-3, rtol=1e-3)
torch.testing.assert_close(Y_eg, Y_min_eg, atol=atol, rtol=rtol)
# update states
states = new_states