mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-22 08:04:27 +08:00
Mamba V2 Test not Asserting Failures. (#21379)
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
This commit is contained in:
parent
accac82928
commit
32ec9e2f2a
@ -119,7 +119,8 @@ def mixer2_gated_norm_tensor_parallel(
|
|||||||
gate_states[..., local_rank * N:(local_rank + 1) * N],
|
gate_states[..., local_rank * N:(local_rank + 1) * N],
|
||||||
)
|
)
|
||||||
ref_output = mixer_single_gpu(hidden_states, gate_states)
|
ref_output = mixer_single_gpu(hidden_states, gate_states)
|
||||||
torch.allclose(output,
|
torch.testing.assert_close(output,
|
||||||
ref_output[..., local_rank * N:(local_rank + 1) * N],
|
ref_output[...,
|
||||||
atol=1e-3,
|
local_rank * N:(local_rank + 1) * N],
|
||||||
rtol=1e-3)
|
atol=5e-3,
|
||||||
|
rtol=1e-3)
|
||||||
|
|||||||
@ -193,6 +193,13 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
|
|||||||
|
|
||||||
# this tests the kernels on a single example (no batching)
|
# this tests the kernels on a single example (no batching)
|
||||||
|
|
||||||
|
# TODO: the bfloat16 case requires higher thresholds. To be investigated
|
||||||
|
|
||||||
|
if itype == torch.bfloat16:
|
||||||
|
atol, rtol = 5e-2, 5e-2
|
||||||
|
else:
|
||||||
|
atol, rtol = 8e-3, 5e-3
|
||||||
|
|
||||||
# set seed
|
# set seed
|
||||||
batch_size = 1 # batch_size
|
batch_size = 1 # batch_size
|
||||||
# ssd_minimal_discrete requires chunk_size divide seqlen
|
# ssd_minimal_discrete requires chunk_size divide seqlen
|
||||||
@ -216,14 +223,14 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
|
|||||||
return_final_states=True)
|
return_final_states=True)
|
||||||
|
|
||||||
# just test the last in sequence
|
# just test the last in sequence
|
||||||
torch.allclose(Y[:, -1], Y_min[:, -1], atol=1e-3, rtol=1e-3)
|
torch.testing.assert_close(Y[:, -1], Y_min[:, -1], atol=atol, rtol=rtol)
|
||||||
|
|
||||||
# just test the last head
|
# just test the last head
|
||||||
# NOTE, in the kernel we always cast states to fp32
|
# NOTE, in the kernel we always cast states to fp32
|
||||||
torch.allclose(final_state[:, -1],
|
torch.testing.assert_close(final_state[:, -1],
|
||||||
final_state_min[:, -1].to(torch.float32),
|
final_state_min[:, -1].to(torch.float32),
|
||||||
atol=1e-3,
|
atol=atol,
|
||||||
rtol=1e-3)
|
rtol=rtol)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("itype", [torch.float32, torch.float16])
|
@pytest.mark.parametrize("itype", [torch.float32, torch.float16])
|
||||||
@ -263,6 +270,13 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
|
|||||||
|
|
||||||
seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases
|
seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases
|
||||||
|
|
||||||
|
# TODO: the irregular chunk size cases have some issues and require higher
|
||||||
|
# tolerance. This is to be invesigated
|
||||||
|
if chunk_size not in {8, 256}:
|
||||||
|
atol, rtol = 5e-1, 5e-1
|
||||||
|
else:
|
||||||
|
atol, rtol = 5e-3, 5e-3
|
||||||
|
|
||||||
# hold state during the cutting process so we know if an
|
# hold state during the cutting process so we know if an
|
||||||
# example has been exhausted and needs to cycle
|
# example has been exhausted and needs to cycle
|
||||||
last_taken: dict = {} # map: eg -> pointer to last taken sample
|
last_taken: dict = {} # map: eg -> pointer to last taken sample
|
||||||
@ -300,7 +314,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
|
|||||||
# just test one dim and dstate
|
# just test one dim and dstate
|
||||||
Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0]
|
Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0]
|
||||||
Y_min_eg = Y_min[i][:, 0, 0]
|
Y_min_eg = Y_min[i][:, 0, 0]
|
||||||
torch.allclose(Y_eg, Y_min_eg, atol=1e-3, rtol=1e-3)
|
torch.testing.assert_close(Y_eg, Y_min_eg, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
# update states
|
# update states
|
||||||
states = new_states
|
states = new_states
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user