mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 06:55:01 +08:00
[Bugfix] Mamba2 SSD varlen bug fix initstates decay, improve test, assert chunk pwr 2 (#21783)
Signed-off-by: Rishi Astra <40644327+RishiAstra@users.noreply.github.com>
This commit is contained in:
parent
1ece7f30ba
commit
46ae7f6666
@ -187,7 +187,7 @@ def generate_continuous_batched_examples(example_lens_by_batch,
|
|||||||
[torch.float32, torch.float16, torch.bfloat16])
|
[torch.float32, torch.float16, torch.bfloat16])
|
||||||
@pytest.mark.parametrize("n_heads", [3, 4, 11, 16, 32])
|
@pytest.mark.parametrize("n_heads", [3, 4, 11, 16, 32])
|
||||||
@pytest.mark.parametrize("d_head", [5, 8, 19, 32, 128])
|
@pytest.mark.parametrize("d_head", [5, 8, 19, 32, 128])
|
||||||
@pytest.mark.parametrize("seq_len_chunk_size", [(119, 17), (128, 32)])
|
@pytest.mark.parametrize("seq_len_chunk_size", [(112, 16), (128, 32)])
|
||||||
def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
|
def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
|
||||||
itype):
|
itype):
|
||||||
|
|
||||||
@ -253,15 +253,15 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
|
|||||||
(8, 8, 16, 32, 16),
|
(8, 8, 16, 32, 16),
|
||||||
]), # mode examples with varied lengths
|
]), # mode examples with varied lengths
|
||||||
|
|
||||||
# odd chunk_size
|
|
||||||
(64, 29, 2, [(11, 4), (13, 23), (19, 22),
|
|
||||||
(21, 15)]), # irregular sizes
|
|
||||||
|
|
||||||
# large-ish chunk_size (256)
|
# large-ish chunk_size (256)
|
||||||
(64, 256, 1, [(5, ), (1, ), (1, ),
|
(64, 256, 1, [(5, ), (1, ), (1, ),
|
||||||
(1, )]), # irregular sizes with small sequences
|
(1, )]), # irregular sizes with small sequences
|
||||||
(64, 256, 2, [(5, 30), (1, 2), (1, 2),
|
(64, 256, 2, [(5, 30), (1, 2), (1, 2),
|
||||||
(1, 2)]), # irregular sizes with small sequences
|
(1, 2)]), # irregular sizes with small sequences
|
||||||
|
|
||||||
|
# we also need to test some large seqlen
|
||||||
|
# to catch errors with init states decay
|
||||||
|
(768, 128, 2, [(138, 225), (138, 225)]),
|
||||||
])
|
])
|
||||||
def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
|
def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
|
||||||
itype):
|
itype):
|
||||||
@ -271,10 +271,9 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
|
|||||||
|
|
||||||
seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases
|
seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases
|
||||||
|
|
||||||
# TODO: the irregular chunk size cases have some issues and require higher
|
# This test can have larger error for longer sequences
|
||||||
# tolerance. This is to be invesigated
|
if seqlen > 256:
|
||||||
if chunk_size not in {8, 256}:
|
atol, rtol = 1e-2, 5e-3
|
||||||
atol, rtol = 5e-1, 5e-1
|
|
||||||
else:
|
else:
|
||||||
atol, rtol = 5e-3, 5e-3
|
atol, rtol = 5e-3, 5e-3
|
||||||
|
|
||||||
|
|||||||
@ -290,10 +290,8 @@ def _chunk_scan_fwd_kernel(
|
|||||||
# get the cs at the offset boundary
|
# get the cs at the offset boundary
|
||||||
# - c_off == 0 is a passthrough
|
# - c_off == 0 is a passthrough
|
||||||
dA_cs_m_boundary = tl.load(
|
dA_cs_m_boundary = tl.load(
|
||||||
dA_cumsum_ptr +
|
dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize,
|
||||||
(pid_m * BLOCK_SIZE_M + c_off - 1) * stride_dA_cs_csize,
|
mask=(((c_off - 1) > -1) and ((c_off) < chunk_size)),
|
||||||
mask=(((pid_m * BLOCK_SIZE_M + c_off - 1) > -1)
|
|
||||||
and ((pid_m * BLOCK_SIZE_M + c_off) < chunk_size)),
|
|
||||||
other=0.0).to(tl.float32)
|
other=0.0).to(tl.float32)
|
||||||
|
|
||||||
if HAS_SEQ_IDX:
|
if HAS_SEQ_IDX:
|
||||||
|
|||||||
@ -21,6 +21,10 @@ from .ssd_state_passing import _state_passing_fwd
|
|||||||
TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')
|
TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')
|
||||||
|
|
||||||
|
|
||||||
|
def is_int_pow_2(n):
|
||||||
|
return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0
|
||||||
|
|
||||||
|
|
||||||
def _mamba_chunk_scan_combined_fwd(x,
|
def _mamba_chunk_scan_combined_fwd(x,
|
||||||
dt,
|
dt,
|
||||||
A,
|
A,
|
||||||
@ -38,6 +42,7 @@ def _mamba_chunk_scan_combined_fwd(x,
|
|||||||
dt_softplus=False,
|
dt_softplus=False,
|
||||||
dt_limit=(0.0, float("inf")),
|
dt_limit=(0.0, float("inf")),
|
||||||
out=None):
|
out=None):
|
||||||
|
assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2"
|
||||||
batch, seqlen, nheads, headdim = x.shape
|
batch, seqlen, nheads, headdim = x.shape
|
||||||
_, _, ngroups, dstate = B.shape
|
_, _, ngroups, dstate = B.shape
|
||||||
assert nheads % ngroups == 0
|
assert nheads % ngroups == 0
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user