mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 13:15:48 +08:00
Mamba V2 Test not Asserting Failures. (#21379)
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
This commit is contained in:
parent
accac82928
commit
32ec9e2f2a
@ -119,7 +119,8 @@ def mixer2_gated_norm_tensor_parallel(
|
||||
gate_states[..., local_rank * N:(local_rank + 1) * N],
|
||||
)
|
||||
ref_output = mixer_single_gpu(hidden_states, gate_states)
|
||||
torch.allclose(output,
|
||||
ref_output[..., local_rank * N:(local_rank + 1) * N],
|
||||
atol=1e-3,
|
||||
rtol=1e-3)
|
||||
torch.testing.assert_close(output,
|
||||
ref_output[...,
|
||||
local_rank * N:(local_rank + 1) * N],
|
||||
atol=5e-3,
|
||||
rtol=1e-3)
|
||||
|
||||
@ -193,6 +193,13 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
|
||||
|
||||
# this tests the kernels on a single example (no batching)
|
||||
|
||||
# TODO: the bfloat16 case requires higher thresholds. To be investigated
|
||||
|
||||
if itype == torch.bfloat16:
|
||||
atol, rtol = 5e-2, 5e-2
|
||||
else:
|
||||
atol, rtol = 8e-3, 5e-3
|
||||
|
||||
# set seed
|
||||
batch_size = 1 # batch_size
|
||||
# ssd_minimal_discrete requires chunk_size divide seqlen
|
||||
@ -216,14 +223,14 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
|
||||
return_final_states=True)
|
||||
|
||||
# just test the last in sequence
|
||||
torch.allclose(Y[:, -1], Y_min[:, -1], atol=1e-3, rtol=1e-3)
|
||||
torch.testing.assert_close(Y[:, -1], Y_min[:, -1], atol=atol, rtol=rtol)
|
||||
|
||||
# just test the last head
|
||||
# NOTE, in the kernel we always cast states to fp32
|
||||
torch.allclose(final_state[:, -1],
|
||||
final_state_min[:, -1].to(torch.float32),
|
||||
atol=1e-3,
|
||||
rtol=1e-3)
|
||||
torch.testing.assert_close(final_state[:, -1],
|
||||
final_state_min[:, -1].to(torch.float32),
|
||||
atol=atol,
|
||||
rtol=rtol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.float32, torch.float16])
|
||||
@ -263,6 +270,13 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
|
||||
|
||||
seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases
|
||||
|
||||
# TODO: the irregular chunk size cases have some issues and require higher
|
||||
# tolerance. This is to be invesigated
|
||||
if chunk_size not in {8, 256}:
|
||||
atol, rtol = 5e-1, 5e-1
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
|
||||
# hold state during the cutting process so we know if an
|
||||
# example has been exhausted and needs to cycle
|
||||
last_taken: dict = {} # map: eg -> pointer to last taken sample
|
||||
@ -300,7 +314,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
|
||||
# just test one dim and dstate
|
||||
Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0]
|
||||
Y_min_eg = Y_min[i][:, 0, 0]
|
||||
torch.allclose(Y_eg, Y_min_eg, atol=1e-3, rtol=1e-3)
|
||||
torch.testing.assert_close(Y_eg, Y_min_eg, atol=atol, rtol=rtol)
|
||||
|
||||
# update states
|
||||
states = new_states
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user