mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-23 11:55:55 +08:00
[Model] Mamba2 varlen refactor (#21467)
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: RishiAstra <40644327+RishiAstra@users.noreply.github.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
87ee8535a6
commit
62ae26c870
@ -7,7 +7,7 @@ import torch.nn.functional as F
|
|||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
|
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
|
||||||
mamba_chunk_scan_combined)
|
mamba_chunk_scan_combined_varlen)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.v1.attention.backends.mamba2_attn import (
|
from vllm.v1.attention.backends.mamba2_attn import (
|
||||||
_query_start_loc_to_chunk_indices_offsets)
|
_query_start_loc_to_chunk_indices_offsets)
|
||||||
@ -185,9 +185,14 @@ def generate_continuous_batched_examples(example_lens_by_batch,
|
|||||||
IND_S = [x % full_length for x in IND_E]
|
IND_S = [x % full_length for x in IND_E]
|
||||||
IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)]
|
IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)]
|
||||||
|
|
||||||
|
# varlen has implicit batch=1
|
||||||
|
dt2 = dt2.squeeze(0)
|
||||||
|
X2 = X2.squeeze(0)
|
||||||
|
B2 = B2.squeeze(0)
|
||||||
|
C2 = C2.squeeze(0)
|
||||||
yield ([Y_min[s, IND_S[s]:IND_E[s]]
|
yield ([Y_min[s, IND_S[s]:IND_E[s]]
|
||||||
for s in range(num_examples)] if return_naive_ref else None,
|
for s in range(num_examples)] if return_naive_ref else None,
|
||||||
cu_seqlens, seq_idx.unsqueeze(0), (A, dt2, X2, B2, C2))
|
cu_seqlens, seq_idx, (A, dt2, X2, B2, C2))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("itype",
|
@pytest.mark.parametrize("itype",
|
||||||
@ -198,7 +203,7 @@ def generate_continuous_batched_examples(example_lens_by_batch,
|
|||||||
def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
|
def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
|
||||||
itype):
|
itype):
|
||||||
|
|
||||||
# this tests the kernels on a single example (no batching)
|
# this tests the kernels on a single example (bs=1)
|
||||||
|
|
||||||
# TODO: the bfloat16 case requires higher thresholds. To be investigated
|
# TODO: the bfloat16 case requires higher thresholds. To be investigated
|
||||||
|
|
||||||
@ -219,23 +224,40 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
|
|||||||
|
|
||||||
Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), A * dt,
|
Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), A * dt,
|
||||||
B, C, chunk_size)
|
B, C, chunk_size)
|
||||||
|
|
||||||
|
cu_seqlens = torch.tensor((0, seqlen), device='cuda').cumsum(dim=0)
|
||||||
|
seq_idx = torch.zeros(seqlen, dtype=torch.int32, device=cu_seqlens.device)
|
||||||
|
|
||||||
|
chunk_indices, chunk_offsets = \
|
||||||
|
_query_start_loc_to_chunk_indices_offsets(
|
||||||
|
cu_seqlens, chunk_size, cu_seqlens[-1])
|
||||||
|
|
||||||
|
# varlen has implicit batch=1
|
||||||
|
X = X.squeeze(0)
|
||||||
|
dt = dt.squeeze(0)
|
||||||
|
A = A.squeeze(0)
|
||||||
|
B = B.squeeze(0)
|
||||||
|
C = C.squeeze(0)
|
||||||
Y = torch.empty_like(X)
|
Y = torch.empty_like(X)
|
||||||
final_state = mamba_chunk_scan_combined(X,
|
final_state = mamba_chunk_scan_combined_varlen(X,
|
||||||
dt,
|
dt,
|
||||||
A,
|
A,
|
||||||
B,
|
B,
|
||||||
C,
|
C,
|
||||||
chunk_size,
|
chunk_size,
|
||||||
D=None,
|
D=None,
|
||||||
return_final_states=True,
|
cu_seqlens=cu_seqlens,
|
||||||
out=Y)
|
seq_idx=seq_idx,
|
||||||
|
chunk_indices=chunk_indices,
|
||||||
|
chunk_offsets=chunk_offsets,
|
||||||
|
out=Y)
|
||||||
|
|
||||||
# just test the last in sequence
|
# just test the last in sequence
|
||||||
torch.testing.assert_close(Y[:, -1], Y_min[:, -1], atol=atol, rtol=rtol)
|
torch.testing.assert_close(Y[-1], Y_min[0, -1], atol=atol, rtol=rtol)
|
||||||
|
|
||||||
# just test the last head
|
# just test the last head
|
||||||
# NOTE, in the kernel we always cast states to fp32
|
# NOTE, in the kernel we always cast states to fp32
|
||||||
torch.testing.assert_close(final_state[:, -1],
|
torch.testing.assert_close(final_state[:, -1].to(torch.float32),
|
||||||
final_state_min[:, -1].to(torch.float32),
|
final_state_min[:, -1].to(torch.float32),
|
||||||
atol=atol,
|
atol=atol,
|
||||||
rtol=rtol)
|
rtol=rtol)
|
||||||
@ -300,7 +322,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
|
|||||||
cu_seqlens, chunk_size, cu_seqlens[-1])
|
cu_seqlens, chunk_size, cu_seqlens[-1])
|
||||||
|
|
||||||
Y = torch.empty_like(X)
|
Y = torch.empty_like(X)
|
||||||
new_states = mamba_chunk_scan_combined(
|
new_states = mamba_chunk_scan_combined_varlen(
|
||||||
X,
|
X,
|
||||||
dt,
|
dt,
|
||||||
A,
|
A,
|
||||||
@ -312,7 +334,6 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
|
|||||||
seq_idx=seq_idx,
|
seq_idx=seq_idx,
|
||||||
chunk_indices=chunk_indices,
|
chunk_indices=chunk_indices,
|
||||||
chunk_offsets=chunk_offsets,
|
chunk_offsets=chunk_offsets,
|
||||||
return_varlen_states=True,
|
|
||||||
initial_states=states,
|
initial_states=states,
|
||||||
out=Y,
|
out=Y,
|
||||||
)
|
)
|
||||||
@ -321,7 +342,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
|
|||||||
for i in range(num_examples):
|
for i in range(num_examples):
|
||||||
|
|
||||||
# just test one dim and dstate
|
# just test one dim and dstate
|
||||||
Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0]
|
Y_eg = Y[cu_seqlens[i]:cu_seqlens[i + 1], 0, 0]
|
||||||
Y_min_eg = Y_min[i][:, 0, 0]
|
Y_min_eg = Y_min[i][:, 0, 0]
|
||||||
torch.testing.assert_close(Y_eg, Y_min_eg, atol=atol, rtol=rtol)
|
torch.testing.assert_close(Y_eg, Y_min_eg, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
@ -386,7 +407,7 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
|||||||
_query_start_loc_to_chunk_indices_offsets(
|
_query_start_loc_to_chunk_indices_offsets(
|
||||||
cu_seqlens, chunk_size, cu_seqlens[-1])
|
cu_seqlens, chunk_size, cu_seqlens[-1])
|
||||||
Y_ref = torch.empty_like(X)
|
Y_ref = torch.empty_like(X)
|
||||||
state_ref = mamba_chunk_scan_combined(
|
state_ref = mamba_chunk_scan_combined_varlen(
|
||||||
X,
|
X,
|
||||||
dt,
|
dt,
|
||||||
A,
|
A,
|
||||||
@ -398,7 +419,6 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
|||||||
seq_idx=seq_idx,
|
seq_idx=seq_idx,
|
||||||
chunk_indices=chunk_indices,
|
chunk_indices=chunk_indices,
|
||||||
chunk_offsets=chunk_offsets,
|
chunk_offsets=chunk_offsets,
|
||||||
return_varlen_states=True,
|
|
||||||
initial_states=None,
|
initial_states=None,
|
||||||
out=Y_ref,
|
out=Y_ref,
|
||||||
)
|
)
|
||||||
@ -414,27 +434,27 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
|||||||
chunked_seq_idx = torch.repeat_interleave(
|
chunked_seq_idx = torch.repeat_interleave(
|
||||||
torch.arange(len(chunked_seqlens), device=device),
|
torch.arange(len(chunked_seqlens), device=device),
|
||||||
chunked_seqlens,
|
chunked_seqlens,
|
||||||
output_size=chunked_cu_seqlens[-1]).unsqueeze(0).to(torch.int32)
|
output_size=chunked_cu_seqlens[-1]).to(torch.int32)
|
||||||
chunked_input_seq_len = chunked_cu_seqlens[-1]
|
chunked_input_seq_len = chunked_cu_seqlens[-1]
|
||||||
X_chunked = torch.zeros_like(X)[:, :chunked_input_seq_len, ...]
|
X_chunked = torch.zeros_like(X)[:chunked_input_seq_len, ...]
|
||||||
dt_chunked = torch.zeros_like(dt)[:, :chunked_input_seq_len, ...]
|
dt_chunked = torch.zeros_like(dt)[:chunked_input_seq_len, ...]
|
||||||
B_chunked = torch.zeros_like(B)[:, :chunked_input_seq_len, ...]
|
B_chunked = torch.zeros_like(B)[:chunked_input_seq_len, ...]
|
||||||
C_chunked = torch.zeros_like(C)[:, :chunked_input_seq_len, ...]
|
C_chunked = torch.zeros_like(C)[:chunked_input_seq_len, ...]
|
||||||
for i in range(num_sequences):
|
for i in range(num_sequences):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
chunk_f = lambda x, i: x[:, cu_seqlens[i]:cu_seqlens[i] + chunked_seqlens[i], ...] # noqa: E501
|
chunk_f = lambda x, i: x[cu_seqlens[i]:cu_seqlens[i] + chunked_seqlens[i], ...] # noqa: E501
|
||||||
|
|
||||||
X_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(X, i) # noqa: E501
|
X_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(X, i) # noqa: E501
|
||||||
dt_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(dt, i) # noqa: E501
|
dt_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(dt, i) # noqa: E501
|
||||||
B_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(B, i) # noqa: E501
|
B_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(B, i) # noqa: E501
|
||||||
C_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(C, i) # noqa: E501
|
C_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(C, i) # noqa: E501
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
chunk_indices, chunk_offsets = \
|
chunk_indices, chunk_offsets = \
|
||||||
_query_start_loc_to_chunk_indices_offsets(
|
_query_start_loc_to_chunk_indices_offsets(
|
||||||
chunked_cu_seqlens, chunk_size, chunked_cu_seqlens[-1])
|
chunked_cu_seqlens, chunk_size, chunked_cu_seqlens[-1])
|
||||||
Y_partial = torch.empty_like(X_chunked)
|
Y_partial = torch.empty_like(X_chunked)
|
||||||
partial_state = mamba_chunk_scan_combined(
|
partial_state = mamba_chunk_scan_combined_varlen(
|
||||||
X_chunked,
|
X_chunked,
|
||||||
dt_chunked,
|
dt_chunked,
|
||||||
A,
|
A,
|
||||||
@ -446,7 +466,6 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
|||||||
seq_idx=chunked_seq_idx,
|
seq_idx=chunked_seq_idx,
|
||||||
chunk_indices=chunk_indices,
|
chunk_indices=chunk_indices,
|
||||||
chunk_offsets=chunk_offsets,
|
chunk_offsets=chunk_offsets,
|
||||||
return_varlen_states=True,
|
|
||||||
initial_states=None,
|
initial_states=None,
|
||||||
out=Y_partial,
|
out=Y_partial,
|
||||||
)
|
)
|
||||||
@ -461,29 +480,28 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
|||||||
remaining_chunked_seq_idx = torch.repeat_interleave(
|
remaining_chunked_seq_idx = torch.repeat_interleave(
|
||||||
torch.arange(len(remaining_chunked_seqlens), device=device),
|
torch.arange(len(remaining_chunked_seqlens), device=device),
|
||||||
remaining_chunked_seqlens,
|
remaining_chunked_seqlens,
|
||||||
output_size=remaining_chunked_cu_seqlens[-1]).unsqueeze(0).to(
|
output_size=remaining_chunked_cu_seqlens[-1]).to(torch.int32)
|
||||||
torch.int32)
|
|
||||||
remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens[-1]
|
remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens[-1]
|
||||||
# fmt: off
|
# fmt: off
|
||||||
remaining_X_chunked = torch.zeros_like(X)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
|
remaining_X_chunked = torch.zeros_like(X)[:remaining_chunked_input_seq_len, ...] # noqa: E501
|
||||||
remaining_dt_chunked = torch.zeros_like(dt)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
|
remaining_dt_chunked = torch.zeros_like(dt)[:remaining_chunked_input_seq_len, ...] # noqa: E501
|
||||||
remaining_B_chunked = torch.zeros_like(B)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
|
remaining_B_chunked = torch.zeros_like(B)[:remaining_chunked_input_seq_len, ...] # noqa: E501
|
||||||
remaining_C_chunked = torch.zeros_like(C)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
|
remaining_C_chunked = torch.zeros_like(C)[:remaining_chunked_input_seq_len, ...] # noqa: E501
|
||||||
for i in range(num_sequences):
|
for i in range(num_sequences):
|
||||||
remaining_chunk_f = lambda x, i: x[:, cu_seqlens[i] + chunked_seqlens[i]:cu_seqlens[i+1], ...] # noqa: E501
|
remaining_chunk_f = lambda x, i: x[cu_seqlens[i] + chunked_seqlens[i]:cu_seqlens[i+1], ...] # noqa: E501
|
||||||
|
|
||||||
remaining_X_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(X, i) # noqa: E501
|
remaining_X_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(X, i) # noqa: E501
|
||||||
remaining_dt_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(dt, i) # noqa: E501
|
remaining_dt_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(dt, i) # noqa: E501
|
||||||
remaining_B_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(B, i) # noqa: E501
|
remaining_B_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(B, i) # noqa: E501
|
||||||
remaining_C_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(C, i) # noqa: E501
|
remaining_C_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(C, i) # noqa: E501
|
||||||
|
|
||||||
# assert input chunking is correct
|
# assert input chunking is correct
|
||||||
concat_chunk_f = lambda pt1, pt2, i: torch.cat([
|
concat_chunk_f = lambda pt1, pt2, i: torch.cat([
|
||||||
pt1[:,chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1],...],
|
pt1[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1],...],
|
||||||
pt2[:,remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1],...],
|
pt2[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1],...],
|
||||||
],
|
],
|
||||||
dim=1)
|
dim=0)
|
||||||
concat_batch_f = lambda pt1, pt2: torch.cat([concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=1) # noqa: E501
|
concat_batch_f = lambda pt1, pt2: torch.cat([concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=0) # noqa: E501
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
assert concat_batch_f(X_chunked, remaining_X_chunked).equal(X)
|
assert concat_batch_f(X_chunked, remaining_X_chunked).equal(X)
|
||||||
@ -498,7 +516,7 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
|||||||
remaining_chunked_cu_seqlens[-1])
|
remaining_chunked_cu_seqlens[-1])
|
||||||
|
|
||||||
Y_chunked = torch.empty_like(remaining_X_chunked)
|
Y_chunked = torch.empty_like(remaining_X_chunked)
|
||||||
state_chunked = mamba_chunk_scan_combined(
|
state_chunked = mamba_chunk_scan_combined_varlen(
|
||||||
remaining_X_chunked,
|
remaining_X_chunked,
|
||||||
remaining_dt_chunked,
|
remaining_dt_chunked,
|
||||||
A,
|
A,
|
||||||
@ -510,7 +528,6 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
|||||||
seq_idx=remaining_chunked_seq_idx,
|
seq_idx=remaining_chunked_seq_idx,
|
||||||
chunk_indices=chunk_indices,
|
chunk_indices=chunk_indices,
|
||||||
chunk_offsets=chunk_offsets,
|
chunk_offsets=chunk_offsets,
|
||||||
return_varlen_states=True,
|
|
||||||
initial_states=partial_state,
|
initial_states=partial_state,
|
||||||
out=Y_chunked,
|
out=Y_chunked,
|
||||||
)
|
)
|
||||||
@ -518,17 +535,17 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
|||||||
|
|
||||||
# kernel chunked is same as kernel overall
|
# kernel chunked is same as kernel overall
|
||||||
for i in range(num_sequences):
|
for i in range(num_sequences):
|
||||||
Y_seq = Y[:, cu_seqlens[i]:cu_seqlens[i + 1], ...]
|
Y_seq = Y[cu_seqlens[i]:cu_seqlens[i + 1], ...]
|
||||||
Y_ref_seq = Y_ref[:, cu_seqlens[i]:cu_seqlens[i + 1], ...]
|
Y_ref_seq = Y_ref[cu_seqlens[i]:cu_seqlens[i + 1], ...]
|
||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
Y_seq[:, :chunked_seqlens[i], ...],
|
Y_seq[:chunked_seqlens[i], ...],
|
||||||
Y_ref_seq[:, :chunked_seqlens[i], ...],
|
Y_ref_seq[:chunked_seqlens[i], ...],
|
||||||
atol=atol,
|
atol=atol,
|
||||||
rtol=rtol,
|
rtol=rtol,
|
||||||
msg=lambda x: f"seq{i} output part1 " + x) # noqa: B023
|
msg=lambda x: f"seq{i} output part1 " + x) # noqa: B023
|
||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
Y_seq[:, chunked_seqlens[i]:, ...],
|
Y_seq[chunked_seqlens[i]:, ...],
|
||||||
Y_ref_seq[:, chunked_seqlens[i]:, ...],
|
Y_ref_seq[chunked_seqlens[i]:, ...],
|
||||||
atol=atol,
|
atol=atol,
|
||||||
rtol=rtol,
|
rtol=rtol,
|
||||||
msg=lambda x: f"seq{i} output part2 " + x) # noqa: B023
|
msg=lambda x: f"seq{i} output part2 " + x) # noqa: B023
|
||||||
|
|||||||
@ -29,7 +29,7 @@ from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated
|
|||||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||||
selective_state_update)
|
selective_state_update)
|
||||||
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
|
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
|
||||||
mamba_chunk_scan_combined)
|
mamba_chunk_scan_combined_varlen)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.model_loader.weight_utils import (
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
LoaderFunction, composed_weight_loader, sharded_weight_loader)
|
LoaderFunction, composed_weight_loader, sharded_weight_loader)
|
||||||
@ -504,6 +504,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
seq_idx_p = attn_metadata.seq_idx_p
|
seq_idx_p = attn_metadata.seq_idx_p
|
||||||
chunk_indices_p = attn_metadata.chunk_indices_p
|
chunk_indices_p = attn_metadata.chunk_indices_p
|
||||||
chunk_offsets_p = attn_metadata.chunk_offsets_p
|
chunk_offsets_p = attn_metadata.chunk_offsets_p
|
||||||
|
query_start_loc_p = attn_metadata.query_start_loc_p
|
||||||
|
|
||||||
# 1. Gated MLP's linear projection
|
# 1. Gated MLP's linear projection
|
||||||
projected_states, _ = self.in_proj(hidden_states)
|
projected_states, _ = self.in_proj(hidden_states)
|
||||||
@ -545,6 +546,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
out, _ = self.out_proj(hidden_states)
|
out, _ = self.out_proj(hidden_states)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
|
||||||
num_prefills = attn_metadata.num_prefills # request count
|
num_prefills = attn_metadata.num_prefills # request count
|
||||||
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
|
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
|
||||||
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
|
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
|
||||||
@ -570,9 +572,6 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
[num_decodes, num_prefills],
|
[num_decodes, num_prefills],
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
query_start_loc_p = (
|
|
||||||
attn_metadata.query_start_loc[-num_prefills - 1:] -
|
|
||||||
num_decodes if has_prefill else None)
|
|
||||||
|
|
||||||
# Preallocate output tensor to avoid memcpy cost for merging prefill
|
# Preallocate output tensor to avoid memcpy cost for merging prefill
|
||||||
# and decode outputs
|
# and decode outputs
|
||||||
@ -620,15 +619,15 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
ssm_state[state_indices_tensor_p], 0)
|
ssm_state[state_indices_tensor_p], 0)
|
||||||
|
|
||||||
# NOTE: final output is an in-place update of out tensor
|
# NOTE: final output is an in-place update of out tensor
|
||||||
varlen_state = mamba_chunk_scan_combined(
|
varlen_states = mamba_chunk_scan_combined_varlen(
|
||||||
hidden_states_p.view(1, num_prefill_tokens,
|
hidden_states_p.view(num_prefill_tokens,
|
||||||
self.num_heads // self.tp_size,
|
self.num_heads // self.tp_size,
|
||||||
self.head_dim),
|
self.head_dim),
|
||||||
dt_p.unsqueeze(0),
|
dt_p,
|
||||||
self.A,
|
self.A,
|
||||||
B_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size,
|
B_p.view(num_prefill_tokens, self.n_groups // self.tp_size,
|
||||||
-1),
|
-1),
|
||||||
C_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size,
|
C_p.view(num_prefill_tokens, self.n_groups // self.tp_size,
|
||||||
-1),
|
-1),
|
||||||
chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
D=self.D,
|
D=self.D,
|
||||||
@ -639,17 +638,15 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
chunk_offsets=chunk_offsets_p,
|
chunk_offsets=chunk_offsets_p,
|
||||||
cu_seqlens=query_start_loc_p,
|
cu_seqlens=query_start_loc_p,
|
||||||
initial_states=initial_states,
|
initial_states=initial_states,
|
||||||
return_varlen_states=True,
|
|
||||||
return_final_states=False,
|
|
||||||
dt_softplus=True,
|
dt_softplus=True,
|
||||||
dt_limit=(0.0, float("inf")),
|
dt_limit=(0.0, float("inf")),
|
||||||
out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1,
|
out=preallocated_ssm_out_p.view(num_prefill_tokens, -1,
|
||||||
self.head_dim),
|
self.head_dim),
|
||||||
state_dtype=ssm_state.dtype)
|
state_dtype=ssm_state.dtype)
|
||||||
|
|
||||||
# update ssm states
|
# update ssm states
|
||||||
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
|
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
|
||||||
ssm_state[state_indices_tensor_p] = varlen_state
|
ssm_state[state_indices_tensor_p] = varlen_states
|
||||||
|
|
||||||
# Process decode requests
|
# Process decode requests
|
||||||
if has_decode:
|
if has_decode:
|
||||||
|
|||||||
@ -427,7 +427,7 @@ def causal_conv1d_fn(
|
|||||||
batch_ptr = metadata.batch_ptr
|
batch_ptr = metadata.batch_ptr
|
||||||
token_chunk_offset_ptr = metadata.token_chunk_offset_ptr
|
token_chunk_offset_ptr = metadata.token_chunk_offset_ptr
|
||||||
else:
|
else:
|
||||||
seqlens = np.diff(query_start_loc.to('cpu'))
|
seqlens = query_start_loc.diff().to('cpu')
|
||||||
args = seqlens
|
args = seqlens
|
||||||
MAX_NUM_PROGRAMS = 1024
|
MAX_NUM_PROGRAMS = 1024
|
||||||
|
|
||||||
|
|||||||
@ -99,34 +99,28 @@ def _bmm_chunk_fwd_kernel(
|
|||||||
seq_idx_ptr,
|
seq_idx_ptr,
|
||||||
# Matrix dimensions
|
# Matrix dimensions
|
||||||
seqlen,
|
seqlen,
|
||||||
chunk_size,
|
chunk_size: tl.constexpr,
|
||||||
K,
|
K: tl.constexpr,
|
||||||
ngroups,
|
ngroups: tl.constexpr,
|
||||||
stride_a_batch,
|
stride_a_seqlen: tl.int64,
|
||||||
stride_a_seqlen,
|
stride_a_head: tl.int64,
|
||||||
stride_a_head,
|
stride_ak: tl.constexpr,
|
||||||
stride_ak,
|
stride_b_seqlen: tl.int64,
|
||||||
stride_b_batch,
|
stride_b_head: tl.int64,
|
||||||
stride_b_seqlen,
|
stride_bk: tl.constexpr,
|
||||||
stride_b_head,
|
stride_out_chunk: tl.int64,
|
||||||
stride_bk,
|
stride_out_head: tl.int64,
|
||||||
stride_out_batch,
|
stride_outm: tl.int64,
|
||||||
stride_out_chunk,
|
stride_outn: tl.constexpr,
|
||||||
stride_out_head,
|
stride_seq_idx_seqlen: tl.constexpr,
|
||||||
stride_outm,
|
|
||||||
stride_outn,
|
|
||||||
stride_seq_idx_batch,
|
|
||||||
stride_seq_idx_seqlen,
|
|
||||||
# Meta-parameters
|
# Meta-parameters
|
||||||
IS_CAUSAL: tl.constexpr,
|
IS_CAUSAL: tl.constexpr,
|
||||||
dot_dtype: tl.constexpr,
|
dot_dtype: tl.constexpr,
|
||||||
HAS_SEQ_IDX: tl.constexpr,
|
|
||||||
BLOCK_SIZE_M: tl.constexpr,
|
BLOCK_SIZE_M: tl.constexpr,
|
||||||
BLOCK_SIZE_N: tl.constexpr,
|
BLOCK_SIZE_N: tl.constexpr,
|
||||||
BLOCK_SIZE_K: tl.constexpr,
|
BLOCK_SIZE_K: tl.constexpr,
|
||||||
):
|
):
|
||||||
pid_b = tl.program_id(axis=1)
|
pid_ch = tl.program_id(axis=1).to(tl.int64)
|
||||||
pid_ch = tl.program_id(axis=2).to(tl.int64)
|
|
||||||
pid_c = pid_ch // ngroups
|
pid_c = pid_ch // ngroups
|
||||||
pid_h = pid_ch - pid_c * ngroups
|
pid_h = pid_ch - pid_c * ngroups
|
||||||
num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)
|
num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)
|
||||||
@ -135,10 +129,10 @@ def _bmm_chunk_fwd_kernel(
|
|||||||
if IS_CAUSAL:
|
if IS_CAUSAL:
|
||||||
if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:
|
if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:
|
||||||
return
|
return
|
||||||
a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head
|
a_ptr += pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head
|
||||||
b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head
|
b_ptr += pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head
|
||||||
if HAS_SEQ_IDX:
|
|
||||||
seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
|
seq_idx_ptr += pid_c * chunk_size * stride_seq_idx_seqlen
|
||||||
|
|
||||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
@ -150,6 +144,8 @@ def _bmm_chunk_fwd_kernel(
|
|||||||
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
||||||
|
|
||||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||||
|
|
||||||
|
# compute a * b.T
|
||||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||||
a = tl.load(a_ptrs,
|
a = tl.load(a_ptrs,
|
||||||
mask=(offs_m[:, None] < chunk_size_limit) &
|
mask=(offs_m[:, None] < chunk_size_limit) &
|
||||||
@ -165,18 +161,19 @@ def _bmm_chunk_fwd_kernel(
|
|||||||
|
|
||||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
if HAS_SEQ_IDX:
|
|
||||||
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
|
||||||
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
|
|
||||||
mask=offs_m < chunk_size_limit,
|
|
||||||
other=-1)
|
|
||||||
seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen,
|
|
||||||
mask=offs_n < chunk_size_limit,
|
|
||||||
other=-2)
|
|
||||||
acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)
|
|
||||||
out = acc.to(out_ptr.dtype.element_ty)
|
|
||||||
|
|
||||||
out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head
|
# Zero out the results that are not from the same request
|
||||||
|
# in the varlen batch
|
||||||
|
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
|
||||||
|
mask=offs_m < chunk_size_limit,
|
||||||
|
other=-1)
|
||||||
|
seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen,
|
||||||
|
mask=offs_n < chunk_size_limit,
|
||||||
|
other=-2)
|
||||||
|
acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)
|
||||||
|
|
||||||
|
out = acc.to(out_ptr.dtype.element_ty)
|
||||||
|
out_ptr += pid_c * stride_out_chunk + pid_h * stride_out_head
|
||||||
out_ptrs = out_ptr + (stride_outm * offs_m[:, None] +
|
out_ptrs = out_ptr + (stride_outm * offs_m[:, None] +
|
||||||
offs_n[None, :] * stride_outn)
|
offs_n[None, :] * stride_outn)
|
||||||
tl.store(out_ptrs,
|
tl.store(out_ptrs,
|
||||||
@ -185,78 +182,61 @@ def _bmm_chunk_fwd_kernel(
|
|||||||
(offs_n[None, :] < chunk_size))
|
(offs_n[None, :] < chunk_size))
|
||||||
|
|
||||||
|
|
||||||
def _bmm_chunk_fwd(a,
|
def _bmm_chunk_fwd(a, b, chunk_size, seq_idx, causal=False, output_dtype=None):
|
||||||
b,
|
|
||||||
chunk_size,
|
|
||||||
seq_idx=None,
|
|
||||||
causal=False,
|
|
||||||
output_dtype=None):
|
|
||||||
"""
|
"""
|
||||||
Argument:
|
Argument:
|
||||||
a: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
|
a: (seqlen, ngroups, k)
|
||||||
b: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
|
b: (seqlen, ngroups, k)
|
||||||
seq_idx: (batch, seqlen) or None. out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out.
|
seq_idx: (seqlen,). out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out.
|
||||||
causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are
|
causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are
|
||||||
guaranteed to be correct.
|
guaranteed to be correct.
|
||||||
Return:
|
Return:
|
||||||
out: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size)
|
out: (nchunks, ngroups, chunk_size, chunk_size)
|
||||||
"""
|
"""
|
||||||
# Check constraints.
|
seqlen, ngroups, k = a.shape
|
||||||
has_groups = a.dim() == 4
|
|
||||||
if not has_groups:
|
|
||||||
batch, seqlen, k = a.shape
|
|
||||||
else:
|
|
||||||
batch, seqlen, ngroups, k = a.shape
|
|
||||||
assert b.shape == a.shape
|
assert b.shape == a.shape
|
||||||
if seq_idx is not None:
|
assert seq_idx is not None
|
||||||
assert seq_idx.shape == (batch, seqlen)
|
assert seq_idx.shape == (seqlen, )
|
||||||
if a.stride(-1) != 1 and a.stride(1) != 1:
|
if a.stride(-1) != 1 and a.stride(0) != 1:
|
||||||
a = a.contiguous()
|
a = a.contiguous()
|
||||||
if b.stride(-1) != 1 and b.stride(1) != 1:
|
if b.stride(-1) != 1 and b.stride(0) != 1:
|
||||||
b = b.contiguous()
|
b = b.contiguous()
|
||||||
|
|
||||||
nchunks = math.ceil(seqlen / chunk_size)
|
nchunks = math.ceil(seqlen / chunk_size)
|
||||||
# Allocates output.
|
# Allocates output.
|
||||||
out_dtype = a.dtype if output_dtype is None else output_dtype
|
out_dtype = a.dtype if output_dtype is None else output_dtype
|
||||||
out = torch.empty(
|
out = torch.empty((nchunks, ngroups, chunk_size, chunk_size),
|
||||||
(batch, nchunks, chunk_size, chunk_size) if not has_groups else
|
device=a.device,
|
||||||
(batch, nchunks, ngroups, chunk_size, chunk_size),
|
dtype=out_dtype)
|
||||||
device=a.device,
|
|
||||||
dtype=out_dtype)
|
|
||||||
dot_dtype = (tl.bfloat16
|
dot_dtype = (tl.bfloat16
|
||||||
if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else
|
if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else
|
||||||
(tl.float16 if a.dtype == torch.float16
|
(tl.float16 if a.dtype == torch.float16
|
||||||
or b.dtype == torch.float16 else tl.float32))
|
or b.dtype == torch.float16 else tl.float32))
|
||||||
grid = lambda META: (triton.cdiv(
|
grid = lambda META: (triton.cdiv(
|
||||||
chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(
|
chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(
|
||||||
chunk_size, META['BLOCK_SIZE_N']), batch, nchunks
|
chunk_size, META['BLOCK_SIZE_N']), nchunks * ngroups)
|
||||||
if not has_groups else nchunks * ngroups)
|
|
||||||
with torch.cuda.device(a.device.index):
|
with torch.cuda.device(a.device.index):
|
||||||
_bmm_chunk_fwd_kernel[grid](
|
_bmm_chunk_fwd_kernel[grid](
|
||||||
a,
|
a_ptr=a,
|
||||||
b,
|
b_ptr=b,
|
||||||
out,
|
out_ptr=out,
|
||||||
seq_idx,
|
seq_idx_ptr=seq_idx,
|
||||||
seqlen,
|
seqlen=seqlen,
|
||||||
chunk_size,
|
chunk_size=chunk_size,
|
||||||
k,
|
K=k,
|
||||||
ngroups if has_groups else 1,
|
ngroups=ngroups,
|
||||||
a.stride(0),
|
stride_a_seqlen=a.stride(0),
|
||||||
a.stride(1),
|
stride_a_head=a.stride(1),
|
||||||
0 if not has_groups else a.stride(2),
|
stride_ak=a.stride(2),
|
||||||
a.stride(-1),
|
stride_b_seqlen=b.stride(0),
|
||||||
b.stride(0),
|
stride_b_head=b.stride(1),
|
||||||
b.stride(1),
|
stride_bk=b.stride(2),
|
||||||
0 if not has_groups else b.stride(2),
|
stride_out_chunk=out.stride(0),
|
||||||
b.stride(-1),
|
stride_out_head=out.stride(1),
|
||||||
out.stride(0),
|
stride_outm=out.stride(-2),
|
||||||
out.stride(1),
|
stride_outn=out.stride(-1),
|
||||||
0 if not has_groups else out.stride(2),
|
stride_seq_idx_seqlen=seq_idx.stride(0),
|
||||||
out.stride(-2),
|
IS_CAUSAL=causal,
|
||||||
out.stride(-1),
|
dot_dtype=dot_dtype,
|
||||||
*((seq_idx.stride(0),
|
|
||||||
seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
|
|
||||||
causal,
|
|
||||||
dot_dtype,
|
|
||||||
HAS_SEQ_IDX=seq_idx is not None,
|
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|||||||
@ -6,7 +6,6 @@
|
|||||||
|
|
||||||
# ruff: noqa: E501,SIM102
|
# ruff: noqa: E501,SIM102
|
||||||
|
|
||||||
import torch
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
@ -114,7 +113,6 @@ def _chunk_scan_fwd_kernel(
|
|||||||
x_ptr,
|
x_ptr,
|
||||||
z_ptr,
|
z_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
out_x_ptr,
|
|
||||||
dt_ptr,
|
dt_ptr,
|
||||||
dA_cumsum_ptr,
|
dA_cumsum_ptr,
|
||||||
seq_idx_ptr,
|
seq_idx_ptr,
|
||||||
@ -126,60 +124,49 @@ def _chunk_scan_fwd_kernel(
|
|||||||
chunk_offsets_ptr,
|
chunk_offsets_ptr,
|
||||||
chunk_meta_num,
|
chunk_meta_num,
|
||||||
# Matrix dimensions
|
# Matrix dimensions
|
||||||
chunk_size,
|
chunk_size: tl.constexpr,
|
||||||
hdim,
|
hdim: tl.constexpr,
|
||||||
dstate,
|
dstate: tl.constexpr,
|
||||||
batch,
|
|
||||||
seqlen,
|
seqlen,
|
||||||
nheads_ngroups_ratio,
|
nheads_ngroups_ratio: tl.constexpr,
|
||||||
# Strides
|
# Strides
|
||||||
stride_cb_batch,
|
stride_cb_chunk: tl.int64,
|
||||||
stride_cb_chunk,
|
stride_cb_head: tl.int64,
|
||||||
stride_cb_head,
|
stride_cb_csize_m: tl.int64,
|
||||||
stride_cb_csize_m,
|
stride_cb_csize_k: tl.constexpr,
|
||||||
stride_cb_csize_k,
|
stride_x_seqlen: tl.int64,
|
||||||
stride_x_batch,
|
stride_x_head: tl.int64,
|
||||||
stride_x_seqlen,
|
stride_x_hdim: tl.constexpr,
|
||||||
stride_x_head,
|
stride_z_seqlen: tl.int64,
|
||||||
stride_x_hdim,
|
stride_z_head: tl.int64,
|
||||||
stride_z_batch,
|
stride_z_hdim: tl.constexpr,
|
||||||
stride_z_seqlen,
|
stride_out_seqlen: tl.int64,
|
||||||
stride_z_head,
|
stride_out_head: tl.int64,
|
||||||
stride_z_hdim,
|
stride_out_hdim: tl.constexpr,
|
||||||
stride_out_batch,
|
stride_dt_chunk: tl.int64,
|
||||||
stride_out_seqlen,
|
stride_dt_head: tl.int64,
|
||||||
stride_out_head,
|
stride_dt_csize: tl.constexpr,
|
||||||
stride_out_hdim,
|
stride_dA_cs_chunk: tl.int64,
|
||||||
stride_dt_batch,
|
stride_dA_cs_head: tl.int64,
|
||||||
stride_dt_chunk,
|
stride_dA_cs_csize: tl.constexpr,
|
||||||
stride_dt_head,
|
stride_seq_idx_seqlen: tl.constexpr,
|
||||||
stride_dt_csize,
|
stride_C_seqlen: tl.int64,
|
||||||
stride_dA_cs_batch,
|
stride_C_head: tl.int64,
|
||||||
stride_dA_cs_chunk,
|
stride_C_dstate: tl.constexpr,
|
||||||
stride_dA_cs_head,
|
stride_states_chunk: tl.int64,
|
||||||
stride_dA_cs_csize,
|
stride_states_head: tl.int64,
|
||||||
stride_seq_idx_batch,
|
stride_states_hdim: tl.int64,
|
||||||
stride_seq_idx_seqlen,
|
stride_states_dstate: tl.constexpr,
|
||||||
stride_C_batch,
|
stride_init_states_batch: tl.int64,
|
||||||
stride_C_seqlen,
|
stride_init_states_head: tl.int64,
|
||||||
stride_C_head,
|
stride_init_states_hdim: tl.int64,
|
||||||
stride_C_dstate,
|
stride_init_states_dstate: tl.constexpr,
|
||||||
stride_states_batch,
|
stride_D_head: tl.constexpr,
|
||||||
stride_states_chunk,
|
|
||||||
stride_states_head,
|
|
||||||
stride_states_hdim,
|
|
||||||
stride_states_dstate,
|
|
||||||
stride_init_states_batch,
|
|
||||||
stride_init_states_head,
|
|
||||||
stride_init_states_hdim,
|
|
||||||
stride_init_states_dstate,
|
|
||||||
stride_D_head,
|
|
||||||
# Meta-parameters
|
# Meta-parameters
|
||||||
IS_CAUSAL: tl.constexpr,
|
IS_CAUSAL: tl.constexpr,
|
||||||
HAS_D: tl.constexpr,
|
HAS_D: tl.constexpr,
|
||||||
D_HAS_HDIM: tl.constexpr,
|
D_HAS_HDIM: tl.constexpr,
|
||||||
HAS_Z: tl.constexpr,
|
HAS_Z: tl.constexpr,
|
||||||
HAS_SEQ_IDX: tl.constexpr,
|
|
||||||
BLOCK_SIZE_M: tl.constexpr,
|
BLOCK_SIZE_M: tl.constexpr,
|
||||||
BLOCK_SIZE_N: tl.constexpr,
|
BLOCK_SIZE_N: tl.constexpr,
|
||||||
BLOCK_SIZE_K: tl.constexpr,
|
BLOCK_SIZE_K: tl.constexpr,
|
||||||
@ -187,9 +174,7 @@ def _chunk_scan_fwd_kernel(
|
|||||||
IS_TRITON_22: tl.constexpr,
|
IS_TRITON_22: tl.constexpr,
|
||||||
HAS_INITSTATES: tl.constexpr,
|
HAS_INITSTATES: tl.constexpr,
|
||||||
):
|
):
|
||||||
pid_bc = tl.program_id(axis=1).to(tl.int64)
|
pid_c = tl.program_id(axis=1).to(tl.int64)
|
||||||
pid_c = pid_bc // batch
|
|
||||||
pid_b = pid_bc - pid_c * batch
|
|
||||||
if not HAS_INITSTATES:
|
if not HAS_INITSTATES:
|
||||||
c_idx = pid_c
|
c_idx = pid_c
|
||||||
c_off = 0
|
c_off = 0
|
||||||
@ -201,53 +186,51 @@ def _chunk_scan_fwd_kernel(
|
|||||||
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
|
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
|
||||||
pid_m = tl.program_id(axis=0) // num_pid_n
|
pid_m = tl.program_id(axis=0) // num_pid_n
|
||||||
pid_n = tl.program_id(axis=0) % num_pid_n
|
pid_n = tl.program_id(axis=0) % num_pid_n
|
||||||
cb_ptr += pid_b * stride_cb_batch + c_idx * stride_cb_chunk + (
|
cb_ptr += c_idx * stride_cb_chunk + (pid_h //
|
||||||
pid_h // nheads_ngroups_ratio) * stride_cb_head
|
nheads_ngroups_ratio) * stride_cb_head
|
||||||
x_ptr += pid_b * stride_x_batch + c_idx * chunk_size * stride_x_seqlen + pid_h * stride_x_head
|
x_ptr += c_idx * chunk_size * stride_x_seqlen + pid_h * stride_x_head
|
||||||
dt_ptr += pid_b * stride_dt_batch + c_idx * stride_dt_chunk + pid_h * stride_dt_head
|
dt_ptr += c_idx * stride_dt_chunk + pid_h * stride_dt_head
|
||||||
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + c_idx * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
dA_cumsum_ptr += c_idx * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
||||||
C_ptr += pid_b * stride_C_batch + c_idx * chunk_size * stride_C_seqlen + (
|
C_ptr += c_idx * chunk_size * stride_C_seqlen + (
|
||||||
pid_h // nheads_ngroups_ratio) * stride_C_head
|
pid_h // nheads_ngroups_ratio) * stride_C_head
|
||||||
|
|
||||||
# M-block offsets and prev states
|
# M-block offsets and prev states
|
||||||
# - logic in next block may override these if there is an active offset
|
# - logic in next block may override these if there is an active offset
|
||||||
offs_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M)
|
offs_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M)
|
||||||
prev_states_ptr = states_ptr + pid_b * stride_states_batch + c_idx * stride_states_chunk + pid_h * stride_states_head
|
prev_states_ptr = states_ptr + c_idx * stride_states_chunk + pid_h * stride_states_head
|
||||||
prev_states_hdim = stride_states_hdim
|
prev_states_hdim = stride_states_hdim
|
||||||
prev_states_dstate = stride_states_dstate
|
prev_states_dstate = stride_states_dstate
|
||||||
|
|
||||||
chunk_size_limit = min(chunk_size, seqlen - c_idx * chunk_size)
|
chunk_size_limit = min(chunk_size, seqlen - c_idx * chunk_size)
|
||||||
if HAS_SEQ_IDX:
|
|
||||||
seq_idx_ptr += pid_b * stride_seq_idx_batch + c_idx * chunk_size * stride_seq_idx_seqlen
|
|
||||||
|
|
||||||
# - we only need seq_idx_prev to be aligned to chunk boundary
|
seq_idx_ptr += c_idx * chunk_size * stride_seq_idx_seqlen
|
||||||
seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen,
|
# - we only need seq_idx_prev to be aligned to chunk boundary
|
||||||
mask=c_idx >= 1,
|
seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen,
|
||||||
other=0)
|
mask=c_idx >= 1,
|
||||||
|
other=0)
|
||||||
|
|
||||||
if HAS_INITSTATES:
|
if HAS_INITSTATES:
|
||||||
# if there are init states, we only need seq_idx_m to point
|
# if there are init states, we only need seq_idx_m to point
|
||||||
# what is the current seq_idx
|
# what is the current seq_idx
|
||||||
|
|
||||||
# get current seq idx
|
# get current seq idx
|
||||||
if (pid_m * BLOCK_SIZE_M + c_off) < chunk_size_limit:
|
if (pid_m * BLOCK_SIZE_M + c_off) < chunk_size_limit:
|
||||||
seq_idx_m = tl.load(
|
seq_idx_m = tl.load(
|
||||||
seq_idx_ptr +
|
seq_idx_ptr +
|
||||||
(pid_m * BLOCK_SIZE_M + c_off) * stride_seq_idx_seqlen, )
|
(pid_m * BLOCK_SIZE_M + c_off) * stride_seq_idx_seqlen, )
|
||||||
|
|
||||||
# - recall that in ssd_state_passing, for the case c_off == 0
|
# - recall that in ssd_state_passing, for the case c_off == 0
|
||||||
# i.e., the very first sequence, we made states_ptr hold its initial state
|
# i.e., the very first sequence, we made states_ptr hold its initial state
|
||||||
# so this edge case is taken care of
|
# so this edge case is taken care of
|
||||||
if ((c_off == 0) and
|
if ((c_off == 0) and (seq_idx_prev != seq_idx_m
|
||||||
(seq_idx_prev != seq_idx_m
|
) # if a seq is changed exactly on boundary
|
||||||
) # if a seq is changed exactly on boundary
|
or (c_off > 0) # implies a new example (pseudo chunk)
|
||||||
or (c_off > 0) # implies a new example (pseudo chunk)
|
):
|
||||||
):
|
|
||||||
|
|
||||||
# - replace prev_states_ptr with init_states
|
# - replace prev_states_ptr with init_states
|
||||||
prev_states_ptr = initstates_ptr + seq_idx_m * stride_init_states_batch + pid_h * stride_init_states_head
|
prev_states_ptr = initstates_ptr + seq_idx_m * stride_init_states_batch + pid_h * stride_init_states_head
|
||||||
prev_states_hdim = stride_init_states_hdim # override strides
|
prev_states_hdim = stride_init_states_hdim # override strides
|
||||||
prev_states_dstate = stride_init_states_dstate
|
prev_states_dstate = stride_init_states_dstate
|
||||||
|
|
||||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize,
|
dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize,
|
||||||
@ -256,7 +239,6 @@ def _chunk_scan_fwd_kernel(
|
|||||||
|
|
||||||
# - handle chunk state limit
|
# - handle chunk state limit
|
||||||
if HAS_INITSTATES:
|
if HAS_INITSTATES:
|
||||||
|
|
||||||
# have to split this if otherwise compilation will have problems
|
# have to split this if otherwise compilation will have problems
|
||||||
dA_cs_m_boundary = 0.0
|
dA_cs_m_boundary = 0.0
|
||||||
|
|
||||||
@ -296,13 +278,11 @@ def _chunk_scan_fwd_kernel(
|
|||||||
dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize,
|
dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize,
|
||||||
mask=(((c_off - 1) > -1) and ((c_off) < chunk_size)),
|
mask=(((c_off - 1) > -1) and ((c_off) < chunk_size)),
|
||||||
other=0.0).to(tl.float32)
|
other=0.0).to(tl.float32)
|
||||||
|
else:
|
||||||
if HAS_SEQ_IDX:
|
|
||||||
# - handle seq idx when HAS_INITSTATES==False
|
# - handle seq idx when HAS_INITSTATES==False
|
||||||
if not HAS_INITSTATES:
|
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
|
||||||
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
|
mask=offs_m < chunk_size_limit,
|
||||||
mask=offs_m < chunk_size_limit,
|
other=-1)
|
||||||
other=-1)
|
|
||||||
|
|
||||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||||
|
|
||||||
@ -319,18 +299,15 @@ def _chunk_scan_fwd_kernel(
|
|||||||
prev_states_ptrs = prev_states_ptr + (
|
prev_states_ptrs = prev_states_ptr + (
|
||||||
offs_n[None, :] * prev_states_hdim +
|
offs_n[None, :] * prev_states_hdim +
|
||||||
offs_k_dstate[:, None] * prev_states_dstate)
|
offs_k_dstate[:, None] * prev_states_dstate)
|
||||||
if HAS_SEQ_IDX:
|
|
||||||
|
|
||||||
if not HAS_INITSTATES:
|
if not HAS_INITSTATES:
|
||||||
# - this is for continuous batching where there is no init states
|
# - this is for continuous batching where there is no init states
|
||||||
scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m),
|
scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0)
|
||||||
0.0)
|
|
||||||
else:
|
|
||||||
# - if there is initstates, we will rely on prev_states, no zeroing
|
|
||||||
# required.
|
|
||||||
scale_m = tl.exp(dA_cs_m - dA_cs_m_boundary)
|
|
||||||
else:
|
else:
|
||||||
scale_m = tl.exp(dA_cs_m)
|
# - if there is initstates, we will rely on prev_states, no zeroing
|
||||||
|
# required.
|
||||||
|
scale_m = tl.exp(dA_cs_m - dA_cs_m_boundary)
|
||||||
|
|
||||||
if BLOCK_SIZE_DSTATE <= 128:
|
if BLOCK_SIZE_DSTATE <= 128:
|
||||||
C = tl.load(C_ptrs,
|
C = tl.load(C_ptrs,
|
||||||
mask=(offs_m[:, None] < chunk_size_limit) &
|
mask=(offs_m[:, None] < chunk_size_limit) &
|
||||||
@ -416,15 +393,7 @@ def _chunk_scan_fwd_kernel(
|
|||||||
acc += x_residual * D
|
acc += x_residual * D
|
||||||
|
|
||||||
if HAS_Z:
|
if HAS_Z:
|
||||||
out_x_ptr += pid_b * stride_out_batch + c_idx * chunk_size * stride_out_seqlen + pid_h * stride_out_head
|
z_ptr += c_idx * chunk_size * stride_z_seqlen + pid_h * stride_z_head
|
||||||
out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] +
|
|
||||||
offs_out_n[None, :])
|
|
||||||
tl.store(out_x_ptrs,
|
|
||||||
acc,
|
|
||||||
mask=(offs_out_m[:, None] < chunk_size_limit) &
|
|
||||||
(offs_out_n[None, :] < hdim))
|
|
||||||
|
|
||||||
z_ptr += pid_b * stride_z_batch + c_idx * chunk_size * stride_z_seqlen + pid_h * stride_z_head
|
|
||||||
z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] +
|
z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] +
|
||||||
stride_z_hdim * offs_out_n[None, :])
|
stride_z_hdim * offs_out_n[None, :])
|
||||||
z = tl.load(z_ptrs,
|
z = tl.load(z_ptrs,
|
||||||
@ -433,7 +402,7 @@ def _chunk_scan_fwd_kernel(
|
|||||||
other=0.0).to(tl.float32)
|
other=0.0).to(tl.float32)
|
||||||
acc *= z * tl.sigmoid(z)
|
acc *= z * tl.sigmoid(z)
|
||||||
|
|
||||||
out_ptr += pid_b * stride_out_batch + c_idx * chunk_size * stride_out_seqlen + pid_h * stride_out_head
|
out_ptr += c_idx * chunk_size * stride_out_seqlen + pid_h * stride_out_head
|
||||||
out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] +
|
out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] +
|
||||||
offs_out_n[None, :] * stride_out_hdim)
|
offs_out_n[None, :] * stride_out_hdim)
|
||||||
tl.store(out_ptrs,
|
tl.store(out_ptrs,
|
||||||
@ -449,126 +418,110 @@ def _chunk_scan_fwd(
|
|||||||
dA_cumsum,
|
dA_cumsum,
|
||||||
C,
|
C,
|
||||||
states,
|
states,
|
||||||
|
out,
|
||||||
|
seq_idx,
|
||||||
D=None,
|
D=None,
|
||||||
z=None,
|
z=None,
|
||||||
seq_idx=None,
|
|
||||||
chunk_indices=None,
|
chunk_indices=None,
|
||||||
chunk_offsets=None,
|
chunk_offsets=None,
|
||||||
initial_states=None,
|
initial_states=None,
|
||||||
out=None,
|
|
||||||
):
|
):
|
||||||
batch, seqlen, nheads, headdim = x.shape
|
assert seq_idx is not None, "this implementation requires seq_idx"
|
||||||
_, _, nchunks, chunk_size = dt.shape
|
|
||||||
_, _, ngroups, dstate = C.shape
|
seqlen, nheads, headdim = x.shape
|
||||||
|
_, nchunks, chunk_size = dt.shape
|
||||||
|
_, ngroups, dstate = C.shape
|
||||||
assert nheads % ngroups == 0
|
assert nheads % ngroups == 0
|
||||||
assert C.shape == (batch, seqlen, ngroups, dstate)
|
assert C.shape == (seqlen, ngroups, dstate)
|
||||||
assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
|
assert cb.shape == (nchunks, ngroups, chunk_size, chunk_size)
|
||||||
if z is not None:
|
|
||||||
assert z.shape == x.shape
|
|
||||||
if D is not None:
|
if D is not None:
|
||||||
assert D.shape == (nheads, headdim) or D.shape == (nheads, )
|
assert D.shape == (nheads, headdim) or D.shape == (nheads, )
|
||||||
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
if z is not None:
|
||||||
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
|
assert z.shape == x.shape
|
||||||
assert states.shape == (batch, nchunks, nheads, headdim, dstate)
|
assert dt.shape == (nheads, nchunks, chunk_size)
|
||||||
|
assert dA_cumsum.shape == (nheads, nchunks, chunk_size)
|
||||||
|
assert states.shape == (nchunks, nheads, headdim, dstate)
|
||||||
|
assert seq_idx.shape == (seqlen, )
|
||||||
|
|
||||||
if seq_idx is not None:
|
if initial_states is not None:
|
||||||
assert seq_idx.shape == (batch, seqlen)
|
# with initial states, we need to take care of how
|
||||||
|
# seq_idx crosses the boundaries
|
||||||
if initial_states is not None:
|
assert chunk_indices is not None and chunk_offsets is not None, \
|
||||||
# with initial states, we need to take care of how
|
"chunk_indices and chunk_offsets should have been set"
|
||||||
# seq_idx crosses the boundaries
|
|
||||||
assert batch == 1, "chunk scan only supports initial states with batch 1"
|
|
||||||
assert chunk_indices is not None and chunk_offsets is not None, \
|
|
||||||
"chunk_indices and chunk_offsets should have been set"
|
|
||||||
else:
|
|
||||||
chunk_indices, chunk_offsets = None, None
|
|
||||||
else:
|
else:
|
||||||
chunk_indices, chunk_offsets = None, None
|
chunk_indices, chunk_offsets = None, None
|
||||||
|
|
||||||
assert out.shape == x.shape
|
|
||||||
|
|
||||||
if z is not None:
|
|
||||||
out_x = torch.empty_like(x)
|
|
||||||
assert out_x.stride() == out.stride()
|
|
||||||
else:
|
|
||||||
out_x = None
|
|
||||||
|
|
||||||
grid = lambda META: (
|
grid = lambda META: (
|
||||||
triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(
|
triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(
|
||||||
headdim, META['BLOCK_SIZE_N']), batch * nchunks
|
headdim, META['BLOCK_SIZE_N']), nchunks
|
||||||
if chunk_offsets is None else len(chunk_offsets), nheads)
|
if chunk_offsets is None else len(chunk_offsets), nheads)
|
||||||
z_strides = ((z.stride(0), z.stride(1), z.stride(2),
|
|
||||||
z.stride(3)) if z is not None else (0, 0, 0, 0))
|
z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else
|
||||||
|
(0, 0, 0))
|
||||||
|
initial_states_strides = ((initial_states.stride(0),
|
||||||
|
initial_states.stride(1),
|
||||||
|
initial_states.stride(2),
|
||||||
|
initial_states.stride(3))
|
||||||
|
if initial_states is not None else (0, 0, 0, 0))
|
||||||
|
|
||||||
_chunk_scan_fwd_kernel[grid](
|
_chunk_scan_fwd_kernel[grid](
|
||||||
cb,
|
cb_ptr=cb,
|
||||||
x,
|
x_ptr=x,
|
||||||
z,
|
z_ptr=z,
|
||||||
out,
|
out_ptr=out,
|
||||||
out_x,
|
dt_ptr=dt,
|
||||||
dt,
|
dA_cumsum_ptr=dA_cumsum,
|
||||||
dA_cumsum,
|
seq_idx_ptr=seq_idx,
|
||||||
seq_idx,
|
C_ptr=C,
|
||||||
C,
|
states_ptr=states,
|
||||||
states,
|
D_ptr=D,
|
||||||
D,
|
initstates_ptr=initial_states,
|
||||||
initial_states,
|
chunk_indices_ptr=chunk_indices,
|
||||||
chunk_indices,
|
chunk_offsets_ptr=chunk_offsets,
|
||||||
chunk_offsets,
|
chunk_meta_num=len(chunk_indices) if chunk_indices is not None else 0,
|
||||||
len(chunk_indices) if chunk_indices is not None else 0,
|
chunk_size=chunk_size,
|
||||||
chunk_size,
|
hdim=headdim,
|
||||||
headdim,
|
dstate=dstate,
|
||||||
dstate,
|
seqlen=seqlen,
|
||||||
batch,
|
nheads_ngroups_ratio=nheads // ngroups,
|
||||||
seqlen,
|
stride_cb_chunk=cb.stride(0),
|
||||||
nheads // ngroups,
|
stride_cb_head=cb.stride(1),
|
||||||
cb.stride(0),
|
stride_cb_csize_m=cb.stride(2),
|
||||||
cb.stride(1),
|
stride_cb_csize_k=cb.stride(3),
|
||||||
cb.stride(2),
|
stride_x_seqlen=x.stride(0),
|
||||||
cb.stride(3),
|
stride_x_head=x.stride(1),
|
||||||
cb.stride(4),
|
stride_x_hdim=x.stride(2),
|
||||||
x.stride(0),
|
stride_z_seqlen=z_strides[0],
|
||||||
x.stride(1),
|
stride_z_head=z_strides[1],
|
||||||
x.stride(2),
|
stride_z_hdim=z_strides[2],
|
||||||
x.stride(3),
|
stride_out_seqlen=out.stride(0),
|
||||||
z_strides[0],
|
stride_out_head=out.stride(1),
|
||||||
z_strides[1],
|
stride_out_hdim=out.stride(2),
|
||||||
z_strides[2],
|
stride_dt_chunk=dt.stride(1),
|
||||||
z_strides[3],
|
stride_dt_head=dt.stride(0),
|
||||||
out.stride(0),
|
stride_dt_csize=dt.stride(2),
|
||||||
out.stride(1),
|
stride_dA_cs_chunk=dA_cumsum.stride(1),
|
||||||
out.stride(2),
|
stride_dA_cs_head=dA_cumsum.stride(0),
|
||||||
out.stride(3),
|
stride_dA_cs_csize=dA_cumsum.stride(2),
|
||||||
dt.stride(0),
|
stride_seq_idx_seqlen=seq_idx.stride(0),
|
||||||
dt.stride(2),
|
stride_C_seqlen=C.stride(0),
|
||||||
dt.stride(1),
|
stride_C_head=C.stride(1),
|
||||||
dt.stride(3),
|
stride_C_dstate=C.stride(2),
|
||||||
dA_cumsum.stride(0),
|
stride_states_chunk=states.stride(0),
|
||||||
dA_cumsum.stride(2),
|
stride_states_head=states.stride(1),
|
||||||
dA_cumsum.stride(1),
|
stride_states_hdim=states.stride(2),
|
||||||
dA_cumsum.stride(3),
|
stride_states_dstate=states.stride(3),
|
||||||
*((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else
|
stride_init_states_batch=initial_states_strides[0],
|
||||||
(0, 0)),
|
stride_init_states_head=initial_states_strides[1],
|
||||||
C.stride(0),
|
stride_init_states_hdim=initial_states_strides[2],
|
||||||
C.stride(1),
|
stride_init_states_dstate=initial_states_strides[3],
|
||||||
C.stride(2),
|
stride_D_head=D.stride(0) if D is not None else 0,
|
||||||
C.stride(3),
|
IS_CAUSAL=True,
|
||||||
states.stride(0),
|
HAS_D=D is not None,
|
||||||
states.stride(1),
|
D_HAS_HDIM=D.dim() == 2 if D is not None else True,
|
||||||
states.stride(2),
|
|
||||||
states.stride(3),
|
|
||||||
states.stride(4),
|
|
||||||
*((initial_states.stride(0), initial_states.stride(1),
|
|
||||||
initial_states.stride(2),
|
|
||||||
initial_states.stride(3)) if initial_states is not None else
|
|
||||||
(0, 0, 0, 0)),
|
|
||||||
D.stride(0) if D is not None else 0,
|
|
||||||
True,
|
|
||||||
D is not None,
|
|
||||||
D.dim() == 2 if D is not None else True,
|
|
||||||
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
|
|
||||||
HAS_Z=z is not None,
|
HAS_Z=z is not None,
|
||||||
HAS_SEQ_IDX=seq_idx is not None,
|
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
|
||||||
IS_TRITON_22=TRITON_22,
|
IS_TRITON_22=TRITON_22,
|
||||||
HAS_INITSTATES=initial_states is not None,
|
HAS_INITSTATES=initial_states is not None,
|
||||||
)
|
)
|
||||||
return out_x
|
return
|
||||||
|
|||||||
@ -35,41 +35,35 @@ def _chunk_cumsum_fwd_kernel(
|
|||||||
dt_out_ptr,
|
dt_out_ptr,
|
||||||
dA_cumsum_ptr,
|
dA_cumsum_ptr,
|
||||||
# Matrix dimension
|
# Matrix dimension
|
||||||
batch,
|
|
||||||
seqlen,
|
seqlen,
|
||||||
nheads,
|
nheads: tl.constexpr,
|
||||||
chunk_size,
|
chunk_size: tl.constexpr,
|
||||||
dt_min,
|
dt_min: tl.constexpr,
|
||||||
dt_max,
|
dt_max: tl.constexpr,
|
||||||
# Strides
|
# Strides
|
||||||
stride_dt_batch,
|
stride_dt_seqlen: tl.int64,
|
||||||
stride_dt_seqlen,
|
stride_dt_head: tl.constexpr,
|
||||||
stride_dt_head,
|
stride_A_head: tl.constexpr,
|
||||||
stride_A_head,
|
stride_dt_bias_head: tl.constexpr,
|
||||||
stride_dt_bias_head,
|
stride_dt_out_head: tl.int64,
|
||||||
stride_dt_out_batch,
|
stride_dt_out_chunk: tl.int64,
|
||||||
stride_dt_out_chunk,
|
stride_dt_out_csize: tl.constexpr,
|
||||||
stride_dt_out_head,
|
stride_dA_cs_head: tl.int64,
|
||||||
stride_dt_out_csize,
|
stride_dA_cs_chunk: tl.int64,
|
||||||
stride_dA_cs_batch,
|
stride_dA_cs_csize: tl.constexpr,
|
||||||
stride_dA_cs_chunk,
|
|
||||||
stride_dA_cs_head,
|
|
||||||
stride_dA_cs_csize,
|
|
||||||
# Meta-parameters
|
# Meta-parameters
|
||||||
DT_SOFTPLUS: tl.constexpr,
|
DT_SOFTPLUS: tl.constexpr,
|
||||||
HAS_DT_BIAS: tl.constexpr,
|
HAS_DT_BIAS: tl.constexpr,
|
||||||
BLOCK_SIZE_H: tl.constexpr,
|
BLOCK_SIZE_H: tl.constexpr,
|
||||||
BLOCK_SIZE_CHUNK: tl.constexpr,
|
BLOCK_SIZE_CHUNK: tl.constexpr,
|
||||||
):
|
):
|
||||||
pid_b = tl.program_id(axis=0)
|
|
||||||
|
|
||||||
# if dt is long, may cause problems, so use 64 bit
|
# if dt is long, may cause problems, so use 64 bit
|
||||||
# https://github.com/triton-lang/triton/issues/1058
|
# https://github.com/triton-lang/triton/issues/1058
|
||||||
pid_c = tl.program_id(axis=1).to(tl.int64)
|
pid_c = tl.program_id(axis=0).to(tl.int64)
|
||||||
pid_h = tl.program_id(axis=2)
|
pid_h = tl.program_id(axis=1)
|
||||||
dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
|
dt_ptr += pid_c * chunk_size * stride_dt_seqlen
|
||||||
dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk
|
dt_out_ptr += pid_c * stride_dt_out_chunk
|
||||||
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk
|
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk
|
||||||
|
|
||||||
offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
|
offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
|
||||||
offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
|
offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
|
||||||
@ -93,9 +87,8 @@ def _chunk_cumsum_fwd_kernel(
|
|||||||
dt += dt_bias[:, None]
|
dt += dt_bias[:, None]
|
||||||
if DT_SOFTPLUS:
|
if DT_SOFTPLUS:
|
||||||
dt = tl.where(dt <= 20.0, softplus(dt), dt)
|
dt = tl.where(dt <= 20.0, softplus(dt), dt)
|
||||||
# As of Triton 2.2.0, tl.clamp is not available yet
|
|
||||||
# dt = tl.clamp(dt, dt_min, dt_max)
|
dt = tl.clamp(dt, dt_min, dt_max)
|
||||||
dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
|
|
||||||
dt = tl.where(
|
dt = tl.where(
|
||||||
(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt,
|
(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt,
|
||||||
0.0)
|
0.0)
|
||||||
@ -197,56 +190,46 @@ def _chunk_state_fwd_kernel(
|
|||||||
dA_cumsum_ptr,
|
dA_cumsum_ptr,
|
||||||
seq_idx_ptr,
|
seq_idx_ptr,
|
||||||
# Matrix dimensions
|
# Matrix dimensions
|
||||||
hdim,
|
hdim: tl.constexpr,
|
||||||
dstate,
|
dstate: tl.constexpr,
|
||||||
chunk_size,
|
chunk_size: tl.constexpr,
|
||||||
batch,
|
|
||||||
seqlen,
|
seqlen,
|
||||||
nheads_ngroups_ratio,
|
nheads_ngroups_ratio: tl.constexpr,
|
||||||
# Strides
|
# Strides
|
||||||
stride_x_batch,
|
stride_x_seqlen: tl.int64,
|
||||||
stride_x_seqlen,
|
stride_x_head: tl.int64,
|
||||||
stride_x_head,
|
stride_x_hdim: tl.constexpr,
|
||||||
stride_x_hdim,
|
stride_b_seqlen: tl.int64,
|
||||||
stride_b_batch,
|
stride_b_head: tl.int64,
|
||||||
stride_b_seqlen,
|
stride_b_dstate: tl.constexpr,
|
||||||
stride_b_head,
|
stride_states_chunk: tl.int64,
|
||||||
stride_b_dstate,
|
stride_states_head: tl.int64,
|
||||||
stride_states_batch,
|
stride_states_hdim: tl.int64,
|
||||||
stride_states_chunk,
|
stride_states_dstate: tl.constexpr,
|
||||||
stride_states_head,
|
stride_dt_head: tl.int64,
|
||||||
stride_states_hdim,
|
stride_dt_chunk: tl.int64,
|
||||||
stride_states_dstate,
|
stride_dt_csize: tl.constexpr,
|
||||||
stride_dt_batch,
|
stride_dA_cs_head: tl.int64,
|
||||||
stride_dt_chunk,
|
stride_dA_cs_chunk: tl.int64,
|
||||||
stride_dt_head,
|
stride_dA_cs_csize: tl.constexpr,
|
||||||
stride_dt_csize,
|
stride_seq_idx_seqlen: tl.constexpr,
|
||||||
stride_dA_cs_batch,
|
|
||||||
stride_dA_cs_chunk,
|
|
||||||
stride_dA_cs_head,
|
|
||||||
stride_dA_cs_csize,
|
|
||||||
stride_seq_idx_batch,
|
|
||||||
stride_seq_idx_seqlen,
|
|
||||||
# Meta-parameters
|
# Meta-parameters
|
||||||
HAS_SEQ_IDX: tl.constexpr,
|
|
||||||
BLOCK_SIZE_M: tl.constexpr,
|
BLOCK_SIZE_M: tl.constexpr,
|
||||||
BLOCK_SIZE_N: tl.constexpr,
|
BLOCK_SIZE_N: tl.constexpr,
|
||||||
BLOCK_SIZE_K: tl.constexpr,
|
BLOCK_SIZE_K: tl.constexpr,
|
||||||
):
|
):
|
||||||
pid_bc = tl.program_id(axis=1).to(tl.int64)
|
pid_c = tl.program_id(axis=1).to(tl.int64)
|
||||||
pid_c = pid_bc // batch
|
|
||||||
pid_b = pid_bc - pid_c * batch
|
|
||||||
pid_h = tl.program_id(axis=2)
|
pid_h = tl.program_id(axis=2)
|
||||||
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
|
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
|
||||||
pid_m = tl.program_id(axis=0) // num_pid_n
|
pid_m = tl.program_id(axis=0) // num_pid_n
|
||||||
pid_n = tl.program_id(axis=0) % num_pid_n
|
pid_n = tl.program_id(axis=0) % num_pid_n
|
||||||
b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (
|
b_ptr += pid_c * chunk_size * stride_b_seqlen + (
|
||||||
pid_h // nheads_ngroups_ratio) * stride_b_head
|
pid_h // nheads_ngroups_ratio) * stride_b_head
|
||||||
x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
|
x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
|
||||||
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
||||||
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
||||||
if HAS_SEQ_IDX:
|
|
||||||
seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
|
seq_idx_ptr += pid_c * chunk_size * stride_seq_idx_seqlen
|
||||||
|
|
||||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
@ -259,13 +242,11 @@ def _chunk_state_fwd_kernel(
|
|||||||
dA_cs_last = tl.load(dA_cumsum_ptr +
|
dA_cs_last = tl.load(dA_cumsum_ptr +
|
||||||
(chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
|
(chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
|
||||||
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
||||||
if HAS_SEQ_IDX:
|
|
||||||
seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen
|
|
||||||
|
|
||||||
|
seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen
|
||||||
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
||||||
if HAS_SEQ_IDX:
|
seq_idx_last = tl.load(seq_idx_ptr +
|
||||||
seq_idx_last = tl.load(seq_idx_ptr +
|
(chunk_size_limit - 1) * stride_seq_idx_seqlen)
|
||||||
(chunk_size_limit - 1) * stride_seq_idx_seqlen)
|
|
||||||
|
|
||||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||||
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
|
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
|
||||||
@ -280,29 +261,28 @@ def _chunk_state_fwd_kernel(
|
|||||||
dA_cs_k = tl.load(dA_cumsum_ptrs,
|
dA_cs_k = tl.load(dA_cumsum_ptrs,
|
||||||
mask=offs_k < chunk_size_limit - k,
|
mask=offs_k < chunk_size_limit - k,
|
||||||
other=0.0).to(tl.float32)
|
other=0.0).to(tl.float32)
|
||||||
if HAS_SEQ_IDX:
|
|
||||||
seq_idx_k = tl.load(seq_idx_ptrs,
|
seq_idx_k = tl.load(seq_idx_ptrs,
|
||||||
mask=offs_k < chunk_size_limit - k,
|
mask=offs_k < chunk_size_limit - k,
|
||||||
other=-1)
|
other=-1)
|
||||||
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k,
|
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k,
|
||||||
other=0.0).to(tl.float32)
|
other=0.0).to(tl.float32)
|
||||||
if not HAS_SEQ_IDX:
|
|
||||||
scale = tl.exp(dA_cs_last - dA_cs_k) * dt_k
|
scale = tl.where(seq_idx_k == seq_idx_last,
|
||||||
else:
|
tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0)
|
||||||
scale = tl.where(seq_idx_k == seq_idx_last,
|
|
||||||
tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0)
|
|
||||||
b *= scale[:, None]
|
b *= scale[:, None]
|
||||||
b = b.to(x_ptr.dtype.element_ty)
|
b = b.to(x_ptr.dtype.element_ty)
|
||||||
acc += tl.dot(x, b)
|
acc += tl.dot(x, b)
|
||||||
|
|
||||||
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
|
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
|
||||||
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
|
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
|
||||||
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
|
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
|
||||||
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
||||||
if HAS_SEQ_IDX:
|
seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen
|
||||||
seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen
|
|
||||||
states = acc.to(states_ptr.dtype.element_ty)
|
states = acc.to(states_ptr.dtype.element_ty)
|
||||||
|
|
||||||
states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head
|
states_ptr += pid_c * stride_states_chunk + pid_h * stride_states_head
|
||||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim +
|
states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim +
|
||||||
@ -400,36 +380,35 @@ def _chunk_state_varlen_kernel(
|
|||||||
states_ptr,
|
states_ptr,
|
||||||
initstates_ptr,
|
initstates_ptr,
|
||||||
# Matrix dimensions
|
# Matrix dimensions
|
||||||
hdim,
|
hdim: tl.constexpr,
|
||||||
dstate,
|
dstate: tl.constexpr,
|
||||||
chunk_size,
|
chunk_size: tl.constexpr,
|
||||||
seqlen,
|
nheads_ngroups_ratio: tl.constexpr,
|
||||||
nheads_ngroups_ratio,
|
|
||||||
# Strides
|
# Strides
|
||||||
stride_x_seqlen,
|
stride_x_seqlen: tl.int64,
|
||||||
stride_x_head,
|
stride_x_head: tl.int64,
|
||||||
stride_x_hdim,
|
stride_x_hdim: tl.constexpr,
|
||||||
stride_b_seqlen,
|
stride_b_seqlen: tl.int64,
|
||||||
stride_b_head,
|
stride_b_head: tl.int64,
|
||||||
stride_b_dstate,
|
stride_b_dstate: tl.constexpr,
|
||||||
stride_dt_chunk,
|
stride_dt_head: tl.int64,
|
||||||
stride_dt_head,
|
stride_dt_chunk: tl.int64,
|
||||||
stride_dt_csize,
|
stride_dt_csize: tl.constexpr,
|
||||||
stride_dA_cs_chunk,
|
stride_dA_cs_head: tl.int64,
|
||||||
stride_dA_cs_head,
|
stride_dA_cs_chunk: tl.int64,
|
||||||
stride_dA_cs_csize,
|
stride_dA_cs_csize: tl.constexpr,
|
||||||
stride_chunk_states_chunk,
|
stride_chunk_states_chunk: tl.int64,
|
||||||
stride_chunk_states_head,
|
stride_chunk_states_head: tl.int64,
|
||||||
stride_chunk_states_hdim,
|
stride_chunk_states_hdim: tl.int64,
|
||||||
stride_chunk_states_dstate,
|
stride_chunk_states_dstate: tl.constexpr,
|
||||||
stride_states_batch,
|
stride_states_batch: tl.int64,
|
||||||
stride_states_head,
|
stride_states_head: tl.int64,
|
||||||
stride_states_hdim,
|
stride_states_hdim: tl.int64,
|
||||||
stride_states_dstate,
|
stride_states_dstate: tl.constexpr,
|
||||||
stride_init_states_batch,
|
stride_init_states_batch: tl.int64,
|
||||||
stride_init_states_head,
|
stride_init_states_head: tl.int64,
|
||||||
stride_init_states_hdim,
|
stride_init_states_hdim: tl.int64,
|
||||||
stride_init_states_dstate,
|
stride_init_states_dstate: tl.constexpr,
|
||||||
# Meta-parameters
|
# Meta-parameters
|
||||||
BLOCK_SIZE_M: tl.constexpr,
|
BLOCK_SIZE_M: tl.constexpr,
|
||||||
BLOCK_SIZE_N: tl.constexpr,
|
BLOCK_SIZE_N: tl.constexpr,
|
||||||
@ -558,52 +537,47 @@ def _chunk_cumsum_fwd(dt,
|
|||||||
dt_bias=None,
|
dt_bias=None,
|
||||||
dt_softplus=False,
|
dt_softplus=False,
|
||||||
dt_limit=(0.0, float("inf"))):
|
dt_limit=(0.0, float("inf"))):
|
||||||
batch, seqlen, nheads = dt.shape
|
seqlen, nheads = dt.shape
|
||||||
assert A.shape == (nheads, )
|
assert A.shape == (nheads, )
|
||||||
if dt_bias is not None:
|
if dt_bias is not None:
|
||||||
assert dt_bias.shape == (nheads, )
|
assert dt_bias.shape == (nheads, )
|
||||||
nchunks = math.ceil(seqlen / chunk_size)
|
nchunks = math.ceil(seqlen / chunk_size)
|
||||||
dt_out = torch.empty(batch,
|
dt_out = torch.empty(nheads,
|
||||||
nheads,
|
|
||||||
nchunks,
|
nchunks,
|
||||||
chunk_size,
|
chunk_size,
|
||||||
device=dt.device,
|
device=dt.device,
|
||||||
dtype=torch.float32)
|
dtype=torch.float32)
|
||||||
dA_cumsum = torch.empty(batch,
|
dA_cumsum = torch.empty(nheads,
|
||||||
nheads,
|
|
||||||
nchunks,
|
nchunks,
|
||||||
chunk_size,
|
chunk_size,
|
||||||
device=dt.device,
|
device=dt.device,
|
||||||
dtype=torch.float32)
|
dtype=torch.float32)
|
||||||
grid_chunk_cs = lambda META: (batch, nchunks,
|
grid_chunk_cs = lambda META: (nchunks,
|
||||||
triton.cdiv(nheads, META['BLOCK_SIZE_H']))
|
triton.cdiv(nheads, META['BLOCK_SIZE_H']))
|
||||||
with torch.cuda.device(dt.device.index):
|
with torch.cuda.device(dt.device.index):
|
||||||
_chunk_cumsum_fwd_kernel[grid_chunk_cs](
|
_chunk_cumsum_fwd_kernel[grid_chunk_cs](
|
||||||
dt,
|
dt_ptr=dt,
|
||||||
A,
|
A_ptr=A,
|
||||||
dt_bias,
|
dt_bias_ptr=dt_bias,
|
||||||
dt_out,
|
dt_out_ptr=dt_out,
|
||||||
dA_cumsum,
|
dA_cumsum_ptr=dA_cumsum,
|
||||||
batch,
|
seqlen=seqlen,
|
||||||
seqlen,
|
nheads=nheads,
|
||||||
nheads,
|
chunk_size=chunk_size,
|
||||||
chunk_size,
|
dt_min=dt_limit[0],
|
||||||
dt_limit[0],
|
dt_max=dt_limit[1],
|
||||||
dt_limit[1],
|
stride_dt_seqlen=dt.stride(0),
|
||||||
dt.stride(0),
|
stride_dt_head=dt.stride(1),
|
||||||
dt.stride(1),
|
stride_A_head=A.stride(0),
|
||||||
dt.stride(2),
|
stride_dt_bias_head=dt_bias.stride(0)
|
||||||
A.stride(0),
|
if dt_bias is not None else 0,
|
||||||
dt_bias.stride(0) if dt_bias is not None else 0,
|
stride_dt_out_head=dt_out.stride(0),
|
||||||
dt_out.stride(0),
|
stride_dt_out_chunk=dt_out.stride(1),
|
||||||
dt_out.stride(2),
|
stride_dt_out_csize=dt_out.stride(2),
|
||||||
dt_out.stride(1),
|
stride_dA_cs_head=dA_cumsum.stride(0),
|
||||||
dt_out.stride(3),
|
stride_dA_cs_chunk=dA_cumsum.stride(1),
|
||||||
dA_cumsum.stride(0),
|
stride_dA_cs_csize=dA_cumsum.stride(2),
|
||||||
dA_cumsum.stride(2),
|
DT_SOFTPLUS=dt_softplus,
|
||||||
dA_cumsum.stride(1),
|
|
||||||
dA_cumsum.stride(3),
|
|
||||||
dt_softplus,
|
|
||||||
HAS_DT_BIAS=dt_bias is not None,
|
HAS_DT_BIAS=dt_bias is not None,
|
||||||
BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
|
BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
|
||||||
)
|
)
|
||||||
@ -617,63 +591,57 @@ def _chunk_state_fwd(B,
|
|||||||
seq_idx=None,
|
seq_idx=None,
|
||||||
states=None,
|
states=None,
|
||||||
states_in_fp32=True):
|
states_in_fp32=True):
|
||||||
batch, seqlen, nheads, headdim = x.shape
|
seqlen, nheads, headdim = x.shape
|
||||||
_, _, nchunks, chunk_size = dt.shape
|
_, nchunks, chunk_size = dt.shape
|
||||||
_, _, ngroups, dstate = B.shape
|
_, ngroups, dstate = B.shape
|
||||||
assert nheads % ngroups == 0
|
assert nheads % ngroups == 0
|
||||||
assert B.shape == (batch, seqlen, ngroups, dstate)
|
assert B.shape == (seqlen, ngroups, dstate)
|
||||||
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
assert dt.shape == (nheads, nchunks, chunk_size)
|
||||||
assert dA_cumsum.shape == dt.shape
|
assert dA_cumsum.shape == dt.shape
|
||||||
if seq_idx is not None:
|
|
||||||
assert seq_idx.shape == (batch, seqlen)
|
assert seq_idx is not None
|
||||||
|
assert seq_idx.shape == (seqlen, )
|
||||||
|
|
||||||
if states is not None:
|
if states is not None:
|
||||||
assert states.shape == (batch, nchunks, nheads, headdim, dstate)
|
assert states.shape == (nchunks, nheads, headdim, dstate)
|
||||||
else:
|
else:
|
||||||
states_dtype = torch.float32 if states_in_fp32 else B.dtype
|
states_dtype = torch.float32 if states_in_fp32 else B.dtype
|
||||||
states = torch.empty((batch, nchunks, nheads, headdim, dstate),
|
states = torch.empty((nchunks, nheads, headdim, dstate),
|
||||||
device=x.device,
|
device=x.device,
|
||||||
dtype=states_dtype)
|
dtype=states_dtype)
|
||||||
grid = lambda META: (
|
|
||||||
triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(
|
grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.
|
||||||
dstate, META['BLOCK_SIZE_N']), batch * nchunks, nheads)
|
cdiv(dstate, META['BLOCK_SIZE_N']), nchunks, nheads)
|
||||||
with torch.cuda.device(x.device.index):
|
with torch.cuda.device(x.device.index):
|
||||||
_chunk_state_fwd_kernel[grid](
|
_chunk_state_fwd_kernel[grid](
|
||||||
x,
|
x_ptr=x,
|
||||||
B,
|
b_ptr=B,
|
||||||
states,
|
states_ptr=states,
|
||||||
dt,
|
dt_ptr=dt,
|
||||||
dA_cumsum,
|
dA_cumsum_ptr=dA_cumsum,
|
||||||
seq_idx,
|
seq_idx_ptr=seq_idx,
|
||||||
headdim,
|
hdim=headdim,
|
||||||
dstate,
|
dstate=dstate,
|
||||||
chunk_size,
|
chunk_size=chunk_size,
|
||||||
batch,
|
seqlen=seqlen,
|
||||||
seqlen,
|
nheads_ngroups_ratio=nheads // ngroups,
|
||||||
nheads // ngroups,
|
stride_x_seqlen=x.stride(0),
|
||||||
x.stride(0),
|
stride_x_head=x.stride(1),
|
||||||
x.stride(1),
|
stride_x_hdim=x.stride(2),
|
||||||
x.stride(2),
|
stride_b_seqlen=B.stride(0),
|
||||||
x.stride(3),
|
stride_b_head=B.stride(1),
|
||||||
B.stride(0),
|
stride_b_dstate=B.stride(2),
|
||||||
B.stride(1),
|
stride_states_chunk=states.stride(0),
|
||||||
B.stride(2),
|
stride_states_head=states.stride(1),
|
||||||
B.stride(-1),
|
stride_states_hdim=states.stride(2),
|
||||||
states.stride(0),
|
stride_states_dstate=states.stride(3),
|
||||||
states.stride(1),
|
stride_dt_head=dt.stride(0),
|
||||||
states.stride(2),
|
stride_dt_chunk=dt.stride(1),
|
||||||
states.stride(3),
|
stride_dt_csize=dt.stride(2),
|
||||||
states.stride(4),
|
stride_dA_cs_head=dA_cumsum.stride(0),
|
||||||
dt.stride(0),
|
stride_dA_cs_chunk=dA_cumsum.stride(1),
|
||||||
dt.stride(2),
|
stride_dA_cs_csize=dA_cumsum.stride(2),
|
||||||
dt.stride(1),
|
stride_seq_idx_seqlen=seq_idx.stride(0),
|
||||||
dt.stride(3),
|
|
||||||
dA_cumsum.stride(0),
|
|
||||||
dA_cumsum.stride(2),
|
|
||||||
dA_cumsum.stride(1),
|
|
||||||
dA_cumsum.stride(3),
|
|
||||||
*((seq_idx.stride(0),
|
|
||||||
seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
|
|
||||||
HAS_SEQ_IDX=seq_idx is not None,
|
|
||||||
)
|
)
|
||||||
return states
|
return states
|
||||||
|
|
||||||
@ -705,46 +673,52 @@ def chunk_state_varlen(B,
|
|||||||
dstate,
|
dstate,
|
||||||
dtype=chunk_states.dtype,
|
dtype=chunk_states.dtype,
|
||||||
device=chunk_states.device)
|
device=chunk_states.device)
|
||||||
|
|
||||||
|
initial_states_strides = ((initial_states.stride(0),
|
||||||
|
initial_states.stride(1),
|
||||||
|
initial_states.stride(2),
|
||||||
|
initial_states.stride(3))
|
||||||
|
if initial_states is not None else (0, 0, 0, 0))
|
||||||
|
|
||||||
grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.
|
grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.
|
||||||
cdiv(dstate, META['BLOCK_SIZE_N']), batch, nheads)
|
cdiv(dstate, META['BLOCK_SIZE_N']), batch, nheads)
|
||||||
with torch.cuda.device(x.device.index):
|
with torch.cuda.device(x.device.index):
|
||||||
_chunk_state_varlen_kernel[grid](
|
_chunk_state_varlen_kernel[grid](
|
||||||
x,
|
x_ptr=x,
|
||||||
B,
|
b_ptr=B,
|
||||||
dt,
|
dt_ptr=dt,
|
||||||
dA_cumsum,
|
dA_cumsum_ptr=dA_cumsum,
|
||||||
chunk_states,
|
chunk_states_ptr=chunk_states,
|
||||||
cu_seqlens,
|
cu_seqlens_ptr=cu_seqlens,
|
||||||
states,
|
states_ptr=states,
|
||||||
initial_states,
|
initstates_ptr=initial_states,
|
||||||
headdim,
|
hdim=headdim,
|
||||||
dstate,
|
dstate=dstate,
|
||||||
chunk_size,
|
chunk_size=chunk_size,
|
||||||
total_seqlen,
|
nheads_ngroups_ratio=nheads // ngroups,
|
||||||
nheads // ngroups,
|
stride_x_seqlen=x.stride(0),
|
||||||
x.stride(0),
|
stride_x_head=x.stride(1),
|
||||||
x.stride(1),
|
stride_x_hdim=x.stride(2),
|
||||||
x.stride(2),
|
stride_b_seqlen=B.stride(0),
|
||||||
B.stride(0),
|
stride_b_head=B.stride(1),
|
||||||
B.stride(1),
|
stride_b_dstate=B.stride(2),
|
||||||
B.stride(2),
|
stride_dt_head=dt.stride(0),
|
||||||
dt.stride(1),
|
stride_dt_chunk=dt.stride(1),
|
||||||
dt.stride(0),
|
stride_dt_csize=dt.stride(2),
|
||||||
dt.stride(2),
|
stride_dA_cs_head=dA_cumsum.stride(0),
|
||||||
dA_cumsum.stride(1),
|
stride_dA_cs_chunk=dA_cumsum.stride(1),
|
||||||
dA_cumsum.stride(0),
|
stride_dA_cs_csize=dA_cumsum.stride(2),
|
||||||
dA_cumsum.stride(2),
|
stride_chunk_states_chunk=chunk_states.stride(0),
|
||||||
chunk_states.stride(0),
|
stride_chunk_states_head=chunk_states.stride(1),
|
||||||
chunk_states.stride(1),
|
stride_chunk_states_hdim=chunk_states.stride(2),
|
||||||
chunk_states.stride(2),
|
stride_chunk_states_dstate=chunk_states.stride(3),
|
||||||
chunk_states.stride(3),
|
stride_states_batch=states.stride(0),
|
||||||
states.stride(0),
|
stride_states_head=states.stride(1),
|
||||||
states.stride(1),
|
stride_states_hdim=states.stride(2),
|
||||||
states.stride(2),
|
stride_states_dstate=states.stride(3),
|
||||||
states.stride(3),
|
stride_init_states_batch=initial_states_strides[0],
|
||||||
*((initial_states.stride(0), initial_states.stride(1),
|
stride_init_states_head=initial_states_strides[1],
|
||||||
initial_states.stride(2),
|
stride_init_states_hdim=initial_states_strides[2],
|
||||||
initial_states.stride(3)) if initial_states is not None else
|
stride_init_states_dstate=initial_states_strides[3],
|
||||||
(0, 0, 0, 0)),
|
|
||||||
HAS_INITSTATES=initial_states is not None)
|
HAS_INITSTATES=initial_states is not None)
|
||||||
return states
|
return states
|
||||||
|
|||||||
@ -31,6 +31,7 @@ def _mamba_chunk_scan_combined_fwd(x,
|
|||||||
B,
|
B,
|
||||||
C,
|
C,
|
||||||
chunk_size,
|
chunk_size,
|
||||||
|
out,
|
||||||
D=None,
|
D=None,
|
||||||
z=None,
|
z=None,
|
||||||
dt_bias=None,
|
dt_bias=None,
|
||||||
@ -41,14 +42,13 @@ def _mamba_chunk_scan_combined_fwd(x,
|
|||||||
cu_seqlens=None,
|
cu_seqlens=None,
|
||||||
dt_softplus=False,
|
dt_softplus=False,
|
||||||
dt_limit=(0.0, float("inf")),
|
dt_limit=(0.0, float("inf")),
|
||||||
state_dtype=None,
|
state_dtype=None):
|
||||||
out=None):
|
|
||||||
assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2"
|
assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2"
|
||||||
batch, seqlen, nheads, headdim = x.shape
|
seqlen, nheads, headdim = x.shape
|
||||||
_, _, ngroups, dstate = B.shape
|
_, ngroups, dstate = B.shape
|
||||||
assert nheads % ngroups == 0
|
assert nheads % ngroups == 0
|
||||||
assert B.shape == (batch, seqlen, ngroups, dstate)
|
assert B.shape == (seqlen, ngroups, dstate)
|
||||||
assert dt.shape == (batch, seqlen, nheads)
|
assert dt.shape == (seqlen, nheads)
|
||||||
assert A.shape == (nheads, )
|
assert A.shape == (nheads, )
|
||||||
assert C.shape == B.shape
|
assert C.shape == B.shape
|
||||||
if z is not None:
|
if z is not None:
|
||||||
@ -56,25 +56,24 @@ def _mamba_chunk_scan_combined_fwd(x,
|
|||||||
if D is not None:
|
if D is not None:
|
||||||
assert D.shape == (nheads, headdim) or D.shape == (nheads, )
|
assert D.shape == (nheads, headdim) or D.shape == (nheads, )
|
||||||
if seq_idx is not None:
|
if seq_idx is not None:
|
||||||
assert seq_idx.shape == (batch, seqlen)
|
assert seq_idx.shape == (seqlen, )
|
||||||
if B.stride(-1) != 1:
|
if B.stride(-1) != 1:
|
||||||
B = B.contiguous()
|
B = B.contiguous()
|
||||||
if C.stride(-1) != 1:
|
if C.stride(-1) != 1:
|
||||||
C = C.contiguous()
|
C = C.contiguous()
|
||||||
if x.stride(-1) != 1 and x.stride(
|
if x.stride(-1) != 1 and x.stride(
|
||||||
1) != 1: # Either M or K dimension should be contiguous
|
0) != 1: # Either M or K dimension should be contiguous
|
||||||
x = x.contiguous()
|
x = x.contiguous()
|
||||||
if z is not None and z.stride(-1) != 1 and z.stride(
|
if z is not None and z.stride(-1) != 1 and z.stride(
|
||||||
1) != 1: # Either M or K dimension should be contiguous
|
0) != 1: # Either M or K dimension should be contiguous
|
||||||
z = z.contiguous()
|
z = z.contiguous()
|
||||||
if D is not None and D.stride(-1) != 1:
|
if D is not None and D.stride(-1) != 1:
|
||||||
D = D.contiguous()
|
D = D.contiguous()
|
||||||
|
assert cu_seqlens is not None, "Assuming varlen input - must supply cu_seqlens"
|
||||||
|
|
||||||
if initial_states is not None:
|
if initial_states is not None:
|
||||||
if cu_seqlens is None:
|
assert initial_states.shape == (len(cu_seqlens) - 1, nheads, headdim,
|
||||||
assert initial_states.shape == (batch, nheads, headdim, dstate)
|
dstate)
|
||||||
else:
|
|
||||||
assert initial_states.shape == (len(cu_seqlens) - 1, nheads,
|
|
||||||
headdim, dstate)
|
|
||||||
|
|
||||||
# This function executes 5 sub-functions for computing mamba
|
# This function executes 5 sub-functions for computing mamba
|
||||||
# - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/
|
# - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/
|
||||||
@ -114,18 +113,16 @@ def _mamba_chunk_scan_combined_fwd(x,
|
|||||||
# - this will ensure that states will be updated with the rightmost flushed seq_idx
|
# - this will ensure that states will be updated with the rightmost flushed seq_idx
|
||||||
# of the previous chunk. This implies that the first chunk of states is either 0
|
# of the previous chunk. This implies that the first chunk of states is either 0
|
||||||
# or equal to init_states of the first example.
|
# or equal to init_states of the first example.
|
||||||
states, final_states = _state_passing_fwd(
|
states = _state_passing_fwd(
|
||||||
rearrange(states, "... p n -> ... (p n)"),
|
rearrange(states, "... p n -> ... (p n)"),
|
||||||
dA_cumsum,
|
dA_cumsum, # (nheads, nchunks, chunk_size)
|
||||||
initial_states=rearrange(initial_states, "... p n -> ... (p n)")
|
initial_states=rearrange(initial_states, "... p n -> ... (p n)")
|
||||||
if initial_states is not None else None,
|
if initial_states is not None else
|
||||||
|
None, # (batch, nheads, headdim*dstate)
|
||||||
seq_idx=seq_idx,
|
seq_idx=seq_idx,
|
||||||
chunk_size=chunk_size,
|
|
||||||
out_dtype=state_dtype if state_dtype is not None else C.dtype,
|
out_dtype=state_dtype if state_dtype is not None else C.dtype,
|
||||||
is_cont_batched=cu_seqlens is not None,
|
|
||||||
chunk_offsets=chunk_offsets)
|
chunk_offsets=chunk_offsets)
|
||||||
states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate)
|
states = rearrange(states, "... (p n) -> ... p n", n=dstate)
|
||||||
for t in [states, final_states])
|
|
||||||
|
|
||||||
# 4. Compute batched matrix multiply for C_j^T B_i terms
|
# 4. Compute batched matrix multiply for C_j^T B_i terms
|
||||||
CB = _bmm_chunk_fwd(C,
|
CB = _bmm_chunk_fwd(C,
|
||||||
@ -144,87 +141,88 @@ def _mamba_chunk_scan_combined_fwd(x,
|
|||||||
# - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had
|
# - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had
|
||||||
# a seq_idx change, in which case we take states information from
|
# a seq_idx change, in which case we take states information from
|
||||||
# init_states.
|
# init_states.
|
||||||
out_x = _chunk_scan_fwd(
|
_chunk_scan_fwd(
|
||||||
CB,
|
CB,
|
||||||
x,
|
x,
|
||||||
dt,
|
dt,
|
||||||
dA_cumsum,
|
dA_cumsum,
|
||||||
C,
|
C,
|
||||||
states,
|
states,
|
||||||
|
out, # in-place update
|
||||||
|
seq_idx,
|
||||||
D=D,
|
D=D,
|
||||||
z=z,
|
z=z,
|
||||||
seq_idx=seq_idx,
|
|
||||||
chunk_indices=chunk_indices,
|
chunk_indices=chunk_indices,
|
||||||
chunk_offsets=chunk_offsets,
|
chunk_offsets=chunk_offsets,
|
||||||
initial_states=initial_states,
|
initial_states=initial_states,
|
||||||
out=out,
|
|
||||||
)
|
)
|
||||||
if cu_seqlens is None:
|
|
||||||
return out_x, dt, dA_cumsum, states, final_states
|
varlen_states = chunk_state_varlen(
|
||||||
else:
|
B,
|
||||||
assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1"
|
x,
|
||||||
varlen_states = chunk_state_varlen(
|
dt,
|
||||||
B.squeeze(0),
|
dA_cumsum,
|
||||||
x.squeeze(0),
|
cu_seqlens,
|
||||||
dt.squeeze(0),
|
states,
|
||||||
dA_cumsum.squeeze(0),
|
initial_states=initial_states,
|
||||||
cu_seqlens,
|
)
|
||||||
states.squeeze(0),
|
|
||||||
initial_states=initial_states,
|
return varlen_states
|
||||||
)
|
|
||||||
return out_x, dt, dA_cumsum, states, final_states, varlen_states
|
|
||||||
|
|
||||||
|
|
||||||
def mamba_chunk_scan_combined(x,
|
def mamba_chunk_scan_combined_varlen(
|
||||||
dt,
|
|
||||||
A,
|
|
||||||
B,
|
|
||||||
C,
|
|
||||||
chunk_size,
|
|
||||||
D=None,
|
|
||||||
z=None,
|
|
||||||
dt_bias=None,
|
|
||||||
initial_states=None,
|
|
||||||
seq_idx=None,
|
|
||||||
chunk_indices=None,
|
|
||||||
chunk_offsets=None,
|
|
||||||
cu_seqlens=None,
|
|
||||||
dt_softplus=False,
|
|
||||||
dt_limit=(0.0, float("inf")),
|
|
||||||
out=None,
|
|
||||||
return_final_states=False,
|
|
||||||
return_varlen_states=False,
|
|
||||||
state_dtype=None):
|
|
||||||
"""
|
|
||||||
Argument:
|
|
||||||
x: (batch, seqlen, nheads, headdim)
|
|
||||||
dt: (batch, seqlen, nheads)
|
|
||||||
A: (nheads)
|
|
||||||
B: (batch, seqlen, ngroups, dstate)
|
|
||||||
C: (batch, seqlen, ngroups, dstate)
|
|
||||||
chunk_size: int
|
|
||||||
D: (nheads, headdim) or (nheads,)
|
|
||||||
z: (batch, seqlen, nheads, headdim)
|
|
||||||
dt_bias: (nheads,)
|
|
||||||
initial_states: (batch, nheads, headdim, dstate)
|
|
||||||
seq_idx: (batch, seqlen)
|
|
||||||
cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
|
|
||||||
dt_softplus: Whether to apply softplus to dt
|
|
||||||
out: Preallocated output tensor
|
|
||||||
state_dtype: The data type of the ssm state
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not return_varlen_states:
|
|
||||||
cu_seqlens = None
|
|
||||||
else:
|
|
||||||
assert cu_seqlens is not None, "cu_seqlens must be provided if return_varlen_states is True"
|
|
||||||
out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd(
|
|
||||||
x,
|
x,
|
||||||
dt,
|
dt,
|
||||||
A,
|
A,
|
||||||
B,
|
B,
|
||||||
C,
|
C,
|
||||||
chunk_size,
|
chunk_size,
|
||||||
|
cu_seqlens,
|
||||||
|
seq_idx,
|
||||||
|
out,
|
||||||
|
D=None,
|
||||||
|
z=None,
|
||||||
|
dt_bias=None,
|
||||||
|
initial_states=None,
|
||||||
|
chunk_indices=None,
|
||||||
|
chunk_offsets=None,
|
||||||
|
dt_softplus=False,
|
||||||
|
dt_limit=(0.0, float("inf")),
|
||||||
|
state_dtype=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Argument:
|
||||||
|
x: (seqlen, nheads, headdim)
|
||||||
|
dt: (seqlen, nheads)
|
||||||
|
A: (nheads)
|
||||||
|
B: (seqlen, ngroups, dstate)
|
||||||
|
C: (seqlen, ngroups, dstate)
|
||||||
|
chunk_size: int
|
||||||
|
seq_idx: (seqlen)
|
||||||
|
cu_seqlens: (batch + 1)
|
||||||
|
out: (seqlen, nheads, headdim) preallocated output tensor
|
||||||
|
D: (nheads, headdim) or (nheads,)
|
||||||
|
z: (seqlen, nheads, headdim)
|
||||||
|
dt_bias: (nheads,)
|
||||||
|
initial_states: (batch, nheads, headdim, dstate)
|
||||||
|
dt_softplus: Whether to apply softplus to dt
|
||||||
|
out: (seqlen, nheads, headdim) preallocated output tensor
|
||||||
|
state_dtype: The data type of the ssm state
|
||||||
|
Return:
|
||||||
|
varlen_states: (batch, nheads, headdim, dstate)
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert cu_seqlens is not None, "cu_seqlens must be provided assuming varlen input"
|
||||||
|
assert seq_idx is not None
|
||||||
|
|
||||||
|
varlen_states = _mamba_chunk_scan_combined_fwd(
|
||||||
|
x,
|
||||||
|
dt,
|
||||||
|
A,
|
||||||
|
B,
|
||||||
|
C,
|
||||||
|
chunk_size,
|
||||||
|
out,
|
||||||
D=D,
|
D=D,
|
||||||
z=z,
|
z=z,
|
||||||
dt_bias=dt_bias,
|
dt_bias=dt_bias,
|
||||||
@ -235,14 +233,6 @@ def mamba_chunk_scan_combined(x,
|
|||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
dt_softplus=dt_softplus,
|
dt_softplus=dt_softplus,
|
||||||
dt_limit=dt_limit,
|
dt_limit=dt_limit,
|
||||||
out=out,
|
|
||||||
state_dtype=state_dtype)
|
state_dtype=state_dtype)
|
||||||
if not return_varlen_states:
|
|
||||||
if not return_final_states:
|
return varlen_states
|
||||||
return
|
|
||||||
else:
|
|
||||||
return final_states
|
|
||||||
else:
|
|
||||||
varlen_states = rest[0]
|
|
||||||
return (varlen_states) if not return_final_states else (final_states,
|
|
||||||
varlen_states)
|
|
||||||
|
|||||||
@ -27,64 +27,46 @@ def _state_passing_fwd_kernel(
|
|||||||
# Pointers to matrices
|
# Pointers to matrices
|
||||||
states_ptr,
|
states_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
final_states_ptr,
|
|
||||||
dA_cs_ptr,
|
dA_cs_ptr,
|
||||||
initstates_ptr,
|
initstates_ptr,
|
||||||
seq_idx_ptr,
|
seq_idx_ptr,
|
||||||
chunk_offsets_ptr,
|
chunk_offsets_ptr,
|
||||||
chunk_meta_num,
|
chunk_meta_num,
|
||||||
# Matrix dimensions
|
# Matrix dimensions
|
||||||
dim,
|
dim: tl.constexpr,
|
||||||
nchunks,
|
nchunks,
|
||||||
seqlen,
|
seqlen,
|
||||||
chunk_size,
|
chunk_size: tl.constexpr,
|
||||||
# Strides
|
# Strides
|
||||||
stride_states_batch,
|
stride_states_chunk: tl.int64,
|
||||||
stride_states_chunk,
|
stride_states_head: tl.int64,
|
||||||
stride_states_head,
|
stride_states_dim: tl.constexpr,
|
||||||
stride_states_dim,
|
stride_out_chunk: tl.int64,
|
||||||
stride_out_batch,
|
stride_out_head: tl.int64,
|
||||||
stride_out_chunk,
|
stride_out_dim: tl.constexpr,
|
||||||
stride_out_head,
|
stride_dA_cs_head: tl.int64,
|
||||||
stride_out_dim,
|
stride_dA_cs_chunk: tl.int64,
|
||||||
stride_final_states_batch,
|
stride_dA_cs_csize: tl.constexpr,
|
||||||
stride_final_states_head,
|
stride_initstates_batch: tl.int64,
|
||||||
stride_final_states_dim,
|
stride_initstates_head: tl.int64,
|
||||||
stride_dA_cs_batch,
|
stride_initstates_dim: tl.constexpr,
|
||||||
stride_dA_cs_chunk,
|
stride_seq_idx_seqlen: tl.constexpr,
|
||||||
stride_dA_cs_head,
|
|
||||||
stride_dA_cs_csize,
|
|
||||||
stride_initstates_batch,
|
|
||||||
stride_initstates_head,
|
|
||||||
stride_initstates_dim,
|
|
||||||
stride_seq_idx_batch,
|
|
||||||
stride_seq_idx_seqlen,
|
|
||||||
# Meta-parameters
|
# Meta-parameters
|
||||||
HAS_INITSTATES: tl.constexpr,
|
HAS_INITSTATES: tl.constexpr,
|
||||||
HAS_SEQ_IDX: tl.constexpr,
|
|
||||||
IS_CONT_BATCHED: tl.constexpr,
|
|
||||||
BLOCK_SIZE: tl.constexpr,
|
BLOCK_SIZE: tl.constexpr,
|
||||||
):
|
):
|
||||||
pid_b = tl.program_id(axis=1)
|
pid_h = tl.program_id(axis=1)
|
||||||
pid_h = tl.program_id(axis=2)
|
|
||||||
pid_m = tl.program_id(axis=0)
|
pid_m = tl.program_id(axis=0)
|
||||||
states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
|
states_ptr += pid_h * stride_states_head
|
||||||
dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + (
|
dA_cs_ptr += pid_h * stride_dA_cs_head + (chunk_size -
|
||||||
chunk_size - 1) * stride_dA_cs_csize
|
1) * stride_dA_cs_csize
|
||||||
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
|
out_ptr += pid_h * stride_out_head
|
||||||
final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head
|
|
||||||
if HAS_INITSTATES:
|
if HAS_INITSTATES:
|
||||||
initstates_ptr += pid_h * stride_initstates_head
|
initstates_ptr += pid_h * stride_initstates_head
|
||||||
if not IS_CONT_BATCHED:
|
|
||||||
initstates_ptr += pid_b * stride_initstates_batch
|
|
||||||
|
|
||||||
if HAS_SEQ_IDX:
|
|
||||||
seq_idx_ptr += pid_b * stride_seq_idx_batch
|
|
||||||
|
|
||||||
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||||
states_ptrs = states_ptr + offs_m * stride_states_dim
|
states_ptrs = states_ptr + offs_m * stride_states_dim
|
||||||
out_ptrs = out_ptr + offs_m * stride_out_dim
|
out_ptrs = out_ptr + offs_m * stride_out_dim
|
||||||
final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim
|
|
||||||
|
|
||||||
# - states will be the past state of the sequence that continues on the current check
|
# - states will be the past state of the sequence that continues on the current check
|
||||||
if not HAS_INITSTATES:
|
if not HAS_INITSTATES:
|
||||||
@ -101,65 +83,63 @@ def _state_passing_fwd_kernel(
|
|||||||
out_ptrs += stride_out_chunk
|
out_ptrs += stride_out_chunk
|
||||||
prev_seq_idx_chunk_end = 0
|
prev_seq_idx_chunk_end = 0
|
||||||
logical_chunk_idx = 0
|
logical_chunk_idx = 0
|
||||||
for c in range(nchunks):
|
for c in range(nchunks - 1):
|
||||||
new_states = tl.load(states_ptrs, mask=offs_m < dim,
|
new_states = tl.load(states_ptrs, mask=offs_m < dim,
|
||||||
other=0.0).to(tl.float32)
|
other=0.0).to(tl.float32)
|
||||||
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
|
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
|
||||||
scale_mask = True
|
scale_mask = True
|
||||||
if HAS_SEQ_IDX:
|
# - the seq to pass forward is the one that is flushed to the right
|
||||||
# - the seq to pass forward is the one that is flushed to the right
|
# boundary.
|
||||||
# boundary.
|
# - that is given by seq_idx_chunk_end below: the sequence index at the end of the chunk.
|
||||||
# - that is given by seq_idx_chunk_end below: the sequence index at the end of the chunk.
|
seq_idx_chunk_end = tl.load(seq_idx_ptr +
|
||||||
seq_idx_chunk_end = tl.load(seq_idx_ptr + (min(
|
(min((c + 1) * chunk_size, seqlen) - 1) *
|
||||||
(c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen)
|
stride_seq_idx_seqlen)
|
||||||
if HAS_INITSTATES:
|
|
||||||
if IS_CONT_BATCHED and prev_seq_idx_chunk_end != seq_idx_chunk_end:
|
|
||||||
# this means in the current chunk the rightmost flushed seq
|
|
||||||
# has changed.
|
|
||||||
# - so we do not propagate the state from previous chunk
|
|
||||||
# - but rather we load that sequence's init state
|
|
||||||
initstates_ptrs = initstates_ptr + seq_idx_chunk_end * stride_initstates_batch
|
|
||||||
|
|
||||||
# - update state with seq_idx_new's init state
|
if HAS_INITSTATES:
|
||||||
states = tl.load(initstates_ptrs,
|
if prev_seq_idx_chunk_end != seq_idx_chunk_end:
|
||||||
mask=offs_m < dim,
|
# this means in the current chunk the rightmost flushed seq
|
||||||
other=0.0).to(tl.float32)
|
# has changed.
|
||||||
|
# - so we do not propagate the state from previous chunk
|
||||||
|
# - but rather we load that sequence's init state
|
||||||
|
initstates_ptrs = initstates_ptr + seq_idx_chunk_end * stride_initstates_batch
|
||||||
|
|
||||||
# - we need to consider the cumsum only of the last sequence in the chunk
|
# - update state with seq_idx_new's init state
|
||||||
# - find its starting position (given by c_off of the logical chunk index)
|
states = tl.load(initstates_ptrs, mask=offs_m < dim,
|
||||||
# - and subtract the cumsum just before that position from the total cumsum
|
other=0.0).to(tl.float32)
|
||||||
# - first, update the logical chunk index (add the number of sequences in the current physical chunk):
|
|
||||||
# sequence index at the start of the current chunk
|
|
||||||
seq_idx_chunk_start = tl.load(seq_idx_ptr +
|
|
||||||
min(c * chunk_size, seqlen) *
|
|
||||||
stride_seq_idx_seqlen)
|
|
||||||
logical_chunk_idx += seq_idx_chunk_end - seq_idx_chunk_start
|
|
||||||
# - load the chunk offset:
|
|
||||||
c_off = tl.load(chunk_offsets_ptr + logical_chunk_idx,
|
|
||||||
mask=logical_chunk_idx < chunk_meta_num,
|
|
||||||
other=0)
|
|
||||||
# - if offset is 0, then the sequence starts at the beginning of the chunk, and we don't need to subtract anything
|
|
||||||
if c_off > 0:
|
|
||||||
# - dA_cs_ptr currently points to the cumsum at the end of the chunk - subtract the chunk size and add the offset
|
|
||||||
dA_cs_boundary = tl.load(
|
|
||||||
dA_cs_ptr - (chunk_size - 1) * stride_dA_cs_csize +
|
|
||||||
(c_off - 1) * stride_dA_cs_csize,
|
|
||||||
mask=(c_off - 1) > -1 and c_off < chunk_size,
|
|
||||||
other=0.0)
|
|
||||||
dA_cs -= dA_cs_boundary
|
|
||||||
|
|
||||||
# - increment logical chunk index for every physical chunk
|
# - we need to consider the cumsum only of the last sequence in the chunk
|
||||||
logical_chunk_idx += 1
|
# - find its starting position (given by c_off of the logical chunk index)
|
||||||
else:
|
# - and subtract the cumsum just before that position from the total cumsum
|
||||||
scale_mask = seq_idx_chunk_end == prev_seq_idx_chunk_end
|
# - first, update the logical chunk index (add the number of sequences in the current physical chunk):
|
||||||
prev_seq_idx_chunk_end = seq_idx_chunk_end
|
# sequence index at the start of the current chunk
|
||||||
|
seq_idx_chunk_start = tl.load(seq_idx_ptr +
|
||||||
|
min(c * chunk_size, seqlen) *
|
||||||
|
stride_seq_idx_seqlen)
|
||||||
|
logical_chunk_idx += seq_idx_chunk_end - seq_idx_chunk_start
|
||||||
|
# - load the chunk offset:
|
||||||
|
c_off = tl.load(chunk_offsets_ptr + logical_chunk_idx,
|
||||||
|
mask=logical_chunk_idx < chunk_meta_num,
|
||||||
|
other=0)
|
||||||
|
# - if offset is 0, then the sequence starts at the beginning of the chunk, and we don't need to subtract anything
|
||||||
|
if c_off > 0:
|
||||||
|
# - dA_cs_ptr currently points to the cumsum at the end of the chunk - subtract the chunk size and add the offset
|
||||||
|
dA_cs_boundary = tl.load(
|
||||||
|
dA_cs_ptr - (chunk_size - 1) * stride_dA_cs_csize +
|
||||||
|
(c_off - 1) * stride_dA_cs_csize,
|
||||||
|
mask=(c_off - 1) > -1 and c_off < chunk_size,
|
||||||
|
other=0.0)
|
||||||
|
dA_cs -= dA_cs_boundary
|
||||||
|
|
||||||
|
# - increment logical chunk index for every physical chunk
|
||||||
|
logical_chunk_idx += 1
|
||||||
|
else:
|
||||||
|
scale_mask = seq_idx_chunk_end == prev_seq_idx_chunk_end
|
||||||
|
prev_seq_idx_chunk_end = seq_idx_chunk_end
|
||||||
|
|
||||||
scale = tl.where(scale_mask, tl.exp(dA_cs), 0.0)
|
scale = tl.where(scale_mask, tl.exp(dA_cs), 0.0)
|
||||||
states = scale * states + new_states
|
states = scale * states + new_states
|
||||||
if c < nchunks - 1:
|
tl.store(out_ptrs, states, mask=offs_m < dim)
|
||||||
tl.store(out_ptrs, states, mask=offs_m < dim)
|
|
||||||
else:
|
|
||||||
tl.store(final_states_ptrs, states, mask=offs_m < dim)
|
|
||||||
states_ptrs += stride_states_chunk
|
states_ptrs += stride_states_chunk
|
||||||
dA_cs_ptr += stride_dA_cs_chunk
|
dA_cs_ptr += stride_dA_cs_chunk
|
||||||
out_ptrs += stride_out_chunk
|
out_ptrs += stride_out_chunk
|
||||||
@ -168,81 +148,53 @@ def _state_passing_fwd_kernel(
|
|||||||
def _state_passing_fwd(
|
def _state_passing_fwd(
|
||||||
states,
|
states,
|
||||||
dA_cumsum,
|
dA_cumsum,
|
||||||
|
seq_idx,
|
||||||
|
chunk_offsets,
|
||||||
initial_states=None,
|
initial_states=None,
|
||||||
seq_idx=None,
|
|
||||||
chunk_size=None,
|
|
||||||
out_dtype=None,
|
out_dtype=None,
|
||||||
is_cont_batched=False,
|
|
||||||
chunk_offsets=None,
|
|
||||||
):
|
):
|
||||||
batch, nchunks, nheads, dim = states.shape
|
nchunks, nheads, dim = states.shape
|
||||||
if chunk_size is None:
|
chunk_size = dA_cumsum.shape[-1]
|
||||||
chunk_size = dA_cumsum.shape[-1]
|
assert dA_cumsum.shape == (nheads, nchunks, chunk_size)
|
||||||
else:
|
seqlen = seq_idx.shape[-1]
|
||||||
assert chunk_size == dA_cumsum.shape[-1]
|
|
||||||
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
|
|
||||||
if initial_states is not None:
|
|
||||||
if is_cont_batched:
|
|
||||||
# - if cu_seqlens is provided, then the initial states
|
|
||||||
# are used for continuous batching. In which case we
|
|
||||||
# require seq_idx to be provided
|
|
||||||
assert seq_idx is not None, "seq_idx must be provided for continuous batching"
|
|
||||||
# - we also need chunk_offsets to be provided, to account
|
|
||||||
# for computation of dA_cumsum from the start of the
|
|
||||||
# sequence
|
|
||||||
assert chunk_offsets is not None, "chunk_offsets must be provided for continuous batching"
|
|
||||||
else:
|
|
||||||
# - this is the regular batching case, where initial
|
|
||||||
# states are used are for each example of the batch.
|
|
||||||
assert initial_states.shape == (batch, nheads, dim)
|
|
||||||
|
|
||||||
if seq_idx is not None:
|
|
||||||
seqlen = seq_idx.shape[-1]
|
|
||||||
assert seq_idx.shape == (batch, seqlen)
|
|
||||||
out_dtype = states.dtype if out_dtype is None else out_dtype
|
out_dtype = states.dtype if out_dtype is None else out_dtype
|
||||||
out = torch.empty((batch, nchunks, nheads, dim),
|
out = torch.empty((nchunks, nheads, dim),
|
||||||
device=states.device,
|
device=states.device,
|
||||||
dtype=out_dtype)
|
dtype=out_dtype)
|
||||||
final_states = torch.empty((batch, nheads, dim),
|
|
||||||
device=states.device,
|
initial_states_strides = ((initial_states.stride(0),
|
||||||
dtype=torch.float32)
|
initial_states.stride(1),
|
||||||
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)
|
initial_states.stride(2))
|
||||||
|
if initial_states is not None else (0, 0, 0))
|
||||||
|
|
||||||
|
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), nheads)
|
||||||
with torch.cuda.device(states.device.index):
|
with torch.cuda.device(states.device.index):
|
||||||
_state_passing_fwd_kernel[grid](
|
_state_passing_fwd_kernel[grid](
|
||||||
states,
|
states_ptr=states,
|
||||||
out,
|
out_ptr=out,
|
||||||
final_states,
|
dA_cs_ptr=dA_cumsum,
|
||||||
dA_cumsum,
|
initstates_ptr=initial_states,
|
||||||
initial_states,
|
seq_idx_ptr=seq_idx,
|
||||||
seq_idx,
|
chunk_offsets_ptr=chunk_offsets,
|
||||||
chunk_offsets,
|
chunk_meta_num=len(chunk_offsets)
|
||||||
len(chunk_offsets) if chunk_offsets is not None else 0,
|
if chunk_offsets is not None else 0,
|
||||||
dim,
|
dim=dim,
|
||||||
nchunks,
|
nchunks=nchunks,
|
||||||
seqlen if seq_idx is not None else 0,
|
seqlen=seqlen if seq_idx is not None else 0,
|
||||||
chunk_size,
|
chunk_size=chunk_size if seq_idx is not None else 0,
|
||||||
states.stride(0),
|
stride_states_chunk=states.stride(0),
|
||||||
states.stride(1),
|
stride_states_head=states.stride(1),
|
||||||
states.stride(2),
|
stride_states_dim=states.stride(2),
|
||||||
states.stride(3),
|
stride_out_chunk=out.stride(0),
|
||||||
out.stride(0),
|
stride_out_head=out.stride(1),
|
||||||
out.stride(1),
|
stride_out_dim=out.stride(2),
|
||||||
out.stride(2),
|
stride_dA_cs_head=dA_cumsum.stride(0),
|
||||||
out.stride(3),
|
stride_dA_cs_chunk=dA_cumsum.stride(1),
|
||||||
final_states.stride(0),
|
stride_dA_cs_csize=dA_cumsum.stride(2),
|
||||||
final_states.stride(1),
|
stride_initstates_batch=initial_states_strides[0],
|
||||||
final_states.stride(2),
|
stride_initstates_head=initial_states_strides[1],
|
||||||
dA_cumsum.stride(0),
|
stride_initstates_dim=initial_states_strides[2],
|
||||||
dA_cumsum.stride(2),
|
stride_seq_idx_seqlen=seq_idx.stride(0),
|
||||||
dA_cumsum.stride(1),
|
|
||||||
dA_cumsum.stride(3),
|
|
||||||
*((initial_states.stride(0), initial_states.stride(1),
|
|
||||||
initial_states.stride(2)) if initial_states is not None else
|
|
||||||
(0, 0, 0)),
|
|
||||||
*((seq_idx.stride(0),
|
|
||||||
seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
|
|
||||||
HAS_INITSTATES=initial_states is not None,
|
HAS_INITSTATES=initial_states is not None,
|
||||||
HAS_SEQ_IDX=seq_idx is not None,
|
|
||||||
IS_CONT_BATCHED=is_cont_batched,
|
|
||||||
)
|
)
|
||||||
return out, final_states
|
return out
|
||||||
|
|||||||
@ -35,7 +35,7 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
|||||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||||
selective_state_update)
|
selective_state_update)
|
||||||
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
|
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
|
||||||
mamba_chunk_scan_combined)
|
mamba_chunk_scan_combined_varlen)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
@ -262,6 +262,7 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
|
|||||||
seq_idx_p = attn_metadata.seq_idx_p
|
seq_idx_p = attn_metadata.seq_idx_p
|
||||||
chunk_indices_p = attn_metadata.chunk_indices_p
|
chunk_indices_p = attn_metadata.chunk_indices_p
|
||||||
chunk_offsets_p = attn_metadata.chunk_offsets_p
|
chunk_offsets_p = attn_metadata.chunk_offsets_p
|
||||||
|
query_start_loc_p = attn_metadata.query_start_loc_p
|
||||||
|
|
||||||
# 1. Gated MLP's linear projection
|
# 1. Gated MLP's linear projection
|
||||||
projected_states = self.in_proj(hidden_states)
|
projected_states = self.in_proj(hidden_states)
|
||||||
@ -302,9 +303,6 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
|
|||||||
[num_decodes, num_prefills],
|
[num_decodes, num_prefills],
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
query_start_loc_p = (
|
|
||||||
attn_metadata.query_start_loc[-num_prefills - 1:] -
|
|
||||||
num_decodes if has_prefill else None)
|
|
||||||
|
|
||||||
# Preallocate output tensor to avoid memcpy cost for merging prefill
|
# Preallocate output tensor to avoid memcpy cost for merging prefill
|
||||||
# and decode outputs
|
# and decode outputs
|
||||||
@ -356,17 +354,17 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
|
|||||||
has_initial_states_p[:, None, None, None],
|
has_initial_states_p[:, None, None, None],
|
||||||
ssm_state[state_indices_tensor_p], 0)
|
ssm_state[state_indices_tensor_p], 0)
|
||||||
|
|
||||||
varlen_state = mamba_chunk_scan_combined(
|
varlen_state = mamba_chunk_scan_combined_varlen(
|
||||||
hidden_states_p.view(1, num_prefill_tokens,
|
hidden_states_p.view(num_prefill_tokens,
|
||||||
self.num_heads // self.tp_size,
|
self.num_heads // self.tp_size,
|
||||||
self.head_dim),
|
self.head_dim),
|
||||||
dt.unsqueeze(0),
|
dt,
|
||||||
self.A,
|
self.A,
|
||||||
B.view(1, num_prefill_tokens, 1, -1),
|
B.view(num_prefill_tokens, 1, -1),
|
||||||
C.view(1, num_prefill_tokens, 1, -1),
|
C.view(num_prefill_tokens, 1, -1),
|
||||||
chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
D=self.D,
|
D=self.D,
|
||||||
z=gate_p.view(1, num_prefill_tokens,
|
z=gate_p.view(num_prefill_tokens,
|
||||||
self.num_heads // self.tp_size, self.head_dim),
|
self.num_heads // self.tp_size, self.head_dim),
|
||||||
dt_bias=self.dt_bias,
|
dt_bias=self.dt_bias,
|
||||||
seq_idx=seq_idx_p,
|
seq_idx=seq_idx_p,
|
||||||
@ -374,11 +372,9 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
|
|||||||
chunk_offsets=chunk_offsets_p,
|
chunk_offsets=chunk_offsets_p,
|
||||||
cu_seqlens=query_start_loc_p,
|
cu_seqlens=query_start_loc_p,
|
||||||
initial_states=initial_states,
|
initial_states=initial_states,
|
||||||
return_varlen_states=True,
|
|
||||||
return_final_states=False,
|
|
||||||
dt_softplus=True,
|
dt_softplus=True,
|
||||||
dt_limit=(0.0, float("inf")),
|
dt_limit=(0.0, float("inf")),
|
||||||
out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1,
|
out=preallocated_ssm_out_p.view(num_prefill_tokens, -1,
|
||||||
self.head_dim),
|
self.head_dim),
|
||||||
state_dtype=ssm_state.dtype,
|
state_dtype=ssm_state.dtype,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -115,7 +115,7 @@ class Mamba2AttentionMetadata:
|
|||||||
num_prefill_tokens: int
|
num_prefill_tokens: int
|
||||||
num_decodes: int
|
num_decodes: int
|
||||||
num_decode_tokens: int
|
num_decode_tokens: int
|
||||||
query_start_loc: torch.Tensor
|
query_start_loc_p: torch.Tensor
|
||||||
seq_lens: torch.Tensor
|
seq_lens: torch.Tensor
|
||||||
|
|
||||||
prep_initial_states: bool
|
prep_initial_states: bool
|
||||||
@ -151,7 +151,7 @@ class Mamba2AttentionMetadataBuilder(
|
|||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
fast_build: bool = False) -> Mamba2AttentionMetadata:
|
fast_build: bool = False) -> Mamba2AttentionMetadata:
|
||||||
num_reqs = common_attn_metadata.num_reqs
|
num_reqs = common_attn_metadata.num_reqs
|
||||||
query_start_loc = common_attn_metadata.query_start_loc
|
query_start_loc_p = None
|
||||||
seq_lens = common_attn_metadata.seq_lens
|
seq_lens = common_attn_metadata.seq_lens
|
||||||
|
|
||||||
seq_idx_p = None
|
seq_idx_p = None
|
||||||
@ -179,7 +179,7 @@ class Mamba2AttentionMetadataBuilder(
|
|||||||
num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] > 0)
|
num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] > 0)
|
||||||
prep_initial_states = torch.any(has_initial_states_cpu).item()
|
prep_initial_states = torch.any(has_initial_states_cpu).item()
|
||||||
has_initial_states_p = has_initial_states_cpu.to(
|
has_initial_states_p = has_initial_states_cpu.to(
|
||||||
query_start_loc.device)
|
common_attn_metadata.query_start_loc.device)
|
||||||
|
|
||||||
query_start_loc_p = common_attn_metadata.query_start_loc[
|
query_start_loc_p = common_attn_metadata.query_start_loc[
|
||||||
-num_prefills - 1:] - num_decode_tokens
|
-num_prefills - 1:] - num_decode_tokens
|
||||||
@ -190,7 +190,6 @@ class Mamba2AttentionMetadataBuilder(
|
|||||||
device=query_start_loc_p.device),
|
device=query_start_loc_p.device),
|
||||||
query_start_loc_p.diff(),
|
query_start_loc_p.diff(),
|
||||||
output_size=num_prefill_tokens)
|
output_size=num_prefill_tokens)
|
||||||
seq_idx_p.unsqueeze_(0)
|
|
||||||
|
|
||||||
# We compute metadata for chunked prefill once at the top level
|
# We compute metadata for chunked prefill once at the top level
|
||||||
# model forward and reuse them in mamba layers. If not needed,
|
# model forward and reuse them in mamba layers. If not needed,
|
||||||
@ -217,7 +216,7 @@ class Mamba2AttentionMetadataBuilder(
|
|||||||
num_prefill_tokens=num_prefill_tokens,
|
num_prefill_tokens=num_prefill_tokens,
|
||||||
num_decodes=num_decodes,
|
num_decodes=num_decodes,
|
||||||
num_decode_tokens=num_decode_tokens,
|
num_decode_tokens=num_decode_tokens,
|
||||||
query_start_loc=query_start_loc,
|
query_start_loc_p=query_start_loc_p,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
prep_initial_states=prep_initial_states,
|
prep_initial_states=prep_initial_states,
|
||||||
chunk_size=self.chunk_size,
|
chunk_size=self.chunk_size,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user